#!/usr/bin/env python
#
# Copyright (C) 2017 Michael Janssen
#
# This library is free software; you can redistribute it and/or modify it
# under the terms of the GNU Library General Public License as published by
# the Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This library is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Library General Public
# License for more details.
#
# You should have received a copy of the GNU Library General Public License
# along with this library; if not, write to the Free Software Foundation,
# Inc., 675 Massachusetts Ave, Cambridge, MA 02139, USA.
#
"""
Auxiliary stuff: utilities that are useful in general.
"""
#from __future__ import print_function
#from builtins import input
import os
import re
import sys
import copy
import time
import pickle
import pyfits
import random
import shutil
import string
import datetime
import resource
import subprocess
import distutils.dir_util
import numpy as np
import scipy.signal as scsig
from glob import glob
from time import gmtime, strftime
from collections import namedtuple, defaultdict
from scipy.interpolate import interp1d
from pipe_modules.default_casa_imports import *


global_uvfits_mountcodes     = {'ALT-AZ': 0, 'EQUATORIAL': 1, 'ORBITING': 2, 'X-Y': 3,
                                'ALT-AZ+NASMYTH-R': 4, 'ALT-AZ+NASMYTH-L': 5
                               }
first_mem_request            = ()
flat_mem0_mpi                = 0
max_mem_to_resources_scaling = 1
max_mem_usage                = 0
all_processes_running_stable = 0
exhaustive_memory_usage_sum  = 0


######################################### --------------- PART 1 --------------- #########################################
######################################### functions with a general functionality #########################################


#def input_stable(msg):
#    """
#    Works as input() for python3 and raw_input() for python2.
#    """
#    inp = input(msg)
#    assert isinstance(inp, str)

#def print_continue(string):
#    """
#    Print function that does not write a new line, which works in both python 2 and 3.
#    """
#    print (string, end='')


class space_print(object):
    """
    Can use this to replace the standard instance of sys.stdout to prepend print0 to print().
    """
    def __init__(self, print0 = '  '):
        self.print0 = print0
    def write(self, s):
        sys.__stdout__.write(self.print0+s)
    def flush(self):
        sys.__stdout__.flush()


def progress_print(current_counter, max_counter, small_step_counter=0, big_step_counter=0):
    """
    Continuous progress print on same line.
    """
    big_step   = 10
    small_step = 2
    if current_counter == max_counter-1:
        sys.stdout.write('100%')
        sys.stdout.write('\n')
        return 999, 999
    this_percentage = int(100. * current_counter / max_counter)
    next_small_step = (small_step_counter + 1) * small_step
    next_big_step   = big_step_counter * big_step
    if this_percentage == next_big_step:
        sys.stdout.write('{}%'.format(str(next_big_step)))
        sys.stdout.flush()
        big_step_counter   += 1
        small_step_counter += 2
    elif this_percentage == next_small_step:
        sys.stdout.write('.')
        sys.stdout.flush()
        small_step_counter += 1
    return small_step_counter, big_step_counter


def input23(_msg):
    """Keyboard input for both python3 and 2.7."""
    if sys.version_info >= (3, 0):
        _inp = input(_msg)
    else:
        _inp = raw_input(_msg)
    return _inp


def get_git_version(_path_to_main_script):
    """Get version from git tag."""
    _path_to_git = '--git-dir=' + _path_to_main_script + '/../.git'
    if not isdir(_path_to_main_script + '/../.git'):
        return '(unknown version)'
    try:
        return subprocess.check_output(['git', _path_to_git, 'describe', '--always', '--tag']).strip().decode()
    except:
        return '(unknown version)'


def handle_cmd_args():
    """
    Tries to mimic some of python's argparse functionality for $casa -c ...
    """
    #can combine multiple command line arguments with a single hyphen:
    these_argv             = split_single_hyphens(sys.argv)
    known_cmd_args         = ['casa', 'main_picard.py']
    pass_next              = ''
    pipedir                = None
    casalogfile            = None
    mpi_and_err_logfile    = None
    got_new_input          = False
    usequickmode           = False
    _no_fringe_params_load = False
    _no_ms_metadata_load   = False
    _force_restore_flags   = False
    _new_listf             = False
    _use_previous_diagdir  = False
    _scrap_old_caltb       = False
    _interactive_mode      = False
    _help_and_exit         = False
    for i,sargv in enumerate(these_argv):
        if sargv=='--help' or sargv=='-h' or sargv=='-help' or sargv=='--h':
            _help_and_exit = True
        elif sargv == '--logfile':
            #opening an already opened file and deleting the content:
            pass_next = these_argv[i+1]
            with open(these_argv[i+1], 'w') as logfile:
                logfile.seek(0)
                logfile.truncate()
        elif sargv == '--pipedir':
            pipedir   = os.path.join(these_argv[i+1], '').rstrip('/')
            pass_next = these_argv[i+1]
        elif sargv == '--caslogf':
            casalogfile = os.path.abspath(these_argv[i+1])
            pass_next   = these_argv[i+1]
        elif sargv == '--errlogf':
            mpi_and_err_logfile = os.path.abspath(these_argv[i+1])
            pass_next           = these_argv[i+1]
        elif sargv == '--input':
            input_folder  = os.path.join(these_argv[i+1], '')
            pass_next     = these_argv[i+1]
            got_new_input = True
        elif sargv == '-p':
            inpf1 = os.path.join('input', '')
            inpf2 = os.path.join('input_template', '')
            if isdir(inpf1):
                input_folder = inpf1
            elif isdir(inpf2):
                input_folder = inpf2
            else:
                raise IOError('No input folders (input/ or input_template/) found in the pwd.')
            got_new_input = True
        elif sargv=='-r':
            _force_restore_flags = True
            try:
                if str(these_argv[i+1]) == 'a':
                    _force_restore_flags = 'a'
                    pass_next            = these_argv[i+1]
            except IndexError:
                pass
        elif sargv=='-q' or sargv=='--quick' or sargv=='--q' or sargv=='-quick':
            try:
                quicklist    = these_argv[i+1]
                pass_next    = these_argv[i+1]
                quicklist    = quicklist.split(',')
                formatted    = []
                for item in quicklist:
                    if '-' in item:
                        quicklist.remove(item)
                if not quicklist:
                    formatted = 'Abort'
                for i,item in enumerate(quicklist):
                    if '~' in item:
                        iitem = item.split('~')
                        iitem = range(int(iitem[0]), int(iitem[1])+1)
                        formatted.extend([str(it) for it in iitem])
                    elif item == 'x':
                        formatted.extend(['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q'])
                    else:
                        formatted.append(item)
            except IndexError:
                formatted = 'Abort'
            usequickmode = True
        elif sargv=='-f':
            _no_fringe_params_load = True
        elif sargv=='-m':
            _no_ms_metadata_load = True
        elif sargv=='-s':
            _scrap_old_caltb = True
        elif sargv=='-l':
            _new_listf = True
            try:
                if str(these_argv[i+1]) == 'e':
                    _new_listf = 'exit'
                    pass_next  = these_argv[i+1]
            except IndexError:
                pass
        elif sargv=='-d':
            _use_previous_diagdir = True
        elif sargv=='-i':
            _interactive_mode = True
        elif sargv=='-c':
            pass_next = these_argv[i+1]
        elif sargv=='-n':
            pass_next = these_argv[i+1]
        elif sargv and not check_array_for_match(sargv, known_cmd_args) and sargv!=pass_next:
            raise SyntaxError('Got an unknown command line argument: ' + str(sargv))
    if got_new_input:
        got_new_input = input_folder
    if usequickmode:
        usequickmode = formatted
    return pipedir, casalogfile, mpi_and_err_logfile, got_new_input, usequickmode, _no_fringe_params_load, \
           _no_ms_metadata_load, _force_restore_flags, _new_listf, _use_previous_diagdir, _scrap_old_caltb, \
           _interactive_mode, _help_and_exit


def get_attr(_namedtuple, _pattern):
    """
    Get all attributes (keys) of a _namedtuple that match _pattern*
    """
    keys = []
    for key, val in _namedtuple._asdict().items():
        keys.append(key)
    regex      = re.compile(_pattern+'.*')
    attrs      = []
    for key in keys:
        if re.match(regex, key):
            attrs.append(key.decode('utf-8'))
    return attrs


def is_set(_object, _parameter):
    """
    True if _object has _parameter and that parameter is not None.
    Else, return False.
    """
    if hasattr(_object, _parameter):
        if getattr(_object, _parameter):
            return True
        else:
            return False
    return False


def fancy_msg(print_str):
    """
    Prints a message in an unnecessarily flashy format.
    """
    #try:
    #    cols = subprocess.check_output(['tput', 'cols'])
    #except subprocess.CalledProcessError:
    #    #this can happen for mpicasa...
    #    cols = '90'
    cols          = '80'
    console_width = int(cols)
    message       = print_str.split('\n')
    msg_lengths   = [len(s) for s in message]
    max_length    = max(msg_lengths)
    print ('')
    print ('_'*console_width)
    print (' ' * (int(console_width / 2) - int(max_length / 2) - 2) + '.'*(max_length + 4)+'\n')
    for length, msg in zip(msg_lengths, message):
        print (' ' * (int(console_width / 2) - int(length / 2)) + msg)
    print (' ' * (int(console_width / 2) - int(max_length / 2)) + '*'*max_length)
    print ('\n\n')


def split_single_hyphens(inlist):
    """
    Takes a list of strings, typically sys.argv and returns a list where all ['-xyz'] are replaced by ['-x', '-y', '-z'].
    """
    outlist = []
    for val in inlist:
        if len(val)>2 and val[0]=='-' and val[1]!='-':
            thisval = ['-'+v for v in val[1:]]
        else:
            thisval = [val]
        outlist.extend(thisval)
    return outlist


def wrap_list(inlist, items_per_line=1, indent='', two_dim=False):
    """
    Wrap lines along outer axis when printing long lists.
    Hacky code; may or may not work in a general case.
    """
    prtout = indent + '['
    for i,item in enumerate(inlist):
        if two_dim:
            prtout += '['
            for j,it in enumerate(item):
                prtout += str(it) + ', '
                if items_per_line == 1:
                    if j%items_per_line == 0:
                        prtout += '\n' + indent
                else:
                    if j!=0 and j%items_per_line == 0:
                        prtout += '\n' + indent
            prtout = prtout.rstrip(', \n') + '],\n' + indent
        else:
            prtout += str(item) + ', '
            if items_per_line == 1:
                if i%items_per_line == 0:
                    prtout += '\n' + indent
            else:
                if i!=0 and i%items_per_line == 0:
                    prtout += '\n' + indent
    prtout = prtout.rstrip(', \n') + ']'
    return prtout


def glob_all(path):
    """
    Returns a list which contains all files in path.
    """
    return glob(path)


def rm_duplicate_files(inlist):
    """
    Removes all duplicate files (e.g., from redundant links) from inlist.
    """
    no_dup_outlist = []
    for f in inlist:
        _unique = True
        for ff in no_dup_outlist:
            if os.path.samefile(f, ff):
                _unique = False
        if _unique:
            no_dup_outlist.append(f)
    return no_dup_outlist


def get_extension_matches_in_all_subdirs(_directory, _extensions, no_duplicates=True, also_dirs=False):
    """
    Recursively grab all files in _directory with _extensions=['.x', '.y', 'z',...]
    """
    matches = []
    for root, dirs, files in os.walk(_directory, followlinks=True):
        for filename in files:
            if filename.endswith(tuple(_extensions)):
                matches.append(os.path.join(root, filename))
        if also_dirs:
            if root.endswith(tuple(_extensions)):
                matches.append(root)
    if no_duplicates:
        matches = rm_duplicate_files(matches)
    return matches


def check_available_space(_inp_params, infiles, overhead=3.8):
    """
    Makes a guess on how much disk space will be needed for the MS for the default case where infiles = raw fits-idi files.
    There should be enough space for ~4 times the size of the original fits-idi files:
      Factor of 1.9 for idi->MS.
      Another factor of two because it is necessary to run importfitsidi() and then partition()
        (yields 2MS being present at the same time).
    Can also be used for infiles=[MS] with overhead=2 if the MS is being copied due to partitioning.
    Ignores the much smaller averaged data products written by exportdata.
    """
    if not isinstance(infiles, list):
        infiles = [infiles]
    avail = os.statvfs(_inp_params.workdir)
    avail = avail.f_frsize * avail.f_bavail
    size1 = 0
    for fitsidi in infiles:
        size1 += os.path.getsize(fitsidi)
    size4 = overhead * size1
    if size4 > avail:
        raise IOError('Not enough disk space to create the measurement set for the pipeline.\n'
                      'Required: ' + str(size4) +'\nAvailable: ' + str(avail)
                     )


def subprocess_ospopen(str_cmd):
    """ Emulate os.popen() with subprocesses. """
    oscmd     = str_cmd.split()
    p         = subprocess.Popen(oscmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    stdout, _ = p.communicate()
    try:
        return stdout.split('\n')
    except TypeError:
        return stdout.decode().split('\n')


def get_free_memory(_mem0=0):
    """
    Get the amount of available memory on the system (function should have better been called get_available_memory).
    Old method was to use the free command but that one may not give what I am asking for on some systems.
    """
    i = 0
    while True:
        if i > 5:
            if _mem0:
                return _mem0
            else:
                return 64000000
        else:
            pass
        try:
            with open('/proc/meminfo', 'r') as memf:
                for line in memf:
                    lline = line.split()
                    if 'MemAvailable' in line:
                        #Available on new kernels (>=v3.14)
                        return max(int(lline[1]), 0)
                    elif 'MemFree' in line:
                        MemFree = int(lline[1])
                    elif 'Active(file)' in line:
                        ActiveFile = int(lline[1])
                    elif 'Inactive(file)' in line:
                        InactiveFile = int(lline[1])
                    elif 'SReclaimable' in line:
                        SReclaimable = int(lline[1])
            LowWaterMark_sum = 0
            with open('/proc/zoneinfo', 'r') as lwmf:
                for line in lwmf:
                    if 'low' in line:
                        LowWaterMark_sum += float(line.split()[-1])
            LowWaterMark_sum *= 12
            SReclaimable     -= min(SReclaimable/2, LowWaterMark_sum)
            PageCache         = ActiveFile + InactiveFile
            PageCache        -= min(PageCache/2, LowWaterMark_sum)
            return max(MemFree + SReclaimable + PageCache, 0)
        except OSError:
            print('-- Encountered a bash issue, kill zombie CASA MPIServer processes or raise ulimit -n.')
            i += 1
            time.sleep(60)
        except NameError:
            #For strange systems that I do not know about.
            return max(int(subprocess_ospopen('free')[1].split()[-1]), 0)


def get_used_memory(pid, active=0):
    i = 0
    while True:
        if i > 5:
            return 0
        try:
            stdout  = subprocess_ospopen('top -b -n 1 -p {0}'.format(pid))
            stdout  = [stdo for stdo in stdout if stdo]
            thisout = stdout[-1]
            try:
                this_active = float(thisout.replace(',','.').split()[-4])
            except ValueError:
                return 0
            if this_active > active:
                stdout = subprocess_ospopen('ps -p {0} -o rss'.format(pid))
                return int(stdout[1].split()[-1])
            #if float(os.popen('top -b -n 1 -p {0}'.format(pid)).readlines()[-1].split()[-4]) > active:
            #    return int(os.popen('ps -p {0} -o rss'.format(pid)).readlines()[1].split()[-1])
            else:
                return 0
        except OSError:
            print('-- Waiting for free memory...')
            i += 1
            time.sleep(60)


def check_no_overflow(_inp_params, running_processes, free_memory0, ff_res_est, N_workers, fudge=0.75):
    """
    Returns True when there seems to be enough memory available for ff_res_est number of resources.
    """
    global first_mem_request
    global flat_mem0_mpi
    global max_mem_to_resources_scaling
    if not running_processes:
        return True
    wait_until_memory_settles(_inp_params, sleeptime=5)
    resources    = [float(rp[0]) for rp in running_processes.values()]
    used_mem_r   = sum([get_used_memory(pid, 5) for pid in list(_inp_params.MPI_processIDs.values())])
    used_mem_a   = sum([get_used_memory(pid, 0) for pid in list(_inp_params.MPI_processIDs.values())])
    this_scaling = float(used_mem_a) / float(sum(resources))
    if this_scaling > max_mem_to_resources_scaling:
        max_mem_to_resources_scaling = this_scaling
    avail_mem   = abs(fudge*free_memory0 - used_mem_r)
    mem_request = max_mem_to_resources_scaling * ff_res_est
    #correct mem_request by actual differences in memory used assuming a non-scalable difference
    if not first_mem_request:
        first_mem_request = (get_free_memory(free_memory0), mem_request)
    elif not flat_mem0_mpi:
        flat_mem0_mpi = abs(first_mem_request[0] - first_mem_request[1] - get_free_memory(free_memory0))
    if flat_mem0_mpi > 0:
        mem_request += flat_mem0_mpi
    #print (running_processes, resources, [get_used_memory(pid, 5) for pid in list(_inp_params.MPI_processIDs.values())], free_memory0, used_mem_r, used_mem_a, max_mem_to_resources_scaling, get_free_memory(), avail_mem, mem_request)
    if mem_request > free_memory0:
        _warnm = '      May not have enough memory for next fringe-fit. Will try it anyway sequentially.'
        print(_warnm)
        if not running_processes:
            return True
        else:
            total_sleeptime = 0
            while True:
                if not running_processes:
                    return True
                #keep track of 'local copy' of running_processes here (no need to report back to continue_mpi_ff function)
                waitfortable, running_processes = update_running_processes(_inp_params, False, running_processes, N_workers)
                time.sleep(1)
                total_sleeptime += 1
                if total_sleeptime > 1200:
                    return True
                if total_sleeptime > 10 and check_if_all_processes_are_idle(_inp_params):
                    return True
    if mem_request < avail_mem:
        return True
    else:
        return False


def estimate_ff_resources(_ms_metadata, scan, solint):
    """
    Estimate the relative amount of 'resources' that will be allocated when fringe-fitting a scan as number of visibilities
    times solint.
    """
    if str(solint) == 'inf' or str(solint) == '0':
        this_solint_frac = 1.
    else:
        this_solint_frac = float(solint) / float(_ms_metadata.yield_scan_length(scan))
    return _ms_metadata.yield_numvisi(str(scan)) * this_solint_frac


def wait_until_memory_settles(_inp_params, threshold=0.001, sleeptime=1, for_fringefit=True):
    """
    Returns when no mpi process shows an increase of memory allocation by more than threshold.
    """
    if not sum([get_used_memory(pid, 5) for pid in list(_inp_params.MPI_processIDs.values())]):
        return
    if for_fringefit:
        # Wait for all processes running fringefit to arrive at the least-squares stage.
        total_sleeptime = 0
        last_file_pos   = 0
        active_pids     = []
        while True:
            with open(_inp_params.CASALogFile, 'r') as casalogf:
                # Continue reading where we left off the last time.
                casalogf.seek(last_file_pos, 0)
                for line in casalogf:
                    if 'Begin Task: fringefit' in line:
                        try:
                            server = int(line.split('#')[0].rstrip()[-1])
                            if server not in active_pids:
                                active_pids.append(server)
                        except (IndexError, ValueError) as _:
                            pass
                    elif 'Starting least squares optimization.' in line or 'End Task: fringefit' in line:
                        try:
                            server = int(line.split('MPIServer-')[1][0])
                            if server in active_pids:
                                active_pids.remove(server)
                        except (IndexError, ValueError) as _:
                            pass
                last_file_pos = casalogf.tell()
            if not any(active_pids):
                break
            time.sleep(sleeptime)
            total_sleeptime += sleeptime
            if total_sleeptime > 1200:
                return
    total_sleeptime = 0
    process_ids     = copy.deepcopy(list(_inp_params.MPI_processIDs.values()))
    while True:
        if not process_ids:
            return
        oldval = {}
        newval = {}
        for pid in process_ids[:]:
            oldval[pid] = max(get_used_memory(pid),1)
        time.sleep(sleeptime)
        total_sleeptime += sleeptime
        for pid in process_ids[:]:
            newval[pid] = max(get_used_memory(pid),1)
            if oldval[pid]==newval[pid] or newval[pid]<oldval[pid] or frac_diff(oldval[pid], newval[pid]) < threshold:
                process_ids.remove(pid)
        if total_sleeptime > 1200:
            return


def update_running_processes(_inp_params, waitfortable, running_processes, N_workers):
    """
    Removes written tables from running_processes.
    """
    for sc in list(running_processes.keys()):
        try:
            if isdir(running_processes[sc][1]):
                check_if_table_is_written(running_processes[sc][1])
                del running_processes[sc]
                waitfortable = False
                wait_until_memory_settles(_inp_params)
        except KeyError:
            pass
    if len(running_processes) < N_workers:
        waitfortable = False
    return waitfortable, running_processes


def continue_mpi_ff(_inp_params, _ms_metadata, scan, scan_solint, scan_table, free_memory0, N_workers,
                    running_processes, force_no_mpi=False, increase_mem_settletime=1):
    """
    Wait until there is enough memory free to continue allocating fringefit jobs.
    Updates running_processes {(number of resources (#visibilities*solint), caltable)} unless a simple memory estimation is used.
    """
    global max_mem_usage
    global exhaustive_memory_usage_sum
    simple_memory_estimation = True
    if _inp_params.mpi_memory_safety or force_no_mpi=='safety':
        if simple_memory_estimation:
            if not running_processes:
                max_mem_usage               = 0.
                exhaustive_memory_usage_sum = 0
                running_processes[scan]     = ('first_process', [scan_table])
                return running_processes
            elif 'full_load' in [m[0] for m in running_processes.values()]:
                # If at some point all MPI workers are running, we can continue to keep them fully occupied.
                running_processes[scan] = ('full_load', [scan_table])
                return running_processes
            if is_set(_inp_params, 'ff_mem_threshold'):
                mem_thresh = _inp_params.ff_mem_threshold
            else:
                mem_thresh = 95
            if is_set(_inp_params, 'ff_mem_settletime'):
                mem_settletime = _inp_params.ff_mem_settletime
            else:
                mem_settletime = 30
            if scan in running_processes.keys() and not isinstance(running_processes[scan][0], str):
                this_mem_estimate , running_processes = balance_memory_with_known_consumptions(_inp_params, running_processes,
                                                                                               scan, scan_table, mem_thresh
                                                                                              )
            else:
                this_mem_estimate       = simple_memory_management(_inp_params, mem_settletime*increase_mem_settletime,
                                                                   mem_thresh
                                                                  )
                running_processes[scan] = (this_mem_estimate, [scan_table])
            return running_processes
        these_resources = estimate_ff_resources(_ms_metadata, str(scan), scan_solint)
        if not running_processes:
            running_processes[scan] = (these_resources, scan_table)
            #first ff scan should just go through
            return running_processes
        processed_scans = list(running_processes.keys())
        num_processed   = len(processed_scans)
        if num_processed < N_workers:
            waitfortable = False
        else:
            waitfortable = True
        wait_until_memory_settles(_inp_params)
        total_sleeptime = 0
        waitfortable, running_processes = update_running_processes(_inp_params, waitfortable, running_processes, N_workers)
        while not check_no_overflow(_inp_params, running_processes, free_memory0, these_resources, N_workers) or waitfortable:
            waitfortable, running_processes = update_running_processes(_inp_params, waitfortable, running_processes, N_workers)
            time.sleep(15)
            total_sleeptime += 15
            if total_sleeptime > 15000:
                raise MemoryError('Cannot fringefit scan ' + str(scan) + '. Not enough free memory.')
            waitfortable, running_processes = update_running_processes(_inp_params, waitfortable, running_processes, N_workers)
            if total_sleeptime > 30 and check_if_all_processes_are_idle(_inp_params):
                return {}
        if scan not in running_processes.keys():
            running_processes[scan] = (these_resources, scan_table)
        return running_processes
    else:
        return True


def check_if_all_processes_are_idle(_inp_params):
    """
    Table may not have been written in the first place due to a fringe-fit error.
    If so, there will be zombie processes. This function checks for this.
    """
    if not sum([get_used_memory(pid, 5) for pid in list(_inp_params.MPI_processIDs.values())]):
        return True
    else:
        return False


def check_if_process_is_active(pid, cpu_perc_threshold=30):
    try:
        stdout  = subprocess_ospopen('top -b -n 1 -p {0}'.format(pid))
        stdout  = [stdo for stdo in stdout if stdo]
        thisout = stdout[-1]
        try:
            this_active = float(thisout.replace(',','.').split()[-4])
        except ValueError:
            return True
        return this_active > cpu_perc_threshold
    except OSError:
        return True


def get_memory_info(_inp_params):
    """
    Returns:
      - Number of CPU-active _inp_params.MPI_processIDs.
      - Number of CPU-inactive _inp_params.MPI_processIDs.
      - The total % of occupied memory.
      - List of % memory usage of each CPU-active _inp_params.MPI_processIDs.
      - The largest % of memory usage out of all CPU-inactive _inp_params.MPI_processIDs.
        This is useful because I assume that the inactive PID with the largest residual lingering memory from a previous process
        will be used for the next active process, where the lingering memory will be overwritten by active memory usage.
    """
    mem_tot = 0.
    mem_est = []
    mem_cur = 0.
    Nact    = 0
    Ninact  = 0
    ps_aux  = subprocess_ospopen('ps aux')
    for ps in ps_aux[1::]:
        ps_info  = ps.split()
        try:
            this_pid = ps_info[1]
            this_mem = float(ps_info[3])
        except (IndexError, ValueError) as _:
            continue
        mem_tot += this_mem
        if int(this_pid) in list(_inp_params.MPI_processIDs.values()):
            if check_if_process_is_active(this_pid):
                Nact += 1
                mem_est.append(this_mem)
            else:
                Ninact += 1
                if this_mem > mem_cur:
                    mem_cur = this_mem
                else:
                    pass
    return Nact, Ninact, mem_tot, mem_est, mem_cur


def simple_memory_management(_inp_params, mem_settletime=30, thresh_percentage=95):
    """
    Simple way to estimate if enough memory is available to continue with a CASA process:
    Does the number of active processes multiplied by the largest memory usage of a single process exceed thresh_percentage?
    """
    global max_mem_usage
    global all_processes_running_stable
    Nproc                      = len(_inp_params.MPI_processIDs.values())
    safety                     = 0
    safety_thresh              = 18000
    stable_mem_for_stable_proc = 1.
    while True:
        if not max_mem_usage:
            # First run.
            wait_until_memory_settles(_inp_params, sleeptime=mem_settletime)
        Nact, Ninact, mem_tot, mem_est, mem_cur = get_memory_info(_inp_params)
        if Nact == Nproc:
            return 'full_load'
        elif Ninact == Nproc:
            return 'all_inactive'
        if not Nact:
            # No active process found.
            sum_mem_est = 0.
            max_mem_est = 0.
        else:
            sum_mem_est = np.sum(mem_est)
            max_mem_est = np.max(mem_est)
        if not max_mem_usage and frac_diff(mem_tot, stable_mem_for_stable_proc) > 0.01:
            all_processes_running_stable = 0
        if not max_mem_usage:
            # Fist run.
            if all_processes_running_stable:
                # See below.
                max_mem_usage = 0.5 * (np.min(mem_est) + np.median(mem_est))
            else:
                max_mem_usage = max_mem_est
        elif max_mem_est > max_mem_usage and not all_processes_running_stable:
            # New round of processes, make sure that the memory has settled before continuing.
            wait_until_memory_settles(_inp_params, sleeptime=mem_settletime)
            max_mem_usage = max_mem_est
            continue
        mem_tot_orig = mem_tot
        if not all_processes_running_stable:
            # To avoid waiting for the memory to settle for each active process:
            # Assume that the inactive mem_cur will be replaced by the max_mem_usage from all active processes and that
            # every single active process will eventually reach max_mem_usage (instead of settling at the current sum_mem_usage).
            mem_tot += Nact * max_mem_usage - sum_mem_est
        else:
            # Every new process beyond the Nact stable ones, will settle at an estimated max_mem_usage.
            # I can ignore sum_mem_est because I do not know which ones belong to the stable running processes and because we are
            # scheduling processes faster than they will allocate memory.
            mem_tot += max(Nact-all_processes_running_stable, 0) * max_mem_usage
        #print (mem_tot, Nact, all_processes_running_stable, max_mem_est, max_mem_usage, mem_tot + max_mem_usage - mem_cur)
        if max_mem_est < max_mem_usage:
            # A new round of processes, where the top memory usage has gone down. The memory really has to settle here and we can
            # take our time because we first let as many processes as possible pass with the old threshold, that are now running
            # in the background.
            wait_until_memory_settles(_inp_params, sleeptime=mem_settletime)
            wait_until_memory_settles(_inp_params, sleeptime=mem_settletime)
            max_mem_usage = max_mem_est
            continue
        if mem_tot + max_mem_usage - mem_cur < thresh_percentage:
            return max_mem_usage
        safety += 10
        time.sleep(10)
        if safety > 5*mem_settletime:
            # Only when we are certain that the memory of all processes has settled.
            # Memory will settle again as first for the first run and then we will use a more aggressive memory estimator
            # for the next processes until thresh_percentage is reached again.
            all_processes_running_stable = Nact
            stable_mem_for_stable_proc   = mem_tot_orig
            max_mem_usage                = 0.
            safety                       = 0
            safety_thresh               -= 5*mem_settletime
            continue
        elif safety > safety_thresh:
            # Should be safe to continue after 5 hours in any case.
            return 'long_wait'
        all_processes_running_stable = 0
        stable_mem_for_stable_proc   = 1.


def balance_memory_with_known_consumptions(_inp_params, running_processes, scan, scan_table, thresh_percentage=95):
    """
    Takes the role of the simple_memory_management() function when I know from the running_processes how much memory a process
    for each scan will require.
    Will be used for the exhaustive fringe search.
    """
    global exhaustive_memory_usage_sum
    safety                       = 0
    safety_thresh                = 18000
    this_mem_estimate            = running_processes[scan][0]
    exhaustive_memory_usage_sum += this_mem_estimate
    while True:
        _, _, mem_tot, mem_est, mem_cur = get_memory_info(_inp_params)
        if mem_tot + exhaustive_memory_usage_sum - np.sum(mem_est) - mem_cur < thresh_percentage or safety > safety_thresh:
            break
        for key in running_processes.keys():
            if 'active' in running_processes[key]:
                for table in running_processes[key][1]:
                    if isdir(table):
                        check_if_table_is_written(table)
                        try:
                            exhaustive_memory_usage_sum -= float(running_processes[key][0])
                        except ValueError:
                            continue
                        these_tables = running_processes[key][1]
                        these_tables.remove(table)
                        running_processes[key] = (running_processes[key][0], these_tables, 'active')
                    else:
                        pass
            else:
                pass
        safety += 10
        time.sleep(10)
    if scan in running_processes.keys():
        scan_table_list = running_processes[scan][1]
        if 'active' in running_processes[scan]:
            scan_table_list.extend([scan_table])
        else:
            # Do not take inactive calibration tables into account that were already written earlier.
            scan_table_list = [scan_table]
        running_processes[scan] = (this_mem_estimate, scan_table_list, 'active')
    else:
        running_processes[scan] = (this_mem_estimate, scan_table, 'active')
    return this_mem_estimate, running_processes


def add_known_mem_usages_to_running_processes(running_processes):
    """
    Get the correct mem usages for filler values such as 'first_process' and 'full_load' of the running_processes.
    """
    try:
        proc_vals  = [v[0] for v in running_processes.values() if not isinstance(v[0], str)]
    except AttributeError:
        return running_processes
    try:
        obs_minmem = min(proc_vals)
        obs_maxmem = max(proc_vals)
    except ValueError:
        return running_processes
    for key in list(running_processes.keys()):
        if running_processes[key][0] == 'full_load':
            running_processes[key] = (obs_minmem, running_processes[key][1])
        elif isinstance(running_processes[key][0], str):
            running_processes[key] = (obs_maxmem, running_processes[key][1])
        else:
            pass
    return running_processes


def get_latest_file(inlist):
    """From a list of files, grab the latest one."""
    return max(inlist, key=os.path.getctime)


def sort_by_time(inlist):
    """Sort a list of files by age."""
    return sorted(inlist, key=os.path.getctime)


def natural_sort_Ned_Batchelder(inlist):
    try:
        texttonum    = lambda text: int(text) if text.isdigit() else text
        alphanum_key = lambda key: [ texttonum(c) for c in re.split('([0-9]+)', key) ]
        inlist.sort( key=alphanum_key )
    except TypeError:
        return


def rm_file_if_empty(filename):
    if not isfile(filename):
        return True
    if os.stat(filename).st_size == 0:
        os.remove(filename)
        return True
    else:
        return False


def check_for_filecontent(filename):
    """
    Returns True when file exists and something is written in there.
    Else, returns False.
    """
    if not filename:
        return False
    if not isfile(filename):
        return False
    isnot_empty = open(filename,'r')
    check_lines = isnot_empty.readlines()
    isnot_empty.close()
    if not check_lines:
        return False
    else:
        return True


def rm_dir_if_present(directory):
    directories = glob_all(directory)
    for dirname in directories:
        if os.path.isdir(dirname):
            shutil.rmtree(dirname, ignore_errors=True)


def rm_file_if_present(filename):
    fnames = glob_all(filename)
    for fname in fnames:
        if isfile(fname):
            os.remove(fname)


def mk_dir_if_not_present(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)


def mk_dirname_path_if_not_present(directory):
    dirname = os.path.dirname(directory)
    if not dirname:
        return
    elif not os.path.exists(dirname):
        os.makedirs(directory)


def isfile(filename):
    return os.path.isfile(filename)


def isdir(directory, returndir=False):
    if not directory:
        return False
    if returndir:
        if os.path.isdir(directory):
            return directory
        else:
            return False
    else:
        return os.path.isdir(directory)


def copyfiles(inpathlist, new_home):
    """
    Copy all files in a inpathlist to a new_home.
    """
    makedir(new_home)
    for f in inpathlist:
        shutil.copy(f, new_home)


def copydir(source, destination):
    _ = distutils.dir_util.copy_tree(source,destination)


def changedir(dirpath):
    """
    Change working directory to dirpath and return the previous directory.
    """
    previous_dir = os.getcwd()
    os.chdir(dirpath)
    print ('\nChanging directories from\n' + previous_dir +'\nto\n' + dirpath + '\n')
    return previous_dir


def makedir(dirpath):
    """
    Creates directory if it does not exist yet.
    """
    if not os.path.exists(dirpath):
        os.makedirs(dirpath)


def find_sign_flips(array):
    """
    for all a in array give array of indices i of a_i if
    sign(a_i)!= sign(a_i-1) and != sign(a_i+1)
    """
    array      = np.sign(array)
    len_array  = len(array) - 1
    sign_flips = []
    for i,a in enumerate(array):
        if i == 0 or i == len_array:
            pass
        elif a != array[i-1] and a != array[i+1]:
            sign_flips.append(i)
    return sign_flips


def proper_line(f, comment_char = '#'):
    """
    Returns only non-commented and non-blank lines from input file.
    Lets lines continue if they end with a \ char.
    """
    lline = None
    for l in f:
        lline = l.rstrip()
        lline = lline.rstrip('\n')
        if lline:
            if comment_char not in lline[0]:
                while lline.endswith('\\'):
                    lline = lline[:-1] + next(f).rstrip().rstrip('\n')
                yield lline


def store_object(_filename, _object=None, _operation=''):
    """
    Reads or writes an object from/to disk.
    """
    if not _operation:
        if os.path.isfile(_filename):
            _operation = 'read'
        else:
            _operation = 'write'
    if _operation == 'read':
        try:
            with open(_filename, 'rb') as dumped:
                loaded_object = pickle.load(dumped)
            return loaded_object
        except IOError:
            raise IOError(str(_filename) + ' not found on disk. Please run the pipeline step to generate it first.')
    elif _operation == 'write':
        if not _object:
            raise ValueError('Must support a valid class for the write option')
        with open(_filename, 'wb') as dump:
            pickle.dump(_object, dump, pickle.HIGHEST_PROTOCOL)
    else:
        raise ValueError('Operation '+str(_operation)+' is not supported. Must be either read or write.')


def string_to_bool(_string):
    """
    Convert a string to a bool if it matches one of the specified boolean expressions.
    """
    if _string == 'None':
        _string = None
    elif _string == 'True':
        _string = True
    elif _string == 'False':
        _string = False
    return _string


def uniques_from_tuples(_tuple_list, _place, _return_index = False):
    """
    For a list of tuples: [(x_1,x_2,x_3,...), (y_1,y_2,y_3,...), ...],
      get for i=_place all unique elements of [x_i, y_i, ...]
    """
    if _tuple_list:
        return np.unique(np.transpose(_tuple_list)[_place], return_index = _return_index)
    else:
        return None


def replace_valids_by_input(inlist, _input):
    """
    Returns a copy of inlist where all valid entries are replaced by _input
    """
    outlist = copy.deepcopy(inlist)
    for i, item in enumerate(outlist):
        if item:
            outlist[i] = _input
    return outlist


def random_number_string(_len):
    """
    Returns a sting of N=_len random numbers.
    Useful for creating unique names.
    """
    return ''.join(random.choice(string.digits) for digit in range(_len))


def unique_filename(_fnam):
    """
    Uses random number strings to create a unique filename for a folder.
    """
    len0 = len(_fnam)
    uf   = _fnam + '.' + random_number_string(6)
    while isdir(uf):
        uf += '.' + random_number_string(6)
        if len(uf) > 10*len0 + 18:
            raise OverflowError('This should not have happened... You have boldly gone where no man has gone before.')
    return uf


def src_caltb(caltable, sourcename):
    """
    Attach sourcename to caltable.
    Unused (using CASA gainfield method instead).
    """
    return '{0}.__{1}__'.format(str(caltable), str(sourcename))


def findall_src_caltb(caltable):
    """Gets all existing src_caltbs."""
    return glob_all(src_caltb(caltable, '*'))


def cut_last_axis(ndarray, break_points, len_y=0, istuple=False):
    """
    Takes an array of shape (x1,x2,x3,...,y) and a list of indices [i0,i1,i2,...,in].
    The indices (break_points) are used to cut the last axis (y) into multiple pieces.
    Returns a list [y1,y2,y3,...yn] with yi = np.array([x1,x2,x3,...,[i-1:i]])
    and yn = np.array([x1,x2,x3,...,[in-1:-1]])
    for istuple=True: must be 1d list of tuples
    """
    if not len_y:
        if istuple:
            len_y = len(ndarray)
        else:
            len_y = np.shape(ndarray)[-1]
    N_end_indx = len(break_points)-1
    outlist    = []
    for i,indx in enumerate(break_points):
        if i == N_end_indx:
            if istuple:
                outlist.append(ndarray[indx:len_y])
            else:
                outlist.append(ndarray[...,indx:len_y])
        else:
            if istuple:
                outlist.append(ndarray[indx:break_points[i+1]])
            else:
                outlist.append(ndarray[...,indx:break_points[i+1]])
    return outlist


def get_closest_match(shootlist, targetlist, skip=[], ignoreval=''):
    """
    Go through all items in shootlist and all items in targetlist finding the closest pair.
    Skip items from shootlist if they are in skip.
    Skip any items that are equal to ignoreval.
    Returns the closest pair and the difference.
    """
    if not isinstance(shootlist, list):
        shootlist = [shootlist]
    if not isinstance(targetlist, list):
        targetlist = [targetlist]
    if not isinstance(shootlist[0], int):
        shootlist  = [float(shoot) for shoot in shootlist]
    if not isinstance(targetlist[0], int):
        targetlist = [float(target) for target in targetlist]
    mindiff = 1e9 + np.abs(max(shootlist)) + np.abs(max(targetlist))
    sucess  = False
    minx    = None
    miny    = None
    for x in shootlist:
        if x not in skip:
            if x != ignoreval:
                for y in targetlist:
                    if y != ignoreval:
                        diff = np.abs(x-y)
                        if diff < mindiff:
                            mindiff = diff
                            minx    = x
                            miny    = y
                            sucess  = True
                            if mindiff == 0:
                                return (minx,miny), mindiff, sucess
    return (minx,miny), mindiff, sucess


def check_for_positive_trend(inlist, fit=True, cut=1.1):
    """
    Takes an (ordered) inlist of numbers and checks if there is a positive trend in the data.
    Method: If fit=True:
                Use linear regression.
            Else:
                Computes differences between neighboring numbers and computes the fraction of positive differences
                over negative differences.
                If that fraction is bigger than cut(=1.1), the function returns True as there is some positive trend.
                Else, False is returned.
    """
    inl = np.asarray(list(inlist))
    if fit:
        x   = np.arange(len(inl))
        a,_ = np.polyfit(x, inl, 1)
        if a > 0:
            return True
        else:
            return False
    else:
        diff     = np.ediff1d(inl)
        positive = float(len( np.where(diff>0)[0] )) + 1
        negative = float(len( np.where(diff<0)[0] )) + 1
        if positive/negative > cut:
            return True
        else:
            return False


def find_last_uphill_point(inlist_y, inlist_x, rel2change=0.1):
    """
    Returns the index of inlist corresponding to the first point which either
    shows a negative first derivative or a relative flattening of the second derivative by more than rel2change.
    """
    inarray       = np.asarray(inlist_y)
    dx            = np.diff(np.asarray(inlist_x))
    firstd        = np.diff(inarray) / dx
    secondd       = np.diff(firstd)
    first_i       = 0
    nothing_found = True
    for i,fd in enumerate(firstd):
        if fd < 0:
            first_i       = i
            nothing_found = False
            break
    for i,sd in enumerate(secondd):
        if sd < 0:
            if abs(sd) / firstd[i] > rel2change:
                this_i = i + 1
                if this_i < first_i or nothing_found:
                    first_i = this_i
                    break
    if not first_i:
        return False
    else:
        return first_i


def find_first_SNR_diff(inlist_y, inlist_x, cutoff_mode='sqrt', room_for_error=1.e-6, snr_mincut=7.01):
    """
    Takes inlist_y values that should increase with sqrt(inlist_x) [typically SNR vs bandwidth or coherent averaging time].
    If cutoff_mode=='sqrt':
      Uses inlist_y[0]/sqrt(inlist_x[0]) to find A and assumes that inlist_y = A * sqrt(inlist_x).
      Gives the first index where (inlist_y - A*sqrt(inlist_x))/inlist_y < room_for_error, i.e. where the relative change
        of inlist_y does no longer follow sqrt(inlist_x) by more than room_for_error.
      Or returns False when inlist_y always rises with more than sqrt(inlist_x).
    If cutoff_mode=='minsnr':
      Returns first index where inlist_y > snr_mincut
    Also, returns inlist_x and A * sqrt(inlist_x) on a fine grid for plotting
    """
    _y = np.asarray(inlist_y)
    _x = np.asarray(inlist_x)
    sx = np.sqrt(_x)
    a0 = _y[0] / sx[0]
    sx *= a0
    dy = (_y - sx) / _y
    xp = np.linspace(min(_x), max(_x), 100)
    sp = a0 * np.sqrt(xp)
    try:
        if cutoff_mode == 'sqrt':
            c_indx = np.where(dy < -room_for_error)[0][0]
            #take the previous point which was still good:
            c_indx -= 1
            if c_indx < 0:
                c_indx = 0
        elif cutoff_mode == 'minsnr':
            c_indx = np.where(_y > snr_mincut)[0][0]
        else:
            raise ValueError('cutoff_mode must be sqrt or minsnr.')
        return c_indx, xp, sp
    except IndexError:
        return False, xp, sp


def interpolate1d_over_flags(x, y, flags, kind='linear', input_bounds_error=False, pass_single=False):
    """
    Uses scipy interp1d to interpolate (x,y), while ignoring flagged(==False) data.
    """
    flags_array = np.asarray(flags)
    xx          = np.asarray(x)[flags_array]
    yy          = np.asarray(y)[flags_array]
    xxx         = []
    yyy         = []
    for i,val in enumerate(yy):
        if not np.isnan(val):
            yyy.append(val)
            xxx.append(xx[i])
    num_good_sols = len(xxx)
    if pass_single:
        if num_good_sols==0:
            return False
        elif num_good_sols==1:
            return interpolate1d_single(xxx,yyy, kind, False)
        else:
            pass
    elif num_good_sols<2:
        return False
    if not input_bounds_error or input_bounds_error == 'extrapolate':
        bounds_error = False
    elif num_good_sols < input_bounds_error:
        bounds_error = False
    else:
        bounds_error = True
    xxx      = np.asarray(xxx)
    yyy      = np.asarray(yyy)
    min_indx = np.argmin(xxx)
    max_indx = np.argmax(xxx)
    return interp1d(xxx, yyy, kind=kind, bounds_error=bounds_error, fill_value=(yyy[min_indx],yyy[max_indx]))


def interpolate1d_single(x,y, kind='linear', bounds_error=False):
    """
    Uses scipy interp1d to interpolate (x,y), even if len(x)=len(y)=1.
    """
    if len(x)<2:
        xx    = np.asarray(copy.deepcopy(x))
        yy    = np.asarray(copy.deepcopy(y))
        xx    = np.append(xx,xx[0])
        yy    = np.append(yy,yy[0])
        xx[0] = xx[0] - 1.e-13
        xx[1] = xx[1] + 1.e-13
    else:
        xx = x
        yy = y
    min_indx = np.argmin(xx)
    max_indx = np.argmax(xx)
    return interp1d(xx, yy, kind=kind, bounds_error=bounds_error, fill_value=(yy[min_indx ],yy[max_indx]))


def regrid_1D_fine(x_orig, y_orig, regrid_dx, bounds=False, return_x=False, kind='linear'):
    """
    Takes x_orig, y_orig data, interpolates between them with a certain kind (linear by default)
    and returns re-gridded data spaced by regrid_dx.
    """
    if bounds:
        x_start = bounds[0]
        x_end   = bounds[1]
    else:
        x_start = min(x_orig)
        x_end   = max(x_orig)
    interp_func = interpolate1d_single(x_orig, y_orig, kind)
    x_new       = np.arange(x_start, x_end, regrid_dx)
    x_new       = np.append(x_new, x_end)
    if return_x:
        return x_new, np.asarray([interp_func(x) for x in x_new])
    else:
        return np.asarray([interp_func(x) for x in x_new])


def check_array_for_match(val, array):
    """ Check if any values in array are in val (useful for strings) """
    _match = False
    for x in array:
        if x in val:
            _match = x
            break
    return _match


def remove_selected_from_lists(_selected, selection_list, allother_lists=[], allother_dict={}):
    """
    Input: - selection_list, which should contain _selected
           - allother_lists, which should be a list of lists,
               each of which with items in an order corresponding to selection_list.
           - allother_dict, containing lists just like allother_lists but for keywords in a dict.
    Returns selection_list*, allother_lists*, allother_dict* where selection_list* has _selected removed and for each list in
    allother_lists and allother_dict the item at the index of _selected in selection_list has been removed.
    """
    rm_indx             = selection_list.index(_selected)
    selection_list_star = copy.deepcopy(selection_list)
    allother_lists_star = copy.deepcopy(allother_lists)
    allother_dict_star  = {}
    del selection_list_star[rm_indx]
    for i in range(len(allother_lists)):
        del allother_lists_star[i][rm_indx]
    for key in allother_dict.keys():
        this_list = copy.deepcopy(list(allother_dict[key]))
        del this_list[rm_indx]
        allother_dict_star[key] = this_list
    return selection_list_star, allother_lists_star, allother_dict_star


def closest_match(array, value):
    """
    Value in array that is closest to input value.
    """
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return array[idx]


def subtract_list(l1,l2):
    """
    Returns all elements of l1 that are not in l2.
    """
    return [x for x in l1 if x not in l2]


def overlap_lists(l1,l2):
    """
    Returns all elements that are both in l1 and l2.
    """
    return [x for x in l1 if x in l2]


def list_position_weights(l1,l2):
    """
    Returns sum of all indices based on the position of all elements of the list l1 in the list l2.
    """
    return sum([l2.index(x) for x in l1])


def get_evenly_spaced_elements(inlist, n_elements):
    """
    Pick n_elements distributed over inlist.
    """
    length  = float(len(inlist))
    if n_elements > length:
        raise ValueError('Cannot get more than the total number of elements in the list.')
    outlist = []
    for i in range(n_elements):
        outlist.append(inlist[int(np.ceil(i * length / n_elements))])
    return outlist


def make_ndlist(shape):
    """
    Make list with a specific shape.
    """
    a = np.zeros(shape = shape)
    b = np.empty((a.shape) + (0, ), dtype = object)
    b.fill([])
    return b.tolist()


def add_extra_dim(inarray, _dim):
    """
    Takes an array and adds extra dimensions until it has at least _dim dimensions.
    """
    current_dim = 0
    dummy_copy  = copy.deepcopy(inarray)
    #robust measure for array dimension even if it consists of numpy subarrays:
    while True:
        try:
            dummy_copy   = dummy_copy[0]
            current_dim += 1
        except (TypeError, IndexError) as reduced_to_scalar:
            break
    num_add_dim = range(_dim - current_dim)
    for _ in num_add_dim:
        inarray = np.expand_dims(inarray, 0)
    return inarray


def unique_unsorted(inarray):
    inarray  = np.asarray(inarray)
    try:
        inarray  = inarray.reshape(max(inarray.shape))
    except ValueError:
        pass
    _, index = np.unique(inarray, return_index=True)
    return np.asarray(inarray)[np.sort(index)]


def flatten_list(inlist, check_for_1dstr=False):
    if check_for_1dstr:
        if isinstance(inlist[0], str):
            return inlist
        else:
            pass
    return [item for sublist in inlist for item in sublist]


def transpose_list(inlist, squeeze=False, check1D=False):
    if squeeze:
        inlist = np.squeeze(inlist)
    if check1D:
        try:
            _ = inlist[0][0]
        except (TypeError, IndexError) as list_or_array:
            return [inlist]
    return list(map(list, zip(*inlist)))


def force_list(initem):
    try:
        if not initem:
            return initem
        else:
            pass
    except ValueError:
        return initem
    if not isinstance(initem, list) and not isinstance(initem, np.ndarray):
        return [initem]
    else:
        return initem


def get_all_baselines(in_antennas, in_baselines=[]):
    """
    Return all baseline combinations from the input in_antennas list and adding all baslines from the in_baselines list.
    Uses the CASA syntax.
    """
    basels = ''
    for i, ant1 in enumerate(in_antennas):
        for ant2 in in_antennas[i:]:
            if ant1!=ant2:
                basels += '{0}&{1};'.format(str(ant1),str(ant2))
    for inb in in_baselines:
        basels += inb+';'
    basels = basels.rstrip(';')
    return basels


def remove_flagged_values(values, flags):
    """
    Returns an array out of values wherever flags==0/False.
    """
    mavals_orig = np.ma.masked_array(values, flags)
    try:
        _ = values[0][0]
        # Keep the array shape.
        S = list(mavals_orig.shape)
        mavals = np.asarray([m.compressed() for m in mavals_orig])
        S[0]  = mavals.shape[0]
        S[-1] = mavals.shape[-1]
        try:
            return mavals.reshape(tuple(S))
        except ValueError:
            return np.asarray([[m.compressed() for m in subarray] for subarray in mavals_orig])
    except IndexError:
        return mavals_orig.compressed()


def medfilt_nobias(inarray, kernelsize=3, bad_vals=[], allowed_range=[-float('Inf'), float('Inf')], minval=1.e-13):
    """
    Applies sliding median window filter while remove the bias from invalid (flagged) values first:
      If any value is outside of the min/max values of allowed_range or equal to any value in bad_vals,
      then they are first replaced with the global median of the whole array before the sliding median is applied.
        Values smaller than minval are ignored for the calculation of the global median.
        Does nothing if all values are smaller than minval (e.g., delays of reference antenna).
    If allowed_range=[0,0], all values in inarray will be replaced by the global median of the whole array.
    """
    if all(abs(val) < minval or np.ma.is_masked(val) for val in inarray):
        return inarray
    if not isinstance(bad_vals, list):
        bad_vals = [bad_vals]
    if len(inarray) == 1 or np.ma.count(inarray) < 2:
        return inarray
    outarray      = copy.deepcopy(inarray)
    valids_median = np.ma.masked_invalid(np.ma.masked_where(np.ma.abs(outarray) < minval, outarray))
    global_median = np.ma.median(valids_median)
    for i, val in enumerate(outarray):
        if check_array_for_match([val], bad_vals) is not False or val<=min(allowed_range) or val>=max(allowed_range) \
        or np.ma.is_masked(val):
            outarray[i] = global_median
    return scsig.medfilt(outarray, kernelsize)


def average_unique_values(invals, _=0):
    """
    Computes the average of all valid unique numbers in invals.
    """
    counter  = 0
    vals_sum = 0
    for vals in invals:
        for val in list(set(vals)):
            if val:
                counter += 1
                if not np.isnan(val):
                    vals_sum += val
    return vals_sum / float(max(counter,1))


def median_unique_values(invals, _=0):
    """
    Computes the median of all valid unique numbers in invals.
    """
    vals_list = []
    for vals in invals:
        for val in list(set(vals)):
            if val:
                if not np.isnan(val):
                    vals_list.append(val)
    if vals_list:
        return np.median(vals_list)
    else:
        return 0


def min_unique_values(invals, _=0):
    """
    Computes the smallest of all valid unique numbers in invals.
    """
    vals_list = []
    for vals in invals:
        for val in list(set(vals)):
            if val:
                if not np.isnan(val):
                    vals_list.append(val)
    if vals_list:
        return np.min(vals_list)
    else:
        return 0


def range_unique_values_min(invals, minval):
    """
    For a list of lists, take the smallest numbers>0 from the inner lists and returns
    a) the smallest, median, and largest values out of these numbers in one array
    and
    b) assuming each inner list consists of two unique numbers, the number of values>minval for each of these two numbers.
    """
    vals_list = []
    p0_nums   = []
    p1_nums   = []
    for vals in invals:
        unique_vals = list(set(vals))
        proper_vals = []
        for val in unique_vals:
            if val and not np.isnan(val):
                proper_vals.append(val)
        if proper_vals:
            p0_nums.append(proper_vals[0])
            try:
                p1_nums.append(proper_vals[1])
            #both p0 and p1 are the same unique number
            except IndexError:
                p1_nums.append(proper_vals[0])
            vals_list.append(np.min(proper_vals))
    if vals_list:
        p0_detections = np.asarray(p0_nums)
        p1_detections = np.asarray(p1_nums)
        p0_detections = len(p0_detections[p0_detections>minval])
        p1_detections = len(p1_detections[p1_detections>minval])
        return [np.min(vals_list), np.median(vals_list), np.max(vals_list)], p0_detections, p1_detections
    else:
        return [0]


def phase_rotate(c, d):
    """
    Rotate a complex number or numpy array of complex numbers c by a phase p[rad].
    """
    return c * ( np.cos(d) + np.sin(d)*1j )


def delay2phase(delay_ns, datafreq_GHz, reffreq_GHz):
    """
    Compute phase rotation from delay.
    """
    return 2*np.pi * delay_ns * (datafreq_GHz - reffreq_GHz)


def wrap_arrayphase(inarray):
    """
    Wraps phase of array [rad] to get a continuous phase even if it goes outside of the [-pi,pi] range.
    Can only solve for a single jump in inarray. Call multiple times to solve multiple jumps.
    """
    outarray = np.asarray(copy.deepcopy(inarray),float)
    for i,ph in enumerate(inarray):
        try:
            nextph = inarray[i+1]
            if abs(nextph - ph) > 1.1*np.pi:
                outarray[i+1:] -= np.sign(nextph) * 2 * np.pi
                return outarray
        except IndexError:
            return outarray


def wrap_phase(a,b):
    """
    Wraps phase of a [rad] to within [-b,b].
    """
    return (a+b)%(2*b) - b


def frac_diff(x,y):
    fx = float(x)
    fy = float(y)
    return np.abs(fx-fy)/(min(fx,fy) + 1.e-13)


def get_vis_value(complx_number, quanitity='amp'):
    try:
        cnum = np.asarray([np.asanyarray(c, dtype=complex) for c in complx_number])
    except TypeError:
        cnum = complx_number
    real = np.real(cnum)
    imag = np.imag(cnum)
    if quanitity == 'amp':
        return np.sqrt(np.power(real, 2) + np.power(imag, 2))
    elif quanitity == 'phase':
        return np.arctan2(imag, real) * 180./np.pi


def unity_ampltiude_solution(complx_number):
    """
    Takes a (real,imag) number and modifies it such that
    Amplitude = sqrt(real**2 + imag**2) is set to one while leaving the phase = atan(imag/real) unchanged.
    """
    real = np.real(complx_number)
    imag = np.imag(complx_number)
    norm = np.sqrt(real**2 + imag**2)
    real /= norm
    imag /= norm
    return complex(real, imag)


def get_nearest_divisor(numerator, denominator):
    """
    Alters denominator to the nearst number that will make the remainder of numerator/denominator equal to zero.
    """
    _eps_quotient        = 1.e-1
    change_made          = False
    adjusted_denominator = denominator
    if denominator > 0.75 * numerator:
        adjusted_denominator = numerator
        change_made          = 'rounded'
    else:
        quotient, remainder = divmod(numerator, denominator)
        if abs(1.0 - (numerator / quotient)/denominator) > _eps_quotient:
            if int(quotient) == 1:
                _diff = np.abs(denominator - remainder)
                if _diff < 0.5 * numerator:
                    quotient += 1.
            adjusted_denominator = numerator / quotient
            change_made          = True
    return adjusted_denominator, change_made


def add_day_frac(d0, frac):
    """
    Adds fraction of day to yyyy/mm/dd/ to get yyyy/mm/dd/hh:mm:ss
    """
    date = datetime.datetime.strptime(d0, "%Y/%m/%d") + datetime.timedelta(days=float(frac))
    date = str(date).replace(' ', '/')
    date = date.replace('-', '/')
    return date


def xyz_to_height(_x, _y, _z, eccentricity=8.1819190842622e-2, radius=6378137):
    """ Compute height of station to within 10 meters from ITRF Geocentric coordinates. """
    #lon            = np.arctan2(_y, _x)
    #r_3d           = np.sqrt(_x**2 + _y**2 + _z**2)
    _e2            = eccentricity**2
    r_2d           = np.sqrt(_x**2 + _y**2)
    lat_geocentric = np.arctan2(r_2d, _z)
    lat_geodetic   = lat_geocentric
    height         = 0.
    for i in range(10):
        R_N = ellipse_curvatre_radius(_e2, radius, lat_geodetic)
        height_new = r_2d / np.cos(lat_geodetic) - R_N
        if np.abs(height - height_new) < 10:
            break
        height       = height_new
        _arg_inverse = r_2d * ( 1 - _e2 * R_N / (R_N + height) )
        lat_geodetic = np.arctan( _z/ _arg_inverse )
    return min(max(height,0), 5200)


def ellipse_curvatre_radius(eccentricity2, radius, latitude):
    """ Help function for xyz_to_height() above."""
    return radius / np.sqrt( 1 - eccentricity2 * np.sin(latitude)**2)


def finalize(_logs,  _finalmsg, _diagdir='', _origdir=''):
    """ For a graceful exit. """
    with open(_logs[0], 'r') as casalogf:
        for line in casalogf:
            if 'Leap second table TAI_UTC seems out-of-date' in line:
                print ('WARNING: Leap seconds may have been out of date! One of the following should be done:')
                print ('1) Follow the online casadocs website about the <CASA Data Repository> on how to update your CASA data.')
                print ('2) Replace the CASA version used by rPICARD by the one from the rPICARD README file.' + \
                       ' That one is updated daily.'
                      )
                break
            else:
                pass
    if _diagdir and _origdir:
        copyfiles(_logs, _diagdir)
        changedir(_origdir)
    fancy_msg(_finalmsg)


######################################### ------------------ PART 2 ------------------ #########################################
######################################### functions with a CASA specific functionality #########################################


def time_convert(mytime, myunit='s', to_datetime_obj=False):
    """
    Convert CASA time format. See https://casaguides.nrao.edu/index.php/Formats_for_Time
    Or return it as datetime object if to_datetime_obj_fmt=True.
    """
    if to_datetime_obj:
        # Difference between MDJ(1858/11/17/00:00:00) and UTC(1970/01/01/00:00:00).
        return datetime.datetime.utcfromtimestamp(mytime - 3506716800)
    myqa = casac.quanta()
    if type(mytime).__name__ != 'list': mytime=[mytime]
    myTimestr = []
    for _time in mytime:
        q1=myqa.quantity(_time,myunit)
        time1=myqa.time(q1,form='ymd')
        myTimestr.append(time1)
    return myTimestr[0][0]


def start_mpi():
    client = MPICommandClient()
    client.start_services()
    client.set_log_mode('redirect')
    client.set_log_level('NORMAL')
    return client


def chop_scans(_ms_metadata, sources, N_num):
    """
    Takes a string of sources seperated by commas and finds all scans on these sources.
    Returns these scans as csv array in parts of len=N_num.
    Used by applycal for example to not process all scans at the same time in parallel with MPI.
    """
    src_scans = []
    for scans in _ms_metadata.yield_scans(sources, do_not_squeeze_single_source=True):
        src_scans.extend(scans)
    N_scans = []
    for scans in [src_scans[i:i+N_num] for i  in range(0, len(src_scans), N_num)]:
        N_scans.append(','.join(scans))
    return N_scans


def station_NamesToIds(ms_name, station_namedict):
    """
    Convert a dict with station names to station IDs.
    """
    mytb = casac.table()
    mytb.open(ms_name+'/ANTENNA')
    namelist = mytb.getcol('NAME').tolist()
    try:
        isnotname = int(namelist[0])
        namelist  = mytb.getcol('STATION').tolist()
    except ValueError:
        pass
    station_IDdict = {}
    for stname in station_namedict.keys():
        station_IDdict[int(namelist.index(stname))] = station_namedict[stname]
    return station_IDdict


def read_CASA_table(_inp_params, _query_select = 'data', _query_criteria = '', _query_sort = '', _tablename = ''):
    """
    Reads CASA data from table (measurementset by default) and returns a numpy array based on a TaQL _query.
    _query_select: TaQL SELECT -> String of comma seperated columnnames of the MS that are to be returned.
        If _query_select contains 'data' then it is replaced by CORRECTED_DATA if available and DATA otherwise.
    _query_criteria: TaQL WHERE statement to make specific cuts through the data.
    _query_sort: TaQL GROUPBY: String of comma separated columnnames which dictates a sort order.
    Advanced TaQL statements are also supported: http://casacore.github.io/casacore-notes/199.html
    FLAG and DATA/CORRECTED_DATA have the shape [Npol, Nchan] even for Npol=1,Nchan=1.
    DATA_DESC_ID are the spw.
    """
    #useful columnnames: ['FLAG', 'ANTENNA1', 'ANTENNA2', 'DATA_DESC_ID', 'SCAN_NUMBER', 'TIME', '(CORRECTED_)DATA']
    if not _tablename:
        _tablename = _inp_params.ms_name
    check_if_table_is_written(_tablename)
    mytb = casac.table()
    mytb.open(_tablename)
    all_cols = mytb.colnames()
    if 'data' in _query_select:
        if 'CORRECTED_DATA' in all_cols:
            _query_select = _query_select.replace('data', 'CORRECTED_DATA')
        elif 'DATA' in all_cols:
            _query_select = _query_select.replace('data', 'DATA')
        else:
            raise ValueError('No data found in the measurement set.\n')
    _q_sel   = 'SELECT ' + _query_select + ' FROM ' + _tablename + ' '
    if _query_criteria:
        _q_whe = 'WHERE ' + _query_criteria + ' '
    else:
        _q_whe = ''
    if _query_sort:
        _q_ord = 'ORDERBY ' + _query_sort
    else:
        _q_ord = ''
    qtb         = mytb.taql(_q_sel + _q_whe + _q_ord)
    _selections = re.findall(r"[\w']+", _query_select)
    table_cols  = []
    for col in _selections:
        if col in all_cols:
            table_cols.append(col)
    table_data  = [qtb.getcol(table_col) for table_col in table_cols]
    if len(table_data) == 1:
        table_data = table_data[0]
    qtb.done()
    mytb.done()
    return table_data


def extract_amp_selfcalsols(selfcaltable):
    """
    Takes a mode 'a', type 'G' selfcaltable and returns a dict with
    [(time, ampcorr-factor)] for each [scan][station][spwid][polid].
    """
    all_data = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
    mytb     = casac.table()
    mytb.open(selfcaltable)
    nrows = mytb.nrows()
    for row in range(nrows):
        thisscan  = mytb.getcell('SCAN_NUMBER', row)
        thisant   = mytb.getcell('ANTENNA1', row)
        thisspwid = mytb.getcell('SPECTRAL_WINDOW_ID', row)
        thistime  = mytb.getcell('TIME', row)
        thesevals = mytb.getcell('CPARAM', row)
        for i,val in enumerate(thesevals):
            thisampval = float(np.real(val[0]))
            try:
                all_data[thisscan][thisant][thisspwid][i].append((thistime, thisampval))
            except KeyError:
                all_data[thisscan][thisant][thisspwid][i] = [(thistime, thisampval)]
    mytb.done()
    return all_data


def check_field_in_table(_inp_params, _ms_metadata, table, field):
    this_fid     = _ms_metadata.yield_sourceID(field)
    present_fids = read_CASA_table(_inp_params, 'unique FIELD_ID', _tablename=table)
    return str(this_fid) in np.asarray(present_fids, dtype=str)


def create_caltable(_inp_params, tablename, parametertype, calibrationtype, singlechannel=False):
    """
    Create a custom calibration table for the _inp_params.msname measurement set (inherit metadata).
    Used for my scalar bandpass function.
    """
    mycb = casac.calibrater()
    mycb.open(_inp_params.ms_name, False, False, False)
    mycb.createcaltable(tablename, parametertype, calibrationtype, singlechannel)
    mycb.close()


def check_if_table_is_written(tablename, _wait=60, _maxiter=5):
    """
    Check if a table is properly written and wait 60s (by default) if not.
    Useful for very slow disks...
    """
    mytb = casac.table()
    if isdir(tablename):
        safety = 0
        while True:
            try:
                mytb.open(tablename)
                mytb.close()
                break
            except RuntimeError:
                time.sleep(_wait)
            safety += 1
            if safety > _maxiter:
                raise OverflowError('The table ' + str(tablename) + ' seems to be stuck on write...')


def create_dummy_caltableinfo(_inp_params, _tbname, _columns, _columnvaltypes, _gaintype='B'):
    """
    Returns dminfo and desc that can be used for a simple cal table (only tested for 'B' type table).
    Table description will be for _columns with _columnvaltypes (arrays must be 1D and same shape).
    _tbname must point to the place on disk where the table will be created (currently unused).
    """
    npcols = np.asarray(_columns, dtype='|S19')
    dfname = 'MSMTAB'
    dftype = 'StandardStMan'
    msname = os.path.abspath(_inp_params.ms_name.strip('/'))

    dfkeywords                    = {}
    dfkeywords['ANTENNA']         = 'Table: ' + msname + '/ANTENNA'
    dfkeywords['FIELD']           = 'Table: ' + msname + '/FIELD'
    dfkeywords['HISTORY']         = 'Table: ' + msname + '/HISTORY'
    dfkeywords['MSName']          = _inp_params.ms_name.strip('/')
    dfkeywords['OBSERVATION']     = 'Table: ' + msname + '/OBSERVATION'
    dfkeywords['ParType']         = 'Complex'
    dfkeywords['PolBasis']        = 'unknown'
    dfkeywords['SPECTRAL_WINDOW'] = 'Table: ' + msname + '/SPECTRAL_WINDOW'
    dfkeywords['VisCal']          = _gaintype + ' Jones'

    dfspec                    = {}
    dfspec['ActualCacheSize'] = 2
    dfspec['BUCKETSIZE']      = 2560
    dfspec['IndexLength']     = 150
    dfspec['PERSCACHESIZE']   = 2

    dummydmi0            = {}
    dummydmi0['COLUMNS'] = npcols
    dummydmi0['NAME']    = dfname
    dummydmi0['SEQNR']   = 0
    dummydmi0['SPEC']    = dfspec
    dummydmi0['TYPE']    = dftype
    dummydmi             = {}
    dummydmi['*1']       = dummydmi0

    dfdesc_interval_keyword                 = {}
    dfdesc_interval_keyword['QuantumUnits'] = np.asarray(['s'], dtype='|S2')

    dfdesc_time_keyword                 = {}
    dfmeasinfo_timekey                  = {}
    dfmeasinfo_timekey['Ref']           = 'UTC'
    dfmeasinfo_timekey['type']          = 'epoch'
    dfdesc_time_keyword['MEASINFO']     = dfmeasinfo_timekey
    dfdesc_time_keyword['QuantumUnits'] = dfdesc_interval_keyword['QuantumUnits']

    dfdesc_special_keywords             = {}
    dfdesc_special_keywords['INTERVAL'] = dfdesc_interval_keyword
    dfdesc_special_keywords['TIME']     = dfdesc_time_keyword

    dfdesc_special_ndim             = {}
    dfdesc_special_ndim['CPARAM']   = -1
    dfdesc_special_ndim['FLAG']     = -1
    dfdesc_special_ndim['PARAMERR'] = -1
    dfdesc_special_ndim['SNR']      = -1
    dfdesc_special_ndim['WEIGHT']   = -1

    dummydesc = {}
    for _col, _colvaltype in zip(_columns, _columnvaltypes):
        thisdict                     = {}
        thisdict['comment']          = ''
        thisdict['dataManagerGroup'] = dfname
        thisdict['dataManagerType']  = dftype
        thisdict['maxlen']           = 0
        thisdict['valueType']        = _colvaltype
        if _col in dfdesc_special_keywords.keys():
            thisdict['keywords'] = dfdesc_special_keywords[_col]
        else:
            thisdict['keywords'] = {}
        if _col in dfdesc_special_ndim.keys():
            thisdict['ndim']   = dfdesc_special_ndim[_col]
            thisdict['option'] = 0
        else:
            thisdict['option'] = 5
        dummydesc[_col] = thisdict
    dummydesc['_define_hypercolumn_'] = {}
    dummydesc['_keywords_']           = dfkeywords
    dummydesc['_private_keywords_']   = {}
    return dummydesc, dummydmi


def fix_antenna_mounts(_inp_params, f_correct_mounts):
    """
    Overwrites the mount types in the ANTENNA table (read from idi files)
    by values specified in a file f_correct_mounts which should have the format
    STATIONCODE1 MOUNTTYPE1
    STATIONCODE2 MOUNTTYPE2
    STATIONCODE3 MOUNTTYPE3
    .
    .
    .
    """
    mytb = casac.table()
    antt = _inp_params.ms_name.strip('/')+'/ANTENNA'
    mytb.open(antt, nomodify=False)
    ants = mytb.getcol('NAME')
    with open(f_correct_mounts) as cmounts:
        for line in cmounts:
            values = line.split()
            for i,val in enumerate(values):
                if val in ants:
                    antindx = np.where(ants==val)[0][0]
                    try:
                        mytb.putcell('MOUNT', int(antindx), str(values[i+1]))
                    except IndexError:
                        #handle a more generic file format
                        pass
    mytb.flush()
    mytb.done()


def get_spwmap(_inp_params, _ms_metadata, _caltables):
    """
    Returns an ordered list for the spwmap parameter based on _caltables.
    [] for each caltable that has solutions for more than one spw.
    [0]*number_of_corresponding_spwds for each caltable that has solutions for only one spw.
    """
    mytb = casac.table()
    if not _caltables:
        return []
    try:
        if isinstance(_caltables, basestring):
            _caltables = [_caltables]
        else:
            pass
    except NameError:
        if isinstance(_caltables, str):
            _caltables = [_caltables]
        else:
            pass
    all_spw = len(_ms_metadata.all_spwds)
    spwmap  = []
    spwids  = []
    for caltb in _caltables:
        check_if_table_is_written(caltb)
        mytb.open(caltb)
        if 'SPECTRAL_WINDOW_ID' in mytb.colnames():
            spwids.append(read_CASA_table(_inp_params, 'unique SPECTRAL_WINDOW_ID', _tablename=caltb))
        else:
            spwids.append([])
        mytb.close()
        mytb.clearlocks()
    spwids = list(spwids)
    for CALTBspws in spwids:
        try:
            _ = CALTBspws[0]
        except IndexError:
            spwmap.append([])
            continue
        N_spws = len(CALTBspws)
        if N_spws<all_spw:
            this_spwmap      = []
            current_CALTBspw = np.min(CALTBspws)
            for MSspw in range(all_spw):
                if MSspw in CALTBspws:
                    current_CALTBspw = MSspw
                else:
                    pass
                this_spwmap.append(current_CALTBspw)
            spwmap.append(this_spwmap)
        else:
            spwmap.append([])
    return spwmap


def extend_solutions_to_all_spw(_ms_metadata, _caltable):
    """
    *Replaced by get_spwmap() for now*
    Takes a caltable solved for with combine='spw' and extend the solutions to all spw in the MS.
    Assumes that all antennas have the same number of spws.
    May become obsolete in the future where CASA does this automatically at the solver stage.
    Example usage: Multiband fringefit solutions.
    """
    mytb = casac.table()
    mytb.open(_caltable, nomodify=False)
    nrows = mytb.nrows()
    spwds = _ms_metadata.all_spwds[1:]
    for spwd in spwds:
        mytb.copyrows(_caltable, nrow = nrows)
        for i in range(spwd*nrows, (1+spwd)*nrows):
            mytb.putcell('SPECTRAL_WINDOW_ID', i, spwd)
    mytb.flush()
    mytb.done()


def concat_caltables(_caltables, outtable, cleanup=True, pass_missing_tables=False, return_missing_tables=False):
    """
    Takes a list of calibration tables and concatenate all which exist into outtable.
    If cleanup=True: delete all original tables.
    """
    mytb = casac.table()
    existing_tables = []
    missing_tables  = []
    for caltb in _caltables:
        if isdir(caltb):
            check_if_table_is_written(caltb)
            existing_tables.append(caltb)
        else:
            missing_tables.append(caltb)
    if not existing_tables:
        if pass_missing_tables:
            if return_missing_tables:
                return missing_tables
            else:
                return False
        else:
            _err = 'Trying to concatenate non-existing tables:' + str(_caltables)
            _err+= '. Check the input parameters of your last calibration step.'
            raise IOError(_err)
    #Sorting tables by scans.
    first_scans = []
    for caltb in existing_tables:
        first_scans.append(min(read_CASA_table(None, 'unique SCAN_NUMBER', _tablename=caltb)))
    existing_tables_sorted = [tab for _,tab in sorted(zip(first_scans, existing_tables))]
    reftable               = existing_tables_sorted[0]
    for caltb in existing_tables_sorted[1:]:
        mytb.open(caltb)
        mytb.copyrows(reftable)
        mytb.close()
    shutil.move(reftable, outtable)
    if cleanup:
        for caltb in existing_tables_sorted[1:]:
            shutil.rmtree(caltb, ignore_errors=True)
    if pass_missing_tables:
        return missing_tables


def get_info_from_listobs(_inp_params):
    """
    Opens the list.obs file to read off the polarizations (correlations) in the order of DATA in the MS.
    Also determines the sources observed for each scan and their IDs.
    And antenna names vs their IDs.
    And spectral windows with number of channels each.
    And number of visibilities of each scan.
    This info could also be read from the MS table itself but this is also fun.
    Makes use of read_CASA_table() to get in each scan all antennas present
      and for each (scan, antenna) also the spectral windows available,
      plus for each scan the length of the scan in seconds.
    """
    myms   = casac.ms()
    mymsmd = casac.msmetadata()
    if _inp_params.diag_listobs:
        anames     = list(read_CASA_table(_inp_params, _query_select = 'STATION', _tablename = _inp_params.ms_name+'/ANTENNA'))
        aIDs       = range(len(anames))
        snames     = list(read_CASA_table(_inp_params, _query_select = 'NAME', _tablename = _inp_params.ms_name+'/FIELD'))
        sIDs       = list(range(len(snames)))
        logfile    = open(_inp_params.diagdir + _inp_params.diag_listobs, 'r')
        tmp_lines  = logfile.readlines()
        for i,line in enumerate(tmp_lines):
            if 'Corrs' in line:
                corr_line = tmp_lines[i+1]
                break
        corr_line = list(reversed(corr_line.split()))
        corrs = []
        for l in corr_line:
            try:
                float(l)
                break
            except ValueError:
                corrs.append(l)
        corrs = list(reversed(corrs))
        logfile.close()
    else:
        raise IOError('Error: Should set inp_params.diag_listobs. ' \
                 'Without this file the correlations [RR,RL,LR,LL] cannot be determined.\n')
    scans        = [read_CASA_table(_inp_params, 'unique SCAN_NUMBER', 'FIELD_ID=='+str(sid) , 'SCAN_NUMBER') for sid in sIDs]
    all_scans    = flatten_list(scans)
    antennas     = [read_CASA_table(_inp_params, 'unique ANTENNA1', 'SCAN_NUMBER=='+str(scan), 'ANTENNA1') \
                    for scan in all_scans]
    antennas2    = [read_CASA_table(_inp_params, 'unique ANTENNA2', 'SCAN_NUMBER=='+str(scan), 'ANTENNA1') \
                    for scan in all_scans]
    for i,ants1 in enumerate(antennas):
        for ant2 in antennas2[i]:
            if ant2 not in ants1:
                antennas[i] = np.append(antennas[i], ant2)
    spwds        = []
    scan_lengths = []
    for scan,ants in zip(all_scans,antennas):
        spwds.append([read_CASA_table(_inp_params, 'unique DATA_DESC_ID',
                                      'SCAN_NUMBER=='+str(scan)+' && ANTENNA1=='+str(a), 'DATA_DESC_ID') \
                                      for a in ants])
        _time = read_CASA_table(_inp_params, 'unique TIME_CENTROID', 'SCAN_NUMBER=='+str(scan))
        #length of scan accounting for missing half second at both ends:
        scan_lengths.append(np.max(_time) - np.min(_time) + 1.0)
    myms.open(_inp_params.ms_name)
    scansum   = myms.getscansummary()
    ssum_fkey = list(scansum.keys())[0]
    ac_period = round(float(scansum[ssum_fkey]['0']['IntegrationTime']), 6)
    scan_visi = [scansum[str(sc)]['0']['nRow'] for sc in all_scans]
    myms.close()
    mymsmd.open(_inp_params.ms_name)
    spws           = range(mymsmd.nspw())
    chan_spws      = [mymsmd.nchan(spw_iter) for spw_iter in range(mymsmd.nspw())]
    chan_freqs     = [mymsmd.chanfreqs(spw_iter) for spw_iter in range(mymsmd.nspw())]
    all_info       = mymsmd.summary()
    scan_info_dict = {}
    obs_ids        = []
    scan_startimes = []
    scan_endtimes  = []
    for obs_key in all_info.keys():
        if 'observationID' in obs_key:
            obs_ids.append(obs_key)
    for obs_id in obs_ids:
        for array_key in all_info[obs_id].keys():
            for scan_key in all_info[obs_id][array_key]:
                for indx in all_info[obs_id][array_key][scan_key].keys():
                    if 'fieldID' in indx:
                        this_fieldstring_indx = indx
                        break
                this_scan = str(scan_key.split('=')[-1])
                scan_info_dict[this_scan] = all_info[obs_id][array_key][scan_key][this_fieldstring_indx]
    for scan in all_scans:
        this_scaninfo = scan_info_dict[str(scan)]
        scan_startimes.append(this_scaninfo['begin time'])
        scan_endtimes.append(this_scaninfo['end time'])
    mymsmd.close()
    return corrs, snames, sIDs, anames, aIDs, spws, chan_spws, chan_freqs, scans, antennas, spwds, scan_lengths, scan_visi, \
           scan_startimes, scan_endtimes, ac_period


def read_inputs(_input_location, _casalogf):
    """
    Reads all input from all files passed to the function.
    Automatically detects floats, ints, strings, and {None,True,False} boolean, and arrays with a ; delimiter.
    Attaches the current datetime to diagdir.
    And reads the current casa logfile to determine a list with MPI process IDs.
    Returns a namedtuple that contains all input parameters for the code in a single object.
        This object can then be passed around in the pipeline.
        The parameters can be retrieved from the corresponding keywords in the input files.
    Also returns the names of the files that were actually read in.
    """
    print('\nReading input...')
    files  = glob_all(_input_location)
    kwargs = []
    args   = []
    if not files:
        raise IOError('Did not find input files in  ' + _input_location + '\n')
    for filename in files:
        with open(filename, 'r') as f:
            for line in proper_line(f):
                if '=' in line:
#FIXME: should clean-up the code below at some point...
                    line = line.replace(' ','')
                    line = line.split('#', 1)[0]
                    line = line.split('=')
                    kwargs.append(line[0])
                    if ';' in line[1]:
                        x = line[1].split(';')
                        isarray = True
                    #no array:
                    else:
                        x = line[1]
                        isarray = False
                    #int or float:
                    try:
                        isfloat = False
                        if isarray:
                            if '.' in x[0]:
                                isfloat = True
                                x = [float(xx) for xx in x]
                        elif '.' in x:
                            x = float(x)
                            isfloat = True
                        if not isfloat:
                            if isarray:
                                x = [int(xx) for xx in x]
                            else:
                                x = int(x)
                    #it is a string instead:
                    except ValueError:
                        if isarray:
                            x = [xx.replace("'",'') for xx in x]
                            x = [xx.replace('"','') for xx in x]
                        else:
                            x = x.replace("'",'')
                            x = x.replace('"','')
                    #possibly arrays of booleans:
                    if isarray:
                        x = [string_to_bool(xx) for xx in x]
                    #or scalar booleans:
                    else:
                        x = string_to_bool(x)
                    args.append(x)
    logf  = open(_casalogf, 'r')
    lines = logf.readlines()
    pids  = {}
    for l in lines:
        if 'MPICommandClient::send_start_service_signal::MPICommandClient::send_start_service_signal::casa' in l:
            if 'Server with rank' in l:
                lline        = l.split('Server with rank')[-1].strip('\n').split()
                server       = int(lline[0])
                pid          = int(lline[-1])
                pids[server] = pid
    logf.close()
    kwargs.append('MPI_processIDs')
    args.append(pids)
    kwargs.append('CASALogFile')
    args.append(_casalogf)
    ps_aux  = subprocess_ospopen('ps aux')
    for ps in ps_aux[1::]:
        ps_info  = ps.split()
        try:
            this_pid = ps_info[1]
            if '-auth /tmp/CASA_MPIServer_xauth' in ps:
                unnecessary_display_for_an_MPI_pid = ps_info[ps_info.index('Xvfb')+1]
                unnecessary_display_for_an_MPI_pid= int(unnecessary_display_for_an_MPI_pid.replace(':', ''))
                if unnecessary_display_for_an_MPI_pid in pids:
                    os.kill(int(this_pid), 9)
        except (IndexError, ValueError, OSError, OverflowError) as _:
            continue
    current_datetime   = strftime("_%Y-%m-%d_%H-%M-%S/", gmtime())
    logdir_indx        = kwargs.index('diagdir')
    msname_indx        = kwargs.index('ms_name')
    #remove trailing slash if present:
    args[msname_indx]  = args[msname_indx].rstrip('/')
    if '_' in args[logdir_indx]:
        raise ValueError('A _ is not allowed in the diagdir input parameter:\nSo diagdir=' + args[logdir_indx] + ' is invalid.')
    #remove trailing slash if present and attach datetime with slash:
    args[logdir_indx]  = args[logdir_indx].rstrip('/') + current_datetime
    workdir_indx       = kwargs.index('workdir')
    flagdir_indx       = kwargs.index('flag_dir')
    #expand $pwd:
    if args[workdir_indx].lower() == '$pwd':
        args[workdir_indx] = os.path.abspath(_input_location.rstrip('*.inp')).rstrip('/')+'/../'
    #add trailing slash if not present:
    args[workdir_indx] = os.path.expanduser(os.path.join(args[workdir_indx], ''))
    args[flagdir_indx] = os.path.join(args[flagdir_indx], '')
    inputs_obj         = namedtuple('inp_params', kwargs )
    inputs_list        = inputs_obj(*args)
    if inputs_list.verbose:
        print('  Found ' + str(len(inputs_list)) + ' parameters')
    print ('Done\n')
    return inputs_list, [os.path.abspath(f) for f in files]


def search_gcdpfu_in_idi(_inp_params, _ms_metadata, antab_file, fitsidi_files):
    """
    Looks for a GAIN_CURVE extension table with DPFU and gain curve information in the loaded fitsidi_files.
    If the info is found and no antab_file was loaded previously, the gain curve and dpfu data is written to
    a file specified in the _inp_params gc_dpfu_fromidi_file parameter and passed on as ANTAB table.
    The Tsys data should have been loaded elsewhere.
    """
    if antab_file and os.path.abspath(antab_file) != os.path.abspath(_inp_params.gc_dpfu_fromidi_file):
        return antab_file
    if not fitsidi_files:
        return antab_file
    try:
        if isinstance(fitsidi_files, str):
            fitsidi_files = [fitsidi_files]
        fitsidi_files = list(fitsidi_files)
        for idi_file in fitsidi_files:
            hdulist = pyfits.open(idi_file)
            try:
                gcext          = hdulist['GAIN_CURVE']
                new_antab_file = _inp_params.gc_dpfu_fromidi_file
                outf           = open(new_antab_file, 'w')
                wrote_any      = False
                for i, antno in enumerate(gcext.data['ANTENNA_NO']):
                    if not int(np.unique(gcext.data['TYPE_1'][i])[0]) == 2:
                        #Can only handle polynomial gain curves.
                        continue
                    else:
                        wrote_any = True
                    this_antname = str(_ms_metadata.yield_antname(int(antno)-1))
                    gain_coeffs  = gcext.data['NTERM_1'][i][0]
                    gain_curve   = gcext.data['GAIN_1'][i][:gain_coeffs]
                    dpfu_pol1    = str(gcext.data['SENS_1'][i][0])
                    dpfu_pol2    = str(gcext.data['SENS_2'][i][0])
                    this_ANline  = 'GAIN {0} ELEV DPFU = {1}, {2} POLY ='.format(this_antname, dpfu_pol1, dpfu_pol2)
                    #Higher order polynomials not handled by JIVE scripts (yet?) or not compatible with VLA-type gain curve(?)
                    for gain_curve_val in gain_curve[:3]:
                        this_ANline += ' {0},'.format(str(gain_curve_val))
                    this_ANline = this_ANline[0:-1] + ' /\n'
                    outf.write(this_ANline)
                hdulist.close()
                outf.close()
                rm_file_if_empty(new_antab_file)
                if wrote_any:
                    if check_for_filecontent(new_antab_file):
                        return new_antab_file
                    else:
                        return antab_file
                else:
                    return antab_file
            except KeyError:
                hdulist.close()
                rm_file_if_empty(new_antab_file)
    except:
        rm_file_if_empty(new_antab_file)
        return antab_file


def get_antabtable(_inp_params):
    if _inp_params.antab_name:
        antab_file = _inp_params.antab_name
    else:
        antab_file = get_extension_matches_in_all_subdirs(_inp_params.workdir, _inp_params.antab_extensions)
    if not antab_file:
        if _inp_params.pass_missing_antab:
            return False
        else:
            raise IOError('No valid ANTAB table found in ' + _inp_params.workdir+ '\n')
    if len(antab_file) > 1:
        raise IOError('Found more than one ANTAB table. This is not supported.\n' \
                      'Please concatenate these files into a singe table:\n' + \
                      wrap_list(antab_file, indent='  ')
                     )
    else:
        return antab_file[0]


def get_wxfiles(_inp_params):
    if _inp_params.wxfile_name:
        wx_files = _inp_params.wxfile_name
        if not isinstance(wx_files, list):
            wx_files = [wx_files]
    else:
        wx_files = get_extension_matches_in_all_subdirs(_inp_params.workdir, _inp_params.wxfile_extensions)
    if not wx_files:
        wx_files = [None]
    if len(wx_files) > 1:
        raise IOError('Found more than one weather table. This is not supported.\n' \
                      'Please concatenate these files into a singe table:\n' + \
                      wrap_list(wx_files, indent='  ')
                     )
    else:
        return wx_files[0]


def get_fitsidifiles(_inp_params):
    if _inp_params.fitsidi_name:
        fitsidi_files = _inp_params.fitsidi_name
    else:
        fitsidi_files_candidates = get_extension_matches_in_all_subdirs(_inp_params.workdir, _inp_params.fitsidi_extensions)
        fitsidi_files = []
        for fitsidi in fitsidi_files_candidates:
            if has_antenna_table(_inp_params, fitsidi):
                fitsidi_files.append(fitsidi)
    if not fitsidi_files:
        if _inp_params.pass_missing_fitsidi:
            return False
        else:
            raise IOError('  No valid fits-idi files found in ' + _inp_params.workdir+ '\n')
    return fitsidi_files


def get_flagfiles(_inp_params):
    if _inp_params.flagfile_name:
        flag_files = _inp_params.flagfile_name
        if not isinstance(flag_files, list):
            flag_files = [flag_files]
    else:
        flag_files = get_extension_matches_in_all_subdirs(_inp_params.workdir, _inp_params.flagfile_extensions)
    return flag_files


def get_modelfiles(_inp_params):
    if _inp_params.modelfile_name:
        model_files = _inp_params.modelfile_name
        if not isinstace(model_files, list):
            model_files = [model_files]
    else:
        model_files = get_extension_matches_in_all_subdirs(_inp_params.workdir, _inp_params.modelfile_extensions, also_dirs=True)
    return model_files


def get_trecfiles(_inp_params):
    if _inp_params.trecfile_name:
        trec_files = _inp_params.trecfile_name
        if not isinstace(trec_files, list):
            trec_files = [trec_files]
    else:
        trec_files = get_extension_matches_in_all_subdirs(_inp_params.workdir, _inp_params.trecfile_extensions)
    return trec_files


def find_first_dobs(idifiles):
    first_dobs = datetime.datetime.strptime('9999/12/12', '%Y/%m/%d')
    for match in idifiles:
        hdulist = pyfits.open(match)
        try:
            dobs = hdulist['PRIMARY'].header['DATE-OBS']
        except KeyError:
            dobs = hdulist['UV_DATA'].header['DATE-OBS']
        dobs = dobs.replace('-','/')
        dobs = datetime.datetime.strptime(dobs, '%Y/%m/%d')
        if dobs < first_dobs:
            first_dobs = dobs
        hdulist.close()
    return first_dobs.strftime('%Y/%m/%d')


def get_num_antenna(idifile):
    hdulist = pyfits.open(idifile)
    try:
        N_ant = len(hdulist['ANTENNA'].data)
    except KeyError:
        N_ant = 0
    hdulist.close()
    return N_ant


def fix_uvfits_AN(uvf_file, _ms_metadata):
    """
    It seems like exportuvfits in CASA sometimes write 'AIPS AN' tables with corrupted antenna names and a bad RDATE header.
    Moreover, all mount types are set to 0 in the MNTSTA column of the AN table.
    This function corrects these issues.
    """
    if not isfile(str(uvf_file)):
        return
    uvfits = pyfits.open(uvf_file, mode='update')
    #dobs   = uvfits['PRIMARY'].header['DATE-OBS']
    for extension_iter in range(len(uvfits)):
        try:
            thisname = uvfits[extension_iter].name
            if 'AIPS AN' not in thisname:
                continue
            ANdata = uvfits[extension_iter].data
        except KeyError:
            uvfits.close()
            return
        ANhead          = uvfits[extension_iter].header
        #ANhead['RDATE'] = dobs
        for i, antname in enumerate(ANdata['ANNAME']):
            #remove illegal characters:
            thisant = ''.join(char for char in str(antname) if char.isalnum())
            try:
                thisant = thisant.decode()
            except AttributeError:
                pass
            goodant = 'unknown'
            if thisant in _ms_metadata.antennanames:
                goodant = thisant
            else:
                for realant in _ms_metadata.antennanames:
                    if thisant.startswith(realant) or thisant.lstrip('b').startswith(realant):
                        goodant = realant
                        break
            ANdata['ANNAME'][i] = goodant
            try:
                ANdata['MNTSTA'][i] = global_uvfits_mountcodes[_ms_metadata.yield_antmount(goodant).upper()]
            except KeyError:
                ANdata['MNTSTA'][i] = global_uvfits_mountcodes['ALT-AZ']
        pyfits.update(uvf_file, ANdata, header=ANhead, ext=extension_iter)
    uvfits.flush()
    uvfits.close()


def lookfor_illegal_chars(_inp_params, key, illegal_keyword_char):
    """
    Raises ValueError if _inp_params.key contains illegal_keyword_char.
    """
    if is_set(_inp_params, key):
        this_keyword = getattr(_inp_params, key)
        if isinstance(this_keyword, str) and illegal_keyword_char in this_keyword:
            raise ValueError('Illegal  {0}  char in  {1}  input parameter.'.format(illegal_keyword_char, key))
        else:
            pass
    else:
        pass


def check_srcinp(_inp_params, srcinp_name):
    """
    Check sources specified in observation.inp.
    Can also be used to check if an input has a ; when it should not.
    """
    if is_set(_inp_params, srcinp_name):
        this_src = getattr(_inp_params, srcinp_name)
        if isinstance(this_src, list):
            raise ValueError('Illegal  ;  character in {0} input. Please use  ,  instead.'.format(srcinp_name))
        this_src = this_src.split(',')
    else:
        this_src = []
    return this_src


def check_duplicate_src(sci_srcs, cal_srcs):
    """
    A source cannot be a science target and calibrator at the same time.
    """
    for sci_s in sci_srcs:
        for cal_s in cal_srcs:
            if cal_s in sci_s:
                raise ValueError('{0} is illegally specified as both a calibrator and science target.'.format(cal_s))
            else:
                pass


def check_input(_inp_params):
    """
    Check if several input parameters are set correctly.
    """
    sci_src = check_srcinp(_inp_params, 'science_target')
    c_cal   = check_srcinp(_inp_params, 'calibrators_instrphase')
    b_cal   = check_srcinp(_inp_params, 'calibrators_bandpass')
    r_cal   = check_srcinp(_inp_params, 'calibrators_rldly')
    d_cal   = check_srcinp(_inp_params, 'calibrators_dterms')
    p_cal   = check_srcinp(_inp_params, 'calibrators_phaseref')
    check_duplicate_src(sci_src, c_cal)
    check_duplicate_src(sci_src, b_cal)
    check_duplicate_src(sci_src, r_cal)
    check_duplicate_src(sci_src, d_cal)
    check_duplicate_src(sci_src, p_cal)
    check_srcinp(_inp_params, 'refant')
    should_not_contain_comma = ['fringe_delay_window_initial', 'fringe_rate_window_initial', 'fringe_delay_window_mb_sci_short',
                                'fringe_rate_window_mb_sci_short', 'fringe_solint_optimize_search_cal',
                                'fringe_solint_optimize_search_sci', 'fringe_solint_mb_reiterate'
                               ]
    for sncc in should_not_contain_comma:
        lookfor_illegal_chars(_inp_params, sncc, ',')


def check_data(_inp_params, fitsidi_files):
    """
    Before trying to attach a TSYS table to the idi files or working with an MS directly,
    raise an IOError if no MS and no idi files are present (i.e. no data present to work on).
    This check must be done before amplitude_calibration.attach_tsys_to_idi() is called.
    """
    if not isdir(_inp_params.ms_name) and not fitsidi_files:
        raise IOError('No data found! There is no MS called ' + str(_inp_params.ms_name) + ' and no fits-idi files.\n' + \
                      'Check inputs: fitsidi_extensions in constants.inp and ms_name in observation.inp.\n' + \
                      'Also, make sure that fits-idi files are within the workdir path:\n' + str(_inp_params.workdir)
                     )
    elif fitsidi_files:
        for ff in fitsidi_files:
            try:
                checkfile = pyfits.open(ff)
                checkfile.close()
            except IOError:
                raise IOError(str(ff) + ' is not a valid fits-idi file! Check fits_extensions in constants.inp')


def has_antenna_table(_inp_params, infile):
    """
    Return True if infile as an ANTENNA table extension (should be the case for fits-idi files).
    Return False otherwise.
    """
    try:
        fopen = pyfits.open(infile)
    except IOError:
        return False
    try:
        _ = fopen['ANTENNA']
        fopen.close()
        return True
    except KeyError:
        fopen.close()
        if _inp_params.verbose:
            print ('  Found {0} as a possible raw-data file but it is not a valid IDI file. Skipping.'.format(str(infile)))
        return False


def sort_fitsidi_files(fitsidi_files):
    """
    Sort FITS-IDI files by descending number of antennas to get consistent antenna table orderings across different days for
    EHT data.
    """
    if fitsidi_files:
        these_fitsidi_files = copy.deepcopy(fitsidi_files)
        these_fitsidi_files = force_list(these_fitsidi_files)
        idi_Nants           = [get_num_antenna(idi) for idi in these_fitsidi_files]
        if len(these_fitsidi_files) > 1:
            these_fitsidi_files = [idi for _, idi in sorted(zip(idi_Nants, these_fitsidi_files), key=lambda pair: pair[0])][::-1]
        return these_fitsidi_files
    else:
        return fitsidi_files


def rescale_sigma_and_weights(ms_name='VLBI.ms.avg', sigmafactor=1/0.881):
    """
    Multiply all SGIMA and WEIGHT values of ms_name by sigmafactor and sigmafactor^(-2) respectively.
    """
    mytb = casac.table()
    mytb.open(ms_name, nomodify=False)
    VALS = mytb.getcol('SIGMA')
    VALS*= sigmafactor
    mytb.putcol('SIGMA', VALS)
    VALS = mytb.getcol('WEIGHT')
    VALS*= sigmafactor**(-2)
    mytb.putcol('WEIGHT', VALS)
    mytb.flush()
    mytb.done()
    mytb.clearlocks()


def set_const_modelamp(ms_name='VLBI.ms', amp=1.0):
    """
    Set all model amplitudes to a const. amp value, while leaving the phases untouched.
    """
    mytb = casac.table()
    mytb.open(ms_name, nomodify=False)
    nrows = mytb.nrows()
    prt_c = 0
    prt_p = 1
    prt_t = nrows / 10 - 2
    for row in range(nrows):
        CELL = mytb.getcell('MODEL_DATA', row)
        CELL*= np.divide(amp, np.sqrt(np.add(np.power(np.real(CELL), 2), np.power(np.imag(CELL), 2))) + 1.e-13)
        mytb.putcell('MODEL_DATA', row, CELL)
        if prt_c == prt_t:
            sys.stdout.write("\r    ...{0}%".format(str(10*prt_p)))
            sys.stdout.flush()
            prt_p+= 1
            prt_c = 0
        prt_c += 1
    print('\n')
    mytb.flush()
    mytb.done()
    mytb.clearlocks()


def load_the_data(_inp_params, fitsidi_files):
    """
    If _inp_params.ms_name does not exist:
      - check if there is enough if space for the MS and if so:
        * executes CASA's importfitsidi task, generating a temporary MS
        * executes CASA's partition task to generate a MMS.
    """
    print ('\nLoading the data...')
    allowed_pmodes = ['fitsidi', 'MS', 'MS_clean', 'MS_fitsidi', 'MS_fitsidi_clean']
    this_pmode     = _inp_params.MS_partitioning
    if not this_pmode:
        this_pmode = [None]
    elif this_pmode not in allowed_pmodes:
        raise ValueError(str(this_pmode) + ' is not an allowed partitioning mode. Must be any of ' + str(allowed_pmodes))
    these_fitsidi_files = copy.deepcopy(fitsidi_files)
    if os.path.exists(_inp_params.ms_name):
        print ('  The measurement set ' + _inp_params.ms_name + ' already exists.\n    ' \
               'I assume you want to work with the same measurement set again, \n    ' \
               'but probably with a different calibration strategy.\n    ' \
               'Therefore, I will not load the fits-idi files again and keep the old MS.'
              )
        if 'MS' in this_pmode and not isdir(_inp_params.ms_name+'/SUBMSS'):
            oldms = unique_filename(_inp_params.ms_name+'.old_unpartitioned')
            print ('  But the MS is not a MMS so I will move it to\n' \
                   '  ' + oldms + '\n  and create a MMS from it for the pipeline.'
                  )
            check_available_space(_inp_params, _inp_params.ms_name, overhead=2)
            old_flagvers = _inp_params.ms_name + '.flagversions'
            if isdir(old_flagvers):
                raise IOError('Found flagversion files for a MS that I was about to create. Please delete\n' + old_flagvers)
            shutil.move(_inp_params.ms_name, oldms)
            task_partition_general(_inp_params, oldms)
            if 'clean' in this_pmode:
                print ('  Got ' + str(this_pmode) + ' as partitioning mode - will delete the old MS.')
                shutil.rmtree(oldms, ignore_errors=True)
    else:
        if _inp_params.verbose:
            print ('  Found\n' + wrap_list(these_fitsidi_files, indent='    '))
        if is_set(_inp_params, 'fitsidi_to_MS_overhead'):
            assumed_disk_space_factor = _inp_params.fitsidi_to_MS_overhead
        else:
            assumed_disk_space_factor = 3.8
        check_available_space(_inp_params, these_fitsidi_files, assumed_disk_space_factor)
        tmpms = 'tmp.ms.' + random_number_string(6)
        tasks.importfitsidi(fitsidifile      = these_fitsidi_files,
                            vis              = tmpms,
                            constobsid       = True,
                            scanreindexgap_s = _inp_params.scanreindexgap
                           )
        if 'fitsidi' in this_pmode:
            task_partition_general(_inp_params, tmpms)
            shutil.rmtree(tmpms, ignore_errors=True)
        else:
            shutil.move(tmpms, _inp_params.ms_name)
    print ('Done\n')


def export_the_data(_inp_params, _ms_metadata):
    """
    If _inp_params.avg_final_ms is set, then make an averaged .avg MS using mstransform(), passing only cross-correlations.
    If _inp_params.exportuvfits is set, then generate uvfits files from all sources, from averaged data if applicable.
    """
    if not _inp_params.avg_final_ms and not _inp_params.exportuvfits:
        print ('\nKeeping calibrated {0} as final product as no export options are specified\n'.format(_inp_params.ms_name))
    else:
        print ('\nExporting the calibrated data...')
        if _inp_params.avg_final_ms:
            _avg_ms_name = _inp_params.ms_name + '.avg'
            if _inp_params.verbose:
                print ('  Generating an averaged MS: ' + _avg_ms_name)
            if _inp_params.avg_final_spw:
                print('    *Warning: avg_final_spw=True may be broken and cause a Segfault...*')
                _combinespws = True
            else:
                _combinespws = False
            if _inp_params.avg_final_channel:
                _chanaverage = True
                _chanbin     = _inp_params.avg_final_channel
            else:
                _chanaverage = False
                _chanbin     = 1
            if _inp_params.avg_final_time:
                _timeaverage = True
                _timebin     = _inp_params.avg_final_time
            else:
                _timeaverage = False
                _timebin     = '0s'
            if _inp_params.verbose:
                _use = 'combinespws={0}, chanbin={1}, timebin={2}'.format(str(_combinespws), str(_chanbin), _timebin)
                print ('    Using ' + _use)
            if isdir(_avg_ms_name):
                shutil.rmtree(_avg_ms_name, ignore_errors=True)
            if isdir(_avg_ms_name+'.flagversions'):
                shutil.rmtree(_avg_ms_name+'.flagversions', ignore_errors=True)
            task_mstransform_general(_inp_params, _avg_ms_name, '*&', _combinespws, _chanaverage, _chanbin,
                                     _timeaverage, _timebin
                                    )
            if is_set(_inp_params, 'sigmascale'):
                rescale_sigma_and_weights(_avg_ms_name, _inp_params.sigmascale)
        if _inp_params.exportuvfits:
            if _inp_params.verbose:
                if _inp_params.avg_final_ms:
                    print ('  Generating uvfits files for all sources from ' + _avg_ms_name)
                else:
                    print ('  Generating uvfits files for all sources from ' + _inp_params.ms_name)
            if _inp_params.avg_final_ms:
                _vis = _avg_ms_name
            else:
                _vis = _inp_params.ms_name
            spw_part = _inp_params.spwpartition_uvf.split(',')
            sources  = list(_ms_metadata.selected_scans_dict.keys())
            for _field in sources:
                for these_spw in spw_part:
                    _fitsfile = _field + '_calibrated.uvf'
                    if these_spw:
                        _fitsfile += '.spw' + these_spw.replace('~', 'to')
                    if _inp_params.verbose:
                        print ('    Exporting ' + _fitsfile)
                    task_exportuvfits_general(_vis, _fitsfile, _field, these_spw)
                fix_uvfits_AN(_fitsfile, _ms_metadata)
        print ('Done\n')


def task_mstransform_general(_inp_params, outputvis, antenna='', combinespws=False, chanaverage=False, chanbin=1,
                             timeaverage=False, timebin='0s', keepflags=True, correlation=""):
    """ Generic mstransform task. Used by exportdata() for time- and frequency-averaging."""
    tasks.mstransform(vis                   =  _inp_params.ms_name,
                      outputvis             =  outputvis,
                      createmms             =  False,
                      separationaxis        =  "auto",
                      numsubms              =  "auto",
                      tileshape             =  [0],
                      field                 =  "",
                      spw                   =  "",
                      scan                  =  "",
                      antenna               =  antenna,
                      correlation           =  correlation,
                      timerange             =  "",
                      intent                =  "",
                      array                 =  "",
                      uvrange               =  "",
                      observation           =  "",
                      feed                  =  "",
                      datacolumn            =  "corrected",
                      realmodelcol          =  False,
                      keepflags             =  keepflags,
                      usewtspectrum         =  False,
                      combinespws           =  combinespws,
                      chanaverage           =  chanaverage,
                      chanbin               =  chanbin,
                      hanning               =  False,
                      regridms              =  False,
                      mode                  =  "channel",
                      nchan                 =  -1,
                      start                 =  0,
                      width                 =  1,
                      nspw                  =  1,
                      interpolation         =  "linear",
                      phasecenter           =  "",
                      restfreq              =  "",
                      outframe              =  "",
                      veltype               =  "radio",
                      preaverage            =  False,
                      timeaverage           =  timeaverage,
                      timebin               =  timebin,
                      timespan              =  "",
                      maxuvwdistance        =  0.0,
                      docallib              =  False,
                      callib                =  "",
                      douvcontsub           =  False,
                      fitspw                =  "",
                      fitorder              =  0,
                      want_cont             =  False,
                      denoising_lib         =  True,
                      nthreads              =  1,
                      niter                 =  1
                     )


def task_exportuvfits_general(vis, fitsfile, field, spw=""):
    """ Generic exportuvfits task. Used by exportdata() to export the calibrated data as uvfits file."""
    if isfile(fitsfile):
        os.remove(fitsfile)
    tasks.exportuvfits(vis                =  vis,
                       fitsfile           =  fitsfile,
                       datacolumn         =  "corrected",
                       field              =  field,
                       spw                =  spw,
                       antenna            =  "",
                       timerange          =  "",
                       writesyscal        =  False,
                       multisource        =  False,
                       combinespw         =  True,
                       writestation       =  True,
                       padwithflags       =  True,
                       overwrite          =  True
                      )


def task_partition_general(_inp_params, inpvis):
    """ Generic partition task. Used to create a MMS with the SUBMSS along the scan axis."""
    mymsmd = casac.msmetadata()
    mymsmd.open(inpvis)
    #limit SUBMMS size before we can engage
    ulimit_n  = int(resource.getrlimit(resource.RLIMIT_NOFILE)[0])
    num_scans = min(len(mymsmd.scannumbers()), ulimit_n/30)
    mymsmd.close()
    tasks.partition(vis             = inpvis,
                    outputvis       = _inp_params.ms_name,
                    createmms       = True,
                    separationaxis  = 'scan',
                    numsubms        = num_scans,
                    flagbackup      = True,
                    datacolumn      = 'all',
                    field           = '',
                    scan            = '',
                    spw             = '',
                    antenna         = '',
                    correlation     = '',
                    timerange       = '',
                    intent          = '',
                    array           = '',
                    uvrange         = '',
                    observation     = '',
                    feed            = ''
                   )


def task_flagmanager(_inp_params, mode, versionname):
    """
    Use flagmanger to backup and restore the status of dataflags.
    """
    tasks.flagmanager(vis         = _inp_params.ms_name,
                      mode        = mode,
                      versionname = versionname,
                      oldname     = "",
                      comment     = "",
                      merge       = "replace"
                     )


def get_flagversions(_inp_params, backuptype):
    """
    For a backuptype (typically 'applycal' or 'flagdata') looks through all all flagbackups
    in ms_name/flagversions/flags.backuptype_* and returns
    the latest backup, a list of all backups.
    """
    backup0       = _inp_params.ms_name + '.flagversions/flags.'
    these_backups = glob(backup0 + backuptype + '_*')
    if these_backups:
        _latest       = [int(tb.split('_')[-1]) for tb in these_backups]
        _latest       = sorted(_latest)[-1]
        latest_backup = backuptype + '_' + str(_latest)
        return latest_backup, these_backups
    else:
        return None, None


def only_keep_latest_flagver(_inp_params, backuptype):
    """Deletes all flagversions except the latest one for the specified backuptype."""
    if _inp_params.only_single_flagbackup:
        latest_flags, all_flags = get_flagversions(_inp_params, backuptype)
        for flagver in all_flags:
            if latest_flags not in flagver and _inp_params.restore_init_flags not in flagver:
                rm_dir_if_present(flagver)


def restore_init_flags(_inp_params, restore_method):
    """
    Creates flagbackup if it does not exists.
    If it does exists, restores flags to that version if _inp_params.restore_init_flags is given,
      unless restore_method=='a' while applycal has been run before,
      then restore to the version prior to the last applycal version.
    """
    if _inp_params.restore_init_flags:
        if restore_method == 'a':
            _backup, _ = get_flagversions(_inp_params, 'applycal')
            print('\nRestoring flags to the version before\n  ' + _backup + ' ...')
            task_flagmanager(_inp_params, 'restore', _backup)
        else:
            _backup = _inp_params.ms_name + '.flagversions/flags.' + _inp_params.restore_init_flags
            if os.path.exists(_backup):
                print('\nRestoring flags to initial version from\n  ' + _backup + ' ...')
                task_flagmanager(_inp_params, 'restore', _inp_params.restore_init_flags)
            else:
                print('\nSaving initial flag version to\n  ' + _backup + ' ...')
                task_flagmanager(_inp_params, 'save', _inp_params.restore_init_flags)
        print('Done\n')
