Exemplo n.º 1
0
def compare(true_g,
            results,
            mirror_results,
            labels=[
                '(b) Gaussian\nwith mirror', '(c) TV\nwith mirror',
                '(d) TV with\nexternal model'
            ]):
    letters = string.ascii_lowercase
    fig = plt.figure(figsize=(10., 4.5))
    grid = ImageGrid(
        fig,
        111,
        nrows_ncols=(2, 4),
        axes_pad=0.01,
        share_all=True,
        cbar_location="right",
        cbar_mode="each",
    )
    for i in range(len(grid)):
        if not i % 4 == 3:
            grid.cbar_axes[i].axis('off')
        grid[i].axis('off')
    true_j = j_density(true_g)
    lim = true_j.max()
    gsols = [
        mirror_results['gaussian']['gsol'], mirror_results['TV prior']['gsol'],
        results['TV prior']['gsol']
    ]
    for g in gsols:
        j = j_density(g)
        lim = max(lim, j.max())
    kwargs = {'cmap': 'gray_r', 'vmin': 0, 'vmax': lim}

    grid[0].matshow(true_j, **kwargs)
    grid[0].set_title("(a) Ground truth\nsample $|\mathbf{j}|$", pad=0)
    for i, (lab, g) in enumerate(zip(labels, gsols)):
        im = grid[i + 1].matshow(j_density(g), **kwargs)
        grid[i + 1].set_title(lab, pad=0)
    grid.cbar_axes[i + 1].colorbar(im)

    for g in gsols:
        j = j_density(g) - true_j
        lim = max(lim, max(abs(j.min()), j.max()))
    kwargs = {'cmap': 'RdBu_r', 'vmin': -lim, 'vmax': lim}
    for i, g in enumerate(gsols):
        im = grid[i + 5].matshow(j_density(g) - true_j, **kwargs)
    grid[5].set_title("Reconstruction\nerror", pad=0)
    grid.cbar_axes[i + 5].colorbar(im)

    grid.set_axes_pad((0.01, 0.5))
    return fig, grid
Exemplo n.º 2
0
def compare_truth(uni_result, uni_tester, para_result, para_tester):
    letters = string.ascii_lowercase
    fig = plt.figure(figsize=(10., 11.))
    L = len(uni_result)
    grid = ImageGrid(
        fig,
        111,
        nrows_ncols=(4, 1 + L),
        axes_pad=0.1,
        share_all=True,
        cbar_location="right",
        cbar_mode="each",
    )
    vlim = get_j_density_range(uni_tester.g, uni_result)
    uniform_truth = j_density(uni_tester.g)
    grid[0].matshow(uniform_truth, cmap='gray_r', **vlim)
    subfig = "({}) ".format(letters[0])
    letters = letters[1:]
    grid[0].set_title(subfig + "Ground truth\nuniform $|\mathbf{j}|$", pad=0)

    for i in range(len(grid)):
        if not i % (L + 1) == L:
            grid.cbar_axes[i].axis('off')

    for i, (label, res) in enumerate(uni_result.items()):
        im = grid[i + 1].matshow(j_density(res['gsol']), cmap='gray_r', **vlim)
        subfig = "({}) ".format(letters[0])
        letters = letters[1:]
        grid[i + 1].set_title(subfig + label, pad=0)
    grid.cbar_axes[L].colorbar(im)

    vlim = get_j_density_range(uni_tester.g, uni_result, uniform_truth)
    for i, (label, res) in enumerate(uni_result.items()):
        im = grid[L + i + 2].matshow(j_density(res['gsol']) - uniform_truth,
                                     cmap='RdBu',
                                     **vlim)
        if not i:
            grid[L + i + 2].set_title("Reconstruction error", pad=0)
    grid.cbar_axes[2 * L + 1].colorbar(im)

    vlim = get_j_density_range(para_tester.g, para_result)
    parabolic_truth = j_density(para_tester.g)
    grid[2 * L + 2].matshow(parabolic_truth, cmap='gray_r', **vlim)
    subfig = "({}) ".format(letters[0])
    letters = letters[1:]
    grid[2 * L + 2].set_title(subfig +
                              "Ground truth \nparabolic $|\mathbf{j}|$",
                              pad=0)

    for i, (label, res) in enumerate(para_result.items()):
        im = grid[2 * L + 3 + i].matshow(j_density(res['gsol']),
                                         cmap='gray_r',
                                         **vlim)
        subfig = "({}) ".format(letters[0])
        letters = letters[1:]
        grid[2 * L + 3 + i].set_title(subfig + label, pad=0)
    grid.cbar_axes[3 * L + 2].colorbar(im)

    vlim = get_j_density_range(para_tester.g, para_result, parabolic_truth)
    for i, (label, res) in enumerate(para_result.items()):
        im = grid[3 * L + 4 + i].matshow(j_density(res['gsol']) -
                                         parabolic_truth,
                                         cmap='RdBu',
                                         **vlim)
        if not i:
            grid[3 * L + 4 + i].set_title("Reconstruction error", pad=0)
    grid.cbar_axes[4 * L + 3].colorbar(im)

    for i in range(len(grid)):
        grid[i].axis('off')

    grid.set_axes_pad((0.1, 0.55))
    return fig, grid