#!/usr/bin/env python
"""
Use Maccormack's technique to solve the subsonic-supesonic isentropic nozzle flow

Author: Bob Wimmer

Date: January 17, 2009

"""
from numpy import *
import Gnuplot, Gnuplot.funcutils
import time

#global definitions
gamma = 1.4
nx = 30
L = 3.
x = linspace(0.,L,nx+1)   #Array along the nozzle
dx = L/float(nx)
A = zeros(nx+1)
for i in xrange(0,nx+1,1):
    A[i] = 1. + 2.2*(x[i] - 1.5)**2
    #A[i] = 1. + x[i]**1.5+2.2*(x[i] - 1.5)**2 - 0.5*(x[i]-1.5)**4
tstop = 1.e2

def solver(tstop,user_action=None):

    t0 = time.clock()   #start timer to find out how much CPU time is used in solver


    """Define arrays etc."""
    rho = zeros(nx+1)
    v = rho.copy()
    T = v.copy()
    dta = v.copy()
    cs = v.copy()
    drho_dt = zeros(nx+1)
    dv_dt = drho_dt.copy()
    dT_dt = dv_dt.copy()
    rho_p = zeros(nx+1)
    v_p = rho_p.copy()
    T_p = v_p.copy()
    drho_p_dt = zeros(nx+1)
    dv_p_dt = drho_p_dt.copy()
    dT_p_dt = dv_dt.copy()
    drho_dt_ave = zeros(nx+1)
    dv_dt_ave = drho_p_dt.copy()
    dT_dt_ave = dv_dt.copy()
    rho_pp = zeros(nx+1)
    v_pp = rho_pp.copy()
    T_pp = v_pp.copy()

    """Set up initial conditions"""
#    for i in xrange(0,nx+1,1):
#        rho[i] = 1. - 0.314*x[i]
#        T[i]   = 1. - 0.231*x[i]
#        cs[i]  = sqrt(T[i])
#        v[i]   = (0.1 + 1.09*x[i])*cs[i]

    #vectorized version for initial conditions:
    rho[0:nx+1] = 1. - 0.3146*x[0:nx+1]
    T[0:nx+1]   = 1. - 0.2314*x[0:nx+1]
    cs[0:nx+1]  = sqrt(T[0:nx+1])
    v[0:nx+1]   = (0.1 + 1.09*x[0:nx+1])*cs[0:nx+1]
    #print v

    """Set up boundary conditions"""
    rho[0]  = 1.
    T[0]    = 1.
    v[0]    = 2.*v[1] - v[2]  #one floating boundary condition at inflow
    rho[nx] = 2.*rho[nx-1] - rho[nx-2]
    v[nx]   = 2.*v[nx-1] - v[nx-2]
    T[nx]   = 2.*T[nx-1] - T[nx-2]
    #print 'initial values at i = 16: ', x[15], A[15], rho[15], v[15], T[15]
    #print 'initial values at i = 17: ', x[16], A[16], rho[16], v[16], T[16]

    t = 0.

    dres = 1.
    oldres = 1.
    while dres > 1.e-7:   
        """compute time step"""
        for i in xrange(0,nx+1,1):
            dta[i] = fabs(0.5*dx/(cs[i]+v[i]))
            #print '*** ', t, dta[i], dx, (cs[i] + v[i]), dx/(cs[i]+v[i])
        dt = dta.min()
        #print 'final dt is: ', dta, '\n******', dt
        t_old = t; t += dt
    
        """predictor step"""
        #prepare derivatives using forward differences
        for i in xrange(0,nx,1):
            drho_dt[i] = -rho[i]*(v[i+1] - v[i])/dx - rho[i]*v[i]*(log(A[i+1])-log(A[i]))/dx - v[i]*(rho[i+1]-rho[i])/dx
            dv_dt[i]   = -v[i]*(v[i+1]-v[i])/dx - ((T[i+1]-T[i])/dx + T[i]/rho[i]*(rho[i+1]-rho[i])/dx)/gamma
            dT_dt[i]   = -v[i]*(T[i+1]-T[i])/dx - (gamma -1)*T[i]*((v[i+1]-v[i])/dx + v[i]*(log(A[i+1]) - log(A[i]))/dx)
            #Now calculate predictor values
            rho_p[i] = rho[i] + drho_dt[i]*dt
            v_p[i]   = v[i]   + dv_dt[i]*dt
            T_p[i]   = T[i]   + dT_dt[i]*dt
        #print 'x values at i = 16: ', rho_p[15], v_p[15], T_p[15]
        #print 'values at i = 17: ', rho[16], v[16], T[16]
        #print 'deriv values at i = 16: ', drho_dt[15], dv_dt[15], dT_dt[15]
        #print 'deriv values at i = 17: ', drho_dt[16], dv_dt[16], dT_dt[16]
        v_p[0] = 2.*v_p[1] - v_p[2]
        rho_p[0] = 1.
        T_p[0] = 1.
        rho_p[nx] = 2.*rho_p[nx-1] - rho_p[nx-2]
        T_p[nx] = 2.*T_p[nx-1] - T_p[nx-2]
        v_p[nx] = 2.*v_p[nx-1] - v_p[nx-2]

        #print rho_p
        #print v_p
        #print T_p

        """corrector step"""
        #prepare derivatives using rearward differences
        for i in xrange(1,nx+1,1):
            drho_p_dt[i] = -rho_p[i]*(v_p[i] - v_p[i-1])/dx - rho_p[i]*v_p[i]*(log(A[i])-log(A[i-1]))/dx - v_p[i]*(rho_p[i]-rho_p[i-1])/dx
            dv_p_dt[i]   = -v_p[i]*(v_p[i]-v_p[i-1])/dx - ((T_p[i]-T_p[i-1])/dx + T_p[i]/rho_p[i]*(rho_p[i]-rho_p[i-1])/dx)/gamma
            dT_p_dt[i]   = -v_p[i]*(T_p[i]-T_p[i-1])/dx - (gamma -1)*T_p[i]*((v_p[i]-v_p[i-1])/dx + v_p[i]*(log(A[i]) - log(A[i-1]))/dx)
        #print 'dx_p_dt values at i = 16: ', drho_p_dt[15], dv_p_dt[15], dT_p_dt[15]

        """time propagator"""
        for i in xrange(0,nx+1,1):
            drho_dt_ave[i] = 0.5*(drho_dt[i] + drho_p_dt[i])
            dv_dt_ave[i]   = 0.5*(dv_dt[i] + dv_p_dt[i])
            dT_dt_ave[i]   = 0.5*(dT_dt[i] + dT_p_dt[i])
            rho_pp[i] = rho[i] + drho_dt_ave[i]*dt
            v_pp[i]   = v[i]   + dv_dt_ave[i]*dt
            T_pp[i]   = T[i]   + dT_dt_ave[i]*dt
        #print 'dx_dt_ave values at i = 16: ', drho_dt_ave[15], dv_dt_ave[15], dT_dt_ave[15]
        #print 'final propagated values at i = 16 are: ', rho_pp[15], v_pp[15], T_pp[15]

        """insert boundary conditions"""
        rho[0]  = 1.
        T[0]    = 1.
        cs[0]   = 1.
        #print 'at lower boundary: ', v_pp[1], v_pp[2]
        #print 'at upper boundary: ', rho_pp[nx-1], rho_pp[nx-2]
        #print 'at upper boundary: ', v_pp[nx-1], v_pp[nx-2]
        #print 'at upper boundary: ', T_pp[nx-1], T_pp[nx-2]
        v_pp[0]    = 2.*v_pp[1] - v_pp[2]  #one floating boundary condition at inflow
        rho_pp[nx] = 2.*rho_pp[nx-1] - rho_pp[nx-2]
        v_pp[nx]   = 2.*v_pp[nx-1] - v_pp[nx-2]
        T_pp[nx]   = 2.*T_pp[nx-1] - T_pp[nx-2]
        cs[nx]  = sqrt(T[nx])
        #print 'boundary conditions are: at 0: ',v[0], ' at 30: ', rho[30], v[30], T[30]


        """The above values are the new ones (at the next time step)"""
        rho, v, T = rho_pp, v_pp, T_pp #update 
        for i in xrange(0,nx+1,1):
            cs[i]     = sqrt(T[i])

        """Determine residuals"""
        rho_res = 0.
        v_res = 0.
        T_res = 0.
        for i in xrange(0,nx+1,1):
            rho_res += drho_dt_ave[i]**2
            v_res   += dv_dt_ave[i]**2
            T_res   += dT_dt_ave[i]**2

        #print 'residuals: ', rho_res, v_res, T_res

        dres = fabs(oldres - rho_res)
        oldres = rho_res

        if user_action is not None:
            user_action(rho, v, T, rho_res, v_res, T_res, dres, x, t)       #I can do whatever I'd like here

    t1 = time.clock()
    return t1-t0  , t                   #solver does NOT return solution u!
					#But it returns how long it worked (t1-t0) and the current model time

# ----------------  end of solver  -------------------------------------------



def test_solver(N):
    """
    Very simple test case.
    Store the solution at every N time levels.
    Measure how long the solver actually works
    """
    g = Gnuplot.Gnuplot()
    rho_solutions = []
    v_solutions   = []
    T_solutions   = []
    timesteps     = []
    rho_residuals = []
    v_residuals   = []
    T_residuals   = []
    # Need time_level_counter as global variable since
    # it is assigned in the action function (that makes
    # a variable local to that block otherwise).
    global time_level_counter
    time_level_counter = 0

    def action(rho, v, T, rho_res, v_res, T_res, dres, x, t):           	#this is where I can access solution u
        global time_level_counter
        if time_level_counter % N == 0: #only do this every N times
            timesteps.append(t)
            rho_residuals.append(rho_res)
            v_residuals.append(v_res)
            T_residuals.append(T_res)
            rho_solutions.append(rho.copy())
            v_solutions.append(v.copy())
            T_solutions.append(T.copy())
	    g.reset()
	    g('set ylabel "y-axis [arb. units]"')   #this is how you access gnupot commands
	    g('set xlabel "x-axis [arb. units]"')
            data_rho = Gnuplot.Data(x,rho,using=(1,2))
            data_v   = Gnuplot.Data(x,v,using=(1,2))
            data_T   = Gnuplot.Data(x,T,using=(1,2))
            g.plot(data_rho,data_v,data_T)
	    z = time.sleep(0.1) 	#wait a little so you can look at it
            #print rho_res, dres
        time_level_counter += 1

    #call the solver here, note that the solution, u, is not returned
    #but cpu time is
    cpu, t = solver(tstop, user_action=action)
    print 'CPU time:', cpu, ' sec'
    print 'time: ', t
    return timesteps, rho_solutions, v_solutions, T_solutions, rho_residuals, v_residuals, T_residuals
                                                   #test_solver has access to solutions through 'action'
				                   #and returns solution array to main

# ---------------------------------  end test_solver  -----------------------------


def main():
    # open output file
    #output = open("1dwave-sol.dat","w")
    timesteps, rho_solutions, v_solutions, T_solutions, rho_residuals, v_residuals, T_residuals = test_solver(2) 	#this is where I call the solver
    #print solutions
    #output.write(str(solutions)) 	#note: the output is an array
    #output.close()
    raw_input('Please press the return key to continue...\n')
    

main()
