def __init__(self, cmap, vmin, vmax=None, label=True, label_position=None, label_rotation=None, clipmin=None, clipmax=None, orientation='horizontal', unit=None, contours=(), width=None, ticks=None, threshold=None, ticklocation='auto', background='white', tight=True, h=None, w=None, *args, **kwargs): # get Colormap if isinstance(cmap, np.ndarray): if threshold is not None: raise NotImplementedError("threshold parameter with cmap=array") if cmap.max() > 1: cmap = cmap / 255. cm = mpl.colors.ListedColormap(cmap, 'LUT') else: cm = mpl.cm.get_cmap(cmap) # prepare layout if orientation == 'horizontal': if h is None and w is None: h = 1 ax_aspect = 4 elif orientation == 'vertical': if h is None and w is None: h = 4 ax_aspect = 0.3 else: raise ValueError("orientation=%s" % repr(orientation)) layout = Layout(1, ax_aspect, 2, tight, False, h, w, *args, **kwargs) EelFigure.__init__(self, cm.name, layout) ax = self._axes[0] # translate between axes and data coordinates if isinstance(vmin, Normalize): norm = vmin else: vmin, vmax = fix_vlim_for_cmap(vmin, vmax, cm.name) norm = Normalize(vmin, vmax) # value ticks if ticks is False: ticks = () tick_labels = None elif isinstance(ticks, dict): tick_dict = ticks ticks = sorted(tick_dict) tick_labels = [tick_dict[t] for t in ticks] else: tick_labels = None if orientation == 'horizontal': axis = ax.xaxis contour_func = ax.axhline else: axis = ax.yaxis contour_func = ax.axvline if label is True: if unit: label = unit else: label = cm.name elif not label: label = '' # show only part of the colorbar if clipmin is not None or clipmax is not None: if isinstance(norm, SymmetricNormalize): raise NotImplementedError( "clipmin or clipmax with SymmetricNormalize") boundaries = norm.inverse(np.linspace(0, 1, cm.N + 1)) if clipmin is None: start = None else: start = np.digitize(clipmin, boundaries, True) if clipmax is None: stop = None else: stop = np.digitize(clipmax, boundaries) + 1 boundaries = boundaries[start:stop] else: boundaries = None colorbar = ColorbarBase(ax, cm, norm, boundaries=boundaries, orientation=orientation, ticklocation=ticklocation, ticks=ticks, label=label) # fix tick location if isinstance(norm, SymmetricNormalize) and ticks is not None: tick_norm = Normalize(norm.vmin, norm.vmax, norm.clip) axis.set_ticks(tick_norm(ticks)) # unit-based tick-labels if unit and tick_labels is None: formatter, label = find_axis_params_data(unit, label) tick_labels = tuple(map(formatter, colorbar.get_ticks())) if tick_labels is not None: if clipmin is not None: tick_labels = [l for l, t in zip(tick_labels, ticks) if t >= clipmin] axis.set_ticklabels(tick_labels) # label position/rotation if label_position is not None: axis.set_label_position(label_position) if label_rotation is not None: axis.label.set_rotation(label_rotation) if orientation == 'vertical': if (label_rotation + 10) % 360 < 20: axis.label.set_va('center') elif orientation == 'vertical' and len(label) <= 3: axis.label.set_rotation(0) axis.label.set_va('center') self._contours = [contour_func(c, c='k') for c in contours] self._draw_hooks.append(self.__fix_alpha) self._draw_hooks.append(self.__update_bar_tickness) self._background = background self._colorbar = colorbar self._orientation = orientation self._width = width self._show()
def heatmap_plot(df, filename=None, average=True, group_by="accuracy", str_format="{:.2%}", box="max", print_flag=True, time_heatmap=False): """Generate heatmap plot""" df_grouped = group_df(df, average=average, group_by=group_by, ascending=box == "min") v_names = list(df_grouped.columns) if average: avg = df_grouped[df_grouped.index == "Average"].iloc[0] v_names.sort( key=lambda x: round(avg[x], 4) * (1 if box == "max" else -1)) df_grouped = df_grouped[v_names] c_names = list(df_grouped.index) min_value = min(df[group_by]) max_value = max(df[group_by]) ratio = (len(c_names) / len(v_names)) / 2 figsize = (len(v_names) * 0.8 + 3.1, 2 + ratio * len(v_names) * 0.8) fig, ax = plt.subplots(1, 1, figsize=figsize) fig.subplots_adjust(wspace=0.2) plt.title(group_by) cmap = cm.get_cmap("Greens", 512) cb_cmap = im_cmap = ListedColormap(cmap(np.linspace(0.0, 0.5, 256))) if box == "min": im_cmap = ListedColormap(cmap(np.linspace(0.5, 0.0, 256))) ax.imshow(df_grouped, cmap=im_cmap, aspect="auto") if isinstance(str_format, str): str_formater = str_format.format else: str_formater = str_format # Create colorbar if min_value != max_value: divider = make_axes_locatable(ax) cax = divider.append_axes("right", size=0.2, pad=0.2) ticks = np.linspace(min_value, max_value, 10) if box == "min": ticks = ticks[::-1] norm = Normalize(vmin=0, vmax=9, clip=False) cbar = ColorbarBase(cax, cmap=cb_cmap, ticks=range(10), norm=norm) cbar.ax.set_yticklabels( [str_formater(ticks[i]) for i in cbar.get_ticks()]) xticks = np.arange(len(v_names)) ax.set_xticks(xticks) ax.set_yticks(np.arange(len(c_names))) ax.set_yticklabels([x.split(".")[-1] for x in c_names]) ax.set_xticklabels(v_names) ax.tick_params(top=False, bottom=True, labeltop=False, labelbottom=True) plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") for (i, c_name) in enumerate(c_names): for (j, v_name) in enumerate(v_names): value = df_grouped[df_grouped.index == c_name][v_name].values[0] value = str_formater(value) ax.text(j, i, value, ha="center", va="center", color="black") if print_flag and "flag" in df.columns and "Average" not in [ c_name, v_name ]: row = df[(df.c_name == c_name) & (df.v_name == v_name)] if row.empty: continue flag = row.iloc[0].flag if flag: ax.text(j + 0.4, i - 0.25, flag, ha="center", va="center", color="black") ax.set_xticks(np.arange(len(v_names)) - 0.5, minor=True) ax.set_yticks(np.arange(len(c_names)) - 0.5, minor=True) ax.grid(True, which="minor", linestyle="-", color="w") ax.tick_params(which="minor", length=0.0) if average: x_avg = v_names.index("Average") y_avg = c_names.index("Average") plot_lines(ax, v=[x_avg - 0.5, x_avg + 0.5], h=[y_avg - 0.5, y_avg + 0.5], color="red") box_value = max(df[group_by]) if box == "max" else min(df[group_by]) box_row = df[df[group_by] == box_value].iloc[0] x, y = v_names.index(box_row.v_name), c_names.index(box_row.c_name) plot_box(ax, left=x - 0.5, right=x + 0.5, bottom=y - 0.5, top=y + 0.5, color="b") fig.tight_layout() if filename is None: filename = f"{len(v_names)}_{len(c_names)}" plt.savefig(f"plot/heatmap_{filename}.png") logging.info("Saved 'plot/heatmap_%s.png'", filename) if average: ax.fill_between(np.arange(-0.5, x_avg + 1, 0.5), -0.5, len(c_names) - 0.5, facecolor="black", alpha=0.2) ax.fill_between(np.arange(x_avg + 0.5, len(v_names), 0.5), y_avg - 0.5, len(c_names) - 0.5, facecolor="black", alpha=0.2) plt.savefig(f"plot/heatmap_{filename}_2.png") logging.info("Saved 'plot/heatmap_%s_2.png'", filename) if group_by == "accuracy" and time_heatmap: df["total_time"] = df.apply( lambda row: row.train_time + row.predict_time, axis=1) heatmap_plot(df, filename=f"{filename}_train_time", average=average, print_flag=False, group_by="train_time", str_format=partial(time_format, decimals=2), box="min") heatmap_plot(df, filename=f"{filename}_predict_time", average=average, print_flag=False, group_by="predict_time", str_format=partial(time_format, decimals=3), box="min") heatmap_plot(df, filename=f"{filename}_total_time", average=average, print_flag=False, group_by="total_time", str_format=partial(time_format, decimals=2), box="min")