#!/usr/bin/env python
"""
Simple 1-d wave function pde solver
"""
from numpy import *
import Gnuplot, Gnuplot.funcutils
import time



def solver(I,f,c,L,n,dt,tstop,user_action=None):
    # f is a function of x and t, I is a function of x
    t0 = time.clock()
    x = linspace(0,L,n+1) #grid points in x direction
    dx = L/float(n)
    if dt <= 0: dt = dx/float(c) #maximum time step for stability
    C2 = (c*dt/dx)**2
    dt2 = dt*dt

    up = zeros(n+1)  #solution at time step n+1 (u-plus)
    u  = up.copy()   #solution at time step n   (u)
    um = up.copy()   #solution at time step n-1 (u-minus)

    #initial conditions
    t = 0.0
    u[0:n] = I(x[0:n])
    # this is the discretized version of the wave equation
    um[1:n] = u[1:n] + 0.5*C2*(u[0:n-1] - 2*u[1:n] + u[2:n+1])+ dt2*f(x[1:n],t)
    #boundary conditions
    um[0] = 0; um[n] = 0

    while t < tstop:
        t_old = t; t += dt
        #discretized wave equation
        up[1:n] = - um[1:n] + 2*u[1:n] + C2*(u[0:n-1] - 2*u[1:n] + u[2:n+1]) + dt2*f(x[1:n],t_old)
        up[0] = 0; up[n] = 0
        um, u, up = u, up, um                # switch references
        if user_action is not None:
            user_action(up, x, t)       #I can do whatever I'd like here
                                        #Note: if I want to do something with the
					#solution, u, I need to do it here!

    t1 = time.clock()
    return dt, x, t1-t0                 #solver does NOT return solution u!
					#But it returns how long it worked (t1-t0)

# ----------------  end of solver  -------------------------------------------



def test_solver1(N):
    """
    Very simple test case.
    Store the solution at every N time levels.
    Measure how long the solver actually works
    """
    g = Gnuplot.Gnuplot()
    def I(x):  return sin(2*x*pi/L)
    def f(x,t): return 0
    solutions = []
    # Need time_level_counter as global variable since
    # it is assigned in the action function (that makes
    # a variable local to that block otherwise).
    # The manager class below provides a cleaner solution.
    global time_level_counter
    time_level_counter = 0

    def action(u, x, t):           	#this is where I can access solution u
        global time_level_counter
        if time_level_counter % N == 0: #only do this every N times
            solutions.append(u.copy())
    	    y = u.copy()
	    g('set ylabel "y-axis [arb. units]"')   #this is how you access gnupot commands
	    g('set xlabel "x-axis [arb. units]"')
	    g('set timestamp "%a/%b/%d/%y %H:%M:%S"')
            g('show time')
	    g.plot(y,yrange=[-1,1],title='this should move')
	    z = time.sleep(0.05) 	#wait a little so you can look at it
	    g.reset()
        time_level_counter += 1

    n = 100; tstop = 5*pi; L = 10
    #call the solver here, note that the solution, u, is not returned
    #but cpu time is
    dt, x, cpu = solver(I, f, 1.0, L, n, 0, tstop, user_action=action)
    print 'CPU time:', cpu, ' sec'
    return solutions		#test_solver has access to solutions through 'action'
				#and returns solution array to main

# ---------------------------------  end test_solver  -----------------------------


def main():
    # open output file
    output = open("1dwave-sol.dat","w")
    solutions = test_solver1(1) 	#this is where I call the solver
    #print solutions
    output.write(str(solutions)) 	#note: the output is an array
    output.close()


main()
