def example_plot(rinfo, b=0, t=-1, mask_components=False, size=2, column_titles=True): image = rinfo["data"]["image"][b, 0] recons = rinfo["outputs"]["recons"][b, t, 0] pred_mask = rinfo["outputs"]["pred_mask"][b, t] components = rinfo["outputs"]["components"][b, t] K, H, W, C = components.shape colors = get_mask_plot_colors(K) nrows = 1 ncols = 3 + K fig, axes = plt.subplots(ncols=ncols, figsize=(ncols * size, nrows * size)) show_img(image, ax=axes[0], color="#000000") show_img(recons, ax=axes[1], color="#000000") show_mask(pred_mask[Ellipsis, 0], ax=axes[2], color="#000000") for k in range(K): mask = pred_mask[k] if mask_components else None show_img(components[k], ax=axes[k + 3], color=colors[k], mask=mask) if column_titles: labels = ["Image", "Recons.", "Mask" ] + ["Component {}".format(k + 1) for k in range(K)] for ax, title in zip(axes, labels): ax.set_title(title) plt.subplots_adjust(hspace=0.03, wspace=0.035) return fig
def inputs_plot(rinfo, b=0, t=0, size=2): B, T, K, H, W, C = rinfo["outputs"]["components"].shape colors = get_mask_plot_colors(K) inputs = rinfo["inputs"]["spatial"] rows = [ ("image", show_img, False), ("components", show_img, False), ("dcomponents", functools.partial(show_img, norm=True), False), ("mask", show_mat, True), ("pred_mask", show_mat, True), ("dmask", functools.partial(show_mat, cmap="coolwarm"), True), ("posterior", show_mat, True), ("log_prob", show_mat, True), ("counterfactual", show_mat, True), ("coordinates", show_coords, False), ] rows = [(n, f, mcb) for n, f, mcb in rows if n in inputs] nrows = len(rows) ncols = K + 1 fig, axes = plt.subplots( nrows=nrows, ncols=ncols, figsize=(ncols * size - size * 0.9, nrows * size), gridspec_kw={"width_ratios": [1] * K + [0.1]}, ) for r, (name, plot_fn, make_cbar) in enumerate(rows): axes[r, 0].set_ylabel(name) if make_cbar: vmin = np.min(inputs[name][b, t]) vmax = np.max(inputs[name][b, t]) if np.abs(vmin - vmax) < 1e-6: vmin -= 0.1 vmax += 0.1 plot_fn = functools.partial(plot_fn, vmin=vmin, vmax=vmax) # print("range of {:<16}: [{:0.2f}, {:0.2f}]".format(name, vmin, vmax)) for k in range(K): if inputs[name].shape[2] == 1: m = inputs[name][b, t, 0] color = (0.0, 0.0, 0.0) else: m = inputs[name][b, t, k] color = colors[k] mappable = plot_fn(m, ax=axes[r, k], color=color) if make_cbar: fig.colorbar(mappable, cax=axes[r, K]) else: axes[r, K].set_visible(False) for k in range(K): axes[0, k].set_title("Component {}".format(k + 1)) # , color=colors[k]) plt.subplots_adjust(hspace=0.05, wspace=0.05) return fig
def iterations_plot(rinfo, b=0, mask_components=False, size=2): image = rinfo["data"]["image"][b] true_mask = rinfo["data"]["true_mask"][b] recons = rinfo["outputs"]["recons"][b] pred_mask = rinfo["outputs"]["pred_mask"][b] pred_mask_logits = rinfo["outputs"]["pred_mask_logits"][b] components = rinfo["outputs"]["components"][b] T, K, H, W, C = components.shape colors = get_mask_plot_colors(K) nrows = T + 1 ncols = 2 + K fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * size, nrows * size)) for t in range(T): show_img(recons[t, 0], ax=axes[t, 0]) show_mask(pred_mask[t, Ellipsis, 0], ax=axes[t, 1]) axes[t, 0].set_ylabel("iter {}".format(t)) for k in range(K): mask = pred_mask[t, k] if mask_components else None show_img(components[t, k], ax=axes[t, k + 2], color=colors[k], mask=mask) axes[0, 0].set_title("Reconstruction") axes[0, 1].set_title("Mask") show_img(image[0], ax=axes[T, 0]) show_mask(true_mask[0, Ellipsis, 0], ax=axes[T, 1]) vmin = np.min(pred_mask_logits[T - 1]) vmax = np.max(pred_mask_logits[T - 1]) for k in range(K): axes[0, k + 2].set_title( "Component {}".format(k + 1)) # , color=colors[k]) show_mat(pred_mask_logits[T - 1, k], ax=axes[T, k + 2], vmin=vmin, vmax=vmax) axes[T, k + 2].set_xlabel( "Mask Logits for\nComponent {}".format(k + 1)) # , color=colors[k]) axes[T, 0].set_xlabel("Input Image") axes[T, 1].set_xlabel("Ground Truth Mask") plt.subplots_adjust(wspace=0.05, hspace=0.05) return fig
def show_coords(m, ax): vmin, vmax = np.min(m), np.max(m) m = (m - vmin) / (vmax - vmin) color_conv = get_mask_plot_colors(m.shape[-1]) color_mask = np.dot(m, color_conv) return ax.imshow(color_mask, interpolation="nearest")
def show_mask(m, ax): color_conv = get_mask_plot_colors(m.shape[0]) color_mask = np.dot(np.transpose(m, [1, 2, 0]), color_conv) return ax.imshow(color_mask.clip(0.0, 1.0), interpolation="nearest")