from math import pi
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from math import pi


def osc(amplitudes, omegas, phases, ks, times, x):
    """
    Define the waves and the total oscillation as a sum of them.
    amplitudes: amplitudes of waves
    omegas: wave frequencies of waves
    phases: phases
    ks: wave numbers
    times: times at which to evaluate
    x: spatial coordinate at which to evaluate
    """
    waves = []
    for t in times:
        #print(t)
        for ampl, om, ph, k in zip(amplitudes, omegas, phases, ks):
            #print(ampl, om, ph, k)
            tt = t*np.ones(len(x))
            pp = ph*np.ones(len(x))
            w = ampl*np.sin(om*tt - k*x + pp)
            waves.append(w)
    waves = np.reshape(waves,(len(times),len(amplitudes),len(x)))
    #print(np.shape(waves))
    return waves


amplitudes=np.array([1., 1., 1., 1.])
omegas = np.array([1., 1.05, 1.10, 1.15])
mp = 0.
phases = np.array([0., mp*pi/8, mp*pi/4, mp*3*pi/8])
ks = np.array([1., .95, .9, .85])*2.*np.pi*(-1.)
#times = np.linspace(0., 1., 2.) #times at which we want to see a snapshot
times = np.asarray([0., 0.16667, .333, .6667]) #times at which we want to see a snapshot
len_x = 600
x = np.linspace(0., 5., len_x)
n = len(amplitudes)
nt = 5. #multiplier for times
km = np.mean(ks)
om = np.mean(omegas)

waves = osc(amplitudes, omegas, phases, ks, nt*times, x)

fig, axs = plt.subplots(len(times), sharex=True)
for ax, i in zip(axs, np.arange(len(times))):
    ws = np.zeros(len(x))
    #print(i)
    #print(waves[i,0])
    xmaxi = 0.
    ymaxi = 0.
    for j in np.arange(len(omegas)):
        w = waves[i,j].ravel()
        #print(w)
        ax.plot(x, w,':')
        #ax.plot(x, wave[1].ravel(),':')
        ws+=w
        #xmax = (omegas[j]*times[i] + phases[j] +1.5*np.pi)/np.pi/2
        xmax = (omegas[j]*times[i]  + phases[j] - 0.5*np.pi)/ks[j]
        ax.plot(xmax,amplitudes[j]*np.sin(xmax), 'o')
        xmaxi += xmax
        ymaxi += amplitudes[j]*np.sin(xmax)
    ax.plot(x,ws,c='k')
    ax.plot(xmaxi/n,ymaxi,'o', c='k')
    ax.tick_params(axis='y',labelsize=12)
axs[-1].set_xlabel(r'$x/2\pi$',fontsize=14)
axs[-1].tick_params(axis='x', labelsize='12')
plt.subplots_adjust(hspace=0)
plt.savefig('group_velocity.eps', bbox_inches='tight')
plt.show()
