Beispiel #1
0
def example_plot(rinfo,
                 b=0,
                 t=-1,
                 mask_components=False,
                 size=2,
                 column_titles=True):
    image = rinfo["data"]["image"][b, 0]
    recons = rinfo["outputs"]["recons"][b, t, 0]
    pred_mask = rinfo["outputs"]["pred_mask"][b, t]
    components = rinfo["outputs"]["components"][b, t]

    K, H, W, C = components.shape
    colors = get_mask_plot_colors(K)

    nrows = 1
    ncols = 3 + K
    fig, axes = plt.subplots(ncols=ncols, figsize=(ncols * size, nrows * size))

    show_img(image, ax=axes[0], color="#000000")
    show_img(recons, ax=axes[1], color="#000000")
    show_mask(pred_mask[Ellipsis, 0], ax=axes[2], color="#000000")
    for k in range(K):
        mask = pred_mask[k] if mask_components else None
        show_img(components[k], ax=axes[k + 3], color=colors[k], mask=mask)

    if column_titles:
        labels = ["Image", "Recons.", "Mask"
                  ] + ["Component {}".format(k + 1) for k in range(K)]
        for ax, title in zip(axes, labels):
            ax.set_title(title)
    plt.subplots_adjust(hspace=0.03, wspace=0.035)
    return fig
Beispiel #2
0
def inputs_plot(rinfo, b=0, t=0, size=2):
    B, T, K, H, W, C = rinfo["outputs"]["components"].shape
    colors = get_mask_plot_colors(K)
    inputs = rinfo["inputs"]["spatial"]
    rows = [
        ("image", show_img, False),
        ("components", show_img, False),
        ("dcomponents", functools.partial(show_img, norm=True), False),
        ("mask", show_mat, True),
        ("pred_mask", show_mat, True),
        ("dmask", functools.partial(show_mat, cmap="coolwarm"), True),
        ("posterior", show_mat, True),
        ("log_prob", show_mat, True),
        ("counterfactual", show_mat, True),
        ("coordinates", show_coords, False),
    ]
    rows = [(n, f, mcb) for n, f, mcb in rows if n in inputs]
    nrows = len(rows)
    ncols = K + 1

    fig, axes = plt.subplots(
        nrows=nrows,
        ncols=ncols,
        figsize=(ncols * size - size * 0.9, nrows * size),
        gridspec_kw={"width_ratios": [1] * K + [0.1]},
    )
    for r, (name, plot_fn, make_cbar) in enumerate(rows):
        axes[r, 0].set_ylabel(name)
        if make_cbar:
            vmin = np.min(inputs[name][b, t])
            vmax = np.max(inputs[name][b, t])
            if np.abs(vmin - vmax) < 1e-6:
                vmin -= 0.1
                vmax += 0.1
            plot_fn = functools.partial(plot_fn, vmin=vmin, vmax=vmax)
            # print("range of {:<16}: [{:0.2f}, {:0.2f}]".format(name, vmin, vmax))
        for k in range(K):
            if inputs[name].shape[2] == 1:
                m = inputs[name][b, t, 0]
                color = (0.0, 0.0, 0.0)
            else:
                m = inputs[name][b, t, k]
                color = colors[k]
            mappable = plot_fn(m, ax=axes[r, k], color=color)
        if make_cbar:
            fig.colorbar(mappable, cax=axes[r, K])
        else:
            axes[r, K].set_visible(False)
    for k in range(K):
        axes[0,
             k].set_title("Component {}".format(k + 1))  # , color=colors[k])

    plt.subplots_adjust(hspace=0.05, wspace=0.05)
    return fig
Beispiel #3
0
def iterations_plot(rinfo, b=0, mask_components=False, size=2):
    image = rinfo["data"]["image"][b]
    true_mask = rinfo["data"]["true_mask"][b]
    recons = rinfo["outputs"]["recons"][b]
    pred_mask = rinfo["outputs"]["pred_mask"][b]
    pred_mask_logits = rinfo["outputs"]["pred_mask_logits"][b]
    components = rinfo["outputs"]["components"][b]

    T, K, H, W, C = components.shape
    colors = get_mask_plot_colors(K)
    nrows = T + 1
    ncols = 2 + K
    fig, axes = plt.subplots(nrows=nrows,
                             ncols=ncols,
                             figsize=(ncols * size, nrows * size))
    for t in range(T):
        show_img(recons[t, 0], ax=axes[t, 0])
        show_mask(pred_mask[t, Ellipsis, 0], ax=axes[t, 1])
        axes[t, 0].set_ylabel("iter {}".format(t))
        for k in range(K):
            mask = pred_mask[t, k] if mask_components else None
            show_img(components[t, k],
                     ax=axes[t, k + 2],
                     color=colors[k],
                     mask=mask)

    axes[0, 0].set_title("Reconstruction")
    axes[0, 1].set_title("Mask")
    show_img(image[0], ax=axes[T, 0])
    show_mask(true_mask[0, Ellipsis, 0], ax=axes[T, 1])
    vmin = np.min(pred_mask_logits[T - 1])
    vmax = np.max(pred_mask_logits[T - 1])

    for k in range(K):
        axes[0, k + 2].set_title(
            "Component {}".format(k + 1))  # , color=colors[k])
        show_mat(pred_mask_logits[T - 1, k],
                 ax=axes[T, k + 2],
                 vmin=vmin,
                 vmax=vmax)
        axes[T, k + 2].set_xlabel(
            "Mask Logits for\nComponent {}".format(k +
                                                   1))  # , color=colors[k])
    axes[T, 0].set_xlabel("Input Image")
    axes[T, 1].set_xlabel("Ground Truth Mask")

    plt.subplots_adjust(wspace=0.05, hspace=0.05)
    return fig
Beispiel #4
0
def show_coords(m, ax):
    vmin, vmax = np.min(m), np.max(m)
    m = (m - vmin) / (vmax - vmin)
    color_conv = get_mask_plot_colors(m.shape[-1])
    color_mask = np.dot(m, color_conv)
    return ax.imshow(color_mask, interpolation="nearest")
Beispiel #5
0
def show_mask(m, ax):
    color_conv = get_mask_plot_colors(m.shape[0])
    color_mask = np.dot(np.transpose(m, [1, 2, 0]), color_conv)
    return ax.imshow(color_mask.clip(0.0, 1.0), interpolation="nearest")