# -*- coding: utf-8 -*-


import numpy as np
import numpy.linalg as npl
import matplotlib.pyplot as plt
import scipy.sparse as sp
from scipy.sparse.linalg import splu


def u0(x):
#    return np.exp(-50*(x-0.5)*(x-0.5))
#    return 0.*x
#    return (x>0.)
#    return ((x>0.3)*(x<0.7)).astype(float)
    return np.sin(2.*np.pi*x)

def vitesse(x):
    #return 1.
    return 2. + np.sin(2.*np.pi*x)

L = 1.
J = 100
N = 10
T = 1.
C = 0.5

dx = 1./J
X = dx*np.arange(J) + 0.5*dx
a = vitesse(X)
v = np.max(abs(a))
u = u0(X)
n = 0
t = 0.
dt = T/N
nu = dt/dx

mat = sp.diags([1+a*dt / dx, -a*dt/dx], [0, -1], shape=(u.size, u.size), format="lil")
mat[0, -1] = -a[J-1] * dt / dx
mat = splu(mat)

while t<T:
    plt.clf()
    t = t + dt
    n = n + 1
    np.copyto(u, mat.solve(u))
    ax = plt.plot()
    p, = plt.plot(X,u)
    p.set_ydata(u)
    plt.grid()
    plt.draw()
    plt.pause(dt)

    print ("itération  ", n)
    print("t = ", t)


plt.plot(X,u)
plt.grid()
plt.show()

#############################
