import numpy as np
import os
import json
import datetime as dt
import warnings

### required code for calculating EPHIN histfluxes from lvl1-sci files


### working example:
"""
execfile("ephin_histograms_funcs.py")
output_path="histfluxes/"
sci_path="/data/missions/soho/costep/level1/sci/"
response_dict_path="../SCRIPTS/"
timeconstant_coinc_counter=0.0102
time_accumulation=59.953
os.system("mkdir %s -p"%output_path)
os.system("rm -f %s*"%output_path)
create_histogram_flux_files(1997,234,2017,277,output_path,sci_path,response_dict_path,timeconstant_coinc_counter,time_accumulation)
"""


### included functions:
"""
# time helper functions
year_doy_iterator(year,doy) - returns the year and day one day afer input date
time_from_msec(millisec) - returns hours and minutes from input millisec
timediff(array_year,array_doy,array_msec,line1,line2) - returns timedelta between two lines

# write header
write_header(toutfile,type,tdict,timeconstant_coinc_counter,time_accumulation) - writes header to output file

# data loading
load_level1_sci(year,doy,sci_path) - load level1 sci files
load_and_merge_consecutive_scis(year,doy,sci_path) - merges two consecutive sci files
create_ringbit_array(data) - extracts ringbit information from data array
create_hist_array(data) - extracts histograms from data array
create_histbit_array(data) - extracts histbit information from data array
create_total_coinc_counter_array(data) - extracts total sum of coincidence counters from data array
create_all_the_arrays(data) - executes all data extractions listed above

# load masterchannel response json
load_masterchannel_response_dict(type) - loads the histogram response JSON

# flux calculation
calc_coinc_counter_per_min_during_hist(array_total_coinc_counter,line,thistdate,array_year,array_doy,array_msec) - calculates coincidence counters during hist integration interval
calc_coinc_deadtime_correction(coinc_counter_sum_per_min, timeconstant_coinc_counter,time_accumulation) - calculates counter scaling used during ringon
def calc_coinc_counter_ratio(coinc_counter_sum_per_min, array_hist,line) - calculated count ratio coinc/histo used for scaling during ringoff
extract_and_unfold_hists(array_hist,line) - unfold histograms from hist array
calc_and_save_hist(array_ringbit,array_hist,array_histbit,array_total_coinc_counter,array_year,array_doy,array_msec,line,masterch_dict_list,types,timeconstant_coinc_counter,time_accumulation,output_path) - calculates and saves fluxes for a given histogram

# actual data processing
create_histogram_flux_files(start_year,start_doy,end_year,end_doy,output_path,sci_path,response_dict_path,timeconstant_coinc_counter,time_accumulation) - main function incorporating the above ones in order to create histogram flux files
"""

# time helper functions
def year_doy_iterator(year,doy):
	# returns the year doy tuple of next day
	if doy>370: year,doy=year+1,1
	else: year,doy=year,doy+1
	return year, doy

def time_from_msec(millisec):
  minutes=millisec/60000.
  hour=0
  while minutes>=60: 
    hour+=1
    minutes-=60
  return hour,minutes

def timediff(array_year,array_doy,array_msec,line1,line2):
	# calculates the time diffents of two lines
	year1,doy1,msec1=array_year[line1],array_doy[line1],array_msec[line1]
	year2,doy2,msec2=array_year[line2],array_doy[line2],array_msec[line2]
	time1=dt.datetime(int(year1),1,1)+dt.timedelta(int(doy1)-1,milliseconds=msec1)
	time2=dt.datetime(int(year2),1,1)+dt.timedelta(int(doy2)-1,milliseconds=msec2)
	return time2-time1

# write header
def write_header(toutfile,type,tdict,timeconstant_coinc_counter,time_accumulation):
	f=open(toutfile,"w")
	f.write("# SOHO/EPHIN histogram fluxes - %s\n"%type)
	f.write("# !!! Note that the histogram fluxes feature several caveats (especially during hard energy spectra) - Please read the Manual/Documentation !!!\n")
	f.write("# %s fluxes and their statistical uncertainties are given in %i channels\n"%(type,len(tdict["masterch_id"])))
	f.write("# The time given indicates the beginning of a 8 minute integration period\n")
	f.write("# Councidence counter time constant (used during ringon) was set to %4.8ss (see Documentation)\n"%(timeconstant_coinc_counter))
	f.write("# The columns represent:\n")
	f.write("# Year, Month, Day, DayOfYear, Hour, Minute, Ringbit, Coinccounterscaler, %i flux columns, %i number of counts columns\n"%(len(tdict["masterch_id"]),len(tdict["masterch_id"])))
	f.write("# If ringbit ==1: outer segments are turned off (high flux mode)\n")
	f.write("# Fluxes are given in units of (cm^2 sr s MeV/nuc)^-1\n")
	f.write("# The number of counts columns can be used to determine statistical uncertainties (sigma=flux/sqrt(N)) \n")
	f.write("# The energy channels are as follows (given in units of MeV/nuc):\n")
	emeans="# E_center: "
	for q in tdict["ecenter"]: emeans+="%4.2f "%q
	f.write(emeans+"\n")
	emins="# E_min: "
	for q in tdict["emin"]: emins+="%4.2f "%q
	f.write(emins+"\n")
	emaxs="# E_max: "
	for q in tdict["emax"]: emaxs+="%4.2f "%q
	f.write(emaxs+"\n")
	f.write("# !!! Note that the histogram fluxes feature several caveats (especially during hard energy spectra) - Please read the Manual/Documentation !!!\n")
	f.close()

# data loading

def load_level1_sci(year,doy,sci_path):
	if year>=2000:  fname="epi%02d%03d.sci"%(year-2000,doy)
	else:  fname="eph%02d%03d.sci"%(year-1900,doy)
	filename=sci_path+"%04d/"%year+fname
	if os.path.isfile(filename):
		with warnings.catch_warnings():
			warnings.simplefilter("ignore")
			data=np.loadtxt(filename)
		if data != []:
			return data
		else: 
			return []
	else:
		return []


def load_and_merge_consecutive_scis(year,doy,sci_path):
	data1,data2=[],[]
	while data1==[]:
		data1=load_level1_sci(year,doy,sci_path)
		year,doy=year_doy_iterator(year,doy)
	while data2==[]:
		data2=load_level1_sci(year,doy,sci_path)
		year,doy=year_doy_iterator(year,doy)
		if data2!=[]:	
			if len(data2[:,0]) <=8: 
				data2=[]
	merged_data=np.vstack((data1,data2))
	return merged_data

def create_ringbit_array(data):
	# ring off == 1
	# ring on == 0
	status=data[:,-1]
	ringoff=np.zeros(len(status))
	for q in range(len(status)):
		binaries='{0:08b}'.format(int(status[q]))
		if int(binaries[-2]): ringoff[q]=1
	return ringoff

def create_hist_array(data):
	hist=data[:,-40:-8]
	return hist

def create_histbit_array(data):
	histbit=data[:,-8]-64
	return histbit

def create_total_coinc_counter_array(data):
	#sdata=np.split(data,len(data[0,:]),1)
	#print sdata
	columns= [0,1,2,36,37,38,39, 22,23,24, 25,26,27, 41,42,43, 44,45,46, 28,29,30,31, 32,33,34,35, 47,48,49,50, 51,52,53,54, 40]
	variables=[]
	for q in columns: variables.append(data[:,q])
	year,doy,msdoy,e1,e2,e3,e4,p1_1,p1_2,p1_3,p2_1,p2_2,p2_3,p3_1,p3_2,p3_3,p4_1,p4_2,p4_3, h1_1,h1_2,h1_3,h1_4,h2_1,h2_2,h2_3,h2_4,h3_1,h3_2,h3_3,h3_4,h4_1,h4_2,h4_3,h4_4, total_int_counts= variables
	p1=p1_1+p1_2+p1_3
	p2=p2_1+p2_2+p2_3
	p3=p3_1+p3_2+p3_3
	p4=p4_1+p4_2+p4_3
	h1=h1_1+h1_2+h1_3+h1_4
	h2=h2_1+h2_2+h2_3+h2_4
	h3=h3_1+h3_2+h3_3+h3_4
	h4=h4_1+h4_2+h4_3+h4_4
	ab=e1+p1+h1
	abc=e2+p2+h2
	abcd=e3+p3+h3
	abcde=e4+p4+h4
	totat_coinc_counter_array= ab+abc+abcd+abcde  #+total_int_counts
	return totat_coinc_counter_array

def create_all_the_arrays(data):
	array_ringbit=create_ringbit_array(data)
	array_hist=create_hist_array(data)
	array_histbit=create_histbit_array(data)
	array_total_coinc_counter=create_total_coinc_counter_array(data)
	array_year=data[:,0]
	array_doy=data[:,1]
	array_msec=data[:,2]
	return array_ringbit,array_hist,array_histbit,array_total_coinc_counter,array_year,array_doy,array_msec


# load masterchannel response json
def load_masterchannel_response_dict(response_dict_path,type):
	this_dict=json.load(open(response_dict_path+"histo_masterch_%s_parallel_dict.json"%type))
	return this_dict



# flux calculation
def calc_coinc_counter_per_min_during_hist(array_total_coinc_counter,line,thistdate,array_year,array_doy,array_msec):
	lineoffset=0
	tcounter_year,tcounter_doy,tcounter_msec=array_year[line+lineoffset],array_doy[line+lineoffset],array_msec[line+lineoffset]
	tcounterdate=dt.datetime(int(tcounter_year),1,1)+dt.timedelta(int(tcounter_doy)-1,milliseconds=int(tcounter_msec))
	while tcounterdate-thistdate>=dt.timedelta(0):
		lineoffset-=1
		tcounter_year,tcounter_doy,tcounter_msec=array_year[line+lineoffset],array_doy[line+lineoffset],array_msec[line+lineoffset]
		tcounterdate=dt.datetime(int(tcounter_year),1,1)+dt.timedelta(int(tcounter_doy)-1,milliseconds=int(tcounter_msec))
	lineoffset+=1 # we went one too far ;)
	coinc_counter_sum_per_min=np.mean(array_total_coinc_counter[line+lineoffset:line])
	return coinc_counter_sum_per_min

def calc_coinc_deadtime_correction(coinc_counter_sum_per_min, timeconstant_coinc_counter,time_accumulation):
	# returns the deadtime correction for the coinc counters. The returned scaler needs to be multiplied with the histogram counts
	def func(x,p):
		y = 1./(1.+p*x)
		return y
	coinc_counter_sum_per_second=coinc_counter_sum_per_min/time_accumulation
	scaler=1./func(coinc_counter_sum_per_second,timeconstant_coinc_counter)
	return scaler

def calc_coinc_counter_ratio(coinc_counter_sum_per_min, array_hist,line):
	scaler=coinc_counter_sum_per_min*8./np.sum(array_hist[line:line+8,:])
	return scaler

def extract_and_unfold_hists(array_hist,line):
	unscaled_hist=array_hist[line:line+8,:]
	unscaled_hist_ab=unscaled_hist[0:2,:].flatten()
	unscaled_hist_abc=unscaled_hist[2:4,:].flatten()
	unscaled_hist_abcd=unscaled_hist[4:6,:].flatten()
	unscaled_hist_abcde=unscaled_hist[6:8,:].flatten()
	unscaled_hist_dict={"AB": unscaled_hist_ab,"ABC":unscaled_hist_abc,"ABCD":unscaled_hist_abcd,"ABCDE":unscaled_hist_abcde}
	return unscaled_hist_dict


def calc_and_save_hist(array_ringbit,array_hist,array_histbit,array_total_coinc_counter,array_year,array_doy,array_msec,line,masterch_dict_list,types,timeconstant_coinc_counter,time_accumulation,output_path):
	# derive deadtime scaler
	thist_year,thist_doy,thist_msec=array_year[line],array_doy[line],array_msec[line]
	thistdate=dt.datetime(int(thist_year),1,1)+dt.timedelta(int(thist_doy)-1,milliseconds=int(thist_msec),minutes=-8)
	coinc_counter_sum_per_min=calc_coinc_counter_per_min_during_hist(array_total_coinc_counter,line,thistdate,array_year,array_doy,array_msec)
	coinc_counter_deadtime_scaler=calc_coinc_deadtime_correction(coinc_counter_sum_per_min, timeconstant_coinc_counter,time_accumulation)
	coinc_counter_ringoff_scaler=calc_coinc_counter_ratio(coinc_counter_sum_per_min, array_hist,line)
	unscaled_hist_dict=extract_and_unfold_hists(array_hist,line)
	# loop over types
	for typerunner, type in enumerate(types):
		tdict=masterch_dict_list[typerunner]		
		# extract counts in master chs
		masterch_counts=[]
		for chrunner, masterchnummer in enumerate(tdict["masterch_id"]):
			masterch_counts.append(0)
			for coin_in_masterch_runner in range(len(tdict["coinces"][chrunner])):
				tcoinc= tdict["coinces"][chrunner][coin_in_masterch_runner]
				tchincoinc= tdict["ch_in_coinces"][chrunner][coin_in_masterch_runner]
				masterch_counts[-1]+= unscaled_hist_dict[tcoinc][tchincoinc]
		# derive flux/staterr in master chs
		masterch_fluxes=[]
		for chrunner, masterchnummer in enumerate(tdict["masterch_id"]):
			if array_ringbit[line-1]==1.: 
				tresponse=tdict["response_ringoff"][chrunner]
				countscaler=coinc_counter_ringoff_scaler
			else:	
				tresponse=tdict["response_ringon"][chrunner]
				countscaler=coinc_counter_deadtime_scaler
			tflux=masterch_counts[chrunner] / (8.*time_accumulation) / tdict["ewidth"][chrunner] / tresponse *countscaler
			masterch_fluxes.append(tflux)
		# write output
		toutfile=output_path+"histfluxes_%s_%i.dat"%(type,thist_year)
		if not os.path.isfile(toutfile): write_header(toutfile,type,tdict,timeconstant_coinc_counter,time_accumulation)
		f=open(toutfile,"a")
		f.write("%i %i %i %i  %i %i  %i %4.4e  "%(thistdate.year,thistdate.month,thistdate.day,int(thistdate.strftime("%j")), thistdate.hour,thistdate.minute,  array_ringbit[line-1], countscaler))
		for tflux in masterch_fluxes:
			f.write(" %4.4e"%tflux)
		f.write("    ")
		for counts in masterch_counts:
			f.write(" %i"%counts)
		f.write("\n")
		f.close()



# actual data processing
def create_histogram_flux_files(start_year,start_doy,end_year,end_doy,output_path,sci_path,response_dict_path,timeconstant_coinc_counter,time_accumulation):
	types=["proton","helium"]
	masterch_dict_list=[]
	for type in types: masterch_dict_list.append(load_masterchannel_response_dict(response_dict_path,type))
	stat_good,stat_bad=0.,0.
	tyear,tdoy=start_year,start_doy	
	printyear=start_year-1
	data=load_and_merge_consecutive_scis(tyear,tdoy,sci_path)
	array_ringbit,array_hist,array_histbit,array_total_coinc_counter,array_year,array_doy,array_msec=create_all_the_arrays(data)
	line=9
	while data[line,0] < end_year or (data[line,0] == end_year and data[line,1]<=end_doy+1):
		if data[line,0]!= printyear:
			printyear+=1
			print("Processing %i..."%printyear)
		if array_histbit[line]==0:
			no_hist_issue=True

			# check if hist is complete and correctly ordered
			comparison=array_histbit[line:line+8]==np.array([0,1,2,3,4,5,6,7])
			if not comparison.all(): no_hist_issue=False

			# check if no ringswitch occours during previous 8 minutes:
			tringbit=array_ringbit[line-8:line]
			if len(np.unique(tringbit))!=1: 
				no_hist_issue=False

			# check if the time stamp from line-8 till line+7 is not too long or too short
			# (should be 16, used threshold is 18 in order to allow two scaling counters being missing - this is taken care of in the scaling stuff)
			thistimediff=timediff(array_year,array_doy,array_msec,line-8,line+7).seconds
			if thistimediff < 14.5*60 or thistimediff > 18.5*60: 
				no_hist_issue=False

			if no_hist_issue==True:
				calc_and_save_hist(array_ringbit,array_hist,array_histbit,array_total_coinc_counter,array_year,array_doy,array_msec,line,masterch_dict_list,types,timeconstant_coinc_counter,time_accumulation,output_path)		
				stat_good+=1
			else:	
				stat_bad+=1
			
			line+=1
		else:
			line+=1
		if line>len(array_ringbit)-10:
			tyear,tdoy,tmsec=array_year[line],array_doy[line],array_msec[line]
			#print tyear,tdoy, line
			data=load_and_merge_consecutive_scis(tyear,tdoy,sci_path)
			array_ringbit,array_hist,array_histbit,array_total_coinc_counter,array_year,array_doy,array_msec=create_all_the_arrays(data)
			line=np.where((array_msec==tmsec)&(array_doy==tdoy))[0][0]

	print("Hist issues were found in %2.2f %% of all cases"%(stat_bad/stat_good*100))
