#!/usr/bin/env python
#
# Copyright (C) 2017 Michael Janssen
#
# This library is free software; you can redistribute it and/or modify it
# under the terms of the GNU Library General Public License as published by
# the Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This library is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Library General Public
# License for more details.
#
# You should have received a copy of the GNU Library General Public License
# along with this library; if not, write to the Free Software Foundation,
# Inc., 675 Massachusetts Ave, Cambridge, MA 02139, USA.
#
""" Diagnostic utilities: Generating logs, plots, and metadata. """
from collections import Counter
from collections import defaultdict
import os
import sys
import copy
import glob
import shutil
import datetime
import itertools
#matplotlib.use('Agg')
import distutils.spawn
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.ticker as mticker
import matplotlib.backends.backend_pdf
import pipe_modules.auxiliary as auxiliary
from pipe_modules.default_casa_imports import *


#If any of these input parameters change, ms_metadata is determined from scratch again.
global_ms_metadata_checkparams = ['science_target',
                                  'calibrators_instrphase',
                                  'calibrators_bandpass',
                                  'calibrators_rldly',
                                  'calibrators_dterms',
                                  'calibrators_phaseref',
                                  'num_selected_scans',
                                  'refant'
                                 ]

global_plt_fig_xsize  = 18
global_plt_fig_ysize  = 14
global_plt_fontsize   = 25
global_plt_markersize = 125

def set_plot_sizes(_inp_params):
    global global_plt_fig_xsize
    global global_plt_fig_ysize
    global global_plt_fontsize
    global global_plt_markersize
    if auxiliary.is_set(_inp_params, 'pc_fig_xsize_inches'):
        global_plt_fig_xsize = _inp_params.pc_fig_xsize_inches
    if auxiliary.is_set(_inp_params, 'pc_fig_ysize_inches'):
        global_plt_fig_ysize = _inp_params.pc_fig_ysize_inches
    if auxiliary.is_set(_inp_params, 'pc_fig_fontsize'):
        global_plt_fontsize = _inp_params.pc_fig_fontsize
    if auxiliary.is_set(_inp_params, 'pc_fig_markersize'):
        global_plt_markersize = _inp_params.pc_fig_markersize
    fig_size                       = plt.rcParams["figure.figsize"]
    fig_size[0]                    = global_plt_fig_xsize
    fig_size[1]                    = global_plt_fig_ysize
    plt.rcParams["figure.figsize"] = fig_size
    plt.rcParams.update({'legend.fontsize': global_plt_fontsize})


class ms_metadata(object):
    """
    Stores usefule metadata about the measurement set:
    Correlator integration time (also called accumulation period or exposure time).
    Correlations in the order of the data in the ms table [RR,RL,LR,LL].
    Number of channels in each spw and frequencies of each channel.
    Sourcenames with ordered IDs as [[sources],[ids]].
    Antennanames with ordered IDs as [[antennas],[ids]] (a dict is not used as it is not meant to be used to get keys via values).
    Mount type of each station.
    Ndarray of scans per source (axis 0), ordered, strings.
    Array of number of visibilities in each scan.
    Ndarray of antennas in each scan (axis 0), ordered, integer IDs.
    Ndarray of spw per antenna (axis 1) per scan (axis 0), ordered, integers: self.spwds.
    1D list of all spw present (assuming that all antennas share the same spw): self.all_spwds.
    Also stores a set uf useful scans per source for diagnostic plots, flagging,
        and tuning of fringe-fit parameters, see self.get_selected_scans.
    And the length of each scan in a dict accessed via yield_scan_length().
    Lastly, stores the most common antennas across all scans in the experiment. Could be used as refants.
    """
    def __init__(self, _inp_params, _correlations, _sourcenames, _sourceIDs, _antennanames, _antennaIDs, _all_spwds, _channels,
                 _channels_nu, _scans, _antennas, _spwds, _scan_lengths, _scan_nRows, _scans_start, _scans_end,
                 _integration_time):
        """
        returns [[scans for source1], [scans for source2], ...]
                [[antennas in scan1], [antennas in scan2], ...]
                [[[spwds for antenna1 in scan 1], [spwds for antenna2 in scan1], ...],
                    [[spwds for antenna1 in scan 2], [spwds for antenna2 in scan2], ...] , ... ]
                        i.e. spwds[scan_i][antenna_i] gives an array of spw for (scan_i, antenna_i)
        """
        print ('\nInitializing metadata collection...')
        self.acp                                       = _integration_time
        self.scans_start                               = _scans_start
        self.scans_end                                 = _scans_end
        self.visibilities                              = _scan_nRows
        self.correlations                              = _correlations
        self.all_spwds                                 = _all_spwds
        self.channels                                  = _channels
        self.channels_nu                               = _channels_nu
        self.sourcenames                               = _sourcenames
        self.sourceIDs                                 = _sourceIDs
        self.antennanames                              = _antennanames
        self.antennaIDs                                = _antennaIDs
        self.scans                                     = _scans
        self.scans                                     = [sc.astype('U13') for sc in self.scans]
        self.all_scans                                 = self.yield_scans()
        self.antennas                                  = _antennas
        self.spwds                                     = _spwds
        self.selected_scans, self.selected_scans_dict  = self.get_selected_scans(_inp_params)
        all_ants                                       = auxiliary.flatten_list(self.antennas)
        ants_counter                                   = Counter([int(ant) for ant in all_ants])
        self.antenna_set                               = list(set(all_ants))
        self.most_common_ants                          = ','.join([self.yield_antname(ant[0])
                                                                   for ant in ants_counter.most_common()
                                                                  ])
        self.scan_length                               = {}
        for sc, scl in zip(self.all_scans, _scan_lengths):
            self.scan_length[sc] = scl
        self.full_bandwidth                            = 0
        self.smallest_GHzdatafreq                      = float("inf")
        for spw in self.all_spwds:
            self.full_bandwidth += self.yield_bandwidth(spw)
            chan_GHzminfreq      = min(self.yield_freq_of_channels_in_spw(spw)) * 1.e-9
            if chan_GHzminfreq < self.smallest_GHzdatafreq:
                self.smallest_GHzdatafreq = chan_GHzminfreq
        self.reference_frequency                       = self.yield_freq_of_channels_in_spw(self.all_spwds[0])[0]
        self.checkparams                               = {}
        for checkparam in global_ms_metadata_checkparams:
            self.checkparams[checkparam] = getattr(_inp_params, checkparam)
        self.mounttypes                                = get_station_mounts(_inp_params)
        self.all_sources                               = self.yield_calibrators(_inp_params)
        _sci_sources                                   = self.yield_science_targets(_inp_params)
        if _sci_sources:
            self.all_sources.extend(_sci_sources)
        print ('Done\n')

    def got_something(self, x):
        if x:
            return x
        else:
            raise AttributeError('Got an empty selection.\nMaybe you need to run the pipeline again with the -m option.')

    def yield_science_targets(self, _inp_params):
        try:
            _result = [src for src in self.selected_scans_dict.keys() if src in _inp_params.science_target]
            return _result
        except TypeError:
            return None

    def yield_calibrators(self, _inp_params):
        try:
            _result = [src for src in self.selected_scans_dict.keys() if src not in _inp_params.science_target]
        except TypeError:
            _result = [src for src in self.selected_scans_dict.keys()]
        return _result

    def yield_phaseref(self, _inp_params, sci_field):
        if auxiliary.is_set(_inp_params, 'calibrators_phaseref'):
            science_src = _inp_params.science_target.split(',')
            phref_src   = _inp_params.calibrators_phaseref.split(',')
            try:
                selection = phref_src[science_src.index(sci_field)]
            except IndexError:
                selection = None
            return selection
        else:
            return None

    def yield_channels_in_spw(self, _spw):
        return self.channels[self.all_spwds.index(_spw)]

    def yield_freq_of_channels_in_spw(self, _spw):
        return self.channels_nu[self.all_spwds.index(_spw)]

    def yield_channel_spacing(self, _spw):
        these_channel_freqs = self.yield_freq_of_channels_in_spw(_spw)
        return np.abs(these_channel_freqs[1] - these_channel_freqs[0])

    def yield_bandwidth(self, _spw):
        channelfreqs = self.yield_freq_of_channels_in_spw(_spw)
        return max(channelfreqs) - min(channelfreqs) + self.yield_channel_spacing(_spw)

    def yield_sourcename(self, _id):
        try:
            isid = int(_id)
            return self.sourcenames[self.sourceIDs.index(_id)]
        except ValueError:
            return _id

    def yield_sourceID(self, _sourcename):
        try:
            isid = int(_sourcename)
            return _sourcename
        except ValueError:
            return self.sourceIDs[self.sourcenames.index(_sourcename)]

    def yield_sourcename_from_scan(self, _scan, replacedot=False):
        indx = None
        for i, scans in enumerate(self.scans):
            if str(_scan) in scans:
                indx = i
                break
        if indx == None:
            raise ValueError('scan '+str(_scan)+' is not present in the MS.')
        _this_srcname = self.sourcenames[indx]
        if replacedot:
            _this_srcname = _this_srcname.replace('.', ',')
        return _this_srcname

    def yield_antname(self, _id):
        try:
            isid = int(_id)
            return self.antennanames[self.antennaIDs.index(_id)]
        except ValueError:
            return _id

    def yield_antID(self, _antennaname):
        try:
            isid = int(_antennaname)
            return _antennaname
        except ValueError:
            return self.antennaIDs[self.antennanames.index(_antennaname)]

    def yield_antmount(self, _antennaname):
        antname = self.yield_antname(_antennaname)
        try:
            return self.got_something(self.mounttypes[antname])
        except AttributeError:
            return self.got_something(None)

    def yield_scan_length(self, _scan):
        #Add positive round bias for python 2-3 consistency.
        return float(self.acp) * round(float(self.scan_length[str(_scan)])/float(self.acp) + 1.e-13)

    def yield_selected_scans(self, _sourcename=''):
        if not _sourcename:
            scans = []
            for src in self.selected_scans_dict.keys():
                scans.extend(int(scan) for scan in self.selected_scans_dict[src].split(','))
            return scans
        else:
            return [int(scan) for scan in self.selected_scans_dict[_sourcename].split(',')]

    def yield_scans(self, _sourcename='', _IDs = [], do_not_squeeze_single_source=False):
        """
        all scans if not sourcename is given
        for 'src1, src2, ...' (or IDs) as input, returns a list of scans per src (even if only a single src is given)
        the same logic is applied to the other yield functions
        """
        if _sourcename:
            if isinstance(_sourcename, list):
                _sourcename = ','.join(_sourcename)
            scans = []
            for src in _sourcename.strip().split(','):
                scans.append(self.scans[self.sourcenames.index(src)])
            if not do_not_squeeze_single_source and len(_sourcename) > 1:
                scans = np.squeeze(scans)
            try:
                forgiveme = len(scans)
            except TypeError:
                scans = [str(scans)]
            return scans
        elif _IDs:
            scans = []
            for srcID in _IDs:
                scans.append(self.scans[self.sourceIDs.index(srcID)])
            scans = np.squeeze(scans)
            try:
                forgiveme = len(scans)
            except TypeError:
                scans = [str(scans)]
            return scans
        else:
            return auxiliary.flatten_list(self.scans)

    def yield_antennas(self, _scan=''):
        if _scan:
            antennas = []
            if isinstance(_scan, list):
                for scan in _scan:
                    antennas.append(self.antennas[self.all_scans.index(scan)])
                return np.squeeze(antennas)
            else:
                for scan in _scan.strip().split(','):
                    antennas.append(self.antennas[self.all_scans.index(scan)])
                return np.squeeze(antennas)
        else:
            return self.antennas

    def yield_numvisi(self, _scan=''):
        if _scan:
            if not isinstance(_scan, list):
                _scan = _scan.strip().split(',')
            numvis = [int(self.got_something(self.visibilities[self.all_scans.index(sc)])) for sc in _scan]
            if len(numvis) == 1:
                numvis = numvis[0]
            return numvis
        else:
            return self.visibilities

    def yield_scantime(self, _scan, scan_part='middle'):
        """
        gives time of a specific scan
        """
        starttime = self.got_something(self.scans_start[self.all_scans.index(_scan)])
        endtime   = self.got_something(self.scans_end[self.all_scans.index(_scan)])
        if scan_part == 'start':
            return starttime
        elif scan_part == 'end':
            return endtime
        else:
            return 0.5 * (starttime + endtime)

    def yield_spwds(self, _scan=''):
        """
        gives array of spwd per antenna in the specified scan or for all scans
        """
        if _scan:
            spwds = []
            if isinstance(_scan, list):
                for scan in _scan:
                    spwds.append(self.spwds[self.all_scans.index(scan)])
                return np.squeeze(spwds)
            else:
                for scan in _scan.strip().split(','):
                    spwds.append(self.spwds[self.all_scans.index(scan)])
                return np.squeeze(spwds)
        else:
            return self.spwds

    def get_selected_scans(self, _inp_params):
        """
        Get (source, scans) from self.get_source_scans() for
        [[calibrator1, calibrator2,...], [[science_target1],[science_target2],...]]
        Also return the same information in a simple dictionary.
        """
        if _inp_params.selected_scans:
            raise ValueError('Sorry the functionality to manually enter scans is not supported yet.' \
                             'You would have to do it here in the code and return tuples of (source, scans).'
                            )
            #_selected_scans = ...
        else:
            _selected_scans_dict = {}
            _selected_scans      = [[],[]]
            _all_sources         = [[],[]]
            _all_calibrators     = []
            if auxiliary.is_set(_inp_params, 'calibrators_instrphase'):
                for source in _inp_params.calibrators_instrphase.split(','):
                    _all_calibrators.append(source)
            if auxiliary.is_set(_inp_params, 'calibrators_bandpass'):
                for source in _inp_params.calibrators_bandpass.split(','):
                    _all_calibrators.append(source)
            if auxiliary.is_set(_inp_params, 'calibrators_rldly'):
                for source in _inp_params.calibrators_rldly.split(','):
                    _all_calibrators.append(source)
            if auxiliary.is_set(_inp_params, 'calibrators_dterms'):
                for source in _inp_params.calibrators_dterms.split(','):
                    _all_calibrators.append(source)
            if auxiliary.is_set(_inp_params, 'calibrators_phaseref'):
                for source in _inp_params.calibrators_phaseref.split(','):
                    _all_calibrators.append(source)
            _all_sources[0] = list(set(_all_calibrators))
            if auxiliary.is_set(_inp_params, 'science_target'):
                _all_sources[1] = _inp_params.science_target.split(',')
            else:
                pass
            src_diff = set(self.sourcenames).symmetric_difference(set(auxiliary.flatten_list(_all_sources)))
            if src_diff:
                print ('  Warning: There is a difference between all sources available\n  and the ones specified ' \
                       'as calibrators and science targets:\n    ' + str(src_diff) + '\n  ' \
                       'This either means you are not using all available sources\n  or you specified sources ' \
                       'which were not observed.\n  Maybe due to a typo?\n' \
                       '  In the former case the sources not specified will not be properly calibrated.\n' \
                       '  In the latter case the code should exit with a ValueError now.')
            for source in _all_sources[0]:
                source_scans = self.get_source_scans(_inp_params, source)
                _selected_scans[0].append((source, source_scans))
                _selected_scans_dict.update({source:source_scans})
            for source in _all_sources[1]:
                source_scans = self.get_source_scans(_inp_params, source)
                _selected_scans[1].append([(source, source_scans)])
                _selected_scans_dict.update({source:source_scans})
        return _selected_scans, _selected_scans_dict

    def get_source_scans(self, _inp_params, _source):
        """
        Get scans on a specific source where as many antennas as possible are present
            and which are evenly distributed over the observation while making sure that every antenna is present.
        If possible, should grab around ~_inp_params.num_selected_scans scans per source.
        """
        _scans         = self.yield_scans(_source)
        _antennas      = [self.yield_antennas(scan) for scan in _scans]
        len_ants       = [len(ants) for ants in _antennas]
        indx_scan      = list(range(len(_scans)))
        #indices of _inp_params.num_selected_scans scans evenly distributed in time:
        target_scans   = auxiliary.get_evenly_spaced_elements(indx_scan, min(_inp_params.num_selected_scans, len(_scans)))
        #indices of scans sorted by numbers of antennas present:
        u_ants, u_indx = np.unique(len_ants, return_index=True)
        #2d array with [[scans with most antennas], ..., [scans with least antennas]]:
        most_ants      = auxiliary.make_ndlist(len(u_indx))
        #most_ants: list per number of antennas and each of these lists is [(scan_number, list_of_antennas_in_that_scan)]
        for i, (lant, ants) in enumerate(zip(len_ants, _antennas)):
            for j,uant in enumerate(list(reversed(u_ants))):
                if lant==uant:
                    most_ants[j].append((i, ants))
        most_ants_orig = copy.deepcopy(most_ants)
        #if we got all antennas available for that source:
        got_all        = False
        use_all_scans  = False
        #can pick a scan with less antennas if the scan with more antennas is very far away from a target_scan:
        max_dist       = int(1 + 0.3 * len(len_ants))
        got_scans      = []
        safety         = 0
        while target_scans or not got_all:
            if not target_scans:
                #we need more than the inp_params.num_selected_scans to cover all antennas:
                use_all_scans = True
            for j,mant in enumerate(most_ants):
                if mant:
                    if use_all_scans:
                        c_closest_match, c_difference, got_one = auxiliary.get_closest_match([item[0] for item in mant],
                                                                                             indx_scan, got_scans)
                    else:
                        c_closest_match, c_difference, got_one = auxiliary.get_closest_match([item[0] for item in mant],
                                                                                             target_scans, got_scans)
                    if c_closest_match:
                        closest_match = c_closest_match
                        difference    = c_difference
                    #take the scan with the most antennas
                    #  as we do not have all antennas yet or because we already have a good enough match:
                    if got_one:
                        if not got_all or difference < max_dist:
                            break
            try:
                got_index    = [a[0] for a in most_ants[j]].index(closest_match[0])
                got_antennas = most_ants[j][got_index][1]
                if use_all_scans:
                    indx_scan.remove(closest_match[1])
                else:
                    target_scans.remove(closest_match[1])
                got_scans.append(closest_match[0])
                #delete antennas from most_ants if we got them already and afterwards remove scans with no antennas left:
                for j,mant in enumerate(most_ants):
                    most_ants[j] = [(t[0], auxiliary.subtract_list(t[1], got_antennas)) for t in most_ants[j]]
                    most_ants[j] = [t for t in most_ants[j] if t[1]]
                if not any(most_ants):
                    #got all available antennas and can start over again by picking from the scans with the most antennas
                    got_all   = True
                    most_ants = copy.deepcopy(most_ants_orig)
            except ValueError:
                #Max_dist was too restrictive(small) and we missed all possible scans. Go again without distance restriction.
                max_dist += 1e9
            safety += 1
            if safety > 238:
                #Reset to all scans as last resort.
                most_ants = copy.deepcopy(most_ants_orig)
            if safety > 1337:
                raise OverflowError('Cannot get selected scans for target scans = \n'+str(target_scans)+ \
                                    '\nor all scans = \n'+str(indx_scan)+'\nfrom available scans = \n'+str(most_ants))
        return ','.join(_scans[selected_scan] for selected_scan in got_scans)

    def get_source_scans_old(self, _inp_params, _source):
        """
        Get scans on a specific source where as many antennas as possible are present
            and which are evenly distributed over the observation.
        If possible, should grab around ~_inp_params.num_selected_scans scans per source.
        """
        _scans    = self.yield_scans(_source)
        #number of antennas in each scan:
        _antennas = [len(self.yield_antennas(scan)) for scan in _scans]
        #sorted list of all unique(different) number of antennas from all _scans
        _len_ants = sorted(list(set(_antennas)), reverse=True)
        _num_ants = auxiliary.make_ndlist(len(_len_ants))
        #order scans by number of antennas present:
        for _scan, _ant in zip(_scans, _antennas):
            for i,lant in enumerate(_len_ants):
                if _ant == lant:
                    _num_ants[i].append(_scan)
        _selected_scans = []
        #select scans distributed over the whole observations, preferably where most antennas are present
        for nant in _num_ants:
            if len(_selected_scans) < _inp_params.num_selected_scans:
                n_scans = len(nant)
                if n_scans >= _inp_params.num_selected_scans:
                    _selected_scans.extend(auxiliary.get_evenly_spaced_elements(nant, _inp_params.num_selected_scans))
                    break
                else:
                    _selected_scans.extend(nant)
        return ','.join(_selected_scans)


def check_for_metadata_reload(_inp_params):
    """
    If any of global_ms_metadata_checkparams _inp_params differ from the ones stored in _ms_metadata, _ms_metadata needs to be
    determined from scratch again.
    """
    if not auxiliary.isfile(_inp_params.store_ms_metadata):
        return False
    try:
        metadata_ms = auxiliary.store_object(_inp_params.store_ms_metadata)
        for checkparam in global_ms_metadata_checkparams:
            if metadata_ms.checkparams[checkparam] != getattr(_inp_params, checkparam):
                return True
            else:
                pass
    except (AttributeError, KeyError, ValueError, UnicodeDecodeError) as _:
        return True
    return False


def file_or_gui(_inp_params, _plotfile, _to_path = []):
    """
    Returns showgui=False if a _plotfile is given.
    Puts the relative path to plotfile in the diagnostics directory and creates the full path.
    _to_path is an optional list of strings ['a', 'b/', 'c', 'd',...], which will be converted to separate folders,
    i.e. if they do not contain a trailing slash, one will be added.
      Then the path consisting of all items will be prepended to _plotfile:
        _plotfile -> a/b/c/d/.../_plotfile
        The logic behind this is to be able to form a path from multiple variables that does not break down
        if it was forgotten to add a '/' to one of the variables.
    """
    if _plotfile:
        if _to_path:
            if not isinstance(_to_path, list):
                _to_path = [_to_path]
            prepend_to_plotfile = ''.join([os.path.join(tp,'') for tp in _to_path])
        else:
            prepend_to_plotfile = ''
        figfile = _inp_params.workdir + _inp_params.diagdir + prepend_to_plotfile + _plotfile
        auxiliary.makedir(os.path.dirname(figfile))
        showgui = False
    else:
        showgui = True
    return showgui, figfile


def get_any_diagdir(_inp_params):
    """
    Returns wildcard for the original <diagdir> as set in constants.inp (without the current datetime attached).
    Useful for -l and -d command line options for example.
    The 'original' diagdir set in constants.inp is not allowed to contain an underscore.
    """
    any_diag_dir = _inp_params.diagdir.strip('/')
    return any_diag_dir.split('_')[0] + '*/'


def get_last_diagdir(_inp_params):
    """
    Grab the name of the diagnostics folder (with datetime attached) from the last time the pipeline was run.
    Returns False if no diagnostics folder was created yet.
    """
    all_diag_dir = auxiliary.glob_all(get_any_diagdir(_inp_params))
    if all_diag_dir:
        return auxiliary.get_latest_file(all_diag_dir)
    else:
        return False


def init_diagnostics(_inp_params, _input_files):
    """
    Creates the diagnostics directory and writes the first logs which contains all used input parameters.
      One contains a copy of the input files and the other is a single list of how python internally stored the input variables.
    """
    auxiliary.makedir(_inp_params.diagdir)
    if _inp_params.verbose:
        print("\nWriting this run's diagnostics to\n" + str(os.path.abspath(_inp_params.diagdir)))
    if _inp_params.diag_inp_params:
        auxiliary.copyfiles(_input_files, _inp_params.diagdir + 'used_input_files/')
        logfile  = _inp_params.diagdir + _inp_params.diag_inp_params
        filename = open(logfile, 'w')
        for key,val in _inp_params._asdict().items():
            prtstr = key + ' = ' + str(val) + '\n'
            filename.write(prtstr)
        filename.close()
    if _inp_params.diag_cmd_args:
        logfile  = _inp_params.diagdir + _inp_params.diag_cmd_args
        if auxiliary.isfile(logfile):
            filename = open(logfile, 'a')
            filename.write('\n')
        else:
            filename = open(logfile, 'w')
        filename.write(' '.join(sys.argv))
        filename.close()


def station_scan_flag_perc(_inp_params, _ms_metadata, stationID, scanID):
    """
    Get the fraction of flagged data from a specified station in a specified scan.
    This is useful when checking if there is enough unflagged data to warrant using the station as refant for the scan.
    """
    if stationID not in _ms_metadata.yield_antennas(str(scanID)):
        return 0
    _select   = '(ANTENNA1==' + str(stationID) + ' || ANTENNA2==' + str(stationID) + ') && SCAN_NUMBER==' + str(scanID)
    flagarray = auxiliary.read_CASA_table(_inp_params, 'FLAG', _select)
    flagarray = flagarray.ravel()
    try:
        isarray    = flagarray[0]
        flagamount = sum(flagarray)
        dataamount = len(flagarray)
        gooddata   = 1. - float(flagamount) / float(dataamount)
        return gooddata
    except IndexError:
        #Handle empty numpy boolean array
        return 0


def print_flags(_inp_params, _ms_metadata):
    """ Print amount of flagged data to a file."""
    if _inp_params.diag_flags:
        logfile  = _inp_params.diagdir + _inp_params.diag_flags
        print('\nWriting an overview of flagged data percentages per station to\n  '+logfile+'\n    This can take a while...')
        _flags   = open(logfile, 'w')
        _flags.write('Amount of flagged data on all baselines to a certain station per source:\n')
        all_antids   = _ms_metadata.antennaIDs
        all_selected = _ms_metadata.selected_scans[0] + auxiliary.flatten_list(_ms_metadata.selected_scans[1])
        all_sources  = [item[0] for item in all_selected]
        all_scans    = [item[1].split(',') for item in all_selected]
        all_antennas = [_ms_metadata.yield_antennas(scans) for scans in all_scans]
        #handle the case where a certain source has only a single scan:
        for i,ants in enumerate(all_antennas):
            if not isinstance(ants[0], np.ndarray):
                all_antennas[i] = [ants]
        for scans, ants_scans, src in zip(all_scans, all_antennas, all_sources):
            _flags.write('\n -- ' + str(src) + ' -- \n')
            flagamount = {}
            dataamount = {}
            for antid in all_antids:
                flagamount[antid] = 0
                dataamount[antid] = 0
            for scan, ants_scan in zip(scans, ants_scans):
                for ant in ants_scan:
                    _select    = '(ANTENNA1==' + str(ant) + ' || ANTENNA2==' + str(ant) + ')' + \
                                 '&& FIELD_ID==' + str(_ms_metadata.yield_sourceID(src)) + \
                                 ' && SCAN_NUMBER==' + str(scan)
                    flagarray  = auxiliary.read_CASA_table(_inp_params, 'FLAG', _select)
                    flagarray  = flagarray.ravel()
                    try:
                        isarray         = flagarray[0]
                        flagamount[ant] += sum(flagarray)
                        dataamount[ant] += len(flagarray)
                    except IndexError:
                        #Handle empty numpy boolean array
                        pass
            for antid in all_antids:
                try:
                    antname = _ms_metadata.yield_antname(antid)
                except ValueError:
                    antname = str(antid)+('(unknown name in MS)')
                if dataamount[antid]:
                    antflags = 100. * float(flagamount[antid]) / float(dataamount[antid])
                    flagnote = '{0}: {1:.2f}%'.format(antname, antflags)
                else:
                    flagnote = '{0}: No data for this source'.format(antname)
                _flags.write(flagnote+'\n')
        _flags.close()
        print('Done\n')


def print_fringe_overview(_inp_params, _ms_metadata, fringe_overview_dict, filename):
    """
    Prints the content of fringe_overview_dict to filename.
    The fringe_overview_dict object is created by the exhaustive_baseline_search() function in the
    exhaustive_baseline_search module.
    """
    if not filename:
        return
    else:
        logfile  = _inp_params.diagdir + filename
    cals = _ms_metadata.yield_calibrators(_inp_params)
    fout = open(logfile, 'w')
    fout.write('#Scan,Time,Source,Calibrator,Antenna1,Antenna2,PolarizationID,Detection,SNR,FFT-delay[ns],FFT-rate[s/s]')
    for src in sorted(fringe_overview_dict):
        ssrc = str(src)
        if ssrc in cals:
            scal = 'Y'
        else:
            scal = 'N'
        sorted_scans = list(fringe_overview_dict[src].keys())
        auxiliary.natural_sort_Ned_Batchelder(sorted_scans)
        for scan in sorted_scans:
            sscan = str(scan)
            stime = str(_ms_metadata. yield_scantime(sscan))
            for ant1 in sorted(fringe_overview_dict[src][scan]):
                sant1 = str(ant1)
                this_overview = fringe_overview_dict[src][scan][ant1]
                if this_overview:
                    for ant2 in sorted(this_overview):
                        sant2  = str(ant2)
                        snr    = fringe_overview_dict[src][scan][ant1][ant2]['snr']
                        flag   = fringe_overview_dict[src][scan][ant1][ant2]['flag']
                        fringe = fringe_overview_dict[src][scan][ant1][ant2]['fringe']
                        for pID, spID in enumerate(['1', '2']):
                            psnr   = snr[pID]
                            pssnr  = str(psnr)
                            pbflag = bool(flag[pID])
                            psdela = str(fringe[pID][0])
                            psrate = str(fringe[pID][1])
                            if pbflag or not psnr:
                                psdetect = 'N'
                            else:
                                psdetect = 'Y'
                            vals = '\n{0},{1},{2},{3},{4},{5},{6},{7},{8},{9},{10}'.format(sscan, stime, ssrc, scal, sant1,
                                                                                           sant2, spID, psdetect, pssnr, psdela,
                                                                                           psrate
                                                                                          )
                            fout.write(vals)
                else:
                    fout.write('\n{0},{1},{2},{3},{4},NA,NA,NA,NA,NA,NA'.format(sscan, stime, ssrc, scal, sant1))
    fout.close()


def task_listobs(_inp_params, _create_new_file=False):
    """ Executes CASA's listobs task. """
    if _inp_params.diag_listobs:
        logfile = _inp_params.diagdir + _inp_params.diag_listobs
        print ('\nWriting listobs file to ' + logfile +'...')
        any_diag_dir      = get_any_diagdir(_inp_params)
        any_previous_file = auxiliary.glob_all(any_diag_dir + _inp_params.diag_listobs)
        if not _create_new_file and any_previous_file:
            try:
                shutil.copyfile(any_previous_file[0], logfile)
            except shutil.Error:
                #file is already there
                pass
        else:
            tasks.listobs(vis         = _inp_params.ms_name,
                          selectdata  = True,
                          spw         = '',
                          field       = '',
                          antenna     = '',
                          uvrange     = '',
                          timerange   = '',
                          correlation = '',
                          scan        = '',
                          intent      = '',
                          feed        = '',
                          array       = '',
                          observation = '',
                          verbose     = True,
                          listfile    = logfile,
                          listunfl    = False,
                          overwrite   = True
                         )
        print ('Done\n')


def get_station_mounts(_inp_params):
    """ Return of dictionary of all station mounts. """
    mytb = casac.table()
    antt = _inp_params.ms_name.strip('/')+'/ANTENNA'
    mytb.open(antt, nomodify=False)
    ants   = mytb.getcol('NAME')
    mounts = mytb.getcol('MOUNT')
    mytb.done()
    stmounts = {}
    for ant,mount in zip(ants, mounts):
        stmounts[ant] = mount
    return stmounts


def plot_calibration_summary(_inp_params):
    """
    Single file calibration summary as suggested by Kazi Rygl and Elisabetta Liuzzo.
    Creates a single summary pdf file containing all plots from individual SUMMARY files from each calibration steps organized
    into a 2x3 layout of A4 pdflatex pages.
    """
    sumfile = _inp_params.diagdir + 'CALIBRATION_SUMMARY.pdf'
    i_sums  = glob.glob(_inp_params.diagdir+'*/SUMMARY_*.pdf')
    if not i_sums:
        return
    elif not distutils.spawn.find_executable('pdflatex'):
        return
    print ('\nGenerating a calibration summary pdf file...')
    if auxiliary.is_set(_inp_params, 'pc_file_ext'):
        plt_ext = _inp_params.pc_file_ext
    else:
        plt_ext = 'png'
    i_pages = {}
    for pdf_f in i_sums:
        fpath = os.path.dirname(pdf_f)
        plfs  = glob.glob(fpath+'/*.{0}'.format(plt_ext))
        Nfs   = len(plfs)
        if pdf_f in plfs:
            Nfs -= 1
        i_pages[pdf_f] = (Nfs, fpath.split('/')[-1])
    dummy_tex = auxiliary.random_number_string(6) + '_tmp'
    with open(dummy_tex+'.tex', 'w') as f:
        f.write(
r"""
\documentclass[12pt]{article}
\usepackage[utf8]{inputenc}
\usepackage{graphicx}
\usepackage[margin=0.8in]{geometry}
\usepackage{fancyhdr}
\pagestyle{fancy}
\renewcommand{\sectionmark}[1]{%
\markboth{\thesection\quad #1}{}}
\fancyhead{}
\fancyhead[L]{\leftmark}
\fancyfoot{}
\fancyfoot[C]{\thepage}
\begin{document}

"""
              )

        for page in i_pages:
            sectionname = i_pages[page][1].replace('_', r'\_')
            f.write(r'\section*{\hfil ' + sectionname + '\hfil}')
            f.write('\n')
            f.write(r'\markboth{' + sectionname + '}{}')
            f.write('\n')
            for j in range(i_pages[page][0]):
                pcounter = j + 1
                f.write(r'\includegraphics[width=0.33\textwidth, page='+str(pcounter)+']{'+page+'} ')
                if (pcounter%3) == 0:
                    f.write(r'\\' + '\n')
            f.write('\n' + r'\newpage' + '\n')

        f.write(
r"""
\end{document}
"""
               )
    os.system('pdflatex ' + dummy_tex + '.tex > /dev/null 2>&1')
    shutil.move(dummy_tex+'.pdf', sumfile)
    for tmpfile in glob.glob(dummy_tex + '.*'):
        os.remove(tmpfile)
    print('\nDone\n')



def plot_calibrated_visibilities(_inp_params, _ms_metadata, _casalogfile=''):
    """
    Makes plots with jplotter if it is available. If not uses CASA-plotms.
    The _casalogfile will be used to create a shell script with a similar name for jplotter instructions
      (must be run outside CASA).
    """
    if _inp_params.diag_calib:
        print ('\nGenerating plots of calibrated visibilities...')
        plotpath0 = os.path.join(_inp_params.diag_calib, '')
        useplotms = False
        if _inp_params.diag_calib_use_jplotter:
            if distutils.spawn.find_executable('jplotter'):
                make_jplotter_plot(_inp_params, _ms_metadata, plotpath0, _casalogfile, in_path=True)
            elif distutils.spawn.find_executable('singularity'):
                make_jplotter_plot(_inp_params, _ms_metadata, plotpath0, _casalogfile, in_path=False)
            else:
                useplotms = True
        else:
            useplotms = True
        if useplotms:
            plot_calibrated_visibilities_with_plotms(_inp_params, _ms_metadata, plotpath0)
        print ('\nDone.\n')
        plot_calibration_summary(_inp_params)


def plot_calibrated_visibilities_with_plotms(_inp_params, _ms_metadata, add_to_dir):
    """
    Makes plots of calibrated data: amplitude and phase vs time and frequency
      for all correlations and all baselines to the reference antenna of all selected_scans form ms_metadata.
      Puts plots from separate source in separate folders.
    """
    print ('  Did not find jplotter in your path. Will have to use the slower plotms instead...')
    _corrs = _ms_metadata.correlations
    if len(_corrs) > 1:
        if _inp_params.diag_calib_only_parallel_hands:
            _corrs = np.take(_corrs, indices = [0,-1])
    all_selected = _ms_metadata.selected_scans[0] + auxiliary.flatten_list(_ms_metadata.selected_scans[1])
    all_sources  = [item[0] for item in all_selected]
    all_scans    = [item[1].split(',') for item in all_selected]
    all_antennas = [_ms_metadata.yield_antennas(scans) for scans in all_scans]
    flat_allants = auxiliary.flatten_list([_ms_metadata.yield_antennas(scans) for scans in auxiliary.flatten_list(all_scans)])
    #handle the case where a certain source has only a single scan:
    for i,ants in enumerate(all_antennas):
        if not isinstance(ants[0], np.ndarray):
            all_antennas[i] = [ants]
    yvals = _inp_params.diag_calib_yvals.split(',')
    i     = 0
    N     = str( 2* len(yvals) * len(_corrs) * len(flat_allants))
    for yval in yvals:
        for corr in _corrs:
            for scans, ants_scans, src in zip(all_scans, all_antennas, all_sources):
                for scan, ants_scan in zip(scans, ants_scans):
                    str_scan = str(scan)
                    for ant_iter, ant in enumerate(ants_scan):
                        i += 2
                        sys.stdout.write('  Plotting '+str(i)+'/'+N+'\r')
                        sys.stdout.flush()
                        try:
                            if not _inp_params.diag_calib_only_crosscorr and yval=='amp':
                                baseline  = _ms_metadata.yield_antname(ant) + '&&' + _ms_metadata.yield_antname(ant)
                                plotname0 = add_to_dir + src + '/' +str_scan+'_'+baseline+'_'+corr+'_'+yval
                                make_plotsms_freqtime_plot(_inp_params, yval, baseline, str_scan, str(corr), plotname0)
                            if _inp_params.diag_calib_only_to_refant:
                                if _ms_metadata.yield_antname(ant) == _inp_params.refant.split(',')[0]:
                                    continue
                                else:
                                    baseline  = _ms_metadata.yield_antname(ant) + '&' + _inp_params.refant.split(',')[0]
                                    plotname0 = add_to_dir + src + '/' +str_scan+'_'+baseline+'_'+corr+'_'+yval
                                    make_plotsms_freqtime_plot(_inp_params, yval, baseline, str_scan, str(corr), plotname0)
                            else:
                                basel_to_ant = ants_scan[ant_iter+1::]
                                for bta in basel_to_ant:
                                    baseline  = _ms_metadata.yield_antname(ant) + '&' + _ms_metadata.yield_antname(bta)
                                    plotname0 = add_to_dir + src + '/' +str_scan+'_'+baseline+'_'+corr+'_'+yval
                                    make_plotsms_freqtime_plot(_inp_params, yval, baseline, str_scan, str(corr), plotname0)
                        except ValueError:
                            #_ms_metadata.yield_antname(ant) caused ValueError: ant is not present in the MS (split MS)
                            pass


def make_jplotter_plot(_inp_params, _ms_metadata, _plotlocation, _casalogfile, in_path=False):
    """
    Uses jplotter to make all plots in one go as ps files per scan.
    """
    print ('  Note: Will only prepare a file with plotting instructions\n  that will be executed outside of CASA afterwards.')
    print ('    For the default plots, crosses correspond to flagged data.')
    _scans = _ms_metadata.yield_selected_scans()
    _scans = [str(scan) for scan in _scans]
    sscans = []
    for scan in _scans:
        this_numvisi = _ms_metadata.yield_numvisi(scan)
        this_src     = _ms_metadata.yield_sourcename_from_scan(scan)
        sscans.append((this_src, scan, this_numvisi))
    #Create 2 summary plots first:
    if auxiliary.is_set(_inp_params, 'calibrators_instrphase'):
        ssrc   = _inp_params.calibrators_instrphase.split(',')[0]
        gscans = []
        for scan in sscans:
            if scan[0] == ssrc:
                gscans.append((scan[0], scan[1], scan[2]))
    else:
        gscans = sscans
    golden_scan = sorted(gscans, key=lambda x: x[2])[-1][1]
    plotpath    = _inp_params.diagdir + _plotlocation
    auxiliary.makedir(os.path.dirname(plotpath))
    if auxiliary.is_set(_inp_params, 'diag_calib_make_overviewplots'):
        filename1 = plotpath + 'OVERVIEWPLOT_VPLOT_{0}'.format(_ms_metadata.all_sources[0])
        filename2 = plotpath + 'OVERVIEWPLOT_POSSM_scan{0}'.format(golden_scan)
        plotcmd0  = 'file {0}; ms {1} column=corrected_data; indexr; '.format(filename1, _inp_params.ms_name)
        plotcmd0 += 'fq */p; '
        plotcmd0 += 'bl cross; '
        plotcmd0 += 'avt none; avc vector; pt anptime; sort bl; y local; '
        if len(_ms_metadata.correlations) > 1:
            plotcmd0 += 'ckey p; '
        plotcmd0 += 'draw points points; show both; '
        plotcmd0 += r"label time : 'time [UTC]'; label phase : '\gF [deg]'; label amplitude : 'Amp [Jy]'; "
        plotcmd0 += 'src {0}; '.format(_ms_metadata.all_sources[0])
        plotcmd0 += 'pl; '
        for src in _ms_metadata.all_sources[1:]:
            filenamepp = plotpath + 'OVERVIEWPLOT_VPLOT_{0}'.format(src)
            plotcmd0   += 'refile {0}; src {1};pl;'.format(filenamepp, src)
        plotcmd0 += 'refile {0}; src none; '.format(filename2)
        plotcmd0 += 'avt vector; avc none; pt anpfreq; sort bl; y local; '
        if len(_ms_metadata.correlations) > 1:
            plotcmd0 += 'ckey p; '
        plotcmd0 += 'draw points points; show both; '
        plotcmd0 += r"label frequency : '\gn [MHz]'; label phase : '\gF [deg]'; label amplitude : 'Amp [Jy]'; "
        plotcmd0 += 'scan {0}; '.format(golden_scan)
        plotcmd0 += 'pl; '
        plotcmd0  += 'fq none; '

    freq_plots = [plotpath+_ms_metadata.yield_sourcename_from_scan(scan, True)+'_scan'+scan+'_plots_vs_freq' for scan in _scans]
    time_plots = [plotpath+_ms_metadata.yield_sourcename_from_scan(scan, True)+'_scan'+scan+'_plots_vs_time' for scan in _scans]
    if auxiliary.is_set(_inp_params, 'diag_calib_make_overviewplots'):
        plotcmd0 += 'refile {0}; '.format(freq_plots[0])
    else:
        plotcmd0 = 'file {0}; ms {1} column=corrected_data; indexr; '.format(freq_plots[0], _inp_params.ms_name)
    if _inp_params.diag_calib_only_parallel_hands:
        plotcmd0 += 'fq */p; '
    plotcmd = plotcmd0
    #plots vs channel per scan
    plotcmd += 'avt vector; avc none; scan {0}; pt anpfreq; sort bl; '.format(_scans[0])
    plotcmd += 'x local; y local; nxy 1 1; '
    if len(_ms_metadata.correlations) > 1:
        plotcmd += 'ckey p; '
    plotcmd += 'draw points points; show both; show noheader; '
    plotcmd += r"label frequency : '\gn [MHz]'; label phase : '\gF [deg]'; label amplitude : 'Amp [Jy]'; "
    plotcmd += 'pl; '
    for fplot, scan in zip(freq_plots[1:], _scans[1:]):
        plotcmd += 'refile {0}; scan {1}; pl;'.format(fplot, scan)
    #plots vs time per scan
    plotcmd += 'avt none; avc vector; pt anptime; sort bl; '
    plotcmd += 'x local; y local; nxy 1 1; '
    if len(_ms_metadata.correlations) > 1:
        plotcmd += 'ckey p; '
    plotcmd += 'draw points points; show both; show noheader; '
    plotcmd += r"label time : 'time [UTC]'; label phase : '\gF [deg]'; label amplitude : 'Amp [Jy]'; "
    for tplot, scan in zip(time_plots, _scans):
        plotcmd += 'refile {0}; scan {1}; pl;'.format(tplot, scan)
    plotcmd = plotcmd.strip(';')
    plotter_file    = _casalogfile+'.plotter.sh'
    plotting_helper = open(plotter_file, 'w')
    plotting_helper.write('#!/bin/sh\n')
    plotting_helper.write('set -e\n')
    plotting_helper.write('plotcmd="{0}"\n'.format(plotcmd))
    plotting_helper.write('cd {0}\n'.format(_inp_params.workdir))
    if in_path:
        plotting_helper.write('printf "${plotcmd}" | jplotter')
    else:
        plotting_helper.write('printf "${plotcmd}" | singularity run --bind $PWD shub://haavee/jiveplot')
    plotting_helper.close()
    os.system('chmod +x ' + plotter_file)


def make_plotsms_freqtime_plot(_inp_params, _yval, _baseline, _scan, _correlation, _plotlocation):
    """
    Calls task_plotms_general() to generate one plot vs frequency and one vs time of _yval on _baseline for _scan
    and _correlation at _plotlocation.
    """
    if '&&' in _baseline:
        scalar = True
    else:
        scalar = False
    plotname_t = _plotlocation +'_'+'vs_time.png'
    plotname_f = _plotlocation +'_'+'vs_freq.png'
    task_plotms_general(_inp_params, xaxis='time',
                        yaxis=_yval, antenna=_baseline, scan=_scan, correlation=_correlation,
                        avgspw=True, avgchannel='9999', scalar=scalar, plotfile=plotname_t
                       )
    task_plotms_general(_inp_params, xaxis='freq',
                        yaxis=_yval, antenna=_baseline, scan=_scan, correlation=_correlation,
                        avgtime='9999', scalar=scalar, plotfile=plotname_f
                       )


def task_plotms_general(_inp_params,
                        xaxis       = '',
                        yaxis       = '',
                        ydatacolumn = 'corrected',
                        field       = '',
                        spw         = '',
                        antenna     = '',
                        scan        = '',
                        correlation = '',
                        avgchannel  = '',
                        avgspw      = False,
                        avgtime     = '',
                        avgscan     = False,
                        avgbaseline = False,
                        scalar      = False,
                        iteraxis    = 'baseline',
                        coloraxis   = '',
                        xlabel      = '',
                        ylabel      = '',
                        plotfile    = '',
                        to_path     = ''
                       ):
    """
    General CASA plotms task.
    Different pipeline functions make use of it with different input parameters.
    Plotfile can contain a full path, the corresponding folders will be created in the diagnostics folder.
    """
    if not isinstance(scan, str):
        scan = str(scan)
    _showgui, plotfile = file_or_gui(_inp_params, plotfile, to_path)
    casaplotms.plotms(vis                  =  _inp_params.ms_name,
                      gridrows             =  1,
                      gridcols             =  1,
                      rowindex             =  0,
                      colindex             =  0,
                      plotindex            =  0,
                      xaxis                =  xaxis,
                      xdatacolumn          =  '',
                      yaxis                =  yaxis,
                      ydatacolumn          =  ydatacolumn,
                      #yaxislocation        =  None,
                      selectdata           =  True,
                      field                =  field,
                      spw                  =  spw,
                      timerange            =  '',
                      uvrange              =  '',
                      antenna              =  antenna,
                      scan                 =  scan,
                      correlation          =  correlation,
                      array                =  '',
                      observation          =  '',
                      intent               =  '',
                      feed                 =  '',
                      msselect             =  '',
                      averagedata          =  True,
                      #for avg over channels and time: can have values >scantime per scan and >total #channels
                      avgchannel           =  avgchannel,
                      avgtime              =  avgtime,
                      avgscan              =  avgscan,
                      avgfield             =  False,
                      #must have avgbaseline=True for a single baseline plot for proper averaging
                      avgbaseline          =  avgbaseline,
                      avgantenna           =  False,
                      #avgspw True or False (can only average over all or none)
                      avgspw               =  avgspw,
                      scalar               =  scalar,
                      transform            =  False,
                      freqframe            =  '',
                      restfreq             =  '',
                      veldef               =  'RADIO',
                      shift                =  [0.0, 0.0],
                      extendflag           =  False,
                      extcorr              =  False,
                      extchannel           =  False,
                      #iteraxis = 'antenna' or 'baseline':
                      iteraxis             =  iteraxis,
                      xselfscale           =  False,
                      yselfscale           =  False,
                      xsharedaxis          =  False,
                      ysharedaxis          =  False,
                      customsymbol         =  True,
                      symbolshape          =  'circle',
                      symbolsize           =  _inp_params.pl_symbolsize,
                      symbolcolor          =  '0000ff',
                      symbolfill           =  'fill',
                      symboloutline        =  False,
                      #coloraxis = 'baseline':
                      coloraxis            =  coloraxis,
                      customflaggedsymbol  =  False,
                      flaggedsymbolshape   =  'circle',
                      flaggedsymbolsize    =  _inp_params.pl_flaggedsymbolsize,
                      flaggedsymbolcolor   =  'ff0000',
                      flaggedsymbolfill    =  'fill',
                      flaggedsymboloutline =  False,
                      plotrange            =  [],
                      title                =  '',
                      titlefont            =  _inp_params.pl_titlefont,
                      xlabel               =  xlabel,
                      xaxisfont            =  _inp_params.pl_xaxisfont,
                      ylabel               =  ylabel,
                      yaxisfont            =  _inp_params.pl_yaxisfont,
                      showmajorgrid        =  False,
                      majorwidth           =  1,
                      majorstyle           =  '',
                      majorcolor           =  'B0B0B0',
                      showminorgrid        =  False,
                      minorwidth           =  1,
                      minorstyle           =  '',
                      minorcolor           =  'D0D0D0',
                      showlegend           =  False,
                      #legendposition       =  None,
                      plotfile             =  plotfile,
                      expformat            =  '',
                      exprange             =  '',
                      highres              =  False,
                      #dpi parameter seems to be ignored?
                      dpi                  =  -1,
                      width                =  _inp_params.pl_width,
                      height               =  _inp_params.pl_height,
                      overwrite            =  True,
                      showgui              =  _showgui,
                      clearplots           =  True
                     )


def task_plotcal_general(_inp_params,
                         caltable,
                         figfile,
                         xaxis     = '',
                         yaxis     = '',
                         poln      = '',
                         field     = '',
                         antenna   = '',
                         spw       = '',
                         timerange = '',
                         overplot  = False,
                         iteration = 'antenna',
                         plotrange = [],
                         showflags = False,
                         to_path   = ''
                        ):
    """
    General CASA plotcal task.
    Used by several calibration tasks to generate diagnostic output (visualize solutions).
    Figfile can contain a full path, the corresponding folders will be created in the diagnostics folder.
    If yaxis and/or poln are a csv string, then plots are made for every value separately.
    Prepends yaxis and poln to the figfile name.
    """
    if not os.path.isdir(caltable):
        raise IOError('Calibration table\n  ' + caltable + '\n does not exist!')
    #could also be called explicitly with None as input:
    if not xaxis:
        xaxis = ''
    if not yaxis:
        yaxis = ''
    if not poln:
        poln = ''
    _poln  = poln.split(',')
    _yaxis = yaxis.split(',')
    _xaxis = xaxis.split(',')
    for pln in _poln:
        for yax in _yaxis:
            for xax in _xaxis:
                _showgui, _figfile = file_or_gui(_inp_params, pln+'_'+yax+'_vs_'+xax+'_'+figfile, to_path)
                tasks.plotcal(caltable   = caltable,
                              xaxis      = xax,
                              yaxis      = yax,
                              poln       = pln,
                              field      = field,
                              antenna    = str(antenna),
                              spw        = spw,
                              timerange  = timerange,
                              subplot    = 111,
                              overplot   = overplot,
                              clearpanel = 'Auto',
                              iteration  = iteration,
                              plotrange  = plotrange,
                              showflags  =  showflags,
                              plotsymbol = 'o',
                              plotcolor  = 'blue',
                              markersize =  _inp_params.pc_markersize,
                              fontsize   =  _inp_params.pc_fontsize,
                              showgui    =  _showgui,
                              figfile    =  _figfile
                             )


def my_plotcal(_inp_params, _ms_metadata, caltable, figname, figfolder, xaxes, yaxes):
    """
    Input:
        - figname: From the corresponding taskname. The file_or_gui() function is used to save the plots in the right places.
        - figfolder: Folder where the plots will be created in.
        - xaxes: For every xaxis provided in a comma-separated string, separate plots are made.
        - yaxes: Stacks all provided yaxis items from this comma-separated string  on top of each other in every plot.
    Output:
        - Separate plots of calibration solutions for every antenna present in the calibration table and every input xaxes.
        - One summary pdf plot that contains all created plots in a single file.
    """
    if _inp_params.bandtype_cmplx_bandpass=='BPOLY' and caltable==_inp_params.calib_complex_bandpass[_inp_params.C_NAM]:
        if _inp_params.verbose:
            print('  Unable to plot solutions from a BPOLY calibration table, sorry.')
        return
    mytb = casac.table()
    mytb.open(caltable)
    tb_colnames = mytb.colnames()
    tb_type     = mytb.getkeywords()['VisCal']
    mytb.close()
    if 'FPARAM' in tb_colnames:
        calparam = 'FPARAM'
    elif 'CPARAM' in tb_colnames:
        calparam = 'CPARAM'
    else:
        warn_msg = 'No CPARAM or FPARAM column in the {0} calibration table.'.format(caltable)
        warn_msg+= 'I only found {0}. Will not plot anything.'.format(str(tb_colnames))
        return
    if 'Fringe' in tb_type:
        is_fftab = True
    else:
        is_fftab = False
    # Assign consistent coloring to sources for individual polarizations.
    colors0 = ['tab:red', 'midnightblue', 'tab:brown', 'darkviolet', 'lime'  , 'deeppink', 'teal'  , 'darkgoldenrod']
    colors1 = ['black'  , 'tab:orange'  , 'green'    , 'red'       , 'cyan'  , 'dimgray' , 'maroon', 'magenta']
    N_srcs   = len(_ms_metadata.sourcenames)
    N_colors = len(colors0)
    if N_srcs > N_colors:
        colors0.extend(plt.cm.viridis(np.linspace(0,1,N_srcs-N_colors)))
        colors1.extend(plt.cm.plasma(np.linspace(0,1,N_srcs-N_colors)[::-1]))
    srccol = defaultdict(dict)
    for i, src in enumerate(_ms_metadata.sourcenames):
        srccol[0][src] = colors0[i]
        srccol[1][src] = colors1[i]
    markersize = global_plt_markersize
    fontsize   = global_plt_fontsize
    xaxes      = xaxes.replace(' ','')
    yaxes      = yaxes.replace(' ','')
    xaxesl     = xaxes.split(',')
    yaxesl     = yaxes.split(',')[::-1]
    monthnames = {1:'Jan', 2:'Feb', 3:'Mar', 4:'Apr', 5:'May', 6:'Jun', 7:'Jul', 8:'Aug', 9:'Sep', 10:'Oct', 11:'Nov', 12:'Dec'}
    ylabels    = {'amp': 'Gain amplitude', 'phase': 'Phase\n[deg]', 'rate': 'Rate\n[psec/sec]',
                  'delay': 'Delay\n[nsec]', 'disp': 'Disp', 'snr': 'S/N', 'tec': 'TEC[$10^{17}/m^2$]',
                  'tsys': 'Tsys [K]', 'gc': 'elevation gain [K/Jy]'
                 }
    fparams    = defaultdict(dict)
    fparams[0] = {'phase':_inp_params.F_0PHAS, 'delay':_inp_params.F_0DELA, 'rate':_inp_params.F_0RATE,
                  'disp':_inp_params.F_0DISP, 'snr':_inp_params.F_0PHAS
                 }
    fparams[1] = {'phase':_inp_params.F_1PHAS, 'delay':_inp_params.F_1DELA, 'rate':_inp_params.F_1RATE,
                  'disp':_inp_params.F_1DISP, 'snr':_inp_params.F_1PHAS
                 }
    allowed_ylabels = list(ylabels.keys())
    if auxiliary.is_set(_inp_params, 'pc_file_ext'):
        plt_ext = _inp_params.pc_file_ext
    else:
        plt_ext = 'png'
    caltb_stations = np.unique(auxiliary.read_CASA_table(_inp_params, 'ANTENNA1', _tablename=caltable))
    caltb_stsrcspw = defaultdict(dict)
    for st in caltb_stations:
        qcrit         = 'ANTENNA1=={0}'.format(str(st))
        caltb_sources = np.unique(auxiliary.read_CASA_table(_inp_params, 'FIELD_ID', qcrit, _tablename=caltable))
        for src in caltb_sources:
            qcrit                   = 'ANTENNA1=={0} && FIELD_ID=={1}'.format(str(st), str(src))
            caltb_stsrcspw[st][src] = np.unique(auxiliary.read_CASA_table(_inp_params, 'SPECTRAL_WINDOW_ID', qcrit,
                                                                          _tablename=caltable)
                                               )
    # PDF summary plot with all individual plots gathered together.
    pl_fname_c      = 'SUMMARY_{0}_vs_{1}_{2}.pdf'.format(yaxes, xaxes, figname)
    _, summary_plot = file_or_gui(_inp_params, pl_fname_c, [figfolder])
    pdf_summary     = matplotlib.backends.backend_pdf.PdfPages(summary_plot)
    for iterp in itertools.product(xaxesl, list(caltb_stsrcspw.keys())):
        # Separate plots per xaxis and antenna.
        xaxis  = iterp[0]
        pl_ant = iterp[1]
        if xaxis != 'time' and xaxis != 'freq' and xaxis != 'elev':
            err_msg = 'Got an unknown xaxis value when trying to plot solutions from the {0} calibration table.'.format(caltable)
            err_msg+= 'Allowed values are time, freq, and elev for gc tables.'
            err_msg+= 'Got {0}. Please check your array_finetune.inp file.'.format(str(xaxis))
            plt.close('all')
            pdf_summary.close()
            raise ValueError(err_msg)
        N_plots     = len(yaxesl)
        fig, axis_i = plt.subplots(nrows=N_plots, sharex=True)
        axis_i      = auxiliary.force_list(axis_i)
        axes_dict   = {}
        for i, ax in enumerate(axis_i):
            yaxis = yaxesl[i]
            if yaxis not in allowed_ylabels:
                err_msg = 'Got an unknown yaxis argument when trying to plot solutions from the {0} caltb: '.format(caltable)
                err_msg+= '{0}. Allowed values are {1}. Please check your array_finetune.inp file.'.format(yaxis,
                                                                                                           str(allowed_ylabels)
                                                                                                          )
                plt.close('all')
                pdf_summary.close()
                raise ValueError(err_msg)
            cannot_plot = False
            if caltable==_inp_params.calib_gaincurve[_inp_params.C_NAM] and (yaxis!='gc' or xaxis!='elev'):
                cannot_plot = True
            elif caltable!=_inp_params.calib_gaincurve[_inp_params.C_NAM] and (yaxis=='gc' or xaxis=='elev'):
                cannot_plot = True
            elif (caltable!=_inp_params.calib_tsys[_inp_params.C_NAM] \
            and caltable!=_inp_params.calib_tsys_add_exptau[_inp_params.C_NAM]) and yaxis=='tsys':
                cannot_plot = True
            if cannot_plot:
                plt.close('all')
                pdf_summary.close()
                if _inp_params.verbose:
                    print ('Unable to plot the selected parameters for the {0} caltb.'.format(caltable))
                    print ('Plesea check you array_finetune.inp file.')
                return
            axes_dict[yaxis] = ax
            ax.ticklabel_format(useOffset=False, style='plain')
            ax.set_ylabel(ylabels[yaxis], fontsize=fontsize)
            ax.tick_params(axis="y", direction='in', labelsize=fontsize)
            ax.tick_params(axis="x", direction='in', labelsize=fontsize)
        pl_antname = _ms_metadata.yield_antname(pl_ant)
        pl_title   = '{0} telescope'.format(pl_antname)
        if xaxis == 'time':
            pl_title += ' (lines between scans)'
        elif xaxis == 'freq':
            pl_title += ' (lines between spws)'
        pl_fname_s     = '{0}_{1}_vs_{2}_{3}.{4}'.format(pl_antname, yaxes, xaxis, figname, plt_ext)
        _, single_plot = file_or_gui(_inp_params, pl_fname_s, [figfolder])
        label_coll     = []
        caltb_sources  = list(caltb_stsrcspw[pl_ant].keys())
        src_spw_pairs  = []
        spw_sorted     = {}
        for src in caltb_sources:
            src_spws = caltb_stsrcspw[pl_ant][src]
            for spw in src_spws:
                src_spw_pairs.append((src, spw))
                minfreq = min(_ms_metadata.channels_nu[int(spw)])
                if  minfreq not in spw_sorted:
                    spw_sorted[minfreq] = spw
        # Line between spws or scans:
        if xaxis == 'freq':
            sorted_chan_freqs = sorted(list(spw_sorted.keys()))
            sorted_spws       = [spw_sorted[chanfreq] for chanfreq in sorted_chan_freqs]
            for spw_iter, spw_number in enumerate(sorted_spws):
                try:
                    next_spw = sorted_spws[spw_iter+1]
                except IndexError:
                    break
                end0   = max(_ms_metadata.channels_nu[int(spw_number)])
                start1 = min(_ms_metadata.channels_nu[int(next_spw)])
                middle = (end0 + start1) * 5.e-10
                for ax in axis_i:
                    ax.axvline(x=middle, ls='dashed', color='gray', linewidth=0.01*markersize, zorder=100)
        elif xaxis == 'time':
            scans = sorted(np.unique(auxiliary.read_CASA_table(_inp_params, 'SCAN_NUMBER', _tablename=caltable)))
            for scan_iter, scan_number in enumerate(scans):
                try:
                    next_scan = scans[scan_iter+1]
                except IndexError:
                    break
                end0   = _ms_metadata.yield_scantime(str(scan_number), 'end')
                start1 = _ms_metadata.yield_scantime(str(next_scan), 'start')
                middle = 0.5 * (end0 + start1)
                for ax in axis_i:
                    ax.axvline(x=auxiliary.time_convert(middle, to_datetime_obj=True), ls='dashed',
                               color='gray', linewidth=0.01*markersize, zorder=100
                              )
        starttime = float("inf")
        for iterp2 in src_spw_pairs:
            # Gather data, set xaxis, and combine sources and spws in stacked plots for every yaxis.
            pl_src   = iterp2[0]
            pl_spw   = iterp2[1]
            pl_sname = _ms_metadata.yield_sourcename(pl_src)
            qcrit    = 'ANTENNA1=={0} && FIELD_ID=={1} && SPECTRAL_WINDOW_ID=={2}'.format(str(pl_ant), str(pl_src), str(pl_spw))
            if xaxis == 'time':
                # Data has the shape [pol/yaxis][channels to be plotted for the same t][time axis to be plotted against].
                qselect                    = 'TIME, {0}, FLAG, SNR'.format(str(calparam))
                pl_t, pl_y, flag, snr      = auxiliary.read_CASA_table(_inp_params, qselect, qcrit, 'TIME', _tablename=caltable)
                snr                        = auxiliary.remove_flagged_values(snr, flag)
                pl_y                       = auxiliary.remove_flagged_values(pl_y, flag)
                try:
                    this_starttime = pl_t[0]
                    if this_starttime < starttime:
                        starttime = this_starttime
                        startdate = auxiliary.time_convert(starttime)
                        startdate = startdate.split('/')
                        thismonth = monthnames[int(startdate[1])]
                        xlabelstr = 'Day/hour ({0} {1}, {2})'.format(thismonth, startdate[2], startdate[0])
                        plt.xlabel(xlabelstr, fontsize=fontsize)
                except IndexError:
                    # No data here.
                    continue
                axis_i[-1].xaxis.set_major_formatter(mdates.DateFormatter('%d/%H'))
            elif xaxis == 'freq':
                # Data has the shape [pol/yaxis][channel axis to be plotted against][time axis to be plotted for the same freq].
                qselect         = '{0}, FLAG, SNR'.format(str(calparam))
                pl_y, flag, snr = auxiliary.read_CASA_table(_inp_params, qselect, qcrit, _tablename=caltable)
                # Swap to have [pol/yaxis][plotted as same freq][axis to be plotted against] to match 'time'
                flag            = np.swapaxes(flag, 1, 2)
                pl_y            = np.swapaxes(pl_y, 1, 2)
                snr             = np.swapaxes(snr, 1, 2)
                snr             = auxiliary.remove_flagged_values(snr, flag)
                pl_y            = auxiliary.remove_flagged_values(pl_y, flag)
                try:
                    _ = pl_y[0]
                except IndexError:
                    # No data here.
                    continue
                plt.xlabel('Frequency [GHz]', fontsize=fontsize)
            elif xaxis == 'elev':
                qselect = str(calparam)
                pl_y    = auxiliary.read_CASA_table(_inp_params, qselect, qcrit, _tablename=caltable)
                plt.xlabel('Elevation angle [deg]', fontsize=fontsize)
            for iterp3 in itertools.product(yaxesl, [0, 1]):
                # Select the data to be plotted.
                yaxis = iterp3[0]
                pol   = iterp3[1]
                ax    = axes_dict[yaxis]
                if yaxis == 'snr':
                    this_param = snr
                else:
                    this_param = pl_y
                try:
                    if this_param.shape[0] == 1:
                        label_overwrite_both_pols = 'R/L:\n'
                    else:
                        label_overwrite_both_pols = ''
                    if is_fftab:
                        try:
                            this_index = fparams[pol][yaxis]
                            pl_y_pol   = this_param[this_index]
                            if yaxis == 'phase':
                                pl_y_pol = auxiliary.wrap_phase(pl_y_pol*180./np.pi, 180)
                            elif yaxis == 'rate':
                                pl_y_pol*= 1.e12
                        except KeyError:
                            err_msg = '{0} is not a valid fringe parameter and cannot be plotted.'.format(yaxis)
                            err_msg+= 'Allowed values are snr,phase,delay,rate,disp. Please check your array_finetune.inp file.'
                            plt.close('all')
                            pdf_summary.close()
                            raise ValueError(err_msg)
                    elif yaxis == 'amp' or yaxis=='tsys':
                        pl_y_pol = auxiliary.get_vis_value(this_param[pol], 'amp')
                    elif yaxis == 'phase':
                        pl_y_pol = auxiliary.get_vis_value(this_param[pol], 'phase')
                    elif yaxis == 'tec':
                        pl_y_pol = np.multiply(this_param[pol], 1.e-17)
                    elif yaxis == 'gc':
                        pl_x  = np.linspace(0, 90, 256, endpoint=True)
                        polys = list(this_param.flatten())
                        Npoly = int(len(polys)/2)
                        if pol==0:
                            poly = polys[0:Npoly]
                        else:
                            poly = polys[Npoly:2*Npoly]
                        f        = np.poly1d(poly[::-1])
                        pl_y_pol = [f(90-pl_x)**2]
                    else:
                        pl_y_pol = this_param[pol]
                    if not is_fftab:
                        this_index = pol
                except IndexError:
                    continue
                try:
                    _ = pl_y_pol[0][0]
                except IndexError:
                    # No valid data here.
                    continue
                if xaxis == 'time':
                    # Handle flags along time axis.
                    any_channel_sol_is_good = np.logical_not(np.mean(np.logical_not(flag[this_index]), 0)).astype(int)
                    pl_x                    = auxiliary.remove_flagged_values(pl_t, any_channel_sol_is_good)
                    pl_x                    = [auxiliary.time_convert(tt, to_datetime_obj=True) for tt in pl_x]
                elif xaxis == 'freq':
                    # The channels can also be grouped together:
                    channels   = pl_y_pol.shape[-1]
                    chan_freqs = _ms_metadata.channels_nu[int(pl_spw)]
                    Nchans     = len(chan_freqs)
                    chan_step  = max(1, int(Nchans/channels))
                    chan_IDs   = np.linspace(0, Nchans-chan_step, channels).astype(int)
                    pl_x       = chan_freqs[chan_IDs]
                    pl_x       = np.multiply(pl_x, 1.e-9)
                try:
                    pl_y_pol = np.asarray([np.nan_to_num(np.asanyarray(pyp, dtype=float), 0) for pyp in pl_y_pol])
                except (TypeError, IndexError) as _:
                    pl_y_pol = np.nan_to_num(np.asarray(pl_y_pol, dtype=float), 0)
                if isinstance(pl_sname, str):
                    pl_sname_fmt = '\n   '.join(pl_sname[n:n + 5] for n in range(0, len(pl_sname), 5))
                    pl_sname_str = pl_sname
                else:
                    pl_sname_fmt = 'All'
                    pl_sname_str = _ms_metadata.yield_sourcename(0)
                if pol == 0:
                    pl_symbol = '+'
                    pl_label  = 'R:{0}'.format(pl_sname_fmt)
                    pl_fcolor = None
                    pl_ecolor = None
                    pl_color  = srccol[0][pl_sname_str]
                elif pol == 1:
                    pl_symbol = 'o'
                    pl_label  = 'L:{0}'.format(pl_sname_fmt)
                    pl_fcolor = 'none'
                    pl_ecolor = srccol[1][pl_sname_str]
                    pl_color  = None
                if label_overwrite_both_pols:
                    pl_label = label_overwrite_both_pols + pl_sname_fmt
                if ax == axis_i[-1] and pl_label not in label_coll:
                    label_coll.append(pl_label)
                else:
                    pl_label = ''
                only_small_num = True
                for j, yy in enumerate(pl_y_pol):
                    # Stack data along freq axis for 'time' and along time axis for 'freq'
                    if j != 0:
                        pl_label = ''
                    if max(yy) > 1000:
                        fmt = mticker.ScalarFormatter(useOffset=False, useMathText=True)
                        g   = lambda x,pos : "${}$".format(fmt._formatSciNotation('%1.10e' % x))
                        ax.yaxis.set_major_formatter(mticker.FuncFormatter(g))
                        only_small_num = False
                    elif only_small_num and abs(max(yy)) < 2:
                        only_small_num = True
                    else:
                        only_small_num = False
                    try:
                        ax.scatter(pl_x, yy, s=markersize, marker=pl_symbol, label=pl_label, linewidth=0.017*markersize,
                                   zorder=200, color=pl_color, facecolor=pl_fcolor, edgecolor=pl_ecolor
                                  )
                    except ValueError:
                        # Flagged data not properly catched before?
                        pass
                if only_small_num:
                    ax.yaxis.set_major_formatter(mticker.FormatStrFormatter('%.2f'))
        for ax in axis_i:
            ymin, ymax = ax.get_ylim()
            ax.set_ylim(ymin-0.01, ymax+0.01)
            Nyticks = max(12/N_plots, 3)
            Nyticks = min(Nyticks, 8)
            ax.yaxis.set_major_locator(plt.MaxNLocator(Nyticks))
            if Nyticks > 1 and ax!=axis_i[0]:
                #Remove top tick except for top plot
                plt.setp(ax.get_yticklabels()[-1], visible=False)
        axis_i[0].set_title(pl_title, fontsize=fontsize)
        axis_i[-1].legend(loc=(1.01,0))
        axis_i[-1].xaxis.set_major_locator(plt.MaxNLocator(7))
        for tick in axis_i[-1].get_xticklabels():
            tick.set_rotation(45)
        plt.subplots_adjust(wspace=0, hspace=0, right=0.8, left=0.15, bottom=0.2)
        plt.savefig(single_plot)
        pdf_summary.savefig(fig)
        plt.close(fig)
    pdf_summary.close()


def simple_plotter(_x, _y, _xaxis, _yaxis, _title, _pltname, _x2=[], _y2=[], _hline=0, _vline=0, _pltype='scatter',
                   rescale_axes=False, logx=False, logy=False):
    """
    If _pltype == scatter:
        Makes a simple scatter plot and saves it to _pltname on disk.
    If _pltype == errbar:
        Makes a simple error bar plot and saves it to _pltname on disk. Here _y must be [vals, min_err, max_err].
    """
    _y2 = np.asarray(_y2)
    _x2 = np.asarray(_x2)
    plt.ioff()
    plt.clf()
    ax = plt.gca()
    ax.ticklabel_format(useOffset=False, style='plain')
    if _pltype == 'scatter':
        plt.scatter(_x, _y, s=global_plt_markersize)
    elif _pltype == 'errbar':
        plt.errorbar(_x, _y[0], yerr=[np.abs(_y[0]-_y[1]), np.abs(_y[0]-_y[2])], fmt='o', elinewidth=2, markeredgewidth=2,
                     capsize=10
                    )
    else:
        raise ValueError(str(_pltype) + ' is not a valid plot type. Must be scatter or errbar.')
    plt.xlabel(_xaxis, fontsize=global_plt_fontsize)
    plt.ylabel(_yaxis, fontsize=global_plt_fontsize)
    if logx:
        plt.xscale('log')
    if logy:
        ax.yaxis.set_tick_params('minor', length=4, width=2, labelsize=global_plt_fontsize)
        ax.yaxis.set_tick_params('major', length=9, width=2, labelsize=global_plt_fontsize)
        plt.yscale('log')
        ax.tick_params(axis="y", direction='in', which='minor')
    ax.tick_params(axis="x", direction='in', labelsize=global_plt_fontsize)
    ax.tick_params(axis="y", direction='in', labelsize=global_plt_fontsize)
    if rescale_axes:
        minx = min(_x)
        maxx = max(_x)
        ext  = 0.1 * (np.abs(minx) + np.abs(maxx))
        plt.xlim(minx - ext, maxx + ext)
    plt.title(_title, fontsize=global_plt_fontsize)
    if _hline:
        plt.axhline(y=_hline, c='r')
    if _vline:
        plt.axvline(x=_vline, c='r')
    if any(_y2):
        plt.plot(_x2, _y2, linestyle='--', c='k')
    plt.savefig(_pltname)
    plt.close()


def simple_barplot(x, ys, _xaxis, _yaxis, _ylim, _pltname, _title, _labels, plotwidth, shiftwidths):
    """
    Plots bars for [ys] and [_labels] with a width given by plotwidth.
    The bottom left positions of [ys] will be at [shiftwidths].
    """
    N      = len(ys)
    colors = plt.cm.winter(np.linspace(0,1,N))
    plt.ioff()
    plt.clf()
    ax = plt.gca()
    ax.tick_params(axis='both', labelsize=global_plt_fontsize)
    ax.set_ylim(_ylim)
    plt.xlabel(_xaxis, fontsize=global_plt_fontsize)
    plt.ylabel(_yaxis, fontsize=global_plt_fontsize)
    for width,y,c,label in zip(shiftwidths,ys,colors,_labels):
        ax.bar(np.asarray(x, dtype=float)+width, np.asarray(y, dtype=float), width=plotwidth, color=c, label=label, align='edge')
    plt.title(_title, fontsize=global_plt_fontsize)
    plt.legend(loc='best')
    plt.savefig(_pltname)
    plt.close()
