#!/usr/bin/env python

"""A simple script for single particle motion

   The idea is to allow particles to fly according to the Lorentz-
   force they experience.

   Author: Bob Wimmer, IEAP, Univ. Kiel
   wimmer@physik.uni-kiel.de

   Date: June 9, 2009

"""


from numpy import *
from math import *
import Gnuplot, Gnuplot.funcutils
import time
from scipy.integrate import odeint

#define global variables
q_e_2_m_e = -1.758820150e11 #electron charge-to-mass ratio
q_e_2_m_p = 9.57883392e7    #proton charge-to-mass ratio
m_e = 9.10938215e-31
m_e_2_m_p = 5.4461702177e-4
m_p_2_m_e = 1836.15267247
one_amu = 1.660538782e-27
m_p = 1.672621637e-27
m_p_amu = 1.00727646677
m_4He = 6.64465620e-27
q_e = 1.602176487e-19
c_light = 299792458
mu_0 = 4.*pi*1.e-7
N_A = 6.02214179e23
k_B = 1.3806504e-23
E_conv = 931494.028233      #multiply E by this to obtain energy keV
			    #Note: I calculate E with m in amu and c_light=1
			    #Thus, E_conv = c_light**2*one_amu/q_e
			#Note: this is changing right now! Sept. 21, 2008
dipole_moment=[0.,0.,1.e19]



def lorentz_gamma(speed):
    """return the lorentz correction factor"""
    if speed > c_light: return "speed has to be less than speed of light!"
    else:
       return 1./sqrt(1.-(speed/c_light)**2)

class particle(object):
     """ particles are defined by mass, charge, velocity, and position (vectors)

     """
     KINDS = ["electron", "proton", "3He", "4He", "C", "N", "O", "20Ne", "22Ne", "Mg", "Si", "Fe"]     



     def __init__(self, kind, charge, velocity,position):
        """mass is rest mass, energy is relativistic"""
	self.kind = kind
	if kind == "electron":
	   self.mass = m_e_2_m_p
        elif kind == "proton":
	   self.mass = m_p_amu
        elif kind == "3He":
	   self.mass = 3.
        elif kind == "4He":
	   self.mass = m_p
        elif kind == "C":
	   self.mass = 12.  #exact
        elif kind == "N":
	   self.mass = 14.
        elif kind == "O":
	   self.mass = 16.
        elif kind == "20Ne":
	   self.mass = 20.
        elif kind == "22Ne":
	   self.mass = 22.
        elif kind == "Mg":
	   self.mass = 24.
        elif kind == "Fe":
	   self.mass = 56.
        else:
	   print "kind not yet implemented - defaulting to proton"
	   self.mass = m_p
	self.charge = charge
	self.velocity = velocity
        self.speed = sqrt(vdot(self.velocity,self.velocity))
	self.position = position
	#print self.position
	self.trajectory = []
	self.trajectory.append(position)
	self.energy = self.mass*(1./sqrt(1.-(self.speed/c_light)**2) - 1.)
	self.epm  = self.energy/self.mass

     def __str__(self):
	rep1 = self.kind + ", mass = " + repr(self.mass) + ", E = " + repr(self.energy*E_conv) + "keV, E/m = " + repr(self.epm*E_conv) + " keV/nuc.\n"
	rep2 = "velocity = [" + repr(self.velocity[0]) + ", " + repr(self.velocity[1]) + ", " + repr(self.velocity[2]) + "]\n"
	rep3 = "position = [" + repr(self.position[0]) + ", " + repr(self.position[1]) + ", " + repr(self.position[2]) + "]"
	rep4 = "trajectory = [" + repr(self.trajectory) + "]"
	return rep1 + rep2 + rep3 + "\n" + rep4


     def gyro_radius(self):
	"""returns particle gyro radius"""
	B_mag = sqrt(dot(B_field(self.position),B_field(self.position)))
	pitch_angle = acos(dot(self.velocity,B_field(self.position))/self.speed/B_mag)
	gyro_r = (self.mass*one_amu)/(self.charge*q_e)*(self.speed*sin(pitch_angle)/B_mag)*lorentz_gamma(self.speed)
	return gyro_r

	
     def propagate(self,dt):
	"""Propagate the particle by time step dt in space.
	   updates position and trajectory of the particle.
        """
	vel = self.velocity
	x = self.position
	y = hstack([x,vel])
	B_mag = sqrt(dot(B_field(x),B_field(x)))
	pitch_angle = acos(dot(vel,B_field(x))/self.speed/B_mag)
	#print "---------------" + str(lorentz_gamma(self.speed))
	gyro_r = (self.mass*one_amu)/(self.charge*q_e)*(self.speed*sin(pitch_angle)/B_mag)*lorentz_gamma(self.speed)
	gyro_f = (self.charge*q_e)/(self.mass*one_amu)*B_mag/lorentz_gamma(self.speed)
	n = 3.
        t=arange(0.,n,1,float)
	t = dt*t/n
	t = append(t,dt)
	#print dt,t
	def deriv(y,t):
	    """returns the derivative of y
	    """
	    x,v = hsplit(y,2)
	    xdot = v
	    vdot = multiply(self.charge*q_e/(self.mass*one_amu),cross(v,B_field(x)))
	    #print t, xdot, vdot
	    #time.sleep(20)
	    z = hstack([xdot, vdot])
	    return z
	z,info = odeint(deriv,y,t,full_output=True)
	#print info	
	self.position, self.velocity = hsplit(z[-1],2)
	#print self.position,self.velocity, self.speed
	self.trajectory.append(self.position)
	#print self.speed


     def plot_trajectory(self):
	"""Plot the trajectory of the particle using gnuplot
        """
    	g = Gnuplot.Gnuplot()
	g.reset()
	g('set ylabel "y-axis [arb. units]"')
	g('set zlabel "z-axis [arb. units]"')
	g('set xlabel "x-axis [arb. units]"')
        #g('set term postscript eps color')
        #g('set output "dipole_drift.eps"')
	#g('set view 30, 30, 1, 1')
	g('set style data lines')
	traj = asanyarray(self.trajectory)
	#print traj
	loli = traj.min(axis=0) #low limits
	hili = traj.max(axis=0) #high limits
	print loli, hili
	lololi = loli.min()
	hihili = hili.max()
	#print lololi, hihili
	range = "["+str(lololi)+":"+str(hihili)+"]"
	#print range
	g('set zrange '+range)
	g('set xrange '+range)
	g('set yrange '+range)
	#g('set zrange [10:12]')
	#x,y,z = array_split(self.trajectory,3,axis=0)
	#print self.trajectory
	g.splot(self.trajectory,title='particle trajectory')
	#z = time.sleep(20.00) 	#wait a little so you can look at it
	raw_input('Please press the return key to continue...\n')



def B_field(x):
    """returns a dipolar magnetic field"""
#    return [0.,0.,0.1]
    if dot(x,x) == 0.: 
	return multiply(2.*mu_0/3.,dipole_moment)
    else:
        a = multiply(3.*dot(dipole_moment,x),x)
        r = sqrt(dot(x,x))
        b = multiply(r**(-5.),a) - multiply(r**(-3.),dipole_moment)
        return multiply(mu_0/(4.*pi),b)




def main():
    t_0 = time.clock()
    origin = [0.,1.e6,0.]
    print B_field(origin)
    p = particle("proton",1,array([0.,1.e6,1.e6]),origin)
    #p.propagate(1.)
    for i in xrange(0,64,1):
        p.propagate(0.3*p.gyro_radius()/p.speed)
    t_1 = time.clock()
    print "calculations took " + str(t_1 - t_0) + " seconds"
    p.plot_trajectory()

main()
