import matplotlib.pyplot as plt 
import numpy as np
import sys
from matplotlib import ticker
import matplotlib.colors as colors
from scipy.optimize import curve_fit
plt.rc('font', family='serif')
import matplotlib
matplotlib.rcParams['text.latex.preamble']=[r"\usepackage{amsmath}"]
matplotlib.rcParams['axes.linewidth'] = 4.0

matplotlib.rcParams['xtick.major.pad'] = 10
matplotlib.rcParams['ytick.major.pad'] = 10
myfontsize=22
ticks_width = 2.5
ticks_length = 14


# metallicities to choose from
ZZ_all = [1.0, 0.4, 0.2, 0.1, 0.04, 0.01] 
# masses to choose from:
masses_all = [10, 12, 14, 16, 18, 20, 22.5, 25, 27.5, 30, 32.5, 35, 37.5, 40,45,42.5, 47.5, 50,52.5, 55, 57.5, 60, 65, 67.5, 72.5, 75, 77.5, 80.0] 
# exceptions: 
#   masses > 60 Msun at Solar metallicity (they do not expand in their post-MS evolution)
#  52.5 Msun at 0.01 Zsun
#  67.5 Msun at 0.01 Zsun


# CHOOSE MASSES AND METALLICITIES HERE
# masses=[35.0, 37.5, 40.0, 42.5, 45.0, 47.5,50.0]
# masses=[50.0,55.0, 60.0, 65.0, 67.5,72.5,80.0]
masses=[30.0]
ZZ = [0.1]


def cubic(x,a,b,c,d):
    return a*x*x*x+b*x*x+c*x+d


def conv_thresh_Teff(logL,Z):
    # Z in Zsun units!
    b1 = [-0.006, 0.0596, -0.1637]
    b2 = [-0.0066, 0.0587, -0.1967]
    b3 = [0.0173, -0.194, 4.0962]

    a1 = b1[0] * np.log10(Z)**2. + b2[0] * np.log10(Z) + b3[0]
    a2 = b1[1] * np.log10(Z)**2. + b2[1] * np.log10(Z) + b3[1]
    a3 = b1[2] * np.log10(Z)**2. + b2[2] * np.log10(Z) + b3[2]

    return 10.**(a1 * logL**2. + a2 * logL + a3)


def decorate_axis(ax):
    ax.tick_params(axis='x', labelsize=myfontsize)
    ax.tick_params(axis='y', labelsize=myfontsize)
    ax.xaxis.set_tick_params(width=ticks_width,length=ticks_length)
    ax.yaxis.set_tick_params(width=ticks_width,length=ticks_length)
    ax.yaxis.set_tick_params(width=ticks_width,length=ticks_length/2.,which='minor')
    ax.xaxis.set_tick_params(width=ticks_width,length=ticks_length/2.,which='minor')
    ax.tick_params(axis='x',which='both',bottom='on', top='on',labelbottom='on',direction='in')
    ax.tick_params(axis='y',which='both',right='on', left='on',labelbottom='on',direction='in')


fig=plt.figure(figsize=(11,6.5))

w,h=fig.get_size_inches()
fig.set_size_inches(w,h+2.2, forward=True)
ttop,lleft,rright,bbottom=0.98,0.12,0.98,0.08
plt.subplots_adjust(left=lleft, bottom=bbottom, right=rright, top=ttop, wspace=0.01, hspace=0.5)

ax1 = plt.subplot2grid((2, 3), (0, 0), colspan=3, rowspan=2)

for Mi in masses:
    for Z in ZZ:

        ## ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ##
        # plotting lambda_CE as a function of radius from the data:
        ## ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ##

        # loading data from the chosen model:
        try:
            star_age, star_mass, center_h1, center_he4,\
            center_c12, radius, log_L, log_Teff, conv_env_mass_fractions,\
            E_bind, E_grav, lambda_CE, M_remnant, R_remnant  = np.loadtxt("./DATA/%s_%s_CE_data.dat" % (Mi,Z),unpack="True")
        except:
            print("\nNo such model is available. Please see read.me for a list of available masses and metallicities.\n")
            exit()

        wt_MS = np.where(M_remnant<0.0) # set to -1.0 during MS
        wt_postMS = np.where(M_remnant>0.)
        max_MS_radius = max(radius[wt_MS])
        min_postMS_radius = min(radius[wt_postMS])
        wt_increasing_radius_postMS = [] # relevant for RLOF 
        prev_max_radius = max_MS_radius
        for i in range(len(radius[wt_postMS])):
            if (radius[wt_postMS][i]>prev_max_radius): 
                prev_max_radius = radius[wt_postMS][i]
                wt_increasing_radius_postMS.append(i)
        wt_increasing_radius_postMS = (np.asarray(wt_increasing_radius_postMS),)


        # plotting all the data from the track (dashed line)
        if (Mi==masses[0]) and (Z==ZZ[0]):
            ax1.plot(radius[wt_postMS],lambda_CE[wt_postMS],lw=2,ls='-.',color='black',label='stellar track\n(post-MS part)')   
        else: ax1.plot(radius[wt_postMS],lambda_CE[wt_postMS],lw=2,ls='-.',color='black')   
        # plotting only points relevant for RLOF (i.e. increasing radius) -- in bold
        if (Mi==masses[0]) and (Z==ZZ[0]):
            ax1.plot(radius[wt_postMS][wt_increasing_radius_postMS][0],lambda_CE[wt_postMS][wt_increasing_radius_postMS][0],lw=3,color='black',label='part relevant\nfor RLOF')
        ax1.scatter(radius[wt_postMS][wt_increasing_radius_postMS],lambda_CE[wt_postMS][wt_increasing_radius_postMS],s=10,color='black')



        ## ~~~~~~~~~~~~~~~~~~~~~~~ ##
        # plotting a fit to the data:
        ## ~~~~~~~~~~~~~~~~~~~~~~~ ##


        M,Zi, R12, R23, Rmax, a1, b1, c1, d1, a2, b2, c2, d2, a3, b3, c3, d3 =\
                                   np.loadtxt("./lambda_R_fit.dat",unpack="True")

        M,R12, R23, Rmax, a1, b1, c1, d1, a2, b2, c2, d2, a3, b3, c3, d3 =\
        M[np.where(Zi==Z)], R12[np.where(Zi==Z)], R23[np.where(Zi==Z)], Rmax[np.where(Zi==Z)],\
        a1[np.where(Zi==Z)], b1[np.where(Zi==Z)], c1[np.where(Zi==Z)], d1[np.where(Zi==Z)],\
        a2[np.where(Zi==Z)], b2[np.where(Zi==Z)], c2[np.where(Zi==Z)], d2[np.where(Zi==Z)],\
        a3[np.where(Zi==Z)], b3[np.where(Zi==Z)], c3[np.where(Zi==Z)], d3[np.where(Zi==Z)] 

        print(Mi)
        i_fit = np.where(M==Mi)[0][0]
        p1_ft = [a1[i_fit], b1[i_fit],c1[i_fit],d1[i_fit]]
        p2_ft = [a2[i_fit], b2[i_fit],c2[i_fit],d2[i_fit]]
        p3_ft = [a3[i_fit], b3[i_fit],c3[i_fit],d3[i_fit]]
        l_R12_ft=np.log10(R12[i_fit])
        l_R23_ft=np.log10(R23[i_fit])
        l_Rmax_ft = np.log10(Rmax[i_fit])
        xx1=np.linspace(np.log10(min_postMS_radius),l_R12_ft)
        # xx1=np.linspace(np.log10(100.0),l_R12_ft)
        xx2=np.linspace(l_R12_ft,l_R23_ft)
        xx3=np.linspace(l_R23_ft,l_Rmax_ft)
        #PLOTTING LINES FROM FITS
        lw=5
        if(Mi==masses[0]) and (Z==ZZ[0]):
           ax1.plot(10**xx1, 10**cubic(xx1,*p1_ft),lw=lw,ls='--',c='blue',zorder=3,alpha=0.5,label=r'$R<R_{12}$')
           ax1.plot(10**xx2, 10**cubic(xx2,*p2_ft),lw=lw+2,ls='--',c='orange',zorder=3,alpha=0.5,label=r'$R_{12}<R<R_{23}$')
           ax1.plot(10**xx3, 10**cubic(xx3,*p3_ft),lw=lw+4,ls='--',c='red',zorder=3,alpha=0.7,label=r'$R_{23}<R<R_{\rm max}$')
        else:
           ax1.plot(10**xx1, 10**cubic(xx1,*p1_ft),lw=lw,ls='--',c='blue',zorder=3,alpha=0.5)
           ax1.plot(10**xx2, 10**cubic(xx2,*p2_ft),lw=lw+2,ls='--',c='orange',zorder=3,alpha=0.5)
           ax1.plot(10**xx3, 10**cubic(xx3,*p3_ft),lw=lw+4,ls='--',c='red',zorder=3,alpha=0.7)
        ax1.annotate(r'Z/Z$_{\odot}$='+str(Z) , xy=(60,3),\
                xycoords='data',fontsize=myfontsize,color='k',weight='bold',rotation = 0)
        if(p3_ft[0]!=0):
                ax1.annotate(r'M/M$_{\odot}$='+str(Mi), xy=(10**xx3[-1],1.3*10**cubic(xx3[-1],*p3_ft)),\
                        xycoords='data',fontsize=myfontsize,color='k',weight='bold',rotation = 0)
        else:
                ax1.annotate(r'M/M$_{\odot}$='+str(Mi), xy=(10**xx2[-1],1.3*10**cubic(xx2[-1],*p2_ft)),\
                        xycoords='data',fontsize=myfontsize,color='k',weight='bold',rotation = 0)

ax1.set_yscale('log')
ax1.set_xscale('log')
ax1.set_xlim([10,20000])
ylims = ax1.get_ylim()
ax1.set_ylim([ylims[0],min(7.0,ylims[1])])
plt.legend(loc='lower left',fontsize=myfontsize-3,frameon=False)

ax1.set_ylabel(r'$\lambda_{\rm CE}$',fontsize=myfontsize+5)
ax1.set_xlabel(r'R [R$_{\odot}$]',fontsize=myfontsize+2)
ticklabsize=20
plt.tick_params(axis='x', which='major', labelsize=ticklabsize)
plt.tick_params(axis='y', which='major', labelsize=ticklabsize)
decorate_axis(ax1)
plt.tight_layout()
plt.savefig('./fit_example.png',dpi=300)
plt.show()

