#!/usr/bin/env python
"""
Solve Poisson equation with various multigrid methods

we will usually solve on grids differing by a factor of two in one dimension, e.g., 1024x1024 --> 512x512 --> 256x256 --> 128x128

author: Bob Wimmer
date: January 4, 2013
email: wimmer@physik.uni-kiel.de

This script follows to a large extent chapter 10 of Hirsch, vol. 1.

Note that boundaries are not treated correctly. I have simply set boundary conditions to zero and don't treat the boundaries at all. Thus they remain zero at all times. This is a highly simplified way to do so...

"""
from numpy import *
from numpy.linalg import norm as norm
#from scipy.linalg import eigh
#from scipy.sparse.linalg import eigsh
import time
import Gnuplot

NX = 512
NY = 512
dx = 1.
dy = 1.
#N_iter = 16

up = zeros(NX*NY).reshape(NX,NY)   #past   (time step n)
uf = up.copy()  #future (time step n+1)
ufp = up.copy()  #future (time step n+1) (needed for over-relaxation methods)
s  = zeros(NX*NY).reshape(NX,NY)  #source term, will need to be filled in later
bc = s.copy()         #boundary conditions, to be implemented later

def plot_sol(u,nx,ny,t):
    """
    plot solution
    """
#    g = Gnuplot.Gnuplot()
    g.reset()
    string = "set xrange [0:" + str(nx) + "]"; g(string)
    string = "set yrange [0:" + str(ny) + "]"; g(string)
    g('set size square')
    g('set ylabel "y-axis [arb. units]"')   #this is how you access gnupot commands
    g('set xlabel "x-axis [arb. units]"')
    string = "set title '" + t + "'"
    g(string)
    g('set pm3d map')
    data = Gnuplot.GridData(u,range(nx), range(ny), with_='image')
    g.splot(data)
    z=time.sleep(0.1)


def source(s,x,y,radius,charge):
    """
    define the source terms as a collection of circular charges at location (OX,OY) and with radii radius. Call successively if more than one charge is required.
    """
    XMIN = int((x - radius)/dx)
    XMAX = int((x + radius)/dx)
    YMIN = int((y - radius)/dy)
    YMAX = int((y + radius)/dy)
    print XMIN, XMAX, YMIN, YMAX
    #Now do two for loops....
    for i in arange(XMIN,XMAX):
        for j in arange(YMIN,YMAX):
            if (i*dx-x)**2 + (j*dy-y)**2 < radius**2: s[i,j] += charge
    return s

def jacobi(nx,ny,up,s,bc):
    """
    solve by Jacobi method
    """
    uf = up.copy()
    uf[1:nx-1,1:nx-1] = 0.25*((up[2:nx,1:ny-1]+ up[0:nx-2,1:ny-1]+ up[1:nx-1,2:ny]+ up[1:nx-1,0:ny-2] ) + s[1:nx-1,1:ny-1] + bc[1:nx-1,1:ny-1])
    return uf

def jacobi_over(nx,ny,up,s,bc,om):
    """
    solve by Jacobi overrelaxation method.

    Note: the Jacobi method is already the 'optimal' Jacobi method. Overrelaxation analysis shows that the Jacobi overrelaxation parameter is unity. This does not improve convergence. Thus, the result should be exactly the same as for the ordinary Jacobi method.
    """
    uf = up.copy()
    uf[1:nx-1,1:nx-1] = om*0.25*((up[2:nx,1:ny-1]+ up[0:nx-2,1:ny-1]+ up[1:nx-1,2:ny]+ up[1:nx-1,0:ny-2] ) + s[1:nx-1,1:ny-1] + bc[1:nx-1,1:ny-1]) + (1.-om)*up[1:nx-1,1:nx-1]
    return uf

def jacobi_slor(nx,ny,up,s,bc,om):
    """
    solve by Jacobi overrelaxation method.

    Note: the Jacobi method is already the 'optimal' Jacobi method. Overrelaxation analysis shows that the Jacobi overrelaxation parameter is unity. This does not improve convergence. Thus, the result should be exactly the same as for the ordinary Jacobi method.
    """
    uf = up.copy()
    ufp = up.copy()
    uf[1:nx-1,1:nx-1] = 0.25*((up[2:nx,1:ny-1]+ up[0:nx-2,1:ny-1]+ up[1:nx-1,2:ny]+ up[1:nx-1,0:ny-2] ) + s[1:nx-1,1:ny-1] + bc[1:nx-1,1:ny-1]) 
    ufp[1:nx-1,1:ny-1] = om*uf[1:nx-1,1:ny-1] + (1.-om)*up[1:nx-1,1:ny-1]
    return ufp

def gauss_seidel(nx,ny,up,s,bc):
    """
    solve by Gauss-Seidel method

    Should converge faster than the Jacobi method.
    """
    uf = up.copy()
    uf[1:nx-1,1:nx-1] = 0.25*((up[2:nx,1:ny-1]+ uf[0:nx-2,1:ny-1]+ up[1:nx-1,2:ny]+ uf[1:nx-1,0:ny-2] ) + s[1:nx-1,1:ny-1] + bc[1:nx-1,1:ny-1])
    return uf

def gauss_seidel_over(nx,ny,up,s,bc,om):
    """
    solve by Gauss-Seidel overrelaxation method

    Should converge faster than the Gauss-Seidel method.
    """
    uf = up.copy()
    ufp = up.copy()
    uf[1:nx-1,1:nx-1] = 0.25*((up[2:nx,1:ny-1]+ uf[0:nx-2,1:ny-1]+ up[1:nx-1,2:ny]+ uf[1:nx-1,0:ny-2] ) + s[1:nx-1,1:ny-1] + bc[1:nx-1,1:ny-1])
    ufp[1:nx-1,1:ny-1] = om*uf[1:nx-1,1:ny-1] + (1.-om)*up[1:nx-1,1:ny-1]
    return ufp

def red_black(nx,ny,up,s,bc,om):
    """
    solve by red-black successive line overrelaxation method

    Should converge faster than the Gauss-Seidel method.
    """
    uf = up.copy()
    ufp = up.copy()
    #red1
    uf[1:nx-1:2,1:ny-1:2] = 0.25*((up[2:nx:2,1:ny-1:2]+ up[0:nx-2:2,1:ny-1:2]+ up[1:nx-1:2,2:ny:2]+ up[1:nx-1:2,0:ny-2:2] ) + s[1:nx-1:2,1:ny-1:2] + bc[1:nx-1:2,1:ny-1:2])
    ufp[1:nx-1:2,1:ny-1:2] = om*uf[1:nx-1:2,1:nx-1:2] + (1.-om)*up[1:ny-1:2,1:ny-1:2]
#    #red2
    uf[2:nx:2,2:ny:2] = 0.25*((up[3:nx:2,2:ny:2]+ up[1:nx-1:2,2:ny:2]+ up[2:nx:2,3:ny:2]+ up[2:nx:2,1:ny-1:2] ) + s[2:nx:2,2:ny:2] + bc[2:nx:2,2:ny:2])
    ufp[2:nx:2,2:ny:2] = om*uf[2:nx:2,2:ny:2] + (1.-om)*up[2:nx:2,2:ny:2]
#    #black1
    uf[2:nx-2:2,1:ny-1:2] = 0.25*((up[3:nx-1:2,1:ny-1:2]+ up[1:nx-3:2,1:ny-1:2]+ up[2:nx-2:2,2:ny:2]+ up[2:nx-2:2,0:ny-2:2] ) + s[2:nx-2:2,1:ny-1:2] + bc[2:nx-2:2,1:ny-1:2])
    ufp[2:nx-2:2,1:ny-1:2] = om*uf[2:nx-2:2,1:ny-1:2] + (1.-om)*up[2:nx-2:2,1:ny-1:2]
#    #black2
    uf[1:nx-1:2,2:ny-2:2] = 0.25*((up[2:nx:2,2:ny-2:2]+ up[0:nx-2:2,2:ny-2:2]+ up[1:nx-1:2,1:ny-3:2]+ up[1:nx-1:2,3:ny-1:2] ) + s[1:nx-1:2,2:ny-2:2] + bc[1:nx-1:2,2:ny-2:2])
    ufp[1:nx-1:2,2:ny-2:2] = om*uf[1:nx-1:2,2:ny-2:2] + (1.-om)*up[1:nx-1:2,2:ny-2:2]
    return ufp

def zebra_row(nx,ny,up,s,bc):
    """
    solve by zebra SOR method using rows 
    """
    uf = up.copy()
    #odd rows
    uf[1:nx-1,1:ny-1:2] = 0.25*((up[2:nx,1:ny-1:2]+ up[0:nx-2,1:ny-1:2]+ up[1:nx-1,2:ny:2]+ up[1:nx-1,0:ny-2:2] ) + s[1:nx-1,1:ny-1:2] + bc[1:nx-1,1:ny-1:2])
    #even rows
    uf[1:nx-1,2:ny-2:2] = 0.25*((up[2:nx,2:ny-2:2]+ up[0:nx-2,2:ny-2:2]+ uf[1:nx-1,3:ny-1:2]+ uf[1:nx-1,1:ny-3:2] ) + s[1:nx-1,2:ny-2:2] + bc[1:nx-1,2:ny-2:2])
    return uf

def zebra_col(nx,ny,up,s,bc):
    """
    solve by zebra SOR method using columns
    """
    #odd columns
    uf = up.copy()
    uf[1:nx-1:2,1:ny-1] = 0.25*((up[2:nx:2,1:ny-1]+ up[0:nx-2:2,1:ny-1]+ up[1:nx-1:2,2:ny]+ up[1:nx-1:2,0:ny-2] ) + s[1:nx-1:2,1:ny-1] + bc[1:nx-1:2,1:ny-1])
    #even columns
    uf[2:nx-2:2,1:ny-1] = 0.25*((uf[3:nx-1:2,1:ny-1]+ uf[1:nx-3:2,1:ny-1]+ up[2:nx-2:2,0:ny-2]+ up[2:nx-2:2,2:ny] ) + s[2:nx-2:2,1:ny-1] + bc[2:nx-2:2,1:ny-1])
    return uf

def iterate(n_iter,nx,ny,up,s,bc,method):
    eps = up
    uf = up.copy()
    if method == 'jacobi':
        for i in arange(0,n_iter):
            upc = up.copy();
            uf = up.copy()
            uf = jacobi(nx,ny,upc,s,bc)
            eps = uf-upc
            up = uf
            plot_sol(up,nx,ny,str(i))
    elif method == 'gauss-seidel':
        for i in arange(0,n_iter):
            upc = up.copy();
            uf = up.copy()
            uf = gauss_seidel(nx,ny,upc,s,bc)
            eps = uf-upc
            up = uf
            plot_sol(up,nx,ny,str(i))
    elif method == 'gauss-seidel-over':
        for i in arange(0,n_iter):
            upc = up.copy();
            uf = up.copy()
            uf = gauss_seidel_over(nx,ny,upc,s,bc,0.35)
            eps = uf-upc
            up = uf
            plot_sol(up,nx,ny,str(i))
    elif method == 'red-black':
        for i in arange(0,n_iter):
            upc = up.copy();
            uf = up.copy()
            uf = red_black(nx,ny,upc,s,bc,.35)
            eps = uf-upc
            up = uf
            plot_sol(up,nx,ny,str(i))
    elif method == 'zebra':
        for i in arange(0,n_iter/2):
            upc = up.copy()
            uf = up.copy()
            uf = zebra_col(nx,ny,upc,s,bc)
            eps = uf-upc
            up = uf
            upc = up.copy();
            uf = zebra_row(nx,ny,upc,s,bc)
            eps = uf-upc
            up = uf
            plot_sol(up,nx,ny,str(i))
    return up, eps

def reduce_grid(nx,ny,u,factor,method):
    """
    reduce the grid dimension by factor using method method

    currently only factor 2 is implemented

    methods include: simple injection (simply returns the even grid points)

    returns new grid solution up(nx/factor,ny/factor)
    """
    if method == 'simple':
        idxp = arange(0,nx/factor); idx = arange(0,nx,2)
        idyp = arange(0,ny/factor); idy = arange(0,ny,2)
        up = zeros(nx/factor*ny/factor).reshape(nx/factor,ny/factor)
        #up[idxp,idyp] = u[2*idxp,2*idyp]
        #print idxp, idx
        for i in idxp:
        #    for j in idyp:
            up[i,idyp] = u[2*i,idy]
    elif method == 'half':
        idxp = arange(1,nx/factor-1); idx = arange(2,nx-2,2)
        idyp = arange(1,nx/factor-1); idy = arange(2,ny-2,2)
        up = zeros(nx/factor*ny/factor).reshape(nx/factor,ny/factor)
        #up[idxp,idyp] = 2.*u[idx,idy] + 1.*(u[idx-1,idy] + u[idx+1,idy] + u[idx,idy-1] + u[idx,idy+1])
        for i in idxp:
        #    for j in idyp:
            up[i,idyp] = 2.*u[2*i,idy]+ 1.*(u[2*i-1,idy] + u[2*i+1,idy] + u[2*i,idy-1] + u[2*i,idy+1])
    else: #do full weighting
        idxp = arange(1,nx/factor-1); idx = arange(2,nx-2,2)
        idyp = arange(1,nx/factor-1); idy = arange(2,ny-2,2)
        up = zeros(nx/factor*ny/factor).reshape(nx/factor,ny/factor)
        #up[idxp,idyp] = 4.*u[idx,idy] + 2.*(u[idx-1,idy] + u[idx+1,idy] + u[idx,idy-1] + u[idx,idy+1]) + 1.*(u[idx-1,idy-1] + u[idx+1,idy-1] + u[idx+1,idy+1] + u[idx-1,idy+1])
        for i in idxp:
        #    for j in idyp:
            up[i,idyp] = 4.*u[2*i,idy]+ 2.*(u[2*i-1,idy] + u[2*i+1,idy] + u[2*i,idy-1] + u[2*i,idy+1]) + 1.*(u[2*i-1,idy-1] + u[2*i+1,idy-1] + u[2*i+1,idy+1] + u[2*i-1,idy+1])
    return up

def interpol_grid(nx,ny,u,factor,method):
    """
    increase the grid dimension by factor using method method

    currently only factor 2 is implemented

    methods include: simple interpolation: fill in missing points in the even rows, then fill in missing rows

    returns new grid solution up(nx*factor,ny*factor)
    """
    idxp = arange(0,nx*factor,factor); idx = arange(0,nx)
    idyp = arange(0,ny*factor,factor); idy = arange(0,ny)
    up = zeros(nx*factor*ny*factor).reshape(nx*factor,ny*factor)
    #put grid points of coarse grid onto even grid points of fine grid
    for i in idx: up[2*i,idyp] = u[i,idy]
    #interpolate along x axis
    for i in arange(1,ny*factor,2): up[1:nx*factor-1,i] = 0.5*(up[0:nx*factor-2,i] + up[2:nx*factor,i])
    #interpolate the lines inbetween
    for i in arange(1,ny*factor-1,2): up[0:nx*factor,i] = 0.5*(up[0:nx*factor,i-1] + up[0:nx*factor,i+1])
    return up

#this allows us to use this file as a module, but also to call it as a standalone script
if __name__ == "__main__":
    g = Gnuplot.Gnuplot()
    s = source(s,NX/4*dx,NY/4*dy,8.,-1.)
    #s = source(s,NX/2*dx,NY/2*dy,1.,4.)
    #s = source(s,NX/4*dx,3*NY/4*dy,1.,-1.)
    #s = source(s,3*NX/4*dx,NY/4*dy,1.,-1.)
    s = source(s,3*NX/4*dx,3*NY/4*dy,8.,1.)
    #plot_sol(s,NX,NY,'source')
    #raw_input('press return to continue')
    up = s.copy()
    t0 = time.clock()    
    #do a few passes on fine grid
    upc=up.copy()
    n_iter=8;nx=NX;ny=NY;method='zebra'
    up, eps = iterate(n_iter,nx,ny,upc,s,bc,method)
    print 'passed 1st set of iterations on ', nx, ' x ', ny, ' grid'
    print 'norm of eps = ', norm(eps)
    #raw_input('press return to continue')

    print '---------------------------------------------------'

    #transfer to coarser grid
    factor=2; 
    meth = 'full'
    upp = reduce_grid(nx,ny,up,factor,meth)
    sp = reduce_grid(nx,ny,s,factor,meth)
    bcp = reduce_grid(nx,ny,bc,factor,meth)
    #plot_sol(sp,nx/factor,nx/factor,'source'); raw_input('press return to continue')
    #plot_sol(bcp,nx/factor,ny/factor,'BCs'); raw_input('press return to continue')
    #print 'maxima of upp, sp, and bcp are: ', upp.max(), sp.max(), bcp.max()
    #do a few passes
    uppc=upp.copy()
    n_iter=8;nx=nx/factor;ny=ny/factor
    upp, epps = iterate(n_iter,nx,ny,uppc,sp,bcp,method)
    print 'passed 2nd set of iterations (on ', nx, ' x ', ny, ' grid)'
    print 'norm of eps = ', norm(epps)
    up = upp.copy()
    #raw_input('press return to continue')

    print '---------------------------------------------------'

    #transfer to coarser grid
    factor=2; 
    uppp = reduce_grid(nx,ny,up,factor,meth)
    spp = reduce_grid(nx,ny,sp,factor,meth)
    bcpp = reduce_grid(nx,ny,bcp,factor,meth)
    print 'maxima of upp, sp, and bcp are: ', uppp.max(), spp.max(), bcpp.max()
    #do a few passes
    uppc=uppp.copy()
    n_iter=8;nx=nx/factor;ny=ny/factor
    upp, epps = iterate(n_iter,nx,ny,uppc,spp,bcpp,method)
    print 'passed 3rd set of iterations (on ', nx, ' x ', ny, ' grid)'
    print 'norm of eps = ', norm(epps)
    up = upp.copy()
    ##raw_input('press return to continue')

    print '---------------------------------------------------'

    #transfer back to finer grid
    upp = interpol_grid(nx,ny,up,factor,meth)
    #do a few passes
    upc=upp.copy()
    n_iter=8;nx=nx*factor;ny=ny*factor
    upp, epps = iterate(n_iter,nx,ny,upc,sp,bcp,method)
    up = upp.copy()
    print 'passed 4th set of iterations (on ', nx, ' x ', ny, ' grid)'
    print 'norm of eps = ', norm(epps)
    #raw_input('press return to continue')

    print '---------------------------------------------------'

    #transfer back to finer grid
    upp = interpol_grid(nx,ny,up,factor,meth)
    #do a few passes
    upc=upp.copy()
    n_iter=16;nx=nx*factor;ny=ny*factor
    upp, epps = iterate(n_iter,nx,ny,upc,s,bc,method)
    up = upp.copy()
    print 'passed 5th set of iterations (on ', nx, ' x ', ny, ' grid)'
    print 'norm of eps = ', norm(epps)
    raw_input('press return to continue')

    t1 = time.clock() 
    print 'Simulation took %5.3f seconds'%(t1-t0)
    #plot_sol(uf)
