# -*- coding: utf-8 -*-

import numpy as np
import numpy.linalg as npl
import matplotlib.pyplot as plt

def u0(x): # donnée initiale
#    return np.sin(2*np.pi*(x))
	return (x>0.3)*(x<0.7).astype(float)
#	return x*x*x*x*(1-x)*(1-x)*(1-x)*(1-x)

def f(t,x): # terme source
#    return(np.sin(2.*np.pi*x)*(1. - 1./(t+1)))
	return(0.)

L = 1. # longueur de l'intervalle en espace
J = 1000 # nombre de degrés de liberté
dx = L/(J+1) # pas d'espace
kappa = 1. # coefficient de diffusion
T = 0.01 # temps final
theta = 1. # paramètre du theta-schéma
c = 0.5 # coefficient de sécurité pour la condition de stabilité (doit être inférieur à 1)
X = np.linspace(0,L,J+2)

if theta < 0.5: # calcul d'un pas de temps assurant la stabilité lorsque theta est inférieur à 1/2 (si c est inférieur à 1)
	dt = c*0.5*dx*dx/kappa/(1.-2.*theta)
else:
	dt = c*dx
X = np.linspace(0,L,J+2)
A = 2.*np.diag(np.ones(J),0) - np.diag(np.ones(J-1),1) - np.diag(np.ones(J-1),-1)
#A[0,0] = 1 # pour Neumann homogène : c'est donc à commenter pour retrouver les conditions de Dirichlet homogènes
#A[J-1,J-1] = 1 # pour Neumann homogène : c'est donc à commenter pour retrouver les conditions de Dirichlet homogènes
B = np.diag(np.ones(J),0) + theta*kappa*dt/dx/dx*A # matrice pour la partie implicite
C = np.diag(np.ones(J),0) + (theta-1.)*kappa*dt/dx/dx*A # matrice pour la partie explicite
U0 = u0(X)
U = u0(X[1:J+1])
t = 0.
n = 0
k = 0

plt.grid('on')
plt.plot(X,np.hstack((0,U,0)))
plt.grid()
plt.pause(dt)

et = np.dot(np.transpose(U),U)/J
tt = np.array([0])

while t<T:
	plt.clf()
	print("itération n = ", n, "temps t = ", t)
	dt = min(dt,T-t)
	F = f(t,X)
	if t + dt<= T: # ça c'est pour tous les pas de temps sauf le dernier
		F = (1 - theta)*f(t,X) + theta*f(t + dt,X) # terme source
		t = t + dt
		n = n + 1
		U = npl.solve(B,np.dot(C,U) + dt*F)
		e = np.dot(np.transpose(U),U)/J
		et = np.vstack((et,e))
		tt = np.vstack((tt,t))
		titre = "Solution au temps t = " + str(t)
		plt.plot(X,U0,label="Donnée initiale")
		plt.legend()
		plt.plot(X,np.hstack((0,U,0)),label=titre) # pour Dirichlet homogène
		#plt.plot(X,np.hstack((U[0],U,U[J-1])),label=titre) # pour Neumann homogène
		plt.legend()
		plt.grid()
		plt.pause(dt)
	else: # ça c'est pour le dernier pas de temps
		ds = T-dt
		F = (1 - theta)*f(t,X) + theta*f(t + ds,X) # terme source 
		t = T
		n = n + 1
		Bs = np.diag(np.ones(J),0) + theta*kappa*ds/dx/dx*A # matrice pour la partie implicite
		Cs = np.diag(np.ones(J),0) + (theta-1.)*kappa*ds/dx/dx*A # matrice pour la partie explicite 
		U = npl.solve(Bs,np.dot(Cs,U) + ds*F)
 
titre = "Solution au temps final, T = " + str(T)
plt.plot(X,U0,label="Donnée initiale")
plt.legend()
plt.plot(X,np.hstack((0,U,0)),label=titre) # pour Dirichlet homogène
#plt.plot(X,np.hstack((U[0],U,U[J-1])),label=titre) # pour Neumann homogène
plt.legend()
plt.grid()
plt.show()

plt.plot(tt,et,label="Énergie au cours du temps")
plt.legend()
plt.grid()
plt.show()
#############################
