def plot_Cnorm(cnorm, labels, Ptrue=[0.1, 0.5], ax=None, title=None, fontsize=12): from matplotlib import pyplot as plt cmap = plt.cm.Blues cnorm = cnorm.astype('float32') if not isinstance(Ptrue, (tuple, list, np.ndarray)): Ptrue = (Ptrue, ) Ptrue = [float(i) for i in Ptrue] if len(Ptrue) != cnorm.shape[0]: raise ValueError( "`Cnorm` was calculated for %d Ptrue values, but given only " "%d values for `Ptrue`: %s" % (cnorm.shape[0], len(Ptrue), str(Ptrue))) ax = to_axis(ax, is_3D=False) ax.imshow(cnorm, interpolation='nearest', cmap=cmap) # axis.get_figure().colorbar(im) ax.set_xticks(np.arange(len(labels))) ax.set_yticks(np.arange(len(Ptrue))) ax.set_xticklabels(labels, rotation=-57, fontsize=fontsize) ax.set_yticklabels([str(i) for i in Ptrue], fontsize=fontsize) ax.set_ylabel('Ptrue', fontsize=fontsize) ax.set_xlabel('Predicted label', fontsize=fontsize) # center text for value of each grid for i, j in itertools.product(range(len(Ptrue)), range(len(labels))): color = 'red' weight = 'normal' fs = fontsize text = '%.2f' % cnorm[i, j] plt.text(j, i, text, weight=weight, color=color, fontsize=fs, verticalalignment="center", horizontalalignment="center") # Turns off grid on the left Axis. ax.grid(False) title = "Cnorm: %.6f" % np.mean(cnorm) if title is None else \ "%s (Cnorm: %.6f)" % (str(title), np.mean(cnorm)) ax.set_title(title, fontsize=fontsize + 2, weight='semibold') # axis.tight_layout() return ax
def plot_distance_heatmap(X, labels, lognorm=True, colormap='hot', ax=None, legend_enable=True, legend_loc='upper center', legend_ncol=3, legend_colspace=0.2, fontsize=10, cbar=True, title=None): r""" Arguments: X : (n_samples, n_features). Coordination for scatter points labels : (n_samples,). List of classes index or name """ import seaborn as sns from matplotlib import pyplot as plt from matplotlib.lines import Line2D from odin import backend as K # prepare data X = K.length_norm(X, axis=-1, epsilon=np.finfo(X.dtype).eps) ax = to_axis(ax) n_samples, n_dim = X.shape # processing labels labels = np.array(labels).ravel() assert labels.shape[0] == n_samples, "labels must be 1-D array." is_continuous = isinstance(labels[0], Number) and int(labels[0]) != labels[0] # float values label (normalize -1 to 1) or binary classification if is_continuous: min_val = np.min(labels) max_val = np.max(labels) labels = 2 * (labels - min_val) / (max_val - min_val) - 1 n_labels = 2 labels_name = {'-1': 0, '+1': 1} else: labels_name = {name: i for i, name in enumerate(np.unique(labels))} n_labels = len(labels_name) labels = np.array([labels_name[name] for name in labels]) # ====== sorting label and X ====== # order_X = np.vstack( [x for _, x in sorted(zip(labels, X), key=lambda pair: pair[0])]) order_label = np.vstack( [y for y, x in sorted(zip(labels, X), key=lambda pair: pair[0])]) distance = sp.spatial.distance_matrix(order_X, order_X) if bool(lognorm): distance = np.log1p(distance) min_non_zero = np.min(distance[np.nonzero(distance)]) distance = np.clip(distance, a_min=min_non_zero, a_max=np.max(distance)) # ====== convert data to image ====== # cm = plt.get_cmap(colormap) distance_img = cm(distance) # diagonal black line (i.e. zero distance) # for i in range(n_samples): # distance_img[i, i] = (0, 0, 0, 1) # labels colormap width = max(int(0.032 * n_samples), 8) if n_labels == 2: cm = plt.get_cmap('bwr') horz_bar = np.repeat(cm(order_label.T), repeats=width, axis=0) vert_bar = np.repeat(cm(order_label), repeats=width, axis=1) all_colors = np.array((cm(np.min(labels)), cm(np.max(labels)))) else: # use seaborn color palette here is better cm = [i + (1., ) for i in sns.color_palette(n_colors=n_labels)] c = np.stack([cm[i] for i in order_label.ravel()]) horz_bar = np.repeat(np.expand_dims(c, 0), repeats=width, axis=0) vert_bar = np.repeat(np.expand_dims(c, 1), repeats=width, axis=1) all_colors = cm # image final_img = np.zeros(shape=(n_samples + width, n_samples + width, distance_img.shape[2]), dtype=distance_img.dtype) final_img[width:, width:] = distance_img final_img[:width, width:] = horz_bar final_img[width:, :width] = vert_bar assert np.sum(final_img[:width, :width]) == 0, \ "Something wrong with my spacial coordination when writing this code!" # ====== plotting ====== # ax.imshow(final_img) ax.axis('off') # ====== legend ====== # if bool(legend_enable): legend_elements = [ Line2D([0], [0], marker='o', color=color, label=name, linewidth=0, linestyle=None, lw=0, markerfacecolor=color, markersize=fontsize // 2) for color, name in zip(all_colors, labels_name.keys()) ] ax.legend(handles=legend_elements, markerscale=1., scatterpoints=1, scatteryoffsets=[0.375, 0.5, 0.3125], loc=legend_loc, bbox_to_anchor=(0.5, -0.01), ncol=int(legend_ncol), columnspacing=float(legend_colspace), labelspacing=0., fontsize=fontsize - 1, handletextpad=0.1) # ====== final configurations ====== # if title is not None: ax.set_title(str(title), fontsize=fontsize) if cbar: from odin.visual import plot_colorbar plot_colorbar(colormap, vmin=np.min(distance), vmax=np.max(distance), ax=ax, orientation='vertical') return ax
def plot_scatter_layers(x_y_val, ax=None, layer_name=None, layer_color=None, layer_marker=None, size=4.0, z_ratio=4, elev=None, azim=88, ticks_off=True, grid=True, surface=True, wireframe=False, wireframe_resolution=10, colorbar=False, colorbar_horizontal=False, legend_loc='upper center', legend_ncol=3, legend_colspace=0.4, fontsize=8, title=None): r""" Parameter --------- z_ratio: float (default: 4) the amount of compression that layer in z_axis will be closer to each others compared to (x, y) axes """ from matplotlib import pyplot as plt assert len(x_y_val) > 1, "Use `plot_scatter_heatmap` to plot only 1 layer" max_z = -np.inf min_z = np.inf for x, y, val in x_y_val: assert len(x) == len(y) == len(val), "Number of samples mismatch" max_z = max(max_z, np.max(x), np.max(y)) min_z = min(min_z, np.min(x), np.min(y)) ax = to_axis(ax, is_3D=True) num_classes = len(x_y_val) # ====== preparing ====== # # name layer_name = check_arg_length(dat=layer_name, n=num_classes, dtype=string_types, default='', converter=lambda x: str(x)) # colormap layer_color = check_arg_length(dat=layer_color, n=num_classes, dtype=string_types, default='Blues', converter=lambda x: plt.get_cmap(str(x))) # class marker layer_marker = check_arg_length(dat=layer_marker, n=num_classes, dtype=string_types, default='o', converter=lambda x: str(x)) # size size = check_arg_length(dat=size, n=num_classes, dtype=Number, default=4.0, converter=lambda x: float(x)) # ====== plotting each class ====== # legends = [] for idx, (alpha, z) in enumerate( zip(np.linspace(0.05, 0.4, num_classes), np.linspace(min_z / 4, max_z / 4, num_classes))): x, y, val = x_y_val[idx] num_samples = len(x) z = np.full(shape=(num_samples,), fill_value=z) _ = ax.scatter(x, y, z, c=val, s=size[idx], marker=layer_marker[idx], cmap=layer_color[idx]) # ploting surface and wireframe if surface or wireframe: x, y = np.meshgrid(np.linspace(min(x), max(x), wireframe_resolution), np.linspace(min(y), max(y), wireframe_resolution)) z = np.full_like(x, fill_value=z[0]) if surface: ax.plot_surface(X=x, Y=y, Z=z, color=layer_color[idx](0.5), edgecolor='none', alpha=alpha) if wireframe: ax.plot_wireframe(X=x, Y=y, Z=z, linewidth=0.8, color=layer_color[idx](0.8), alpha=alpha + 0.1) # legend name = layer_name[idx] if len(name) > 0: legends.append((name, _)) # colorbar if colorbar: cba = plt.colorbar( _, shrink=0.5, pad=0.01, orientation='horizontal' if colorbar_horizontal else 'vertical') if len(name) > 0: cba.set_label(name, fontsize=fontsize) # ====== plot the legend ====== # if len(legends) > 0: legends = ax.legend([i[1] for i in legends], [i[0] for i in legends], markerscale=1.5, scatterpoints=1, scatteryoffsets=[0.375, 0.5, 0.3125], loc=legend_loc, bbox_to_anchor=(0.5, -0.01), ncol=int(legend_ncol), columnspacing=float(legend_colspace), labelspacing=0., fontsize=fontsize, handletextpad=0.1) for i, c in enumerate(layer_color): legends.legendHandles[i].set_color(c(.8)) # ====== some configuration ====== # if ticks_off: ax.set_xticklabels([]) ax.set_yticklabels([]) ax.set_zticklabels([]) ax.grid(grid) if title is not None: ax.set_title(str(title)) if (elev is not None or azim is not None): ax.view_init(elev=ax.elev if elev is None else elev, azim=ax.azim if azim is None else azim) return ax
def plot_heatmap(data, cmap="Blues", ax=None, xticklabels=None, yticklabels=None, xlabel=None, ylabel=None, cbar_title=None, cbar=False, fontsize=12, gridline=0, hide_spines=True, annotation=None, text_colors=dict(diag="black", minrow=None, mincol=None, maxrow=None, maxcol=None, other="black"), title=None): r""" Showing heatmap matrix """ from matplotlib import pyplot as plt ax = to_axis(ax, is_3D=False) ax.grid(False) fig = ax.get_figure() # figsize = fig.get_size_inches() # prepare labels if xticklabels is None and yticklabels is not None: xticklabels = ["X#%d" % i for i in range(data.shape[1])] if yticklabels is None and xticklabels is not None: yticklabels = ["Y#%d" % i for i in range(data.shape[0])] # Plot the heatmap im = ax.imshow(data, interpolation='nearest', cmap=cmap, aspect='equal', origin='upper') # Create colorbar if cbar: cb = plt.colorbar(im, fraction=0.02, pad=0.02) if cbar_title is not None: cb.ax.set_ylabel(cbar_title, rotation=-90, va="bottom", fontsize=fontsize) ## major ticks if xticklabels is not None and yticklabels is not None: ax.set_xticks(np.arange(data.shape[1])) ax.set_xticklabels(xticklabels, fontsize=fontsize) ax.set_yticks(np.arange(data.shape[0])) ax.set_yticklabels(list(yticklabels), fontsize=fontsize) # Let the horizontal axes labeling appear on top. ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False) # Rotate the tick labels and set their alignment. plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor") else: # turn-off all ticks ax.tick_params(top=False, bottom=False, labeltop=False, labelbottom=False) ## axis label if ylabel is not None: ax.set_ylabel(ylabel, fontsize=fontsize + 1) if xlabel is not None: ax.set_xlabel(xlabel, fontsize=fontsize + 1) ## Turn spines off if hide_spines: for edge, spine in ax.spines.items(): spine.set_visible(False) ## minor ticks and create white grid. # (if no minor ticks, the image will be cut-off) ax.set_xticks(np.arange(data.shape[1] + 1) - .5, minor=True) ax.set_yticks(np.arange(data.shape[0] + 1) - .5, minor=True) if gridline > 0: ax.grid(which="minor", color="w", linestyle='-', linewidth=gridline) ax.tick_params(which="minor", bottom=False, left=False) # set the title if title is not None: ax.set_title(str(title), fontsize=fontsize + 2, weight='semibold') # prepare the annotation if annotation is not None and annotation is not False: if annotation is True: annotation = np.array([['%.2g' % x for x in row] for row in data]) assert annotation.shape == data.shape kw = dict(horizontalalignment="center", verticalalignment="center", fontsize=fontsize) # np.log(max(2, np.mean(data.shape) - np.mean(figsize))) # Loop over the data and create a `Text` for each "pixel". # Change the text's color depending on the data. texts = [] minrow = text_colors.get('minrow', None) maxrow = text_colors.get('maxrow', None) mincol = text_colors.get('mincol', None) maxcol = text_colors.get('maxcol', None) for i in range(data.shape[0]): for j in range(data.shape[1]): # basics text config if i == j: kw['weight'] = 'bold' color = text_colors.get('diag', 'black') else: kw['weight'] = 'normal' color = text_colors.get('other', 'black') # min, max of row if data[i, j] == min(data[i]) and minrow is not None: color = minrow elif data[i, j] == max(data[i]) and maxrow is not None: color = maxrow # min, max of column if data[i, j] == min(data[:, j]) and mincol is not None: color = mincol elif data[i, j] == max(data[:, j]) and maxcol is not None: color = maxcol # show text text = im.axes.text(j, i, annotation[i, j], color=color, **kw) texts.append(text) return ax
def _plot_scatter_points(*, x, y, z, val, color, marker, size, size_range, alpha, max_n_points, cbar, cbar_horizontal, cbar_nticks, cbar_ticks_rotation, cbar_title, cbar_fontsize, legend_enable, legend_loc, legend_ncol, legend_colspace, elev, azim, ticks_off, grid, fontsize, centroids, xlabel, ylabel, title, ax, **kwargs): from matplotlib import pyplot as plt import matplotlib as mpl # keep the marker as its original text text_marker = kwargs.get('text_marker', False) x, y, z = _parse_scatterXYZ(x, y, z) assert len(x) == len(y), "Number of samples mismatch" if z is not None: assert len(y) == len(z) is_3D_mode = False if z is None else True ax = to_axis(ax, is_3D_mode) ### check the colormap if val is None: vmin, vmax, color_normalizer = None, None, None is_colormap = False else: from matplotlib.colors import LinearSegmentedColormap vmin = np.min(val) vmax = np.max(val) color_normalizer = mpl.colors.Normalize(vmin=vmin, vmax=vmax) is_colormap = True if is_colormap: assert isinstance(color, (string_types, LinearSegmentedColormap)), \ "`colormap` can be string or instance of matplotlib Colormap, " + \ "but given: %s" % type(color) if not is_colormap and isinstance(color, string_types) and color == 'bwr': color = 'b' ### perform downsample and select the styles max_n_points, x, y, z, color, marker, size = _downsample_scatter_points( x, y, z, max_n_points, color, marker, size) color, marker, size, legend = _validate_color_marker_size_legend( max_n_points, color, marker, size, text_marker=text_marker, is_colormap=is_colormap, size_range=size_range) ### centroid style centroid_style = dict(horizontalalignment='center', verticalalignment='center', fontsize=fontsize + 2, weight="bold", bbox=dict(boxstyle="circle", facecolor="black", alpha=0.48, pad=0., edgecolor='none')) ### plotting artist = [] legend_name = [] for plot_idx, (style, name) in enumerate(legend.items()): style = list(style) x_, y_, z_, val_ = [], [], [], [] # get the right set of data points for i, (c, m, s) in enumerate(zip(color, marker, size)): if c == style[0] and m == style[1] and s == style[2]: x_.append(x[i]) y_.append(y[i]) if is_colormap: val_.append(val[i]) if is_3D_mode: z_.append(z[i]) # 2D or 3D plot if not is_3D_mode: z_ = None # colormap or normal color if not is_colormap: val_ = None else: cm = plt.cm.get_cmap(style[0]) val_ = color_normalizer(val_) style[0] = cm(val_) # yield for plotting n_art = len(artist) yield ax, artist, x_, y_, z_, style # check new axis added assert len(artist) > n_art, \ "Forgot adding new art object created by plotting" # check if ploting centroid if centroids: if is_3D_mode: ax.text(np.mean(x_), np.mean(y_), np.mean(z_), s=name[0], color=style[0], **centroid_style) else: ax.text(np.mean(x_), np.mean(y_), s=name[0], color=style[0], **centroid_style) # make the shortest name name = [i for i in name if len(i) > 0] short_name = [] for i in name: if i not in short_name: short_name.append(i) name = ', '.join(short_name) if len(name) > 0: legend_name.append(name) ### at the end of the iteration, axis configuration if len(artist) == len(legend): ## colorbar (only enable when colormap is provided) if is_colormap and cbar: mappable = plt.cm.ScalarMappable(norm=color_normalizer, cmap=cm) mappable.set_clim(vmin, vmax) cba = plt.colorbar( mappable, ax=ax, shrink=0.99, pad=0.01, orientation='horizontal' if cbar_horizontal else 'vertical') if isinstance(cbar_nticks, Number): cbar_range = np.linspace(vmin, vmax, num=int(cbar_nticks)) cbar_nticks = [f'{i:.2g}' for i in cbar_range] elif isinstance(cbar_nticks, (tuple, list, np.ndarray)): cbar_range = np.linspace(vmin, vmax, num=len(cbar_nticks)) cbar_nticks = [str(i) for i in cbar_nticks] else: raise ValueError(f"No support for cbar_nticks='{cbar_nticks}'") cba.set_ticks(cbar_range) cba.set_ticklabels(cbar_nticks) if cbar_title is not None: if cbar_horizontal: # horizontal colorbar cba.ax.set_xlabel(str(cbar_title), fontsize=cbar_fontsize) else: # vertical colorbar cba.ax.set_ylabel(str(cbar_title), fontsize=cbar_fontsize) cba.ax.tick_params(labelsize=cbar_fontsize, labelrotation=cbar_ticks_rotation) ## plot the legend if len(legend_name) > 0 and bool(legend_enable): markerscale = 1.5 if isinstance(artist[0], mpl.text.Text): # text plot special case for i, art in enumerate(list(artist)): pos = [art._x, art._y] if is_3D_mode: pos.append(art._z) if is_colormap: c = art._color else: c = art._color artist[i] = ax.scatter(*pos, c=c, s=0.1) markerscale = 25 # sort the legends legend_name, artist = zip( *sorted(zip(legend_name, artist), key=lambda t: t[0])) legend_kw = {} if legend_loc is not None: legend_kw['loc'] = legend_loc if legend_ncol is not None: legend_kw['ncol'] = legend_ncol legend = ax.legend(artist, legend_name, labelspacing=0., handletextpad=0.1, markerscale=markerscale, scatterpoints=1, columnspacing=float(legend_colspace), fontsize=fontsize, **legend_kw) # scatteryoffsets=[0.375, 0.5, 0.3125], # bbox_to_anchor=(0.5, -0.01), # labelspacing=0., # handletextpad=0.1) ## tick configuration if ticks_off: ax.set_xticklabels([]) ax.set_yticklabels([]) if is_3D_mode: ax.set_zticklabels([]) if grid: ax.set_axisbelow(True) ax.grid(grid, which='both', axis='both', linewidth=0.8, alpha=0.5) if xlabel is not None: ax.set_xlabel(str(xlabel), fontsize=fontsize - 1) if ylabel is not None: ax.set_ylabel(str(ylabel), fontsize=fontsize - 1) if title is not None: ax.set_title(str(title), fontsize=fontsize, fontweight='regular') if is_3D_mode and (elev is not None or azim is not None): ax.view_init(elev=ax.elev if elev is None else elev, azim=ax.azim if azim is None else azim)