def test_entropy_determinism(random_order, device, levy_area_approximation, return_U, return_A): if device == gpu and not torch.cuda.is_available(): pytest.skip(msg="CUDA not available.") t0, t1 = 0.0, 1.0 entropy = 56789 points1 = torch.rand(1000) points2 = torch.rand(1000) outs = [] tol = 1e-6 if random_order else 0. bm = torchsde.BrownianInterval( t0=t0, t1=t1, size=(), device=device, levy_area_approximation=levy_area_approximation, entropy=entropy, tol=tol, halfway_tree=random_order) for point1, point2 in zip(points1, points2): point1, point2 = sorted([point1, point2]) outs.append(bm(point1, point2, return_U=return_U, return_A=return_A)) bm = torchsde.BrownianInterval( t0=t0, t1=t1, size=(), device=device, levy_area_approximation=levy_area_approximation, entropy=entropy, tol=tol, halfway_tree=random_order) if random_order: perm = torch.randperm(1000) points1 = points1[perm] points2 = points2[perm] outs = [outs[i.item()] for i in perm] for point1, point2, out in zip(points1, points2, outs): point1, point2 = sorted([point1, point2]) out_ = bm(point1, point2, return_U=return_U, return_A=return_A) # Assert equal if torch.is_tensor(out): out = (out, ) if torch.is_tensor(out_): out_ = (out_, ) for outi, outi_ in zip(out, out_): if torch.is_tensor(outi): assert (outi == outi_).all() else: assert outi == outi_
def test_reversibility(sde_cls): batch_size = 32 state_size = 4 t_size = 20 dt = 0.1 brownian_size = { NOISE_TYPES.scalar: 1, NOISE_TYPES.diagonal: state_size, NOISE_TYPES.general: 2, NOISE_TYPES.additive: 2 }[sde_cls.noise_type] class MinusSDE(torch.nn.Module): def __init__(self, sde): self.noise_type = sde.noise_type self.sde_type = sde.sde_type self.f = lambda t, y: -sde.f(-t, y) self.g = lambda t, y: -sde.g(-t, y) sde = sde_cls(d=state_size, m=brownian_size, sde_type='stratonovich') minus_sde = MinusSDE(sde) y0 = torch.full((batch_size, state_size), 0.1) ts = torch.linspace(0, (t_size - 1) * dt, t_size) bm = torchsde.BrownianInterval(t0=ts[0], t1=ts[-1], size=(batch_size, brownian_size)) ys, (f, g, z) = torchsde.sdeint(sde, y0, ts, bm=bm, method='reversible_heun', dt=dt, extra=True) backward_ts = -ts.flip(0) backward_ys = torchsde.sdeint(minus_sde, ys[-1], backward_ts, bm=torchsde.ReverseBrownian(bm), method='reversible_heun', dt=dt, extra_solver_state=(-f, -g, z)) backward_ys = backward_ys.flip(0) torch.testing.assert_allclose(ys, backward_ys, rtol=1e-6, atol=1e-6)
def test_consistency(device, levy_area_approximation): if device == gpu and not torch.cuda.is_available(): pytest.skip(msg="CUDA not available.") t0, t1 = 0.0, 1.0 for _ in range(REPS): bm = torchsde.BrownianInterval( t0=t0, t1=t1, size=(LARGE_BATCH_SIZE, ), device=device, levy_area_approximation=levy_area_approximation) for _ in range(MEDIUM_REPS): ta, t_, tb = sorted(npr.uniform(low=t0, high=t1, size=(3, ))) if levy_area_approximation == 'none': W = bm(ta, tb) W1 = bm(ta, t_) W2 = bm(t_, tb) else: W, U = bm(ta, tb, return_U=True) W1, U1 = bm(ta, t_, return_U=True) W2, U2 = bm(t_, tb, return_U=True) torch.testing.assert_allclose(W1 + W2, W, rtol=1e-6, atol=1e-6) if levy_area_approximation != 'none': torch.testing.assert_allclose(U1 + U2 + (tb - t_) * W1, U, rtol=1e-6, atol=1e-6)
def test_normality_simple(device, levy_area_approximation): if device == gpu and not torch.cuda.is_available(): pytest.skip(msg="CUDA not available.") t0, t1 = 0.0, 1.0 for _ in range(REPS): base_W = torch.tensor(npr.randn(), device=device).repeat(LARGE_BATCH_SIZE) bm = torchsde.BrownianInterval( t0=t0, t1=t1, W=base_W, levy_area_approximation=levy_area_approximation) t_ = npr.uniform(low=t0, high=t1) W = bm(t0, t_) mean_W = base_W * (t_ - t0) / (t1 - t0) std_W = math.sqrt((t1 - t_) * (t_ - t0) / (t1 - t0)) rescaled_W = (W - mean_W) / std_W _, pval = kstest(rescaled_W.cpu().detach().numpy(), 'norm') assert pval >= ALPHA if levy_area_approximation != 'none': W, U = bm(t0, t_, return_U=True) H = _U_to_H(W, U, t_ - t0) mean_H = 0 std_H = math.sqrt((t_ - t0) / 12) rescaled_H = (H - mean_H) / std_H _, pval = kstest(rescaled_H.cpu().detach().numpy(), 'norm') assert pval >= ALPHA
def main( batch_size=1024, context_size=64, hidden_size=128, lr_init=1e-2, t0=0., t1=2., lr_gamma=0.997, num_iters=5000, kl_anneal_iters=500, pause_every=50, noise_std=0.01, adjoint=False, train_dir='./dump/lorenz/', method="euler", ): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') xs, ts = make_dataset(t0=t0, t1=t1, batch_size=batch_size, noise_std=noise_std, train_dir=train_dir, device=device) latent_sde = LatentSDE(data_size=3, context_size=context_size, hidden_size=hidden_size).to(device) optimizer = optim.Adam(params=latent_sde.parameters(), lr=lr_init) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=lr_gamma) kl_scheduler = LinearScheduler(iters=kl_anneal_iters) # Fix the same Brownian motion for visualization. bm_vis = torchsde.BrownianInterval(t0=t0, t1=t1, size=( batch_size, 3, ), device=device, levy_area_approximation="space-time") for global_step in tqdm.tqdm(range(1, num_iters + 1)): latent_sde.zero_grad() log_pxs, log_ratio = latent_sde(xs, ts, noise_std, adjoint, method) loss = -log_pxs + log_ratio * kl_scheduler.val loss.backward() optimizer.step() scheduler.step() kl_scheduler.step() if global_step % pause_every == 0: lr_now = optimizer.param_groups[0]['lr'] logging.warning( f'global_step: {global_step:06d}, lr: {lr_now:.5f}, ' f'log_pxs: {log_pxs:.4f}, log_ratio: {log_ratio:.4f} loss: {loss:.4f}, kl_coeff: {kl_scheduler.val:.4f}' ) img_path = os.path.join(train_dir, f'global_step_{global_step:06d}.pdf') vis(xs, ts, latent_sde, bm_vis, img_path)
def forward(self, ts, batch_size, eps=None): eps = torch.randn(batch_size, 1).to( self.qy0_std) if eps is None else eps y0 = self.qy0_mean + eps * self.qy0_std qy0 = distributions.Normal(loc=self.qy0_mean, scale=self.qy0_std) py0 = distributions.Normal(loc=self.py0_mean, scale=self.py0_std) logqp0 = distributions.kl_divergence(qy0, py0).sum(dim=1) # KL(t=0). bm = torchsde.BrownianInterval(t0=ts[0], t1=ts[-1], dtype=y0.dtype, device=y0.device, size=(batch_size, 2), pool_size=POOL_SIZE, cache_size=CACHE_SIZE) aug_y0 = torch.cat([y0, torch.zeros(batch_size, 1).to(y0)], dim=1) aug_ys = sdeint_fn(sde=self, bm=bm, y0=aug_y0, ts=ts.to(device), method=args.method, dt=args.dt, adaptive=args.adaptive, adjoint_adaptive=args.adjoint_adaptive, rtol=args.rtol, atol=args.atol, names={ 'drift': 'f_aug', 'diffusion': 'g_aug' }) ys, logqp_path = aug_ys[:, :, 0:1], aug_ys[-1, :, 1] logqp = (logqp0 + logqp_path).mean(dim=0) # KL(t=0) + KL(path). return ys, logqp
def test_sdeint_run_shape_method(sde_cls, use_bm, levy_area_approximation, sde_type, method, adaptive, logqp, device): """Tests that sdeint: (a) runs/raises an error as appropriate (b) produces tensors of the right shape (c) accepts every method """ if method == 'milstein_grad_free': method = 'milstein' options = dict(grad_free=True) else: options = dict() should_fail = False if sde_type == 'ito': if method not in ('euler', 'srk', 'milstein'): should_fail = True else: if method not in ('euler_heun', 'heun', 'midpoint', 'log_ode', 'milstein'): should_fail = True if method in ('milstein', 'srk') and sde_cls.noise_type == 'general': should_fail = True if method == 'srk' and levy_area_approximation == 'none': should_fail = True if method == 'log_ode' and levy_area_approximation in ('none', 'space-time'): should_fail = True if sde_cls.noise_type in (NOISE_TYPES.scalar, NOISE_TYPES.diagonal): kwargs = {'d': d} else: kwargs = {'d': d, 'm': m} sde = sde_cls(sde_type=sde_type, **kwargs).to(device) if use_bm: if sde_cls.noise_type == 'scalar': size = (batch_size, 1) elif sde_cls.noise_type == 'diagonal': size = (batch_size, d + 1) if logqp else (batch_size, d) else: assert sde_cls.noise_type in ('additive', 'general') size = (batch_size, m) bm = torchsde.BrownianInterval( t0=t0, t1=t1, size=size, dtype=dtype, device=device, levy_area_approximation=levy_area_approximation) else: bm = None _test_sdeint(sde, bm, method, adaptive, logqp, device, should_fail, options)
def forward(self, y, adjoint=False, dt=0.02, adaptive=False, adjoint_adaptive=False, method="midpoint", rtol=1e-4, atol=1e-3): # Note: This works correctly, as long as we are requesting the nfe after each gradient update. # There are obviously cleaner ways to achieve this. self.nfe = 0 sdeint = torchsde.sdeint_adjoint if adjoint else torchsde.sdeint if self.aug_zeros.numel() > 0: # Add zero channels. aug_zeros = self.aug_zeros.expand(y.shape[0], *self.aug_zeros_size) y = torch.cat((y, aug_zeros), dim=1) # 235200 aug_y = torch.cat( (y.reshape(-1), self.flat_initial_params, torch.tensor([0.], device=y.device))) # 841609: (235200, 606408, 1) aug_y = aug_y[None] bm = torchsde.BrownianInterval( t0=self.ts[0], t1=self.ts[-1], size=aug_y.shape, dtype=aug_y.dtype, device=aug_y.device, cache_size=45 if adjoint else 30 # If not adjoint, don't really need to cache. ) if adjoint_adaptive: _, aug_y1 = sdeint(self, aug_y, self.ts, bm=bm, method=method, dt=dt, adaptive=adaptive, adjoint_adaptive=adjoint_adaptive, rtol=rtol, atol=atol) else: _, aug_y1 = sdeint(self, aug_y, self.ts, bm=bm, method=method, dt=dt, adaptive=adaptive, rtol=rtol, atol=atol) y1 = aug_y1[:y.numel()].reshape(y.size()) logits = self.projection(y1) logqp = .5 * aug_y1[-1] return logits, logqp
def _setup(device, levy_area_approximation, shape): t0, t1 = torch.tensor([0., 1.], device=device) ta = torch.rand([], device=device) tb = torch.rand([], device=device) ta, tb = min(ta, tb), max(ta, tb) bm = torchsde.BrownianInterval( t0=t0, t1=t1, size=shape, device=device, levy_area_approximation=levy_area_approximation) return ta, tb, bm
def test_adjoint(sde_cls, method, sde_type, adaptive): # Skipping below, since method not supported for corresponding noise types. if sde_cls.noise_type == NOISE_TYPES.general and method in ( METHODS.milstein, METHODS.srk): return d = 3 m = { NOISE_TYPES.scalar: 1, NOISE_TYPES.diagonal: d, NOISE_TYPES.general: 2, NOISE_TYPES.additive: 2 }[sde_cls.noise_type] batch_size = 4 t0, t1 = ts = torch.tensor([0.0, 0.5], device=device) dt = 1e-3 y0 = torch.zeros(batch_size, d).to(device).fill_(0.1) sde = sde_cls(d=d, m=m, sde_type=sde_type).to(device) levy_area_approximation = { 'euler': 'none', 'milstein': 'none', 'srk': 'space-time', 'midpoint': 'none' }[method] bm = torchsde.BrownianInterval( t0=t0, t1=t1, size=(batch_size, m), dtype=dtype, device=device, levy_area_approximation=levy_area_approximation) def func(inputs, modules): y0, sde = inputs[0], modules[0] ys = torchsde.sdeint_adjoint(sde, y0, ts, bm, dt=dt, method=method, adaptive=adaptive) return (ys[-1]**2).sum(dim=1).mean(dim=0) # `grad_inputs=True` also works, but we really only care about grad wrt params and want fast tests. utils.gradcheck(func, y0, sde, eps=1e-6, rtol=1e-2, atol=1e-2, grad_params=True)
def _compare(w0, ts, msg=''): bm = torchsde.BrownianPath(t0=t0, w0=w0) bp_py_time = _time_query(bm, ts) logging.warning(f'{msg} (torchsde.BrownianPath): {bp_py_time:.4f}') bm = torchsde.BrownianTree(t0=t0, t1=t1, w0=w0, tol=1e-5) bt_py_time = _time_query(bm, ts) logging.warning(f'{msg} (torchsde.BrownianTree): {bt_py_time:.4f}') bm = torchsde.BrownianInterval(t0=t0, t1=t1, size=w0.shape, dtype=w0.dtype, device=w0.device) bi_py_time = _time_query(bm, ts) logging.warning(f'{msg} (torchsde.BrownianInterval): {bi_py_time:.4f}') return bp_py_time, bt_py_time, bi_py_time
def test_specialised_functions(sde_type, method): vector = torch.randn(m) fg = problems.FGSDE(sde_type, vector) f_and_g = problems.FAndGSDE(sde_type, vector) g_prod = problems.GProdSDE(sde_type, vector) f_and_g_prod = problems.FAndGProdSDE(sde_type, vector) f_and_g_with_g_prod1 = problems.FAndGGProdSDE1(sde_type, vector) f_and_g_with_g_prod2 = problems.FAndGGProdSDE2(sde_type, vector) y0 = torch.randn(batch_size, d) outs = [] for sde in (fg, f_and_g, g_prod, f_and_g_prod, f_and_g_with_g_prod1, f_and_g_with_g_prod2): bm = torchsde.BrownianInterval(t0, t1, (batch_size, m), entropy=45678) outs.append(torchsde.sdeint(sde, y0, [t0, t1], dt=dt, bm=bm)[1]) for o in outs[1:]: # Equality of floating points, because we expect them to do everything exactly the same. assert o.shape == outs[0].shape assert (o == outs[0]).all()
def main(): # Dataset. ts_, ts_ext_, ts_vis_, ts, ts_ext, ts_vis, ys, ys_ = make_data() # Plotting parameters. vis_batch_size = 1024 ylims = (-1.75, 1.75) alphas = [0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50, 0.55] percentiles = [0.999, 0.99, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1] vis_idx = np.random.permutation(vis_batch_size) # From https://colorbrewer2.org/. if args.color == "blue": sample_colors = ('#8c96c6', '#8c6bb1', '#810f7c') fill_color = '#9ebcda' mean_color = '#4d004b' num_samples = len(sample_colors) else: sample_colors = ('#fc4e2a', '#e31a1c', '#bd0026') fill_color = '#fd8d3c' mean_color = '#800026' num_samples = len(sample_colors) eps = torch.randn(vis_batch_size, 1).to( device) # Fix seed for the random draws used in the plots. bm = torchsde.BrownianInterval( t0=ts_vis[0], t1=ts_vis[-1], size=(vis_batch_size, 1), device=device, levy_area_approximation='space-time' ) # We need space-time Levy area to use the SRK solver # Model. model = LatentSDE().to(device) optimizer = optim.Adam(model.parameters(), lr=1e-2) scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=.999) kl_scheduler = LinearScheduler(iters=args.kl_anneal_iters) logpy_metric = EMAMetric() kl_metric = EMAMetric() loss_metric = EMAMetric() if args.show_prior: with torch.no_grad(): zs = model.sample_p(ts=ts_vis, batch_size=vis_batch_size, eps=eps, bm=bm).squeeze() ts_vis_, zs_ = ts_vis.cpu().numpy(), zs.cpu().numpy() zs_ = np.sort(zs_, axis=1) img_dir = os.path.join(args.train_dir, 'prior.png') plt.subplot(frameon=False) for alpha, percentile in zip(alphas, percentiles): idx = int((1 - percentile) / 2. * vis_batch_size) zs_bot_ = zs_[:, idx] zs_top_ = zs_[:, -idx] plt.fill_between(ts_vis_, zs_bot_, zs_top_, alpha=alpha, color=fill_color) # `zorder` determines who's on top; the larger the more at the top. plt.scatter(ts_, ys_, marker='x', zorder=3, color='k', s=35) # Data. plt.ylim(ylims) plt.xlabel('$t$') plt.ylabel('$Y_t$') plt.tight_layout() plt.savefig(img_dir, dpi=args.dpi) plt.close() logging.info(f'Saved prior figure at: {img_dir}') for global_step in tqdm.tqdm(range(args.train_iters)): # Plot and save. if global_step % args.pause_iters == 0: img_path = os.path.join(args.train_dir, f'global_step_{global_step}.png') with torch.no_grad(): zs = model.sample_q(ts=ts_vis, batch_size=vis_batch_size, eps=eps, bm=bm).squeeze() samples = zs[:, vis_idx] ts_vis_, zs_, samples_ = ts_vis.cpu().numpy(), zs.cpu().numpy( ), samples.cpu().numpy() zs_ = np.sort(zs_, axis=1) plt.subplot(frameon=False) if args.show_percentiles: for alpha, percentile in zip(alphas, percentiles): idx = int((1 - percentile) / 2. * vis_batch_size) zs_bot_, zs_top_ = zs_[:, idx], zs_[:, -idx] plt.fill_between(ts_vis_, zs_bot_, zs_top_, alpha=alpha, color=fill_color) if args.show_mean: plt.plot(ts_vis_, zs_.mean(axis=1), color=mean_color) if args.show_samples: for j in range(num_samples): plt.plot(ts_vis_, samples_[:, j], color=sample_colors[j], linewidth=1.0) if args.show_arrows: num, dt = 12, 0.12 t, y = torch.meshgrid([ torch.linspace(0.2, 1.8, num).to(device), torch.linspace(-1.5, 1.5, num).to(device) ]) t, y = t.reshape(-1, 1), y.reshape(-1, 1) fty = model.f(t=t, y=y).reshape(num, num) dt = torch.zeros(num, num).fill_(dt).to(device) dy = fty * dt dt_, dy_, t_, y_ = dt.cpu().numpy(), dy.cpu().numpy( ), t.cpu().numpy(), y.cpu().numpy() plt.quiver(t_, y_, dt_, dy_, alpha=0.3, edgecolors='k', width=0.0035, scale=50) if args.hide_ticks: plt.xticks([], []) plt.yticks([], []) plt.scatter(ts_, ys_, marker='x', zorder=3, color='k', s=35) # Data. plt.ylim(ylims) plt.xlabel('$t$') plt.ylabel('$Y_t$') plt.tight_layout() plt.savefig(img_path, dpi=args.dpi) plt.close() logging.info(f'Saved figure at: {img_path}') if args.save_ckpt: torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'kl_scheduler': kl_scheduler }, os.path.join(ckpt_dir, f'global_step_{global_step}.ckpt')) # Train. optimizer.zero_grad() zs, kl = model(ts=ts_ext, batch_size=args.batch_size) zs = zs.squeeze() zs = zs[ 1: -1] # Drop first and last which are only used to penalize out-of-data region and spread uncertainty. likelihood_constructor = { "laplace": distributions.Laplace, "normal": distributions.Normal }[args.likelihood] likelihood = likelihood_constructor(loc=zs, scale=args.scale) logpy = likelihood.log_prob(ys).sum(dim=0).mean(dim=0) loss = -logpy + kl * kl_scheduler.val loss.backward() optimizer.step() scheduler.step() kl_scheduler.step() logpy_metric.step(logpy) kl_metric.step(kl) loss_metric.step(loss) logging.info(f'global_step: {global_step}, ' f'logpy: {logpy_metric.val:.3f}, ' f'kl: {kl_metric.val:.3f}, ' f'loss: {loss_metric.val:.3f}')
def test_against_sdeint(sde_cls, sde_type, method, options, dt, rtol, atol, len_ts): # Skipping below, since method not supported for corresponding noise types. if sde_cls.noise_type == NOISE_TYPES.general and method in ( METHODS.milstein, METHODS.srk): return d = 3 m = { NOISE_TYPES.scalar: 1, NOISE_TYPES.diagonal: d, NOISE_TYPES.general: 2, NOISE_TYPES.additive: 2 }[sde_cls.noise_type] batch_size = 4 ts = torch.linspace(0.0, 1.0, len_ts, device=device, dtype=torch.float64) t0 = ts[0] t1 = ts[-1] y0 = torch.full((batch_size, d), 0.1, device=device, dtype=torch.float64, requires_grad=True) sde = sde_cls(d=d, m=m, sde_type=sde_type).to(device, torch.float64) if method == METHODS.srk: levy_area_approximation = LEVY_AREA_APPROXIMATIONS.space_time else: levy_area_approximation = LEVY_AREA_APPROXIMATIONS.none bm = torchsde.BrownianInterval( t0=t0, t1=t1, size=(batch_size, m), dtype=torch.float64, device=device, levy_area_approximation=levy_area_approximation) if method == METHODS.reversible_heun: adjoint_method = METHODS.adjoint_reversible_heun adjoint_options = options else: adjoint_method = None adjoint_options = None ys_true = torchsde.sdeint(sde, y0, ts, dt=dt, method=method, bm=bm, options=options) grad = torch.randn_like(ys_true) ys_true.backward(grad) true_grad = torch.cat([y0.grad.view(-1)] + [param.grad.view(-1) for param in sde.parameters()]) y0.grad.zero_() for param in sde.parameters(): param.grad.zero_() ys_test = torchsde.sdeint_adjoint(sde, y0, ts, dt=dt, method=method, bm=bm, adjoint_method=adjoint_method, options=options, adjoint_options=adjoint_options) ys_test.backward(grad) test_grad = torch.cat([y0.grad.view(-1)] + [param.grad.view(-1) for param in sde.parameters()]) torch.testing.assert_allclose(ys_true, ys_test) torch.testing.assert_allclose(true_grad, test_grad, rtol=rtol, atol=atol)
def test_against_numerical(sde_cls, sde_type, method, options, adaptive): # Skipping below, since method not supported for corresponding noise types. if sde_cls.noise_type == NOISE_TYPES.general and method in ( METHODS.milstein, METHODS.srk): return d = 3 m = { NOISE_TYPES.scalar: 1, NOISE_TYPES.diagonal: d, NOISE_TYPES.general: 2, NOISE_TYPES.additive: 2 }[sde_cls.noise_type] batch_size = 4 t0, t1 = ts = torch.tensor([0.0, 0.5], device=device) dt = 1e-3 y0 = torch.full((batch_size, d), 0.1, device=device) sde = sde_cls(d=d, m=m, sde_type=sde_type).to(device) if method == METHODS.srk: levy_area_approximation = LEVY_AREA_APPROXIMATIONS.space_time else: levy_area_approximation = LEVY_AREA_APPROXIMATIONS.none bm = torchsde.BrownianInterval( t0=t0, t1=t1, size=(batch_size, m), dtype=dtype, device=device, levy_area_approximation=levy_area_approximation) if method == METHODS.reversible_heun: tol = 1e-6 adjoint_method = METHODS.adjoint_reversible_heun adjoint_options = options else: tol = 1e-2 adjoint_method = None adjoint_options = None def func(inputs, modules): y0, sde = inputs[0], modules[0] ys = torchsde.sdeint_adjoint(sde, y0, ts, dt=dt, method=method, adjoint_method=adjoint_method, adaptive=adaptive, bm=bm, options=options, adjoint_options=adjoint_options) return (ys[-1]**2).sum(dim=1).mean(dim=0) # `grad_inputs=True` also works, but we really only care about grad wrt params and want fast tests. utils.gradcheck(func, y0, sde, eps=1e-6, rtol=tol, atol=tol, grad_params=True)
def test_normality_conditional(device, levy_area_approximation): if device == gpu and not torch.cuda.is_available(): pytest.skip(msg="CUDA not available.") t0, t1 = 0.0, 1.0 for _ in range(REPS): bm = torchsde.BrownianInterval( t0=t0, t1=t1, size=(LARGE_BATCH_SIZE, ), device=device, levy_area_approximation=levy_area_approximation) for _ in range(MEDIUM_REPS): ta, t_, tb = sorted(npr.uniform(low=t0, high=t1, size=(3, ))) W = bm(ta, tb) W1 = bm(ta, t_) W2 = bm(t_, tb) mean_W1 = W * (t_ - ta) / (tb - ta) std_W1 = math.sqrt((tb - t_) * (t_ - ta) / (tb - ta)) rescaled_W1 = (W1 - mean_W1) / std_W1 _, pval = kstest(rescaled_W1.cpu().detach().numpy(), 'norm') assert pval >= ALPHA mean_W2 = W * (tb - t_) / (tb - ta) std_W2 = math.sqrt((tb - t_) * (t_ - ta) / (tb - ta)) rescaled_W2 = (W2 - mean_W2) / std_W2 _, pval = kstest(rescaled_W2.cpu().detach().numpy(), 'norm') assert pval >= ALPHA if levy_area_approximation != 'none': W, U = bm(ta, tb, return_U=True) W1, U1 = bm(ta, t_, return_U=True) W2, U2 = bm(t_, tb, return_U=True) h = tb - ta h1 = t_ - ta h2 = tb - t_ denom = math.sqrt(h1**3 + h2**3) a = h1**3.5 * h2**0.5 / (2 * h * denom) b = h1**0.5 * h2**3.5 / (2 * h * denom) c = math.sqrt(3) * h1**1.5 * h2**1.5 / (6 * denom) H = _U_to_H(W, U, h) H1 = _U_to_H(W1, U1, h1) H2 = _U_to_H(W2, U2, h2) mean_H1 = H * (h1 / h)**2 std_H1 = math.sqrt(a**2 + c**2) / h1 rescaled_H1 = (H1 - mean_H1) / std_H1 _, pval = kstest(rescaled_H1.cpu().detach().numpy(), 'norm') assert pval >= ALPHA mean_H2 = H * (h2 / h)**2 std_H2 = math.sqrt(b**2 + c**2) / h2 rescaled_H2 = (H2 - mean_H2) / std_H2 _, pval = kstest(rescaled_H2.cpu().detach().numpy(), 'norm') assert pval >= ALPHA
def main(): # Dataset. ts_, ts_ext_, ts_vis_, ts, ts_ext, ts_vis, ys, ys_ = make_data() summary = SummaryWriter(os.path.join(args.train_dir, 'tb')) # Plotting parameters. vis_batch_size = 1024 ylims = (-1.75, 1.75) alphas = [0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50, 0.55] percentiles = [0.999, 0.99, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1] vis_idx = np.random.permutation(vis_batch_size) if args.color == "blue": fill_color = '#9ebcda' mean_color = '#4d004b' num_samples = 60 else: sample_colors = ('#fc4e2a', '#e31a1c', '#bd0026') fill_color = '#fd8d3c' mean_color = '#800026' num_samples = len(sample_colors) eps = torch.randn(vis_batch_size, 1).to( device) # Fix seed for the random draws used in the plots. bm = torchsde.BrownianInterval( t0=ts_vis[0], t1=ts_vis[-1], size=(vis_batch_size, 1), device=device, levy_area_approximation='space-time', pool_size=POOL_SIZE, cache_size=CACHE_SIZE, ) # We need space-time Levy area to use the SRK solver # Model. # Note: This `mu` is selected based on the yvalue of the two endpoints of the left and right segments. model = LatentSDE(mu=-0.80901699, sigma=args.sigma).to(device) optimizer = make_optimizer(optimizer=args.optimizer, params=model.parameters(), lr=args.lr) scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=.99997) kl_scheduler = LinearScheduler(iters=args.kl_anneal_iters, maxval=args.kl_coeff) nll_scheduler = ConstantScheduler(constant=args.nll_coef) logpy_metric = EMAMetric() kl_metric = EMAMetric() loss_metric = EMAMetric() if os.path.exists(os.path.join(args.train_dir, 'ckpts', f'state.ckpt')): logging.info("Loading checkpoints...") checkpoint = torch.load( os.path.join(args.train_dir, 'ckpts', f'state.ckpt')) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) try: logpy_metric.set(checkpoint['logpy_metric']) kl_metric.set(checkpoint['kl_metric']) loss_metric.set(checkpoint['loss_metric']) except: logging.warning( f"Could not successfully load logpy, kl, and loss metrics from checkpoint" ) logging.info( f"Successfully loaded checkpoints at global_step {checkpoint['global_step']}" ) if args.show_prior: with torch.no_grad(): zs = model.sample_p(ts=ts_vis, batch_size=vis_batch_size, eps=eps, bm=bm).squeeze() ts_vis_, zs_ = ts_vis.cpu().numpy(), zs.cpu().numpy() zs_ = np.sort(zs_, axis=1) img_dir = os.path.join(args.train_dir, 'prior.png') plt.subplot(frameon=False) for alpha, percentile in zip(alphas, percentiles): idx = int((1 - percentile) / 2. * vis_batch_size) zs_bot_ = zs_[:, idx] zs_top_ = zs_[:, -idx] plt.fill_between(ts_vis_, zs_bot_, zs_top_, alpha=alpha, color=fill_color) # `zorder` determines who's on top; the larger the more at the top. plt.scatter(ts_, ys_[:, 0], marker='x', zorder=3, color='k', s=35) # Data. if args.data != "irregular_sine": plt.scatter(ts_, ys_[:, 1], marker='x', zorder=3, color='k', s=35) # Data. plt.ylim(ylims) plt.xlabel('$t$') plt.ylabel('$Y_t$') plt.tight_layout() plt.savefig(img_dir, dpi=args.dpi) summary.add_figure('Prior', plt.gcf(), 0) logging.info(f'Prior saved to tensorboard') plt.close() logging.info(f'Saved prior figure at: {img_dir}') for global_step in tqdm.tqdm(range(args.train_iters)): # Plot and save. if global_step % args.pause_iters == 0 or global_step == ( args.train_iters - 1): img_path = os.path.join(args.train_dir, "plots", f'global_step_{global_step}.png') with torch.no_grad(): # TODO: zs = model.sample_q(ts=ts_vis, batch_size=vis_batch_size, eps=None, bm=bm).squeeze() samples = zs[:, vis_idx] ts_vis_, zs_, samples_ = ts_vis.cpu().numpy(), zs.cpu().numpy( ), samples.cpu().numpy() zs_ = np.sort(zs_, axis=1) plt.subplot(frameon=False) if args.show_percentiles: for alpha, percentile in zip(alphas, percentiles): idx = int((1 - percentile) / 2. * vis_batch_size) zs_bot_, zs_top_ = zs_[:, idx], zs_[:, -idx] plt.fill_between(ts_vis_, zs_bot_, zs_top_, alpha=alpha, color=fill_color) if args.show_mean: plt.plot(ts_vis_, zs_.mean(axis=1), color=mean_color) if args.show_samples: for j in range(num_samples): plt.plot(ts_vis_, samples_[:, j], linewidth=1.0) if args.show_arrows: t_start, t_end = ts_vis_[0], ts_vis_[-1] num, dt = 12, 0.12 t, y = torch.meshgrid([ torch.linspace(t_start, t_end, num).to(device), torch.linspace(*ylims, num).to(device) ]) t, y = t.reshape(-1, 1), y.reshape(-1, 1) fty = model.f(t=t, y=y).reshape(num, num) dt = torch.zeros(num, num).fill_(dt).to(device) dy = fty * dt dt_, dy_, t_, y_ = dt.cpu().numpy(), dy.cpu().numpy( ), t.cpu().numpy(), y.cpu().numpy() plt.quiver(t_, y_, dt_, dy_, alpha=0.3, edgecolors='k', width=0.0035, scale=50) if args.hide_ticks: plt.xticks([], []) plt.yticks([], []) plt.scatter(ts_, ys_[:, 0], marker='x', zorder=3, color='k', s=35) # Data. if args.data != "irregular_sine": plt.scatter(ts_, ys_[:, 1], marker='x', zorder=3, color='k', s=35) # Data. plt.ylim(ylims) plt.xlabel('$t$') plt.ylabel('$Y_t$') plt.tight_layout() if global_step % args.save_fig == 0: plt.savefig(img_path, dpi=args.dpi) current_fig = plt.gcf() summary.add_figure('Predictions plot', current_fig, global_step) logging.info(f'Predictions plot saved to tensorboard') plt.close() logging.info(f'Saved figure at: {img_path}') if args.save_ckpt: torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict() }, os.path.join(args.train_dir, 'ckpts', f'global_step_{global_step}.ckpt')) # for preemption torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'global_step': global_step, 'logpy_metric': logpy_metric.val, 'kl_metric': kl_metric.val, 'loss_metric': loss_metric.val }, os.path.join(args.train_dir, 'ckpts', f'state.ckpt')) # Train. optimizer.zero_grad() zs, kl = model(ts=ts_ext, batch_size=args.batch_size) zs = zs.squeeze() zs = zs[ 1: -1] # Drop first and last which are only used to penalize out-of-data region and spread uncertainty. likelihood_constructor = { "laplace": distributions.Laplace, "normal": distributions.Normal, "cauchy": distributions.Cauchy }[args.likelihood] likelihood = likelihood_constructor(loc=zs, scale=args.scale) # Proper summation of log-likelihoods. logpy = 0. ys_split = ys.split(split_size=1, dim=-1) for _ys in ys_split: logpy = logpy + likelihood.log_prob(_ys).sum(dim=0).mean(dim=0) logpy = logpy / len(ys_split) loss = -logpy * nll_scheduler.val + kl * kl_scheduler.val loss.backward() optimizer.step() scheduler.step() kl_scheduler.step() nll_scheduler.step(global_step) logpy_metric.step(logpy) kl_metric.step(kl) loss_metric.step(loss) logging.info(f'global_step: {global_step}, ' f'logpy: {logpy_metric.val:.3f}, ' f'kl: {kl_metric.val:.3f}, ' f'loss: {loss_metric.val:.3f}') summary.add_scalar('KL Schedler', kl_scheduler.val, global_step) summary.add_scalar('NLL Schedler', nll_scheduler.val, global_step) summary.add_scalar('Loss', loss_metric.val, global_step) summary.add_scalar('KL', kl_metric.val, global_step) summary.add_scalar('Log(py) Likelihood', logpy_metric.val, global_step) logging.info(f'Logged loss, kl, logpy to tensorboard') summary.close()