#!/usr/bin/env python
"""
Solve the non-linear inviscid Burgers equation as an example for a non-linear PDE to mimick the non-linear term in the Navier-Stokes eqs.

    u_t + u u_x = 0

using the following schemes:
- FOU
- Lax
- Lax-Friedrichs
- Lax-Wendroff
- Warming and Beam 
- MacCormack 

Many of these schemes do not work. This motivates the next lecture in which we will look at chapter 9 of Hirsch on time integration methods.

The schemes are defined by: 


Lax:
u_i^n = 0.5*(u_i+1^n + u_i-1^n) - dt/dx/4*(u_i+1^n*u_i+1^n -u_i-1^n*u_i-1^n )

Lax-Wendroff: (see below for meaning of IDX_i and PR and FU)
u[FU,IDX1] = u[PR,IDX1] - sigma/2*0.5*(u[PR,IDX2]*u[PR,IDX2] - u[PR,IDX3]*u[PR,IDX3]) + sigma**2/4*((u[PR,IDX1] + u[PR,IDX2])*0.5*(u[PR,IDX2]*u[PR,IDX2] - u[PR,IDX1]*u[PR,IDX1]) - (u[PR,IDX1]+u[PR,IDX3])*0.5*(u[PR,IDX1]*u[PR,IDX1] - u[PR,IDX3]*u[PR,IDX3]))

Warming and Beam:
u[FU,IDX1] = u[PR,IDX1] - sigma/4*(3*u[PR,IDX1]*u[PR,IDX1] - 4*u[PR,IDX3]*u[PR,IDX3] + u[PR,IDX4]*u[PR,IDX4]) + sigma**2/4*(3*u[PR,IDX1]*(u[PR,IDX2]*u[PR,IDX2] - u[PR,IDX2]*u[PR,IDX2]) - 4*u[PR,IDX3]*(u[PR,IDX1]*u[PR,IDX1] - u[PR,IDX3]*u[PR,IDX3]) + u[PR,IDX4]*(u[PR,IDX3]*u[PR,IDX3] - u[PR,IDX4]*u[PR,IDX4]))

Lax-Wendroff:
u_i^n+1 = u_i^n - s/2 (u_i+1^n - u_i-1^n) + s^2/2 (u_i+1^n - 2u_i^n + u_i-1^n),


Warming and Beam (in an unstable implementation):
u_i^n+1 = u_i^n - s/2 (3u_i^n - 4u_i-1^n + u_i-2^n) + s^2/2 (u_i^n - 2u_i-1^n + u_i-2^n),

and MacCormack's method:
ub_i^n+1 = u_i^n-s*(u_i^n-i_i-1^n)
u_i^n+1 = 1/2*(ub_i^n+1 + u_i^n) - s/2*(f(ub_i+1^n+1) - f(ub_i^n+1))
where f(u) = 1/2*u^2                      (see p. 458 in Hirsch)
(this is a predictor-corrector method)


See p. 298 ff in Hirsch, vol. 1 for their application to the convection equation. Here, we first need to find the difference equations given above.

In this exercise we don't treat flux limiters. For that, see ex9 in comp. hydro dynamics

Author: Bob Wimmer, wimmer@physik.uni-kiel.de

Date: December 28, 2012

"""
import numpy as np
import Gnuplot, Gnuplot.funcutils
import time

#global definitions
def Gaussian(x,t,sigma):
    """  A Gaussian curve.
        x = Variable
        t = time shift
        sigma = standard deviation      """
    return np.exp(-(x-t)**2/(2*sigma**2))



N    =   1600     #  Number of spatial points.
n_plot= 5       #  Number of time steps to increment before updating the plot.
dx   = 1.0e0    #  Spatial resolution
x    = dx*np.linspace(0,N,N)        #  Spatial axis.
a = 1.           #exact group velocity/phase velocity for a linear wave
sigma = 0.9999      #CFL number (generally, although not always, sigma<1 for stability. See below)
dt = sigma/a*dx  #define time step based on CFL number. Time step will change depending on actual speed. 
print 'dt = ', dt
tstop = N/3*dt      #  Number of time steps.  

#  Direct index assignment is MUCH faster than using a spatial FOR loop, so
#  these constants are used in the update equations.  Remember that Python uses
#  zero-based indexing.
IDX1 = range(2,N-1)     #u[i]
IDX2 = range(3,N)       #u[i+1]
IDX3 = range(1,N-2)     #u[i-1]
IDX4 = range(0,N-3)     #u[i-2]
PA = 0                  #Past
PR = 1                  #Present
FU = 2                  #Future
u = np.zeros((3,N))     #one entry for past, present, and future
ub = np.zeros(N)        #u bar, used in MacCormack's method

#Define initial conditions:
k = 1.*np.pi/200
xn = x[range(0,N/4)]/dx
xn1 = x[range(0,int(N/4-a*dt/dx))]/dx
#square wave:
#u[PR,0:N/4] = 1.
#ramp:
#xi = np.ones(N/4)
#xu = np.divide(x[range(0,N/4)]/dx,x[N/4])
#u[PR,0:N/4] = (xi-xu)
#u[PA,0:N/4-a*dt/dx] = 1.
#wave packet:
u[PR,0:N/4] = np.sin(k*xn)
u[PA,0:N/4-a*dt/dx] = np.sin(k*xn1)

#analytival solution (will be propagated in plot routine):
AN = u[PR].copy()
#note: this is not correct for Burgers equation! Need to use try-ana-sol.py for that, but it does not yet work.

phi = k*dx
k_min = np.pi/dx
print 'k delta x = phi =', phi, ', k = ', k, ', k_min = pi/dx = ', k_min
print 'Phi_tilde = sigma * phi = a k delta t = ', a*k*dt
epsphi = np.arctan((sigma*np.sin(phi))/(1.-2.*sigma**2*(np.sin(phi/2))**2))/(sigma*phi)
print 'a_num = a * ',epsphi, ', (a - a_num)*1500 = ',a*(1.-epsphi)*1500
G = np.sqrt(1.-4.*sigma**2*(1.-sigma**2)*(np.sin(phi/2))**4)
print 'G = ',G,', G^',N,' = ', G**N

def solver(tstop,user_action=None):

    t0 = time.clock()   #start timer to find out how much CPU time is used in solver

    t = 0.
    dt = sigma/max(u[PR])*dx  #next dt needs to fulfill CFL condition 
    #print tstop*dt, dt
    while t < tstop:
        #print t, tstop*dt
#        #FOU non conservative:
#        u[FU,IDX1] = u[PR,IDX1] - sigma*u[PR,IDX1]*(u[PR,IDX1] - u[PR,IDX3])

#        #FOU conservative:
#        u[FU,IDX1] = u[PR,IDX1] - sigma/2*(u[PR,IDX1]*u[PR,IDX1] - u[PR,IDX3]*u[PR,IDX3])

#        #Lax-Friedrichs conservative:
#        u[FU,IDX1] = 0.5*(u[PR,IDX3] + u[PR,IDX2]) - sigma/4*(u[PR,IDX2]*u[PR,IDX2] - u[PR,IDX3]*u[PR,IDX3]) 

#        #Lax:
#        u[FU,IDX1] = 0.5*(u[PR,IDX2] + u[PR,IDX3]) - dt/dx/4*(u[PR,IDX2]*u[PR,IDX2] -u[PR,IDX3]*u[PR,IDX3] )

#        #Lax-Wendroff:
#        u[FU,IDX1] = u[PR,IDX2] - sigma/2*0.5*(u[PR,IDX2]*u[PR,IDX2] - u[PR,IDX3]*u[PR,IDX3]) + sigma**2/4*((u[PR,IDX1] + u[PR,IDX2])*0.5*(u[PR,IDX2]*u[PR,IDX2] - u[PR,IDX1]*u[PR,IDX1]) - (u[PR,IDX1]+u[PR,IDX3])*0.5*(u[PR,IDX1]*u[PR,IDX1] - u[PR,IDX3]*u[PR,IDX3]))

#        #Lax-Wendroff conservative:
#        u[FU,IDX1] = u[PR,IDX2] - sigma/4*(u[PR,IDX2]*u[PR,IDX2] - u[PR,IDX3]*u[PR,IDX3]) + sigma*sigma/8*((u[PR,IDX1] + u[PR,IDX2])*(u[PR,IDX2]*u[PR,IDX2] - u[PR,IDX1]*u[PR,IDX1]) - (u[PR,IDX1] + u[PR,IDX3])*(u[PR,IDX1]*u[PR,IDX1] - u[PR,IDX3]*u[PR,IDX3]) )

#        #Warming and Beam:
#        u[FU,IDX1] = u[PR,IDX1] - sigma/4*(3*u[PR,IDX1]*u[PR,IDX1] - 4*u[PR,IDX3]*u[PR,IDX3] + u[PR,IDX4]*u[PR,IDX4]) + sigma**2/4*(3*u[PR,IDX1]*(u[PR,IDX2]*u[PR,IDX2] - u[PR,IDX2]*u[PR,IDX2]) - 4*u[PR,IDX3]*(u[PR,IDX1]*u[PR,IDX1] - u[PR,IDX3]*u[PR,IDX3]) + u[PR,IDX4]*(u[PR,IDX3]*u[PR,IDX3] - u[PR,IDX4]*u[PR,IDX4]))

#        #MacCormack:
        ub[IDX1] = u[PR,IDX1] - sigma*(u[PR,IDX1] - u[PR,IDX3])
        u[FU,IDX1] = 0.5*(ub[IDX1] + u[PR,IDX1]) - sigma/4*(ub[IDX2]*ub[IDX2] - ub[IDX1]*ub[IDX1])
       
        #update time step
        u[PA] = u[PR]
        u[PR] = u[FU]
#        u[PR,0] = 0.01*t
        t +=dt
        dt = sigma/max(u[PR])*dx  #next dt needs to fulfill CFL condition 
        #print 'dt = ', dt, ', max(u) = ', max(u[PR]), u[PR].argmax()
        if user_action is not None:
            #print u[PR]
            user_action(u[PR], x, t)       #I can do whatever I'd like here
            #raw_input('press return')
    t1 = time.clock()
    return t1-t0, t                     #solver does NOT return solution u!
					#But it returns how long it worked (t1-t0) and the current model time

# ----------------  end of solver  -------------------------------------------



def test_solver(n_plot):
    """
    Very simple test case.
    Store the solution at every N time levels.
    Measure how long the solver actually works
    """
    g = Gnuplot.Gnuplot()#depending on your system, you may need to insert persist=1 in the parenthesis
    # Need time_level_counter as global variable since
    # it is assigned in the action function (that makes
    # a variable local to that block otherwise).
    global time_level_counter
    time_level_counter = 0
    t = 0.
    cpu = 0.
    def action(u, x, t): #this is where I can access solution u
        global time_level_counter
        if time_level_counter % n_plot == 0: #only do this every N times
            g.reset()
            g('set ylabel "y-axis [arb. units]"')   #this is how you access gnuplot commands
            g('set xlabel "x-axis [arb. units]"')
            string = 'set yrange [-1:1.5 ]'
            g(string)
            N0 = int(a*t/dx)
            string = 'elapsed time: ' + str(t) + ', i.e., ' + str(t/dt) + ' CFL steps'
            str2 = 'set title "' + string + '"'
            g(str2)
            xn = x[range(0,N/4)]/dx
            #an = np.zeros(N)
            #an[0:N/4] = np.sin(k*xn)
            an = AN.copy()
            an = np.roll(an,N0)
            data_u   = Gnuplot.Data(x,u,using=(1,2), with_ = 'line', title = 'numerical')
            data_an   = Gnuplot.Data(x,an,using=(1,2), with_ = 'line', title = 'analytic')
            g.plot(data_u,data_an)
            #raw_input('press return')
            #g.plot(data_psip)
            z = time.sleep(0.1) 	#wait a little so you can look at it
            #print rho_res, dres
        time_level_counter += 1

    #call the solver here, note that the solution, u, is not returned
    #but cpu time is
    cpu, t = solver(tstop, user_action=action)
    print 'CPU time:', cpu, ' sec'
    print 'time: ', t
    #raw_input('press return')
    return u
                                                   #test_solver has access to solutions through 'action'
				                   #and returns solution array to main

# ---------------------------------  end test_solver  -----------------------------


def main():
    # open output file
    #output = open("1dwave-sol.dat","w")
    u=np.zeros(N)
    u = test_solver(n_plot) 	#this is where I call the solver
    #print solutions
    #output.write(str(solutions)) 	#note: the output is an array
    #output.close()
    raw_input('Please press the return key to continue...\n')
    

main()
