def inspect_strong_order(): batch_size, d, m = 4096, 5, 5 t0, t1 = ts = torch.tensor([0., 5.]).to(device) dts = tuple(2**-i for i in range(1, 9)) y0 = torch.ones(batch_size, d).to(device) sde = Ex3Additive(d=d).to(device) euler_mses_ = [] srk_mses_ = [] with torch.no_grad(): bm = BrownianPath(t0=t0, w0=torch.zeros( batch_size, m).to(device)) # It's important to have the correct size!!! for dt in tqdm.tqdm(dts): # Only take end value. _, ys_euler = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='euler') _, ys_srk = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='srk', options={'trapezoidal_approx': False}) _, ys_analytical = sde.analytical_sample(y0=y0, ts=ts, bm=bm) euler_mse = compute_mse(ys_euler, ys_analytical) srk_mse = compute_mse(ys_srk, ys_analytical) euler_mse_, srk_mse_ = to_numpy(euler_mse, srk_mse) euler_mses_.append(euler_mse_) srk_mses_.append(srk_mse_) del euler_mse_, srk_mse_ # Divide the log-error by 2, since textbook strong orders are represented so. log = lambda x: np.log(np.array(x)) euler_slope, _, _, _, _ = stats.linregress(log(dts), log(euler_mses_) / 2) srk_slope, _, _, _, _ = stats.linregress(log(dts), log(srk_mses_) / 2) plt.figure() plt.plot(dts, euler_mses_, label=f'euler(k={euler_slope:.4f})') plt.plot(dts, srk_mses_, label=f'srk(k={srk_slope:.4f})') plt.xscale('log') plt.yscale('log') plt.legend() img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_additive') makedirs_if_not_found(img_dir) plt.savefig(os.path.join(img_dir, 'rate')) plt.close()
def inspect_rate(): batch_size, d, m = 4096, 5, 5 t0, t1 = ts = torch.tensor([0., 5.]).to(device) dts = tuple(2**-i for i in range(1, 10)) y0 = torch.ones(batch_size, d).to(device) sde = Ex3Additive(d=d).to(device) euler_mses_ = [] srk_mses_ = [] with torch.no_grad(): bm = BrownianPath(t0=t0, w0=torch.zeros( batch_size, m).to(device)) # It's important to have the correct size!!! for dt in dts: # Only take end value. _, ys_euler = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='euler') _, ys_srk = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='srk', options={'trapezoidal_approx': False}) _, ys_analytical = sde.analytical_sample(y0=y0, ts=ts, bm=bm) euler_mse = compute_mse(ys_euler, ys_analytical) srk_mse = compute_mse(ys_srk, ys_analytical) euler_mse_, srk_mse_ = to_numpy(euler_mse, srk_mse) euler_mses_.append(euler_mse_) srk_mses_.append(srk_mse_) plt.figure() plt.plot(dts, euler_mses_, label='euler') plt.plot(dts, srk_mses_, label='srk') plt.xscale('log') plt.yscale('log') plt.legend() img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_additive') makedirs_if_not_found(img_dir) plt.savefig(os.path.join(img_dir, 'rate')) plt.close()
def test_normality(self): """Kolmogorov-Smirnov test.""" t0_, t1_ = 0.0, 1.0 t0, t1 = torch.tensor([t0_, t1_]) eps = 1e-2 for _ in range(REPS): w0_, w1_ = 0.0, npr.randn() * np.sqrt(t1_) # Use the same endpoint for the batch, so samples from same dist. w0 = torch.tensor(w0_).repeat(BATCH_SIZE) w1 = torch.tensor(w1_).repeat(BATCH_SIZE) bm = BrownianPath(t0=t0, w0=w0) bm.insert(t=t1, w=w1) t_ = npr.uniform(low=t0_ + eps, high=t1_ - eps) samples = bm(t_) samples_ = samples.detach().numpy() mean_ = ((t1_ - t_) * w0_ + (t_ - t0_) * w1_) / (t1_ - t0_) std_ = np.sqrt((t1_ - t_) * (t_ - t0_) / (t1_ - t0_)) ref_dist = norm(loc=mean_, scale=std_) _, pval = kstest(samples_, ref_dist.cdf) self.assertGreaterEqual(pval, ALPHA)
def inspect_rate(): batch_size, d = 4096, 10 t0, t1 = ts = torch.tensor([0., 5.]).to(device) dts = tuple(2 ** -i for i in range(1, 10)) y0 = torch.ones(batch_size, d).to(device) sde = Ex2(d=d).to(device) euler_mses_ = [] milstein_mses_ = [] srk_mses_ = [] with torch.no_grad(): bm = BrownianPath(t0=t0, w0=torch.zeros(batch_size, 1).to(device)) for dt in dts: # Only take end value. _, ys_euler = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='euler') _, ys_milstein = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='milstein') _, ys_srk = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='srk') _, ys_analytical = sde.analytical_sample(y0=y0, ts=ts, bm=bm) euler_mse = compute_mse(ys_euler, ys_analytical) milstein_mse = compute_mse(ys_milstein, ys_analytical) srk_mse = compute_mse(ys_srk, ys_analytical) euler_mse_, milstein_mse_, srk_mse_ = to_numpy(euler_mse, milstein_mse, srk_mse) euler_mses_.append(euler_mse_) milstein_mses_.append(milstein_mse_) srk_mses_.append(srk_mse_) plt.figure() plt.plot(dts, euler_mses_, label='euler') plt.plot(dts, milstein_mses_, label='milstein') plt.plot(dts, srk_mses_, label='srk') plt.xscale('log') plt.yscale('log') plt.legend() img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_scalar') makedirs_if_not_found(img_dir) plt.savefig(os.path.join(img_dir, 'rate')) plt.close()
def inspect_samples(): batch_size, d, m = 32, 1, 5 steps = 10 ts = torch.linspace(0., 5., steps=steps).to(device) t0 = ts[0] dt = 3e-1 y0 = torch.ones(batch_size, d).to(device) sde = AdditiveSDE(d=d, m=m).to(device) with torch.no_grad(): bm = BrownianPath(t0=t0, w0=torch.zeros(batch_size, m).to(device)) ys_em = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='euler') ys_srk = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='srk', options={'trapezoidal_approx': False}) ys_true = sdeint(sde, y0=y0, ts=ts, dt=1e-3, bm=bm, method='euler') ys_em = ys_em.squeeze().t() ys_srk = ys_srk.squeeze().t() ys_true = ys_true.squeeze().t() ts_, ys_em_, ys_srk_, ys_true_ = to_numpy(ts, ys_em, ys_srk, ys_true) # Visualize sample path. img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_additive') makedirs_if_not_found(img_dir) for i, (ys_em_i, ys_srk_i, ys_true_i) in enumerate(zip(ys_em_, ys_srk_, ys_true_)): plt.figure() plt.plot(ts_, ys_em_i, label='em') plt.plot(ts_, ys_srk_i, label='srk') plt.plot(ts_, ys_true_i, label='true') plt.legend() plt.savefig(os.path.join(img_dir, f'{i}')) plt.close()
def inspect_sample(): batch_size, d = 32, 1 steps = 100 ts = torch.linspace(0., 5., steps=steps).to(device) t0 = ts[0] dt = 1e-1 y0 = torch.ones(batch_size, d).to(device) sde = Ex2(d=d).to(device) sde.noise_type = "scalar" with torch.no_grad(): bm = BrownianPath(t0=t0, w0=torch.zeros_like(y0)) ys_euler = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='euler') ys_milstein = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='milstein') ys_srk = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='srk', options={'trapezoidal_approx': False}) ys_analytical = sde.analytical_sample(y0=y0, ts=ts, bm=bm) ys_euler = ys_euler.squeeze().t() ys_milstein = ys_milstein.squeeze().t() ys_srk = ys_srk.squeeze().t() ys_analytical = ys_analytical.squeeze().t() ts_, ys_euler_, ys_milstein_, ys_srk_, ys_analytical_ = to_numpy( ts, ys_euler, ys_milstein, ys_srk, ys_analytical) # Visualize sample path. img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_scalar') makedirs_if_not_found(img_dir) for i, (ys_euler_i, ys_milstein_i, ys_srk_i, ys_analytical_i) in enumerate( zip(ys_euler_, ys_milstein_, ys_srk_, ys_analytical_)): plt.figure() plt.plot(ts_, ys_euler_i, label='euler') plt.plot(ts_, ys_milstein_i, label='milstein') plt.plot(ts_, ys_srk_i, label='srk') plt.plot(ts_, ys_analytical_i, label='analytical') plt.legend() plt.savefig(os.path.join(img_dir, f'{i}')) plt.close()
def _test_gradient(self, problem, method, adaptive, rtol=1e-6, atol=1e-5): if method == 'euler' and adaptive: return bm = BrownianPath(t0=t0, w0=w0) with torch.no_grad(): grad_outputs = torch.ones(batch_size, d).to(device) alt_grad = problem.analytical_grad(y0, t1, grad_outputs, bm) problem.zero_grad() _, yt = sdeint_adjoint(problem, y0, ts, bm=bm, method=method, dt=dt, adaptive=adaptive, rtol=rtol, atol=atol) loss = yt.sum(dim=1).mean(dim=0) loss.backward() adj_grad = torch.cat(tuple(p.grad for p in problem.parameters())) self.tensorAssertAllClose(alt_grad, adj_grad)
def _test_forward_and_backward(sde): bm = BrownianPath(t0=t0, w0=w0) for method in methods: _test_forward(sde, bm, method=method) _test_backward(sde, bm, method=method)
def on_epoch_end(self, vis_n_sim=1024): img_path = os.path.join(train_dir, f'global_step_{self.current_epoch}.png') ylims = (-1.75, 1.75) alphas = [0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50, 0.55] percentiles = [0.999, 0.99, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1] sample_colors = ('#8c96c6', '#8c6bb1', '#810f7c') fill_color = '#9ebcda' mean_color = '#4d004b' num_samples = len(sample_colors) vis_idx = np.random.permutation(vis_n_sim) eps = torch.randn(vis_n_sim, 1) bm = BrownianPath(t0=self.vis_span[0], w0=torch.zeros(vis_n_sim, 1)) # -- Not used -- From show_prior option in original implementation # zs = self.model.sample_p(vis_span=self.vis_span, n_sim=vis_n_sim, eps=eps, bm=bm).squeeze() # ts_vis_, zs_ = self.vis_span.cpu().numpy(), zs.cpu().numpy() # zs_ = np.sort(zs_, axis=1) zs = self.model.lsde.sample_q(vis_span=self.vis_span, n_sim=vis_n_sim, eps=eps, bm=bm).squeeze() samples = zs[:, vis_idx] s_span_vis_ = self.vis_span.cpu().detach().numpy() zs_ = zs.cpu().detach().numpy() samples_ = samples.cpu().detach().numpy() zs_ = np.sort(zs_, axis=1) with torch.no_grad(): plt.subplot(frameon=False) for alpha, percentile in zip(alphas, percentiles): idx = int((1 - percentile) / 2. * vis_n_sim) zs_bot_, zs_top_ = zs_[:, idx], zs_[:, -idx] plt.fill_between(s_span_vis_, zs_bot_, zs_top_, alpha=alpha, color=fill_color) plt.plot(s_span_vis_, zs_.mean(axis=1), color=mean_color) for j in range(num_samples): plt.plot(s_span_vis_, samples_[:, j], color=sample_colors[j], linewidth=1.0) num, ds = 12, 0.12 s, x = torch.meshgrid( [torch.linspace(0.2, 1.8, num), torch.linspace(-1.5, 1.5, num)] ) s, x = s.reshape(-1, 1).to(self.device), x.reshape(-1, 1).to(self.device) ftx = self.model.lsde.defunc.f(s=s, x=x) ftx = ftx.cpu().reshape(num, num) ds = torch.zeros(num, num).fill_(ds) dx = ftx * ds ds_, dx_, = ds.cpu().detach().numpy(), dx.cpu().detach().numpy() s_, x_ = s.cpu().detach().numpy(), x.cpu().detach().numpy() plt.quiver(s_, x_, ds_, dx_, alpha=0.3, edgecolors='k', width=0.0035, scale=50) # Data. plt.scatter(self.s_span.cpu().numpy(), self.x_sample.cpu().numpy(), marker='x', zorder=3, color='k', s=35) plt.ylim(ylims) plt.xlabel('$t$') plt.ylabel('$Y_t$') plt.tight_layout() plt.savefig(img_path, dpi=400) plt.close()
t0 = 0.0 t1 = 0.3 T = 5 batch_size = 16 dt = 1e-2 ts = torch.linspace(t0, t1, steps=T).to(device) y0 = torch.ones(batch_size, d).to(device) basic_sdes = ( basic_sde.BasicSDE1(d=d).to(device), basic_sde.BasicSDE2(d=d).to(device), basic_sde.BasicSDE3(d=d).to(device), basic_sde.BasicSDE4(d=d).to(device), ) bm_diagonal = BrownianPath(t0=ts[0], w0=torch.zeros(batch_size, d).to(device)) bm_general = BrownianPath(t0=ts[0], w0=torch.zeros(batch_size, m).to(device)) bm_scalar = BrownianPath(t0=ts[0], w0=torch.zeros(batch_size, 1).to(device)) class TestSdeint(TorchTestCase): def test_rename_methods(self): # Test renaming works with a subset of names when `logqp=False`. sde = basic_sde.CustomNamesSDE().to(device) ans = sdeint(sde, y0, ts, dt=dt, names={'drift': 'forward'}) self.assertEqual(ans.shape, (T, batch_size, d)) # Test renaming works with a subset of names when `logqp=True`. sde = basic_sde.CustomNamesSDELogqp().to(device) ans = sdeint(sde, y0, ts, dt=dt, names={'drift': 'forward', 'prior_drift': 'w'}, logqp=True)
def _setUp(self, device=None): t0, t1 = torch.tensor([0., 1.]).to(device) w0, w1 = torch.randn([2, BATCH_SIZE, D]).to(device) self.t = torch.rand([]).to(device) self.bm = BrownianPath(t0=t0, w0=w0)
def main(): # Dataset. ts_, ts_ext_, ts_vis_, ts, ts_ext, ts_vis, ys, ys_ = make_data() # Plotting parameters. vis_batch_size = 1024 ylims = (-1.75, 1.75) alphas = [0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50, 0.55] percentiles = [0.999, 0.99, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1] vis_idx = npr.permutation(vis_batch_size) # From https://colorbrewer2.org/. if args.color == "blue": sample_colors = ('#8c96c6', '#8c6bb1', '#810f7c') fill_color = '#9ebcda' mean_color = '#4d004b' num_samples = len(sample_colors) else: sample_colors = ('#fc4e2a', '#e31a1c', '#bd0026') fill_color = '#fd8d3c' mean_color = '#800026' num_samples = len(sample_colors) # Fix seed for the random draws used in the plots. eps = torch.randn(vis_batch_size, 1).to(device) bm = BrownianPath(t0=ts_vis[0], w0=torch.zeros(vis_batch_size, 1).to(device)) # Model. model = LatentSDE().to(device) optimizer = optim.Adam(model.parameters(), lr=1e-2) scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=.999) kl_scheduler = utils.LinearScheduler(iters=args.kl_anneal_iters) logp_metric = utils.EMAMetric() log_ratio_metric = utils.EMAMetric() loss_metric = utils.EMAMetric() if args.show_prior: with torch.no_grad(): zs = model.sample_p(ts=ts_vis, batch_size=vis_batch_size, eps=eps, bm=bm).squeeze() ts_vis_, zs_ = ts_vis.cpu().numpy(), zs.cpu().numpy() zs_ = np.sort(zs_, axis=1) img_dir = os.path.join(args.train_dir, 'prior.png') plt.subplot(frameon=False) for alpha, percentile in zip(alphas, percentiles): idx = int((1 - percentile) / 2. * vis_batch_size) zs_bot_ = zs_[:, idx] zs_top_ = zs_[:, -idx] plt.fill_between(ts_vis_, zs_bot_, zs_top_, alpha=alpha, color=fill_color) # `zorder` determines who's on top; the larger the more at the top. plt.scatter(ts_, ys_, marker='x', zorder=3, color='k', s=35) # Data. plt.ylim(ylims) plt.xlabel('$t$') plt.ylabel('$Y_t$') plt.tight_layout() plt.savefig(img_dir, dpi=args.dpi) plt.close() logging.info(f'Saved prior figure at: {img_dir}') for global_step in tqdm.tqdm(range(args.train_iters)): # Plot and save. if global_step % args.pause_iters == 0: img_path = os.path.join(args.train_dir, f'global_step_{global_step}.png') with torch.no_grad(): zs = model.sample_q(ts=ts_vis, batch_size=vis_batch_size, eps=eps, bm=bm).squeeze() samples = zs[:, vis_idx] ts_vis_, zs_, samples_ = ts_vis.cpu().numpy(), zs.cpu().numpy( ), samples.cpu().numpy() zs_ = np.sort(zs_, axis=1) plt.subplot(frameon=False) if args.show_percentiles: for alpha, percentile in zip(alphas, percentiles): idx = int((1 - percentile) / 2. * vis_batch_size) zs_bot_, zs_top_ = zs_[:, idx], zs_[:, -idx] plt.fill_between(ts_vis_, zs_bot_, zs_top_, alpha=alpha, color=fill_color) if args.show_mean: plt.plot(ts_vis_, zs_.mean(axis=1), color=mean_color) if args.show_samples: for j in range(num_samples): plt.plot(ts_vis_, samples_[:, j], color=sample_colors[j], linewidth=1.0) if args.show_arrows: num, dt = 12, 0.12 t, y = torch.meshgrid([ torch.linspace(0.2, 1.8, num).to(device), torch.linspace(-1.5, 1.5, num).to(device) ]) t, y = t.reshape(-1, 1), y.reshape(-1, 1) fty = model.f(t=t, y=y).reshape(num, num) dt = torch.zeros(num, num).fill_(dt).to(device) dy = fty * dt dt_, dy_, t_, y_ = dt.cpu().numpy(), dy.cpu().numpy( ), t.cpu().numpy(), y.cpu().numpy() plt.quiver(t_, y_, dt_, dy_, alpha=0.3, edgecolors='k', width=0.0035, scale=50) if args.hide_ticks: plt.xticks([], []) plt.yticks([], []) plt.scatter(ts_, ys_, marker='x', zorder=3, color='k', s=35) # Data. plt.ylim(ylims) plt.xlabel('$t$') plt.ylabel('$Y_t$') plt.tight_layout() plt.savefig(img_path, dpi=args.dpi) plt.close() logging.info(f'Saved figure at: {img_path}') if args.save_ckpt: torch.save({'model': model.state_dict()}, os.path.join(ckpt_dir, f'global_step_{global_step}.ckpt')) # Train. optimizer.zero_grad() zs, log_ratio = model(ts=ts_ext, batch_size=args.batch_size) zs = zs.squeeze() zs = zs[ 1: -1] # Drop first and last which are only used to penalize out-of-data region and spread uncertainty. likelihood = { "laplace": Laplace(loc=zs, scale=args.scale), "normal": Normal(loc=zs, scale=args.scale) }[args.likelihood] logp = likelihood.log_prob(ys).sum(dim=0).mean(dim=0) loss = -logp + log_ratio * kl_scheduler() loss.backward() optimizer.step() scheduler.step() kl_scheduler.step() logp_metric.step(logp) log_ratio_metric.step(log_ratio) loss_metric.step(loss) logging.info( f'global_step: {global_step}, ' f'logp: {logp_metric.val():.3f}, log_ratio: {log_ratio_metric.val():.3f}, loss: {loss_metric.val():.3f}' )