#!/usr/bin/env python


"""
   Python script to solve the PDEs describing three coupled oscillators 
   with masses $m_i$ and friction coefficients b_i.
   The coupling constant is constant for all oscillators. 
"""


from scipy.integrate import odeint
from numpy import loadtxt, savetxt, hsplit
import numpy as np
import matplotlib.pyplot as plt

#Definitions of constants of the problem (you can play around with them).
#Masses:
m1 = 1.; m2 = 1.; m3 = 1.
#friction coefficients
b1 = 0.25; b2 = 0.25; b3 = 0.25
#coupling constant between oscilators
D = 1.

#Now pack all constants into a "vector"
p = [m1, m2, m3, b1, b2, b3, D]


def eqs(x,t,p):
    """
    Define the function that needs to be integrated.
    x:  Vector with variables x_i which describe the current state
    t:  time
    p:  parameters (m_i, b_i, D)
    """

    x1, y1, x2, y2, x3, y3 = x
    m1, m2, m3, b1, b2, b3, D = p

    #Now describe the derivatives
    f = [y1,
         (-b1*y1 - 2*D*x1 + D*x2         )/m1,
         y2,
         (-b2*y2 + D*x1 - 2*D*x2 + D*x3  )/m3,
         y3,
         (-b3*y3          + D*x2 - 2*D*x3)/m1]
    return f


#Definition of the initial conditions
x1 = 0.   #All pendulums are at rest
y1 = 0.1  #Only pendulum 1 has an initial velocity
x2 = 0.
y2 = 0.
x3 = 0.
y3 = 0.

#Pack initial conditions inzo "vector" z0 (state 0)
z0 = [x1,y1,x2,y2,x3,y3]

# Define parameters for integration routine
abserr = 1.0e-8
relerr = 1.0e-6
stoptime = 25.0
numpoints = 750

# Define the time steps for which we want solutions
t = np.linspace(0, stoptime, numpoints)

#and now call the integration routine


z = odeint(eqs, z0, t, args = (p,), atol = abserr, rtol = relerr)
s1, sp1, s2, sp2, s3, sp3 = hsplit(z,6)
#save results
print(np.shape(s1[:,0]), np.shape(t))

print(s1[:,0])

savetxt('three-oscillators.dat', np.c_[t, s1[:,0], sp1[:,0], s2[:,0], sp2[:,0], s3[:,0], sp3[:,0]])

#plot solution
t, x1, y1, x2, y2, x3, y3 = loadtxt('three-oscillators.dat',unpack=True)

fig, ax = plt.subplots()
ax.plot(t, x1, 'k', lw=2, label=r'$x_1$')
ax.plot(t, x2, 'r', lw=2, label=r'$x_2$')
ax.plot(t, x3, 'b', lw=2, label=r'$x_3$')
ax.set_xlabel('time', fontsize=16)
ax.set_ylabel('amplitude', fontsize=16)
ax.tick_params(labelsize=14)
ax.legend(loc='upper right')
ax.set_title('Three coupled oscillators', fontsize=16)
plt.savefig('three-oscillators.png', bb_inches='tight', dpi=100)
plt.show()
