Пример #1
0
def get_box_threshold_vectors(graphdata):
    """Extract significance threshold matrices from different boxes and different
    source variables.

    Returns a list of list, with entries in the first list referring to
    a specific box, and entries in the second list referring to a specific
    source variable.

    """

    valuematrices = []
    # Get number of source variables
    for box in graphdata.boxindex:
        sourcevalues = []
        for sourceindex, sourcevar in enumerate(graphdata.sourcevar):
            sourcefile = filename_sig_template.format(
                graphdata.case,
                graphdata.scenario,
                graphdata.method[0],
                box,
                sourcevar,
            )
            valuematrix, _ = data_processing.read_header_values_datafile(
                sourcefile)
            sourcevalues.append(valuematrix)
        valuematrices.append(sourcevalues)

    return valuematrices
Пример #2
0
def get_scenario_data_vectors(graphdata, example_sourcefile, example_scenario):
    """Extract value matrices from different scenarios."""

    # example_sourcefile is based on the scenario with which the graph was called
    # This graph will be plotted multiple times for all scenarios involved
    # TODO: Find a more elegant solution

    valuematrices = []

    for scenario in graphdata.scenarios:

        # Change sourcefile path on scenario level

        sourcefile = data_processing.change_dirtype(example_sourcefile,
                                                    example_scenario, scenario)
        valuematrix, _ = data_processing.read_header_values_datafile(
            sourcefile)
        valuematrices.append(valuematrix)

    return valuematrices
Пример #3
0
def fig_fft(graphdata, graph, scenario, savedir):
    """Plots FFT over frequency range."""

    graphdata.get_legendbbox(graph)
    graphdata.get_frequencyunit(graph)
    graphdata.get_plotvars(graph)

    sourcefile = os.path.join(
        graphdata.saveloc,
        "fftdata",
        "{}_{}_fft.csv".format(graphdata.case, scenario),
    )

    valuematrix, headers = data_processing.read_header_values_datafile(
        sourcefile)

    plt.figure(1, (12, 6))

    for varname in graphdata.plotvars:
        varindex = headers.index(varname)
        plt.plot(
            valuematrix[:, 0],
            valuematrix[:, varindex],
            "-",
            label=r"${}$".format(headers[varindex]),
        )

    plt.ylabel("Normalised value", fontsize=14)
    plt.xlabel(r"Frequency ({})".format(graphdata.frequencyunit), fontsize=14)
    plt.legend(bbox_to_anchor=graphdata.legendbbox)

    if graphdata.axis_limits is not False:
        plt.axis(graphdata.axis_limits)

    plt.savefig(os.path.join(savedir, "{}_fft.pdf".format(scenario)))
    plt.close()

    return None
Пример #4
0
def fig_maxval_variables(graphdata, graph, scenario, savedir):
    """Generates a figure that shows dependence of method values on
    different variables in a scenario.
    Draws one line for each scenario.
    """

    # Get values for x-axis
    graphdata.get_xvalues(graph)
    graphdata.get_linelabels(graph)
    graphdata.get_legendbbox(graph)

    # Get back from savedir to trends source
    # This is up to the embed type level
    trendsdir = data_processing.change_dirtype(savedir, "graphs", "trends")

    # Extract current method and sigstatus from weightdir
    dirparts = data_processing.getfolders(trendsdir)
    # The method is three folders up from the embed level
    method = dirparts[-3]
    # The sigstatus is two folders up from the embed level
    sigstatus = dirparts[-2]

    plt.figure(1, (12, 6))

    for count, scenario in enumerate(graphdata.scenario):

        sourcefile = filename_template.format(
            graphdata.case,
            scenario,
            graphdata.method[0],
            graphdata.sigstatus,
            graphdata.boxindex,
            graphdata.sourcevar,
        )

        if delays:
            valuematrix, headers = data_processing.read_header_values_datafile(
                sourcefile)
            max_values = [
                max(valuematrix[:, index + 1])
                for index in range(valuematrix.shape[1] - 1)
            ]
        else:
            valuematrix, headers = data_processing.read_header_values_datafile(
                sourcefile)
            max_values = valuematrix[1:]

        plt.plot(
            graphdata.xvals,
            max_values,
            "--",
            marker="o",
            markersize=4,
            label=graphdata.linelabels[count],
        )

        if drawfit:
            graphdata.fitlinelabels(graphname)
            fit_params = np.polyfit(np.log(graphdata.xvals),
                                    np.log(max_values), 1)
            fit_y = [(i * fit_params[0] + fit_params[1])
                     for i in np.log(graphdata.xvals)]
            fitted_vals = [np.exp(val) for val in fit_y]

            plt.loglog(
                graphdata.xvals,
                fitted_vals,
                "--",
                label=graphdata.fitlinelabels[count],
            )

    plt.ylabel(yaxislabel[graphdata.method[0]], fontsize=14)
    plt.xlabel(r"Time constant ($\tau$)", fontsize=14)
    plt.legend(bbox_to_anchor=graphdata.legendbbox)

    if graphdata.axis_limits is not False:
        plt.axis(graphdata.axis_limits)

    plt.savefig(graph_filename_template.format(graphname))
    plt.close()

    return None
Пример #5
0
def fig_values_vs_boxes(graphdata, graph, scenario, savedir):
    """Plots measure values for different boxes and multiple variable pairs.

    Makes use of the trend data generated by trendextraction from arrays.

    """

    graphdata.get_legendbbox(graph)
    graphdata.get_timeunit(graph)
    graphdata.get_sourcevars(graph)
    graphdata.get_destvars(graph)

    # Get back from savedir to trends source
    # This is up to the embed type level
    trendsdir = data_processing.change_dirtype(savedir, "graphs", "trends")

    # Extract current method and sigstatus from weightdir
    dirparts = data_processing.getfolders(trendsdir)
    # The method is three folders up from the embed level
    method = dirparts[-3]
    # The sigstatus is two folders up from the embed level
    sigstatus = dirparts[-2]

    # Select typenames based on method and sigstatus
    if method[:16] == "transfer_entropy":

        typenames = [
            "weight_absolute_trend",
            "signtested_weight_directional_trend",
        ]
        delay_typenames = ["delay_absolute_trend", "delay_directional_trend"]

        if sigstatus == "sigtested":
            typenames.append("sigweight_absolute_trend")
            typenames.append("signtested_sigweight_directional_trend")

    else:
        typenames = ["weight_trend"]
        delay_typenames = ["delay_trend"]

        if sigstatus == "sigtested":
            typenames.append("sigweight_trend")

    # Y axis label lookup dictionary

    yaxislabel_lookup = {
        "weight_absolute_trend": "absolute",
        "signtested_weight_directional_trend": "directional",
        "delay_absolute_trend": "absolute",
        "delay_directional_trend": "directional",
        "sigweight_absolute_trend": "absolute",
        "signtested_sigweight_directional_trend": "directional",
    }

    for typename in typenames:
        for sourcevar in graphdata.sourcevars:

            fig = plt.figure(1, figsize=(12, 6))
            ax = fig.add_subplot(111)
            if len(typename) > 15:
                yaxislabelstring = yaxislabel_lookup[typename] + "_" + method
            else:
                yaxislabelstring = method
            ax.set_ylabel(yaxislabel[yaxislabelstring], fontsize=14)
            ax.set_xlabel(r"Box", fontsize=14)

            # Open data file and plot graph
            sourcefile = os.path.join(trendsdir, sourcevar,
                                      "{}.csv".format(typename))

            valuematrix, headers = data_processing.read_header_values_datafile(
                sourcefile)

            for destvar in graphdata.destvars:
                destvarindex = graphdata.destvars.index(destvar)
                ax.plot(
                    np.arange(len(valuematrix[:, 0])),
                    valuematrix[:, destvarindex],
                    marker="o",
                    markersize=4,
                    label=destvar,
                )

            # Shrink current axis by 20%
            box = ax.get_position()
            ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])

            ax.legend(loc="center left", bbox_to_anchor=graphdata.legendbbox)

            if graphdata.axis_limits is not False:
                ax.axis(graphdata.axis_limits)

            plt.gca().set_ylim(bottom=-0.05)

            plt.savefig(
                os.path.join(
                    savedir,
                    "{}_{}_{}.pdf".format(scenario, typename, sourcevar),
                ))
            plt.close()

    for delay_typename in delay_typenames:
        for sourcevar in graphdata.sourcevars:

            fig = plt.figure(1, figsize=(12, 6))
            ax = fig.add_subplot(111)
            if len(typename) > 8:
                yaxislabelstring = typename[8:] + "_" + method
            else:
                yaxislabelstring = method

            ax.set_ylabel(r"Delay ({})".format(graphdata.timeunit),
                          fontsize=14)
            ax.set_xlabel(r"Box", fontsize=14)

            # Open data file and plot graph
            sourcefile = os.path.join(trendsdir, sourcevar,
                                      "{}.csv".format(delay_typename))

            valuematrix, headers = data_processing.read_header_values_datafile(
                sourcefile)

            for destvar in graphdata.destvars:
                destvarindex = graphdata.destvars.index(destvar)
                ax.plot(
                    np.arange(len(valuematrix[:, 0])),
                    valuematrix[:, destvarindex],
                    marker="o",
                    markersize=4,
                    label=destvar,
                )

            # Shrink current axis by 20%
            box = ax.get_position()
            ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])

            ax.legend(loc="center left", bbox_to_anchor=graphdata.legendbbox)

            if graphdata.axis_limits is not False:
                ax.axis(graphdata.axis_limits)

            plt.gca().set_ylim(bottom=-0.05)

            plt.savefig(
                os.path.join(
                    savedir,
                    "{}_{}_{}.pdf".format(scenario, delay_typename, sourcevar),
                ))
            plt.close()

    return None
Пример #6
0
def fig_diffscen_vs_delay(graphdata, graph, scenario, savedir):
    """Plot one variable from different scenarios.
    Assumes only a single index in varindexes.
    """

    plt.close("all")

    graphdata.get_legendbbox(graph)
    graphdata.get_timeunit(graph)
    graphdata.get_boxindexes(graph)
    graphdata.get_sourcevars(graph)
    graphdata.get_destvars(graph)
    graphdata.get_sigthresholdplotting(graph)
    graphdata.get_linelabels(graph)

    # Get x-axis values
    #    graphdata.get_xvalues(graphname)

    # Get back from savedir to weightdata source
    # This is up to the embed type level
    weightdir = data_processing.change_dirtype(savedir, "graphs", "weightdata")

    # Extract current method from weightdir
    dirparts = data_processing.getfolders(weightdir)
    # The method is two folders up from the embed level
    method = dirparts[-3]

    # Select typenames based on method
    if method[:16] == "transfer_entropy":
        typenames = []
        thresh_typenames = []
        graphdata.get_typenames(graph)
        if "simple" in graphdata.typenames:
            typenames.append("weights_absolute")
            thresh_typenames.append("sigthresh_absolute")
        if "directional" in graphdata.typenames:
            typenames.append("weights_directional")
            thresh_typenames.append("sigthresh_directional")
        # typenames = [
        #     'weights_absolute',
        #     'weights_directional']
        # thresh_typenames = [
        #     'sigthresh_absolute',
        #     'sigthresh_directional']
    else:
        typenames = ["weights"]
        thresh_typenames = ["sigthresh"]

    # Get labels
    if graphdata.linelabels:
        graphdata.get_labelformat(graph)
        labels = [
            graphdata.labelformat.format(linelabel)
            for linelabel in graphdata.linelabels
        ]
    else:
        labels = [destvar for destvar in graphdata.destvars]

    for typeindex, typename in enumerate(typenames):
        for boxindex, sourcevar in itertools.product(graphdata.boxindexes,
                                                     graphdata.sourcevars):

            fig = plt.figure(1, figsize=(12, 6))
            ax = fig.add_subplot(111)
            if len(typename) > 8:
                yaxislabelstring = typename[8:] + "_" + method
            else:
                yaxislabelstring = method
            ax.set_ylabel(yaxislabel[yaxislabelstring], fontsize=14)
            ax.set_xlabel(r"Delay ({})".format(graphdata.timeunit),
                          fontsize=14)

            # Open data file and plot graph
            sourcefile = os.path.join(
                weightdir,
                typename,
                "box{:03d}".format(boxindex),
                "{}.csv".format(sourcevar),
            )

            _, headers = data_processing.read_header_values_datafile(
                sourcefile)

            # Get valuematrices
            valuematrices = plotter.get_scenario_data_vectors(
                graphdata, sourcefile, scenario)

            bbox_props = dict(boxstyle="round", fc="w", ec="0.5", alpha=0.8)

            xaxis_intervals = []
            relevant_values = []
            for scenarioindex, valuematrix in enumerate(valuematrices):
                # # Get the maximum from each valuematrix in the entry
                # # which corresponds to the common element of interest.
                #
                # # TODO: Fix this old hardcoded remnant
                # # 3 referred to the index of tau=1 for many cases involved
                # #        values = valuematrix[:, 3]
                # values = valuematrix[:, graphdata.varindexes]
                # xaxis_intervals.append(valuematrix[:, 0])
                # relevant_values.append(values)

                for destvarindex, destvar in enumerate(graphdata.destvars):
                    destvarvalueindex = headers.index(destvar)
                    ax.plot(
                        valuematrix[:, 0],
                        valuematrix[:, destvarvalueindex],
                        marker=markers[scenarioindex],
                        markersize=8,
                        label=labels[scenarioindex],
                    )

                    label_index = list(
                        valuematrix[:, destvarvalueindex]).index(
                            max(valuematrix[:, destvarvalueindex]))

                    ax.text(
                        valuematrix[:, 0][label_index],
                        valuematrix[:, destvarvalueindex][label_index],
                        labels[scenarioindex],
                        ha="center",
                        va="center",
                        size=10,
                        bbox=bbox_props,
                    )

            if graphdata.axis_limits is not False:
                ax.axis(graphdata.axis_limits)
            else:
                plt.gca().set_ylim(bottom=-0.05)

            plt.savefig(
                os.path.join(
                    savedir,
                    "{}_{}_box{:03d}_{}.pdf".format(scenario, typename,
                                                    boxindex, sourcevar),
                ),
                bbox_inches="tight",
                pad_inches=0,
            )

            plt.close("all")

            # Also save as SVG to allow manual editing
            # plt.savefig(
            #     os.path.join(savedir, '{}_{}_box{:03d}_{}.svg'.format(
            #         scenario, typename, boxindex, sourcevar)),
            #     bbox_inches='tight', pad_inches=0, format='svg')
            # plt.close()

    return None
Пример #7
0
def fig_values_vs_delays(graphdata, graph, scenario, savedir):
    """Generates a figure that shows dependence of method values on delays.

    Constrained to a single scenario, box and source variable.

    Automatically iterates through absolute and directional weights.
    Able to iterate through multiple box indexes and source variables.

    Provides the option to plot weight significance threshold values.

    """
    plt.close("all")

    graphdata.get_legendbbox(graph)
    graphdata.get_timeunit(graph)
    graphdata.get_boxindexes(graph)
    graphdata.get_sourcevars(graph)
    graphdata.get_destvars(graph)
    graphdata.get_sigthresholdplotting(graph)
    graphdata.get_linelabels(graph)

    # Get back from savedir to weightdata source
    # This is up to the embed type level
    weightdir = data_processing.change_dirtype(savedir, "graphs", "weightdata")

    # Extract current method from weightdir
    dirparts = data_processing.getfolders(weightdir)
    # The method is two folders up from the embed level
    method = dirparts[-3]

    # Select typenames based on method
    if method[:16] == "transfer_entropy":
        typenames = []
        thresh_typenames = []
        graphdata.get_typenames(graph)
        if "simple" in graphdata.typenames:
            typenames.append("weights_absolute")
            thresh_typenames.append("sigthresh_absolute")
        if "directional" in graphdata.typenames:
            typenames.append("weights_directional")
            thresh_typenames.append("sigthresh_directional")
        # typenames = [
        #     'weights_absolute',
        #     'weights_directional']
        # thresh_typenames = [
        #     'sigthresh_absolute',
        #     'sigthresh_directional']
    else:
        typenames = ["weights"]
        thresh_typenames = ["sigthresh"]

    # Get labels
    if graphdata.linelabels:
        graphdata.get_labelformat(graph)
        labels = [
            str(graphdata.labelformat).format(linelabel)
            for linelabel in graphdata.linelabels
        ]
    else:
        labels = [destvar for destvar in graphdata.destvars]

    for typeindex, typename in enumerate(typenames):
        for boxindex, sourcevar in itertools.product(graphdata.boxindexes,
                                                     graphdata.sourcevars):

            fig = plt.figure(1, figsize=(12, 6))
            ax = fig.add_subplot(111)
            if len(typename) > 8:
                yaxislabelstring = typename[8:] + "_" + method
            else:
                yaxislabelstring = method
            ax.set_ylabel(yaxislabel[yaxislabelstring], fontsize=14)
            ax.set_xlabel(r"Delay ({})".format(graphdata.timeunit),
                          fontsize=14)

            # Open data file and plot graph
            sourcefile = os.path.join(
                weightdir,
                typename,
                "box{:03d}".format(boxindex),
                "{}.csv".format(sourcevar),
            )

            valuematrix, headers = data_processing.read_header_values_datafile(
                sourcefile)

            if graphdata.thresholdplotting:
                threshold_sourcefile = os.path.join(
                    weightdir,
                    thresh_typenames[typeindex],
                    "box{:03d}".format(boxindex),
                    "{}.csv".format(sourcevar),
                )

                (
                    threshmatrix,
                    headers,
                ) = data_processing.read_header_values_datafile(
                    threshold_sourcefile)

            bbox_props = dict(boxstyle="round", fc="w", ec="0.5", alpha=0.8)

            for destvarindex, destvar in enumerate(graphdata.destvars):
                destvarvalueindex = headers.index(destvar)
                ax.plot(
                    valuematrix[:, 0],
                    valuematrix[:, destvarvalueindex],
                    marker=markers[destvarindex],
                    markersize=8,
                    label=labels[destvarindex],
                )

                label_index = list(valuematrix[:, destvarvalueindex]).index(
                    max(valuematrix[:, destvarvalueindex]))

                ax.text(
                    valuematrix[:, 0][label_index],
                    valuematrix[:, destvarvalueindex][label_index],
                    labels[destvarindex],
                    ha="center",
                    va="center",
                    size=10,
                    bbox=bbox_props,
                )

                if graphdata.thresholdplotting:
                    ax.plot(
                        threshmatrix[:, 0],
                        threshmatrix[:, destvarvalueindex],
                        marker="x",
                        markersize=4,
                        linestyle=":",
                        label=destvar + " threshold",
                    )

            # Shrink current axis by 20%
            #                box = ax.get_position()
            #                ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])

            # ax.legend(loc='center left',
            #          bbox_to_anchor=graphdata.legendbbox)

            if graphdata.axis_limits is not False:
                ax.axis(graphdata.axis_limits)
            else:
                plt.gca().set_ylim(bottom=-0.05)

            plt.savefig(
                os.path.join(
                    savedir,
                    "{}_{}_box{:03d}_{}.pdf".format(scenario, typename,
                                                    boxindex, sourcevar),
                ),
                bbox_inches="tight",
                pad_inches=0,
            )

            # Also save as SVG to allow manual editing
            plt.savefig(
                os.path.join(
                    savedir,
                    "{}_{}_box{:03d}_{}.svg".format(scenario, typename,
                                                    boxindex, sourcevar),
                ),
                bbox_inches="tight",
                pad_inches=0,
                format="svg",
            )
            plt.close()

    return None