#!/usr/bin/env python
"""
Solve linear convection equation

    u_t + a u_x = 0

using the Lax-Wendroff scheme

u_i^n+1 = u_i^n - s/2 (u_i+1^n - u_i-1^n) + s^2/2 (u_i+1^n - 2u_i^n + u_i-1^n),

where s = a dt/dx

See p. 307 in Hirsch, vol. 1

Author: Bob Wimmer

Date: March 13, 2010

"""
import numpy as np
import Gnuplot, Gnuplot.funcutils
import time

#global definitions
def Gaussian(x,t,sigma):
    """  A Gaussian curve.
        x = Variable
        t = time shift
        sigma = standard deviation      """
    return np.exp(-(x-t)**2/(2*sigma**2))



N    =   1500     #  Number of spatial points.
n_plot= 5       #  Number of time steps to increment before updating the plot.
dx   = 1.0e0    #  Spatial resolution
x    = dx*np.linspace(0,N,N)        #  Spatial axis.
a = 1.           #exact group velocity/phase velocity for a linear wave
sigma = 0.5
dt = sigma/a*dx
print 'dt = ', dt
tstop = N*dt      #  Number of time steps.  

#  Direct index assignment is MUCH faster than using a spatial FOR loop, so
#  these constants are used in the update equations.  Remember that Python uses
#  zero-based indexing.
IDX1 = range(1,N-1)     #u[i]
IDX2 = range(2,N)       #u[i+1]
IDX3 = range(0,N-2)     #u[i-1]
PA = 0                  #Past
PR = 1                  #Present
FU = 2                  #Future
u = np.zeros((3,N))     #one entry for past, present, and future

#initial conditions
#u[PR,0:N/2] = 1.; u[PR,N/2:N] = 0.
k = 2.*np.pi/200
xn = x[range(0,N/4)]/dx
u[PR,0:N/4] = np.sin(k*xn)
an = u[PR].copy()

phi = k*dx
k_min = np.pi/dx
print 'k delta x = phi =', phi, ', k = ', k, ', k_min = pi/dx = ', k_min
print 'Phi_tilde = sigma * phi = a k delta t = ', a*k*dt
epsphi = np.arctan((sigma*np.sin(phi))/(1.-2.*sigma**2*(np.sin(phi/2))**2))/(sigma*phi)
print 'a_num = a * ',epsphi, ', (a - a_num)*1500 = ',a*(1.-epsphi)*1500
G = np.sqrt(1.-4.*sigma**2*(1.-sigma**2)*(np.sin(phi/2))**4)
print 'G = ',G,', G^',N,' = ', G**N

def solver(tstop,user_action=None):

    t0 = time.clock()   #start timer to find out how much CPU time is used in solver

    t = 0.
    #print tstop*dt, dt
    # Precompute a couple of indexing constants, this speeds up the computation
    while t < tstop:
        #print t, tstop*dt

        u[FU,IDX1] = u[PR,IDX1] - sigma/2*(u[PR,IDX2] - u[PR,IDX3]) + sigma**2/2*(u[PR,IDX2] - 2*u[PR,IDX1] + u[PR,IDX3])

        #update time step
        u[PA] = u[PR]
        u[PR] = u[FU]
        t +=dt
        if user_action is not None:
            #print u[PR]
            user_action(u[PR], x, t)       #I can do whatever I'd like here

    t1 = time.clock()
    return t1-t0, t                     #solver does NOT return solution u!
					#But it returns how long it worked (t1-t0) and the current model time

# ----------------  end of solver  -------------------------------------------



def test_solver(n_plot):
    """
    Very simple test case.
    Store the solution at every N time levels.
    Measure how long the solver actually works
    """
    g = Gnuplot.Gnuplot()#depending on your system, you may need to insert persist=1 in the parenthesis
    # Need time_level_counter as global variable since
    # it is assigned in the action function (that makes
    # a variable local to that block otherwise).
    global time_level_counter
    time_level_counter = 0
    t = 0.
    cpu = 0.
    def action(u, x, t): #this is where I can access solution u
        global time_level_counter
        if time_level_counter % n_plot == 0: #only do this every N times
            g.reset()
            g('set ylabel "y-axis [arb. units]"')   #this is how you access gnuplot commands
            g('set xlabel "x-axis [arb. units]"')
            string = 'set yrange [-1:1]'
            g(string)
            N0 = int(a*t)
            xn = x[range(0,N/4)]/dx
            an = np.zeros(N)
            an[0:N/4] = np.sin(k*xn)
            an = np.roll(an,N0)
            data_u   = Gnuplot.Data(x,u,using=(1,2), with_ = 'line', title = 'numerical')
            data_an   = Gnuplot.Data(x,an,using=(1,2), with_ = 'line', title = 'analytic')
            g.plot(data_u,data_an)
            #g.plot(data_psip)
            #z = time.sleep(0.1) 	#wait a little so you can look at it
            #print rho_res, dres
        time_level_counter += 1

    #call the solver here, note that the solution, u, is not returned
    #but cpu time is
    cpu, t = solver(tstop, user_action=action)
    print 'CPU time:', cpu, ' sec'
    print 'time: ', t
    return u
                                                   #test_solver has access to solutions through 'action'
				                   #and returns solution array to main

# ---------------------------------  end test_solver  -----------------------------


def main():
    # open output file
    #output = open("1dwave-sol.dat","w")
    u=np.zeros(N)
    u = test_solver(n_plot) 	#this is where I call the solver
    #print solutions
    #output.write(str(solutions)) 	#note: the output is an array
    #output.close()
    raw_input('Please press the return key to continue...\n')
    

main()
