def diagnostic(results, tester, protocols): s = tester.sigma g = tester.g fig, axes = plt.subplots(3, len(results)) for ax, pro, (k, v) in zip(axes[0], protocols, results.items()): ax.matshow(results[k]['residual']) ax.set_title(r"$\lambda\sigma$={}, ".format(pro['sigma']) + k) ax.axis('off') for ax, pro, (k, v) in zip(axes[1], protocols, results.items()): ax.matshow(absfft(results[k]['residual'])) ax.axis('off') for ax, (k, v) in zip(axes[2], results.items()): res = results[k]['residual'].ravel() r, b, _ = ax.hist(res, bins=50, density=True) x = 0.5 * (b[1:] + b[:-1]) ax.plot(x, np.exp(-0.5 * (x / s)**2) / np.sqrt(2 * np.pi * s**2)) ax.set_title("Standard Deviation = {:.3f}".format(res.std())) ax.set_yscale('log') fig.suptitle('Noise sigma = {}'.format(s)) fig2, axes2 = plt.subplots() axes2.plot(curl(g)[1][len(g) // 2], label='Ground truth') for k, v in results.items(): axes2.plot(curl(results[k]['gsol'])[1][len(g) // 2], label=k) plt.legend() return (fig, axes), (fig2, axes2)
def plotcurrents(gfield, cross=True): jx, jy = curl(gfield) fig, axes = plt.subplots(1, 4 if cross else 3, figsize=(8.9, 3)) kw = dict(vmin=min(jx.min(), jy.min()), vmax=max(jx.max(), jy.max())) axes[0].matshow(gfield, cmap='copper') axes[0].set_title(r'$g$-field') axes[1].matshow(jx, **kw) axes[1].set_title(r'$j_x$') axes[2].matshow(jy, **kw) axes[2].set_title(r'$j_y$') if cross: axes[3].plot(jy[jy.shape[0] // 2]) axes[3].set_title(r'$j_y$ cross-section') simplify(axes) return fig, axes
def make_fake_data(): DIRNAME = os.path.dirname(os.path.abspath(__file__)) outfile = os.path.join(DIRNAME, 'fake_data.npz') maskfile = os.path.join(DIRNAME, 'fake_data_hallprobe_interpolated.npy') if os.path.exists(outfile): print('"fake_data.npz" exists, skipping generation') return else: print('creating "fake_data.npz"') if not os.path.exists(maskfile): print("generating mask") make_mask() mask = np.load(maskfile) Ly, Lx = 300, 200 y_by_x_ratio = 0.5 true_params = {'J_ext': np.array([1000]), 'sigma': np.array([1.40803307e-02])} #true_params['psf_params'] = p.array([3.26043651e+00, 3.40755272e+00, #5.82311678e+00]) true_params['psf_params'] = np.array([3., 6., 10.]) fake_data_offset = [840-100, 185]#fake_offset netmodel = ResistorNetworkModel(mask, phi_offset = fake_data_offset, gshape=(Ly, Lx), electrodes=[50,550]) kernel = GaussianKernel(mask.shape, params=true_params['psf_params'], rxy=1./y_by_x_ratio) netmodel.kernel = kernel netmodel.updateParams('J_ext', np.array([1000])) jx, jy = curl(netmodel.g_ext, dx=kernel.rxy) np.savez(outfile, offset = fake_data_offset, psf_params = true_params['psf_params'], J_ext = true_params['J_ext'], all_g = netmodel.gfield, unitJ_flux = netmodel._unitJ_flux, image_g = netmodel.g_ext, image_flux = netmodel.ext_flux)
def diagnostic(kernel, ref_g, g_sol, ref_flux, netmodel, sigma, gamma, asp=0.5, title=None): """ Create fancy diagnostic plot for analyzing fake data current reconstruction Parameters ---------- kernel : Kernel Kernel object used for making synthetic data and reconstruction ref_g : array_like reference (ground truth) g-field g_sol : array_like best fit g-field reconstruction ref_flux : array_like (noisy) data shown to the optimizer netmodel : ResistorNetworkModel model of external current sigma : float noise amplitude added to synthetic data gamma : float regularization strength (constant in fron of TV term) asp : float, optional aspect ratio of images title : str, optional title of plot Returns ------- fig, axes : tuple (matplotlib.figure, matplotlib.axes) """ jx, jy = curl(g_sol + netmodel.g_ext) jx, jy = kernel.crop(jx), kernel.crop(jy) true_jx, true_jy = curl(ref_g) true_jx, true_jy = kernel.crop(true_jx), kernel.crop(true_jy) fit_flux = kernel.applyM(g_sol).real.reshape(kernel.Ly, -1) fig, axes = plt.subplots(3, 4, figsize=(21, 13)) _, sx, sy = kernel.params shx, shy = int(3 * sx), int(3 * sy) ker = np.fft.fftshift( kernel.psf.real)[kernel.Ly_pad - shy:kernel.Ly_pad + shy, kernel.Lx_pad - shx + 1:kernel.Lx_pad + shx, ] residuals = ref_flux - fit_flux sly, slx = np.s_[kernel.Ly_pad // 2, :], np.s_[:, kernel.Lx_pad // 2] tv_ref_g = total_variation(ref_g, rxy=kernel.rxy).reshape(kernel._padshape) tv_fit_g = total_variation(g_sol, rxy=kernel.rxy).reshape(kernel._padshape) reference = [ref_flux, true_jx, true_jy] ref_label = ["Reference flux", "True $J_x$", "True $J_y$"] reconst = [fit_flux, jx, jy] rec_label = ["Reconstructed flux", "Recovered $J_x$", "Recovered $J_y$"] jxlim = min(np.abs(jx).max(), np.abs(true_jx).max()) jylim = min(np.abs(jy).max(), np.abs(true_jy).max()) flim = min(np.abs(ref_flux).max(), np.abs(fit_flux).max()) ref_lim = [flim, jxlim, jylim] rec_lim = [flim, jxlim, jylim] for axrow in range(3): for axcol in range(4): axe = axes[axrow, axcol] if axcol == 0: # reference lim = ref_lim[axrow] axe.matshow(reference[axrow], aspect=asp, vmin=-lim, vmax=lim) axe.axis("off") axe.set_title(ref_label[axrow], fontsize=30) if axrow == 1: # Kernel inset yy, xx = reference[axrow].shape yk, xk = ker.shape py, px = yk / yy, xk / xx insetax = inset_axes( axe, width="{}%".format(px * 100), height="{}%".format(py * 100), loc=8, ) insetax.matshow(ker, cmap="Greys") simpleaxis(insetax) insetax.set_ylabel("PSF", fontsize=18, rotation="horizontal", labelpad=20, y=0.2) if axrow == 1: axe.axvline(kernel.Lx // 2, color="k", alpha=0.1) if axrow == 2: axe.axhline(kernel.Ly // 2, color="k", alpha=0.1) elif axcol == 1: lim = rec_lim[axrow] axe.matshow(reconst[axrow], aspect=asp, vmin=-lim, vmax=lim) axe.axis("off") axe.set_title(rec_label[axrow], fontsize=30) if axrow == 1: axe.axvline(kernel.Lx // 2, color="k", alpha=0.1) if axrow == 2: axe.axhline(kernel.Ly // 2, color="k", alpha=0.1) elif axcol == 2: if axrow == 0: axe.matshow(residuals, aspect=asp, vmin=-4 * sigma, vmax=4 * sigma) axe.axis("off") axe.set_title("Residuals", fontsize=30) elif axrow == 1: axe.plot(true_jx[:, kernel.Lx // 2], label=r"Reference") axe.plot(jx[:, kernel.Lx // 2], label=r"Reconstructed") axe.set_title(r"$J_x$ cross section", fontsize=30) axe.set_xlim([0, kernel.Ly]) elif axrow == 2: axe.plot(true_jy[kernel.Ly // 2], label=r"Reference") axe.plot(jy[kernel.Ly // 2], label=r"Reconstructed") axe.set_title(r"$J_y$ cross section", fontsize=30) axe.set_xlim([0, kernel.Lx]) axe.legend(loc="best", fontsize=16) axe.set_ylabel("Current", fontsize=18) elif axcol == 3: # histogram, g/TV cross sections if axrow == 0: pass # Histogram of residuals hist, bins = np.histogram(residuals.ravel(), bins=60, density=True) p = lambda x: np.exp(-(x / sigma)**2 / 2) / np.sqrt( 2 * np.pi * sigma**2) resp = 0.5 * (bins[1:] + bins[:-1]) axe.plot(resp, hist, label="Residuals") axe.plot(resp, p(resp), label="True noise") axe.set_yscale("log") axe.legend(loc="best", fontsize=16) axe.set_title("Residual distribution", fontsize=30) elif axrow == 1 or axrow == 2: sl = {1: slx, 2: sly}[axrow] axe2 = axe.twinx() gslice, ref_g_slice = g_sol[sl], ref_g[sl] axe.plot(gslice, label="Fit $g$", lw=4, c="k") axe.plot(ref_g_slice, label="True $g$", lw=3, c="r", alpha=0.6) axe.set_ylabel(r"$g$", fontsize=24) axe.set_ylim([3 * gslice.min(), 1.4 * gslice.max()]) tv_fit_slice, tv_ref_slice = tv_fit_g[sl], tv_ref_g[sl] axe2.plot(tv_fit_slice, lw=2, label="TV of fit $g$", c="k", alpha=0.4) axe2.plot( tv_ref_slice, ":", lw=3, label="TV of true $g$", c="r", alpha=0.6, ) axe2.set_ylabel("TV(g)", fontsize=24) axe2.set_ylim([0, 2 * tv_ref_slice.max()]) if axrow == 1: axe.set_title("Vertical Cross Section", fontsize=20) axe.axvline(kernel.py, color="k", lw=0.4) axe.axvline(kernel.py + kernel.Ly, color="k", lw=0.4) axe.set_xlim([0, kernel.Ly_pad]) axe2.legend(loc="upper right", fontsize=16) axe.legend(loc="lower left", fontsize=16) else: axe.set_title("Horizontal Cross Section", fontsize=20) axe.axvline(kernel.px, color="k", lw=0.4) axe.axvline(kernel.px + kernel.Lx, color="k", lw=0.4) axe.set_xlim([0, kernel.Lx_pad]) axe2.legend(loc="upper right", fontsize=16) axe.legend(loc="upper left", fontsize=16) title = (title if title is not None else "Reconstruction with $\mu$ = {}".format(gamma)) plt.suptitle(title, fontsize=25) return fig, axes
def fit_diagnostic(model, ref_data, asp, title=None): jx, jy = curl(model.g_sol, model.dx, model.dy) fit_flux = (model.kernel.applyM(model.gfield).real + model.extmodel.ext_flux) slx = np.s_[:,model.Lx/2-1:model.Lx/2+1] fig, axes = plt.subplots(3, 3, figsize=(18, 15)) flux_arrays = [ref_data, fit_flux, ref_data-fit_flux] flux_lim = np.abs(np.array(flux_arrays)).max() flux_labels = ['Measured flux', 'Reconstructed flux', 'Residuals'] j_arrays = [model.crop(jx), model.crop(jy)] j_labels = ['Recovered $J_x$','Recovered $J_y$'] j_lim = np.abs(np.array(j_arrays)).max() for axrow in range(3): for axcol in range(3): ax = axes[axrow, axcol] if axrow == 0: if axcol < 2: ax.matshow(flux_arrays[axcol], aspect = asp, vmin=-flux_lim, vmax = flux_lim) else: rlim = 4*flux_arrays[axcol].std() ax.matshow(flux_arrays[axcol], aspect = asp, vmin = -rlim, vmax=rlim) ax.axis('off') ax.set_title(flux_labels[axcol], fontsize=30) elif axrow == 1: if axcol < 2: ax.matshow(j_arrays[axcol], aspect = asp, vmin=-j_lim, vmax = j_lim) ax.axis('off') ax.set_title(j_labels[axcol], fontsize=30) if axcol == 2: res = flux_arrays[2] hist, bins = np.histogram(res, bins=60, normed=True, range=(-4*res.std(), 4*res.std())) p = lambda x: np.exp(-(x/model.sigma)**2/2)/np.sqrt(2*np.pi*model.sigma**2) resp = 0.5*(bins[1:]+bins[:-1]) ax.plot(resp, hist, label='Fit residuals') ax.plot(resp, p(resp), label='Estimated noise') ax.legend(loc='lower center', fontsize=16) ax.set_title("Residual distribution", fontsize=30) ax.set_yscale('log') ax.set_ylabel('log $P(r)$', fontsize=20) ax.set_xlabel('$r$', fontsize=20) else: if axcol == 1: ax2 = ax.twinx() lj = ax.plot(j_arrays[0][slx].mean(1), c = 'b', label='Horizontal current') lf = ax2.plot(ref_data[slx].mean(1), c = 'g', label='Flux data') lab = lj + lf ax.legend(lab, [l.get_label() for l in lab], loc='upper right', fontsize=16) ax.set_title("Vertical Cross sections", fontsize=30) ax.set_xlabel("Vertical distance", fontsize=20) ax.set_ylabel("$J_x$", fontsize=20) ax2.set_ylabel("Flux", fontsize=20) ax.set_xlim([0, model.Ly]) else: ax.axis('off') title = title if title is not None else "Reconstruction with $\mu$={}".format(model.linearModel.mu_reg) plt.suptitle(title, fontsize=25) return fig, axes
def j_density(g): return np.hypot(*curl(g))
def currents(self): return curl(self.gfield)