import matplotlib.pyplot as plt
import numpy as np
import matplotlib.animation as an

fig, ax = plt.subplots()
#ax = fig.gca()

len_x = 800
x = np.linspace(0., 80.*np.pi, len_x)
ampl = np.array([1., 1.])
oms = np.array([1.07, 1.13])
ks = np.array([0.95, 1.05])

t = 0
y = ampl[0]*np.sin(oms[0]*t - ks[0]*x) + ampl[1]*np.sin(oms[1]*t - ks[1]*x)
y2 = (ampl[0]+ampl[1])*np.cos((oms[1]-oms[0])*t/2 - (ks[1]-ks[0])*x/2)

line, = ax.plot(x, y)
line2, = ax.plot(x, y2, c='k')
line3, = ax.plot(x, -y2, c='k')

ax.set_xlabel('z', fontsize=14)
ax.set_ylabel('x', fontsize=14)
ax.tick_params(labelsize=12)

#plt.savefig('wave_animation.eps',bounding_box='tight', bbox_inches=None, transparen=True)


def update(i):
    y = ampl[0]*np.sin(oms[0]*i - ks[0]*x) + ampl[1]*np.sin(oms[1]*i - ks[1]*x)
    y2 = (ampl[0]+ampl[1])*np.cos((oms[1]-oms[0])*i/2 - (ks[1]-ks[0])*x/2)
    line.set_ydata(y)
    line2.set_ydata(y2)
    line3.set_ydata(-y2)

ani = an.FuncAnimation(fig, update, frames= 400, interval=50)

ani.save('wave_animation.mp4')
plt.show()

