#! /usr/bin/env python

from numpy import *
from Gnuplot import Gnuplot, GnuplotOpts, Data, funcutils

"""Solve time dependent diffusion equation"""


def bound(location,T): #boundary conditions
    if location == float(): 
        return 1.
    else:
        return -1.

def march(L,n,lo,hi,dt,alpha,T,u): #the actual solver
    up = zeros(n+1)
    x = linspace(0,L,n+1) # grid points in x direction
    dx = float(L)/float(n)
    up[1:n] = u[1:n] + alpha*dt/dx/dx*(u[2:n+1] - 2*u[1:n] + u[0:n-1])
    up[0] = bound(lo,T)
    up[n] = bound(hi,T)
    u = up
    return up,x


# when executed, just run:
if __name__ == '__main__':
    import time
    L = 1.
    n = 40
    alpha = 5. #heat conduction coefficient
    dt = 0.00005  #time increment

    g = Gnuplot()    
    lo = 0.
    hi = 1.
    u = ones(n+1)   #set up the initial conditions
    u[n/2:n] = -1.  #set up the initial conditions
    u[0]=bound(lo,0.)
    i = 0
    T = 0.
    while T < 0.0175:
       t0 = time.clock()
       up,x=march(L,n,lo,hi, dt,alpha,T,u) #call solver
       pldat = Data(up,with_='line lw 2 lc rgb "red"') #data to be plotted
       g.reset #reset the plot
       g('set data style lines')
       g('set border lw 2')
       g('set xlabel font "Helvetica, 18"')
       g('set ylabel font "Helvetica, 18"')
       g('set xtics font "Helvetica, 18"')
       g('set ytics font "Helvetica, 18"')
       g('set xtics out')
       g('set ytics out')
       g('set xlabel "Ort [a.u.]"')
       g('set ylabel "B [a.u.]"')
       g('set yrange [-1.5:1.5]')
       g('set title font "Helvetica, 18"')
       string = 'set title "time: {0:=5} [a.u.]"'.format(T*100000)
       g(string)
       while time.clock() < t0 + 0.001:
          a = 0
       g.plot(pldat)
       time.sleep(0.05) #sleep a little so this does not run too fast
       u = up           #update solution
       T += dt          #increment time
    raw_input('Please press the return key to continue...\n')
   
