示例#1
0
def plot_series(series, x_label, y_label, sort=False, scale="linear",
                bar=False, colour=None, name=None):

    figure_name = saving.build_figure_name("series", name)

    if not colour:
        colour = style.STANDARD_PALETTE[0]

    series_length = series.shape[0]

    x = numpy.linspace(0, series_length, series_length)

    y_log = scale == "log"

    if sort:
        # Sort descending
        series = numpy.sort(series)[::-1]
        x_label = "sorted " + x_label
        figure_name += "-sorted"

    figure = pyplot.figure()
    axis = figure.add_subplot(1, 1, 1)
    seaborn.despine()

    if bar:
        axis.bar(x, series, log=y_log, color=colour, alpha=0.4)
    else:
        axis.plot(x, series, color=colour)
        axis.set_yscale(scale)

    axis.set_xlabel(capitalise_string(x_label))
    axis.set_ylabel(capitalise_string(y_label))

    return figure, figure_name
示例#2
0
def plot_accuracy_evolution(accuracies, name=None):

    figure_name = saving.build_figure_name("accuracies", name)
    figure = pyplot.figure()
    axis = figure.add_subplot(1, 1, 1)
    seaborn.despine()

    for accuracies_kind, accuracies in sorted(accuracies.items()):
        if accuracies is None:
            continue
        elif accuracies_kind == "training":
            line_style = "solid"
            colour = style.STANDARD_PALETTE[0]
        elif accuracies_kind == "validation":
            line_style = "dashed"
            colour = style.STANDARD_PALETTE[1]

        label = "{} set".format(capitalise_string(accuracies_kind))
        epochs = numpy.arange(len(accuracies)) + 1
        axis.plot(epochs,
                  100 * accuracies,
                  color=colour,
                  linestyle=line_style,
                  label=label)

    handles, labels = axis.get_legend_handles_labels()
    labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0]))

    axis.legend(handles, labels, loc="best")

    axis.set_xlabel("Epoch")
    axis.set_ylabel("Accuracies")

    return figure, figure_name
示例#3
0
def plot_correlation_matrix(correlation_matrix, axis_label=None, name=None):

    figure_name = saving.build_figure_name(name)

    figure = pyplot.figure()
    axis = figure.add_subplot(1, 1, 1)

    colour_bar_dictionary = {"label": "Pearson correlation coefficient"}

    seaborn.set(style="white")
    seaborn.heatmap(correlation_matrix,
                    vmin=-1,
                    vmax=1,
                    center=0,
                    xticklabels=False,
                    yticklabels=False,
                    cbar=True,
                    cbar_kws=colour_bar_dictionary,
                    square=True,
                    ax=axis)
    style.reset_plot_look()

    if axis_label:
        axis.set_xlabel(axis_label)
        axis.set_ylabel(axis_label)

    return figure, figure_name
示例#4
0
def plot_correlations(correlation_sets,
                      x_key,
                      y_key,
                      x_label=None,
                      y_label=None,
                      name=None):

    figure_name = saving.build_figure_name("correlations", name)

    if not isinstance(correlation_sets, dict):
        correlation_sets = {"correlations": correlation_sets}

    figure = pyplot.figure()
    axis = figure.add_subplot(1, 1, 1)
    seaborn.despine()

    axis.set_xlabel(x_label)
    axis.set_ylabel(y_label)

    for correlation_set_name, correlation_set in correlation_sets.items():
        axis.scatter(correlation_set[x_key],
                     correlation_set[y_key],
                     label=correlation_set_name)

    if len(correlation_sets) > 1:
        axis.legend(loc="best")

    return figure, figure_name
示例#5
0
文件: images.py 项目: studymeow/scvae
def combine_images_from_data_set(data_set,
                                 indices=None,
                                 number_of_random_examples=None,
                                 name=None):

    image_name = saving.build_figure_name("random_image_examples", name)
    random_state = numpy.random.RandomState(13)

    if indices is not None:
        n_examples = len(indices)
        if number_of_random_examples is not None:
            n_examples = min(n_examples, number_of_random_examples)
            indices = random_state.permutation(indices)[:n_examples]
    else:
        if number_of_random_examples is not None:
            n_examples = number_of_random_examples
        else:
            n_examples = DEFAULT_NUMBER_OF_RANDOM_EXAMPLES_FOR_COMBINED_IMAGES
        indices = random_state.permutation(
            data_set.number_of_examples)[:n_examples]

    if n_examples == 1:
        image_name = saving.build_figure_name("image_example", name)
    else:
        image_name = saving.build_figure_name("image_examples", name)

    width, height = data_set.feature_dimensions

    examples = data_set.values[indices]
    if scipy.sparse.issparse(examples):
        examples = examples.A
    examples = examples.reshape(n_examples, width, height)

    column = int(numpy.ceil(numpy.sqrt(n_examples)))
    row = int(numpy.ceil(n_examples / column))

    image = numpy.zeros((row * width, column * height))

    for m in range(n_examples):
        c = int(m % column)
        r = int(numpy.floor(m / column))
        rows = slice(r * width, (r + 1) * width)
        columns = slice(c * height, (c + 1) * height)
        image[rows, columns] = examples[m]

    return image, image_name
示例#6
0
def plot_cutoff_count_histogram(series,
                                excess_zero_count=0,
                                cutoff=None,
                                normed=False,
                                scale="linear",
                                colour=None,
                                name=None):

    series = series.copy()

    figure_name = "histogram"

    if normed:
        figure_name += "-normed"

    figure_name += "-counts"
    figure_name = saving.build_figure_name(figure_name, name)
    figure_name += "-cutoff-{}".format(cutoff)

    if not colour:
        colour = style.STANDARD_PALETTE[0]

    y_log = scale == "log"

    count_number = numpy.arange(cutoff + 1)
    # Array to count counts of a given count number
    count_number_count = numpy.zeros(cutoff + 1)

    for i in range(cutoff + 1):
        if count_number[i] < cutoff:
            c = (series == count_number[i]).sum()
        elif count_number[i] == cutoff:
            c = (series >= cutoff).sum()
        count_number_count[i] = c

    count_number_count[0] += excess_zero_count

    if normed:
        count_number_count /= count_number_count.sum()

    figure = pyplot.figure()
    axis = figure.add_subplot(1, 1, 1)
    seaborn.despine()

    axis.bar(count_number,
             count_number_count,
             log=y_log,
             color=colour,
             alpha=0.4)

    axis.set_xlabel("Count bins")

    if normed:
        axis.set_ylabel("Frequency")
    else:
        axis.set_ylabel("Number of counts")

    return figure, figure_name
示例#7
0
def plot_probabilities(posterior_probabilities, prior_probabilities,
                       x_label=None, y_label=None, palette=None,
                       uniform=False, name=None):

    figure_name = saving.build_figure_name("probabilities", name)

    if posterior_probabilities is None and prior_probabilities is None:
        raise ValueError("No posterior nor prior probabilities given.")

    n_centroids = 0
    if posterior_probabilities is not None:
        n_posterior_centroids = len(posterior_probabilities)
        n_centroids = max(n_centroids, n_posterior_centroids)
    if prior_probabilities is not None:
        n_prior_centroids = len(prior_probabilities)
        n_centroids = max(n_centroids, n_prior_centroids)

    if not x_label:
        x_label = "$k$"
    if not y_label:
        y_label = "$\\pi_k$"

    figure = pyplot.figure(figsize=(8, 6), dpi=80)
    axis = figure.add_subplot(1, 1, 1)
    seaborn.despine()

    if not palette:
        palette = [style.STANDARD_PALETTE[0]] * n_centroids

    if posterior_probabilities is not None:
        for k in range(n_posterior_centroids):
            axis.bar(k, posterior_probabilities[k], color=palette[k])
        axis.set_ylabel("$\\pi_{\\phi}^k$")
        if prior_probabilities is not None:
            for k in range(n_posterior_centroids):
                axis.plot(
                    [k-0.4, k+0.4],
                    2 * [prior_probabilities[k]],
                    color="black",
                    linestyle="dashed"
                )
            prior_line = matplotlib.lines.Line2D(
                [], [],
                color="black",
                linestyle="dashed",
                label="$\\pi_{\\theta}^k$"
            )
            axis.legend(handles=[prior_line], loc="best", fontsize=18)
    elif prior_probabilities is not None:
        for k in range(n_prior_centroids):
            axis.bar(k, prior_probabilities[k], color=palette[k])
        axis.set_ylabel("$\\pi_{\\theta}^k$")

    axis.set_xlabel(x_label)

    return figure, figure_name
示例#8
0
def plot_centroid_covariance_matrices_evolution(covariance_matrices,
                                                distribution,
                                                name=None):

    distribution = normalise_string(distribution)
    figure_name = "centroids_evolution-{}-covariance_matrices".format(
        distribution)
    figure_name = saving.build_figure_name(figure_name, name)

    y_label = _axis_label_for_symbol(symbol="\\Sigma",
                                     distribution=distribution,
                                     prefix="|",
                                     suffix="(y = k)|")

    n_epochs, n_centroids, __, __ = covariance_matrices.shape
    determinants = numpy.empty([n_epochs, n_centroids])

    for e in range(n_epochs):
        for k in range(n_centroids):
            determinants[e,
                         k] = numpy.prod(numpy.diag(covariance_matrices[e, k]))

    if determinants.all() > 0:
        line_range_ratio = numpy.empty(n_centroids)
        for k in range(n_centroids):
            determinants_min = determinants[:, k].min()
            determinants_max = determinants[:, k].max()
            line_range_ratio[k] = determinants_max / determinants_min
        range_ratio = line_range_ratio.max() / line_range_ratio.min()
        if range_ratio > 1e2:
            y_scale = "log"
        else:
            y_scale = "linear"

    centroids_palette = style.darker_palette(n_centroids)
    epochs = numpy.arange(n_epochs) + 1

    figure = pyplot.figure()
    axis = figure.add_subplot(1, 1, 1)
    seaborn.despine()

    for k in range(n_centroids):
        axis.plot(epochs,
                  determinants[:, k],
                  color=centroids_palette[k],
                  label="$k = {}$".format(k))

    axis.set_xlabel("Epochs")
    axis.set_ylabel(y_label)

    axis.set_yscale(y_scale)

    axis.legend(loc="best")

    return figure, figure_name
示例#9
0
def plot_variable_correlations(values,
                               variable_names=None,
                               colouring_data_set=None,
                               name="variable_correlations"):

    figure_name = saving.build_figure_name(name)
    n_examples, n_features = values.shape

    random_state = numpy.random.RandomState(117)
    shuffled_indices = random_state.permutation(n_examples)
    values = values[shuffled_indices]

    if colouring_data_set:
        labels = colouring_data_set.labels
        class_names = colouring_data_set.class_names
        number_of_classes = colouring_data_set.number_of_classes
        class_palette = colouring_data_set.class_palette
        label_sorter = colouring_data_set.label_sorter

        if not class_palette:
            index_palette = style.lighter_palette(number_of_classes)
            class_palette = {
                class_name: index_palette[i]
                for i, class_name in enumerate(
                    sorted(class_names, key=label_sorter))
            }

        labels = labels[shuffled_indices]

        colours = []

        for label in labels:
            colour = class_palette[label]
            colours.append(colour)

    else:
        colours = style.NEUTRAL_COLOUR

    figure, axes = pyplot.subplots(nrows=n_features,
                                   ncols=n_features,
                                   figsize=[1.5 * n_features] * 2)

    for i in range(n_features):
        for j in range(n_features):
            axes[i, j].scatter(values[:, i], values[:, j], c=colours, s=1)

            axes[i, j].set_xticks([])
            axes[i, j].set_yticks([])

            if i == n_features - 1:
                axes[i, j].set_xlabel(variable_names[j])

        axes[i, 0].set_ylabel(variable_names[i])

    return figure, figure_name
示例#10
0
def plot_variable_label_correlations(variable_vector,
                                     variable_name,
                                     colouring_data_set,
                                     name="variable_label_correlations"):

    figure_name = saving.build_figure_name(name)
    n_examples = variable_vector.shape[0]

    class_names_to_class_ids = numpy.vectorize(
        lambda class_name: colouring_data_set.class_name_to_class_id[class_name
                                                                     ])
    class_ids_to_class_names = numpy.vectorize(
        lambda class_name: colouring_data_set.class_id_to_class_name[class_name
                                                                     ])

    labels = colouring_data_set.labels
    class_names = colouring_data_set.class_names
    number_of_classes = colouring_data_set.number_of_classes
    class_palette = colouring_data_set.class_palette
    label_sorter = colouring_data_set.label_sorter

    if not class_palette:
        index_palette = style.lighter_palette(number_of_classes)
        class_palette = {
            class_name: index_palette[i]
            for i, class_name in enumerate(
                sorted(class_names, key=label_sorter))
        }

    random_state = numpy.random.RandomState(117)
    shuffled_indices = random_state.permutation(n_examples)
    variable_vector = variable_vector[shuffled_indices]

    labels = labels[shuffled_indices]
    label_ids = numpy.expand_dims(class_names_to_class_ids(labels), axis=-1)
    colours = [class_palette[label] for label in labels]

    unique_class_ids = numpy.unique(label_ids)
    unique_class_names = class_ids_to_class_names(unique_class_ids)

    figure = pyplot.figure()
    axis = figure.add_subplot(1, 1, 1)
    seaborn.despine()

    axis.scatter(variable_vector, label_ids, c=colours, s=1)

    axis.set_yticks(unique_class_ids)
    axis.set_yticklabels(unique_class_names)

    axis.set_xlabel(variable_name)
    axis.set_ylabel(capitalise_string(colouring_data_set.terms["class"]))

    return figure, figure_name
示例#11
0
def plot_elbo_heat_map(data_frame,
                       x_label,
                       y_label,
                       z_label=None,
                       z_symbol=None,
                       z_min=None,
                       z_max=None,
                       name=None):

    figure_name = saving.build_figure_name("ELBO_heat_map", name)
    figure = pyplot.figure()
    axis = figure.add_subplot(1, 1, 1)

    if not z_min:
        z_min = data_frame.values.min()
    if not z_max:
        z_max = data_frame.values.max()

    if z_symbol:
        z_label = "$" + z_symbol + "$"

    cbar_dict = {}
    if z_label:
        cbar_dict["label"] = z_label

    seaborn.set(style="white")
    seaborn.heatmap(data_frame,
                    vmin=z_min,
                    vmax=z_max,
                    xticklabels=True,
                    yticklabels=True,
                    cbar=True,
                    cbar_kws=cbar_dict,
                    annot=True,
                    fmt="-.6g",
                    square=False,
                    ax=axis)
    style.reset_plot_look()

    axis.set_xlabel(x_label)
    axis.set_ylabel(y_label)

    return figure, figure_name
示例#12
0
def plot_kl_divergence_evolution(kl_neurons, scale="log", name=None):

    figure_name = saving.build_figure_name("kl_divergence_evolution", name)
    n_epochs, __ = kl_neurons.shape

    kl_neurons = numpy.sort(kl_neurons, axis=1)

    if scale == "log":
        kl_neurons = numpy.log(kl_neurons)
        scale_label = "$\\log$ "
    else:
        scale_label = ""

    figure = pyplot.figure()
    axis = figure.add_subplot(1, 1, 1)

    cbar_dict = {"label": scale_label + "KL$(p_i|q_i)$"}

    number_of_epoch_labels = 10
    if n_epochs > 2 * number_of_epoch_labels:
        epoch_label_frequency = int(
            numpy.floor(n_epochs / number_of_epoch_labels))
    else:
        epoch_label_frequency = True

    epochs = numpy.arange(n_epochs) + 1

    seaborn.heatmap(pandas.DataFrame(kl_neurons.T, columns=epochs),
                    xticklabels=epoch_label_frequency,
                    yticklabels=False,
                    cbar=True,
                    cbar_kws=cbar_dict,
                    cmap=style.STANDARD_COLOUR_MAP,
                    ax=axis)

    axis.set_xlabel("Epochs")
    axis.set_ylabel("$i$")

    seaborn.despine(ax=axis)

    return figure, figure_name
示例#13
0
def plot_centroid_probabilities_evolution(probabilities,
                                          distribution,
                                          linestyle="solid",
                                          name=None):

    distribution = normalise_string(distribution)

    y_label = _axis_label_for_symbol(symbol="\\pi",
                                     distribution=distribution,
                                     suffix="^k")

    figure_name = "centroids_evolution-{}-probabilities".format(distribution)
    figure_name = saving.build_figure_name(figure_name, name)

    n_epochs, n_centroids = probabilities.shape

    centroids_palette = style.darker_palette(n_centroids)
    epochs = numpy.arange(n_epochs) + 1

    figure = pyplot.figure()
    axis = figure.add_subplot(1, 1, 1)
    seaborn.despine()

    for k in range(n_centroids):
        axis.plot(epochs,
                  probabilities[:, k],
                  color=centroids_palette[k],
                  linestyle=linestyle,
                  label="$k = {}$".format(k))

    axis.set_xlabel("Epochs")
    axis.set_ylabel(y_label)

    axis.legend(loc="best")

    return figure, figure_name
示例#14
0
def plot_model_metric_sets(metrics_sets,
                           x_key,
                           y_key,
                           x_label=None,
                           y_label=None,
                           primary_differentiator_key=None,
                           primary_differentiator_order=None,
                           secondary_differentiator_key=None,
                           secondary_differentiator_order=None,
                           special_cases=None,
                           other_method_metrics=None,
                           palette=None,
                           marker_styles=None,
                           name=None):

    figure_name = saving.build_figure_name("model_metric_sets", name)

    if other_method_metrics:
        figure_name += "-other_methods"

    if not isinstance(metrics_sets, list):
        metrics_sets = [metrics_sets]

    if not palette:
        palette = style.STANDARD_PALETTE.copy()

    if not marker_styles:
        marker_styles = [
            "X",  # cross
            "s",  # square
            "D",  # diamond
            "o",  # circle
            "P",  # plus
            "^",  # upright triangle
            "p",  # pentagon
            "*",  # star
        ]

    figure = pyplot.figure()
    axis = figure.add_subplot(1, 1, 1)
    seaborn.despine()

    axis.set_xlabel(x_label)
    axis.set_ylabel(y_label)

    colours = {}
    markers = {}

    for metrics_set in metrics_sets:

        x = numpy.array(metrics_set[x_key])
        y = numpy.array(metrics_set[y_key])

        if x.dtype == "object" or y.dtype == "object":
            continue

        x_mean = x.mean()
        x_ddof = 1 if x.size > 1 else 0
        x_sd = x.std(ddof=x_ddof)

        y_mean = y.mean()
        y_ddof = 1 if y.size > 1 else 0
        y_sd = y.std(ddof=y_ddof)

        colour_key = metrics_set[primary_differentiator_key]
        if colour_key in colours:
            colour = colours[colour_key]
        else:
            try:
                index = primary_differentiator_order.index(colour_key)
                colour = palette[index]
            except (ValueError, IndexError):
                colour = "black"
            colours[colour_key] = colour
            axis.errorbar(x=x_mean,
                          y=y_mean,
                          yerr=y_sd,
                          xerr=x_sd,
                          capsize=2,
                          linestyle="",
                          color=colour,
                          label=colour_key,
                          markersize=7)

        marker_key = metrics_set[secondary_differentiator_key]
        if marker_key in markers:
            marker = markers[marker_key]
        else:
            try:
                index = secondary_differentiator_order.index(marker_key)
                marker = marker_styles[index]
            except (ValueError, IndexError):
                marker = None
            markers[marker_key] = marker
            axis.errorbar(x_mean,
                          y_mean,
                          color="black",
                          marker=marker,
                          linestyle="none",
                          label=marker_key)

        errorbar_colour = colour
        darker_colour = seaborn.dark_palette(colour, n_colors=4)[2]

        special_case_changes = special_cases.get(colour_key, {})
        special_case_changes.update(special_cases.get(marker_key, {}))

        for object_name, object_change in special_case_changes.items():

            if object_name == "errorbar_colour":
                if object_change == "darken":
                    errorbar_colour = darker_colour

        axis.errorbar(x=x_mean,
                      y=y_mean,
                      yerr=y_sd,
                      xerr=x_sd,
                      ecolor=errorbar_colour,
                      capsize=2,
                      color=colour,
                      marker=marker,
                      markeredgecolor=darker_colour,
                      markersize=7)

    baseline_line_styles = ["dashed", "dotted", "dashdotted"]

    if other_method_metrics:
        for method_name, metric_values in other_method_metrics.items():

            y_values = metric_values.get(y_key, None)

            if not y_values:
                continue

            y = numpy.array(y_values)

            y_mean = y.mean()

            if y.shape[0] > 1:
                y_sd = y.std(ddof=1)
            else:
                y_sd = None

            line_style = baseline_line_styles.pop(0)

            axis.axhline(y=y_mean,
                         color=style.STANDARD_PALETTE[-1],
                         linestyle=line_style,
                         label=method_name,
                         zorder=-1)

            if y_sd is not None:
                axis.axhspan(ymin=y_mean - y_sd,
                             ymax=y_mean + y_sd,
                             facecolor=style.STANDARD_PALETTE[-1],
                             alpha=0.1,
                             edgecolor=None,
                             label=method_name,
                             zorder=-2)

    if len(metrics_sets) > 1:

        order = primary_differentiator_order + secondary_differentiator_order
        handles, labels = axis.get_legend_handles_labels()

        label_handles = {}

        for label, handle in zip(labels, handles):
            label_handles.setdefault(label, [])
            label_handles[label].append(handle)

        labels, handles = [], []

        for label, handle_set in label_handles.items():
            labels.append(label)
            handles.append(tuple(handle_set))

        labels, handles = zip(
            *sorted(zip(labels, handles),
                    key=lambda l: [order.index(l[0]), l[0]]
                    if l[0] in order else [len(order), l[0]]))

        axis.legend(handles, labels, loc="best")

    return figure, figure_name
示例#15
0
def plot_model_metrics(metrics_sets,
                       key,
                       label=None,
                       primary_differentiator_key=None,
                       primary_differentiator_order=None,
                       secondary_differentiator_key=None,
                       secondary_differentiator_order=None,
                       palette=None,
                       marker_styles=None,
                       name=None):

    figure_name = saving.build_figure_name("model_metrics", name)

    if not isinstance(metrics_sets, list):
        metrics_sets = [metrics_sets]

    if not palette:
        palette = style.STANDARD_PALETTE.copy()

    figure = pyplot.figure()
    axis = figure.add_subplot(1, 1, 1)
    seaborn.despine()

    axis.set_xlabel(primary_differentiator_key.capitalize() + "s")
    axis.set_ylabel(label)

    x_positions = {}
    x_offsets = {}
    colours = {}

    x_gap = 3
    x_scale = len(secondary_differentiator_order) - 1 + 2 * x_gap

    for metrics_set in metrics_sets:

        y = numpy.array(metrics_set[key])

        if y.dtype == "object":
            continue

        y_mean = y.mean()
        y_ddof = 1 if y.size > 1 else 0
        y_sd = y.std(ddof=y_ddof)

        x_position_key = metrics_set[primary_differentiator_key]
        if x_position_key in x_positions:
            x_position = x_positions[x_position_key]
        else:
            try:
                index = primary_differentiator_order.index(x_position_key)
                x_position = index
            except (ValueError, IndexError):
                x_position = 0
            x_positions[x_position_key] = x_position

        x_offset_key = metrics_set[secondary_differentiator_key]
        if x_offset_key in x_offsets:
            x_offset = x_offsets[x_offset_key]
        else:
            try:
                index = secondary_differentiator_order.index(x_offset_key)
                x_offset = (index + x_gap - x_scale / 2) / x_scale
            except (ValueError, IndexError):
                x_offset = 0
            x_offsets[x_offset_key] = x_offset

        x = x_position + x_offset

        colour_key = x_offset_key
        if colour_key in colours:
            colour = colours[colour_key]
        else:
            try:
                index = secondary_differentiator_order.index(colour_key)
                colour = palette[index]
            except (ValueError, IndexError):
                colour = "black"
            colours[colour_key] = colour
            axis.errorbar(x=x,
                          y=y_mean,
                          yerr=y_sd,
                          capsize=2,
                          linestyle="",
                          color=colour,
                          label=colour_key)

        axis.errorbar(
            x=x,
            y=y_mean,
            yerr=y_sd,
            ecolor=colour,
            capsize=2,
            color=colour,
            marker="_",
            # markeredgecolor=darker_colour,
            # markersize=7
        )

    x_ticks = []
    x_tick_labels = []

    for model, x_position in x_positions.items():
        x_ticks.append(x_position)
        x_tick_labels.append(model)

    axis.set_xticks(x_ticks)
    axis.set_xticklabels(x_tick_labels)

    if len(metrics_sets) > 1:

        order = primary_differentiator_order + secondary_differentiator_order
        handles, labels = axis.get_legend_handles_labels()

        label_handles = {}

        for label, handle in zip(labels, handles):
            label_handles.setdefault(label, [])
            label_handles[label].append(handle)

        labels, handles = [], []

        for label, handle_set in label_handles.items():
            labels.append(label)
            handles.append(tuple(handle_set))

        labels, handles = zip(
            *sorted(zip(labels, handles),
                    key=lambda l: [order.index(l[0]), l[0]]
                    if l[0] in order else [len(order), l[0]]))

        axis.legend(handles, labels, loc="best")

    return figure, figure_name
示例#16
0
文件: matrices.py 项目: ritwik7/scvae
def plot_heat_map(values,
                  x_name,
                  y_name,
                  z_name=None,
                  z_symbol=None,
                  z_min=None,
                  z_max=None,
                  symmetric=False,
                  labels=None,
                  label_kind=None,
                  center=None,
                  name=None):

    figure_name = saving.build_figure_name("heat_map", name)
    n_examples, n_features = values.shape

    if symmetric and n_examples != n_features:
        raise ValueError(
            "Input cannot be symmetric, when it is not given as a 2-d square"
            "array or matrix.")

    figure = pyplot.figure()
    axis = figure.add_subplot(1, 1, 1)

    if not z_min:
        z_min = values.min()

    if not z_max:
        z_max = values.max()

    if z_symbol:
        z_name = "$" + z_symbol + "$"

    cbar_dict = {}

    if z_name:
        cbar_dict["label"] = z_name

    if not symmetric:
        aspect_ratio = n_examples / n_features
        square_cells = 1 / 5 < aspect_ratio and aspect_ratio < 5
    else:
        square_cells = True

    if labels is not None:
        x_indices = numpy.argsort(labels)
        y_name += " sorted"
        if label_kind:
            y_name += " by " + label_kind
    else:
        x_indices = numpy.arange(n_examples)

    if symmetric:
        y_indices = x_indices
        x_name = y_name
    else:
        y_indices = numpy.arange(n_features)

    seaborn.set(style="white")
    seaborn.heatmap(values[x_indices][:, y_indices],
                    vmin=z_min,
                    vmax=z_max,
                    center=center,
                    xticklabels=False,
                    yticklabels=False,
                    cbar=True,
                    cbar_kws=cbar_dict,
                    cmap=style.STANDARD_COLOUR_MAP,
                    square=square_cells,
                    ax=axis)
    style.reset_plot_look()

    axis.set_xlabel(x_name)
    axis.set_ylabel(y_name)

    return figure, figure_name
示例#17
0
文件: matrices.py 项目: ritwik7/scvae
def plot_matrix(feature_matrix,
                plot_distances=False,
                center_value=None,
                example_label=None,
                feature_label=None,
                value_label=None,
                sorting_method=None,
                distance_metric="Euclidean",
                labels=None,
                label_kind=None,
                class_palette=None,
                feature_indices_for_plotting=None,
                hide_dendrogram=False,
                name_parts=None):

    figure_name = saving.build_figure_name(name_parts)
    n_examples, n_features = feature_matrix.shape

    if plot_distances:
        center_value = None
        feature_label = None
        value_label = "Pairwise {} distances in {} space".format(
            distance_metric, value_label)

    if not plot_distances and feature_indices_for_plotting is None:
        feature_indices_for_plotting = numpy.arange(n_features)

    if sorting_method == "labels" and labels is None:
        raise ValueError("No labels provided to sort after.")

    if labels is not None and not class_palette:
        raise ValueError("No class palette provided.")

    # Distances (if needed)
    distances = None
    if plot_distances or sorting_method == "hierarchical_clustering":
        distances = sklearn.metrics.pairwise_distances(
            feature_matrix, metric=distance_metric.lower())

    # Figure initialisation
    figure = pyplot.figure()

    axis_heat_map = figure.add_subplot(1, 1, 1)
    left_most_axis = axis_heat_map

    divider = make_axes_locatable(axis_heat_map)
    axis_colour_map = divider.append_axes("right", size="5%", pad=0.1)

    axis_labels = None
    axis_dendrogram = None

    if labels is not None:
        axis_labels = divider.append_axes("left", size="5%", pad=0.01)
        left_most_axis = axis_labels

    if sorting_method == "hierarchical_clustering" and not hide_dendrogram:
        axis_dendrogram = divider.append_axes("left", size="20%", pad=0.01)
        left_most_axis = axis_dendrogram

    # Label colours
    if labels is not None:
        label_colours = [
            tuple(colour) if isinstance(colour, list) else colour
            for colour in [class_palette[l] for l in labels]
        ]
        unique_colours = [
            tuple(colour) if isinstance(colour, list) else colour
            for colour in class_palette.values()
        ]
        value_for_colour = {
            colour: i
            for i, colour in enumerate(unique_colours)
        }
        label_colour_matrix = numpy.array([
            value_for_colour[colour] for colour in label_colours
        ]).reshape(n_examples, 1)
        label_colour_map = matplotlib.colors.ListedColormap(unique_colours)
    else:
        label_colour_matrix = None
        label_colour_map = None

    # Heat map aspect ratio
    if not plot_distances:
        square_cells = False
    else:
        square_cells = True

    seaborn.set(style="white")

    # Sorting and optional dendrogram
    if sorting_method == "labels":
        example_indices = numpy.argsort(labels)

        if not label_kind:
            label_kind = "labels"

        if example_label:
            example_label += " sorted by " + label_kind

    elif sorting_method == "hierarchical_clustering":
        linkage = scipy.cluster.hierarchy.linkage(
            scipy.spatial.distance.squareform(distances, checks=False),
            metric="average")
        dendrogram = seaborn.matrix.dendrogram(distances,
                                               linkage=linkage,
                                               metric=None,
                                               method="ward",
                                               axis=0,
                                               label=False,
                                               rotate=True,
                                               ax=axis_dendrogram)
        example_indices = dendrogram.reordered_ind

        if example_label:
            example_label += " sorted by hierarchical clustering"

    elif sorting_method is None:
        example_indices = numpy.arange(n_examples)

    else:
        raise ValueError("`sorting_method` should be either \"labels\""
                         " or \"hierarchical clustering\"")

    # Heat map of values
    if plot_distances:
        plot_values = distances[example_indices][:, example_indices]
    else:
        plot_values = feature_matrix[
            example_indices][:, feature_indices_for_plotting]

    if scipy.sparse.issparse(plot_values):
        plot_values = plot_values.A

    colour_bar_dictionary = {}

    if value_label:
        colour_bar_dictionary["label"] = value_label

    seaborn.heatmap(plot_values,
                    center=center_value,
                    xticklabels=False,
                    yticklabels=False,
                    cbar=True,
                    cbar_kws=colour_bar_dictionary,
                    cbar_ax=axis_colour_map,
                    square=square_cells,
                    ax=axis_heat_map)

    # Colour labels
    if axis_labels:
        seaborn.heatmap(label_colour_matrix[example_indices],
                        xticklabels=False,
                        yticklabels=False,
                        cbar=False,
                        cmap=label_colour_map,
                        ax=axis_labels)

    style.reset_plot_look()

    # Axis labels
    if example_label:
        left_most_axis.set_ylabel(example_label)

    if feature_label:
        axis_heat_map.set_xlabel(feature_label)

    return figure, figure_name
示例#18
0
def plot_learning_curves(curves, model_type, epoch_offset=0, name=None):

    figure_name = "learning_curves"
    figure_name = saving.build_figure_name("learning_curves", name)

    x_label = "Epoch"
    y_label = "Nat"

    if model_type == "AE":
        figure = pyplot.figure()
        axis_1 = figure.add_subplot(1, 1, 1)
    elif model_type == "VAE":
        figure, (axis_1, axis_2) = pyplot.subplots(nrows=2,
                                                   sharex=True,
                                                   figsize=(6.4, 9.6))
        figure.subplots_adjust(hspace=0.1)
    elif model_type == "GMVAE":
        figure, (axis_1, axis_2, axis_3) = pyplot.subplots(nrows=3,
                                                           sharex=True,
                                                           figsize=(6.4, 14.4))
        figure.subplots_adjust(hspace=0.1)

    for curve_set_name, curve_set in sorted(curves.items()):

        if curve_set_name == "training":
            line_style = "solid"
            colour_index_offset = 0
        elif curve_set_name == "validation":
            line_style = "dashed"
            colour_index_offset = 1

        def curve_colour(i):
            return style.STANDARD_PALETTE[len(curves) * i +
                                          colour_index_offset]

        for curve_name, curve in sorted(curve_set.items()):
            if curve is None:
                continue
            elif curve_name == "lower_bound":
                curve_name = "$\\mathcal{L}$"
                colour = curve_colour(0)
                axis = axis_1
            elif curve_name == "reconstruction_error":
                curve_name = "$\\log p(x|z)$"
                colour = curve_colour(1)
                axis = axis_1
            elif "kl_divergence" in curve_name:
                if curve_name == "kl_divergence":
                    index = ""
                    colour = curve_colour(0)
                    axis = axis_2
                else:
                    latent_variable = curve_name.replace("kl_divergence_", "")
                    latent_variable = re.sub(pattern=r"(\w)(\d)",
                                             repl=r"\1_\2",
                                             string=latent_variable)
                    index = "$_{" + latent_variable + "}$"
                    if latent_variable in ["z", "z_2"]:
                        colour = curve_colour(0)
                        axis = axis_2
                    elif latent_variable == "z_1":
                        colour = curve_colour(1)
                        axis = axis_2
                    elif latent_variable == "y":
                        colour = curve_colour(0)
                        axis = axis_3
                curve_name = "KL" + index + "$(q||p)$"
            elif curve_name == "log_likelihood":
                curve_name = "$L$"
                axis = axis_1
            epochs = numpy.arange(len(curve)) + epoch_offset + 1
            label = "{} ({} set)".format(curve_name, curve_set_name)
            axis.plot(epochs,
                      curve,
                      color=colour,
                      linestyle=line_style,
                      label=label)

    handles, labels = axis_1.get_legend_handles_labels()
    labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0]))

    axis_1.legend(handles, labels, loc="best")

    if model_type == "AE":
        axis_1.set_xlabel(x_label)
        axis_1.set_ylabel(y_label)
    elif model_type == "VAE":
        handles, labels = axis_2.get_legend_handles_labels()
        labels, handles = zip(
            *sorted(zip(labels, handles), key=lambda t: t[0]))
        axis_2.legend(handles, labels, loc="best")
        axis_1.set_ylabel("")
        axis_2.set_ylabel("")

        if model_type == "GMVAE":
            axis_3.legend(loc="best")
            handles, labels = axis_3.get_legend_handles_labels()
            labels, handles = zip(
                *sorted(zip(labels, handles), key=lambda t: t[0]))
            axis_3.legend(handles, labels, loc="best")
            axis_3.set_xlabel(x_label)
            axis_3.set_ylabel("")
        else:
            axis_2.set_xlabel(x_label)
        figure.text(-0.01, 0.5, y_label, va="center", rotation="vertical")

    seaborn.despine()

    return figure, figure_name
示例#19
0
def plot_histogram(series, excess_zero_count=0, label=None, normed=False,
                   discrete=False, x_scale="linear", y_scale="linear",
                   colour=None, name=None):

    series = series.copy()

    figure_name = "histogram"

    if normed:
        figure_name += "-normed"

    figure_name = saving.build_figure_name(figure_name, name)

    figure = pyplot.figure()
    axis = figure.add_subplot(1, 1, 1)
    seaborn.despine()

    series_length = len(series) + excess_zero_count

    series_max = series.max()

    if discrete and series_max < MAXIMUM_NUMBER_OF_BINS_FOR_HISTOGRAMS:
        number_of_bins = int(numpy.ceil(series_max)) + 1
        bin_range = numpy.array((-0.5, series_max + 0.5))
    else:
        if series_max < MAXIMUM_NUMBER_OF_BINS_FOR_HISTOGRAMS:
            number_of_bins = "auto"
        else:
            number_of_bins = MAXIMUM_NUMBER_OF_BINS_FOR_HISTOGRAMS
        bin_range = numpy.array((series.min(), series_max))

    if colour is None:
        colour = style.STANDARD_PALETTE[0]

    if x_scale == "log":
        series += 1
        bin_range += 1
        label += " (shifted one)"
        figure_name += "-log_values"

    y_log = y_scale == "log"

    histogram, bin_edges = numpy.histogram(
        series,
        bins=number_of_bins,
        range=bin_range
    )

    histogram[0] += excess_zero_count

    width = bin_edges[1] - bin_edges[0]
    bin_centres = bin_edges[:-1] + width / 2

    if normed:
        histogram = histogram / series_length

    axis.bar(
        bin_centres,
        histogram,
        width=width,
        log=y_log,
        color=colour,
        alpha=0.4
    )

    axis.set_xscale(x_scale)
    axis.set_xlabel(capitalise_string(label))

    if normed:
        axis.set_ylabel("Frequency")
    else:
        axis.set_ylabel("Number of counts")

    return figure, figure_name
示例#20
0
def plot_class_histogram(labels, class_names=None, class_palette=None,
                         normed=False, scale="linear", label_sorter=None,
                         name=None):

    figure_name = "histogram"

    if normed:
        figure_name += "-normed"

    figure_name += "-classes"

    figure_name = saving.build_figure_name(figure_name, name)

    figure = pyplot.figure()
    axis = figure.add_subplot(1, 1, 1)
    seaborn.despine()

    if class_names is None:
        class_names = numpy.unique(labels)

    n_classes = len(class_names)

    if not class_palette:
        index_palette = style.lighter_palette(n_classes)
        class_palette = {
            class_name: index_palette[i] for i, class_name in enumerate(sorted(
                class_names, key=label_sorter))
        }

    histogram = {
        class_name: {
            "index": i,
            "count": 0,
            "colour": class_palette[class_name]
        }
        for i, class_name in enumerate(sorted(
            class_names, key=label_sorter))
    }

    total_count_sum = 0

    for label in labels:
        histogram[label]["count"] += 1
        total_count_sum += 1

    indices = []
    class_names = []

    for class_name, class_values in sorted(histogram.items()):
        index = class_values["index"]
        count = class_values["count"]
        frequency = count / total_count_sum
        colour = class_values["colour"]
        indices.append(index)
        class_names.append(class_name)

        if normed:
            count_or_frequecny = frequency
        else:
            count_or_frequecny = count

        axis.bar(index, count_or_frequecny, color=colour)

    axis.set_yscale(scale)

    maximum_class_name_width = max([
        len(str(class_name)) for class_name in class_names
        if class_name not in ["No class"]
    ])
    if maximum_class_name_width > 5:
        y_ticks_rotation = 45
        y_ticks_horizontal_alignment = "right"
        y_ticks_rotation_mode = "anchor"
    else:
        y_ticks_rotation = 0
        y_ticks_horizontal_alignment = "center"
        y_ticks_rotation_mode = None
    pyplot.xticks(
        ticks=indices,
        labels=class_names,
        horizontalalignment=y_ticks_horizontal_alignment,
        rotation=y_ticks_rotation,
        rotation_mode=y_ticks_rotation_mode
    )

    axis.set_xlabel("Classes")

    if normed:
        axis.set_ylabel("Frequency")
    else:
        axis.set_ylabel("Number of counts")

    return figure, figure_name
示例#21
0
def plot_separate_learning_curves(curves, loss, name=None):

    if not isinstance(loss, list):
        losses = [loss]
    else:
        losses = loss

    if not isinstance(name, list):
        names = [name]
    else:
        names = name

    names.extend(losses)
    figure_name = saving.build_figure_name("learning_curves", names)

    x_label = "Epoch"
    y_label = "Nat"

    figure = pyplot.figure()
    axis = figure.add_subplot(1, 1, 1)
    seaborn.despine()

    for curve_set_name, curve_set in sorted(curves.items()):

        if curve_set_name == "training":
            line_style = "solid"
            colour_index_offset = 0
        elif curve_set_name == "validation":
            line_style = "dashed"
            colour_index_offset = 1

        def curve_colour(i):
            return style.STANDARD_PALETTE[len(curves) * i +
                                          colour_index_offset]

        for curve_name, curve in sorted(curve_set.items()):
            if curve is None or curve_name not in losses:
                continue
            elif curve_name == "lower_bound":
                curve_name = "$\\mathcal{L}$"
                colour = curve_colour(0)
            elif curve_name == "reconstruction_error":
                curve_name = "$\\log p(x|z)$"
                colour = curve_colour(1)
            elif "kl_divergence" in curve_name:
                if curve_name == "kl_divergence":
                    index = ""
                    colour = curve_colour(0)
                else:
                    latent_variable = curve_name.replace("kl_divergence_", "")
                    latent_variable = re.sub(pattern=r"(\w)(\d)",
                                             repl=r"\1_\2",
                                             string=latent_variable)
                    index = "$_{" + latent_variable + "}$"
                    if latent_variable in ["z", "z_2"]:
                        colour = curve_colour(0)
                    elif latent_variable == "z_1":
                        colour = curve_colour(1)
                    elif latent_variable == "y":
                        colour = curve_colour(0)
                curve_name = "KL" + index + "$(q||p)$"
            elif curve_name == "log_likelihood":
                curve_name = "$L$"
            epochs = numpy.arange(len(curve)) + 1
            label = curve_name + " ({} set)".format(curve_set_name)
            axis.plot(epochs,
                      curve,
                      color=colour,
                      linestyle=line_style,
                      label=label)

    handles, labels = axis.get_legend_handles_labels()
    labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0]))

    axis.legend(handles, labels, loc="best")

    axis.set_xlabel(x_label)
    axis.set_ylabel(y_label)

    return figure, figure_name
示例#22
0
def plot_centroid_means_evolution(means,
                                  distribution,
                                  decomposed=False,
                                  name=None):

    symbol = "\\mu"
    if decomposed:
        decomposition_method = "PCA"
    else:
        decomposition_method = ""
    distribution = normalise_string(distribution)
    suffix = "(y = k)"

    x_label = _axis_label_for_symbol(symbol=symbol,
                                     coordinate=1,
                                     decomposition_method=decomposition_method,
                                     distribution=distribution,
                                     suffix=suffix)
    y_label = _axis_label_for_symbol(symbol=symbol,
                                     coordinate=2,
                                     decomposition_method=decomposition_method,
                                     distribution=distribution,
                                     suffix=suffix)

    figure_name = "centroids_evolution-{}-means".format(distribution)
    figure_name = saving.build_figure_name(figure_name, name)

    n_epochs, n_centroids, latent_size = means.shape

    if latent_size > 2:
        raise ValueError("Dimensions of means should be 2.")

    centroids_palette = style.darker_palette(n_centroids)
    epochs = numpy.arange(n_epochs) + 1

    figure = pyplot.figure()
    axis = figure.add_subplot(1, 1, 1)
    seaborn.despine()

    colour_bar_scatter_plot = axis.scatter(means[:, 0, 0],
                                           means[:, 0, 1],
                                           c=epochs,
                                           cmap=seaborn.dark_palette(
                                               style.NEUTRAL_COLOUR,
                                               as_cmap=True),
                                           zorder=0)

    for k in range(n_centroids):
        colour = centroids_palette[k]
        colour_map = seaborn.dark_palette(colour, as_cmap=True)
        axis.plot(means[:, k, 0],
                  means[:, k, 1],
                  color=colour,
                  label="$k = {}$".format(k),
                  zorder=k + 1)
        axis.scatter(means[:, k, 0],
                     means[:, k, 1],
                     c=epochs,
                     cmap=colour_map,
                     zorder=n_centroids + k + 1)

    axis.legend(loc="best")

    colour_bar = figure.colorbar(colour_bar_scatter_plot)
    colour_bar.outline.set_linewidth(0)
    colour_bar.set_label("Epochs")

    axis.set_xlabel(x_label)
    axis.set_ylabel(y_label)

    return figure, figure_name
示例#23
0
def plot_profile_comparison(observed_series, expected_series,
                            expected_series_total_standard_deviations=None,
                            expected_series_explained_standard_deviations=None,
                            x_name="feature", y_name="value", sort=True,
                            sort_by="expected", sort_direction="ascending",
                            x_scale="linear", y_scale="linear", y_cutoff=None,
                            name=None):

    sort_by = normalise_string(sort_by)
    sort_direction = normalise_string(sort_direction)
    figure_name = saving.build_figure_name("profile_comparison", name)

    if scipy.sparse.issparse(observed_series):
        observed_series = observed_series.A.squeeze()

    if scipy.sparse.issparse(expected_series_total_standard_deviations):
        expected_series_total_standard_deviations = (
            expected_series_total_standard_deviations.A.squeeze())

    if scipy.sparse.issparse(expected_series_explained_standard_deviations):
        expected_series_explained_standard_deviations = (
            expected_series_explained_standard_deviations.A.squeeze())

    observed_colour = style.STANDARD_PALETTE[0]
    expected_palette = seaborn.light_palette(style.STANDARD_PALETTE[1], 5)

    expected_colour = expected_palette[-1]
    expected_total_standard_deviations_colour = expected_palette[1]
    expected_explained_standard_deviations_colour = expected_palette[3]

    if sort:
        x_label = "{}s sorted {} by {} {}s [sort index]".format(
            capitalise_string(x_name), sort_direction, sort_by, y_name.lower())
    else:
        x_label = "{}s [original index]".format(capitalise_string(x_name))
    y_label = capitalise_string(y_name) + "s"

    observed_label = "Observed"
    expected_label = "Expected"
    expected_total_standard_deviations_label = "Total standard deviation"
    expected_explained_standard_deviations_label = (
        "Explained standard deviation")

    # Sorting
    if sort_by == "expected":
        sort_series = expected_series
        expected_marker = ""
        expected_line_style = "solid"
        expected_z_order = 3
        observed_marker = "o"
        observed_line_style = ""
        observed_z_order = 2
    elif sort_by == "observed":
        sort_series = observed_series
        expected_marker = "o"
        expected_line_style = ""
        expected_z_order = 2
        observed_marker = ""
        observed_line_style = "solid"
        observed_z_order = 3

    if sort:
        sort_indices = numpy.argsort(sort_series)
        if sort_direction == "descending":
            sort_indices = sort_indices[::-1]
        elif sort_direction != "ascending":
            raise ValueError(
                "Sort direction can either be ascending or descending.")
    else:
        sort_indices = slice(None)

    # Standard deviations
    if expected_series_total_standard_deviations is not None:
        with_total_standard_deviations = True
        expected_series_total_standard_deviations_lower = (
            expected_series - expected_series_total_standard_deviations)
        expected_series_total_standard_deviations_upper = (
            expected_series + expected_series_total_standard_deviations)
    else:
        with_total_standard_deviations = False

    if (expected_series_explained_standard_deviations is not None
            and expected_series_explained_standard_deviations.mean() > 0):
        with_explained_standard_deviations = True
        expected_series_explained_standard_deviations_lower = (
            expected_series - expected_series_explained_standard_deviations)
        expected_series_explained_standard_deviations_upper = (
            expected_series + expected_series_explained_standard_deviations)
    else:
        with_explained_standard_deviations = False

    # Figure
    if y_scale == "both":
        figure, axes = pyplot.subplots(nrows=2, sharex=True)
        figure.subplots_adjust(hspace=0.1)
        axis_upper = axes[0]
        axis_lower = axes[1]
        axis_upper.set_zorder = 1
        axis_lower.set_zorder = 0
    else:
        figure = pyplot.figure()
        axis = figure.add_subplot(1, 1, 1)
        axes = [axis]

    handles = []
    feature_indices = numpy.arange(len(observed_series)) + 1

    for i, axis in enumerate(axes):
        observed_plot, = axis.plot(
            feature_indices,
            observed_series[sort_indices],
            label=observed_label,
            color=observed_colour,
            marker=observed_marker,
            linestyle=observed_line_style,
            zorder=observed_z_order
        )
        if i == 0:
            handles.append(observed_plot)
        expected_plot, = axis.plot(
            feature_indices,
            expected_series[sort_indices],
            label=expected_label,
            color=expected_colour,
            marker=expected_marker,
            linestyle=expected_line_style,
            zorder=expected_z_order
        )
        if i == 0:
            handles.append(expected_plot)
        if with_total_standard_deviations:
            axis.fill_between(
                feature_indices,
                expected_series_total_standard_deviations_lower[sort_indices],
                expected_series_total_standard_deviations_upper[sort_indices],
                color=expected_total_standard_deviations_colour,
                zorder=0
            )
            expected_plot_standard_deviations_values = (
                matplotlib.patches.Patch(
                    label=expected_total_standard_deviations_label,
                    color=expected_total_standard_deviations_colour
                )
            )
            if i == 0:
                handles.append(expected_plot_standard_deviations_values)
        if with_explained_standard_deviations:
            axis.fill_between(
                feature_indices,
                expected_series_explained_standard_deviations_lower[
                    sort_indices],
                expected_series_explained_standard_deviations_upper[
                    sort_indices],
                color=expected_explained_standard_deviations_colour,
                zorder=1
            )
            expected_plot_standard_deviations_expectations = (
                matplotlib.patches.Patch(
                    label=expected_explained_standard_deviations_label,
                    color=expected_explained_standard_deviations_colour
                )
            )
            if i == 0:
                handles.append(expected_plot_standard_deviations_expectations)

    if y_scale == "both":
        axis_upper.legend(
            handles=handles,
            loc="best"
        )
        seaborn.despine(ax=axis_upper)
        seaborn.despine(ax=axis_lower)

        axis_upper.set_yscale("log", nonposy="clip")
        axis_lower.set_yscale("linear")
        figure.text(0.04, 0.5, y_label, va="center", rotation="vertical")

        axis_lower.set_xscale(x_scale)
        axis_lower.set_xlabel(x_label)

        y_upper_min, y_upper_max = axis_upper.get_ylim()
        y_lower_min, y_lower_max = axis_lower.get_ylim()
        axis_upper.set_ylim(y_cutoff, y_upper_max)

        y_lower_min = max(-1, y_lower_min)
        axis_lower.set_ylim(y_lower_min, y_cutoff)

    else:
        axis.legend(
            handles=handles,
            loc="best"
        )
        seaborn.despine()

        y_scale_arguments = {}
        if y_scale == "log":
            y_scale_arguments["nonposy"] = "clip"
        axis.set_yscale(y_scale, **y_scale_arguments)
        axis.set_ylabel(y_label)

        axis.set_xscale(x_scale)
        axis.set_xlabel(x_label)

        y_min, y_max = axis.get_ylim()
        y_min = max(-1, y_min)

        if y_cutoff:
            if y_scale == "linear":
                y_max = y_cutoff
            elif y_scale == "log":
                y_min = y_cutoff

        axis.set_ylim(y_min, y_max)

    return figure, figure_name