''' 
# online material accompanying the paper: Chruslinska, M. & Nelemans, G., MNRAS, 2019 (DOI: 10.1093/mnras/stz2057)
# this python script allows to draw metallicity from the star formation rate density weighted distribution of metallicities
# at a given redshift
# usage: python sample_metallicity.py
# adjust the data input file name (input_file) & redshift (z); see example use at the bottom
# need 'Time_redshift_deltaT.dat' and input file with the data for a chosen variation (any of the '*FOH_z_dM.dat')
'''
import numpy as np
import matplotlib.pyplot as plt

def solar_metallicity_scales():
    Asplund09=[0.0134,8.69]
    AndersGrevesse89=[0.017,8.83]
    GrevesseSauval98=[0.0201,8.93]
    Villante14=[0.019,8.85]
    scale_ref=np.array(['Asplund09','AndersGrevesse89','GrevesseSauval98','Villante14'])
    Z_FOH_solar=np.array([Asplund09,AndersGrevesse89,GrevesseSauval98,Villante14])
    return scale_ref, Z_FOH_solar

def FOH2ZZ(foh,solar_Z_scale='AndersGrevesse89'):
    '''convert from 12+log[O/H] to ZZ'''
    scale_ref, Z_FOH_solar=solar_metallicity_scales()
    idx=np.where(scale_ref==solar_Z_scale)[0][0]
    Zsun,FOHsun = Z_FOH_solar[idx]    
    logZ = np.log10(Zsun) + foh - FOHsun
    ZZ=10**logZ
    return ZZ

def ZZ2FOH(zz,solar_Z_scale='AndersGrevesse89'):
    '''convert from ZZ to 12+log[O/H] '''
    scale_ref, Z_FOH_solar=solar_metallicity_scales()
    idx=np.where(scale_ref==solar_Z_scale)[0][0]
    Zsun,FOHsun = Z_FOH_solar[idx]
    foh = np.log10(zz)-np.log10(Zsun)+FOHsun
    return foh

#(array) oxygen to hydrogen abundance ratio ( FOH == 12 + log(O/H) )
# as used in the calculations - do not change
FOH_min, FOH_max = 5.3, 9.7
FOH_arr = np.linspace( FOH_min,FOH_max, 200)
dFOH=FOH_arr[1]-FOH_arr[0]

def get_data(input_file,zmin=0.,zmax=4):

    #read time, redshift and timestep as used in the calculations
    #starts at the highest redshift (z=z_start=10) and goes to z=0
    time, redshift_global, delt = np.loadtxt('Time_redshift_deltaT.dat',unpack=True) 
    #reading mass per unit (comoving) volume formed in each z (row) - FOH (column) bin
    data=np.loadtxt(input_file)
    image_data=np.array( [data[ii]/(1e6*delt[ii]) for ii in range(len(delt))] )#fill the array with SFRD(FOH,z)

    redshift=redshift_global
    #select the interesting redshift range
    if( zmax!=10 or zmin!=0 ):
        idx= np.where(np.abs(np.array(redshift)-zmax)==np.abs(np.array(redshift)-zmax).min())[0][0] 
        idx0= np.where(np.abs(np.array(redshift)-zmin)==np.abs(np.array(redshift)-zmin).min())[0][0] 
        image_data = image_data[idx:idx0]
        redshift=redshift_global[idx:idx0]
        delt=delt[idx:idx0]

    image_data/=dFOH
    return image_data, redshift, delt

def prepare_CDF(SFRD_data):

    #NORMALIZE the input data
    mtot_z = [np.sum(SFRD_data[:][ii]) for ii in range(SFRD_data.shape[0])]
    SFRD_normed = np.array([ [(SFRD_data[ii][j])/mtot_z[ii] for\
                     j in range(SFRD_data.shape[1])] for ii in range(SFRD_data.shape[0])]) 
    #CALCULATE the cumulative sum of the data
    Z_cumsum = np.array([ [np.sum(SFRD_normed[ii][:j]) for j in range(SFRD_data.shape[1])]\
                             for ii in range(SFRD_data.shape[0])])
    Z_cumsum=np.transpose(Z_cumsum)
    return Z_cumsum

def sample_SFRD_z(z,input_file,metallicity_measure,solar_Z_scale='AndersGrevesse89',n=1e5):
    image,redshift,delt = get_data(input_file,zmin=0,zmax=10)
    Z_cumsum = prepare_CDF(image)
    iz=np.where( np.abs(redshift-z)==np.min( np.abs(redshift-z) ) )[0][0]
    Z_cumsum=Z_cumsum[:,iz]
    r=np.random.random(int(n))
    foh=[]
    for ri in r:
        foh.append(FOH_arr[np.where(np.abs(ri-Z_cumsum)==np.min(np.abs(ri-Z_cumsum)))][0])
    if(metallicity_measure=='Z'): metallicities=FOH2ZZ(foh,solar_Z_scale='AndersGrevesse89')
    else: metallicities=foh
    return metallicities

''' Example use: '''
#Choose model variation (one of the files '*_FOH_z_dM.dat')
input_file='low-Z_extreme_FOH_z_dM.dat'
#Choose redshift (<10)
z=0
#Choose metallicity measure (either 12+log(O/H) -- 'Z_OH' or metal mass fraction -- 'Z'
metallicity_measure='Z_OH'
#Solar metallicity scale (used to convert Z_OH to Z if metallicity_measure=='Z')
solar_Z_scale='AndersGrevesse89'
#draw the sample of metallicities from the star formation rate density weighted distribution of metallicities at redshift=z
metallicities=sample_SFRD_z(z,input_file,metallicity_measure,solar_Z_scale=solar_Z_scale)

#Show histogram:
if(metallicity_measure=='Z'): 
    bins=FOH2ZZ(FOH_arr, solar_Z_scale)
    plt.xscale('log')
else: bins=FOH_arr
plt.hist(metallicities,bins=bins)
plt.xlabel('metallicity; z='+str(z))
plt.tight_layout()
plt.show()

