#!/usr/bin/env python
"""
Solve Poisson equation with various multigrid methods

we will usually solve on grids differing by a factor of two in one dimension, e.g., 1024x1024 --> 512x512 --> 256x256 --> 128x128

author: Bob Wimmer
date: January 4, 2013
email: wimmer@physik.uni-kiel.de

This script follows to a large extent chapter 10 of Hirsch, vol. 1.

Note that boundaries are not treated correctly. I have simply set boundary conditions to zero and don't treat the boundaries at all. Thus they remain zero at all times. This is a highly simplified way to do so...

"""
from numpy import *
from numpy.linalg import norm as norm
#from scipy.linalg import eigh
#from scipy.sparse.linalg import eigsh
import time
import Gnuplot
from try_laplace import *

NX = 512
NY = 512
dx = 1.
dy = 1.
N_iter = 8

up = zeros(NX*NY).reshape(NX,NY)   #past   (time step n)
uf = up.copy()  #future (time step n+1)
ufp = up.copy()  #future (time step n+1) (needed for over-relaxation methods)
s  = zeros(NX*NY).reshape(NX,NY)  #source term, will need to be filled in later
bc = s.copy()         #boundary conditions, to be implemented later
g = Gnuplot.Gnuplot()


def iterate(n_iter,nx,ny,up,s,bc,method):
    eps = up
    uf = up.copy()
    if method == 'jacobi':
        for i in arange(0,n_iter):
            upc = up.copy();
            uf = up.copy()
            uf = jacobi(nx,ny,upc,s,bc)
            eps = uf-upc
            up = uf
            plot_sol(up,nx,ny,str(i))
    elif method == 'gauss-seidel':
        for i in arange(0,n_iter):
            upc = up.copy();
            uf = up.copy()
            uf = gauss_seidel(nx,ny,upc,s,bc)
            eps = uf-upc
            up = uf
            plot_sol(up,nx,ny,str(i))
    elif method == 'gauss-seidel-over':
        for i in arange(0,n_iter):
            upc = up.copy();
            uf = up.copy()
            uf = gauss_seidel_over(nx,ny,upc,s,bc,0.35)
            eps = uf-upc
            up = uf
            plot_sol(up,nx,ny,str(i))
    elif method == 'red-black':
        for i in arange(0,n_iter):
            upc = up.copy();
            uf = up.copy()
            uf = red_black(nx,ny,upc,s,bc,.35)
            eps = uf-upc
            up = uf
            plot_sol(up,nx,ny,str(i))
    elif method == 'zebra':
        for i in arange(0,n_iter/2):
            upc = up.copy()
            uf = up.copy()
            uf = zebra_col(nx,ny,upc,s,bc)
            eps = uf-upc
            up = uf
            upc = up.copy();
            uf = zebra_row(nx,ny,upc,s,bc)
            eps = uf-upc
            up = uf
            plot_sol(up,nx,ny,str(i))
    return up, eps

def reduce_grid(nx,ny,u,factor,method):
    """
    reduce the grid dimension by factor using method method

    currently only factor 2 is implemented

    methods include: simple injection (simply returns the even grid points)

    returns new grid solution up(nx/factor,ny/factor)
    """
    if method == 'simple':
        idxp = arange(0,nx/factor); idx = arange(0,nx,2)
        idyp = arange(0,ny/factor); idy = arange(0,ny,2)
        up = zeros(nx/factor*ny/factor).reshape(nx/factor,ny/factor)
        #up[idxp,idyp] = u[2*idxp,2*idyp]
        #print idxp, idx
        for i in idxp:
        #    for j in idyp:
            up[i,idyp] = u[2*i,idy]
    elif method == 'half':
        idxp = arange(1,nx/factor-1); idx = arange(2,nx-2,2)
        idyp = arange(1,nx/factor-1); idy = arange(2,ny-2,2)
        up = zeros(nx/factor*ny/factor).reshape(nx/factor,ny/factor)
        #up[idxp,idyp] = 2.*u[idx,idy] + 1.*(u[idx-1,idy] + u[idx+1,idy] + u[idx,idy-1] + u[idx,idy+1])
        for i in idxp:
        #    for j in idyp:
            up[i,idyp] = 2.*u[2*i,idy]+ 1.*(u[2*i-1,idy] + u[2*i+1,idy] + u[2*i,idy-1] + u[2*i,idy+1])
    else: #do full weighting
        idxp = arange(1,nx/factor-1); idx = arange(2,nx-2,2)
        idyp = arange(1,nx/factor-1); idy = arange(2,ny-2,2)
        up = zeros(nx/factor*ny/factor).reshape(nx/factor,ny/factor)
        #up[idxp,idyp] = 4.*u[idx,idy] + 2.*(u[idx-1,idy] + u[idx+1,idy] + u[idx,idy-1] + u[idx,idy+1]) + 1.*(u[idx-1,idy-1] + u[idx+1,idy-1] + u[idx+1,idy+1] + u[idx-1,idy+1])
        for i in idxp:
        #    for j in idyp:
            up[i,idyp] = 4.*u[2*i,idy]+ 2.*(u[2*i-1,idy] + u[2*i+1,idy] + u[2*i,idy-1] + u[2*i,idy+1]) + 1.*(u[2*i-1,idy-1] + u[2*i+1,idy-1] + u[2*i+1,idy+1] + u[2*i-1,idy+1])
    return up

def interpol_grid(nx,ny,u,factor,method):
    """
    increase the grid dimension by factor using method method

    currently only factor 2 is implemented

    methods include: simple interpolation: fill in missing points in the even rows, then fill in missing rows

    returns new grid solution up(nx*factor,ny*factor)
    """
    idxp = arange(0,nx*factor,factor); idx = arange(0,nx)
    idyp = arange(0,ny*factor,factor); idy = arange(0,ny)
    up = zeros(nx*factor*ny*factor).reshape(nx*factor,ny*factor)
    #put grid points of coarse grid onto even grid points of fine grid
    for i in idx: up[2*i,idyp] = u[i,idy]
    #interpolate along x axis
    for i in arange(1,ny*factor,2): up[1:nx*factor-1,i] = 0.5*(up[0:nx*factor-2,i] + up[2:nx*factor,i])
    #interpolate the lines inbetween
    for i in arange(1,ny*factor-1,2): up[0:nx*factor,i] = 0.5*(up[0:nx*factor,i-1] + up[0:nx*factor,i+1])
    return up

#this allows us to use this file as a module, but also to call it as a standalone script
if __name__ == "__main__":
#    g = Gnuplot.Gnuplot()
    s = source(s,NX/4*dx,NY/4*dy,8.,-1.)
    #s = source(s,NX/2*dx,NY/2*dy,1.,4.)
    #s = source(s,NX/4*dx,3*NY/4*dy,1.,-1.)
    #s = source(s,3*NX/4*dx,NY/4*dy,1.,-1.)
    s = source(s,3*NX/4*dx,3*NY/4*dy,8.,1.)
    #plot_sol(s,NX,NY,'source')
    #raw_input('press return to continue')
    up = s.copy()
    t0 = time.clock()    
    #do a few passes on fine grid
    upc=up.copy()
    n_iter=8;nx=NX;ny=NY;method='zebra'
    up, eps = iterate(n_iter,nx,ny,upc,s,bc,method)
    print 'passed 1st set of iterations on ', nx, ' x ', ny, ' grid'
    print 'norm of eps = ', norm(eps)
    #raw_input('press return to continue')

    print '---------------------------------------------------'

    #transfer to coarser grid
    factor=2; 
    meth = 'full'
    upp = reduce_grid(nx,ny,up,factor,meth)
    sp = reduce_grid(nx,ny,s,factor,meth)
    bcp = reduce_grid(nx,ny,bc,factor,meth)
    #plot_sol(sp,nx/factor,nx/factor,'source'); raw_input('press return to continue')
    #plot_sol(bcp,nx/factor,ny/factor,'BCs'); raw_input('press return to continue')
    #print 'maxima of upp, sp, and bcp are: ', upp.max(), sp.max(), bcp.max()
    #do a few passes
    uppc=upp.copy()
    n_iter=8;nx=nx/factor;ny=ny/factor
    upp, epps = iterate(n_iter,nx,ny,uppc,sp,bcp,method)
    print 'passed 2nd set of iterations (on ', nx, ' x ', ny, ' grid)'
    print 'norm of eps = ', norm(epps)
    up = upp.copy()
    #raw_input('press return to continue')

    print '---------------------------------------------------'

    #transfer to coarser grid
    factor=2; 
    uppp = reduce_grid(nx,ny,up,factor,meth)
    spp = reduce_grid(nx,ny,sp,factor,meth)
    bcpp = reduce_grid(nx,ny,bcp,factor,meth)
    print 'maxima of upp, sp, and bcp are: ', uppp.max(), spp.max(), bcpp.max()
    #do a few passes
    uppc=uppp.copy()
    n_iter=8;nx=nx/factor;ny=ny/factor
    upp, epps = iterate(n_iter,nx,ny,uppc,spp,bcpp,method)
    print 'passed 3rd set of iterations (on ', nx, ' x ', ny, ' grid)'
    print 'norm of eps = ', norm(epps)
    up = upp.copy()
    ##raw_input('press return to continue')

    print '---------------------------------------------------'

    #transfer back to finer grid
    upp = interpol_grid(nx,ny,up,factor,meth)
    #do a few passes
    upc=upp.copy()
    n_iter=8;nx=nx*factor;ny=ny*factor
    upp, epps = iterate(n_iter,nx,ny,upc,sp,bcp,method)
    up = upp.copy()
    print 'passed 4th set of iterations (on ', nx, ' x ', ny, ' grid)'
    print 'norm of eps = ', norm(epps)
    #raw_input('press return to continue')

    print '---------------------------------------------------'

    #transfer back to finer grid
    upp = interpol_grid(nx,ny,up,factor,meth)
    #do a few passes
    upc=upp.copy()
    n_iter=16;nx=nx*factor;ny=ny*factor
    upp, epps = iterate(n_iter,nx,ny,upc,s,bc,method)
    up = upp.copy()
    print 'passed 5th set of iterations (on ', nx, ' x ', ny, ' grid)'
    print 'norm of eps = ', norm(epps)
    raw_input('press return to continue')

    t1 = time.clock() 
    print 'Simulation took %5.3f seconds'%(t1-t0)
    #plot_sol(uf)
