# This scripts is doing...
# Patrick Kuehl, date

import matplotlib
#matplotlib.use('Agg')

import time
from pylab import *
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import numpy.ma as ma
from matplotlib import dates
import time
import datetime as date
from datetime import datetime
from datetime import timedelta
import os
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 


import matplotlib.dates as mdates

""" ----------------- """
#### use input on very bottom of script !!!!! 
""" ----------------- """

def year_doy_to_name(year,doy):
  if year<2000: 
    pre="eph"
    year=year-1900
  else: 
    pre="epi"
    year=year-2000
  name="%s%02d%03d" %(pre,year,doy)
  return name


def plot_daily_rl2(mysyear,myeyear,mysdoy,myedoy,output_path="plots",log=True,mean=None,shour=0,ehour=0):
  # myyear: int of year of interest
  # mydoy: int of dayofyear of interest  
  #         NOTE: mm-dd is also implemented
  # output_path: the path/subfolder where the annual subfolders including the plots will be stored
  # log: True/False for ylogscale
  # meaner: int of minutes to average over. "None" results in original (1min) resolution
  if mean!=None: output_path+="_%imin_avg"%mean

  mysyear=int(mysyear)
  myeyear=int(myeyear)

  shour=int(shour)
  ehour=int(ehour)
  try: mysdoy=int(mysdoy)
  except:
    mon=int(mysdoy.split("-")[0])
    day=int(mysdoy.split("-")[1])
    mydate=datetime(int(mysyear),int(mon),int(day))
    mysdoy=int(mydate.strftime("%j"))
  try: myedoy=int(myedoy)
  except:
    mon=int(myedoy.split("-")[0])
    day=int(myedoy.split("-")[1])
    mydate=datetime(int(myeyear),int(mon),int(day))
    myedoy=int(mydate.strftime("%j"))

  ################## load data
  init=1

  
  #if all in one year
  if mysyear==myeyear:
    myyear=mysyear
    for mydoy in range(mysdoy,myedoy+1):
      name=year_doy_to_name(myyear,mydoy)
      try:
        tdata=np.loadtxt("/data/missions/soho/costep/level2/rl2/%4d/%s.rl2" %(myyear,name))
        if init==1:
          data=tdata
          init=0
        else:
          data=np.vstack((data,tdata))
      except:
        print "no file %s!" %name  
  #enable crossyear
  else:
    myyear=mysyear
    for mydoy in range(mysdoy,366):
      name=year_doy_to_name(myyear,mydoy)
      try:
        tdata=np.loadtxt("/data/missions/soho/costep/level2/rl2/%4d/%s.rl2" %(myyear,name))
        if init==1:
          data=tdata
          init=0
        else:
          data=np.vstack((data,tdata))
      except:
        print "no file %s!" %name  
    myyear=myeyear
    for mydoy in range(1,myedoy+1):
      name=year_doy_to_name(myyear,mydoy)
      try:
        tdata=np.loadtxt("/data/missions/soho/costep/level2/rl2/%4d/%s.rl2" %(myyear,name))
        if init==1:
          data=tdata
          init=0
        else:
          data=np.vstack((data,tdata))
      except:
        print "no file %s!" %name 


  # take those dataparts we are interested in
  year=data[:,0]
  doy=data[:,1]
  msec=data[:,2]
  e150=data[:,6]
  e300=data[:,7]
  e1300=data[:,8]
  e3000=data[:,9]
  p4=data[:,10]
  p8=data[:,11]
  p25=data[:,12]
  p41=data[:,13]
  he4=data[:,14]
  he8=data[:,15]
  he25=data[:,16]
  he41=data[:,17]
  inte=data[:,18]
  status=data[:,47]


  fmodes=np.zeros(len(status))
  for q in range(len(status)):
    binaries='{0:08b}'.format(int(status[q]))
    if int(binaries[-1])==1:
      if int(binaries[-3])==1: fmodes[q]=1
      else: fmodes[q]=2
  
  ringoff=np.zeros(len(status))
  for q in range(len(status)):
    binaries='{0:08b}'.format(int(status[q]))
    if int(binaries[-2]): ringoff[q]=1

  datearray=[(datetime(int(year[date]),1,1,0)+timedelta(int(doy[date])-1,int(msec[date]/1000.))) for date in range(len(year))]

  # average data if necessary
  if mean!=None:
        chs=[e150,e300,e1300,e3000,p4,p8,p25,p41,he4,he8,he25,he41,inte]
        states=[fmodes,ringoff]
        newdatearray=[]
        index=0
        #if all in one year
        if mysyear==myeyear:
          for mydoy in range(mysdoy,myedoy+1):
            startmin=0
            endmin=mean
            while endmin<24*60:
              if index>len(datearray)-2:  break
              starttime=datetime(int(myyear),1,1)+timedelta(int(mydoy)-1,minutes=int(startmin))
              endtime=datetime(int(myyear),1,1)+timedelta(int(mydoy)-1,minutes=int(endmin))
              while datearray[index]<starttime: 
                index+=1
                if index>len(datearray)-2:  break
              index2=int(index)
              while datearray[index2]<endtime:
                index2+=1
                if index2>len(datearray)-2:  break
              for ch in chs:
                ch[len(newdatearray)]=np.nanmean(ch[index:index2])
              for state in states:
                try:  state[len(newdatearray)]=np.max(state[index:index2])
                except: state[len(newdatearray)]=0

              newdatearray.append(datetime(int(myyear),1,1)+timedelta(int(mydoy)-1,minutes=int(startmin+mean/2.)))

              startmin+=mean
              endmin+=mean
              index=int(index2)
        #enable crossyear
        else:
          myyear=mysyear
          for mydoy in range(mysdoy,366):
            startmin=0
            endmin=mean
            while endmin<24*60:
              if index>len(datearray)-2:  break
              starttime=datetime(int(myyear),1,1)+timedelta(int(mydoy)-1,minutes=int(startmin))
              endtime=datetime(int(myyear),1,1)+timedelta(int(mydoy)-1,minutes=int(endmin))
              while datearray[index]<starttime: 
                index+=1
                if index>len(datearray)-2:  break
              index2=int(index)
              while datearray[index2]<endtime:
                index2+=1
                if index2>len(datearray)-2:  break
              for ch in chs:
                ch[len(newdatearray)]=np.nanmean(ch[index:index2])
              for state in states:
                try:  state[len(newdatearray)]=np.max(state[index:index2])
                except: state[len(newdatearray)]=0

              newdatearray.append(datetime(int(myyear),1,1)+timedelta(int(mydoy)-1,minutes=int(startmin+mean/2.)))

              startmin+=mean
              endmin+=mean
              index=int(index2)
          myyear=myeyear
          for mydoy in range(1,myedoy+1):
            startmin=0
            endmin=mean
            while endmin<24*60:
              if index>len(datearray)-2:  break
              starttime=datetime(int(myyear),1,1)+timedelta(int(mydoy)-1,minutes=int(startmin))
              endtime=datetime(int(myyear),1,1)+timedelta(int(mydoy)-1,minutes=int(endmin))
              while datearray[index]<starttime: 
                index+=1
                if index>len(datearray)-2:  break
              index2=int(index)
              while datearray[index2]<endtime:
                index2+=1
                if index2>len(datearray)-2:  break
              for ch in chs:
                ch[len(newdatearray)]=np.nanmean(ch[index:index2])
              for state in states:
                try:  state[len(newdatearray)]=np.max(state[index:index2])
                except: state[len(newdatearray)]=0

              newdatearray.append(datetime(int(myyear),1,1)+timedelta(int(mydoy)-1,minutes=int(startmin+mean/2.)))

              startmin+=mean
              endmin+=mean
              index=int(index2)




        datearray=newdatearray



        e150=chs[0][:len(datearray)]
        e300=chs[1][:len(datearray)]
        e1300=chs[2][:len(datearray)]
        e3000=chs[3][:len(datearray)]
        p4=chs[4][:len(datearray)]
        p8=chs[5][:len(datearray)]
        p25=chs[6][:len(datearray)]
        p41=chs[7][:len(datearray)]
        he4=chs[8][:len(datearray)]
        he8=chs[9][:len(datearray)]
        he25=chs[10][:len(datearray)]
        he41=chs[11][:len(datearray)]
        inte=chs[12][:len(datearray)]

        fmodes=states[0][:len(datearray)]
        ringoff=states[1][:len(datearray)]



  ################## now plot data
  fig=plt.figure(figsize=(11.7,8.3))	#8.3,11.7))
  plt.subplots_adjust(left=0.13, right=0.9, top=0.97,bottom=0.055, wspace=None, hspace=0.001)
  mycolors=["deepskyblue","mediumslateblue","steelblue","blue","orange","tomato","indianred","red","lightgreen","lawngreen","forestgreen","g","k"]
  mylabels=["0.25 - 0.70 MeV","0.67 - 3.00 MeV","2.64 - 6.18 MeV","4.80 - 10.40 MeV","4.3 - 7.8 MeV","7.8 - 25 MeV","25 - 40.9 MeV","40.9 - 53 MeV","4.3 - 7.8 MeV/nuc","7.8 - 25 MeV/nuc","25 - 40.9 MeV/nuc","40.9 - 53 MeV/nuc","e$^-$: > 8.70 MeV\np: >53 MeV\nHe: > 53 MeV/nuc"]
  lfs=10
  myncol=4
  mycs=0
  myhtp=0.1
  myds="steps"  #"default"

  if max(fmodes)==1: 
    mylabels[2]="2.64 - 10.40 MeV"
    mylabels[6]="25.0 - 53.0 MeV"
    mylabels[10]="25.0 - 53.0 MeV/nuc"
  if max(fmodes)==2: 
    mylabels[2]="2.64 - 10.40 MeV (FMEnoEP)"
    mylabels[6]="25 - 53 MeV (FMEnoEP)"
    mylabels[10]="25 - 53 MeV/nuc (FMEnoEP)"

  # first panel (electrons)
  ax0 = plt.subplot2grid((4, 1), (0, 0))
  ax0.yaxis.tick_left()    

  plt.plot_date(datearray,e150,'-',color=mycolors[0],label=mylabels[0],drawstyle=myds)
  plt.plot_date(datearray,e300,'-',color=mycolors[1],label=mylabels[1],drawstyle=myds)
  plt.plot_date(datearray,e1300,'-',color=mycolors[2],label=mylabels[2],drawstyle=myds)
  if max(fmodes)<1:  plt.plot_date(datearray,e3000,'-',color=mycolors[3],label=mylabels[3],drawstyle=myds)
  
  plt.ylabel("electrons / (cm$^2$ s sr MeV)$^{-1}$")
  if log==True: plt.yscale("log")
  plt.legend(fontsize=lfs,ncol=myncol,columnspacing=mycs,handletextpad=myhtp)
  plt.setp(ax0.get_xticklabels(), visible=False) # do not show tick labels


  # second panel (protons)
  ax1 = plt.subplot2grid((4, 1), (1, 0), sharex=ax0)      
  ax1.yaxis.set_label_position("right")
  ax1.yaxis.tick_right()
  plt.plot_date(datearray,p4,'-',color=mycolors[4],label=mylabels[4],drawstyle=myds)
  plt.plot_date(datearray,p8,'-',color=mycolors[5],label=mylabels[5],drawstyle=myds)
  plt.plot_date(datearray,p25,'-',color=mycolors[6],label=mylabels[6],drawstyle=myds)
  if max(fmodes)<1:  plt.plot_date(datearray,p41,'-',color=mycolors[7],label=mylabels[7],drawstyle=myds)
  plt.ylabel("protons / (cm$^2$ s sr MeV)$^{-1}$")
  if log==True: plt.yscale("log")
  plt.legend(fontsize=lfs,ncol=myncol,columnspacing=mycs,handletextpad=myhtp)
  plt.setp(ax1.get_xticklabels(), visible=False) # do not show tick labels

  

  # third panel (helium)
  ax2 = plt.subplot2grid((4, 1), (2, 0), sharex=ax0)  
  ax2.yaxis.tick_left()    

  plt.plot_date(datearray,he4,'-',color=mycolors[8],label=mylabels[8],drawstyle=myds)
  plt.plot_date(datearray,he8,'-',color=mycolors[9],label=mylabels[9],drawstyle=myds)
  plt.plot_date(datearray,he25,'-',color=mycolors[10],label=mylabels[10],drawstyle=myds)
  if max(fmodes)<1:  plt.plot_date(datearray,he41,'-',color=mycolors[11],label=mylabels[11],drawstyle=myds)
  plt.ylabel("helium / (cm$^2$ s sr MeV/nuc)$^{-1}$")
  if log==True: plt.yscale("log")
  plt.legend(fontsize=lfs,ncol=myncol,columnspacing=mycs,handletextpad=myhtp)
  plt.setp(ax2.get_xticklabels(), visible=False) # do not show tick labels



  # fourth panel (integral)
  ax3 = plt.subplot2grid((4, 1), (3, 0), sharex=ax0)      
  ax3.yaxis.set_label_position("right")
  ax3.yaxis.tick_right()
  plt.plot_date(datearray,inte,'-',color=mycolors[12],label=mylabels[12],drawstyle=myds)
  plt.ylabel("integral channel / (cm$^2$ s sr)$^{-1}$")
  if log==True: plt.yscale("log")
  plt.legend(fontsize=lfs)

  myFmt = mdates.DateFormatter('%H:%M \n %d %b \n %Y ')  #%H:%M:%S \n #%y %b 
  #ax3.get_xaxis().set_major_locator(dates.HourLocator(byhour=range(0,24,3)))
  #ax3.get_xaxis().set_minor_locator(dates.HourLocator(byhour=range(0,24,1)))
  ax3.xaxis.set_major_formatter(myFmt)

  # set title
  day=datearray[0].day
  mon=datearray[0].month
  #ax0.set_title("%02d.%02d.%04d (DoY: %03d)" %(day,mon,myyear,mydoy) )
    
  plt.xlim(datetime(mysyear,1,1,shour)+timedelta(mysdoy-1),datetime(myeyear,1,1,ehour)+timedelta(myedoy-1))


  # fill ring offs
  for tax in [ax0,ax1,ax2,ax3]:
    tax.fill_between(datearray, tax.get_ylim()[0], tax.get_ylim()[1], where=ringoff>0, facecolor='red', interpolate=False,alpha=0.3)

  #  save plot
  #os.system("mkdir live -p" )
  #plt.savefig("live/plot.png" )
  #plt.close()

  plt.savefig("ephin_rl2_wwwlike.pdf")
  plt.show()

"""for mean in [30]: #None,30, 5
  for year in range(1995,2017):
    for doy in range(1,370):
      try:
        plot_daily_rl2(year,doy,mean=mean) 
        print "could do ", year, doy
      except:
        print "could not do ", year, doy"""

#plot_daily_rl2(2015,2016,"12-1","9-1",mean=60*1,shour=0,ehour=23) #138
#plot_daily_rl2(1998,126,mean=30)


plot_daily_rl2(2021,2021,"10-25","11-5",mean=60*1,shour=2,ehour=23) #138

