#!/usr/bin/env python

"""A simple script to fly a population of independent particles

   The idea is to allow particles to fly according to the Lorentz-
   force they experience.

   Author: Bob Wimmer, IEAP, Univ. Kiel
   wimmer@physik.uni-kiel.de

   Date: June 9, 2009

"""


from numpy import *
from math import *
import Gnuplot, Gnuplot.funcutils
import time
from scipy.integrate import odeint

#define global variables
q_e_2_m_e = -1.758820150e11 #electron charge-to-mass ratio
q_e_2_m_p = 9.57883392e7    #proton charge-to-mass ratio
m_e = 9.10938215e-31
m_e_2_m_p = 5.4461702177e-4
m_p_2_m_e = 1836.15267247
one_amu = 1.660538782e-27
m_p = 1.672621637e-27
m_p_amu = 1.00727646677
m_4He = 6.64465620e-27
q_e = 1.602176487e-19
c_light = 299792458
mu_0 = 4.*pi*1.e-7
N_A = 6.02214179e23
k_B = 1.3806504e-23
E_conv = 931494.028233      #multiply E by this to obtain energy keV
			    #Note: I calculate E with m in amu and c_light=1
			    #Thus, E_conv = c_light**2*one_amu/q_e
			#Note: this is changing right now! Sept. 21, 2008
dipole_moment=[0.,0.,1.e19]




def lorentz_gamma(speed):
    """return the lorentz correction factor"""
    if speed > c_light: return "speed has to be less than speed of light!"
    else:
       return 1./sqrt(1.-(speed/c_light)**2)

class particle(object):
     """ particles are defined by mass, charge, velocity, and position (vectors)

     """
     KINDS = ["electron", "proton", "3He", "4He", "C", "N", "O", "20Ne", "22Ne", "Mg", "Si", "Fe"]     



     def __init__(self, kind, charge, velocity,position):
        """mass is rest mass, energy is relativistic"""
	self.kind = kind
	if kind == "electron":
	   self.mass = m_e_2_m_p
        elif kind == "proton":
	   self.mass = m_p_amu
        elif kind == "3He":
	   self.mass = 3.
        elif kind == "4He":
	   self.mass = m_p
        elif kind == "C":
	   self.mass = 12.  #exact
        elif kind == "N":
	   self.mass = 14.
        elif kind == "O":
	   self.mass = 16.
        elif kind == "20Ne":
	   self.mass = 20.
        elif kind == "22Ne":
	   self.mass = 22.
        elif kind == "Mg":
	   self.mass = 24.
        elif kind == "Fe":
	   self.mass = 56.
        else:
	   print "kind not yet implemented - defaulting to proton"
	   self.mass = m_p
	self.charge = charge
	self.velocity = velocity
        self.speed = sqrt(vdot(self.velocity,self.velocity))
	self.position = position
	#print self.position
	self.trajectory = []
	self.trajectory.append(position)
	self.energy = self.mass*(1./sqrt(1.-(self.speed/c_light)**2) - 1.)
	self.epm  = self.energy/self.mass

     def __str__(self):
	rep1 = self.kind + ", mass = " + repr(self.mass) + ", E = " + repr(self.energy*E_conv) + "keV, E/m = " + repr(self.epm*E_conv) + " keV/nuc.\n"
	rep2 = "velocity = [" + repr(self.velocity[0]) + ", " + repr(self.velocity[1]) + ", " + repr(self.velocity[2]) + "]\n"
	rep3 = "position = [" + repr(self.position[0]) + ", " + repr(self.position[1]) + ", " + repr(self.position[2]) + "]"
	rep4 = "trajectory = [" + repr(self.trajectory) + "]"
	return rep1 + rep2 + rep3 + "\n" + rep4

     def change_dir(self,theta,phi):
        """only change direction, not speed (magnitude of velocity"""
	self.velocity[0] = self.speed*cos(theta)*cos(phi)
	self.velocity[1] = self.speed*cos(theta)*sin(phi)
	self.velocity[2] = self.speed*sin(theta)
	#print theta*180/pi, phi*180./pi


     def change_speed(self,speed):
        """only change speed (magnitude of velocity), but not direction"""
	if speed > c_light:
	   print "nonsense! - speed can't be larger than c!"
	   speed = 0.99999*c_light
	fac = divide(self.velocity,3.)
	self.speed = speed
	self.velocity = self.speed*fac  
	self.energy = self.mass*(1./sqrt(1.-(speed/c_light)**2) - 1.)
	self.epm = self.energy/self.mass

	
     def gyro_radius(self):
	"""returns particle gyro radius"""
	B_mag = sqrt(dot(B_field(self.position),B_field(self.position)))
	pitch_angle = acos(dot(self.velocity,B_field(self.position))/self.speed/B_mag)
	gyro_r = (self.mass*one_amu)/(self.charge*q_e)*(self.speed*sin(pitch_angle)/B_mag)*lorentz_gamma(self.speed)
	return gyro_r



     def propagate(self,dt):
	"""Propagate the particle by time step dt in space.
	   updates position and trajectory of the particle.
        """
	vel = self.velocity
	x = self.position
	y = hstack([x,vel])
	B_mag = sqrt(dot(B_field(x),B_field(x)))
	pitch_angle = acos(dot(vel,B_field(x))/self.speed/B_mag)
	#print "---------------" + str(lorentz_gamma(self.speed))
	gyro_r = (self.mass*one_amu)/(self.charge*q_e)*(self.speed*sin(pitch_angle)/B_mag)*lorentz_gamma(self.speed)
	#print B_mag, (self.mass*one_amu), (self.charge*q_e),(self.speed*sin(pitch_angle)/B_mag),self.speed,lorentz_gamma(self.speed)
	gyro_f = (self.charge*q_e)/(self.mass*one_amu)*B_mag/lorentz_gamma(self.speed)
	#print B_field(x), pitch_angle, gyro_r/self.gyro_radius(), gyro_f,self.speed
	n = 3.
        t=arange(0.,n,1,float)
	t = dt*t/n
	t = append(t,dt)
	#print dt,t
	def deriv(y,t):
	    """returns the derivative of y
	    """
	    x,v = hsplit(y,2)
	    xdot = v
	    vdot = multiply(self.charge*q_e/(self.mass*one_amu),cross(v,B_field(x)))
	    #print t, xdot, vdot
	    #time.sleep(20)
	    z = hstack([xdot, vdot])
	    return z
	z,info = odeint(deriv,y,t,full_output=True)
	#print info	
	self.position, self.velocity = hsplit(z[-1],2)
	#print self.position,self.velocity, self.speed
	self.trajectory.append(self.position)
	#print self.speed


     def plot_trajectory(self):
	"""Plot the trajectory of the particle using gnuplot
        """
    	g = Gnuplot.Gnuplot()
	g.reset()
	g('set ylabel "y-axis [arb. units]"')
	g('set zlabel "z-axis [arb. units]"')
	g('set xlabel "x-axis [arb. units]"')
        #g('set term postscript eps color')
        #g('set output "dipole_drift.eps"')
	#g('set view 30, 30, 1, 1')
	g('set style data lines')
	traj = asanyarray(self.trajectory)
	#print traj
	loli = traj.min(axis=0) #low limits
	hili = traj.max(axis=0) #high limits
	print loli, hili
	lololi = loli.min()
	hihili = hili.max()
	#print lololi, hihili
	range = "["+str(lololi)+":"+str(hihili)+"]"
	#print range
	g('set zrange '+range)
	g('set xrange '+range)
	g('set yrange '+range)
	#g('set zrange [10:12]')
	#x,y,z = array_split(self.trajectory,3,axis=0)
	#print self.trajectory
	g.splot(self.trajectory,title='particle trajectory')
	old_pos = self.trajectory[0]
	#z = time.sleep(20.00) 	#wait a little so you can look at it
	raw_input('Please press the return key to continue...\n')



def B_field(x):
    """returns a dipolar magnetic field"""
#    return [0.,0.,0.1]
    if dot(x,x) == 0.: 
	return multiply(2.*mu_0/3.,dipole_moment)
    else:
        a = multiply(3.*dot(dipole_moment,x),x)
        r = sqrt(dot(x,x))
        b = multiply(r**(-5.),a) - multiply(r**(-3.),dipole_moment)
        return multiply(mu_0/(4.*pi),b)



class distribution(particle):
    """ model distribution, e.g., a power law, multiple power law, etc.
    Allows to move an entity (particle or wave packet) from one place 
    within the distribution to the other. 
    Try to conserve momentum, energy, particle number, etc. later on

    """

    def __init__(self):
        self.particles = []

    def __clear__(self):
	self.particles = []

    def __str__(self):
	for particle in self.particles:
	    print particle
        rep= "-----------------"
	return rep


    def add(self,particle):
	"""add a particle to the distribution"""
	self.particles.append(particle)

    def delete(self,particle):
	"""delete a particle from the distribution"""
	self.particles.remove(particle)

	    
    def populate(self,type,N,kind,charge,origin,arg):
	"""populate a distribution with particles of kind, charge, mass at location origin
	   and distributed according to type
	"""
	print type
	if type == "thermal": 
	   #populate_thermal(N,kind,charge,origin,arg)
	   """populate the distribution with N particles (of kind, charge, mass) 
	   at location origin, and
	   distributed according to Boltzmann-Maxwell with temperature T.
	   """
	   E_scale = k_B*arg/E_conv/q_e  
	   #print "E_scale = " + str(E_scale*E_conv) + " eV"
	   #this recovers my funny energy cale in which c=1 and mass is in amu
	   for i in xrange(0,N,1):
	      vel = [0.,0.,0.] #only so I can define the particle already to access its mass
	      p = particle(kind,charge,vel,origin)
	      #get a random number for kinetic energy, Boltzmann-distributed 
	      E = random.exponential(E_scale)
	      #invert this random energy to speed in units c_light
	      #print p.mass, E, 1.+(E/p.mass)**2
	      #print "You'd better check whether this is normalized ok. I haven't gone through changing velocity units to m/s"
	      speed = sqrt(1. - 1./((E/p.mass)**2+1.))*c_light
	      #print speed
	      p.change_speed(speed)
	      #get random angles theta and phi and change velocity direction
	      p.change_dir(asin(random.uniform(-1,1)),random.uniform(0,2*pi))
	      self.particles.append(p)
#	elif type == "power": 
#	   #populate_power(N,kind,charge,origin,arg)
#           """populate the distribution with N particles (of kind, charge, mass) 
#	   at location origin, and
#	   distributed according to a power law with exponent gamma.
#	   dJ/dE \propto E^{-\gamma}
#	   """
#	   for i in xrange(0,N,1):
#	      vel = [0.,0.,0.] #only so I can define the particle already to access its mass
#	      p = particle(kind,charge,vel,origin)
#	      #get a random number for kinetic energy, distributed according to a power law
#	      E = random.power(arg)
#	      #invert this random energy to speed in units c_light
#	      #print p.mass, E, 1.+(E/p.mass)**2
#	      speed = sqrt(1. - 1./((E/p.mass)**2+1.))
#	      p.change_speed(speed)
#	      #get random angles theta and phi and change velocity direction
#	      p.change_dir(asin(random.uniform(-1,1)),random.uniform(0,2*pi))
#	      self.particles.append(p)
	else: 
	   print "This distribution is not defined"

    def propagate(self,time):
	"""Propagate the population for time step time.
	   Uses particle propagate mehtod particle.propagate.
	"""
	for p in self.particles:
	   #propagate single particles
	   elapsed_time = 0.
	   while elapsed_time < time:
	   #for i in xrange(0,int(time),1):
	       if random.uniform(0,1) < 0.1:#this equals the number of scatterings per third of a gyroperiod (propagation time)
	          p.change_dir(asin(random.uniform(-1,1)),random.uniform(0,2*pi))
	          #p.change_speed(random.uniform(0.,0.01))
               p.propagate(0.3*p.gyro_radius()/p.speed)
	       elapsed_time = elapsed_time + 0.3*p.gyro_radius()/p.speed


    def plot_trajectories(self):
	"""Plot the trajectory of the particle using gnuplot
        """
    	g = Gnuplot.Gnuplot()
	g.reset()
	g('set ylabel "y-axis [arb. units]"')
	g('set xlabel "x-axis [arb. units]"')
	#g('set view 30, 30, 1, 1')
	g('set style data lines')
	#g('set zrange [-0.5:0.5]')
	#g('set xrange [-1:1]')
	#g('set yrange [-1:1]')
	#g('set zrange [-1:1]')
	#g('set xyplane at 0')
	#x,y,z = array_split(self.trajectory,3,axis=0)
	#print self.trajectory, x,y,z
	data = [[0.,0.,0.],[0.,0.,0.]]
	g.splot(data,title='particle trajectories')
        first = True
	for particle in self.particles:
	    traj = asanyarray(particle.trajectory)
	    #print traj
	    loli = traj.min(axis=0) #low limits
	    hili = traj.max(axis=0) #high limits
	    lo = loli.min(); hi = hili.max()
	    #print loli, hili, lo, hi
	    if first:
	        old_lo = lo
	   	old_hi = hi
		first = False
	    #print "---------------"
	    #print lo, old_lo, hi, old_hi
	    old_lo = min(lo,old_lo)
	    old_hi = max(hi,old_hi)
	    #print lo, old_lo, hi, old_hi
	    range = "["+str(old_lo)+":"+str(old_hi)+"]"
	    #range = "[" + str(loli.min()) + ":" + str(hili.max()) + "]"
	    g('set zrange '+range)
	    g('set xrange '+range)
	    g('set yrange '+range)
	    g.replot(particle.trajectory,title='particle trajectories')
	#z = time.sleep(20.00) 	#wait a little so you can look at it
	raw_input('Please press the return key to continue...\n')

    def statistics(self):
	"""Calculate statistics of the distribution
	"""
	#pass  #do at a later time
	#ave_pos = an array of average positions at each time step
	for particle in self.particles:
	     ave_pos = particle.position.mean()#This gives mean at the end
	     std_pos = particle.position.std()
	     ave_vel = particle.velocity.mean()
	     std_vel = particle.velocity.std()
	print ave_pos, std_pos, ave_vel, std_vel
	#ave_vel = an array of average velocities at each time step
	#ave_speed = an array of average speeds at each ime step
	#sig_pos = an array of average standard deviations of position at each time step
	#sig_speed = etc.
	#sig_vel = ...
	#any other stuff we need later on


def main():
    t_0 = time.clock()
    origin = [0.,1.e6,0.]
    print B_field(origin)
#    p = particle("proton",1,array([0.,1.e6,1.e6]),origin)
#    #p.propagate(1.)
#    for i in xrange(0,64,1):
#        p.propagate(0.3*p.gyro_radius()/p.speed)
#    t_1 = time.clock()
#    print "calculations took " + str(t_1 - t_0) + " seconds"
#    p.plot_trajectory()
    f = distribution()
    temp = 1.e7; gamma = -2.
    f.populate("thermal",1,"proton",1,origin,temp)
    t = 100.
    f.propagate(t)
    t_1 = time.clock()
    print "calculations took " + str(t_1 - t_0) + " seconds"
    f.plot_trajectories()
    f.statistics()

main()
