#!/usr/bin/env python
import cgi
import matplotlib
#matplotlib.use('Agg')

import time
from pylab import *
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import numpy.ma as ma
from matplotlib import dates
import time
import datetime as date
from datetime import datetime
from datetime import timedelta
import os
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 

np.seterr(divide='ignore', invalid='ignore')
import matplotlib.dates as mdates

def year_doy_to_name(year,doy):
  if year<2000: 
    pre="eph"
    year=year-1900
  else: 
    pre="epi"
    year=year-2000
  name="%s%02d%03d" %(pre,year,doy)
  return name


def year_doy_to_name_chandra(year,doy):
  name="a%04d%03d" %(year,doy)
  return name

def plot_daily_rl2(mysyear,myeyear,mysdoy,myedoy,log=True,mean=1,shour=0,ehour=0,orientation="Portrait",ranges=None, plotformat="PNG",timing_analysis=False,mission="s"):
  mean=int(mean)
  # myyear: int of year of interest
  # mydoy: int of dayofyear of interest  
  #         NOTE: mm-dd is also implemented
  # log: True/False for ylogscale
  # meaner: int of minutes to average over. "None" results in original (1min) resolution

  if timing_analysis: start_time = time.time()


  mysyear=int(mysyear)
  myeyear=int(myeyear)

  shour=int(shour)
  ehour=int(ehour)
  try: mysdoy=int(mysdoy)
  except:
    mon=int(mysdoy.split("-")[0])
    day=int(mysdoy.split("-")[1])
    mydate=datetime(int(mysyear),int(mon),int(day))
    mysdoy=int(mydate.strftime("%j"))
  try: myedoy=int(myedoy)
  except:
    mon=int(myedoy.split("-")[0])
    day=int(myedoy.split("-")[1])
    mydate=datetime(int(myeyear),int(mon),int(day))
    myedoy=int(mydate.strftime("%j"))

  ################## load data
  init=1
  if mission=="s": path="/data/missions/soho/costep/level2/rl2/"
  elif mission=="c": path="/data/missions/chandra/ephin/rl2/"

  
  #if all in one year
  if mysyear==myeyear:
    myyear=mysyear
    for mydoy in range(mysdoy,myedoy+1):
      if mission=="s": name=year_doy_to_name(myyear,mydoy)
      elif mission=="c": name=year_doy_to_name_chandra(myyear,mydoy)
      try:
        tdata=np.loadtxt("%s%4d/%s.rl2" %(path,myyear,name))
        dummy=tdata[2,2]  # this is to stop loading data with only one line of zeros, etc 2015 doy 12
        if init==1:
          data=tdata
          init=0
        else:
          data=np.vstack((data,tdata))
      except:
        dummy=1
        #print "no file %s!" %name  
  #enable crossyear
  else:
    tyears=range(mysyear,myeyear+1)
    tsdoy,tedoy=[],[]
    for q in range(len(tyears)):
      if q==0:  
        tsdoy.append(mysdoy)
        tedoy.append(370)
      elif q==len(tyears)-1:
        tsdoy.append(1)
        tedoy.append(myedoy)
      else:
        tsdoy.append(1)
        tedoy.append(370)
    for idx,myyear in enumerate(tyears):    
      for mydoy in range(tsdoy[idx],tedoy[idx]):
        if mission=="s": name=year_doy_to_name(myyear,mydoy)
        elif mission=="c": name=year_doy_to_name_chandra(myyear,mydoy)
        try:
          tdata=np.loadtxt("%s%4d/%s.rl2" %(path,myyear,name))
          dummy=tdata[2,2] # this is to stop loading data with only one line of zeros, etc 2015 doy 12
          if init==1:
            data=tdata
            init=0
          else:
            data=np.vstack((data,tdata))
        except:
          dummy=1
          #print "no file %s!" %name  



  if timing_analysis: print("--- %s seconds --- DATA LOADED" % (time.time() - start_time))
  # take those dataparts we are interested in
  year=data[:,0]
  doy=data[:,1]
  msec=data[:,2]
  e150=data[:,6]
  e300=data[:,7]
  e1300=data[:,8]
  e3000=data[:,9]
  p4=data[:,10]
  p8=data[:,11]
  p25=data[:,12]
  p41=data[:,13]
  he4=data[:,14]
  he8=data[:,15]
  he25=data[:,16]
  he41=data[:,17]
  inte=data[:,18]
  status=data[:,47]


  fmodes=np.zeros(len(status))
  for q in range(len(status)):
    binaries='{0:08b}'.format(int(status[q]))
    if int(binaries[-1])==1:
      if int(binaries[-3])==1: fmodes[q]=1
      else: fmodes[q]=2
  
  ringoff=np.zeros(len(status))
  for q in range(len(status)):
    binaries='{0:08b}'.format(int(status[q]))
    if int(binaries[-2]): ringoff[q]=1

  datearray=[(datetime(int(year[date]),1,1,0)+timedelta(int(doy[date])-1,int(msec[date]/1000.))) for date in range(len(year))]
  # average data if necessary
  if mean!=1:
        chs=[e150,e300,e1300,e3000,p4,p8,p25,p41,he4,he8,he25,he41,inte]
        states=[fmodes,ringoff]
        newdatearray=[]
        index=0
        #if all in one year
        if mysyear==myeyear:
          for mydoy in range(mysdoy,myedoy+1):
            startmin=0
            endmin=mean
            while endmin<=24*60:
              if index>len(datearray)-2:  break
              starttime=datetime(int(myyear),1,1)+timedelta(int(mydoy)-1,minutes=int(startmin))
              endtime=datetime(int(myyear),1,1)+timedelta(int(mydoy)-1,minutes=int(endmin))
              while datearray[index]<starttime: 
                index+=1
                if index>len(datearray)-2:  break
              index2=int(index)
              while datearray[index2]<endtime:
                index2+=1
                if index2>len(datearray)-2:  break
              for ch in chs:
                ch[len(newdatearray)]=np.nanmean(ch[index:index2])
              for state in states:
                try:  state[len(newdatearray)]=np.max(state[index:index2])
                except: state[len(newdatearray)]=0

              newdatearray.append(datetime(int(myyear),1,1)+timedelta(int(mydoy)-1,minutes=int(startmin+mean/2.)))

              startmin+=mean
              endmin+=mean
              index=int(index2)
        #enable crossyear
        else:
          tyears=range(mysyear,myeyear+1)
          tsdoy,tedoy=[],[]
          for q in range(len(tyears)):
            if q==0:  
              tsdoy.append(mysdoy)
              tedoy.append(366)
            elif q==len(tyears)-1:
              tsdoy.append(1)
              tedoy.append(myedoy)
            else:
              tsdoy.append(1)
              tedoy.append(366)
          for idx,myyear in enumerate(tyears):      
            for mydoy in range(tsdoy[idx],tedoy[idx]):
              startmin=0
              endmin=mean
              while endmin<=24*60:
                if index>len(datearray)-2:  break
                starttime=datetime(int(myyear),1,1)+timedelta(int(mydoy)-1,minutes=int(startmin))
                endtime=datetime(int(myyear),1,1)+timedelta(int(mydoy)-1,minutes=int(endmin))
                while datearray[index]<starttime: 
                  index+=1
                  if index>len(datearray)-2:  break
                index2=int(index)
                while datearray[index2]<endtime:
                  index2+=1
                  if index2>len(datearray)-2:  break
                for ch in chs:
                  ch[len(newdatearray)]=np.nanmean(ch[index:index2])
                for state in states:
                  try:  state[len(newdatearray)]=np.max(state[index:index2])
                  except: state[len(newdatearray)]=0

                newdatearray.append(datetime(int(myyear),1,1)+timedelta(int(mydoy)-1,minutes=int(startmin+mean/2.)))

                startmin+=mean
                endmin+=mean
                index=int(index2)
          




        datearray=newdatearray



        e150=chs[0][:len(datearray)]
        e300=chs[1][:len(datearray)]
        e1300=chs[2][:len(datearray)]
        e3000=chs[3][:len(datearray)]
        p4=chs[4][:len(datearray)]
        p8=chs[5][:len(datearray)]
        p25=chs[6][:len(datearray)]
        p41=chs[7][:len(datearray)]
        he4=chs[8][:len(datearray)]
        he8=chs[9][:len(datearray)]
        he25=chs[10][:len(datearray)]
        he41=chs[11][:len(datearray)]
        inte=chs[12][:len(datearray)]

        fmodes=states[0][:len(datearray)]
        ringoff=states[1][:len(datearray)]

  if timing_analysis: print("--- %s seconds --- DATA AVERAGED" % (time.time() - start_time))
  # create NaNs in datagap
  q=0
  chs=[e150,e300,e1300,e3000,p4,p8,p25,p41,he4,he8,he25,he41,inte]
  while q < (len(datearray)-3):
    if datearray[q+1]-datearray[q] > timedelta(minutes=int(mean*2)):
      for w in range(len(chs)):
        chs[w]=np.insert(chs[w],[q+1,q+2],[np.nan,np.nan])

      fmodes=np.insert(fmodes,[q+1,q+2],[fmodes[q],fmodes[q]])
      ringoff=np.insert(ringoff,[q+1,q+2],[ringoff[q],ringoff[q]])

      datearray=np.insert(datearray,[q+2,q+3],[datearray[q+1]+timedelta(minutes=mean/2),datearray[q+2]-timedelta(minutes=mean/2)])
      q+=2
    else: q+=1

  for w in range(len(chs)): chs[w]=ma.masked_invalid(chs[w])
  e150=chs[0]
  e300=chs[1]
  e1300=chs[2]
  e3000=chs[3]
  p4=chs[4]
  p8=chs[5]
  p25=chs[6]
  p41=chs[7]
  he4=chs[8]
  he8=chs[9]
  he25=chs[10]
  he41=chs[11]
  inte=chs[12]

  if timing_analysis: print("--- %s seconds --- DATA GAPS MARKED" % (time.time() - start_time))
  ################## now plot data
  if orientation=="Portrait":
    fig=plt.figure(figsize=(8.3,11.7))
    plt.subplots_adjust(left=0.10, right=0.9, top=0.97,bottom=0.055, wspace=None, hspace=0.000)
  if orientation=="Landscape":
    fig=plt.figure(figsize=(11.7,8.3))
    plt.subplots_adjust(left=0.07, right=0.93, top=0.957,bottom=0.075, wspace=None, hspace=0.000)
  mycolors=["deepskyblue","mediumslateblue","steelblue","blue","orange","tomato","indianred","red","lightgreen","lawngreen","forestgreen","g","k"]
  mylabels=["0.25 - 0.70 MeV","0.67 - 3.00 MeV","2.64 - 6.18 MeV","4.80 - 10.40 MeV","4.3 - 7.8 MeV","7.8 - 25 MeV","25 - 40.9 MeV","40.9 - 53 MeV","4.3 - 7.8 MeV/nuc","7.8 - 25 MeV/nuc","25 - 40.9 MeV/nuc","40.9 - 53 MeV/nuc","e$^-$: > 8.70 MeV\np: >53 MeV\nHe: > 53 MeV/nuc"]
  lfs=10
  myncol=4
  mycs=0
  myhtp=0.1
  myds="steps-mid"  #"default"


  if max(fmodes)==1: 
    mylabels[2]="2.64 - 10.40 MeV"
    mylabels[6]="25.0 - 53.0 MeV"
    mylabels[10]="25.0 - 53.0 MeV/nuc"
  if max(fmodes)==2: 
    mylabels[2]="2.64 - 10.40 MeV (FMEnoEP)"
    mylabels[6]="25 - 53 MeV (FMEnoEP)"
    mylabels[10]="25 - 53 MeV/nuc (FMEnoEP)"

  # first panel (electrons)
  ax0 = plt.subplot2grid((4, 1), (0, 0))
  ax0.yaxis.tick_left()    

  plt.plot_date(datearray,e150,'-',color=mycolors[0],label=mylabels[0],drawstyle=myds,rasterized=True)
  plt.plot_date(datearray,e300,'-',color=mycolors[1],label=mylabels[1],drawstyle=myds,rasterized=True)
  plt.plot_date(datearray,e1300,'-',color=mycolors[2],label=mylabels[2],drawstyle=myds,rasterized=True)
  if max(fmodes)<1:  plt.plot_date(datearray,e3000,'-',color=mycolors[3],label=mylabels[3],drawstyle=myds,rasterized=True)
  
  plt.ylabel("electrons / (cm$^2$ s sr MeV)$^{-1}$")
  if log==True: plt.yscale("log")
  plt.legend(fontsize=lfs,ncol=myncol,columnspacing=mycs,handletextpad=myhtp)
  plt.setp(ax0.get_xticklabels(), visible=False) # do not show tick labels

  if ranges!=None:  plt.ylim(float(ranges[0][0]),float(ranges[0][1]))

  # second panel (protons)
  ax1 = plt.subplot2grid((4, 1), (1, 0), sharex=ax0)      
  ax1.yaxis.set_label_position("right")
  ax1.yaxis.tick_right()
  plt.plot_date(datearray,p4,'-',color=mycolors[4],label=mylabels[4],drawstyle=myds,rasterized=True)
  plt.plot_date(datearray,p8,'-',color=mycolors[5],label=mylabels[5],drawstyle=myds,rasterized=True)
  plt.plot_date(datearray,p25,'-',color=mycolors[6],label=mylabels[6],drawstyle=myds,rasterized=True)
  if max(fmodes)<1:  plt.plot_date(datearray,p41,'-',color=mycolors[7],label=mylabels[7],drawstyle=myds,rasterized=True)
  plt.ylabel("protons / (cm$^2$ s sr MeV)$^{-1}$")
  if log==True: plt.yscale("log")
  plt.legend(fontsize=lfs,ncol=myncol,columnspacing=mycs,handletextpad=myhtp)
  plt.setp(ax1.get_xticklabels(), visible=False) # do not show tick labels

  if ranges!=None:  plt.ylim(float(ranges[1][0]),float(ranges[1][1]))


  # third panel (helium)
  ax2 = plt.subplot2grid((4, 1), (2, 0), sharex=ax0)  
  ax2.yaxis.tick_left()    

  plt.plot_date(datearray,he4,'-',color=mycolors[8],label=mylabels[8],drawstyle=myds,rasterized=True)
  plt.plot_date(datearray,he8,'-',color=mycolors[9],label=mylabels[9],drawstyle=myds,rasterized=True)
  plt.plot_date(datearray,he25,'-',color=mycolors[10],label=mylabels[10],drawstyle=myds,rasterized=True)
  if max(fmodes)<1:  plt.plot_date(datearray,he41,'-',color=mycolors[11],label=mylabels[11],drawstyle=myds,rasterized=True)
  plt.ylabel("helium / (cm$^2$ s sr MeV/nuc)$^{-1}$")
  if log==True: plt.yscale("log")
  plt.legend(fontsize=lfs,ncol=myncol,columnspacing=mycs,handletextpad=myhtp)
  plt.setp(ax2.get_xticklabels(), visible=False) # do not show tick labels

  if ranges!=None:  plt.ylim(float(ranges[2][0]),float(ranges[2][1]))


  # fourth panel (integral)
  ax3 = plt.subplot2grid((4, 1), (3, 0), sharex=ax0)      
  ax3.yaxis.set_label_position("right")
  ax3.yaxis.tick_right()
  plt.plot_date(datearray,inte,'-',color=mycolors[12],label=mylabels[12],drawstyle=myds,rasterized=True)
  plt.ylabel("integral channel / (cm$^2$ s sr)$^{-1}$")
  if log==True: plt.yscale("log")
  plt.legend(fontsize=lfs)

  if ranges!=None:  plt.ylim(float(ranges[3][0]),float(ranges[3][1]))


  if timing_analysis: print("--- %s seconds --- DATA PLOTTED" % (time.time() - start_time))
  # fill ring offs
  for tax in [ax0,ax1,ax2,ax3]:
    tax.fill_between(datearray, tax.get_ylim()[0], tax.get_ylim()[1], where=ringoff>0, facecolor='red', interpolate=False,alpha=0.3,rasterized=True)

  # set xlim

  if timing_analysis: print("--- %s seconds --- RING OFFS MARKED" % (time.time() - start_time))
  sdate,edate=datetime(mysyear,1,1,shour)+timedelta(mysdoy-1),datetime(myeyear,1,1,ehour)+timedelta(myedoy-1)
  datediff=edate-sdate
  day=datearray[0].day
  mon=datearray[0].month
  plt.xlim(sdate,edate)


  # formats
  myFmt = mdates.DateFormatter('%H:%M\n %d %b \n %Y ')  #%H:%M:%S \n #%y %b 
  #ax3.get_xaxis().set_major_locator(dates.HourLocator(byhour=range(0,24,3)))
  #ax3.get_xaxis().set_minor_locator(dates.HourLocator(byhour=range(0,24,1)))
  ax3.xaxis.set_major_formatter(myFmt)
  

  if datediff<=timedelta(hours=3): ax3.get_xaxis().set_minor_locator(dates.MinuteLocator(interval=1))
  elif datediff<=timedelta(hours=12): ax3.get_xaxis().set_minor_locator(dates.MinuteLocator(byminute=range(0,60,15)))
  elif datediff<=timedelta(days=2):  ax3.get_xaxis().set_minor_locator(dates.HourLocator(byhour=range(0,24,1)))
  elif datediff<=timedelta(days=15):  ax3.get_xaxis().set_minor_locator(dates.HourLocator(byhour=range(0,24,6)))
  elif datediff<=timedelta(days=50):  ax3.get_xaxis().set_minor_locator(dates.DayLocator(interval=1))


  # title / disclaimer
  plt.rcParams.update({'axes.titlesize': 'small'})
  #ax0.set_title("These data are for browsing purposes. Please contact the instrument team before use them for scientific studies.\nPlot created via www.ieap.uni-kiel.de/et/people/kuehl/www2_ephin")

  if timing_analysis: print("--- %s seconds --- ALL DONE; NOW SAVE" % (time.time() - start_time))

  #  save plot
  #os.system("mkdir live -p" )
  #plt.savefig("test.pdf" )
  #plt.close()
  plt.show()

#plot_daily_rl2(2016,2016,1,180,mean=1440,shour=0,ehour=0,timing_analysis=True,plotformat="PDF",orientation="Landscape") 


mission=raw_input("Mission: SOHO ('s') or CHANDRA ('c'): ") #2007

syear=int(raw_input("Insert start year: ")) 
eyear=int(raw_input("Insert end year: ")) 
start_doy=(raw_input("Insert start DoY (mm-dd supported): "))
end_doy=(raw_input("Insert end DoY (mm-dd supported): "))
shour=int(raw_input("Insert start hour: ")) 
ehour=int(raw_input("Insert end hour: ")) 

avg=int(raw_input("Averaging in minutes (e.g. 1, 60, 1440, ...): "))
orientation=raw_input("Plot format: Portrait ('p') or Landscape ('l'): ")
if orientation=="l": orientation="Landscape"
else: orientation="Portrait"

plot_daily_rl2(syear,eyear,start_doy,end_doy,mean=avg,shour=shour,ehour=ehour,timing_analysis=True,orientation=orientation, mission=mission) 



