#!/bin/env python

import os
import sys
import optparse
import xml.dom.minidom


def getText(nodelist):
    rc = []
    for node in nodelist:
        if node.nodeType == node.TEXT_NODE:
            rc.append(node.data)
    return ''.join(rc)

def getDataValues(tag):
    values = []
    text = getText(tag.childNodes)
    lines = text.split()
    for line in lines:
        if not len(line):
            continue
        values += [float(line)]
    return values

def getOpts():
    op = optparse.OptionParser(usage="Usage: %prog [options] <pdos.xml>")
    op.add_option("-s", "--split", dest="doSplit" , action="store_true", default=False, \
                  help="save each orbital information to a gnuplot compatible file.")

    (options, args) = op.parse_args()
    if not len(args):
        op.error("incorrect number of arguments")
        return 1
    filename = args[0]
    if not os.path.exists(filename):
        op.error("can't find file %s" % filename)
        return 2
    if not filename[-4:] in ['.xml', 'PDOS']:
        op.error("xml or PDOS file required")
        return 3
    return (filename, options.doSplit)

def main():
    opts = getOpts()
    if type(opts) == type(int()):
        return opts
    (filename, doSplit) = opts

    dom = xml.dom.minidom.parse(filename)

    tag_nspin        = dom.getElementsByTagName("nspin"        )[0]
    # tag_norbitals    = dom.getElementsByTagName("norbitals"    )[0]
    tag_EnergyValues = dom.getElementsByTagName("energy_values")[0]

    energy_units = tag_EnergyValues.getAttribute("units")

    text = getText(tag_nspin.childNodes)
    nspin = int(text)

    # text = getText(tag_norbitals.childNodes)
    # norbitals = int(text)

    energy_values = getDataValues(tag_EnergyValues)

    if not doSplit:
        print "Total number of spin components =", nspin
        # print "Total number of orbitals =", norbitals
        print "Orbitals:"

    tag_orbitals = dom.getElementsByTagName("orbital")

    class Tmc: None
    orbitals = []
    for tag_orbital in tag_orbitals:
        orbital_index = int(tag_orbital.getAttribute("index"))
        atom_index    = int(tag_orbital.getAttribute("atom_index"))
        species       =     tag_orbital.getAttribute("species")
        data_pos      =     tag_orbital.getAttribute("position")
        items = data_pos.split()
        position = (float(items[0]), float(items[1]), float(items[2]))
        n = int(tag_orbital.getAttribute("n"))
        l = int(tag_orbital.getAttribute("l"))
        m = int(tag_orbital.getAttribute("m"))
        z = int(tag_orbital.getAttribute("z"))

        tag_data = tag_orbital.getElementsByTagName("data")[0]
        all_spins_pdos_values = getDataValues(tag_data)

        # TODO: because of the absence of a spin quantum number in a pdos.xml 
        # in an explicit way we are doing extra parse work here to introduce it
        # in a usual form: spin/n/l/m/z
        for ispin in range(0, nspin):
            pdos_values = []
            for i in range(0, len(energy_values)):
                pdos_values += [all_spins_pdos_values[i*nspin + ispin]]

            if not doSplit:
                print "%4i%4i%4s" % (orbital_index, atom_index, species), \
                    "(%12.6f%12.6f%12.6f)" % position, "spin=%1i,n=%1i,l=%1i,m=%2i,z=%1i" % (ispin, n, l, m, z)
                continue
            mc = Tmc()
            mc.atom_index      = atom_index
            mc.species         = species
            mc.position        = position
            mc.quantum_numbers = (ispin, n, l, m, z)
            mc.pdos_values     = pdos_values
            orbitals += [mc]
    if not doSplit:
        return 0

    for mc in orbitals:
        print mc.quantum_numbers

        output = "%s%i_" % (mc.species, mc.atom_index) + "spin=%i_n=%i_l=%i_m=%i_z=%i" % mc.quantum_numbers
        f = open(output, 'w')
        for i in range(0, len(energy_values)):
            f.write("%12.6f%12.6f\n" % (energy_values[i], mc.pdos_values[i]))
        f.close()
    return 0

if __name__ == "__main__":
    sys.exit(main())

