def plot_series(series, x_label, y_label, sort=False, scale="linear", bar=False, colour=None, name=None): figure_name = saving.build_figure_name("series", name) if not colour: colour = style.STANDARD_PALETTE[0] series_length = series.shape[0] x = numpy.linspace(0, series_length, series_length) y_log = scale == "log" if sort: # Sort descending series = numpy.sort(series)[::-1] x_label = "sorted " + x_label figure_name += "-sorted" figure = pyplot.figure() axis = figure.add_subplot(1, 1, 1) seaborn.despine() if bar: axis.bar(x, series, log=y_log, color=colour, alpha=0.4) else: axis.plot(x, series, color=colour) axis.set_yscale(scale) axis.set_xlabel(capitalise_string(x_label)) axis.set_ylabel(capitalise_string(y_label)) return figure, figure_name
def plot_accuracy_evolution(accuracies, name=None): figure_name = saving.build_figure_name("accuracies", name) figure = pyplot.figure() axis = figure.add_subplot(1, 1, 1) seaborn.despine() for accuracies_kind, accuracies in sorted(accuracies.items()): if accuracies is None: continue elif accuracies_kind == "training": line_style = "solid" colour = style.STANDARD_PALETTE[0] elif accuracies_kind == "validation": line_style = "dashed" colour = style.STANDARD_PALETTE[1] label = "{} set".format(capitalise_string(accuracies_kind)) epochs = numpy.arange(len(accuracies)) + 1 axis.plot(epochs, 100 * accuracies, color=colour, linestyle=line_style, label=label) handles, labels = axis.get_legend_handles_labels() labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0])) axis.legend(handles, labels, loc="best") axis.set_xlabel("Epoch") axis.set_ylabel("Accuracies") return figure, figure_name
def plot_correlation_matrix(correlation_matrix, axis_label=None, name=None): figure_name = saving.build_figure_name(name) figure = pyplot.figure() axis = figure.add_subplot(1, 1, 1) colour_bar_dictionary = {"label": "Pearson correlation coefficient"} seaborn.set(style="white") seaborn.heatmap(correlation_matrix, vmin=-1, vmax=1, center=0, xticklabels=False, yticklabels=False, cbar=True, cbar_kws=colour_bar_dictionary, square=True, ax=axis) style.reset_plot_look() if axis_label: axis.set_xlabel(axis_label) axis.set_ylabel(axis_label) return figure, figure_name
def plot_correlations(correlation_sets, x_key, y_key, x_label=None, y_label=None, name=None): figure_name = saving.build_figure_name("correlations", name) if not isinstance(correlation_sets, dict): correlation_sets = {"correlations": correlation_sets} figure = pyplot.figure() axis = figure.add_subplot(1, 1, 1) seaborn.despine() axis.set_xlabel(x_label) axis.set_ylabel(y_label) for correlation_set_name, correlation_set in correlation_sets.items(): axis.scatter(correlation_set[x_key], correlation_set[y_key], label=correlation_set_name) if len(correlation_sets) > 1: axis.legend(loc="best") return figure, figure_name
def combine_images_from_data_set(data_set, indices=None, number_of_random_examples=None, name=None): image_name = saving.build_figure_name("random_image_examples", name) random_state = numpy.random.RandomState(13) if indices is not None: n_examples = len(indices) if number_of_random_examples is not None: n_examples = min(n_examples, number_of_random_examples) indices = random_state.permutation(indices)[:n_examples] else: if number_of_random_examples is not None: n_examples = number_of_random_examples else: n_examples = DEFAULT_NUMBER_OF_RANDOM_EXAMPLES_FOR_COMBINED_IMAGES indices = random_state.permutation( data_set.number_of_examples)[:n_examples] if n_examples == 1: image_name = saving.build_figure_name("image_example", name) else: image_name = saving.build_figure_name("image_examples", name) width, height = data_set.feature_dimensions examples = data_set.values[indices] if scipy.sparse.issparse(examples): examples = examples.A examples = examples.reshape(n_examples, width, height) column = int(numpy.ceil(numpy.sqrt(n_examples))) row = int(numpy.ceil(n_examples / column)) image = numpy.zeros((row * width, column * height)) for m in range(n_examples): c = int(m % column) r = int(numpy.floor(m / column)) rows = slice(r * width, (r + 1) * width) columns = slice(c * height, (c + 1) * height) image[rows, columns] = examples[m] return image, image_name
def plot_cutoff_count_histogram(series, excess_zero_count=0, cutoff=None, normed=False, scale="linear", colour=None, name=None): series = series.copy() figure_name = "histogram" if normed: figure_name += "-normed" figure_name += "-counts" figure_name = saving.build_figure_name(figure_name, name) figure_name += "-cutoff-{}".format(cutoff) if not colour: colour = style.STANDARD_PALETTE[0] y_log = scale == "log" count_number = numpy.arange(cutoff + 1) # Array to count counts of a given count number count_number_count = numpy.zeros(cutoff + 1) for i in range(cutoff + 1): if count_number[i] < cutoff: c = (series == count_number[i]).sum() elif count_number[i] == cutoff: c = (series >= cutoff).sum() count_number_count[i] = c count_number_count[0] += excess_zero_count if normed: count_number_count /= count_number_count.sum() figure = pyplot.figure() axis = figure.add_subplot(1, 1, 1) seaborn.despine() axis.bar(count_number, count_number_count, log=y_log, color=colour, alpha=0.4) axis.set_xlabel("Count bins") if normed: axis.set_ylabel("Frequency") else: axis.set_ylabel("Number of counts") return figure, figure_name
def plot_probabilities(posterior_probabilities, prior_probabilities, x_label=None, y_label=None, palette=None, uniform=False, name=None): figure_name = saving.build_figure_name("probabilities", name) if posterior_probabilities is None and prior_probabilities is None: raise ValueError("No posterior nor prior probabilities given.") n_centroids = 0 if posterior_probabilities is not None: n_posterior_centroids = len(posterior_probabilities) n_centroids = max(n_centroids, n_posterior_centroids) if prior_probabilities is not None: n_prior_centroids = len(prior_probabilities) n_centroids = max(n_centroids, n_prior_centroids) if not x_label: x_label = "$k$" if not y_label: y_label = "$\\pi_k$" figure = pyplot.figure(figsize=(8, 6), dpi=80) axis = figure.add_subplot(1, 1, 1) seaborn.despine() if not palette: palette = [style.STANDARD_PALETTE[0]] * n_centroids if posterior_probabilities is not None: for k in range(n_posterior_centroids): axis.bar(k, posterior_probabilities[k], color=palette[k]) axis.set_ylabel("$\\pi_{\\phi}^k$") if prior_probabilities is not None: for k in range(n_posterior_centroids): axis.plot( [k-0.4, k+0.4], 2 * [prior_probabilities[k]], color="black", linestyle="dashed" ) prior_line = matplotlib.lines.Line2D( [], [], color="black", linestyle="dashed", label="$\\pi_{\\theta}^k$" ) axis.legend(handles=[prior_line], loc="best", fontsize=18) elif prior_probabilities is not None: for k in range(n_prior_centroids): axis.bar(k, prior_probabilities[k], color=palette[k]) axis.set_ylabel("$\\pi_{\\theta}^k$") axis.set_xlabel(x_label) return figure, figure_name
def plot_centroid_covariance_matrices_evolution(covariance_matrices, distribution, name=None): distribution = normalise_string(distribution) figure_name = "centroids_evolution-{}-covariance_matrices".format( distribution) figure_name = saving.build_figure_name(figure_name, name) y_label = _axis_label_for_symbol(symbol="\\Sigma", distribution=distribution, prefix="|", suffix="(y = k)|") n_epochs, n_centroids, __, __ = covariance_matrices.shape determinants = numpy.empty([n_epochs, n_centroids]) for e in range(n_epochs): for k in range(n_centroids): determinants[e, k] = numpy.prod(numpy.diag(covariance_matrices[e, k])) if determinants.all() > 0: line_range_ratio = numpy.empty(n_centroids) for k in range(n_centroids): determinants_min = determinants[:, k].min() determinants_max = determinants[:, k].max() line_range_ratio[k] = determinants_max / determinants_min range_ratio = line_range_ratio.max() / line_range_ratio.min() if range_ratio > 1e2: y_scale = "log" else: y_scale = "linear" centroids_palette = style.darker_palette(n_centroids) epochs = numpy.arange(n_epochs) + 1 figure = pyplot.figure() axis = figure.add_subplot(1, 1, 1) seaborn.despine() for k in range(n_centroids): axis.plot(epochs, determinants[:, k], color=centroids_palette[k], label="$k = {}$".format(k)) axis.set_xlabel("Epochs") axis.set_ylabel(y_label) axis.set_yscale(y_scale) axis.legend(loc="best") return figure, figure_name
def plot_variable_correlations(values, variable_names=None, colouring_data_set=None, name="variable_correlations"): figure_name = saving.build_figure_name(name) n_examples, n_features = values.shape random_state = numpy.random.RandomState(117) shuffled_indices = random_state.permutation(n_examples) values = values[shuffled_indices] if colouring_data_set: labels = colouring_data_set.labels class_names = colouring_data_set.class_names number_of_classes = colouring_data_set.number_of_classes class_palette = colouring_data_set.class_palette label_sorter = colouring_data_set.label_sorter if not class_palette: index_palette = style.lighter_palette(number_of_classes) class_palette = { class_name: index_palette[i] for i, class_name in enumerate( sorted(class_names, key=label_sorter)) } labels = labels[shuffled_indices] colours = [] for label in labels: colour = class_palette[label] colours.append(colour) else: colours = style.NEUTRAL_COLOUR figure, axes = pyplot.subplots(nrows=n_features, ncols=n_features, figsize=[1.5 * n_features] * 2) for i in range(n_features): for j in range(n_features): axes[i, j].scatter(values[:, i], values[:, j], c=colours, s=1) axes[i, j].set_xticks([]) axes[i, j].set_yticks([]) if i == n_features - 1: axes[i, j].set_xlabel(variable_names[j]) axes[i, 0].set_ylabel(variable_names[i]) return figure, figure_name
def plot_variable_label_correlations(variable_vector, variable_name, colouring_data_set, name="variable_label_correlations"): figure_name = saving.build_figure_name(name) n_examples = variable_vector.shape[0] class_names_to_class_ids = numpy.vectorize( lambda class_name: colouring_data_set.class_name_to_class_id[class_name ]) class_ids_to_class_names = numpy.vectorize( lambda class_name: colouring_data_set.class_id_to_class_name[class_name ]) labels = colouring_data_set.labels class_names = colouring_data_set.class_names number_of_classes = colouring_data_set.number_of_classes class_palette = colouring_data_set.class_palette label_sorter = colouring_data_set.label_sorter if not class_palette: index_palette = style.lighter_palette(number_of_classes) class_palette = { class_name: index_palette[i] for i, class_name in enumerate( sorted(class_names, key=label_sorter)) } random_state = numpy.random.RandomState(117) shuffled_indices = random_state.permutation(n_examples) variable_vector = variable_vector[shuffled_indices] labels = labels[shuffled_indices] label_ids = numpy.expand_dims(class_names_to_class_ids(labels), axis=-1) colours = [class_palette[label] for label in labels] unique_class_ids = numpy.unique(label_ids) unique_class_names = class_ids_to_class_names(unique_class_ids) figure = pyplot.figure() axis = figure.add_subplot(1, 1, 1) seaborn.despine() axis.scatter(variable_vector, label_ids, c=colours, s=1) axis.set_yticks(unique_class_ids) axis.set_yticklabels(unique_class_names) axis.set_xlabel(variable_name) axis.set_ylabel(capitalise_string(colouring_data_set.terms["class"])) return figure, figure_name
def plot_elbo_heat_map(data_frame, x_label, y_label, z_label=None, z_symbol=None, z_min=None, z_max=None, name=None): figure_name = saving.build_figure_name("ELBO_heat_map", name) figure = pyplot.figure() axis = figure.add_subplot(1, 1, 1) if not z_min: z_min = data_frame.values.min() if not z_max: z_max = data_frame.values.max() if z_symbol: z_label = "$" + z_symbol + "$" cbar_dict = {} if z_label: cbar_dict["label"] = z_label seaborn.set(style="white") seaborn.heatmap(data_frame, vmin=z_min, vmax=z_max, xticklabels=True, yticklabels=True, cbar=True, cbar_kws=cbar_dict, annot=True, fmt="-.6g", square=False, ax=axis) style.reset_plot_look() axis.set_xlabel(x_label) axis.set_ylabel(y_label) return figure, figure_name
def plot_kl_divergence_evolution(kl_neurons, scale="log", name=None): figure_name = saving.build_figure_name("kl_divergence_evolution", name) n_epochs, __ = kl_neurons.shape kl_neurons = numpy.sort(kl_neurons, axis=1) if scale == "log": kl_neurons = numpy.log(kl_neurons) scale_label = "$\\log$ " else: scale_label = "" figure = pyplot.figure() axis = figure.add_subplot(1, 1, 1) cbar_dict = {"label": scale_label + "KL$(p_i|q_i)$"} number_of_epoch_labels = 10 if n_epochs > 2 * number_of_epoch_labels: epoch_label_frequency = int( numpy.floor(n_epochs / number_of_epoch_labels)) else: epoch_label_frequency = True epochs = numpy.arange(n_epochs) + 1 seaborn.heatmap(pandas.DataFrame(kl_neurons.T, columns=epochs), xticklabels=epoch_label_frequency, yticklabels=False, cbar=True, cbar_kws=cbar_dict, cmap=style.STANDARD_COLOUR_MAP, ax=axis) axis.set_xlabel("Epochs") axis.set_ylabel("$i$") seaborn.despine(ax=axis) return figure, figure_name
def plot_centroid_probabilities_evolution(probabilities, distribution, linestyle="solid", name=None): distribution = normalise_string(distribution) y_label = _axis_label_for_symbol(symbol="\\pi", distribution=distribution, suffix="^k") figure_name = "centroids_evolution-{}-probabilities".format(distribution) figure_name = saving.build_figure_name(figure_name, name) n_epochs, n_centroids = probabilities.shape centroids_palette = style.darker_palette(n_centroids) epochs = numpy.arange(n_epochs) + 1 figure = pyplot.figure() axis = figure.add_subplot(1, 1, 1) seaborn.despine() for k in range(n_centroids): axis.plot(epochs, probabilities[:, k], color=centroids_palette[k], linestyle=linestyle, label="$k = {}$".format(k)) axis.set_xlabel("Epochs") axis.set_ylabel(y_label) axis.legend(loc="best") return figure, figure_name
def plot_model_metric_sets(metrics_sets, x_key, y_key, x_label=None, y_label=None, primary_differentiator_key=None, primary_differentiator_order=None, secondary_differentiator_key=None, secondary_differentiator_order=None, special_cases=None, other_method_metrics=None, palette=None, marker_styles=None, name=None): figure_name = saving.build_figure_name("model_metric_sets", name) if other_method_metrics: figure_name += "-other_methods" if not isinstance(metrics_sets, list): metrics_sets = [metrics_sets] if not palette: palette = style.STANDARD_PALETTE.copy() if not marker_styles: marker_styles = [ "X", # cross "s", # square "D", # diamond "o", # circle "P", # plus "^", # upright triangle "p", # pentagon "*", # star ] figure = pyplot.figure() axis = figure.add_subplot(1, 1, 1) seaborn.despine() axis.set_xlabel(x_label) axis.set_ylabel(y_label) colours = {} markers = {} for metrics_set in metrics_sets: x = numpy.array(metrics_set[x_key]) y = numpy.array(metrics_set[y_key]) if x.dtype == "object" or y.dtype == "object": continue x_mean = x.mean() x_ddof = 1 if x.size > 1 else 0 x_sd = x.std(ddof=x_ddof) y_mean = y.mean() y_ddof = 1 if y.size > 1 else 0 y_sd = y.std(ddof=y_ddof) colour_key = metrics_set[primary_differentiator_key] if colour_key in colours: colour = colours[colour_key] else: try: index = primary_differentiator_order.index(colour_key) colour = palette[index] except (ValueError, IndexError): colour = "black" colours[colour_key] = colour axis.errorbar(x=x_mean, y=y_mean, yerr=y_sd, xerr=x_sd, capsize=2, linestyle="", color=colour, label=colour_key, markersize=7) marker_key = metrics_set[secondary_differentiator_key] if marker_key in markers: marker = markers[marker_key] else: try: index = secondary_differentiator_order.index(marker_key) marker = marker_styles[index] except (ValueError, IndexError): marker = None markers[marker_key] = marker axis.errorbar(x_mean, y_mean, color="black", marker=marker, linestyle="none", label=marker_key) errorbar_colour = colour darker_colour = seaborn.dark_palette(colour, n_colors=4)[2] special_case_changes = special_cases.get(colour_key, {}) special_case_changes.update(special_cases.get(marker_key, {})) for object_name, object_change in special_case_changes.items(): if object_name == "errorbar_colour": if object_change == "darken": errorbar_colour = darker_colour axis.errorbar(x=x_mean, y=y_mean, yerr=y_sd, xerr=x_sd, ecolor=errorbar_colour, capsize=2, color=colour, marker=marker, markeredgecolor=darker_colour, markersize=7) baseline_line_styles = ["dashed", "dotted", "dashdotted"] if other_method_metrics: for method_name, metric_values in other_method_metrics.items(): y_values = metric_values.get(y_key, None) if not y_values: continue y = numpy.array(y_values) y_mean = y.mean() if y.shape[0] > 1: y_sd = y.std(ddof=1) else: y_sd = None line_style = baseline_line_styles.pop(0) axis.axhline(y=y_mean, color=style.STANDARD_PALETTE[-1], linestyle=line_style, label=method_name, zorder=-1) if y_sd is not None: axis.axhspan(ymin=y_mean - y_sd, ymax=y_mean + y_sd, facecolor=style.STANDARD_PALETTE[-1], alpha=0.1, edgecolor=None, label=method_name, zorder=-2) if len(metrics_sets) > 1: order = primary_differentiator_order + secondary_differentiator_order handles, labels = axis.get_legend_handles_labels() label_handles = {} for label, handle in zip(labels, handles): label_handles.setdefault(label, []) label_handles[label].append(handle) labels, handles = [], [] for label, handle_set in label_handles.items(): labels.append(label) handles.append(tuple(handle_set)) labels, handles = zip( *sorted(zip(labels, handles), key=lambda l: [order.index(l[0]), l[0]] if l[0] in order else [len(order), l[0]])) axis.legend(handles, labels, loc="best") return figure, figure_name
def plot_model_metrics(metrics_sets, key, label=None, primary_differentiator_key=None, primary_differentiator_order=None, secondary_differentiator_key=None, secondary_differentiator_order=None, palette=None, marker_styles=None, name=None): figure_name = saving.build_figure_name("model_metrics", name) if not isinstance(metrics_sets, list): metrics_sets = [metrics_sets] if not palette: palette = style.STANDARD_PALETTE.copy() figure = pyplot.figure() axis = figure.add_subplot(1, 1, 1) seaborn.despine() axis.set_xlabel(primary_differentiator_key.capitalize() + "s") axis.set_ylabel(label) x_positions = {} x_offsets = {} colours = {} x_gap = 3 x_scale = len(secondary_differentiator_order) - 1 + 2 * x_gap for metrics_set in metrics_sets: y = numpy.array(metrics_set[key]) if y.dtype == "object": continue y_mean = y.mean() y_ddof = 1 if y.size > 1 else 0 y_sd = y.std(ddof=y_ddof) x_position_key = metrics_set[primary_differentiator_key] if x_position_key in x_positions: x_position = x_positions[x_position_key] else: try: index = primary_differentiator_order.index(x_position_key) x_position = index except (ValueError, IndexError): x_position = 0 x_positions[x_position_key] = x_position x_offset_key = metrics_set[secondary_differentiator_key] if x_offset_key in x_offsets: x_offset = x_offsets[x_offset_key] else: try: index = secondary_differentiator_order.index(x_offset_key) x_offset = (index + x_gap - x_scale / 2) / x_scale except (ValueError, IndexError): x_offset = 0 x_offsets[x_offset_key] = x_offset x = x_position + x_offset colour_key = x_offset_key if colour_key in colours: colour = colours[colour_key] else: try: index = secondary_differentiator_order.index(colour_key) colour = palette[index] except (ValueError, IndexError): colour = "black" colours[colour_key] = colour axis.errorbar(x=x, y=y_mean, yerr=y_sd, capsize=2, linestyle="", color=colour, label=colour_key) axis.errorbar( x=x, y=y_mean, yerr=y_sd, ecolor=colour, capsize=2, color=colour, marker="_", # markeredgecolor=darker_colour, # markersize=7 ) x_ticks = [] x_tick_labels = [] for model, x_position in x_positions.items(): x_ticks.append(x_position) x_tick_labels.append(model) axis.set_xticks(x_ticks) axis.set_xticklabels(x_tick_labels) if len(metrics_sets) > 1: order = primary_differentiator_order + secondary_differentiator_order handles, labels = axis.get_legend_handles_labels() label_handles = {} for label, handle in zip(labels, handles): label_handles.setdefault(label, []) label_handles[label].append(handle) labels, handles = [], [] for label, handle_set in label_handles.items(): labels.append(label) handles.append(tuple(handle_set)) labels, handles = zip( *sorted(zip(labels, handles), key=lambda l: [order.index(l[0]), l[0]] if l[0] in order else [len(order), l[0]])) axis.legend(handles, labels, loc="best") return figure, figure_name
def plot_heat_map(values, x_name, y_name, z_name=None, z_symbol=None, z_min=None, z_max=None, symmetric=False, labels=None, label_kind=None, center=None, name=None): figure_name = saving.build_figure_name("heat_map", name) n_examples, n_features = values.shape if symmetric and n_examples != n_features: raise ValueError( "Input cannot be symmetric, when it is not given as a 2-d square" "array or matrix.") figure = pyplot.figure() axis = figure.add_subplot(1, 1, 1) if not z_min: z_min = values.min() if not z_max: z_max = values.max() if z_symbol: z_name = "$" + z_symbol + "$" cbar_dict = {} if z_name: cbar_dict["label"] = z_name if not symmetric: aspect_ratio = n_examples / n_features square_cells = 1 / 5 < aspect_ratio and aspect_ratio < 5 else: square_cells = True if labels is not None: x_indices = numpy.argsort(labels) y_name += " sorted" if label_kind: y_name += " by " + label_kind else: x_indices = numpy.arange(n_examples) if symmetric: y_indices = x_indices x_name = y_name else: y_indices = numpy.arange(n_features) seaborn.set(style="white") seaborn.heatmap(values[x_indices][:, y_indices], vmin=z_min, vmax=z_max, center=center, xticklabels=False, yticklabels=False, cbar=True, cbar_kws=cbar_dict, cmap=style.STANDARD_COLOUR_MAP, square=square_cells, ax=axis) style.reset_plot_look() axis.set_xlabel(x_name) axis.set_ylabel(y_name) return figure, figure_name
def plot_matrix(feature_matrix, plot_distances=False, center_value=None, example_label=None, feature_label=None, value_label=None, sorting_method=None, distance_metric="Euclidean", labels=None, label_kind=None, class_palette=None, feature_indices_for_plotting=None, hide_dendrogram=False, name_parts=None): figure_name = saving.build_figure_name(name_parts) n_examples, n_features = feature_matrix.shape if plot_distances: center_value = None feature_label = None value_label = "Pairwise {} distances in {} space".format( distance_metric, value_label) if not plot_distances and feature_indices_for_plotting is None: feature_indices_for_plotting = numpy.arange(n_features) if sorting_method == "labels" and labels is None: raise ValueError("No labels provided to sort after.") if labels is not None and not class_palette: raise ValueError("No class palette provided.") # Distances (if needed) distances = None if plot_distances or sorting_method == "hierarchical_clustering": distances = sklearn.metrics.pairwise_distances( feature_matrix, metric=distance_metric.lower()) # Figure initialisation figure = pyplot.figure() axis_heat_map = figure.add_subplot(1, 1, 1) left_most_axis = axis_heat_map divider = make_axes_locatable(axis_heat_map) axis_colour_map = divider.append_axes("right", size="5%", pad=0.1) axis_labels = None axis_dendrogram = None if labels is not None: axis_labels = divider.append_axes("left", size="5%", pad=0.01) left_most_axis = axis_labels if sorting_method == "hierarchical_clustering" and not hide_dendrogram: axis_dendrogram = divider.append_axes("left", size="20%", pad=0.01) left_most_axis = axis_dendrogram # Label colours if labels is not None: label_colours = [ tuple(colour) if isinstance(colour, list) else colour for colour in [class_palette[l] for l in labels] ] unique_colours = [ tuple(colour) if isinstance(colour, list) else colour for colour in class_palette.values() ] value_for_colour = { colour: i for i, colour in enumerate(unique_colours) } label_colour_matrix = numpy.array([ value_for_colour[colour] for colour in label_colours ]).reshape(n_examples, 1) label_colour_map = matplotlib.colors.ListedColormap(unique_colours) else: label_colour_matrix = None label_colour_map = None # Heat map aspect ratio if not plot_distances: square_cells = False else: square_cells = True seaborn.set(style="white") # Sorting and optional dendrogram if sorting_method == "labels": example_indices = numpy.argsort(labels) if not label_kind: label_kind = "labels" if example_label: example_label += " sorted by " + label_kind elif sorting_method == "hierarchical_clustering": linkage = scipy.cluster.hierarchy.linkage( scipy.spatial.distance.squareform(distances, checks=False), metric="average") dendrogram = seaborn.matrix.dendrogram(distances, linkage=linkage, metric=None, method="ward", axis=0, label=False, rotate=True, ax=axis_dendrogram) example_indices = dendrogram.reordered_ind if example_label: example_label += " sorted by hierarchical clustering" elif sorting_method is None: example_indices = numpy.arange(n_examples) else: raise ValueError("`sorting_method` should be either \"labels\"" " or \"hierarchical clustering\"") # Heat map of values if plot_distances: plot_values = distances[example_indices][:, example_indices] else: plot_values = feature_matrix[ example_indices][:, feature_indices_for_plotting] if scipy.sparse.issparse(plot_values): plot_values = plot_values.A colour_bar_dictionary = {} if value_label: colour_bar_dictionary["label"] = value_label seaborn.heatmap(plot_values, center=center_value, xticklabels=False, yticklabels=False, cbar=True, cbar_kws=colour_bar_dictionary, cbar_ax=axis_colour_map, square=square_cells, ax=axis_heat_map) # Colour labels if axis_labels: seaborn.heatmap(label_colour_matrix[example_indices], xticklabels=False, yticklabels=False, cbar=False, cmap=label_colour_map, ax=axis_labels) style.reset_plot_look() # Axis labels if example_label: left_most_axis.set_ylabel(example_label) if feature_label: axis_heat_map.set_xlabel(feature_label) return figure, figure_name
def plot_learning_curves(curves, model_type, epoch_offset=0, name=None): figure_name = "learning_curves" figure_name = saving.build_figure_name("learning_curves", name) x_label = "Epoch" y_label = "Nat" if model_type == "AE": figure = pyplot.figure() axis_1 = figure.add_subplot(1, 1, 1) elif model_type == "VAE": figure, (axis_1, axis_2) = pyplot.subplots(nrows=2, sharex=True, figsize=(6.4, 9.6)) figure.subplots_adjust(hspace=0.1) elif model_type == "GMVAE": figure, (axis_1, axis_2, axis_3) = pyplot.subplots(nrows=3, sharex=True, figsize=(6.4, 14.4)) figure.subplots_adjust(hspace=0.1) for curve_set_name, curve_set in sorted(curves.items()): if curve_set_name == "training": line_style = "solid" colour_index_offset = 0 elif curve_set_name == "validation": line_style = "dashed" colour_index_offset = 1 def curve_colour(i): return style.STANDARD_PALETTE[len(curves) * i + colour_index_offset] for curve_name, curve in sorted(curve_set.items()): if curve is None: continue elif curve_name == "lower_bound": curve_name = "$\\mathcal{L}$" colour = curve_colour(0) axis = axis_1 elif curve_name == "reconstruction_error": curve_name = "$\\log p(x|z)$" colour = curve_colour(1) axis = axis_1 elif "kl_divergence" in curve_name: if curve_name == "kl_divergence": index = "" colour = curve_colour(0) axis = axis_2 else: latent_variable = curve_name.replace("kl_divergence_", "") latent_variable = re.sub(pattern=r"(\w)(\d)", repl=r"\1_\2", string=latent_variable) index = "$_{" + latent_variable + "}$" if latent_variable in ["z", "z_2"]: colour = curve_colour(0) axis = axis_2 elif latent_variable == "z_1": colour = curve_colour(1) axis = axis_2 elif latent_variable == "y": colour = curve_colour(0) axis = axis_3 curve_name = "KL" + index + "$(q||p)$" elif curve_name == "log_likelihood": curve_name = "$L$" axis = axis_1 epochs = numpy.arange(len(curve)) + epoch_offset + 1 label = "{} ({} set)".format(curve_name, curve_set_name) axis.plot(epochs, curve, color=colour, linestyle=line_style, label=label) handles, labels = axis_1.get_legend_handles_labels() labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0])) axis_1.legend(handles, labels, loc="best") if model_type == "AE": axis_1.set_xlabel(x_label) axis_1.set_ylabel(y_label) elif model_type == "VAE": handles, labels = axis_2.get_legend_handles_labels() labels, handles = zip( *sorted(zip(labels, handles), key=lambda t: t[0])) axis_2.legend(handles, labels, loc="best") axis_1.set_ylabel("") axis_2.set_ylabel("") if model_type == "GMVAE": axis_3.legend(loc="best") handles, labels = axis_3.get_legend_handles_labels() labels, handles = zip( *sorted(zip(labels, handles), key=lambda t: t[0])) axis_3.legend(handles, labels, loc="best") axis_3.set_xlabel(x_label) axis_3.set_ylabel("") else: axis_2.set_xlabel(x_label) figure.text(-0.01, 0.5, y_label, va="center", rotation="vertical") seaborn.despine() return figure, figure_name
def plot_histogram(series, excess_zero_count=0, label=None, normed=False, discrete=False, x_scale="linear", y_scale="linear", colour=None, name=None): series = series.copy() figure_name = "histogram" if normed: figure_name += "-normed" figure_name = saving.build_figure_name(figure_name, name) figure = pyplot.figure() axis = figure.add_subplot(1, 1, 1) seaborn.despine() series_length = len(series) + excess_zero_count series_max = series.max() if discrete and series_max < MAXIMUM_NUMBER_OF_BINS_FOR_HISTOGRAMS: number_of_bins = int(numpy.ceil(series_max)) + 1 bin_range = numpy.array((-0.5, series_max + 0.5)) else: if series_max < MAXIMUM_NUMBER_OF_BINS_FOR_HISTOGRAMS: number_of_bins = "auto" else: number_of_bins = MAXIMUM_NUMBER_OF_BINS_FOR_HISTOGRAMS bin_range = numpy.array((series.min(), series_max)) if colour is None: colour = style.STANDARD_PALETTE[0] if x_scale == "log": series += 1 bin_range += 1 label += " (shifted one)" figure_name += "-log_values" y_log = y_scale == "log" histogram, bin_edges = numpy.histogram( series, bins=number_of_bins, range=bin_range ) histogram[0] += excess_zero_count width = bin_edges[1] - bin_edges[0] bin_centres = bin_edges[:-1] + width / 2 if normed: histogram = histogram / series_length axis.bar( bin_centres, histogram, width=width, log=y_log, color=colour, alpha=0.4 ) axis.set_xscale(x_scale) axis.set_xlabel(capitalise_string(label)) if normed: axis.set_ylabel("Frequency") else: axis.set_ylabel("Number of counts") return figure, figure_name
def plot_class_histogram(labels, class_names=None, class_palette=None, normed=False, scale="linear", label_sorter=None, name=None): figure_name = "histogram" if normed: figure_name += "-normed" figure_name += "-classes" figure_name = saving.build_figure_name(figure_name, name) figure = pyplot.figure() axis = figure.add_subplot(1, 1, 1) seaborn.despine() if class_names is None: class_names = numpy.unique(labels) n_classes = len(class_names) if not class_palette: index_palette = style.lighter_palette(n_classes) class_palette = { class_name: index_palette[i] for i, class_name in enumerate(sorted( class_names, key=label_sorter)) } histogram = { class_name: { "index": i, "count": 0, "colour": class_palette[class_name] } for i, class_name in enumerate(sorted( class_names, key=label_sorter)) } total_count_sum = 0 for label in labels: histogram[label]["count"] += 1 total_count_sum += 1 indices = [] class_names = [] for class_name, class_values in sorted(histogram.items()): index = class_values["index"] count = class_values["count"] frequency = count / total_count_sum colour = class_values["colour"] indices.append(index) class_names.append(class_name) if normed: count_or_frequecny = frequency else: count_or_frequecny = count axis.bar(index, count_or_frequecny, color=colour) axis.set_yscale(scale) maximum_class_name_width = max([ len(str(class_name)) for class_name in class_names if class_name not in ["No class"] ]) if maximum_class_name_width > 5: y_ticks_rotation = 45 y_ticks_horizontal_alignment = "right" y_ticks_rotation_mode = "anchor" else: y_ticks_rotation = 0 y_ticks_horizontal_alignment = "center" y_ticks_rotation_mode = None pyplot.xticks( ticks=indices, labels=class_names, horizontalalignment=y_ticks_horizontal_alignment, rotation=y_ticks_rotation, rotation_mode=y_ticks_rotation_mode ) axis.set_xlabel("Classes") if normed: axis.set_ylabel("Frequency") else: axis.set_ylabel("Number of counts") return figure, figure_name
def plot_separate_learning_curves(curves, loss, name=None): if not isinstance(loss, list): losses = [loss] else: losses = loss if not isinstance(name, list): names = [name] else: names = name names.extend(losses) figure_name = saving.build_figure_name("learning_curves", names) x_label = "Epoch" y_label = "Nat" figure = pyplot.figure() axis = figure.add_subplot(1, 1, 1) seaborn.despine() for curve_set_name, curve_set in sorted(curves.items()): if curve_set_name == "training": line_style = "solid" colour_index_offset = 0 elif curve_set_name == "validation": line_style = "dashed" colour_index_offset = 1 def curve_colour(i): return style.STANDARD_PALETTE[len(curves) * i + colour_index_offset] for curve_name, curve in sorted(curve_set.items()): if curve is None or curve_name not in losses: continue elif curve_name == "lower_bound": curve_name = "$\\mathcal{L}$" colour = curve_colour(0) elif curve_name == "reconstruction_error": curve_name = "$\\log p(x|z)$" colour = curve_colour(1) elif "kl_divergence" in curve_name: if curve_name == "kl_divergence": index = "" colour = curve_colour(0) else: latent_variable = curve_name.replace("kl_divergence_", "") latent_variable = re.sub(pattern=r"(\w)(\d)", repl=r"\1_\2", string=latent_variable) index = "$_{" + latent_variable + "}$" if latent_variable in ["z", "z_2"]: colour = curve_colour(0) elif latent_variable == "z_1": colour = curve_colour(1) elif latent_variable == "y": colour = curve_colour(0) curve_name = "KL" + index + "$(q||p)$" elif curve_name == "log_likelihood": curve_name = "$L$" epochs = numpy.arange(len(curve)) + 1 label = curve_name + " ({} set)".format(curve_set_name) axis.plot(epochs, curve, color=colour, linestyle=line_style, label=label) handles, labels = axis.get_legend_handles_labels() labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0])) axis.legend(handles, labels, loc="best") axis.set_xlabel(x_label) axis.set_ylabel(y_label) return figure, figure_name
def plot_centroid_means_evolution(means, distribution, decomposed=False, name=None): symbol = "\\mu" if decomposed: decomposition_method = "PCA" else: decomposition_method = "" distribution = normalise_string(distribution) suffix = "(y = k)" x_label = _axis_label_for_symbol(symbol=symbol, coordinate=1, decomposition_method=decomposition_method, distribution=distribution, suffix=suffix) y_label = _axis_label_for_symbol(symbol=symbol, coordinate=2, decomposition_method=decomposition_method, distribution=distribution, suffix=suffix) figure_name = "centroids_evolution-{}-means".format(distribution) figure_name = saving.build_figure_name(figure_name, name) n_epochs, n_centroids, latent_size = means.shape if latent_size > 2: raise ValueError("Dimensions of means should be 2.") centroids_palette = style.darker_palette(n_centroids) epochs = numpy.arange(n_epochs) + 1 figure = pyplot.figure() axis = figure.add_subplot(1, 1, 1) seaborn.despine() colour_bar_scatter_plot = axis.scatter(means[:, 0, 0], means[:, 0, 1], c=epochs, cmap=seaborn.dark_palette( style.NEUTRAL_COLOUR, as_cmap=True), zorder=0) for k in range(n_centroids): colour = centroids_palette[k] colour_map = seaborn.dark_palette(colour, as_cmap=True) axis.plot(means[:, k, 0], means[:, k, 1], color=colour, label="$k = {}$".format(k), zorder=k + 1) axis.scatter(means[:, k, 0], means[:, k, 1], c=epochs, cmap=colour_map, zorder=n_centroids + k + 1) axis.legend(loc="best") colour_bar = figure.colorbar(colour_bar_scatter_plot) colour_bar.outline.set_linewidth(0) colour_bar.set_label("Epochs") axis.set_xlabel(x_label) axis.set_ylabel(y_label) return figure, figure_name
def plot_profile_comparison(observed_series, expected_series, expected_series_total_standard_deviations=None, expected_series_explained_standard_deviations=None, x_name="feature", y_name="value", sort=True, sort_by="expected", sort_direction="ascending", x_scale="linear", y_scale="linear", y_cutoff=None, name=None): sort_by = normalise_string(sort_by) sort_direction = normalise_string(sort_direction) figure_name = saving.build_figure_name("profile_comparison", name) if scipy.sparse.issparse(observed_series): observed_series = observed_series.A.squeeze() if scipy.sparse.issparse(expected_series_total_standard_deviations): expected_series_total_standard_deviations = ( expected_series_total_standard_deviations.A.squeeze()) if scipy.sparse.issparse(expected_series_explained_standard_deviations): expected_series_explained_standard_deviations = ( expected_series_explained_standard_deviations.A.squeeze()) observed_colour = style.STANDARD_PALETTE[0] expected_palette = seaborn.light_palette(style.STANDARD_PALETTE[1], 5) expected_colour = expected_palette[-1] expected_total_standard_deviations_colour = expected_palette[1] expected_explained_standard_deviations_colour = expected_palette[3] if sort: x_label = "{}s sorted {} by {} {}s [sort index]".format( capitalise_string(x_name), sort_direction, sort_by, y_name.lower()) else: x_label = "{}s [original index]".format(capitalise_string(x_name)) y_label = capitalise_string(y_name) + "s" observed_label = "Observed" expected_label = "Expected" expected_total_standard_deviations_label = "Total standard deviation" expected_explained_standard_deviations_label = ( "Explained standard deviation") # Sorting if sort_by == "expected": sort_series = expected_series expected_marker = "" expected_line_style = "solid" expected_z_order = 3 observed_marker = "o" observed_line_style = "" observed_z_order = 2 elif sort_by == "observed": sort_series = observed_series expected_marker = "o" expected_line_style = "" expected_z_order = 2 observed_marker = "" observed_line_style = "solid" observed_z_order = 3 if sort: sort_indices = numpy.argsort(sort_series) if sort_direction == "descending": sort_indices = sort_indices[::-1] elif sort_direction != "ascending": raise ValueError( "Sort direction can either be ascending or descending.") else: sort_indices = slice(None) # Standard deviations if expected_series_total_standard_deviations is not None: with_total_standard_deviations = True expected_series_total_standard_deviations_lower = ( expected_series - expected_series_total_standard_deviations) expected_series_total_standard_deviations_upper = ( expected_series + expected_series_total_standard_deviations) else: with_total_standard_deviations = False if (expected_series_explained_standard_deviations is not None and expected_series_explained_standard_deviations.mean() > 0): with_explained_standard_deviations = True expected_series_explained_standard_deviations_lower = ( expected_series - expected_series_explained_standard_deviations) expected_series_explained_standard_deviations_upper = ( expected_series + expected_series_explained_standard_deviations) else: with_explained_standard_deviations = False # Figure if y_scale == "both": figure, axes = pyplot.subplots(nrows=2, sharex=True) figure.subplots_adjust(hspace=0.1) axis_upper = axes[0] axis_lower = axes[1] axis_upper.set_zorder = 1 axis_lower.set_zorder = 0 else: figure = pyplot.figure() axis = figure.add_subplot(1, 1, 1) axes = [axis] handles = [] feature_indices = numpy.arange(len(observed_series)) + 1 for i, axis in enumerate(axes): observed_plot, = axis.plot( feature_indices, observed_series[sort_indices], label=observed_label, color=observed_colour, marker=observed_marker, linestyle=observed_line_style, zorder=observed_z_order ) if i == 0: handles.append(observed_plot) expected_plot, = axis.plot( feature_indices, expected_series[sort_indices], label=expected_label, color=expected_colour, marker=expected_marker, linestyle=expected_line_style, zorder=expected_z_order ) if i == 0: handles.append(expected_plot) if with_total_standard_deviations: axis.fill_between( feature_indices, expected_series_total_standard_deviations_lower[sort_indices], expected_series_total_standard_deviations_upper[sort_indices], color=expected_total_standard_deviations_colour, zorder=0 ) expected_plot_standard_deviations_values = ( matplotlib.patches.Patch( label=expected_total_standard_deviations_label, color=expected_total_standard_deviations_colour ) ) if i == 0: handles.append(expected_plot_standard_deviations_values) if with_explained_standard_deviations: axis.fill_between( feature_indices, expected_series_explained_standard_deviations_lower[ sort_indices], expected_series_explained_standard_deviations_upper[ sort_indices], color=expected_explained_standard_deviations_colour, zorder=1 ) expected_plot_standard_deviations_expectations = ( matplotlib.patches.Patch( label=expected_explained_standard_deviations_label, color=expected_explained_standard_deviations_colour ) ) if i == 0: handles.append(expected_plot_standard_deviations_expectations) if y_scale == "both": axis_upper.legend( handles=handles, loc="best" ) seaborn.despine(ax=axis_upper) seaborn.despine(ax=axis_lower) axis_upper.set_yscale("log", nonposy="clip") axis_lower.set_yscale("linear") figure.text(0.04, 0.5, y_label, va="center", rotation="vertical") axis_lower.set_xscale(x_scale) axis_lower.set_xlabel(x_label) y_upper_min, y_upper_max = axis_upper.get_ylim() y_lower_min, y_lower_max = axis_lower.get_ylim() axis_upper.set_ylim(y_cutoff, y_upper_max) y_lower_min = max(-1, y_lower_min) axis_lower.set_ylim(y_lower_min, y_cutoff) else: axis.legend( handles=handles, loc="best" ) seaborn.despine() y_scale_arguments = {} if y_scale == "log": y_scale_arguments["nonposy"] = "clip" axis.set_yscale(y_scale, **y_scale_arguments) axis.set_ylabel(y_label) axis.set_xscale(x_scale) axis.set_xlabel(x_label) y_min, y_max = axis.get_ylim() y_min = max(-1, y_min) if y_cutoff: if y_scale == "linear": y_max = y_cutoff elif y_scale == "log": y_min = y_cutoff axis.set_ylim(y_min, y_max) return figure, figure_name