from __future__ import print_function

from . import absetup
import numpy
import subprocess
import os
import tempfile
import shutil
import utils
import stat
import logging

# The grid of spherical Marcs models.
# It is not complete though, not all combinations are available so some
# more work is needed below to find actual existing models.

tatm = [2500., 2600., 2700., 2800., 2900., 3000., 3100., 3200., 3300., 3400., 3500., 3600., 3700., 3800., 3900., 4000., 4250., 4500., 4750., 5000., 5250., 5500., 5750., 6000., 6250., 6500., 6750., 7000., 7250., 7500., 7750., 8000.]
#loggatm = [-0.5, 0., 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5]
#zatm = [-3., -2.5, -2.0, -1.5, -1.0, -0.75, -0.5, -0.25, 0.0, 0.25, 0.50, 0.75]

mspath = absetup.catpref+'/cats/Marcs/Standard_composition/Spherical'
mppath = absetup.catpref+'/cats/Marcs/Standard_composition/Plane-parallel'


def intpatm(teff, logg, mh, mstar, atmname, workdir='.', wait=True, atoms=[], abun=[], vturb=2.0):
# Interpolate in the MARCS grid to produce an atmosphere model with
# the desired parameters and write the output to 'atmname'
#
# Use the interpol_marcs.f programme provided by Marcs
# - The optimised interpolation is only valid for 3800 < T < 7000 K, whereas we 
#   need cooler stars. 
# - For stars outside that range we use linear interpolation. The assumption is
#   that this is still better than just picking the closest model. 
# - For points that fall outside the grid we fall back on just picking the closest
#   model.
# Also need some way to handle the *mass*, which becomes important
# for the spherical models. For old GCs, 1 Msun should be ok, but
# for younger clusters we need higher masses. 

    spherical = True
    if (logg > 3): spherical = False

    print("MARCS.INTPATM:")
    
    logging.info('MARCS: Searching for model atmosphere for Teff=%0.1f, Logg=%0.2f, [m/H]=%0.2f' % (teff, logg, mh))
    print("  Teff=%0.1f, logg=%0.2f, mh=%0.2f, mstar=%0.2f" % (teff, logg, mh, mstar)) 
    
    # Check temperature range
    
    if (teff < min(tatm)) or (teff > max(tatm)):
        print("    ERROR: Requested Teff outside available range (%0.0f K < Teff < %0.0f K)" % (min(tatm), max(tatm)))
        logging.info("    ERROR: Requested Teff outside available range (%0.0f K < Teff < %0.0f K)" % (min(tatm), max(tatm)))
        raise Exception('Requested Teff outside available range')

    # Find closest Teff on grid

    tbest = tatm[0]
    ii = itbest = 0
    for ti in tatm:
        if abs(ti-teff) < abs(tbest-teff): 
            tbest = ti
            itbest = ii
        ii += 1
    
    # Find 2nd-best matching Teff on grid
       
    t2best = tbest       
    if (tbest > teff):
        if (itbest > 0): t2best = tatm[itbest-1]
        Tefflow = t2best
        Teffup  = tbest
        atmlstlow = 'atmname2.lst'     # Atmospheres for second-best Teff
        atmlstup  = 'atmname.lst'      # Atmospheres for best-fitting Teff
        atmparlow = 'atmpar2.lst'
        atmparup  = 'atmpar.lst'
    else:
        if (itbest < len(tatm)-1): t2best = tatm[itbest+1]
        Tefflow = tbest
        Teffup  = t2best
        atmlstlow = 'atmname.lst'
        atmlstup  = 'atmname2.lst'
        atmparlow = 'atmpar.lst'
        atmparup  = 'atmpar2.lst'

    print("  Requested Teff=%0.0f K" % teff)        
    print("  Bracketing Marcs Teff values: %0.0f K and %0.0f K" % (Tefflow, Teffup) )
    logging.info("  Requested Teff=%0.0f K" % teff)
    logging.info("  Bracketing Marcs Teff values: %0.0f K and %0.0f K" % (Tefflow, Teffup) )

    
# Make lists of available atmospheres for relevant Teff. 
# Currently a mass of 1 Msun is hard-coded for the spherical models (_m1.0)
# The plane-parallel models are mass-independent and mass is set to 0 (_m0.0)
# We currently assume vturb=2.0 for the spherical models (_t02) and 1.0 for the plane-parallel ones (_t01)

    if (spherical):
        mpath = mspath
        os.system("ls %s/s%0.0f*mod* | sed s,'%s/',, | grep _m1 | sed -e s,_g,' ', -e s,'_m1.0_t02_st_z',' ', -e s,'_a.*',, > %s/atmpar.lst" % (mpath, tbest, mpath, workdir))
        os.system("ls %s/s%0.0f*mod* | sed s,'%s/',, | grep _m1 | sed s,'^','%s/', | cat -n > %s/atmname.lst" % (mpath, tbest, mpath, mpath, workdir))

        os.system("ls %s/s%0.0f*mod* | sed s,'%s/',, | grep _m1 | sed -e s,_g,' ', -e s,'_m1.0_t02_st_z',' ', -e s,'_a.*',, > %s/atmpar2.lst" % (mpath, t2best, mpath, workdir))
        os.system("ls %s/s%0.0f*mod* | sed s,'%s/',, | grep _m1 | sed s,'^','%s/', | cat -n > %s/atmname2.lst" % (mpath, t2best, mpath, mpath, workdir))
    else:
        mpath = mppath    
        os.system("ls %s/p%0.0f*mod* | sed s,'%s/',, | grep _m0 | sed -e s,_g,' ', -e s,'_m0.0_t01_st_z',' ', -e s,'_a.*',, > %s/atmpar.lst" % (mpath, tbest, mpath, workdir))
        os.system("ls %s/p%0.0f*mod* | sed s,'%s/',, | grep _m0 | sed s,'^','%s/', | cat -n > %s/atmname.lst" % (mpath, tbest, mpath, mpath, workdir))

        os.system("ls %s/p%0.0f*mod* | sed s,'%s/',, | grep _m0 | sed -e s,_g,' ', -e s,'_m0.0_t01_st_z',' ', -e s,'_a.*',, > %s/atmpar2.lst" % (mpath, t2best, mpath, workdir))
        os.system("ls %s/p%0.0f*mod* | sed s,'%s/',, | grep _m0 | sed s,'^','%s/', | cat -n > %s/atmname2.lst" % (mpath, t2best, mpath, mpath, workdir))


    # Find the logg values bracketing (or closest to) the requested value
    
    logglst, zlst = utils.getcol('%s/atmpar.lst' % workdir, (1,2))     # Best Teff
    logglst2, zlst2 = utils.getcol('%s/atmpar2.lst' % workdir, (1,2))  # 2nd best Teff

    loggunq = []
    for loggi in logglst:
        if not loggi in loggunq: loggunq.append(loggi)
    loggunq.sort()    
    print("  Unique logg values for Teff=%0.0f K: " % tbest, end='')
    print(loggunq)
    
    loggbest = loggunq[0]
    ii = iloggbest = 0
    for loggi in loggunq:
        if abs(loggi-logg) < abs(loggbest-logg): 
            loggbest = loggi
            iloggbest = ii
        ii += 1
            
    logg2best = loggbest       # Find 2nd-best matching logg
    if (loggbest > logg):
        if (iloggbest > 0): logg2best = loggunq[iloggbest-1]
    else:
        if (iloggbest < len(loggunq)-1): logg2best = loggunq[iloggbest+1]


    # Find closest LogZ match
    # We may fall back on this later.
    
    nn = numpy.array(range(1,len(logglst)+1))
    wg = numpy.where((logglst == loggbest))
    zw = zlst[wg]
    nnw = nn[wg]

    zbest = zw[0]
    nnbest = nnw[0]
    for zi,nni in zip(zw,nnw):
        if abs(zi-mh) < abs(zbest-mh): 
            zbest = zi
            nnbest = nni
    
    with open('%s/atmname.lst' % workdir, 'r') as f:
        for l in f:
            ll = l.split()
            i = int(ll[0])
            if (i == nnbest): fnbest = ll[1]
     
    
    # Check if requested logg falls within interpolation range.
    # If not, use closest matching atmosphere from grid.    
    # If yes, attempt to set up interpolation grid. This may still fail if
    # some corners of the cube fall outside the grid.
        
    print("  Requested logg=%0.2f" % logg)
    logging.info("  Requested logg=%0.2f" % logg)
    
    if (logg < min(loggbest, logg2best)) or (logg > max(loggbest, logg2best)):    
        print("  Requested logg outside available range (%0.3f < logg < %0.3f)" % (min(loggunq), max(loggunq)))
        logging.info("  Requested logg outside available range (%0.3f < logg < %0.3f)" % (min(loggunq), max(loggunq)))
        models = [None]
    else:
        print("  Bracketing Marcs logg values: %0.2f and %0.2f" % (min(loggbest, logg2best), max(loggbest,logg2best)))    
        logging.info("  Bracketing Marcs logg values: %0.2f and %0.2f" % (min(loggbest, logg2best), max(loggbest,logg2best)))
    

        # Find bracketing LogZ values. 
        # If these are matched at all Teff, logg combinations then we're good.    
    
        print("  Requested logZ=%0.2f" % mh)
        logging.info("  Requested logZ=%0.2f" % mh)

        zlow, zup = min(zw), max(zw)
	    
        if (mh < zlow) or (mh > zup):
            print("  Requested logZ outside available range (%0.3f < logZ < %0.3f)" % (zlow, zup))
            logging.info("  Requested logZ outside available range (%0.3f < logZ < %0.3f)" % (zlow, zup))
            models = [None]
        else:
        
            for zi, nni in zip(zw, nnw):
                if (abs(zi-mh) < abs(zlow-mh)) and (zi <= mh):
                    zlow = zi
                if (abs(zi-mh) < abs(zup-mh)) and (zi >= mh):
                    zup = zi
	    
            print("  Bracketing Marcs logZ values: %0.2f and %0.2f" % (zlow, zup))
            logging.info("  Bracketing Marcs logZ values: %0.2f and %0.2f" % (zlow, zup))
	                      
            # Find the atmosphere files for each (Teff, logg, LogZ) combination
          
            models = []
	        
            for _atmpar, _atmlst in zip((atmparlow, atmparup), (atmlstlow, atmlstup)):
    
                _logglst, _zlst = utils.getcol('%s/%s' % (workdir, _atmpar), (1,2))
                nn = numpy.array(range(1,len(_logglst)+1))
            
                for _logg in (min(loggbest, logg2best), max(loggbest, logg2best)):
            
                    wg = numpy.where((_logglst == _logg))
                    zw = _zlst[wg]
                    nnw = nn[wg]
                    print(zw)
                    print(nnw)
	                
                    if (zlow in zw) and (zup in zw):
                        nnlow, nnup = nnw[list(zw).index(zlow)], nnw[list(zw).index(zup)]
                
                        for _nn in (nnlow, nnup):
        
                            with open('%s/%s' % (workdir, _atmlst), 'r') as f:
                                for l in f:
                                    ll = l.split()
                                    i = int(ll[0])
                                    if (i == _nn): fn = ll[1]
	                    
                            print('Model = %s' % fn)
                            jj = fn.rfind('/')+1                                        
	                                         
                            if (fn[-3:] == '.gz'):
                                modout = fn[jj:-3]
                                os.system('gzip -d < %s > %s/%s' % (fn, workdir, modout))
                            else:
                                modout = fn[jj:]
                                os.system('cp -f %s %s/%s' % (fn, workdir, modout)) 
                            
                            models.append(modout)                      
                    else:
                        print("   One or both LogZ values missing for logg=%0.2f" % _logg)
                        logging.info("   One or both LogZ values missing for logg=%0.2f" % _logg)
                        models.append(None)
                        models.append(None)

    
#    print 'Models:'
#    for _model in models: print _model

    cwd = os.getcwd()
    fi = open(workdir+'/interpol.com','w')
    fi.write('#!/bin/bash\n')
    fi.write('cd '+workdir+'\n')
    
    if (None in models):    # Something went wrong, pick the closest model
        print("    Failed to set up interpolation grid.")
        print("    Using closest model, %s" % fnbest)
        logging.info("    Failed to set up interpolation grid.")
        logging.info("    Using closest model, %s" % fnbest)

        if (fnbest[-3:] == '.gz'):
            fi.write('gzip -d < %s > %s/%s.marcs\n' % (fnbest, cwd, atmname))
        else:
            fi.write('cp -f %s %s/%s.marcs\n' % (fnbest, cwd, atmname))
    else:                   # We have the required 8 models, now interpolate
        logging.info("  Interpolating in these Marcs models:")
        for _model in models: 
            logging.info("    %s" % _model)
        
        if ((teff > 3800) and (teff < 7000) and
            (logg > 0.0) and (logg < 5.0) and
            (mh < 0) and (mh > -4) ):
            fi.write(absetup.binpath+'/interpol_modeles <<EOF\n')
        else:
            fi.write(absetup.binpath+'/interpol_modeles_lin <<EOF\n')    

#        for _model in models:
#            fi.write('\'%s/%s\'\n' % (workdir, _model))
        for _model in models:
            fi.write('\'%s\'\n' % (_model))

        fi.write('\'%s/%s.marcs\'\n' % (cwd, atmname))
        fi.write('\'interp.out2\'\n')    
        fi.write('%0.1f\n' % teff)
        fi.write('%0.3f\n' % logg)
        fi.write('%0.3f\n' % mh)
        fi.write('.false.\n')
        fi.write('.false.\n')
        fi.write('\'dummy\'\n')
        fi.write('EOF\n')    
        
    fi.close()
    os.chmod(workdir+'/interpol.com',stat.S_IXUSR | stat.S_IRUSR | stat.S_IWUSR)
       
    
    p = subprocess.Popen(workdir+'/interpol.com')
    
    if (wait):
        os.waitpid(p.pid,0)
    else:
        return p

