def _time_sdeint_adjoint(sde, y0, ts, bm): now = time.perf_counter() sde.zero_grad() y0 = y0.clone().requires_grad_(True) ys = torchsde.sdeint_adjoint(sde, y0, ts, bm, method='euler') ys.sum().backward() return time.perf_counter() - now
def sdeint(sde, x0, s_span, bm=None, logqp=False, method='srk', adaptive=False, rtol=1e-6, atol=1e-5, dt=1e-2, dt_min=1e-4, options=None, names=None): """ s_span -> ts: ts (Tensor or sequence of float): Query times in non-descending order. The state at the first time of `ts` should be `y0`. x0 -> y0: y0 (sequence of Tensor): Tensors for initial state. """ return sdeint_adjoint(sde=sde, y0=x0, ts=s_span, bm=bm, logqp=logqp, method=method, dt=dt, adaptive=adaptive, rtol=rtol, atol=atol, dt_min=dt_min, options=options, names=names)
def forward(self, xs, ts, noise_std, adjoint=False, method="euler"): # Contextualization is only needed for posterior inference. ctx = self.encoder(torch.flip(xs, dims=(0, ))) ctx = torch.flip(ctx, dims=(0, )) self.contextualize((ts, ctx)) if adjoint: # Must use the argument `adjoint_params`, since `ctx` is not part of the input to `f`, `g`, and `h`. adjoint_params = ((ctx, ) + tuple(self.f_net.parameters()) + tuple(self.g_nets.parameters()) + tuple(self.h_net.parameters())) _xs, log_ratio = torchsde.sdeint_adjoint( self, xs[0], ts, adjoint_params=adjoint_params, dt=1e-2, logqp=True, method=method) else: _xs, log_ratio = torchsde.sdeint(self, xs[0], ts, dt=1e-2, logqp=True, method=method) xs_dist = Normal(loc=_xs, scale=noise_std) log_pxs = xs_dist.log_prob(xs).sum(dim=(0, 2)).mean() log_ratio = log_ratio.sum(dim=0).mean() return log_pxs, log_ratio
def forward(self, ts, batch_size): # ts has shape (t_size,) and corresponds to the points we want to evaluate the SDE at. ################### # Actually solve the SDE. ################### init_noise = torch.randn(batch_size, self._initial_noise_size, device=ts.device) x0 = self._initial(init_noise) ################### # We use the reversible Heun method to get accurate gradients whilst using the adjoint method. ################### xs = torchsde.sdeint_adjoint( self._func, x0, ts, method='reversible_heun', dt=1.0, adjoint_method='adjoint_reversible_heun', ) xs = xs.transpose(0, 1) ys = self._readout(xs) ################### # Normalise the data to the form that the discriminator expects, in particular including time as a channel. ################### ts = ts.unsqueeze(0).unsqueeze(-1).expand(batch_size, ts.size(0), 1) return torchcde.linear_interpolation_coeffs(torch.cat([ts, ys], dim=2))
def func(inputs, modules): y0, sde = inputs[0], modules[0] ys = torchsde.sdeint_adjoint(sde, y0, ts, bm, dt=dt, method=method, adaptive=adaptive) return (ys[-1]**2).sum(dim=1).mean(dim=0)
def _adjoint(self, x): out = torchsde.sdeint_adjoint(self.defunc, x, self.s_span, rtol=self.rtol, atol=self.atol, adaptive=self.adaptive, method=self.solver, dt=self.ds)[-1] return out
def func(x): ys_and_logqp = sdeint_adjoint(sde, x, ts, bm, logqp=True, method=method, dt=dt, adaptive=adaptive) ys, logqp = ys_and_logqp # Just another arbitrarily chosen function with two outputs. return torch.stack([(ys**2.).sum(), (logqp / 3.).sum()], dim=0)
def _test_forward(sde, bm, method, adaptive=False, rtol=1e-6, atol=1e-5): sde.zero_grad() ys, log_ratio = sdeint_adjoint(sde, y0, ts, bm, logqp=True, method=method, dt=dt, adaptive=adaptive, rtol=rtol, atol=atol) loss = ys.sum(0).mean(0).sum(0) + log_ratio.sum(0).mean(0) loss.backward()
def test_basic(problem, method, adaptive): d = 10 batch_size = 128 ts = torch.tensor([0.0, 0.5], device=device) dt = 1e-3 y0 = torch.zeros(batch_size, d).to(device).fill_(0.1) problem = problem(d).to(device) num_before = _count_differentiable_params(problem) problem.zero_grad() _, yt = torchsde.sdeint_adjoint(problem, y0, ts, method=method, dt=dt, adaptive=adaptive) loss = yt.sum(dim=1).mean(dim=0) loss.backward() num_after = _count_differentiable_params(problem) assert num_before == num_after
def _test_basic(self, problem, method, adaptive, rtol=1e-6, atol=1e-5): if method == 'euler' and adaptive: return nbefore = _count_differentiable_params(problem) problem.zero_grad() _, yt = sdeint_adjoint(problem, y0, ts, method=method, dt=dt, adaptive=adaptive, rtol=rtol, atol=atol) loss = yt.sum(dim=1).mean(dim=0) loss.backward() nafter = _count_differentiable_params(problem) self.assertEqual(nbefore, nafter)
def _test_gradient(self, problem, method, adaptive, rtol=1e-6, atol=1e-5): if method == 'euler' and adaptive: return bm = BrownianPath(t0=t0, w0=w0) with torch.no_grad(): grad_outputs = torch.ones(batch_size, d).to(device) alt_grad = problem.analytical_grad(y0, t1, grad_outputs, bm) problem.zero_grad() _, yt = sdeint_adjoint(problem, y0, ts, bm=bm, method=method, dt=dt, adaptive=adaptive, rtol=rtol, atol=atol) loss = yt.sum(dim=1).mean(dim=0) loss.backward() adj_grad = torch.cat(tuple(p.grad for p in problem.parameters())) self.tensorAssertAllClose(alt_grad, adj_grad)
def forward(self, ts, batch_size, eps=None): eps = torch.randn(batch_size, 1).to( self.qy0_std) if eps is None else eps y0 = self.qy0_mean + eps * self.qy0_std qy0 = Normal(loc=self.qy0_mean, scale=self.qy0_std) py0 = Normal(loc=self.py0_mean, scale=self.py0_std) logqp0 = kl_divergence(qy0, py0).sum(1).mean(0) # KL(time=0). # `trapezoidal_approx` is for SRK. Disabling it gives better performance. if args.adjoint: zs, logqp = sdeint_adjoint(self, y0, ts, logqp=True, method=args.method, dt=args.dt, adaptive=args.adaptive, rtol=args.rtol, atol=args.atol, options={'trapezoidal_approx': False}) else: zs, logqp = sdeint(self, y0, ts, logqp=True, method=args.method, dt=args.dt, adaptive=args.adaptive, rtol=args.rtol, atol=args.atol, options={'trapezoidal_approx': False}) logqp = logqp.sum(0).mean(0) log_ratio = logqp0 + logqp # KL(time=0) + KL(path). return zs, log_ratio
def test_against_sdeint(sde_cls, sde_type, method, options, dt, rtol, atol, len_ts): # Skipping below, since method not supported for corresponding noise types. if sde_cls.noise_type == NOISE_TYPES.general and method in ( METHODS.milstein, METHODS.srk): return d = 3 m = { NOISE_TYPES.scalar: 1, NOISE_TYPES.diagonal: d, NOISE_TYPES.general: 2, NOISE_TYPES.additive: 2 }[sde_cls.noise_type] batch_size = 4 ts = torch.linspace(0.0, 1.0, len_ts, device=device, dtype=torch.float64) t0 = ts[0] t1 = ts[-1] y0 = torch.full((batch_size, d), 0.1, device=device, dtype=torch.float64, requires_grad=True) sde = sde_cls(d=d, m=m, sde_type=sde_type).to(device, torch.float64) if method == METHODS.srk: levy_area_approximation = LEVY_AREA_APPROXIMATIONS.space_time else: levy_area_approximation = LEVY_AREA_APPROXIMATIONS.none bm = torchsde.BrownianInterval( t0=t0, t1=t1, size=(batch_size, m), dtype=torch.float64, device=device, levy_area_approximation=levy_area_approximation) if method == METHODS.reversible_heun: adjoint_method = METHODS.adjoint_reversible_heun adjoint_options = options else: adjoint_method = None adjoint_options = None ys_true = torchsde.sdeint(sde, y0, ts, dt=dt, method=method, bm=bm, options=options) grad = torch.randn_like(ys_true) ys_true.backward(grad) true_grad = torch.cat([y0.grad.view(-1)] + [param.grad.view(-1) for param in sde.parameters()]) y0.grad.zero_() for param in sde.parameters(): param.grad.zero_() ys_test = torchsde.sdeint_adjoint(sde, y0, ts, dt=dt, method=method, bm=bm, adjoint_method=adjoint_method, options=options, adjoint_options=adjoint_options) ys_test.backward(grad) test_grad = torch.cat([y0.grad.view(-1)] + [param.grad.view(-1) for param in sde.parameters()]) torch.testing.assert_allclose(ys_true, ys_test) torch.testing.assert_allclose(true_grad, test_grad, rtol=rtol, atol=atol)