#!/usr/bin/env python
"""
Solve Poisson equation with various methods defined below.

we will usually solve on a 128 by 128 grid 

author: Bob Wimmer
date: January 3, 2013
email: wimmer@physik.uni-kiel.de

This script follows to a large extent chapter 10 of Hirsch, vol. 1.

"""
from numpy import *
from numpy.linalg import norm as norm
from scipy.linalg import solve_banded
#from scipy.sparse.linalg import eigsh
import time
import Gnuplot

NX = 128 #128
NY = 128 #128
dx = 1.
dy = 1.
N_iter = 256
histmax = zeros(N_iter) #array to store max of solution in
om_jacobi = 1.
om_gs = 1.5

up = zeros(NX*NY).reshape(NX,NY)   #past   (time step n)
uf = up.copy()  #future (time step n+1)
ufp = up.copy()  #future (time step n+1) (needed for over-relaxation methods)
s  = zeros(NX*NY).reshape(NX,NY)  #source term, will need to be filled in later
bc = s.copy()         #boundary conditions, to be implemented later
g = Gnuplot.Gnuplot()
h = Gnuplot.Gnuplot()

def plot_sol(u,nx,ny,t):
    """
    plot solution
    """
 #   g = Gnuplot.Gnuplot()
    g.reset()
    string = "set xrange [0:" + str(nx) + "]"; g(string)
    string = "set yrange [0:" + str(ny) + "]"; g(string)
    g('set size square')
    g('set ylabel "y-axis [arb. units]"')   #this is how you access gnupot commands
    g('set xlabel "x-axis [arb. units]"')
    string = "set title '" + t + "'"
    g(string)
    g('set pm3d map')
    data = Gnuplot.GridData(u,range(nx), range(ny), with_='image')
    g.splot(data)
    z=time.sleep(0.125)

def plot_histmax(histmax):
    h.reset()
    string = "set xrange [0:" + str(N_iter) + "]"; h(string)
    string = "set yrange [0:1]"; h(string)
    h('set ylabel "max [arb. units]"')   #this is how you access gnupot commands
    h('set xlabel "ieration number"')
    data = Gnuplot.Data(arange(0,N_iter),histmax)
    h.plot(data)
    z=time.sleep(0.125)


def source(s,x,y,radius,charge):
    """
    define the source terms as a collection of circular charges at location (OX,OY) and with radii radius. Call successively if more than one charge is required.
    """
    XMIN = int((x - radius)/dx)
    XMAX = int((x + radius)/dx)
    YMIN = int((y - radius)/dy)
    YMAX = int((y + radius)/dy)
    print XMIN, XMAX, YMIN, YMAX
    #Now do two for loops....
    for i in arange(XMIN,XMAX):
        for j in arange(YMIN,YMAX):
            if (i*dx-x)**2 + (j*dy-y)**2 < radius**2: s[i,j] += charge
    return s

def jacobi(nx,ny,up,s,bc):
    """
    solve by Jacobi method
    """
    uf = up.copy()
    uf[1:nx-1,1:nx-1] = 0.25*((up[2:nx,1:ny-1]+ up[0:nx-2,1:ny-1]+ up[1:nx-1,2:ny]+ up[1:nx-1,0:ny-2] ) + s[1:nx-1,1:ny-1] + bc[1:nx-1,1:ny-1])
    return uf

def jacobi_over(nx,ny,up,s,bc,om):
    """
    solve by Jacobi overrelaxation method.

    Note: the Jacobi method is already the 'optimal' Jacobi method. Overrelaxation analysis shows that the Jacobi overrelaxation parameter is unity. This does not improve convergence. Thus, the result should be exactly the same as for the ordinary Jacobi method.
    """
    uf = up.copy()
    uf[1:nx-1,1:nx-1] = om*0.25*((up[2:nx,1:ny-1]+ up[0:nx-2,1:ny-1]+ up[1:nx-1,2:ny]+ up[1:nx-1,0:ny-2] ) + s[1:nx-1,1:ny-1] + bc[1:nx-1,1:ny-1]) + (1.-om)*up[1:nx-1,1:nx-1]
    return uf

def jacobi_slor(nx,ny,up,s,bc,om):
    """
    solve by Jacobi overrelaxation method.

    Note: the Jacobi method is already the 'optimal' Jacobi method. Overrelaxation analysis shows that the Jacobi overrelaxation parameter is unity. This does not improve convergence. Thus, the result should be exactly the same as for the ordinary Jacobi method.
    """
    uf = up.copy()
    ufp = up.copy()
    uf[1:nx-1,1:nx-1] = 0.25*((up[2:nx,1:ny-1]+ up[0:nx-2,1:ny-1]+ up[1:nx-1,2:ny]+ up[1:nx-1,0:ny-2] ) + s[1:nx-1,1:ny-1] + bc[1:nx-1,1:ny-1]) 
    ufp[1:nx-1,1:ny-1] = om*uf[1:nx-1,1:ny-1] + (1.-om)*up[1:nx-1,1:ny-1]
    return ufp

def gauss_seidel(nx,ny,up,s,bc):
    """
    solve by Gauss-Seidel method

    Should converge faster than the Jacobi method.
    """
    uf = up.copy()
    uf[1:nx-1,1:nx-1] = 0.25*((up[2:nx,1:ny-1]+ uf[0:nx-2,1:ny-1]+ up[1:nx-1,2:ny]+ uf[1:nx-1,0:ny-2] ) + s[1:nx-1,1:ny-1] + bc[1:nx-1,1:ny-1])
    return uf

def gauss_seidel_over(nx,ny,up,s,bc,om):
    """
    solve by Gauss-Seidel overrelaxation method

    Should converge faster than the Gauss-Seidel method.
    """
    uf = up.copy()
    ufp = up.copy()
    uf[1:nx-1,1:nx-1] = 0.25*((up[2:nx,1:ny-1]+ uf[0:nx-2,1:ny-1]+ up[1:nx-1,2:ny]+ uf[1:nx-1,0:ny-2] ) + s[1:nx-1,1:ny-1] + bc[1:nx-1,1:ny-1])
    ufp[1:nx-1,1:ny-1] = om*uf[1:nx-1,1:ny-1] + (1.-om)*up[1:nx-1,1:ny-1]
    return ufp

def slor_col(nx,ny,ab,up,s,bc,om):
    """
    Solve by SLOR algorithm

    Should converge faster than Gauss-Seidel method.
    """
    uf = up.copy()
    ufp = up.copy()
    nx = shape(up)[0]  #number of columns
    ny = shape(up)[1]  #number of rows
    rhs = zeros(nx)
    for i in arange(1,nx-1):
        rhs[1:ny-1] = 0.25*(up[i+1,2:ny] + uf[i-1,0:ny-2] + s[i,1:ny-1] + bc[i,1:ny-1])
        ufp[i,:] = solve_banded((1,1), ab,rhs)
        uf[i,:] = om*ufp[i,:] + (1-om)*up[i,:]
    return uf

def slor_row(nx,ny,ab,up,s,bc,om):
    """
    Solve by SLOR algorithm

    Should converge faster than Gauss-Seidel method.
    """
    uf = up.copy()
    ufp = up.copy()
    nx = shape(up)[0]  #number of columns
    ny = shape(up)[1]  #number of rows
    rhs = zeros(nx)
    for i in arange(1,ny-1):
        rhs[1:nx-1] = 0.25*(up[2:nx,i+1] + uf[0:nx-2,i-1] + s[1:nx-1,i] + bc[1:nx-1,i])
        ufp[:,i] = solve_banded((1,1), ab,rhs)
        uf[:,i] = om*ufp[:,i] + (1-om)*up[:,i]
    return uf

def red_black(nx,ny,up,s,bc,om):
    """
    solve by red-black successive line overrelaxation method

    Should converge faster than the Gauss-Seidel method.
    """
    uf = up.copy()
    ufp = up.copy()
    #red1
    uf[1:nx-1:2,1:ny-1:2] = 0.25*((up[2:nx:2,1:ny-1:2]+ up[0:nx-2:2,1:ny-1:2]+ up[1:nx-1:2,2:ny:2]+ up[1:nx-1:2,0:ny-2:2] ) + s[1:nx-1:2,1:ny-1:2] + bc[1:nx-1:2,1:ny-1:2])
    ufp[1:nx-1:2,1:ny-1:2] = om*uf[1:nx-1:2,1:nx-1:2] + (1.-om)*up[1:ny-1:2,1:ny-1:2]
#    #red2
    uf[2:nx:2,2:ny:2] = 0.25*((up[3:nx:2,2:ny:2]+ up[1:nx-1:2,2:ny:2]+ up[2:nx:2,3:ny:2]+ up[2:nx:2,1:ny-1:2] ) + s[2:nx:2,2:ny:2] + bc[2:nx:2,2:ny:2])
    ufp[2:nx:2,2:ny:2] = om*uf[2:nx:2,2:ny:2] + (1.-om)*up[2:nx:2,2:ny:2]
#    #black1
    uf[2:nx-2:2,1:ny-1:2] = 0.25*((up[3:nx-1:2,1:ny-1:2]+ up[1:nx-3:2,1:ny-1:2]+ up[2:nx-2:2,2:ny:2]+ up[2:nx-2:2,0:ny-2:2] ) + s[2:nx-2:2,1:ny-1:2] + bc[2:nx-2:2,1:ny-1:2])
    ufp[2:nx-2:2,1:ny-1:2] = om*uf[2:nx-2:2,1:ny-1:2] + (1.-om)*up[2:nx-2:2,1:ny-1:2]
#    #black2
    uf[1:nx-1:2,2:ny-2:2] = 0.25*((up[2:nx:2,2:ny-2:2]+ up[0:nx-2:2,2:ny-2:2]+ up[1:nx-1:2,1:ny-3:2]+ up[1:nx-1:2,3:ny-1:2] ) + s[1:nx-1:2,2:ny-2:2] + bc[1:nx-1:2,2:ny-2:2])
    ufp[1:nx-1:2,2:ny-2:2] = om*uf[1:nx-1:2,2:ny-2:2] + (1.-om)*up[1:nx-1:2,2:ny-2:2]
    return ufp

def zebra_row(nx,ny,up,s,bc):
    """
    solve by zebra SOR method using rows 
    """
    uf = up.copy()
    #odd rows
    uf[1:nx-1,1:ny-1:2] = 0.25*((up[2:nx,1:ny-1:2]+ up[0:nx-2,1:ny-1:2]+ up[1:nx-1,2:ny:2]+ up[1:nx-1,0:ny-2:2] ) + s[1:nx-1,1:ny-1:2] + bc[1:nx-1,1:ny-1:2])
    #even rows
    uf[1:nx-1,2:ny-2:2] = 0.25*((up[2:nx,2:ny-2:2]+ up[0:nx-2,2:ny-2:2]+ uf[1:nx-1,3:ny-1:2]+ uf[1:nx-1,1:ny-3:2] ) + s[1:nx-1,2:ny-2:2] + bc[1:nx-1,2:ny-2:2])
    return uf

def zebra_col(nx,ny,up,s,bc):
    """
    solve by zebra SOR method using columns
    """
    #odd columns
    uf = up.copy()
    uf[1:nx-1:2,1:ny-1] = 0.25*((up[2:nx:2,1:ny-1]+ up[0:nx-2:2,1:ny-1]+ up[1:nx-1:2,2:ny]+ up[1:nx-1:2,0:ny-2] ) + s[1:nx-1:2,1:ny-1] + bc[1:nx-1:2,1:ny-1])
    #even columns
    uf[2:nx-2:2,1:ny-1] = 0.25*((uf[3:nx-1:2,1:ny-1]+ uf[1:nx-3:2,1:ny-1]+ up[2:nx-2:2,0:ny-2]+ up[2:nx-2:2,2:ny] ) + s[2:nx-2:2,1:ny-1] + bc[2:nx-2:2,1:ny-1])
    return uf

#this allows us to use this file as a module, but also to call it as a standalone script
if __name__ == "__main__":
#    g = Gnuplot.Gnuplot()
    #prepare matrix for SLOR method
    ud = -0.25*ones(NX-1); ud = insert(ud,0,0)
    ld = -0.25*ones(NX-1); ld = insert(ld,NX-1,0)
    d  = ones(NX)
    ab = matrix([ud,d,ld,])
    #print ab
    s = source(s,NX/4*dx,NY/4*dy,1.,-1.)
    #s = source(s,NX/2*dx,NY/2*dy,1.,4.)
    #s = source(s,NX/4*dx,3*NY/4*dy,1.,-1.)
    #s = source(s,3*NX/4*dx,NY/4*dy,1.,-1.)
    s = source(s,3*NX/4*dx,3*NY/4*dy,1.,1.)
    up = s.copy()
    t0 = time.clock()    
    nx=NX;ny=NY    
    for i in arange(0,N_iter):
        upc = up.copy()
        uf = jacobi(nx,ny,upc,s,bc)
        #uf = gauss_seidel(nx,ny,upc,s,bc)
        #uf = jacobi_over(nx,ny,upc,s,bc,1.)
        #uf = jacobi_slor(nx,ny,upc,s,bc,1.)
        #uf = gauss_seidel_over(nx,ny,upc,s,bc,0.35)
        #uf = slor_col(nx,ny,ab,upc,s,bc,1.5)
        #uf = slor_row(nx,ny,ab,upc,s,bc,1.5)
        #uf = red_black(nx,ny,upc,s,bc,.5)
        #uf = zebra_col(nx,ny,upc,s,bc)
        #uf = zebra_row(nx,ny,upc,s,bc)
        eps = uf-upc
        if i%2 ==0: plot_sol(uf,NX,NY,str(i))
        if i >0: print i, ': ', norm(eps), norm(uf), norm(up), norm(upc), uf.max(), uf.max() - histmax[i-2]
        histmax[i] = uf.max()
        if i%2 ==0: plot_histmax(histmax)
        up = uf #update solution
    t1 = time.clock() 
    print 'Simulation took %5.3f seconds'%(t1-t0)
    #plot_sol(uf)
