# This scripts is doing...
# Patrick Kuehl, date
import time
from pylab import *
from datetime import datetime
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import numpy as np
import numpy.ma as ma
from matplotlib.colors import LogNorm
from scipy.optimize import leastsq
from matplotlib import dates
import time
import datetime as date
from datetime import datetime
from datetime import timedelta
import os
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 
from matplotlib.backends.backend_pdf import PdfPages
from mpl_toolkits.mplot3d import Axes3D
import colormaps as cmaps

"""
creates spectra and time plots from lvl3 varbins
define paths here:
"""
work_folder="/home/pacifix/kuehl/work/datacorrection/ephin_lvl3_varbins/work_folder/"
data_folder="/data/projects/soho/ephin/level3/l3i_varbins/"



# helper

def totimestamp(dt, epoch=datetime(1995,1,1)):
    td = dt - epoch
    #return td.total_seconds()
    return (td.microseconds + (td.seconds + td.days * 86400) * 10**6) / 10**6 

def fromtimestamp(stamp, epoch=datetime(1995,1,1)):
    mydate= epoch+timedelta(seconds=stamp)
    return mydate


""" now the actual functions"""
# load level3 vabin data
def load_level3_l3i_varbin_macro(macro_name,macro,year,doy,tres,annual=False,mission=False,unpack=True):
  # tres in minutes, e.g. 1, 5, 10, 30, 60, 1440
  # mission: entire mission (for 60 and 1440 min only), annual: entire year
  number_of_channels=len(macro["chnames"])
  path=data_folder+macro_name+"/"
  if mission==True:
    data=np.loadtxt(path+"%imin/entire_misson_%imin.l3i"%(tres,tres))
  elif annual==True:
    data=np.loadtxt(path+"%imin/%i.l3i"%(tres,year))
  else:
    data=np.loadtxt(path+"%imin/%i/%i_%03d.l3i"%(tres,year,year,doy))
  if unpack==True:
    l= np.hsplit(data, 1)[0].T  #[0].T
    tyear, tmonth, tday, tdoy, thour, tminute, tstatus, taccumtime=l[0],l[1],l[2],l[3],l[4],l[5],l[6],l[7]
    chs_flux=[]
    chs_stat=[]
    chs_sys=[]
    for q in range(number_of_channels):
      chs_flux.append(l[8+3*q])
      chs_sys.append(l[9+3*q])
      chs_stat.append(l[10+3*q])
    return tyear, tmonth, tday, tdoy, thour, tminute, tstatus, taccumtime,chs_flux,chs_sys,chs_stat
  else:
    return data

# loads macro file
def load_macro(macrofile):
  f=open(macrofile,"r")
  macro={"chnames":[],"chtypes":[],"chranges":[],"chgeoms":[],"chgeomssys":[],"chgeomsoff":[],"chgeomsoffsys":[],"chabmin":[],"chabmax":[],"chemin":[],"chemax":[],"chemean":[],"chelogmean":[]}
  lists=["chnames","chtypes","chranges","chgeoms","chgeomssys","chgeomsoff","chgeomsoffsys","chabmin","chabmax","chemin","chemax"]
  for line in f:
    if line[0]!="#":
      line=line.replace("\n","")
      for i in range(30): line=line.replace("  "," ")
      sline=line.split(" ")
      for idx,val in enumerate(sline):
        if idx>2: val=float(val)
        macro[lists[idx]].append(val)
  for q in range(len(macro["chemin"])):
    #if macro["chnames"][q][0].lower()=="p": ttype="p"
    #if macro["chnames"][q][0].lower()=="h": ttype="h"
    #macro["chtypes"].append(ttype)
    macro["chemean"].append( (macro["chemin"][q]+macro["chemax"][q])/2. )
    macro["chelogmean"].append( np.sqrt(macro["chemin"][q]*macro["chemax"][q]) )
  f.close()
  return  macro


# plot spectra
def plot_spectra_for_macrolist(macro_names,year,doy,tres,shour=0,smin=0,ehour=23,emin=59,plot_name="level3_varbins_spectra.pdf", bincenter="chelogmean",protononly=False):
  # macro_names example: macro_names=["nominal_ephinbins","logspaced_bins_2+6","logspaced_bins_4+12"]
  macros=[]
  for mn in macro_names:
    macros.append(load_macro(work_folder+mn+".l3imacro"))
  data={}
  for idx,macro in enumerate(macros):
    number_of_channels=len(macro["chnames"])
    macro_name=macro_names[idx]
    print macro_name
    tyear, tmonth, tday, tdoy, thour, tminute, tstatus, taccumtime,chs_flux,chs_stat,chs_sys=load_level3_l3i_varbin_macro(macro_name,macro,year,doy,tres,annual=False,mission=False,unpack=True)
    data[macro_name]={}
    vals=[tyear, tmonth, tday, tdoy, thour, tminute, tstatus, taccumtime,chs_flux,chs_sys,chs_stat]
    names=["tyear","tmonth","tday","tdoy","thour","tminute","tstatus","taccumtime","chs_flux","chs_sys","chs_stat"]
    for q in range(len(vals)):
      data[macro_name][names[q]]=vals[q]
  plt.figure()
  if tres!=1440:
    si,ei=-1,-1
    for step in range(len(data[macro_names[0]]["tyear"])):
      hour=np.array(data[macro_names[0]]["thour"])[step]
      minute=np.array(data[macro_names[0]]["tminute"])[step]
      sdiff= date.datetime(2000,1,1,int(shour),int(smin)) - date.datetime(2000,1,1,int(hour),int(minute))
      ediff= date.datetime(2000,1,1,int(ehour),int(emin)) - date.datetime(2000,1,1,int(hour),int(minute))
      if sdiff.total_seconds()>=0: si=step
      if ediff.total_seconds()>=tres*60: ei=step
    rshour=np.array(data[macro_names[0]]["thour"])[si]
    rsmin=np.array(data[macro_names[0]]["tminute"])[si]
    rehour=np.array(data[macro_names[0]]["thour"])[ei+1]
    remin=np.array(data[macro_names[0]]["tminute"])[ei+1]
    plt.title("%s min averages - %s-%03d %02d:%02d-%02d:%02d"%(tres,year,doy,rshour,rsmin,rehour,remin))
  else: plt.title("daily averages - %s-%03d"%(year,doy))
  for idx,macro in enumerate(macros):
    number_of_channels=len(macro["chnames"])
    macro_name=macro_names[idx]
    if tres!=1440:
      tfluxes=np.average(np.array(data[macro_name]["chs_flux"])[:,si:ei+1],axis=1)
    else: tfluxes=np.array(data[macro_name]["chs_flux"])
    if protononly==False:
      plt.plot(macro[bincenter][:number_of_channels/2],tfluxes[:number_of_channels/2],"o-",label="p: "+macro_name)
      plt.plot(macro[bincenter][number_of_channels/2:],tfluxes[number_of_channels/2:],"s-",label="he: "+macro_name)
    else:
      mybincenter,myfluxes=[],[]
      for runner,ttype in enumerate(macro["chtypes"]):
        if ttype.lower()=="p": 
          mybincenter.append(macro[bincenter][runner])
          myfluxes.append(tfluxes[runner])
      plt.plot(mybincenter,myfluxes,"o-",label="p: "+macro_name,mec="w")
  plt.xscale("log")
  plt.yscale("log")
  plt.legend(fontsize=10,loc=3)
  plt.xlabel("Energy / MeV/nuc")
  plt.ylabel("Flux / (cm$^2$ s sr MeV/nuc)$^{-1}$")
  plt.show()


# plot spectra
def plot_timeseries_for_macro(macro_name,syear,eyear,sdoy,edoy,tres,plot_name="level3_varbins_timeseries.pdf"):
  if syear!=eyear and tres<60:
    print "Can plot multiple years only for tres=60 or tres=1440!"
    print "aborting..."
    #break
  # macro_names example: macro_names=["nominal_ephinbins","logspaced_bins_2+6","logspaced_bins_4+12"]
  macro=load_macro(work_folder+macro_name+".l3imacro")
  data={}
  number_of_channels=len(macro["chnames"])
  if syear==eyear:
    tyear, tmonth, tday, tdoy, thour, tminute, tstatus, taccumtime,chs_flux,chs_stat,chs_sys=load_level3_l3i_varbin_macro(macro_name,macro,syear,1,tres,annual=True,mission=False,unpack=True)
  if syear!=eyear:
    tyear, tmonth, tday, tdoy, thour, tminute, tstatus, taccumtime,chs_flux,chs_stat,chs_sys=load_level3_l3i_varbin_macro(macro_name,macro,syear,1,tres,annual=True,mission=True,unpack=True)
  datearray=[(datetime(int(tyear[date]),1,1,int(thour[date]),int(tminute[date]))+timedelta(int(tdoy[date])-1)) for date in range(len(tyear))]
  data[macro_name]={}
  vals=[tyear, tmonth, tday, tdoy, thour, tminute, tstatus, taccumtime,chs_flux,chs_sys,chs_stat]
  names=["tyear","tmonth","tday","tdoy","thour","tminute","tstatus","taccumtime","chs_flux","chs_sys","chs_stat"]
  for q in range(len(vals)):
    data[macro_name][names[q]]=vals[q]
  plt.rcParams.update({'font.size': 8})
  fig=plt.figure(figsize=(8.3,11.7))
  plt.subplots_adjust(left=0.09, right=0.96, top=0.98,bottom=0.03, wspace=0.01, hspace=0.1)
  axes=[]
  for chidx in range(number_of_channels):
    if chidx==0:
      axes.append(plt.subplot2grid((number_of_channels, 1), (chidx, 0)))
    else:
      axes.append(plt.subplot2grid((number_of_channels, 1), (chidx, 0), rowspan=1,sharex=axes[0]))  
    parttype=macro["chnames"][chidx][0]
    if parttype.lower()=="p": mycolor="r"
    elif parttype.lower()=="h": mycolor="g"
    else: mycolor="b"
    plt.plot_date(datearray,data[macro_name]["chs_flux"][chidx],drawstyle="steps-pre",ls="-",ms=0,label=macro["chnames"][chidx]+" (%2.2f-%2.2f MeV/nuc)"%(macro["chemin"][chidx],macro["chemax"][chidx]),rasterized=True,color=mycolor)
    plt.yscale("log")
    plt.legend(fontsize=8,loc=2)
    if chidx==number_of_channels/2: plt.ylabel("Flux / (cm$^2$ s sr MeV/nuc)$^{-1}$")
    if chidx!=number_of_channels-1:
      plt.setp(axes[chidx].get_xticklabels(), visible=False) # do not show tick labels
    if syear!=eyear:
      plt.xlim(datetime(syear,1,1)+timedelta(int(sdoy)-1),datetime(eyear,1,1)+timedelta(int(edoy)-1))
    else:
      plt.xlim(datetime(syear,1,1)+timedelta(int(sdoy)-1),datetime(eyear,1,1)+timedelta(int(edoy)-1))
  plt.gca().xaxis.set_major_formatter(mdates.DateFormatter("%b %d, %Y\n%H:%M:%S"))
  plt.show()
  #plt.savefig(plot_name) 



plottype=raw_input("Spectra (s) or time series (t) ")
bintype=int(raw_input("Select bins: nominal (0), log.spaced 4 bins (1), 8 bins (2) or 16 bins (3) "))
if bintype==0:
  macro_name="nominal_ephinbins"
elif bintype==1:
  macro_name="logspaced_bins_1+3"  
elif bintype==2:
  macro_name="logspaced_bins_2+6"
elif bintype==3:
  macro_name="logspaced_bins_4+12"

tres=int(raw_input("select timeresolution: 1, 5, 10, 30, 60 or 1440 minutes "))




if plottype=="t":
  mysyear=int(raw_input("Start Year: "))
  myeyear=int(raw_input("End Year: "))
  mysdoy=raw_input("Start Doy ('mm-dd' allowed as well): ")
  myedoy=raw_input("Start End ('mm-dd' allowed as well): ")

  try: mysdoy=int(mysdoy)
  except:
      mon=int(mysdoy.split("-")[0])
      day=int(mysdoy.split("-")[1])
      mydate=datetime(int(mysyear),int(mon),int(day))
      mysdoy=int(mydate.strftime("%j"))
  try: myedoy=int(myedoy)
  except:
      mon=int(myedoy.split("-")[0])
      day=int(myedoy.split("-")[1])
      mydate=datetime(int(myeyear),int(mon),int(day))
      myedoy=int(mydate.strftime("%j"))

  plot_timeseries_for_macro(macro_name,mysyear,myeyear,mysdoy,myedoy,tres)

if plottype=="s":
  mysyear=int(raw_input("Year: "))
  mysdoy=raw_input("Select Doy ('mm-dd' allowed as well): ")
  try: mysdoy=int(mysdoy)
  except:
      mon=int(mysdoy.split("-")[0])
      day=int(mysdoy.split("-")[1])
      mydate=datetime(int(mysyear),int(mon),int(day))
      mysdoy=int(mydate.strftime("%j"))
  if tres!=1440:  
    shour,smin=raw_input("Select Start time ('hh:mm'): ").split(":")
    ehour,emin=raw_input("Select End time ('hh:mm'): ").split(":")
  else: shour,ehour,smin,emin=0,0,0,0 # takes daily spectrum anyhow
  plot_spectra_for_macrolist([macro_name],mysyear,mysdoy,tres,shour=shour,smin=smin,ehour=ehour,emin=emin,plot_name="level3_varbins_spectra.pdf", bincenter="chelogmean")






