#!/usr/bin/env python
"""
Simulate Gaussian wave packet interacting with potential wall.
Numerically solve Schroedinger's equation.

Some aspects copied from numpy cookbook article on Schroedinger's equation,
others from Langtangen, Python Scripting for Computational Sciences.

Author: Bob Wimmer

Date: March 13, 2010

"""
import numpy as np
import Gnuplot, Gnuplot.funcutils
import time

#global definitions
def Gaussian(x,t,sigma):
    """  A Gaussian curve.
        x = Variable
        t = time shift
        sigma = standard deviation      """
    return np.exp(-(x-t)**2/(2*sigma**2))

#Various definitions for the potential V(x)

def free(npts):
    "Free particle (constant potential)."
    return np.zeros(npts)

def step(npts,v0):
    "Potential step"
    v = free(npts)
    v[npts/2:] = v0
    return v

def barrier(npts,v0,thickness):
    "Barrier potential"
    v = free(npts)
    v[npts/2:npts/2+thickness] = v0
    return v

def double_barrier(npts,v0,thickness):
    "Barrier potential"
    v = free(npts)
    v[npts/4-1*thickness:npts/4] = 1.*v0
    v[3*npts/4:3*npts/4+1*thickness] = 1.*v0
    return v

def coulomb(npts,v0,thickness):
    "Coulomb-barrier potential"
    v = free(npts)
    for i in xrange(npts/2,npts/2+20*thickness,1):
        v[i] = v0*npts/2/i
    return v


N    = 2000     #  Number of spatial points.
tstop = 4*N      #  Number of time steps.  5*N is a nice value for terminating
                #  before anything reaches the boundaries.
n_plot= 15       #  Number of time steps to increment before updating the plot.
dx   = 1.0e0    #  Spatial resolution
m    = 1.0e0    #  Particle mass
hbar = 1.0e0    #  Plank's constant
X    = dx*np.linspace(0,N,N)        #  Spatial axis.

# Define potential parameters.  By playing with the type of potential and the height
# and thickness (for barriers), you'll see the various transmission/reflection
# regimes of quantum mechanical tunneling.
V0   = 4.8e-2   #  Potential amplitude (used for steps and barriers)
THICK = 100      # "Thickness" of the potential barrier (if appropriate
                # V-function is chosen)

# Uncomment the potential type you want to use here:
# Zero potential, packet propagates freely.
#POTENTIAL = 'free'

# Potential step.  The height (V0) of the potential chosen above will determine
# the amount of reflection/transmission you'll observe
#POTENTIAL = 'step'

# Potential barrier.  Note that BOTH the potential height (V0) and thickness
# of the barrier (THICK) affect the amount of tunneling vs reflection you'll
# observe.
POTENTIAL = 'barrier'

# Double potential barrier (crude model for confinement).
# Note that BOTH the potential height (V0) and thickness
# of the barrier (THICK) affect the amount of tunneling vs reflection you'll
# observe.
#POTENTIAL = 'double_barrier'

# Coulomb potential barrier.  Note that BOTH the potential height (V0) and thickness
# of the barrier (THICK) affect the amount of tunneling vs reflection you'll
# observe.
#POTENTIAL = 'coulomb'



#  Initial wave function constants
sigma = 20.0 # Standard deviation on the Gaussian envelope (remember Heisenberg uncertainty).
x0 = round(N/2) - 25*sigma # Time shift
k0 = np.pi/10 # Wavenumber (note that energy is a function of k)
# Energy for a localized gaussian wavepacket interacting with a localized
# potential (so the interaction term can be neglected by computing the energy
# integral over a region where V=0)
E = (hbar**2/2.0/m)*(k0**2+0.5/sigma**2)


#  Fill in the appropriate potential function 
if POTENTIAL=='free':
    V = free(N)
elif POTENTIAL=='step':
    V = step(N,V0)
elif POTENTIAL=='barrier':
    V = barrier(N,V0,THICK)
elif POTENTIAL=='double_barrier':
    V = double_barrier(N,V0,THICK)
elif POTENTIAL=='coulomb':
    V = coulomb(N,V0,THICK)
else:
    raise ValueError("Unrecognized potential type: %s" % POTENTIAL)
#  More simulation parameters.  The maximum stable time step is a function of the potential, V.
Vmax = V.max()                                 #  Maximum potential of the domain.
dt   = hbar/(2*hbar**2/(m*dx**2)+Vmax)         #  Critical time step.
c1   = hbar*dt/(m*dx**2)                       #  Constant coefficient 1.
c2   = 2*dt/hbar                               #  Constant coefficient 2.
c2V  = c2*V                                    #  pre-compute outside of update loop
# Print summary info
print 'One-dimensional Schrodinger equation - time evolution'
print 'Wavepacket energy:   ',E
print 'Potential type:      ',POTENTIAL
print 'Potential height V0: ',V0
print 'Barrier thickness:   ',THICK


# Inialize the wave functions.  Three states represent past, present, and future.
PA = 0                  #  Past
PR = 1                  #  Present
FU = 2                  #  Future
psi_r = np.zeros((3,N)) #  Real
psi_i = np.zeros((3,N)) #  Imaginary
psi_p = np.zeros(N)    # Observable probability (magnitude squared of the complex wave function).
#  A present-only state will "split" with half the
#  wave function propagating to the left and the other half to the right.
#  Including a "past" state will cause it to propagate one way.
xn = range(1,N/2)
x = X[xn]/dx    #  Normalized position coordinate
gg = Gaussian(x,x0,sigma)
cx = np.cos(k0*x)       #this is an array
sx = np.sin(k0*x)       #so is this
psi_r[PR,xn] = cx*gg    #present, real part
psi_i[PR,xn] = sx*gg    #present, imaginary part
psi_r[PA,xn] = cx*gg    #past real part
psi_i[PA,xn] = sx*gg    #past imaginary part
# Initial normalization of wavefunctions
#   Compute the observable probability.
psi_p = psi_r[PR]**2 + psi_i[PR]**2
#  Normalize the wave functions so that the total probability in the simulation
#  is equal to 1.
P   = dx * psi_p.sum()                      #  Total probability.
nrm = np.sqrt(P)
psi_r /= nrm
psi_i /= nrm
psi_p /= P


#  Direct index assignment is MUCH faster than using a spatial FOR loop, so
#  these constants are used in the update equations.  Remember that Python uses
#  zero-based indexing.
IDX1 = range(1,N-1)                            #  psi [ l ], where l stands for x_l
IDX2 = range(2,N)                              #  psi [ l + 1 ]
IDX3 = range(0,N-2)                            #  psi [ l - 1 ]




def solver(tstop,user_action=None):

    t0 = time.clock()   #start timer to find out how much CPU time is used in solver

    t = 0.
    #print tstop*dt, dt
    # Precompute a couple of indexing constants, this speeds up the computation
    #Note that boundary conditions are not treated correctly. So don't believe
    #results which approach boundaries of computational domain!!!
    while t < tstop*dt:
        #print t, tstop*dt
        psi_rPR = psi_r[PR]
        psi_iPR = psi_i[PR]
        #  Apply the update equations.
        psi_i[FU,IDX1] = psi_i[PA,IDX1] + \
                         c1*(psi_rPR[IDX2] - 2*psi_rPR[IDX1] +
                             psi_rPR[IDX3])
        psi_i[FU] -= c2V*psi_r[PR]
        
        psi_r[FU,IDX1] = psi_r[PA,IDX1] - \
                         c1*(psi_iPR[IDX2] - 2*psi_iPR[IDX1] +
                             psi_iPR[IDX3])
        psi_r[FU] += c2V*psi_i[PR]
        #  Increment the time steps.  PR -> PA and FU -> PR
        psi_r[PA] = psi_rPR
        psi_r[PR] = psi_r[FU]
        psi_i[PA] = psi_iPR
        psi_i[PR] = psi_i[FU]
        psi_p = psi_r[PR]**2 + psi_i[PR]**2
        t += dt

        if user_action is not None:
            #print psi_r[PR], psi_p
            user_action(psi_r[PR], psi_i[PR], psi_p, x, t)       #I can do whatever I'd like here

    t1 = time.clock()
    return t1-t0, t                     #solver does NOT return solution u!
					#But it returns how long it worked (t1-t0) and the current model time

# ----------------  end of solver  -------------------------------------------



def test_solver(n_plot):
    """
    Very simple test case.
    Store the solution at every N time levels.
    Measure how long the solver actually works
    """
    g = Gnuplot.Gnuplot()#depending on your system, you may need to insert persist=1 in the parenthesis
    # Need time_level_counter as global variable since
    # it is assigned in the action function (that makes
    # a variable local to that block otherwise).
    global time_level_counter
    time_level_counter = 0
    t = 0.
    cpu = 0.
    def action(psir, psii, psip, x, t): #this is where I can access solution u
        global time_level_counter
        if time_level_counter % n_plot == 0: #only do this every N times
            g.reset()
            g('set ylabel "y-axis [arb. units]"')   #this is how you access gnuplot commands
            g('set xlabel "x-axis [arb. units]"')
            string = 'set yrange [-0.15:'+str(V0*5.)+']'
            g(string)
            data_psir   = Gnuplot.Data(X,psir,using=(1,2),with_='line',title='psi_r')
            data_psii   = Gnuplot.Data(X,psii,using=(1,2),with_='line',title='psi_i')
            data_psip   = Gnuplot.Data(X,5.*psip,using=(1,2),with_='line lw 2',title='psi_p')
            data_V   = Gnuplot.Data(X,1.*V,using=(1,2),with_='line lw 2',title='V')
            g.plot(data_psir,data_psii,data_psip,data_V)
            #g.plot(data_psip)
            #z = time.sleep(0.1) 	#wait a little so you can look at it
            #print rho_res, dres
        time_level_counter += 1

    #call the solver here, note that the solution, u, is not returned
    #but cpu time is
    cpu, t = solver(tstop, user_action=action)
    print 'CPU time:', cpu, ' sec'
    print 'time: ', t
    return psi_r[PR], psi_i[PR], psi_p
                                                   #test_solver has access to solutions through 'action'
				                   #and returns solution array to main

# ---------------------------------  end test_solver  -----------------------------


def main():
    # open output file
    #output = open("1dwave-sol.dat","w")
    psir=np.zeros(N); psii = psir.copy(); psip = psir.copy()
    psir, psii, psip = test_solver(n_plot) 	#this is where I call the solver
    #print solutions
    #output.write(str(solutions)) 	#note: the output is an array
    #output.close()
    raw_input('Please press the return key to continue...\n')
    

main()
