#!/usr/bin/env python
"""
Solve linear convection equation

    u_t + a u_x = 0

using the Lax-Wendroff scheme

u_i^n+1 = u_i^n - s/2 (u_i+1^n - u_i-1^n) + s^2/2 (u_i+1^n - 2u_i^n + u_i-1^n),

where s = a dt/dx

See p. 307 in Hirsch, vol. 1

Author: Bob Wimmer

Date: March 13, 2010

"""
import numpy as np
import Gnuplot, Gnuplot.funcutils
import time

#global definitions
def Gaussian(x,t,sigma):
    """  A Gaussian curve.
        x = Variable
        t = time shift
        sigma = standard deviation      """
    return np.exp(-(x-t)**2/(2*sigma**2))

def psi(r):
    """flux limiter"""
    alpha=0.5
    #print r
    #print (r + np.fabs(r))/(1.+r)
    #return (r + np.fabs(r))/(1.+r)   #van Leer limiter
    #return (r*r+r)/(1.+r*r)  #Van Albada et al. limiter
    on = np.ones(np.alen(r))
    ze = np.zeros(np.alen(r))
    #sb = np.maximum(ze,np.minimum(2*r,on),np.minimum(r,2*on)) #superbee limiter
    al = np.maximum(ze,np.minimum(2*r,alpha*r+on,2*on)) #ALFA limiter (MUSCL for alpha = 0.5)
    #ub = np.maximum(ze,np.minimum(2/sigma*r,on),np.minimum(r,2/(1.-sigma)*on)) #ultrabee limiter
    return al
    #return on
    #return sb
    #return ub
    #print np.ones(np.alen(r))
    #return np.ones(np.alen(r))  #turn off limiters


N    =   1500     #  Number of spatial points.
n_plot= 1       #  Number of time steps to increment before updating the plot.
dx   = 1.0e0   #  Spatial resolution
x    = dx*np.linspace(0,N,N)        #  Spatial axis.
a = 1.           #exact group velocity/phase velocity for a linear wave
sigma = 0.5
dt = sigma/a*dx
print 'dt = ', dt
tstop = 1.*N*dt      #  Number of time steps.  

#  Direct index assignment is MUCH faster than using a spatial FOR loop, so
#  these constants are used in the update equations.  Remember that Python uses
#  zero-based indexing.
IDX1 = range(1,N-1)     #u[i]
IDX2 = range(2,N)       #u[i+1]
IDX3 = range(0,N-2)     #u[i-1]
PA = 0                  #Past
PR = 1                  #Present
FU = 2                  #Future
u = np.zeros((3,N))     #one entry for past, present, and future
r = np.zeros(N)         #ratio of gradients
R = np.zeros(N)         #ratio of gradients

#initial conditions
u[PR,N/4:N/2] = 1.; u[PR,N/2:N] = 0.
#k = 2.*np.pi/200
#xn = x[range(0,N/4)]/dx
#u[PR,0:N/4] = np.sin(k*xn)
#an = u[PR].copy()

k_min = np.pi/dx
phi = k_min*dx
print 'k delta x = phi =', phi, ', k = ', k_min, ', k_min = pi/dx = ', k_min
print 'Phi_tilde = sigma * phi = a k delta t = ', a*k_min*dt
epsphi = np.arctan((sigma*np.sin(phi))/(1.-2.*sigma**2*(np.sin(phi/2))**2))/(sigma*phi)
print 'Phi = ', epsphi*a*k_min*dt,', epsphi = ', epsphi
print 'a_num = a * ',epsphi, ', (a - a_num)*1500 = ',a*(1.-epsphi)*1500
G = np.sqrt(1.-4.*sigma**2*(1.-sigma**2)*(np.sin(phi/2))**4)
print 'G = ',G,', G^',N,' = ', G**N
r1 = np.empty(N)
r2 = r1.copy()


def solver(tstop,user_action=None):

    t0 = time.clock()   #start timer to find out how much CPU time is used in solver

    t = 0.
    #print tstop*dt, dt
    # Precompute a couple of indexing constants, this speeds up the computation
    while t < tstop:
        #print t, tstop*dt
        #compute ratios of gradients
        r1[IDX1] = u[PR,IDX2] - u[PR,IDX1]
        r2[IDX1] = u[PR,IDX1] - u[PR,IDX3]
        for i in IDX1:
            if r1[i] != 0. and r2[i] != 0.:
                r[i] = r1[i]/r2[i]
                R[i] = r2[i]/r1[i]
            elif r1[i] == 0. and r2[i] == 0.:
                r[i] = 1.
                R[i] = 1.
            elif r1[i] != 0. and r2[i] == 0.:
                r[i] = 100. #some large number will do
                R[i] = 0.
            else:
                r[i] = 0.
                R[i] = 100. #some large number will do
        r[IDX1] = np.nan_to_num(r[IDX1])
        R[IDX1] = np.nan_to_num(R[IDX1])
        #r[IDX1] = np.nan_to_num((u[PR,IDX2] - u[PR,IDX1])/(u[PR,IDX1] - u[PR,IDX3]))
        #R[IDX1] = np.nan_to_num((u[PR,IDX1] - u[PR,IDX3])/(u[PR,IDX2] - u[PR,IDX1]))

        #u[FU,IDX1] = u[PR,IDX1] - sigma/2*(u[PR,IDX2] - u[PR,IDX3]) + sigma**2/2*(u[PR,IDX2] - 2*u[PR,IDX1] + u[PR,IDX3])
        #u[FU,IDX1] = u[PR,IDX1] - sigma*(u[PR,IDX1] - u[PR,IDX3])*(1.+0.5*(1.-sigma)*(psi(R[IDX1])/R[IDX1] - psi(R[IDX3])))*(u[PR,IDX1] - u[PR,IDX3])
        u[FU,IDX1] = u[PR,IDX1] - sigma*(u[PR,IDX1] - u[PR,IDX3]) - 0.5*sigma*(1.-sigma)*psi(R[IDX1])*(u[PR,IDX2] - u[PR,IDX1]) + 0.5*sigma*(1.-sigma)*psi(R[IDX3])*(u[PR,IDX1] - u[PR,IDX3])
        #u[FU,IDX1] = u[PR,IDX1] - sigma*(u[PR,IDX1] - u[PR,IDX3])**2*(1.+0.5*(1.-sigma)*(psi(r[IDX1]) - psi(R[IDX3])))

        #print psi(R[IDX1])

        #update time step
        u[PA] = u[PR]
        u[PR] = u[FU]
        #print u[PR,IDX1]
        t +=dt
        if user_action is not None:
            #print u[PR]
            user_action(u[PR], x, t)       #I can do whatever I'd like here

    t1 = time.clock()
    return t1-t0, t                     #solver does NOT return solution u!
					#But it returns how long it worked (t1-t0) and the current model time

# ----------------  end of solver  -------------------------------------------



def test_solver(n_plot):
    """
    Very simple test case.
    Store the solution at every N time levels.
    Measure how long the solver actually works
    """
    g = Gnuplot.Gnuplot()#depending on your system, you may need to insert persist=1 in the parenthesis
    # Need time_level_counter as global variable since
    # it is assigned in the action function (that makes
    # a variable local to that block otherwise).
    global time_level_counter
    time_level_counter = 0
    t = 0.
    cpu = 0.
    def action(u, x, t): #this is where I can access solution u
        global time_level_counter
        if time_level_counter % n_plot == 0: #only do this every N times
            g.reset()
            g('set ylabel "y-axis [arb. units]"')   #this is how you access gnuplot commands
            g('set xlabel "x-axis [arb. units]"')
            string = 'set yrange [-0.5:1.2]'
            g(string)
            string = 'set title "time: ' + str(t/dt) +'"'
            g(string)
            N0 = int(a*t)
            #xn = x[range(0,N/4)]/dx
            an = np.zeros(N)
            #an[0:N/4] = np.sin(k*xn)
            an[N/4:N/2] = 1.; an[N/2:N] = 0.
            an = np.roll(an,N0)
            data_u   = Gnuplot.Data(x,u,using=(1,2), with_ = 'line', title = 'numerical')
            data_an   = Gnuplot.Data(x,an,using=(1,2), with_ = 'line', title = 'analytic')
            g.plot(data_u,data_an)
            #g.plot(data_psip)
            #z = time.sleep(0.1) 	#wait a little so you can look at it
            #print rho_res, dres
        time_level_counter += 1

    #call the solver here, note that the solution, u, is not returned
    #but cpu time is
    cpu, t = solver(tstop, user_action=action)
    print 'CPU time:', cpu, ' sec'
    print 'time: ', t
    return u
                                                   #test_solver has access to solutions through 'action'
				                   #and returns solution array to main

# ---------------------------------  end test_solver  -----------------------------


def main():
    # open output file
    #output = open("1dwave-sol.dat","w")
    u=np.zeros(N)
    u = test_solver(n_plot) 	#this is where I call the solver
    #print solutions
    #output.write(str(solutions)) 	#note: the output is an array
    #output.close()
    raw_input('Please press the return key to continue...\n')
    

main()
