Beispiel #1
0
def plot_heat_map(values,
                  x_name,
                  y_name,
                  z_name=None,
                  z_symbol=None,
                  z_min=None,
                  z_max=None,
                  symmetric=False,
                  labels=None,
                  label_kind=None,
                  center=None,
                  name=None):

    figure_name = saving.build_figure_name("heat_map", name)
    n_examples, n_features = values.shape

    if symmetric and n_examples != n_features:
        raise ValueError(
            "Input cannot be symmetric, when it is not given as a 2-d square"
            "array or matrix.")

    figure = pyplot.figure()
    axis = figure.add_subplot(1, 1, 1)

    if not z_min:
        z_min = values.min()

    if not z_max:
        z_max = values.max()

    if z_symbol:
        z_name = "$" + z_symbol + "$"

    cbar_dict = {}

    if z_name:
        cbar_dict["label"] = z_name

    if not symmetric:
        aspect_ratio = n_examples / n_features
        square_cells = 1 / 5 < aspect_ratio and aspect_ratio < 5
    else:
        square_cells = True

    if labels is not None:
        x_indices = numpy.argsort(labels)
        y_name += " sorted"
        if label_kind:
            y_name += " by " + label_kind
    else:
        x_indices = numpy.arange(n_examples)

    if symmetric:
        y_indices = x_indices
        x_name = y_name
    else:
        y_indices = numpy.arange(n_features)

    seaborn.set(style="white")
    seaborn.heatmap(values[x_indices][:, y_indices],
                    vmin=z_min,
                    vmax=z_max,
                    center=center,
                    xticklabels=False,
                    yticklabels=False,
                    cbar=True,
                    cbar_kws=cbar_dict,
                    cmap=style.STANDARD_COLOUR_MAP,
                    square=square_cells,
                    ax=axis)
    style.reset_plot_look()

    axis.set_xlabel(x_name)
    axis.set_ylabel(y_name)

    return figure, figure_name
Beispiel #2
0
    "plot_probabilities", "plot_learning_curves",
    "plot_separate_learning_curves", "plot_accuracy_evolution",
    "plot_kl_divergence_evolution", "plot_centroid_probabilities_evolution",
    "plot_centroid_means_evolution",
    "plot_centroid_covariance_matrices_evolution", "plot_matrix",
    "plot_heat_map", "plot_values", "plot_variable_correlations",
    "plot_series", "plot_profile_comparison", "save_figure"
]

from scvae.analyses.figures.histograms import (plot_histogram,
                                               plot_cutoff_count_histogram,
                                               plot_class_histogram,
                                               plot_probabilities)
from scvae.analyses.figures.learning_curves import (
    plot_learning_curves,
    plot_separate_learning_curves,
    plot_accuracy_evolution,
    plot_kl_divergence_evolution,
    plot_centroid_probabilities_evolution,
    plot_centroid_means_evolution,
    plot_centroid_covariance_matrices_evolution,
)
from scvae.analyses.figures.matrices import plot_matrix, plot_heat_map
from scvae.analyses.figures.saving import save_figure
from scvae.analyses.figures.scatter import (plot_values,
                                            plot_variable_correlations)
from scvae.analyses.figures.series import plot_series, plot_profile_comparison
from scvae.analyses.figures.style import reset_plot_look

reset_plot_look()
Beispiel #3
0
def plot_matrix(feature_matrix,
                plot_distances=False,
                center_value=None,
                example_label=None,
                feature_label=None,
                value_label=None,
                sorting_method=None,
                distance_metric="Euclidean",
                labels=None,
                label_kind=None,
                class_palette=None,
                feature_indices_for_plotting=None,
                hide_dendrogram=False,
                name_parts=None):

    figure_name = saving.build_figure_name(name_parts)
    n_examples, n_features = feature_matrix.shape

    if plot_distances:
        center_value = None
        feature_label = None
        value_label = "Pairwise {} distances in {} space".format(
            distance_metric, value_label)

    if not plot_distances and feature_indices_for_plotting is None:
        feature_indices_for_plotting = numpy.arange(n_features)

    if sorting_method == "labels" and labels is None:
        raise ValueError("No labels provided to sort after.")

    if labels is not None and not class_palette:
        raise ValueError("No class palette provided.")

    # Distances (if needed)
    distances = None
    if plot_distances or sorting_method == "hierarchical_clustering":
        distances = sklearn.metrics.pairwise_distances(
            feature_matrix, metric=distance_metric.lower())

    # Figure initialisation
    figure = pyplot.figure()

    axis_heat_map = figure.add_subplot(1, 1, 1)
    left_most_axis = axis_heat_map

    divider = make_axes_locatable(axis_heat_map)
    axis_colour_map = divider.append_axes("right", size="5%", pad=0.1)

    axis_labels = None
    axis_dendrogram = None

    if labels is not None:
        axis_labels = divider.append_axes("left", size="5%", pad=0.01)
        left_most_axis = axis_labels

    if sorting_method == "hierarchical_clustering" and not hide_dendrogram:
        axis_dendrogram = divider.append_axes("left", size="20%", pad=0.01)
        left_most_axis = axis_dendrogram

    # Label colours
    if labels is not None:
        label_colours = [
            tuple(colour) if isinstance(colour, list) else colour
            for colour in [class_palette[l] for l in labels]
        ]
        unique_colours = [
            tuple(colour) if isinstance(colour, list) else colour
            for colour in class_palette.values()
        ]
        value_for_colour = {
            colour: i
            for i, colour in enumerate(unique_colours)
        }
        label_colour_matrix = numpy.array([
            value_for_colour[colour] for colour in label_colours
        ]).reshape(n_examples, 1)
        label_colour_map = matplotlib.colors.ListedColormap(unique_colours)
    else:
        label_colour_matrix = None
        label_colour_map = None

    # Heat map aspect ratio
    if not plot_distances:
        square_cells = False
    else:
        square_cells = True

    seaborn.set(style="white")

    # Sorting and optional dendrogram
    if sorting_method == "labels":
        example_indices = numpy.argsort(labels)

        if not label_kind:
            label_kind = "labels"

        if example_label:
            example_label += " sorted by " + label_kind

    elif sorting_method == "hierarchical_clustering":
        linkage = scipy.cluster.hierarchy.linkage(
            scipy.spatial.distance.squareform(distances, checks=False),
            metric="average")
        dendrogram = seaborn.matrix.dendrogram(distances,
                                               linkage=linkage,
                                               metric=None,
                                               method="ward",
                                               axis=0,
                                               label=False,
                                               rotate=True,
                                               ax=axis_dendrogram)
        example_indices = dendrogram.reordered_ind

        if example_label:
            example_label += " sorted by hierarchical clustering"

    elif sorting_method is None:
        example_indices = numpy.arange(n_examples)

    else:
        raise ValueError("`sorting_method` should be either \"labels\""
                         " or \"hierarchical clustering\"")

    # Heat map of values
    if plot_distances:
        plot_values = distances[example_indices][:, example_indices]
    else:
        plot_values = feature_matrix[
            example_indices][:, feature_indices_for_plotting]

    if scipy.sparse.issparse(plot_values):
        plot_values = plot_values.A

    colour_bar_dictionary = {}

    if value_label:
        colour_bar_dictionary["label"] = value_label

    seaborn.heatmap(plot_values,
                    center=center_value,
                    xticklabels=False,
                    yticklabels=False,
                    cbar=True,
                    cbar_kws=colour_bar_dictionary,
                    cbar_ax=axis_colour_map,
                    square=square_cells,
                    ax=axis_heat_map)

    # Colour labels
    if axis_labels:
        seaborn.heatmap(label_colour_matrix[example_indices],
                        xticklabels=False,
                        yticklabels=False,
                        cbar=False,
                        cmap=label_colour_map,
                        ax=axis_labels)

    style.reset_plot_look()

    # Axis labels
    if example_label:
        left_most_axis.set_ylabel(example_label)

    if feature_label:
        axis_heat_map.set_xlabel(feature_label)

    return figure, figure_name
Beispiel #4
0
def plot_values(values,
                colour_coding=None,
                colouring_data_set=None,
                centroids=None,
                sampled_values=None,
                class_name=None,
                feature_index=None,
                figure_labels=None,
                example_tag=None,
                name="scatter"):

    figure_name = name

    if figure_labels:
        title = figure_labels.get("title")
        x_label = figure_labels.get("x label")
        y_label = figure_labels.get("y label")
    else:
        title = "none"
        x_label = "$x$"
        y_label = "$y$"

    if not title:
        title = "none"

    figure_name += "-" + normalise_string(title)

    if colour_coding:
        colour_coding = normalise_string(colour_coding)
        figure_name += "-" + colour_coding
        if "predicted" in colour_coding:
            if colouring_data_set.prediction_specifications:
                figure_name += "-" + (
                    colouring_data_set.prediction_specifications.name)
            else:
                figure_name += "unknown_prediction_method"
        if colouring_data_set is None:
            raise ValueError("Colouring data set not given.")

    if sampled_values is not None:
        figure_name += "-samples"

    values = values.copy()[:, :2]
    if scipy.sparse.issparse(values):
        values = values.A

    # Randomise examples in values to remove any prior order
    n_examples, __ = values.shape
    random_state = numpy.random.RandomState(117)
    shuffled_indices = random_state.permutation(n_examples)
    values = values[shuffled_indices]

    # Adjust marker size based on number of examples
    style._adjust_marker_size_for_scatter_plots(n_examples)

    figure = pyplot.figure()
    axis = figure.add_subplot(1, 1, 1)
    seaborn.despine()

    axis.set_xlabel(x_label)
    axis.set_ylabel(y_label)

    colour_map = seaborn.dark_palette(style.STANDARD_PALETTE[0], as_cmap=True)

    alpha = 1
    if sampled_values is not None:
        alpha = 0.5

    if colour_coding and ("labels" in colour_coding or "ids" in colour_coding
                          or "class" in colour_coding
                          or colour_coding == "batches"):

        if colour_coding == "predicted_cluster_ids":
            labels = colouring_data_set.predicted_cluster_ids
            class_names = numpy.unique(labels).tolist()
            number_of_classes = len(class_names)
            class_palette = None
            label_sorter = None
        elif colour_coding == "predicted_labels":
            labels = colouring_data_set.predicted_labels
            class_names = colouring_data_set.predicted_class_names
            number_of_classes = colouring_data_set.number_of_predicted_classes
            class_palette = colouring_data_set.predicted_class_palette
            label_sorter = colouring_data_set.predicted_label_sorter
        elif colour_coding == "predicted_superset_labels":
            labels = colouring_data_set.predicted_superset_labels
            class_names = colouring_data_set.predicted_superset_class_names
            number_of_classes = (
                colouring_data_set.number_of_predicted_superset_classes)
            class_palette = colouring_data_set.predicted_superset_class_palette
            label_sorter = colouring_data_set.predicted_superset_label_sorter
        elif "superset" in colour_coding:
            labels = colouring_data_set.superset_labels
            class_names = colouring_data_set.superset_class_names
            number_of_classes = colouring_data_set.number_of_superset_classes
            class_palette = colouring_data_set.superset_class_palette
            label_sorter = colouring_data_set.superset_label_sorter
        elif colour_coding == "batches":
            labels = colouring_data_set.batch_indices.flatten()
            class_names = colouring_data_set.batch_names
            number_of_classes = colouring_data_set.number_of_batches
            class_palette = None
            label_sorter = None
        else:
            labels = colouring_data_set.labels
            class_names = colouring_data_set.class_names
            number_of_classes = colouring_data_set.number_of_classes
            class_palette = colouring_data_set.class_palette
            label_sorter = colouring_data_set.label_sorter

        if not class_palette:
            index_palette = style.lighter_palette(number_of_classes)
            class_palette = {
                class_name: index_palette[i]
                for i, class_name in enumerate(
                    sorted(class_names, key=label_sorter))
            }

        # Examples are shuffled, so should their labels be
        labels = labels[shuffled_indices]

        if ("labels" in colour_coding or "ids" in colour_coding
                or colour_coding == "batches"):
            colours = []
            classes = set()

            for i, label in enumerate(labels):
                colour = class_palette[label]
                colours.append(colour)

                # Plot one example for each class to add labels
                if label not in classes:
                    classes.add(label)
                    axis.scatter(values[i, 0],
                                 values[i, 1],
                                 color=colour,
                                 label=label,
                                 alpha=alpha)

            axis.scatter(values[:, 0], values[:, 1], c=colours, alpha=alpha)

            class_handles, class_labels = axis.get_legend_handles_labels()

            if class_labels:
                class_labels, class_handles = zip(
                    *sorted(zip(class_labels, class_handles),
                            key=(lambda t: label_sorter(t[0])
                                 ) if label_sorter else None))
                class_label_maximum_width = max(map(len, class_labels))
                if class_label_maximum_width <= 5 and number_of_classes <= 20:
                    axis.legend(class_handles, class_labels, loc="best")
                else:
                    if number_of_classes <= 20:
                        class_label_columns = 2
                    else:
                        class_label_columns = 3
                    axis.legend(
                        class_handles,
                        class_labels,
                        bbox_to_anchor=(-0.1, 1.05, 1.1, 0.95),
                        loc="lower left",
                        ncol=class_label_columns,
                        mode="expand",
                        borderaxespad=0.,
                    )

        elif "class" in colour_coding:
            colours = []
            figure_name += "-" + normalise_string(str(class_name))
            ordered_indices_set = {str(class_name): [], "Remaining": []}

            for i, label in enumerate(labels):
                if label == class_name:
                    colour = class_palette[label]
                    ordered_indices_set[str(class_name)].append(i)
                else:
                    colour = style.NEUTRAL_COLOUR
                    ordered_indices_set["Remaining"].append(i)
                colours.append(colour)

            colours = numpy.array(colours)

            z_order_index = 1
            for label, ordered_indices in sorted(ordered_indices_set.items()):
                if label == "Remaining":
                    z_order = 0
                else:
                    z_order = z_order_index
                    z_order_index += 1
                ordered_values = values[ordered_indices]
                ordered_colours = colours[ordered_indices]
                axis.scatter(ordered_values[:, 0],
                             ordered_values[:, 1],
                             c=ordered_colours,
                             label=label,
                             alpha=alpha,
                             zorder=z_order)

                handles, labels = axis.get_legend_handles_labels()
                labels, handles = zip(*sorted(zip(labels, handles),
                                              key=lambda t: label_sorter(t[0])
                                              if label_sorter else None))
                axis.legend(handles,
                            labels,
                            bbox_to_anchor=(-0.1, 1.05, 1.1, 0.95),
                            loc="lower left",
                            ncol=2,
                            mode="expand",
                            borderaxespad=0.)

    elif colour_coding == "count_sum":

        n = colouring_data_set.count_sum[shuffled_indices].flatten()
        scatter_plot = axis.scatter(values[:, 0],
                                    values[:, 1],
                                    c=n,
                                    cmap=colour_map,
                                    alpha=alpha)
        colour_bar = figure.colorbar(scatter_plot)
        colour_bar.outline.set_linewidth(0)
        colour_bar.set_label("Total number of {}s per {}".format(
            colouring_data_set.terms["item"],
            colouring_data_set.terms["example"]))

    elif colour_coding == "feature":
        if feature_index is None:
            raise ValueError("Feature number not given.")
        if feature_index > colouring_data_set.number_of_features:
            raise ValueError("Feature number higher than number of features.")

        feature_name = colouring_data_set.feature_names[feature_index]
        figure_name += "-{}".format(normalise_string(feature_name))

        f = colouring_data_set.values[shuffled_indices, feature_index]
        if scipy.sparse.issparse(f):
            f = f.A
        f = f.squeeze()

        scatter_plot = axis.scatter(values[:, 0],
                                    values[:, 1],
                                    c=f,
                                    cmap=colour_map,
                                    alpha=alpha)
        colour_bar = figure.colorbar(scatter_plot)
        colour_bar.outline.set_linewidth(0)
        colour_bar.set_label(feature_name)

    elif colour_coding is None:
        axis.scatter(values[:, 0],
                     values[:, 1],
                     c="k",
                     alpha=alpha,
                     edgecolors="none")

    else:
        raise ValueError("Colour coding `{}` not found.".format(colour_coding))

    if centroids:
        prior_centroids = centroids["prior"]

        if prior_centroids:
            n_centroids = prior_centroids["probabilities"].shape[0]
        else:
            n_centroids = 0

        if n_centroids > 1:
            centroids_palette = style.darker_palette(n_centroids)
            classes = numpy.arange(n_centroids)

            means = prior_centroids["means"]
            covariance_matrices = prior_centroids["covariance_matrices"]

            for k in range(n_centroids):
                axis.scatter(means[k, 0],
                             means[k, 1],
                             s=60,
                             marker="x",
                             color="black",
                             linewidth=3)
                axis.scatter(means[k, 0],
                             means[k, 1],
                             marker="x",
                             facecolor=centroids_palette[k],
                             edgecolors="black")
                ellipse_fill, ellipse_edge = _covariance_matrix_as_ellipse(
                    covariance_matrices[k],
                    means[k],
                    colour=centroids_palette[k])
                axis.add_patch(ellipse_edge)
                axis.add_patch(ellipse_fill)

    if sampled_values is not None:

        sampled_values = sampled_values.copy()[:, :2]
        if scipy.sparse.issparse(sampled_values):
            sampled_values = sampled_values.A

        sample_colour_map = seaborn.blend_palette(("white", "purple"),
                                                  as_cmap=True)

        x_limits = axis.get_xlim()
        y_limits = axis.get_ylim()

        axis.hexbin(sampled_values[:, 0],
                    sampled_values[:, 1],
                    gridsize=75,
                    cmap=sample_colour_map,
                    linewidths=0.,
                    edgecolors="none",
                    zorder=-100)

        axis.set_xlim(x_limits)
        axis.set_ylim(y_limits)

    # Reset marker size
    style.reset_plot_look()

    return figure, figure_name