#!/usr/bin/env python
"""
Simple 2-d wave function pde solver
"""
from numpy import *
import Gnuplot, Gnuplot.funcutils
import time



def solver2d(I,f,c,bc,Lx,Ly,nx,ny,dt,tstop,user_action=None):
    # f is a function of x, y, and t, I is a function of x and y
    t0 = time.clock()
    x = linspace(0,Lx,nx+1) #grid points in x direction
    dx = Lx/float(nx)
    y = linspace(0,Ly,ny+1) #grid points in x direction
    dy = Ly/float(ny)
    xv = x[:,newaxis]  #for vectorized version
    yv = y[newaxis,:]  #for vectorized version
    #print xv
    if dt <= 0: 
       dt = (1.0/float(c))*(1./sqrt(1./dx**2 + 1./dy**2)) #maximum time step for stability
    Cx2 = (c*dt/dx)**2
    Cy2 = (c*dt/dy)**2
    dt2 = dt*dt

    up = zeros((nx+1,ny+1))  #solution at time step n+1 (u-plus)
    u  = up.copy()   #solution at time step n   (u)
    um = up.copy()   #solution at time step n-1 (u-minus)

    #set initial conditions
    t = 0.0
    for i in xrange(0,nx,1):
        for j in xrange(0,ny,1):
            u[i,j] = I(x[i],y[j])

    # this is the discretized version of the wave equation
    for i in xrange(0,nx,1):
        for j in xrange(0,ny,1):
            um[i:j] = u[i:j] + 0.5*Cx2*(u[i-1,j] - 2*u[i,j] + u[i+1,j])   \
                             + 0.5*Cy2*(u[i,j-1] - 2*u[i,j] + u[i,j+1])   \
			     + dt2*f(x[i],y[j],t)
    #boundary conditions at i = 0, j=0, i=nx, and j=ny
    i = 0
    for j in xrange(0,ny,1): um[i,j] = bc(x[i],y[j],t+dt)
    j = 0
    for i in xrange(0,ny,1): um[i,j] = bc(x[i],y[j],t+dt)
    i = nx
    for j in xrange(0,ny,1): um[i,j] = bc(x[i],y[j],t+dt)
    j = ny
    for i in xrange(0,ny,1): um[i,j] = bc(x[i],y[j],t+dt)


    while t < tstop:
        t_old = t; t += dt
        #discretized wave equation, update all inner points
#	for i in xrange(1,nx-1,1):
#	    for j in xrange(1,ny-1,1):
#	        up[i,j] = - um[i,j] + 2.*u[i,j] + \
#		            Cx2*(u[i-1,j] - 2.*u[i,j] + u[i+1,j]) + \
#		            Cy2*(u[i,j-1] - 2.*u[i,j] + u[i,j+1]) + \
#			    dt2*f(x[i], y[j], t_old)

	#----  vectorized version of above, should run faster ------------------------
	up[1:nx-1,1:ny-1] = - um[1:nx-1,1:ny-1] + 2.*u[1:nx-1,1:ny-1] + \
			  Cx2*(u[0:nx-2,1:ny-1] - 2.*u[1:nx-1,1:ny-1] + u[2:nx,1:ny-1]) + \
			  Cy2*(u[1:nx-1,0:ny-2] - 2.*u[1:nx-1,1:ny-1] + u[1:nx-1,2:ny]) + \
			  dt2*f(xv[1:nx-1,1:ny-1], yv[1:nx-1,1:ny-1], t_old)
	#for i in xrange(1,nx-1,1): print i
	#insert boundary conditions
    	i = 0
    	for j in xrange(0,ny,1): up[i,j] = bc(x[i],y[j],t+dt)
    	j = 0
    	for i in xrange(0,ny,1): up[i,j] = bc(x[i],y[j],t+dt)
    	i = nx
    	for j in xrange(0,ny,1): up[i,j] = bc(x[i],y[j],t+dt)
    	j = ny
    	for i in xrange(0,ny,1): up[i,j] = bc(x[i],y[j],t+dt)

        if user_action is not None:
            user_action(up, x, t)       #I can do whatever I'd like here
                                        #Note: if I want to do something with the
					#solution, u, I need to do it here!

        um, u, up = u, up, um                # switch references/update data structures
    t1 = time.clock()
    return dt, x, t1-t0                 #solver does NOT return solution u!
					#But it returns how long it worked (t1-t0)

# ----------------  end of solver  -------------------------------------------



def test_solver1(N):
    """
    Very simple test case.
    Store the solution at every N time levels.
    Measure how long the solver actually works
    """
    g = Gnuplot.Gnuplot()
    c = 1.
    #def I2(x,y):  return exp(-(x-Lx/2.)**2/2. - (y-Ly/2)**2/2.)
    def I2(x,y):  return 0.
    def f(x,y,t): return 0.
    #def bc(x,y,t): return 0.
    def bc(x,y,t): 
        if x == 0:
            return 0.05*sin(0.2*pi*(y-0.5*c*t))
        elif y == 0:
            return 0.05*sin(0.2*pi*(x-0.5*c*t))
        else:
            return 0.
    solutions = []
    # Need time_level_counter as global variable since
    # it is assigned in the action function (that makes
    # a variable local to that block otherwise).
    # The manager class below provides a cleaner solution.
    global time_level_counter
    time_level_counter = 0

    def action(u, x, t):           	#this is where I can access solution u
        global time_level_counter
        if time_level_counter % N == 0: #only do this every N times
            solutions.append(u.copy())
    	    y = u.copy()
	    g.reset()
	    g('set ylabel "y-axis [arb. units]"')   #this is how you access gnupot commands
	    g('set xlabel "x-axis [arb. units]"')
	    #g('set timestamp "%a/%b/%d/%y %H:%M:%S"')
            #g('show time')
            #g('set dgrid3d 40, 40')
            #g('unset dgrid3d')
	    #g('show dgrid3d')
	    #g('set contour base')
            #g('unset surface')
            #g('set pm3d map') #does not work like this
	    g('set view 45, 60, 1, 1')
	    g('set zrange [-0.5:0.5]')
	    p = [(1.*i,1.*j,y[i,j]) for i in xrange(0,nx,1) for j in xrange(0,ny,1)]
	    g.splot(p,title='this should move')
	    z = time.sleep(0.2) 	#wait a little so you can look at it
        time_level_counter += 1

    nx = 30; ny = 30; tstop = 30
    Lx = 20; Ly = 20 
    #call the solver here, note that the solution, u, is not returned
    #but cpu time is
    dt, x, cpu = solver2d(I2, f, c, bc, Lx, Ly, nx, ny, 0, tstop, user_action=action)
    print 'CPU time:', cpu, ' sec'
    return solutions		#test_solver has access to solutions through 'action'
				#and returns solution array to main

# ---------------------------------  end test_solver  -----------------------------


def main():
    # open output file
    #output = open("1dwave-sol.dat","w")
    solutions = test_solver1(2) 	#this is where I call the solver
    #print solutions
    #output.write(str(solutions)) 	#note: the output is an array
    #output.close()
    raw_input('Please press the return key to continue...\n')
    

main()
