Пример #1
0
def plot_variable_correlations(values,
                               variable_names=None,
                               colouring_data_set=None,
                               name="variable_correlations"):

    figure_name = saving.build_figure_name(name)
    n_examples, n_features = values.shape

    random_state = numpy.random.RandomState(117)
    shuffled_indices = random_state.permutation(n_examples)
    values = values[shuffled_indices]

    if colouring_data_set:
        labels = colouring_data_set.labels
        class_names = colouring_data_set.class_names
        number_of_classes = colouring_data_set.number_of_classes
        class_palette = colouring_data_set.class_palette
        label_sorter = colouring_data_set.label_sorter

        if not class_palette:
            index_palette = style.lighter_palette(number_of_classes)
            class_palette = {
                class_name: index_palette[i]
                for i, class_name in enumerate(
                    sorted(class_names, key=label_sorter))
            }

        labels = labels[shuffled_indices]

        colours = []

        for label in labels:
            colour = class_palette[label]
            colours.append(colour)

    else:
        colours = style.NEUTRAL_COLOUR

    figure, axes = pyplot.subplots(nrows=n_features,
                                   ncols=n_features,
                                   figsize=[1.5 * n_features] * 2)

    for i in range(n_features):
        for j in range(n_features):
            axes[i, j].scatter(values[:, i], values[:, j], c=colours, s=1)

            axes[i, j].set_xticks([])
            axes[i, j].set_yticks([])

            if i == n_features - 1:
                axes[i, j].set_xlabel(variable_names[j])

        axes[i, 0].set_ylabel(variable_names[i])

    return figure, figure_name
Пример #2
0
def plot_variable_label_correlations(variable_vector,
                                     variable_name,
                                     colouring_data_set,
                                     name="variable_label_correlations"):

    figure_name = saving.build_figure_name(name)
    n_examples = variable_vector.shape[0]

    class_names_to_class_ids = numpy.vectorize(
        lambda class_name: colouring_data_set.class_name_to_class_id[class_name
                                                                     ])
    class_ids_to_class_names = numpy.vectorize(
        lambda class_name: colouring_data_set.class_id_to_class_name[class_name
                                                                     ])

    labels = colouring_data_set.labels
    class_names = colouring_data_set.class_names
    number_of_classes = colouring_data_set.number_of_classes
    class_palette = colouring_data_set.class_palette
    label_sorter = colouring_data_set.label_sorter

    if not class_palette:
        index_palette = style.lighter_palette(number_of_classes)
        class_palette = {
            class_name: index_palette[i]
            for i, class_name in enumerate(
                sorted(class_names, key=label_sorter))
        }

    random_state = numpy.random.RandomState(117)
    shuffled_indices = random_state.permutation(n_examples)
    variable_vector = variable_vector[shuffled_indices]

    labels = labels[shuffled_indices]
    label_ids = numpy.expand_dims(class_names_to_class_ids(labels), axis=-1)
    colours = [class_palette[label] for label in labels]

    unique_class_ids = numpy.unique(label_ids)
    unique_class_names = class_ids_to_class_names(unique_class_ids)

    figure = pyplot.figure()
    axis = figure.add_subplot(1, 1, 1)
    seaborn.despine()

    axis.scatter(variable_vector, label_ids, c=colours, s=1)

    axis.set_yticks(unique_class_ids)
    axis.set_yticklabels(unique_class_names)

    axis.set_xlabel(variable_name)
    axis.set_ylabel(capitalise_string(colouring_data_set.terms["class"]))

    return figure, figure_name
Пример #3
0
def analyse_distributions(data_set,
                          colouring_data_set=None,
                          cutoffs=None,
                          preprocessed=False,
                          analysis_level="normal",
                          export_options=None,
                          analyses_directory=None):

    if not colouring_data_set:
        colouring_data_set = data_set

    if analysis_level is None:
        analysis_level = defaults["analyses"]["analysis_level"]

    if analyses_directory is None:
        analyses_directory = defaults["analyses"]["directory"]
    distribution_directory = os.path.join(analyses_directory, "histograms")

    data_set_title = data_set.kind + " set"
    data_set_name = data_set.kind
    if data_set.version != "original":
        data_set_title = data_set.version + " " + data_set_title
        data_set_name = None

    data_set_discreteness = data_set.discreteness and not preprocessed

    print("Plotting distributions for {}.".format(data_set_title))

    # Class distribution
    if (data_set.number_of_classes and data_set.number_of_classes < 100
            and colouring_data_set == data_set):
        distribution_time_start = time()
        figure, figure_name = figures.plot_class_histogram(
            labels=data_set.labels,
            class_names=data_set.class_names,
            class_palette=data_set.class_palette,
            normed=True,
            scale="linear",
            label_sorter=data_set.label_sorter,
            name=data_set_name)
        figures.save_figure(figure=figure,
                            name=figure_name,
                            options=export_options,
                            directory=distribution_directory)
        distribution_duration = time() - distribution_time_start
        print("    Class distribution plotted and saved ({}).".format(
            format_duration(distribution_duration)))

    # Superset class distribution
    if data_set.label_superset and colouring_data_set == data_set:
        distribution_time_start = time()
        figure, figure_name = figures.plot_class_histogram(
            labels=data_set.superset_labels,
            class_names=data_set.superset_class_names,
            class_palette=data_set.superset_class_palette,
            normed=True,
            scale="linear",
            label_sorter=data_set.superset_label_sorter,
            name=[data_set_name, "superset"])
        figures.save_figure(figure=figure,
                            name=figure_name,
                            options=export_options,
                            directory=distribution_directory)
        distribution_duration = time() - distribution_time_start
        print("    Superset class distribution plotted and saved ({}).".format(
            format_duration(distribution_duration)))

    # Count distribution
    if scipy.sparse.issparse(data_set.values):
        series = data_set.values.data
        excess_zero_count = data_set.values.size - series.size
    else:
        series = data_set.values.reshape(-1)
        excess_zero_count = 0
    distribution_time_start = time()
    for x_scale in ["linear", "log"]:
        figure, figure_name = figures.plot_histogram(
            series=series,
            excess_zero_count=excess_zero_count,
            label=data_set.tags["value"].capitalize() + "s",
            discrete=data_set_discreteness,
            normed=True,
            x_scale=x_scale,
            y_scale="log",
            name=["counts", data_set_name])
        figures.save_figure(figure=figure,
                            name=figure_name,
                            options=export_options,
                            directory=distribution_directory)
    distribution_duration = time() - distribution_time_start
    print("    Count distribution plotted and saved ({}).".format(
        format_duration(distribution_duration)))

    # Count distributions with cut-off
    if (analysis_level == "extensive" and cutoffs
            and data_set.example_type == "counts"):
        distribution_time_start = time()
        for cutoff in cutoffs:
            figure, figure_name = figures.plot_cutoff_count_histogram(
                series=series,
                excess_zero_count=excess_zero_count,
                cutoff=cutoff,
                normed=True,
                scale="log",
                name=data_set_name)
            figures.save_figure(figure=figure,
                                name=figure_name,
                                options=export_options,
                                directory=distribution_directory + "-counts")
        distribution_duration = time() - distribution_time_start
        print("    Count distributions with cut-offs plotted and saved ({}).".
              format(format_duration(distribution_duration)))

    # Count sum distribution
    distribution_time_start = time()
    figure, figure_name = figures.plot_histogram(
        series=data_set.count_sum,
        label="Total number of {}s per {}".format(data_set.tags["item"],
                                                  data_set.tags["example"]),
        normed=True,
        y_scale="log",
        name=["count sum", data_set_name])
    figures.save_figure(figure=figure,
                        name=figure_name,
                        options=export_options,
                        directory=distribution_directory)
    distribution_duration = time() - distribution_time_start
    print("    Count sum distribution plotted and saved ({}).".format(
        format_duration(distribution_duration)))

    # Count distributions and count sum distributions for each class
    if analysis_level == "extensive" and colouring_data_set.labels is not None:

        class_count_distribution_directory = distribution_directory
        if data_set.version == "original":
            class_count_distribution_directory += "-classes"

        if colouring_data_set.label_superset:
            labels = colouring_data_set.superset_labels
            class_names = colouring_data_set.superset_class_names
            class_palette = colouring_data_set.superset_class_palette
            label_sorter = colouring_data_set.superset_label_sorter
        else:
            labels = colouring_data_set.labels
            class_names = colouring_data_set.class_names
            class_palette = colouring_data_set.class_palette
            label_sorter = colouring_data_set.label_sorter

        if not class_palette:
            index_palette = style.lighter_palette(
                colouring_data_set.number_of_classes)
            class_palette = {
                class_name: index_palette[i]
                for i, class_name in enumerate(
                    sorted(class_names, key=label_sorter))
            }

        distribution_time_start = time()
        for class_name in class_names:

            class_indices = labels == class_name

            if not class_indices.any():
                continue

            values_label = data_set.values[class_indices]

            if scipy.sparse.issparse(values_label):
                series = values_label.data
                excess_zero_count = values_label.size - series.size
            else:
                series = data_set.values.reshape(-1)
                excess_zero_count = 0

            figure, figure_name = figures.plot_histogram(
                series=series,
                excess_zero_count=excess_zero_count,
                label=data_set.tags["value"].capitalize() + "s",
                discrete=data_set_discreteness,
                normed=True,
                y_scale="log",
                colour=class_palette[class_name],
                name=["counts", data_set_name, "class", class_name])
            figures.save_figure(figure=figure,
                                name=figure_name,
                                options=export_options,
                                directory=class_count_distribution_directory)

        distribution_duration = time() - distribution_time_start
        print("    Count distributions for each class plotted and saved ({}).".
              format(format_duration(distribution_duration)))

        distribution_time_start = time()
        for class_name in class_names:

            class_indices = labels == class_name
            if not class_indices.any():
                continue

            figure, figure_name = figures.plot_histogram(
                series=data_set.count_sum[class_indices],
                label="Total number of {}s per {}".format(
                    data_set.tags["item"], data_set.tags["example"]),
                normed=True,
                y_scale="log",
                colour=class_palette[class_name],
                name=["count sum", data_set_name, "class", class_name])
            figures.save_figure(figure=figure,
                                name=figure_name,
                                options=export_options,
                                directory=class_count_distribution_directory)

        distribution_duration = time() - distribution_time_start
        print("    "
              "Count sum distributions for each class plotted and saved ({}).".
              format(format_duration(distribution_duration)))

    print()
Пример #4
0
def analyse_matrices(data_set,
                     plot_distances=False,
                     name=None,
                     export_options=None,
                     analyses_directory=None):

    if plot_distances:
        base_name = "distances"
    else:
        base_name = "heat_maps"

    if analyses_directory is None:
        analyses_directory = defaults["analyses"]["directory"]
    analyses_directory = os.path.join(analyses_directory, base_name)

    if not name:
        name = []
    elif not isinstance(name, list):
        name = [name]

    name.insert(0, base_name)

    # Subsampling indices (if necessary)
    random_state = numpy.random.RandomState(57)
    shuffled_indices = random_state.permutation(data_set.number_of_examples)

    # Feature selection for plotting (if necessary)
    feature_indices_for_plotting = None
    if (not plot_distances and data_set.number_of_features >
            MAXIMUM_NUMBER_OF_FEATURES_FOR_HEAT_MAPS):
        feature_variances = data_set.values.var(axis=0)
        if isinstance(feature_variances, numpy.matrix):
            feature_variances = feature_variances.A.squeeze()
        feature_indices_for_plotting = numpy.argsort(
            feature_variances)[-MAXIMUM_NUMBER_OF_FEATURES_FOR_HEAT_MAPS:]
        feature_indices_for_plotting.sort()

    # Class palette
    class_palette = data_set.class_palette
    if data_set.labels is not None and not class_palette:
        index_palette = style.lighter_palette(data_set.number_of_classes)
        class_palette = {
            class_name: tuple(index_palette[i])
            for i, class_name in enumerate(
                sorted(data_set.class_names, key=data_set.label_sorter))
        }

    # Axis labels
    example_label = data_set.tags["example"].capitalize() + "s"
    feature_label = data_set.tags["feature"].capitalize() + "s"
    value_label = data_set.tags["value"].capitalize() + "s"

    version = data_set.version
    symbol = None
    value_name = "values"

    if version in ["z", "x"]:
        symbol = "$\\mathbf{{{}}}$".format(version)
        value_name = "component"
    elif version in ["y"]:
        symbol = "${}$".format(version)
        value_name = "value"

    if version in ["y", "z"]:
        feature_label = " ".join([symbol, value_name + "s"])

    if plot_distances:
        if version in ["y", "z"]:
            value_label = symbol
        else:
            value_label = version

    if feature_indices_for_plotting is not None:
        feature_label = "{} most varying {}".format(
            len(feature_indices_for_plotting), feature_label.lower())

    plot_string = "Plotting heat map for {} values."
    if plot_distances:
        plot_string = "Plotting pairwise distances in {} space."
    print(plot_string.format(data_set.version))

    sorting_methods = ["hierarchical_clustering"]

    if data_set.labels is not None:
        sorting_methods.insert(0, "labels")

    for sorting_method in sorting_methods:

        distance_metrics = [None]

        if plot_distances or sorting_method == "hierarchical_clustering":
            distance_metrics = ["Euclidean", "cosine"]

        for distance_metric in distance_metrics:

            start_time = time()

            if (sorting_method == "hierarchical_clustering"
                    and data_set.number_of_examples >
                    MAXIMUM_NUMBER_OF_EXAMPLES_FOR_DENDROGRAM):
                sample_size = MAXIMUM_NUMBER_OF_EXAMPLES_FOR_DENDROGRAM
            elif (data_set.number_of_examples >
                  MAXIMUM_NUMBER_OF_EXAMPLES_FOR_HEAT_MAPS):
                sample_size = MAXIMUM_NUMBER_OF_EXAMPLES_FOR_HEAT_MAPS
            else:
                sample_size = None

            indices = numpy.arange(data_set.number_of_examples)

            if sample_size:
                indices = shuffled_indices[:sample_size]
                example_label = "{} randomly sampled {}".format(
                    sample_size, data_set.tags["example"] + "s")

            figure, figure_name = figures.plot_matrix(
                feature_matrix=data_set.values[indices],
                plot_distances=plot_distances,
                example_label=example_label,
                feature_label=feature_label,
                value_label=value_label,
                sorting_method=sorting_method,
                distance_metric=distance_metric,
                labels=(data_set.labels[indices]
                        if data_set.labels is not None else None),
                label_kind=data_set.tags["class"],
                class_palette=class_palette,
                feature_indices_for_plotting=feature_indices_for_plotting,
                name_parts=name +
                [data_set.version, distance_metric, sorting_method])
            figures.save_figure(figure=figure,
                                name=figure_name,
                                options=export_options,
                                directory=analyses_directory)

            duration = time() - start_time

            plot_kind_string = "Heat map for {} values".format(
                data_set.version)

            if plot_distances:
                plot_kind_string = "{} distances in {} space".format(
                    distance_metric.capitalize(), data_set.version)

            subsampling_string = ""

            if sample_size:
                subsampling_string = "{} {} randomly sampled examples".format(
                    "for" if plot_distances else "of", sample_size)

            sort_string = "sorted using {}".format(
                sorting_method.replace("_", " "))

            if (not plot_distances
                    and sorting_method == "hierarchical_clustering"):
                sort_string += " (with {} distances)".format(distance_metric)

            print("    " + " ".join([
                s for s in [
                    plot_kind_string, subsampling_string, sort_string,
                    "plotted and saved", "({})".format(
                        format_duration(duration))
                ] if s
            ]) + ".")

    print()
Пример #5
0
def plot_values(values,
                colour_coding=None,
                colouring_data_set=None,
                centroids=None,
                sampled_values=None,
                class_name=None,
                feature_index=None,
                figure_labels=None,
                example_tag=None,
                name="scatter"):

    figure_name = name

    if figure_labels:
        title = figure_labels.get("title")
        x_label = figure_labels.get("x label")
        y_label = figure_labels.get("y label")
    else:
        title = "none"
        x_label = "$x$"
        y_label = "$y$"

    if not title:
        title = "none"

    figure_name += "-" + normalise_string(title)

    if colour_coding:
        colour_coding = normalise_string(colour_coding)
        figure_name += "-" + colour_coding
        if "predicted" in colour_coding:
            if colouring_data_set.prediction_specifications:
                figure_name += "-" + (
                    colouring_data_set.prediction_specifications.name)
            else:
                figure_name += "unknown_prediction_method"
        if colouring_data_set is None:
            raise ValueError("Colouring data set not given.")

    if sampled_values is not None:
        figure_name += "-samples"

    values = values.copy()[:, :2]
    if scipy.sparse.issparse(values):
        values = values.A

    # Randomise examples in values to remove any prior order
    n_examples, __ = values.shape
    random_state = numpy.random.RandomState(117)
    shuffled_indices = random_state.permutation(n_examples)
    values = values[shuffled_indices]

    # Adjust marker size based on number of examples
    style._adjust_marker_size_for_scatter_plots(n_examples)

    figure = pyplot.figure()
    axis = figure.add_subplot(1, 1, 1)
    seaborn.despine()

    axis.set_xlabel(x_label)
    axis.set_ylabel(y_label)

    colour_map = seaborn.dark_palette(style.STANDARD_PALETTE[0], as_cmap=True)

    alpha = 1
    if sampled_values is not None:
        alpha = 0.5

    if colour_coding and ("labels" in colour_coding or "ids" in colour_coding
                          or "class" in colour_coding
                          or colour_coding == "batches"):

        if colour_coding == "predicted_cluster_ids":
            labels = colouring_data_set.predicted_cluster_ids
            class_names = numpy.unique(labels).tolist()
            number_of_classes = len(class_names)
            class_palette = None
            label_sorter = None
        elif colour_coding == "predicted_labels":
            labels = colouring_data_set.predicted_labels
            class_names = colouring_data_set.predicted_class_names
            number_of_classes = colouring_data_set.number_of_predicted_classes
            class_palette = colouring_data_set.predicted_class_palette
            label_sorter = colouring_data_set.predicted_label_sorter
        elif colour_coding == "predicted_superset_labels":
            labels = colouring_data_set.predicted_superset_labels
            class_names = colouring_data_set.predicted_superset_class_names
            number_of_classes = (
                colouring_data_set.number_of_predicted_superset_classes)
            class_palette = colouring_data_set.predicted_superset_class_palette
            label_sorter = colouring_data_set.predicted_superset_label_sorter
        elif "superset" in colour_coding:
            labels = colouring_data_set.superset_labels
            class_names = colouring_data_set.superset_class_names
            number_of_classes = colouring_data_set.number_of_superset_classes
            class_palette = colouring_data_set.superset_class_palette
            label_sorter = colouring_data_set.superset_label_sorter
        elif colour_coding == "batches":
            labels = colouring_data_set.batch_indices.flatten()
            class_names = colouring_data_set.batch_names
            number_of_classes = colouring_data_set.number_of_batches
            class_palette = None
            label_sorter = None
        else:
            labels = colouring_data_set.labels
            class_names = colouring_data_set.class_names
            number_of_classes = colouring_data_set.number_of_classes
            class_palette = colouring_data_set.class_palette
            label_sorter = colouring_data_set.label_sorter

        if not class_palette:
            index_palette = style.lighter_palette(number_of_classes)
            class_palette = {
                class_name: index_palette[i]
                for i, class_name in enumerate(
                    sorted(class_names, key=label_sorter))
            }

        # Examples are shuffled, so should their labels be
        labels = labels[shuffled_indices]

        if ("labels" in colour_coding or "ids" in colour_coding
                or colour_coding == "batches"):
            colours = []
            classes = set()

            for i, label in enumerate(labels):
                colour = class_palette[label]
                colours.append(colour)

                # Plot one example for each class to add labels
                if label not in classes:
                    classes.add(label)
                    axis.scatter(values[i, 0],
                                 values[i, 1],
                                 color=colour,
                                 label=label,
                                 alpha=alpha)

            axis.scatter(values[:, 0], values[:, 1], c=colours, alpha=alpha)

            class_handles, class_labels = axis.get_legend_handles_labels()

            if class_labels:
                class_labels, class_handles = zip(
                    *sorted(zip(class_labels, class_handles),
                            key=(lambda t: label_sorter(t[0])
                                 ) if label_sorter else None))
                class_label_maximum_width = max(map(len, class_labels))
                if class_label_maximum_width <= 5 and number_of_classes <= 20:
                    axis.legend(class_handles, class_labels, loc="best")
                else:
                    if number_of_classes <= 20:
                        class_label_columns = 2
                    else:
                        class_label_columns = 3
                    axis.legend(
                        class_handles,
                        class_labels,
                        bbox_to_anchor=(-0.1, 1.05, 1.1, 0.95),
                        loc="lower left",
                        ncol=class_label_columns,
                        mode="expand",
                        borderaxespad=0.,
                    )

        elif "class" in colour_coding:
            colours = []
            figure_name += "-" + normalise_string(str(class_name))
            ordered_indices_set = {str(class_name): [], "Remaining": []}

            for i, label in enumerate(labels):
                if label == class_name:
                    colour = class_palette[label]
                    ordered_indices_set[str(class_name)].append(i)
                else:
                    colour = style.NEUTRAL_COLOUR
                    ordered_indices_set["Remaining"].append(i)
                colours.append(colour)

            colours = numpy.array(colours)

            z_order_index = 1
            for label, ordered_indices in sorted(ordered_indices_set.items()):
                if label == "Remaining":
                    z_order = 0
                else:
                    z_order = z_order_index
                    z_order_index += 1
                ordered_values = values[ordered_indices]
                ordered_colours = colours[ordered_indices]
                axis.scatter(ordered_values[:, 0],
                             ordered_values[:, 1],
                             c=ordered_colours,
                             label=label,
                             alpha=alpha,
                             zorder=z_order)

                handles, labels = axis.get_legend_handles_labels()
                labels, handles = zip(*sorted(zip(labels, handles),
                                              key=lambda t: label_sorter(t[0])
                                              if label_sorter else None))
                axis.legend(handles,
                            labels,
                            bbox_to_anchor=(-0.1, 1.05, 1.1, 0.95),
                            loc="lower left",
                            ncol=2,
                            mode="expand",
                            borderaxespad=0.)

    elif colour_coding == "count_sum":

        n = colouring_data_set.count_sum[shuffled_indices].flatten()
        scatter_plot = axis.scatter(values[:, 0],
                                    values[:, 1],
                                    c=n,
                                    cmap=colour_map,
                                    alpha=alpha)
        colour_bar = figure.colorbar(scatter_plot)
        colour_bar.outline.set_linewidth(0)
        colour_bar.set_label("Total number of {}s per {}".format(
            colouring_data_set.terms["item"],
            colouring_data_set.terms["example"]))

    elif colour_coding == "feature":
        if feature_index is None:
            raise ValueError("Feature number not given.")
        if feature_index > colouring_data_set.number_of_features:
            raise ValueError("Feature number higher than number of features.")

        feature_name = colouring_data_set.feature_names[feature_index]
        figure_name += "-{}".format(normalise_string(feature_name))

        f = colouring_data_set.values[shuffled_indices, feature_index]
        if scipy.sparse.issparse(f):
            f = f.A
        f = f.squeeze()

        scatter_plot = axis.scatter(values[:, 0],
                                    values[:, 1],
                                    c=f,
                                    cmap=colour_map,
                                    alpha=alpha)
        colour_bar = figure.colorbar(scatter_plot)
        colour_bar.outline.set_linewidth(0)
        colour_bar.set_label(feature_name)

    elif colour_coding is None:
        axis.scatter(values[:, 0],
                     values[:, 1],
                     c="k",
                     alpha=alpha,
                     edgecolors="none")

    else:
        raise ValueError("Colour coding `{}` not found.".format(colour_coding))

    if centroids:
        prior_centroids = centroids["prior"]

        if prior_centroids:
            n_centroids = prior_centroids["probabilities"].shape[0]
        else:
            n_centroids = 0

        if n_centroids > 1:
            centroids_palette = style.darker_palette(n_centroids)
            classes = numpy.arange(n_centroids)

            means = prior_centroids["means"]
            covariance_matrices = prior_centroids["covariance_matrices"]

            for k in range(n_centroids):
                axis.scatter(means[k, 0],
                             means[k, 1],
                             s=60,
                             marker="x",
                             color="black",
                             linewidth=3)
                axis.scatter(means[k, 0],
                             means[k, 1],
                             marker="x",
                             facecolor=centroids_palette[k],
                             edgecolors="black")
                ellipse_fill, ellipse_edge = _covariance_matrix_as_ellipse(
                    covariance_matrices[k],
                    means[k],
                    colour=centroids_palette[k])
                axis.add_patch(ellipse_edge)
                axis.add_patch(ellipse_fill)

    if sampled_values is not None:

        sampled_values = sampled_values.copy()[:, :2]
        if scipy.sparse.issparse(sampled_values):
            sampled_values = sampled_values.A

        sample_colour_map = seaborn.blend_palette(("white", "purple"),
                                                  as_cmap=True)

        x_limits = axis.get_xlim()
        y_limits = axis.get_ylim()

        axis.hexbin(sampled_values[:, 0],
                    sampled_values[:, 1],
                    gridsize=75,
                    cmap=sample_colour_map,
                    linewidths=0.,
                    edgecolors="none",
                    zorder=-100)

        axis.set_xlim(x_limits)
        axis.set_ylim(y_limits)

    # Reset marker size
    style.reset_plot_look()

    return figure, figure_name
Пример #6
0
def plot_class_histogram(labels, class_names=None, class_palette=None,
                         normed=False, scale="linear", label_sorter=None,
                         name=None):

    figure_name = "histogram"

    if normed:
        figure_name += "-normed"

    figure_name += "-classes"

    figure_name = saving.build_figure_name(figure_name, name)

    figure = pyplot.figure()
    axis = figure.add_subplot(1, 1, 1)
    seaborn.despine()

    if class_names is None:
        class_names = numpy.unique(labels)

    n_classes = len(class_names)

    if not class_palette:
        index_palette = style.lighter_palette(n_classes)
        class_palette = {
            class_name: index_palette[i] for i, class_name in enumerate(sorted(
                class_names, key=label_sorter))
        }

    histogram = {
        class_name: {
            "index": i,
            "count": 0,
            "colour": class_palette[class_name]
        }
        for i, class_name in enumerate(sorted(
            class_names, key=label_sorter))
    }

    total_count_sum = 0

    for label in labels:
        histogram[label]["count"] += 1
        total_count_sum += 1

    indices = []
    class_names = []

    for class_name, class_values in sorted(histogram.items()):
        index = class_values["index"]
        count = class_values["count"]
        frequency = count / total_count_sum
        colour = class_values["colour"]
        indices.append(index)
        class_names.append(class_name)

        if normed:
            count_or_frequecny = frequency
        else:
            count_or_frequecny = count

        axis.bar(index, count_or_frequecny, color=colour)

    axis.set_yscale(scale)

    maximum_class_name_width = max([
        len(str(class_name)) for class_name in class_names
        if class_name not in ["No class"]
    ])
    if maximum_class_name_width > 5:
        y_ticks_rotation = 45
        y_ticks_horizontal_alignment = "right"
        y_ticks_rotation_mode = "anchor"
    else:
        y_ticks_rotation = 0
        y_ticks_horizontal_alignment = "center"
        y_ticks_rotation_mode = None
    pyplot.xticks(
        ticks=indices,
        labels=class_names,
        horizontalalignment=y_ticks_horizontal_alignment,
        rotation=y_ticks_rotation,
        rotation_mode=y_ticks_rotation_mode
    )

    axis.set_xlabel("Classes")

    if normed:
        axis.set_ylabel("Frequency")
    else:
        axis.set_ylabel("Number of counts")

    return figure, figure_name