import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import scipy.sparse as sp
from scipy.sparse.linalg import splu

def upwind(u, a, dt, dx):
    u_flux = np.empty(u.size + 1)
    u_flux[1:] = u
    u_flux[0] = u_flux[-1]

    u -= a * dt / dx * (u_flux[1:] - u_flux[:-1])

def centre(u, a, dt, dx):
    u_flux = np.empty(u.size + 1)
    u_flux[1:-1] = 0.5 * (u[:-1] + u[1:])
    u_flux[-1] = 0.5 * (u[-1] + u[0])
    u_flux[0] = u_flux[-1]

    u -= a * dt / dx * (u_flux[1:] - u_flux[:-1])

def downwind(u, a, dt, dx):
    u_flux = np.empty(u.size + 1)
    u_flux[:-1] = u
    u_flux[-1] = u_flux[0]

    u -= a * dt / dx * (u_flux[1:] - u_flux[:-1])

def Lax_Friedrichs(u, a, dt, dx):
    u_flux = np.empty(u.size + 1)
    u_flux[1:-1] = 0.5 * (u[:-1] + u[1:]) + dx / (2 * a * dt) * (u[:-1] - u[1:])
    u_flux[-1] = 0.5 * (u[-1] + u[0]) + dx / (2 * a * dt) * (u[-1] - u[0])
    u_flux[0] = u_flux[-1]

    u -= a * dt / dx * (u_flux[1:] - u_flux[:-1])

def Rusanov(u, a, dt, dx, c):
    u_flux = np.empty(u.size + 1)
    u_flux[1:-1] = 0.5 * (u[:-1] + u[1:]) + c / (2 * a) * (u[:-1] - u[1:])
    u_flux[-1] = 0.5 * (u[-1] + u[0]) + c / (2 * a) * (u[-1] - u[0])
    u_flux[0] = u_flux[-1]

    u -= a * dt / dx * (u_flux[1:] - u_flux[:-1])

def Lax_Wendroff(u, a, dt, dx):
    u_flux = np.empty(u.size + 1)
    u_flux[1:-1] = 0.5 * (u[:-1] + u[1:]) - a * dt / (2 * dx) * (u[1:] - u[:-1])
    u_flux[-1] = 0.5 * (u[-1] + u[0]) - a * dt / (2 * dx) * (u[0] - u[-1])
    u_flux[0] = u_flux[-1]

    u -= a * dt / dx * (u_flux[1:] - u_flux[:-1])

def upwind_implicit(u, a, dt, dx, mat=[np.array([])]):
    if mat[0].shape[0] != u.size:
        mat[0] = sp.diags([1+a*dt / dx, -a*dt/dx], [0, -1], shape=(u.size, u.size), format="lil")
        mat[0][0, -1] = -a * dt / dx
        mat[0] = mat[0].tocsc()
        mat[0] = splu(mat[0])

    np.copyto(u, mat[0].solve(u))
    
    
def initial_condition_nonsmooth(x):
    u = np.zeros(x.size)
    u[np.logical_and(x > 0.15, x < 0.35)] = 1
    return u

def initial_condition_lipschitz(x):
    u = np.empty(x.size)

    a1, b1 = 0.1, 0.3
    a2, b2 = 0.4, 0.6
    
    u[x < a1] = 0

    ind = np.logical_and(x >= a1, x <= b1)
    u[ind] = (x[ind] - a1) / (b1 - a1)
    
    u[np.logical_and(x > b1, x < a2)] = 1

    ind = np.logical_and(x >= a2, x <= b2)
    u[ind] = (a2 - x[ind]) / (b2 - a2) + 1

    u[x > b2] = 0
    
    return u

def initial_condition_smooth(x):
    u = np.empty(x.size)

    a1, b1 = 0.1, 0.3
    a2, b2 = 0.4, 0.6
    
    u[x < a1] = 0

    ind = np.logical_and(x >= a1, x <= b1)
    X = (x[ind] - a1) / (b1 - a1)
    u[ind] = X**4 * (35 + X * (-84 + X * (70 - 20 * X)))
    
    u[np.logical_and(x > b1, x < a2)] = 1

    ind = np.logical_and(x >= a2, x <= b2)
    X = (x[ind] - a2) / (b2 - a2)
    u[ind] = 1 - X**4 * (35 + X * (-84 + X * (70 - 20 * X)))

    u[x > b2] = 0
    
    return u


def run_with_animation(init_cond, f_scheme, a, dt, dx, T):
    # condition initiale
    x = np.arange(0.5 * dx, 1, dx)
    u0 = init_cond(x)
    u = np.empty_like(u0)
    
    # discretisation en temps
    Nt = int(T / dt) + 1
    
    # plot initial
    fig = plt.figure()
    ax = fig.gca()
    plot, = ax.plot(x, u0)

    # fonction pour initialiser l'animation (pour répéter l'animation)
    def init_animation():
        np.copyto(u, u0)
        plot.set_ydata(u)
        ax.set_title(f"temps : 0")
    
    # fonction pour mettre à jour les images de l'animation
    def update_animation(frame_number):
        f_scheme(u, a, dt, dx)
        plot.set_ydata(u)
        ax.set_title(f"temps : {frame_number * dt:.2f}")

    
    anim = animation.FuncAnimation(fig, update_animation, frames=Nt, init_func=init_animation, interval=1e3 * dt, repeat=True)
    plt.show()


def run_with_loop(init_cond, f_scheme, a, dt, dx, T):
    # condition initiale
    x = np.arange(0.5 * dx, 1, dx)
    u0 = init_cond(x)
    u = np.empty_like(u0)

    # plot initial
    fig = plt.figure()
    ax = fig.gca()
    plot, = ax.plot(x, u0)

    # Pour répéter l'animation tant que la figure est ouverte
    while plt.fignum_exists(fig.number):        
        t = 0
        np.copyto(u, u0)
        plot.set_ydata(u)
        ax.set_title(f"temps : 0")
        plt.pause(dt)

        # Pour mettre à jour l'animation
        while t < T:
            f_scheme(u, a, dt, dx)

            t += dt
            plot.set_ydata(u)
            ax.set_title(f"temps : {t:.2f}")
            plt.pause(dt)

def run_without_plot(init_cond, f_scheme, a, dt, dx, T):
    # condition initiale
    x = np.arange(0.5 * dx, 1, dx)
    u0 = init_cond(x)
    u = np.empty_like(u0)

    t = 0
    np.copyto(u, u0)
    
    while t < T:
        f_scheme(u, a, dt, dx)
        t += min(dt, T-t)

    return u


# Test des schemas
T = 1.25
a = 1
a_cfl = 0.5
Nmin = 5
Nmax = 12

fig, axes = plt.subplots(3, 3)

for k in range(Nmin, Nmax+1):
    dx = 1 / 2**k
    dt = a_cfl * dx / a

    for plot_id, f_scheme in enumerate([upwind, centre, downwind, Lax_Friedrichs, lambda u, a, dt, dx : Rusanov(u, a, dt, dx, 1.5*a), Lax_Wendroff, upwind_implicit]):
        u = run_without_plot(initial_condition_nonsmooth, f_scheme, a, dt, dx, T)
        ax = axes.ravel()[plot_id]
        ax.plot(np.arange(0.5 * dx, 1, dx), u)
        ax.set_title(f_scheme.__name__)

plt.show()


# Test de la condition CFL
T = 1.25
a = 1

dx = 1 / 2**10
dt = 4 * dx / a

print(f"a = {a}")
print(f"dx = {dx}")
print(f"dt = {dt}")
print(f"CFL : a * dt / dx = {a * dt / dx}")

fig, axes = plt.subplots(3, 3)

for plot_id, f_scheme in enumerate([upwind, centre, downwind, Lax_Friedrichs, lambda u, a, dt, dx : Rusanov(u, a, dt, dx, 2*a), Lax_Wendroff, upwind_implicit]):
    u = run_without_plot(initial_condition_nonsmooth, f_scheme, a, dt, dx, T)
    ax = axes.ravel()[plot_id]
    ax.plot(np.arange(0.5 * dx, 1, dx), u)
    ax.set_title(f_scheme.__name__)

plt.show()


def run_conv(init_cond, Nmin, Nmax, T, a, a_cfl, sl_l1, sl_linf):
    err_l1 = np.empty((5, Nmax - Nmin + 1))
    err_linf = np.empty((5, Nmax - Nmin + 1))
    dxs = 1 / 2**np.arange(Nmin, Nmax + 1)

    dxmax = 1 / 2**(Nmax+2)
    x = np.arange(0.5 * dxmax, 1, dxmax)
    u_exact = init_cond(x)

    fig, axes = plt.subplots(2, 1)

    for ischeme, f_scheme in enumerate([upwind, Lax_Friedrichs, lambda u, a, dt, dx : Rusanov(u, a, dt, dx, 1.5*a), Lax_Wendroff, upwind_implicit]):
        for k, dx in enumerate(dxs):
            dt = a_cfl * dx / a
            
            u = run_without_plot(init_cond, f_scheme, a, dt, dx, T)
            u_interp = np.repeat(u, 2**(Nmax + 2 - (k + Nmin)))
            err_l1[ischeme, k] = dxmax * np.linalg.norm(u_interp - u_exact, ord=1)
            err_linf[ischeme, k] = np.linalg.norm(u_interp - u_exact, ord=np.inf)

        axes[0].loglog(dxs, err_l1[ischeme, :], label=f_scheme.__name__)
        axes[1].loglog(dxs, err_linf[ischeme, :], label=f_scheme.__name__)

    axes[0].loglog(dxs, dxs**sl_l1, label=f"slope={sl_l1}")
    axes[0].set_title("Erreur L1")
    axes[0].legend()

    axes[1].loglog(dxs, dxs**sl_linf, label=f"slope={sl_linf}")
    axes[1].set_title("Erreur Linf")
    axes[1].legend()

    plt.show()


# Non-smooth initial data
run_conv(initial_condition_nonsmooth, 5, 12, 1, 1, 0.5, 0.5, 0)

# Lipschitz initial data
run_conv(initial_condition_lipschitz, 5, 12, 1, 1, 0.5, 1, 0.5)

# Smooth initial data
run_conv(initial_condition_smooth, 5, 12, 1, 1, 0.5, 1, 1)
