#! /usr/bin/env python

from numpy import *
import Gnuplot#, Gnuplot.funcutils
from numpy import arange, eye, linalg, ones

"""Try to solve a diffusion problem numerically"""


    

def bound(location,T):
    if location == float():
       return 1.       #-0.5*(sin(64*pi*T))
       #return 1.-0.5*(sin(64*pi*T))
    else:
       return 0.      #+0.3*(sin(32*pi*T))
       #return 0.5+0.3*(sin(32*pi*T))


def march(L,n,lo,hi,dt,alpha,T,u):
    up = zeros(n+1)
    x = linspace(0,L,n+1) # grid points in x direction
    dx = float(L)/float(n)
    up[1:n] = u[1:n] + alpha*dt/dx/dx*(u[2:n+1] - 2*u[1:n] + u[0:n-1])
    up[0] = bound(lo,T)
    up[n] = bound(hi,T)
    u = up
    return up,x

def implicit(L,n,lo,hi,dt,alpha,T,u):
    """implicit integration scheme for heat equation.
       Difference equation is given by:
       A*u^n+1_i-1 - B*u^n+1_i + A*u^n+1_i+1 = - T^n_i - A *(u^n_i+1 - 2u^n_i + u^n_i-1),
       where
       A = \frac{\alpha \Delta T}{2 \Delta x^2}
       B = 1 + A/2
    """
    up = zeros(n+1)
    x = linspace(0,L,n+1) # grid points in x direction
    dx = float(L)/float(n)
    #first calculate A
    A = alpha * dt / 2. / dx/dx
    B = 1.+A*2.
    #Next fill matrix with values
    mat = -B*eye(n+1)
    print mat
    i = 1
    for i in xrange(n):
        mat[i,i+1] = A
        mat[i,i-1] = A
        i = i+1
    mat[n,n-1] = A
    print mat
    #Now compute right-hand side
    K = zeros(n+1)
    K[1:n] = -u[1:n] - A*(u[2:n+1] -2*u[1:n] + u[0:n-1])
    #correct for known boundary conditions
    up[0] = bound(lo,T)
    up[n] = bound(hi,T)
    print "boundary conditions: u(lo) = " + str(up[0]) + ", u(hi) = " + str(up[n])
    K[0] -= A*up[0]
    K[n] -= A*up[n]
    print K
    up = linalg.solve(mat,K)
    u = up
    return up,x
  

# when executed, just run:
if __name__ == '__main__':
    import time
    L = 1.
    n = 40
    alpha = 5. #heat conduction coefficient
    dt = 1.5e-1  #time increment

    g = Gnuplot.Gnuplot()    
    lo = 0.
    hi = 1.
    #u = linspace(bound(lo,0.), bound(hi,0.),n+1)
    u = linspace(0., 1.0,n+1)
    g.plot(u)
    raw_input('Please press the return key to continue...\n')
    i = 0
    T = 0.
    while T < 1.:
       t0 = time.clock()
       #up,x=march(L,n,lo,hi, dt,alpha,T,u)
       up,x=implicit(L,n,lo,hi, dt,alpha,T,u)
       g.reset
       while time.clock() < t0 + 0.1:
          a = 0
       g.plot(up)
       u = up
       T += dt
    raw_input('Please press the return key to continue...\n')
   
