Пример #1
0
def diagnostic(results, tester, protocols):
    s = tester.sigma
    g = tester.g
    fig, axes = plt.subplots(3, len(results))
    for ax, pro, (k, v) in zip(axes[0], protocols, results.items()):
        ax.matshow(results[k]['residual'])
        ax.set_title(r"$\lambda\sigma$={}, ".format(pro['sigma']) + k)
        ax.axis('off')
    for ax, pro, (k, v) in zip(axes[1], protocols, results.items()):
        ax.matshow(absfft(results[k]['residual']))
        ax.axis('off')
    for ax, (k, v) in zip(axes[2], results.items()):
        res = results[k]['residual'].ravel()
        r, b, _ = ax.hist(res, bins=50, density=True)
        x = 0.5 * (b[1:] + b[:-1])
        ax.plot(x, np.exp(-0.5 * (x / s)**2) / np.sqrt(2 * np.pi * s**2))
        ax.set_title("Standard Deviation = {:.3f}".format(res.std()))
        ax.set_yscale('log')
    fig.suptitle('Noise sigma = {}'.format(s))

    fig2, axes2 = plt.subplots()
    axes2.plot(curl(g)[1][len(g) // 2], label='Ground truth')
    for k, v in results.items():
        axes2.plot(curl(results[k]['gsol'])[1][len(g) // 2], label=k)
    plt.legend()
    return (fig, axes), (fig2, axes2)
Пример #2
0
def plotcurrents(gfield, cross=True):
    jx, jy = curl(gfield)
    fig, axes = plt.subplots(1, 4 if cross else 3, figsize=(8.9, 3))
    kw = dict(vmin=min(jx.min(), jy.min()), vmax=max(jx.max(), jy.max()))
    axes[0].matshow(gfield, cmap='copper')
    axes[0].set_title(r'$g$-field')
    axes[1].matshow(jx, **kw)
    axes[1].set_title(r'$j_x$')
    axes[2].matshow(jy, **kw)
    axes[2].set_title(r'$j_y$')
    if cross:
        axes[3].plot(jy[jy.shape[0] // 2])
        axes[3].set_title(r'$j_y$ cross-section')
    simplify(axes)
    return fig, axes
Пример #3
0
def make_fake_data():
    DIRNAME = os.path.dirname(os.path.abspath(__file__))
    outfile = os.path.join(DIRNAME, 'fake_data.npz')
    maskfile = os.path.join(DIRNAME, 'fake_data_hallprobe_interpolated.npy')
    if os.path.exists(outfile):
        print('"fake_data.npz" exists, skipping generation')
        return
    else:
        print('creating "fake_data.npz"')

    if not os.path.exists(maskfile):
        print("generating mask")
        make_mask()
    
    mask = np.load(maskfile)
    Ly, Lx = 300, 200
    y_by_x_ratio = 0.5
    
    true_params = {'J_ext': np.array([1000]), 
                  'sigma': np.array([1.40803307e-02])}
    
    #true_params['psf_params'] =  p.array([3.26043651e+00,   3.40755272e+00,
    #5.82311678e+00])
    true_params['psf_params'] =  np.array([3.,  6.,  10.])
    
    fake_data_offset = [840-100, 185]#fake_offset
    
    netmodel = ResistorNetworkModel(mask, phi_offset = fake_data_offset, 
                                    gshape=(Ly, Lx), electrodes=[50,550])
    
    kernel = GaussianKernel(mask.shape, params=true_params['psf_params'],
                            rxy=1./y_by_x_ratio)
                            
    netmodel.kernel = kernel
    netmodel.updateParams('J_ext', np.array([1000]))
    jx, jy = curl(netmodel.g_ext, dx=kernel.rxy)

    np.savez(outfile, 
             offset = fake_data_offset, 
             psf_params = true_params['psf_params'],
             J_ext = true_params['J_ext'], all_g = netmodel.gfield,
             unitJ_flux = netmodel._unitJ_flux,
             image_g = netmodel.g_ext,
             image_flux = netmodel.ext_flux)
Пример #4
0
def diagnostic(kernel,
               ref_g,
               g_sol,
               ref_flux,
               netmodel,
               sigma,
               gamma,
               asp=0.5,
               title=None):
    """
    Create fancy diagnostic plot for analyzing fake data current reconstruction
    Parameters
    ----------
    kernel : Kernel
        Kernel object used for making synthetic data and reconstruction
    ref_g : array_like
        reference (ground truth) g-field
    g_sol : array_like
        best fit g-field reconstruction
    ref_flux : array_like
        (noisy) data shown to the optimizer
    netmodel : ResistorNetworkModel
        model of external current
    sigma : float
        noise amplitude added to synthetic data
    gamma : float
        regularization strength (constant in fron of TV term)
    asp : float, optional
        aspect ratio of images
    title : str, optional
        title of plot

    Returns
    -------
    fig, axes : tuple (matplotlib.figure, matplotlib.axes)
    """
    jx, jy = curl(g_sol + netmodel.g_ext)
    jx, jy = kernel.crop(jx), kernel.crop(jy)
    true_jx, true_jy = curl(ref_g)
    true_jx, true_jy = kernel.crop(true_jx), kernel.crop(true_jy)
    fit_flux = kernel.applyM(g_sol).real.reshape(kernel.Ly, -1)
    fig, axes = plt.subplots(3, 4, figsize=(21, 13))

    _, sx, sy = kernel.params
    shx, shy = int(3 * sx), int(3 * sy)
    ker = np.fft.fftshift(
        kernel.psf.real)[kernel.Ly_pad - shy:kernel.Ly_pad + shy,
                         kernel.Lx_pad - shx + 1:kernel.Lx_pad + shx, ]

    residuals = ref_flux - fit_flux
    sly, slx = np.s_[kernel.Ly_pad // 2, :], np.s_[:, kernel.Lx_pad // 2]
    tv_ref_g = total_variation(ref_g, rxy=kernel.rxy).reshape(kernel._padshape)
    tv_fit_g = total_variation(g_sol, rxy=kernel.rxy).reshape(kernel._padshape)

    reference = [ref_flux, true_jx, true_jy]
    ref_label = ["Reference flux", "True $J_x$", "True $J_y$"]
    reconst = [fit_flux, jx, jy]
    rec_label = ["Reconstructed flux", "Recovered $J_x$", "Recovered $J_y$"]

    jxlim = min(np.abs(jx).max(), np.abs(true_jx).max())
    jylim = min(np.abs(jy).max(), np.abs(true_jy).max())
    flim = min(np.abs(ref_flux).max(), np.abs(fit_flux).max())

    ref_lim = [flim, jxlim, jylim]
    rec_lim = [flim, jxlim, jylim]

    for axrow in range(3):
        for axcol in range(4):
            axe = axes[axrow, axcol]
            if axcol == 0:  # reference
                lim = ref_lim[axrow]
                axe.matshow(reference[axrow], aspect=asp, vmin=-lim, vmax=lim)
                axe.axis("off")
                axe.set_title(ref_label[axrow], fontsize=30)
                if axrow == 1:  # Kernel inset
                    yy, xx = reference[axrow].shape
                    yk, xk = ker.shape
                    py, px = yk / yy, xk / xx
                    insetax = inset_axes(
                        axe,
                        width="{}%".format(px * 100),
                        height="{}%".format(py * 100),
                        loc=8,
                    )
                    insetax.matshow(ker, cmap="Greys")
                    simpleaxis(insetax)
                    insetax.set_ylabel("PSF",
                                       fontsize=18,
                                       rotation="horizontal",
                                       labelpad=20,
                                       y=0.2)
                if axrow == 1:
                    axe.axvline(kernel.Lx // 2, color="k", alpha=0.1)
                if axrow == 2:
                    axe.axhline(kernel.Ly // 2, color="k", alpha=0.1)
            elif axcol == 1:
                lim = rec_lim[axrow]
                axe.matshow(reconst[axrow], aspect=asp, vmin=-lim, vmax=lim)
                axe.axis("off")
                axe.set_title(rec_label[axrow], fontsize=30)
                if axrow == 1:
                    axe.axvline(kernel.Lx // 2, color="k", alpha=0.1)
                if axrow == 2:
                    axe.axhline(kernel.Ly // 2, color="k", alpha=0.1)
            elif axcol == 2:
                if axrow == 0:
                    axe.matshow(residuals,
                                aspect=asp,
                                vmin=-4 * sigma,
                                vmax=4 * sigma)
                    axe.axis("off")
                    axe.set_title("Residuals", fontsize=30)
                elif axrow == 1:
                    axe.plot(true_jx[:, kernel.Lx // 2], label=r"Reference")
                    axe.plot(jx[:, kernel.Lx // 2], label=r"Reconstructed")
                    axe.set_title(r"$J_x$ cross section", fontsize=30)
                    axe.set_xlim([0, kernel.Ly])
                elif axrow == 2:
                    axe.plot(true_jy[kernel.Ly // 2], label=r"Reference")
                    axe.plot(jy[kernel.Ly // 2], label=r"Reconstructed")
                    axe.set_title(r"$J_y$ cross section", fontsize=30)
                    axe.set_xlim([0, kernel.Lx])
                axe.legend(loc="best", fontsize=16)
                axe.set_ylabel("Current", fontsize=18)
            elif axcol == 3:  # histogram, g/TV cross sections
                if axrow == 0:
                    pass  # Histogram of residuals
                    hist, bins = np.histogram(residuals.ravel(),
                                              bins=60,
                                              density=True)
                    p = lambda x: np.exp(-(x / sigma)**2 / 2) / np.sqrt(
                        2 * np.pi * sigma**2)
                    resp = 0.5 * (bins[1:] + bins[:-1])
                    axe.plot(resp, hist, label="Residuals")
                    axe.plot(resp, p(resp), label="True noise")
                    axe.set_yscale("log")
                    axe.legend(loc="best", fontsize=16)
                    axe.set_title("Residual distribution", fontsize=30)

                elif axrow == 1 or axrow == 2:
                    sl = {1: slx, 2: sly}[axrow]
                    axe2 = axe.twinx()

                    gslice, ref_g_slice = g_sol[sl], ref_g[sl]
                    axe.plot(gslice, label="Fit $g$", lw=4, c="k")
                    axe.plot(ref_g_slice,
                             label="True $g$",
                             lw=3,
                             c="r",
                             alpha=0.6)
                    axe.set_ylabel(r"$g$", fontsize=24)
                    axe.set_ylim([3 * gslice.min(), 1.4 * gslice.max()])

                    tv_fit_slice, tv_ref_slice = tv_fit_g[sl], tv_ref_g[sl]
                    axe2.plot(tv_fit_slice,
                              lw=2,
                              label="TV of fit $g$",
                              c="k",
                              alpha=0.4)
                    axe2.plot(
                        tv_ref_slice,
                        ":",
                        lw=3,
                        label="TV of true $g$",
                        c="r",
                        alpha=0.6,
                    )
                    axe2.set_ylabel("TV(g)", fontsize=24)
                    axe2.set_ylim([0, 2 * tv_ref_slice.max()])

                    if axrow == 1:
                        axe.set_title("Vertical Cross Section", fontsize=20)
                        axe.axvline(kernel.py, color="k", lw=0.4)
                        axe.axvline(kernel.py + kernel.Ly, color="k", lw=0.4)
                        axe.set_xlim([0, kernel.Ly_pad])
                        axe2.legend(loc="upper right", fontsize=16)
                        axe.legend(loc="lower left", fontsize=16)
                    else:
                        axe.set_title("Horizontal Cross Section", fontsize=20)
                        axe.axvline(kernel.px, color="k", lw=0.4)
                        axe.axvline(kernel.px + kernel.Lx, color="k", lw=0.4)
                        axe.set_xlim([0, kernel.Lx_pad])
                        axe2.legend(loc="upper right", fontsize=16)
                        axe.legend(loc="upper left", fontsize=16)

    title = (title if title is not None else
             "Reconstruction with $\mu$ = {}".format(gamma))
    plt.suptitle(title, fontsize=25)
    return fig, axes
Пример #5
0
def fit_diagnostic(model, ref_data, asp, title=None): 
    jx, jy = curl(model.g_sol, model.dx, model.dy)
    fit_flux = (model.kernel.applyM(model.gfield).real + 
                model.extmodel.ext_flux)
    slx = np.s_[:,model.Lx/2-1:model.Lx/2+1]
    
    fig, axes = plt.subplots(3, 3, figsize=(18, 15))
    
    flux_arrays = [ref_data, fit_flux, ref_data-fit_flux]
    flux_lim = np.abs(np.array(flux_arrays)).max()
    flux_labels = ['Measured flux', 'Reconstructed flux', 'Residuals']
    
    j_arrays = [model.crop(jx), model.crop(jy)]
    j_labels = ['Recovered $J_x$','Recovered $J_y$']
    j_lim = np.abs(np.array(j_arrays)).max()
    
    for axrow in range(3):
        for axcol in range(3):
            ax = axes[axrow, axcol]
            if axrow == 0:
                if axcol < 2:
                    ax.matshow(flux_arrays[axcol], aspect = asp,
                               vmin=-flux_lim, vmax = flux_lim)
                else:
                    rlim = 4*flux_arrays[axcol].std()
                    ax.matshow(flux_arrays[axcol], aspect = asp,
                               vmin = -rlim, vmax=rlim)
                ax.axis('off')
                ax.set_title(flux_labels[axcol], fontsize=30)
            elif axrow == 1:
                if axcol < 2:
                    ax.matshow(j_arrays[axcol], aspect = asp,
                               vmin=-j_lim, vmax = j_lim)
                    ax.axis('off')
                    ax.set_title(j_labels[axcol], fontsize=30)
                if axcol == 2:
                    res = flux_arrays[2]
                    hist, bins = np.histogram(res, bins=60, normed=True, range=(-4*res.std(), 4*res.std()))
                    p = lambda x: np.exp(-(x/model.sigma)**2/2)/np.sqrt(2*np.pi*model.sigma**2)
                    resp = 0.5*(bins[1:]+bins[:-1])
                    ax.plot(resp, hist, label='Fit residuals')
                    ax.plot(resp, p(resp), label='Estimated noise')
                    ax.legend(loc='lower center', fontsize=16)
                    ax.set_title("Residual distribution", fontsize=30)
                    ax.set_yscale('log')
                    ax.set_ylabel('log $P(r)$', fontsize=20)
                    ax.set_xlabel('$r$', fontsize=20)
            else:
                if axcol == 1:
                    ax2 = ax.twinx()
                    lj = ax.plot(j_arrays[0][slx].mean(1), c = 'b', label='Horizontal current')
                    lf = ax2.plot(ref_data[slx].mean(1), c = 'g', label='Flux data')
                    lab = lj + lf
                    ax.legend(lab, [l.get_label() for l in lab], 
                              loc='upper right', fontsize=16)
                    ax.set_title("Vertical Cross sections", fontsize=30)
                    ax.set_xlabel("Vertical distance", fontsize=20)
                    ax.set_ylabel("$J_x$", fontsize=20)
                    ax2.set_ylabel("Flux", fontsize=20)
                    ax.set_xlim([0, model.Ly])
                else:
                    ax.axis('off')
    title = title if title is not None else "Reconstruction with $\mu$={}".format(model.linearModel.mu_reg)
    plt.suptitle(title, fontsize=25) 
    return fig, axes
Пример #6
0
def j_density(g):
    return np.hypot(*curl(g))
Пример #7
0
 def currents(self):
     return curl(self.gfield)