def update_a(i, merge_df, modes, var_x, var_ys, var_t, sources, lines, ax):
    """
    Update function for animated plots
    """

    var_t_vals = np.sort(merge_df[var_t].unique())
    var_t_val = var_t_vals[i]

    if var_t == "Time":
        var_t_string = str(var_t_val).rstrip('0').rstrip('.')
    elif var_t == "Freq":
        freq_MHz = var_t_val / 1e6
        var_t_string = "{:7.3f} MHz".format(freq_MHz)
    else:
        var_t_string = ("%.4f" % var_t_val).rstrip('0').rstrip('.')

    title = "Plot of "
    for source in sources:
        title = add_key(title, sources, source)

    title = title + " for "
    for var_y in var_ys:
        title = add_key(title, var_ys, var_y)

    title = (title + "-channels over " + gen_pretty_name(var_x) + " at\n" +
             gen_pretty_name(var_t))

    title = "\n".join([modes["title"], title + " of " + var_t_string])

    ax.set_title(title, wrap=True)

    var_x_vals = plottable(
        merge_df.loc[merge_df[var_t] == var_t_val].reset_index(drop=True),
        var_x)

    # sets the y axis scale to logarithmic if requested.
    if 'log' in modes['scale']:
        ax.set_yscale('log')

    # sets the y axis scale to percentage if requested.
    if 'percent' in modes['scale']:
        ax.yaxis.set_major_formatter(mtick.PercentFormatter())

    no_sources = len(sources)
    for y_index in range(len(var_ys)):
        var_y = var_ys[y_index]
        for source_index in range(no_sources):
            source = sources[source_index]
            sep = get_source_separator(source)
            var_y_vals = plottable(
                merge_df.loc[merge_df[var_t] == var_t_val].reset_index(
                    drop=True), (var_y + sep + source))

            # sets the y axis scale to percentage if requested.
            if 'percent' in modes['scale']:
                var_y_vals = var_y_vals * 100

            line_index = (y_index * no_sources) + source_index
            lines[line_index].set_data(var_x_vals, var_y_vals)
    ax.set_aspect('auto')
def calc_fom_1d(merge_df, m_keys, fom):
    """
    This function takes a merged dataframe as an argument and
    calculates and returns the selected figure of merit for the differences in
    between source and model that data frame
    """
    # creates an output list for figures of merit
    fom_outs = []

    if fom == "rmse":
        for key in m_keys:
            fom_outs.append(
                np.mean(plottable(merge_df, (key + '_diff'))**2)**0.5)
    if fom == "corr":
        for key in m_keys:
            # uses absolute values as real values cannot be negative and complex
            # values cannot be correlated
            model_vals = list(plottable(merge_df, (key + '_model')))
            scope_vals = list(plottable(merge_df, (key + '_scope')))
            corr = pearsonr(model_vals, scope_vals)[0]
            fom_outs.append(corr)
        # using [0] from the pearsonr to return the correlation coefficient, but not
        # the 2-tailed p-value stored in [1]

    return (fom_outs)
def normalise_data(merge_df,modes,channel,out_str=""):
    '''
    This function normalises the data for the scope according to the 
    normalisation mode specified.  These options are detailed belwo
    '''
    if modes['verbose'] >=2:
        print("Normalising data")
    if 'o' in modes['norm'] :
        if modes['verbose'] >=2:
            print("Normalisation basis: Overall")
        #normalises by dividing by the maximum
        merge_df[channel+out_str]=merge_df[channel]/np.max((plottable(merge_df[channel])))
    elif 'f' in modes['norm']:
        if modes['verbose'] >=2:
            print("Normalisation basis: Frequency")
        #normalises by dividing by the maximum for each frequency
        var_str='Freq'
        norm_operation(merge_df, var_str,channel,modes,out_str)
    elif 't' in modes['norm']:
        if modes['verbose'] >=2:
            print("Normalisation basis: time")
        #normalises by dividing by the maximum for each frequency
        var_str='Time'
        norm_operation(merge_df, var_str,channel,modes,out_str)
    elif 'n' in modes ['norm']:
        if modes['verbose'] >=2:
            print("Normalisation basis: None")
        pass     #nothing to be done       
    else:
        if modes['verbose'] >=1:
            print("WARNING: Normalisation mode not specified correctly!")
 
    return (merge_df)
def crop_operation (in_df,modes):
    if modes['verbose'] >=2:
        print("Carrying out Crop Operation")
    out_df=in_df.copy()
    #goes through all the columns of the data
    for col in out_df:
        #targets the dependent variables
        if col not in ['Time', 'Freq', 'd_Time', 'original_Time']:
            #drops all zero values from the data
            out_df.drop(out_df[out_df[col] == 0.0].index, inplace=True)
            #if the cropping mode isn't set to 0, crop the scope data
            if 0.0 != modes['crop']:
                if modes['crop_type'] == "median":
                    col_limit = np.median(out_df[col])*modes['crop']
                elif modes['crop_type'] == "mean":
                    col_limit = np.mean(out_df[col])*modes['crop']
                elif modes['crop_type'] == "percentile":
                    if modes['crop'] < 100:
                        col_limit = np.percentile(out_df[col], modes['crop'])
                    else:
                        if modes['verbose'] >=1:
                            print("WARNING: Percentile must be less than 100")
                        col_limit = np.max(plottable(out_df[col]))
                else:
                    if modes['verbose'] >=1:
                        print("WARNING: crop_type incorrectly specified.")
                    col_limit = np.median(out_df[col])*modes['crop']
                out_df.drop(out_df[out_df[col] > col_limit].index, inplace=True)
                # out_df.drop(out_df[out_df[col] < 0].index, inplace=True)
            
    return(out_df)
def norm_operation(in_df, var_str,channel,modes,out_str=""):
    '''
    This function carries out the normalisation operation based on the input 
    which specifies which variable to normalise over.  
    '''

    if modes['verbose'] >=2:
        print("Carrying out normalisation")
    #identifies allthe unique values of the variable in the column
    unique_vals=in_df[var_str].unique()
    

    #iterates over all unique values
    for unique_val in unique_vals:

        unique_max = np.max(plottable(in_df.loc[(in_df[var_str]==unique_val),channel]))

        if unique_max !=0:
            in_df.loc[(in_df[var_str]==unique_val),(channel+out_str)]=in_df.loc[(in_df[var_str]==unique_val),channel]/unique_max
        else:
            in_df.loc[(in_df[var_str]==unique_val),(channel+out_str)]=0
def calc_fom_nd(in_df, var_str, m_keys, modes, fom="rmse"):
    """
    This function calculates a figure of merit between the scope and model
    values for the specified channels  as they are distributed against another
    column of the dataframe merge_df which is identified by var_str

    in current versions, useable values for var_str are "Time" and "Freq"
    in current versions, useable values for fom are "rmse" and "corr"
    """

    print("probe: ", modes['colour'])
    if modes['verbose'] >= 2:
        print("Calculating the " + gen_pretty_name(fom) +
              " between observed and model data.")
    # creates empty lists for the Errors
    n_foms = []
    for i in range(len(m_keys)):
        n_foms.append([])

    # identifies allthe unique values of the variable in the column
    unique_vals = in_df[var_str].unique()

    unique_vals = np.sort(unique_vals)

    # iterates over all unique values
    for unique_val in unique_vals:
        # creates a dataframe with  only the elements that match the current
        # unique value
        unique_merge_df = in_df[in_df[var_str] ==
                                unique_val].copy().reset_index(drop=True)
        # uses this unique value for and the 1-dimensional calc_fom_1d function
        # to calculate the Figure of merit for each channel
        n_fom = calc_fom_1d(unique_merge_df, m_keys, fom)

        # appends these to the list
        for i in range(len(m_keys)):
            n_foms[i].append(n_fom[i])

    # creates an overlaid plot of how the Figure of Merit  between model and scope
    # varies for each of the channels against var_str
    if modes['verbose'] >= 2:
        print("Plotting the " + gen_pretty_name(fom) +
              " between model and scope for " +
              channel_maker(m_keys, modes, ", ") + " against " +
              gen_pretty_name(var_str))

    if modes['colour'] in ["matching", "matching_dark"] and len(m_keys) == 1:
        text_colour = colour_models(m_keys[0])
    elif modes['colour'] in ["dark", "matching_dark"]:
        text_colour = "white"

    else:  # light or None
        text_colour = "black"

    mpl.rcParams.update({
        'text.color': text_colour,
        'axes.labelcolor': text_colour,
        'xtick.color': text_colour,
        'ytick.color': text_colour
    })

    mpl.rc('axes', edgecolor=text_colour)
    fig, ax = plt.subplots()
    if modes['dpi'] is None:
        pass
    else:
        fig.set_dpi(modes['dpi'])

    if modes['image_size'] is None:
        pass  # do nothing
    else:
        fig.set_size_inches(modes['image_size'])

    if modes['colour'] in ["dark", "matching_dark"]:
        ax.set_facecolor('black')
        fig.patch.set_facecolor('black')


    graph_title = "\n".join([modes['title'],"Plot of the "+gen_pretty_name(fom)+\
            " in "])
    for key in m_keys:
        plt.plot(plottable(unique_vals, var_str),
                 n_foms[m_keys.index(key)],
                 label=key + '_' + fom,
                 color=colour_models(key))

        graph_title = add_key(graph_title, m_keys, key)

    # calculates and adds title with frequency in MHz

    graph_title = graph_title + "-channels over " + gen_pretty_name(var_str)

    plt.title(graph_title, wrap=True)

    # rotates the labels.  This is necessary for timestamps
    plt.xticks(rotation=90)
    plt.legend(frameon=False)
    plt.xlabel(gen_pretty_name(var_str, units=True), wrap=True)

    # prints or saves the plot
    if modes['out_dir'] == None:
        plt.show()
    else:
        # creates an output-friendly string for the channel
        str_channel = list_to_string(m_keys, '_')

        plt_file = prep_out_file(modes,
                                 plot=fom,
                                 ind_var=var_str,
                                 channel=str_channel,
                                 out_type=modes['image_type'])

        if modes['verbose'] >= 2:
            print("Saving: " + plt_file)
        try:
            plt.savefig(plt_file,
                        bbox_inches='tight',
                        facecolor=fig.get_facecolor(),
                        edgecolor='none')
        except ValueError:
            if modes['verbose'] >= 1:
                print("ERROR: Unable to save file, showing instead.")
            try:
                plt.show()
            except:
                print("Unable to show file.")
        plt.close()

    # returns the correlation lists if needed
    return (n_foms)
def four_var_plot(in_df,
                  modes,
                  var_x,
                  var_y,
                  var_z,
                  var_y2,
                  source,
                  plot_name=""):
    """
    Plots a two part plot of four variables from merge_df as controlled by
    modes.

    Plot 1 is a 3-d colour plot with x, y and z variables controlled by
    arguments.

    Plot 2 is a 2-d scatter plot with the same x parameter and another y
    variable

    var_z must be one of the dependent variables
    """
    if modes['verbose'] >= 2:
        print("Plotting " + gen_pretty_name(source) + "\nfor " +
              gen_pretty_name(var_z) + " against " + gen_pretty_name(var_x) +
              " and " + gen_pretty_name(var_y) + " and " +
              gen_pretty_name(var_y2, plot_name) + " against " +
              gen_pretty_name(var_x))

    if modes['colour'] in ["matching", "matching_dark"]:
        text_colour = colour_models(var_y)
    elif modes['colour'] == "dark":
        text_colour = "white"

    else:  # light or None
        text_colour = "black"

    mpl.rcParams.update({
        'text.color': text_colour,
        'axes.labelcolor': text_colour,
        'xtick.color': text_colour,
        'ytick.color': text_colour
    })

    mpl.rc('axes', edgecolor=text_colour)
    fig, ax = plt.subplots()
    if modes['dpi'] is None:
        pass
    else:
        fig.set_dpi(modes['dpi'])

    if modes['image_size'] is None:
        pass  # do nothing
    else:
        fig.set_size_inches(modes['image_size'])

    if modes['colour'] in ["dark", "matching_dark"]:
        ax.set_facecolor('black')
        fig.patch.set_facecolor('black')

    # create a 2 X 2 grid
    gs = grd.GridSpec(2,
                      2,
                      height_ratios=[1, 1],
                      width_ratios=[20, 1],
                      wspace=0.1)

    plt.subplot(gs[0])
    upper_title = ("Plot of " + gen_pretty_name(source) + "\nfor " +
                   gen_pretty_name(var_z) + " against " +
                   gen_pretty_name(var_x) + " and " + gen_pretty_name(var_y))
    label = "\n".join([modes["title"], upper_title])
    plt.title(label, wrap=True)

    sep = get_source_separator(source)

    #    try:
    #        # plots the channel in a colour based on its name
    #        plt.tripcolor(plottable(in_df,var_x),
    #                      plottable(in_df,var_y),
    #                      plottable(in_df,(var_z+sep+source)),
    #                      cmap=plt.get_cmap(colour_models(var_z+'_s')))
    #
    #    except RuntimeError:
    #        if  modes['verbose'] >=1:
    #            print("ERROR: Data not suitable for 3d colour plot.  Possible alternatives: animated plots")
    # plots the channel in a colour based on its name
    colours = plt.get_cmap(colour_models(var_z + '_s'))

    if (modes["three_d"] in ["colour", "color"]):
        try:
            x_vals = plottable(in_df, var_x)
            y_vals = plottable(in_df, var_y)
            if "percent" in modes["scale"]:
                z_vals = plottable(in_df, var_z) * 100
            else:
                z_vals = plottable(in_df, var_z)

            if modes["scale"] == "log":
                # finds the limits of the z variable
                maxz = np.max(z_vals)
                minz = np.min(z_vals)

                # if the values go below zero, then plot with symmetric log, otherwise use log plotting
                if minz <= 0:
                    log_lim = 10
                    linthresh = max([abs(maxz), abs(minz)]) / log_lim
                    norm = SymLogNorm(linthresh,
                                      linscale=1.0,
                                      vmin=minz,
                                      vmax=maxz,
                                      clip=False)
                else:
                    norm = LogNorm()

                p = plt.tripcolor(x_vals,
                                  y_vals,
                                  z_vals,
                                  cmap=colours,
                                  norm=norm)
            else:
                p = plt.tripcolor(x_vals, y_vals, z_vals, cmap=colours)
        except RuntimeError:
            if modes['verbose'] >= 1:
                print(
                    "ERROR: Data not suitable for 3d colour plot.  Possible alternatives: contour/animated plots"
                )
    elif (modes["three_d"] in ["contour"]):
        try:

            cols = np.unique(in_df[var_y]).shape[0]
            x_vals = np.array(in_df[var_x]).reshape(-1, cols)
            y_vals = np.array(in_df[var_y]).reshape(-1, cols)
            if "percent" in modes["scale"]:
                z_vals = np.array(in_df[(var_z + sep + source)]).reshape(
                    -1, cols) * 100
            else:
                z_vals = np.array(in_df[(var_z + sep + source)]).reshape(
                    -1, cols)
            plt.contour(x_vals, y_vals, z_vals, cmap=colours)
        except:
            if modes['verbose'] >= 1:
                print(
                    "ERROR: Data not suitable for 3d contour plot.  Possible alternatives: colour/animated plots"
                )

    # TODO: fix percentile plotting limits
    plt.clim(np.percentile(plottable(in_df, (var_z + sep + source)), 5),
             np.percentile(plottable(in_df, (var_z + sep + source)), 95))

    # plots axes
    plt.xticks([])
    plt.ylabel(gen_pretty_name(var_y, units=True), wrap=True)

    # color bar in it's own axis
    colorAx = plt.subplot(gs[1])

    if "percent" in modes["scale"]:
        cb = plt.colorbar(p, cax=colorAx, format='%.3g%%')
    else:
        cb = plt.colorbar(p, cax=colorAx)

    cb.set_label(source + " for " + var_z)

    plt.subplot(gs[2])

    lower_title = ("Plot of "+gen_pretty_name(var_y2, plot_name)+" against "+\
                   gen_pretty_name(var_x))
    plt.title(lower_title, wrap=True)

    # plots the scattergraph
    plt.plot(plottable(in_df, var_x),
             plottable(in_df, var_y2),
             color=colour_models(var_y2),
             marker=".",
             linestyle="None")

    plt.xlabel(gen_pretty_name(var_x, units=True), wrap=True)
    plt.ylabel(gen_pretty_name(var_y2, units=True), wrap=True)
    plt.legend(frameon=False)

    # prints or saves the plot
    if modes['out_dir'] == None:
        plt.show()
    else:
        plt_file = prep_out_file(modes,
                                 source=source,
                                 plot=var_x,
                                 dims="nd",
                                 channel=var_z,
                                 ind_var=var_y,
                                 plot_name=plot_name,
                                 out_type=modes['image_type'])
        if modes['verbose'] >= 2:
            print("Saving: " + plt_file)
        try:
            plt.savefig(plt_file,
                        bbox_inches='tight',
                        facecolor=fig.get_facecolor(),
                        edgecolor='none')
        except ValueError:
            if modes['verbose'] >= 1:
                print("ERROR: Unable to save file, showing instead.")
            try:
                plt.show()
            except:
                print("Unable to show file.")
        plt.close()
def plot_1f(merge_df, m_keys, modes, sources, var_str):
    #creates an overlaid plot of how the sources
    #varies for each of the channels against var_str
    title = "Plot of "
    for source in sources:
        title = add_key(title, sources, source)
    title = title + " for "
    for key in m_keys:
        title = add_key(title, m_keys, key)

    title = title + "-channels over " + gen_pretty_name(var_str)
    freq_in = modes['freq'][0]
    freq_MHz = freq_in / 1e6

    title = title + "\nat a Frequency of {:7.3f} MHz".format(freq_MHz)

    if modes['verbose'] >= 2:
        print(title)

    if modes['colour'] in ["matching", "matching_dark"] and len(m_keys) == 1:
        text_colour = colour_models(m_keys[0])
    elif modes['colour'] in ["dark", "matching_dark"]:
        text_colour = "white"

    else:  # light or None
        text_colour = "black"

    mpl.rcParams.update({
        'text.color': text_colour,
        'axes.labelcolor': text_colour,
        'xtick.color': text_colour,
        'ytick.color': text_colour
    })

    mpl.rc('axes', edgecolor=text_colour)
    fig, ax = plt.subplots()
    if modes['dpi'] is None:
        pass
    else:
        fig.set_dpi(modes['dpi'])

    if modes['image_size'] is None:
        pass  # do nothing
    else:
        fig.set_size_inches(modes['image_size'])

    if modes['colour'] in ["dark", "matching_dark"]:
        ax.set_facecolor('black')
        fig.patch.set_facecolor('black')

    for key in m_keys:
        for source in sources:
            sep = get_source_separator(source)

            var_y_vals = plottable(merge_df, (key + sep + source))

            # sets the y axis scale to percentage if requested.
            if 'percent' in modes['scale']:
                var_y_vals = var_y_vals * 100

            ax.plot(plottable(merge_df, var_str),
                    var_y_vals,
                    label=key + sep + source,
                    color=colour_models(key + sep + source))

    ax.set_title(title, wrap=True)

    ax.legend(frameon=False)
    # plots the axis labels rotated so they're legible
    ax.tick_params(labelrotation=45)
    ax.set_xlabel(gen_pretty_name(var_str, units=True))

    # sets the y axis scale to logarithmic if requested.
    if 'log' in modes['scale']:
        ax.set_yscale('log')

    # sets the y axis scale to percentage if requested.
    if 'percent' in modes['scale']:
        ax.yaxis.set_major_formatter(mtick.PercentFormatter())

    # prints or saves the plot
    if modes['out_dir'] == None:
        plt.show()
    else:
        str_sources = channel_maker(sources, modes)
        plt_file = prep_out_file(modes,
                                 plot="vals",
                                 dims="1d",
                                 ind_var=var_str,
                                 channel=list_to_string(m_keys, '_'),
                                 source=str_sources,
                                 freq=min(merge_df.Freq),
                                 out_type="png")
        if modes['verbose'] >= 2:
            print("Saving: " + plt_file)
        try:
            plt.savefig(plt_file,
                        bbox_inches='tight',
                        facecolor=fig.get_facecolor(),
                        edgecolor='none')
        except ValueError:
            if modes['verbose'] >= 1:
                print("ERROR: Unable to save file, showing instead.")
            try:
                plt.show()
            except:
                print("Unable to show file.")
        plt.close()
    return (0)
def plot_3d_graph(merge_df, key, modes, source, var_x, var_y):
    """
    This function generates 3d colour plots against frequency and time for the 
    given value for a given channel
    """

    sep = get_source_separator(source)

    if modes['verbose'] >= 2:
        print("Generating a 3-d plot of " + gen_pretty_name(source) + " for " +
              key)

    if modes['colour'] in ["matching", "matching_dark"]:
        text_colour = colour_models(key)
    elif modes['colour'] == "dark":
        text_colour = "white"

    else:  # light or None
        text_colour = "black"

    mpl.rcParams.update({
        'text.color': text_colour,
        'axes.labelcolor': text_colour,
        'xtick.color': text_colour,
        'ytick.color': text_colour
    })

    mpl.rc('axes', edgecolor=text_colour)
    fig, ax = plt.subplots()
    if modes['dpi'] is None:
        pass
    else:
        fig.set_dpi(modes['dpi'])

    if modes['image_size'] is None:
        pass  # do nothing
    else:
        fig.set_size_inches(modes['image_size'])

    if modes['colour'] in ["dark", "matching_dark"]:
        ax.set_facecolor('black')
        fig.set_facecolor('black')

    graph_title = "\n".join([
        modes['title'],
        ("Plot of the " + gen_pretty_name(source) + " for " + key +
         "-channel \nover " + gen_pretty_name(var_x) + " and " +
         gen_pretty_name(var_y) + ".")
    ])
    plt.title(graph_title, wrap=True)

    var_z = (key + sep + source)

    # plots the channel in a colour based on its name
    colours = plt.get_cmap(colour_models(key + '_s'))

    if modes["three_d"] in ["colour", "color"]:
        try:
            x_vals = plottable(merge_df, var_x)
            y_vals = plottable(merge_df, var_y)

            if "percent" in modes["scale"]:
                z_vals = plottable(merge_df, var_z) * 100
            else:
                z_vals = plottable(merge_df, var_z)

            if "log" in modes["scale"]:
                maxz = np.max(z_vals)
                minz = np.min(z_vals)

                # if the values go below zero, then plot with symmetric log, otherwise use log plotting
                if minz <= 0:
                    log_lim = 10
                    linthresh = max([abs(maxz), abs(minz)]) / log_lim
                    norm = SymLogNorm(linthresh,
                                      linscale=1.0,
                                      vmin=minz,
                                      vmax=maxz,
                                      clip=False)
                else:
                    norm = LogNorm()

                p = plt.tripcolor(x_vals,
                                  y_vals,
                                  z_vals,
                                  cmap=colours,
                                  norm=norm)
            else:
                p = plt.tripcolor(x_vals, y_vals, z_vals, cmap=colours)
        except RuntimeError:
            if modes['verbose'] >= 1:
                print(
                    "ERROR: Data not suitable for 3d colour plot.  Possible alternatives: contour/animated plots"
                )
    elif modes["three_d"] in ["contour"]:
        try:
            cols = np.unique(merge_df[var_y]).shape[0]
            x_vals = np.array(merge_df[var_x]).reshape(-1, cols)
            y_vals = np.array(merge_df[var_y]).reshape(-1, cols)
            if "percent" in modes["scale"]:
                z_vals = np.array(merge_df[var_z]).reshape(-1, cols)
            else:
                z_vals = np.array(merge_df[var_z]).reshape(-1, cols) * 100

            plt.contour(x_vals, y_vals, z_vals, cmap=colours)
        except:
            if modes['verbose'] >= 1:
                print(
                    "ERROR: Data not suitable for 3d contour plot.  Possible alternatives: colour/animated plots"
                )

    plt.legend(frameon=False)

    if var_x in ['d_Time']:
        # plots x-label using start time
        plt.xlabel(gen_pretty_name(var_x, units=True) + "\nStart Time: " +
                   str(min(merge_df.Time)),
                   wrap=True)
    else:
        plt.xlabel(gen_pretty_name(var_x, units=True), wrap=True)

    plt.ylabel(gen_pretty_name(var_y, units=True), wrap=True)

    if "percent" in modes["scale"]:
        plt.colorbar(format='%.3g%%')
    else:
        plt.colorbar()

    plt.tight_layout

    # prints or saves the plot
    if modes['out_dir'] is None:
        plt.show()
    else:
        plt_file = prep_out_file(modes,
                                 source=source,
                                 plot="vals",
                                 dims="nd",
                                 channel=key,
                                 out_type=modes['image_type'])
        if modes['verbose'] >= 2:
            print("plotting: " + plt_file)
        try:
            plt.savefig(plt_file,
                        bbox_inches='tight',
                        facecolor=fig.get_facecolor(),
                        edgecolor='none')
        except ValueError:
            if modes['verbose'] >= 1:
                print("ERROR: Unable to save file, showing instead.")
            try:
                plt.show()
            except:
                print("Unable to show file.")

        plt.close()
def animated_plot(merge_df,
                  modes,
                  var_x,
                  var_ys,
                  var_t,
                  sources,
                  time_delay=20):
    """
    Produces an animated linegraph(s) with the X, Y and T variables specified
    """

    if modes['colour'] in ["matching", "matching_dark"] and len(var_ys) == 1:
        text_colour = colour_models(var_ys[0])
    elif modes['colour'] in ["dark", "matching_dark"]:
        text_colour = "white"

    else:  # light or None or matching and multiple y
        text_colour = "black"

    mpl.rcParams.update({
        'text.color': text_colour,
        'axes.labelcolor': text_colour,
        'xtick.color': text_colour,
        'ytick.color': text_colour
    })

    mpl.rc('axes', edgecolor=text_colour)

    fig, ax = plt.subplots()

    if modes['dpi'] is None:
        pass
    else:
        fig.set_dpi(modes['dpi'])

    if modes['image_size'] is None:
        pass  # do nothing
    else:
        fig.set_size_inches(modes['image_size'])

    if modes['colour'] in ["dark", "matching_dark"]:
        ax.set_facecolor('black')
        fig.patch.set_facecolor('black')

    # hard coded for now, need to parameterise
    percentile_gap = 0  # 5
    multiplier = 1  # 1.5

    # sets default values for max_ and min_y
    max_y = np.nextafter(0, 1)  # makes max and min values distinct
    min_y = 0

    # Plot a scatter that persists (isn't redrawn) and the initial line.
    var_t_vals = np.sort(merge_df[var_t].unique())
    var_t_val = var_t_vals[0]

    # str_channel = list_to_string(var_ys,", ")

    if var_t == "Time":
        var_t_string = str(var_t_val).rstrip('0').rstrip('.')
    elif var_t == "Freq":
        freq_MHz = var_t_val / 1e6
        var_t_string = "{:7.3f} MHz".format(freq_MHz)
    else:
        var_t_string = ("%.4f" % var_t_val).rstrip('0').rstrip('.')

    title = "Plot of "
    for source in sources:
        title = add_key(title, sources, source)
    title = title + " for "
    for var_y in var_ys:
        title = add_key(title, var_ys, var_y)

    title = (title + "-channels over " + gen_pretty_name(var_x) + " at\n" +
             gen_pretty_name(var_t))

    if modes['verbose'] >= 2:
        print("Generating an Animated " + title)

    title = "\n".join([modes["title"], title + " of " + var_t_string])

    ax.set_title(title, wrap=True)

    var_x_vals = plottable(
        merge_df.loc[merge_df[var_t] == var_t_val].reset_index(drop=True),
        var_x)

    lines = []

    for i in range(len(var_ys)):
        var_y = var_ys[i]
        for source in sources:
            sep = get_source_separator(source)

            var_y_vals = plottable(
                merge_df.loc[merge_df[var_t] == var_t_val].reset_index(
                    drop=True), (var_y + sep + source))

            var_y_vals_all = merge_df[var_y + sep + source]

            # sets the y axis scale to percentage if requested.
            if 'percent' in modes['scale']:
                var_y_vals = var_y_vals * 100

            line, = ax.plot(var_x_vals,
                            var_y_vals,
                            color=colour_models(var_y + sep + source))
            lines.append(line)

            # code to set x and y limits.
            local_min_y = np.percentile(var_y_vals_all,
                                        percentile_gap) * multiplier
            min_y = min(min_y, local_min_y)

            # min_y = 0#min(merge_df[(var_y+"_"+source)].min(),0)
            local_max_y = np.percentile(var_y_vals_all,
                                        100 - percentile_gap) * multiplier
            max_y = max(max_y, local_max_y)

    # sets the y axis scale to logarithmic if requested.
    if 'log' in modes['scale']:
        ax.set_yscale('log')

    # sets the y axis scale to percentage if requested.
    if 'percent' in modes['scale']:
        ax.yaxis.set_major_formatter(mtick.PercentFormatter())

    if min_y > 0 and 'linear' in modes['scale']:
        min_y = 0

    ax.set_ylim(min_y, max_y)

    ax.set_xlabel(gen_pretty_name(var_x, units=True), wrap=True)
    ax.set_ylabel(channel_maker(var_ys, modes, ", ") +
                  " flux\n(arbitrary units)",
                  wrap=True)

    ax.legend(frameon=False)

    if modes['out_dir'] is None:
        repeat_option = True
    else:
        repeat_option = False

    # creates a global variable as animations only work with globals
    if "anim" not in globals():
        global anim
        anim = []
    else:
        pass

    anim.append(
        FuncAnimation(fig,
                      update_a,
                      frames=range(len(var_t_vals)),
                      interval=time_delay,
                      fargs=(merge_df, modes, var_x, var_ys, var_t, sources,
                             lines, ax),
                      repeat=repeat_option))

    ax.set_aspect('auto')

    plt.subplots_adjust(top=0.80)  # TODO: automate this so it's not fixed
    if modes['out_dir'] is not None:
        str_channel = channel_maker(var_ys, modes)
        str_sources = list_to_string(sources, "_")
        # str_channel = list_to_string(var_ys,", ")
        plot_name = var_x + "_over_" + var_t
        plt_file = prep_out_file(modes,
                                 source=str_sources,
                                 plot=plot_name,
                                 dims="nd",
                                 channel=str_channel,
                                 out_type=modes['image_type'])

        try:
            anim[len(anim) - 1].save(plt_file,
                                     dpi=80,
                                     writer='pillow',
                                     savefig_kwargs={
                                         'facecolor': fig.get_facecolor(),
                                         'edgecolor': 'none'
                                     })
        except ValueError:
            if modes['verbose'] >= 1:
                print("ERROR: Unable to save file, try showing instead.")
            try:
                plt.show()
            except:
                print("ERROR: Unable to show file.")

        # plt.close() # TODO: fix this so it works
    else:
        plt.show()  # will just loop the animation forever.