def test_reversibility(sde_cls): batch_size = 32 state_size = 4 t_size = 20 dt = 0.1 brownian_size = { NOISE_TYPES.scalar: 1, NOISE_TYPES.diagonal: state_size, NOISE_TYPES.general: 2, NOISE_TYPES.additive: 2 }[sde_cls.noise_type] class MinusSDE(torch.nn.Module): def __init__(self, sde): self.noise_type = sde.noise_type self.sde_type = sde.sde_type self.f = lambda t, y: -sde.f(-t, y) self.g = lambda t, y: -sde.g(-t, y) sde = sde_cls(d=state_size, m=brownian_size, sde_type='stratonovich') minus_sde = MinusSDE(sde) y0 = torch.full((batch_size, state_size), 0.1) ts = torch.linspace(0, (t_size - 1) * dt, t_size) bm = torchsde.BrownianInterval(t0=ts[0], t1=ts[-1], size=(batch_size, brownian_size)) ys, (f, g, z) = torchsde.sdeint(sde, y0, ts, bm=bm, method='reversible_heun', dt=dt, extra=True) backward_ts = -ts.flip(0) backward_ys = torchsde.sdeint(minus_sde, ys[-1], backward_ts, bm=torchsde.ReverseBrownian(bm), method='reversible_heun', dt=dt, extra_solver_state=(-f, -g, z)) backward_ys = backward_ys.flip(0) torch.testing.assert_allclose(ys, backward_ys, rtol=1e-6, atol=1e-6)
def _test_sdeint(sde, bm, method, adaptive, logqp, device, should_fail, options): y0 = torch.ones(batch_size, d, device=device) ts = torch.linspace(t0, t1, steps=T, device=device) if adaptive and method == 'euler' and sde.noise_type != 'additive': ctx = pytest.warns(UserWarning) else: ctx = _nullcontext() # Using `f` as drift. with torch.no_grad(): try: with ctx: ans = torchsde.sdeint(sde, y0, ts, bm, method=method, dt=dt, adaptive=adaptive, logqp=logqp, options=options) except ValueError: if should_fail: return raise else: if should_fail: pytest.fail("Expected an error; did not get one.") if logqp: ans, log_ratio = ans assert log_ratio.shape == (T - 1, batch_size) assert ans.shape == (T, batch_size, d) # Using `h` as drift. with torch.no_grad(): with ctx: ans = torchsde.sdeint(sde, y0, ts, bm, method=method, dt=dt, adaptive=adaptive, names={'drift': 'h'}, logqp=logqp, options=options) if logqp: ans, log_ratio = ans assert log_ratio.shape == (T - 1, batch_size) assert ans.shape == (T, batch_size, d)
def inspect_strong_order(): batch_size, d, m = 4096, 5, 5 t0, t1 = ts = torch.tensor([0., 5.]).to(device) dts = tuple(2**-i for i in range(1, 9)) y0 = torch.ones(batch_size, d).to(device) sde = Ex3Additive(d=d).to(device) euler_mses_ = [] srk_mses_ = [] with torch.no_grad(): bm = BrownianPath(t0=t0, w0=torch.zeros( batch_size, m).to(device)) # It's important to have the correct size!!! for dt in tqdm.tqdm(dts): # Only take end value. _, ys_euler = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='euler') _, ys_srk = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='srk', options={'trapezoidal_approx': False}) _, ys_analytical = sde.analytical_sample(y0=y0, ts=ts, bm=bm) euler_mse = compute_mse(ys_euler, ys_analytical) srk_mse = compute_mse(ys_srk, ys_analytical) euler_mse_, srk_mse_ = to_numpy(euler_mse, srk_mse) euler_mses_.append(euler_mse_) srk_mses_.append(srk_mse_) del euler_mse_, srk_mse_ # Divide the log-error by 2, since textbook strong orders are represented so. log = lambda x: np.log(np.array(x)) euler_slope, _, _, _, _ = stats.linregress(log(dts), log(euler_mses_) / 2) srk_slope, _, _, _, _ = stats.linregress(log(dts), log(srk_mses_) / 2) plt.figure() plt.plot(dts, euler_mses_, label=f'euler(k={euler_slope:.4f})') plt.plot(dts, srk_mses_, label=f'srk(k={srk_slope:.4f})') plt.xscale('log') plt.yscale('log') plt.legend() img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_additive') makedirs_if_not_found(img_dir) plt.savefig(os.path.join(img_dir, 'rate')) plt.close()
def _test_sdeint(self, sde, bm, adaptive, method, dt): # Using `f` as drift. with torch.no_grad(): ans = sdeint(sde, y0, ts, bm, method=method, dt=dt, adaptive=adaptive) self.assertEqual(ans.shape, (T, batch_size, d)) # Using `h` as drift. with torch.no_grad(): ans = sdeint(sde, y0, ts, bm, method=method, dt=dt, adaptive=adaptive, names={'drift': 'h'}) self.assertEqual(ans.shape, (T, batch_size, d))
def test_rename_methods(self): # Test renaming works with a subset of names when `logqp=False`. sde = basic_sde.CustomNamesSDE().to(device) ans = sdeint(sde, y0, ts, dt=dt, names={'drift': 'forward'}) self.assertEqual(ans.shape, (T, batch_size, d)) # Test renaming works with a subset of names when `logqp=True`. sde = basic_sde.CustomNamesSDELogqp().to(device) ans = sdeint(sde, y0, ts, dt=dt, names={'drift': 'forward', 'prior_drift': 'w'}, logqp=True) self.assertEqual(ans[0].shape, (T, batch_size, d)) self.assertEqual(ans[1].shape, (T - 1, batch_size))
def test_srk_determinism(self): # srk for additive. sde = basic_sde.AdditiveSDE(d=d, m=m).to(device) ys1 = sdeint(sde, y0, ts, bm=bm_general, adaptive=False, method='srk', dt=dt) ys2 = sdeint(sde, y0, ts, bm=bm_general, adaptive=False, method='srk', dt=dt) self.tensorAssertAllClose(ys1, ys2) # srk for diagonal. sde = basic_sde.BasicSDE1(d=d).to(device) ys1 = sdeint(sde, y0, ts, bm=bm_diagonal, adaptive=False, method='srk', dt=dt) ys2 = sdeint(sde, y0, ts, bm=bm_diagonal, adaptive=False, method='srk', dt=dt) self.tensorAssertAllClose(ys1, ys2)
def inspect_rate(): batch_size, d, m = 4096, 5, 5 t0, t1 = ts = torch.tensor([0., 5.]).to(device) dts = tuple(2**-i for i in range(1, 10)) y0 = torch.ones(batch_size, d).to(device) sde = Ex3Additive(d=d).to(device) euler_mses_ = [] srk_mses_ = [] with torch.no_grad(): bm = BrownianPath(t0=t0, w0=torch.zeros( batch_size, m).to(device)) # It's important to have the correct size!!! for dt in dts: # Only take end value. _, ys_euler = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='euler') _, ys_srk = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='srk', options={'trapezoidal_approx': False}) _, ys_analytical = sde.analytical_sample(y0=y0, ts=ts, bm=bm) euler_mse = compute_mse(ys_euler, ys_analytical) srk_mse = compute_mse(ys_srk, ys_analytical) euler_mse_, srk_mse_ = to_numpy(euler_mse, srk_mse) euler_mses_.append(euler_mse_) srk_mses_.append(srk_mse_) plt.figure() plt.plot(dts, euler_mses_, label='euler') plt.plot(dts, srk_mses_, label='srk') plt.xscale('log') plt.yscale('log') plt.legend() img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_additive') makedirs_if_not_found(img_dir) plt.savefig(os.path.join(img_dir, 'rate')) plt.close()
def sample(self, x0, ts, noise_std, normalize): """Sample data for training. Store data normalization constants if necessary.""" xs = torchsde.sdeint(self, x0, ts) if normalize: mean, std = torch.mean(xs, dim=(0, 1)), torch.std(xs, dim=(0, 1)) xs.sub_(mean).div_(std).add_(torch.randn_like(xs) * noise_std) return xs
def forward(self, xs, ts, noise_std, adjoint=False, method="euler"): # Contextualization is only needed for posterior inference. ctx = self.encoder(torch.flip(xs, dims=(0, ))) ctx = torch.flip(ctx, dims=(0, )) self.contextualize((ts, ctx)) if adjoint: # Must use the argument `adjoint_params`, since `ctx` is not part of the input to `f`, `g`, and `h`. adjoint_params = ((ctx, ) + tuple(self.f_net.parameters()) + tuple(self.g_nets.parameters()) + tuple(self.h_net.parameters())) _xs, log_ratio = torchsde.sdeint_adjoint( self, xs[0], ts, adjoint_params=adjoint_params, dt=1e-2, logqp=True, method=method) else: _xs, log_ratio = torchsde.sdeint(self, xs[0], ts, dt=1e-2, logqp=True, method=method) xs_dist = Normal(loc=_xs, scale=noise_std) log_pxs = xs_dist.log_prob(xs).sum(dim=(0, 2)).mean() log_ratio = log_ratio.sum(dim=0).mean() return log_pxs, log_ratio
def test_sdeint_tuplesde(self): y0_ = (y0,) # Make tuple input. sde = basic_sde.TupleSDE(d=d).to(device) bm = lambda t: (bm_diagonal(t),) with torch.no_grad(): ans = sdeint(sde, y0_, ts, bm, method='euler', dt=dt) self.assertTrue(isinstance(ans, tuple))
def time_func(): ys = sdeint(geometric_bm, y0, ts, adaptive=False, dt=ts[1], options={'trapezoidal_approx': True})
def sample(self, x0, ts, bm=None): return torchsde.sdeint(self, x0, ts, names={'drift': 'h'}, dt=1e-3, bm=bm)
def _time_sdeint_bp(sde, y0, ts, bm): now = time.perf_counter() sde.zero_grad() y0 = y0.clone().requires_grad_(True) ys = torchsde.sdeint(sde, y0, ts, bm, method='euler') ys.sum().backward() return time.perf_counter() - now
def test_rename_methods(device): """Test renaming works with a subset of names.""" sde = problems.CustomNamesSDE().to(device) y0 = torch.ones(batch_size, d, device=device) ts = torch.linspace(t0, t1, steps=T, device=device) ans = torchsde.sdeint(sde, y0, ts, dt=dt, names={'drift': 'forward'}) assert ans.shape == (T, batch_size, d)
def test_rename_methods_logqp(device): """Test renaming works with a subset of names when `logqp=True`.""" sde = problems.CustomNamesSDELogqp().to(device) y0 = torch.ones(batch_size, d, device=device) ts = torch.linspace(t0, t1, steps=T, device=device) ans = torchsde.sdeint(sde, y0, ts, dt=dt, names={'drift': 'forward', 'prior_drift': 'w'}, logqp=True) assert ans[0].shape == (T, batch_size, d) assert ans[1].shape == (T - 1, batch_size)
def inspect_rate(): batch_size, d = 4096, 10 t0, t1 = ts = torch.tensor([0., 5.]).to(device) dts = tuple(2 ** -i for i in range(1, 10)) y0 = torch.ones(batch_size, d).to(device) sde = Ex2(d=d).to(device) euler_mses_ = [] milstein_mses_ = [] srk_mses_ = [] with torch.no_grad(): bm = BrownianPath(t0=t0, w0=torch.zeros(batch_size, 1).to(device)) for dt in dts: # Only take end value. _, ys_euler = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='euler') _, ys_milstein = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='milstein') _, ys_srk = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='srk') _, ys_analytical = sde.analytical_sample(y0=y0, ts=ts, bm=bm) euler_mse = compute_mse(ys_euler, ys_analytical) milstein_mse = compute_mse(ys_milstein, ys_analytical) srk_mse = compute_mse(ys_srk, ys_analytical) euler_mse_, milstein_mse_, srk_mse_ = to_numpy(euler_mse, milstein_mse, srk_mse) euler_mses_.append(euler_mse_) milstein_mses_.append(milstein_mse_) srk_mses_.append(srk_mse_) plt.figure() plt.plot(dts, euler_mses_, label='euler') plt.plot(dts, milstein_mses_, label='milstein') plt.plot(dts, srk_mses_, label='srk') plt.xscale('log') plt.yscale('log') plt.legend() img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_scalar') makedirs_if_not_found(img_dir) plt.savefig(os.path.join(img_dir, 'rate')) plt.close()
def trajectory(self, x: torch.Tensor, s_span: torch.Tensor): x = self._prep_sdeint(x) sol = torchsde.sdeint(self.defunc, x, s_span, rtol=self.rtol, atol=self.atol, method=self.solver, dt=self.ds) return sol
def _autograd(self, x): self.defunc.intloss, self.defunc.sensitivity = self.intloss, self.sensitivity return torchsde.sdeint(self.defunc, x, self.s_span, rtol=self.rtol, atol=self.atol, adaptive=self.adaptive, method=self.solver, dt=self.ds)[-1]
def sample_p(self, ts, batch_size, eps=None, bm=None): eps = torch.randn(batch_size, 1).to( self.py0_mean) if eps is None else eps y0 = self.py0_mean + eps * self.py0_std return sdeint(self, y0, ts, bm=bm, method='srk', dt=args.dt, names={'drift': 'h'})
def sample_q(self, vis_span, n_sim, eps=None, bm=None, dt=0.01): """ :param vis_span: :param n_sim: :param eps: :param bm: :param dt: """ eps = torch.randn(n_sim, 1).to(self.qy0_mean) if eps is None else eps y0 = self.qy0_mean + eps.to(self.device) * self.qy0_std return torchsde.sdeint(self.defunc, y0, vis_span, bm=bm, method='srk', dt=dt)
def inspect_samples(): batch_size, d, m = 32, 1, 5 steps = 10 ts = torch.linspace(0., 5., steps=steps).to(device) t0 = ts[0] dt = 3e-1 y0 = torch.ones(batch_size, d).to(device) sde = AdditiveSDE(d=d, m=m).to(device) with torch.no_grad(): bm = BrownianPath(t0=t0, w0=torch.zeros(batch_size, m).to(device)) ys_em = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='euler') ys_srk = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='srk', options={'trapezoidal_approx': False}) ys_true = sdeint(sde, y0=y0, ts=ts, dt=1e-3, bm=bm, method='euler') ys_em = ys_em.squeeze().t() ys_srk = ys_srk.squeeze().t() ys_true = ys_true.squeeze().t() ts_, ys_em_, ys_srk_, ys_true_ = to_numpy(ts, ys_em, ys_srk, ys_true) # Visualize sample path. img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_additive') makedirs_if_not_found(img_dir) for i, (ys_em_i, ys_srk_i, ys_true_i) in enumerate(zip(ys_em_, ys_srk_, ys_true_)): plt.figure() plt.plot(ts_, ys_em_i, label='em') plt.plot(ts_, ys_srk_i, label='srk') plt.plot(ts_, ys_true_i, label='true') plt.legend() plt.savefig(os.path.join(img_dir, f'{i}')) plt.close()
def inspect_samples(y0: Tensor, ts: Vector, dt: Scalar, sde: BaseSDE, bm: BaseBrownian, img_dir: str, methods: Tuple[str, ...], options: Optional[Tuple] = None, vis_dim=0, dt_true: Optional[float] = 2**-14, labels: Optional[Tuple[str, ...]] = None): if options is None: options = (None, ) * len(methods) if labels is None: labels = methods sde = copy.deepcopy(sde).requires_grad_(False) solns = [ sdeint(sde, y0, ts, bm, method=method, dt=dt, options=options_) for method, options_ in zip(methods, options) ] method_for_true = 'euler' if sde.sde_type == SDE_TYPES.ito else 'midpoint' true = sdeint(sde, y0, ts, bm, method=method_for_true, dt=dt_true) labels += ('true', ) solns += [true] # (T, batch_size, d) -> (T, batch_size) -> (batch_size, T). solns = [soln[..., vis_dim].t() for soln in solns] for i, samples in enumerate(zip(*solns)): utils.swiss_knife_plotter(img_path=os.path.join(img_dir, f'{i}'), plots=[{ 'x': ts, 'y': sample, 'label': label, 'marker': 'x' } for sample, label in zip(samples, labels)])
def inspect_sample(): batch_size, d = 32, 1 steps = 100 ts = torch.linspace(0., 5., steps=steps).to(device) t0 = ts[0] dt = 1e-1 y0 = torch.ones(batch_size, d).to(device) sde = Ex2(d=d).to(device) sde.noise_type = "scalar" with torch.no_grad(): bm = BrownianPath(t0=t0, w0=torch.zeros_like(y0)) ys_euler = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='euler') ys_milstein = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='milstein') ys_srk = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='srk', options={'trapezoidal_approx': False}) ys_analytical = sde.analytical_sample(y0=y0, ts=ts, bm=bm) ys_euler = ys_euler.squeeze().t() ys_milstein = ys_milstein.squeeze().t() ys_srk = ys_srk.squeeze().t() ys_analytical = ys_analytical.squeeze().t() ts_, ys_euler_, ys_milstein_, ys_srk_, ys_analytical_ = to_numpy( ts, ys_euler, ys_milstein, ys_srk, ys_analytical) # Visualize sample path. img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_scalar') makedirs_if_not_found(img_dir) for i, (ys_euler_i, ys_milstein_i, ys_srk_i, ys_analytical_i) in enumerate( zip(ys_euler_, ys_milstein_, ys_srk_, ys_analytical_)): plt.figure() plt.plot(ts_, ys_euler_i, label='euler') plt.plot(ts_, ys_milstein_i, label='milstein') plt.plot(ts_, ys_srk_i, label='srk') plt.plot(ts_, ys_analytical_i, label='analytical') plt.legend() plt.savefig(os.path.join(img_dir, f'{i}')) plt.close()
def sde_sample(self, batch_size=64, tau=1., t=None, y=None, dt=1e-2, tweedie_correction=True): self.module.eval() t = torch.tensor([-self.t1, -self.t0], device=self.device) if t is None else t y = self.sample_t1_marginal(batch_size, tau) if y is None else y ys = torchsde.sdeint(self, y.flatten(start_dim=1), t, dt=dt) ys = ys.view(len(t), *y.size()) if tweedie_correction: ys[-1] = self.tweedie_correction(self.t0, ys[-1], dt) return ys
def test_specialised_functions(sde_type, method): vector = torch.randn(m) fg = problems.FGSDE(sde_type, vector) f_and_g = problems.FAndGSDE(sde_type, vector) g_prod = problems.GProdSDE(sde_type, vector) f_and_g_prod = problems.FAndGProdSDE(sde_type, vector) f_and_g_with_g_prod1 = problems.FAndGGProdSDE1(sde_type, vector) f_and_g_with_g_prod2 = problems.FAndGGProdSDE2(sde_type, vector) y0 = torch.randn(batch_size, d) outs = [] for sde in (fg, f_and_g, g_prod, f_and_g_prod, f_and_g_with_g_prod1, f_and_g_with_g_prod2): bm = torchsde.BrownianInterval(t0, t1, (batch_size, m), entropy=45678) outs.append(torchsde.sdeint(sde, y0, [t0, t1], dt=dt, bm=bm)[1]) for o in outs[1:]: # Equality of floating points, because we expect them to do everything exactly the same. assert o.shape == outs[0].shape assert (o == outs[0]).all()
def sample_q(self, vis_span, n_sim, eps=None, bm=None, dt=0.01): """[summary] Args: vis_span ([type]): [description] n_sim ([type]): [description] eps ([type], optional): [description]. Defaults to None. bm ([type], optional): [description]. Defaults to None. dt (float, optional): [description]. Defaults to 0.01. Returns: [type]: [description] """ eps = torch.randn(n_sim, 1).to(self.qy0_mean) if eps is None else eps y0 = self.qy0_mean + eps.to(self.device) * self.qy0_std return torchsde.sdeint(self.defunc, y0, vis_span, bm=bm, method='srk', dt=dt)
def forward(self, ts, batch_size): # ts has shape (t_size,) and corresponds to the points we want to evaluate the SDE at. ################### # Actually solve the SDE. ################### init_noise = torch.randn(batch_size, self._initial_noise_size, device=ts.device) x0 = self._initial(init_noise) xs = torchsde.sdeint(self._func, x0, ts, method='midpoint', dt=1.0) # shape (t_size, batch_size, hidden_size) xs = xs.transpose(0, 1) # switch t_size and batch_size ys = self._readout(xs) ################### # Normalise the data to the form that the discriminator expects, in particular including time as a channel. ################### t_size = ts.size(0) ts = ts.unsqueeze(0).unsqueeze(-1).expand(batch_size, t_size, 1) return torchcde.linear_interpolation_coeffs(torch.cat([ts, ys], dim=2))
def forward(self, eps: torch.Tensor, s_span=None): """[summary] Args: eps (torch.Tensor): [description] s_span ([type], optional): [description]. Defaults to None. Returns: [type]: [description] """ eps = eps.to(self.qy0_std) x0 = self.qy0_mean + eps * self.qy0_std qy0 = Normal(loc=self.qy0_mean, scale=self.qy0_std) py0 = Normal(loc=self.py0_mean, scale=self.py0_std) logqp0 = kl_divergence(qy0, py0).sum(1).mean(0) # KL(time=0). if s_span is not None: s_span_ext = s_span else: s_span_ext = self.s_span.cpu() zs, logqp = torchsde.sdeint(sde=self.defunc, x0=x0, s_span=s_span_ext, rtol=self.rtol, atol=self.atol, logqp=True, options=self.options, adaptive=self.adaptive, method=self.solver) logqp = logqp.sum(0).mean(0) log_ratio = logqp0 + logqp # KL(time=0) + KL(path). return zs, log_ratio
def forward(self, ts, batch_size, eps=None): eps = torch.randn(batch_size, 1).to( self.qy0_std) if eps is None else eps y0 = self.qy0_mean + eps * self.qy0_std qy0 = Normal(loc=self.qy0_mean, scale=self.qy0_std) py0 = Normal(loc=self.py0_mean, scale=self.py0_std) logqp0 = kl_divergence(qy0, py0).sum(1).mean(0) # KL(time=0). # `trapezoidal_approx` is for SRK. Disabling it gives better performance. if args.adjoint: zs, logqp = sdeint_adjoint(self, y0, ts, logqp=True, method=args.method, dt=args.dt, adaptive=args.adaptive, rtol=args.rtol, atol=args.atol, options={'trapezoidal_approx': False}) else: zs, logqp = sdeint(self, y0, ts, logqp=True, method=args.method, dt=args.dt, adaptive=args.adaptive, rtol=args.rtol, atol=args.atol, options={'trapezoidal_approx': False}) logqp = logqp.sum(0).mean(0) log_ratio = logqp0 + logqp # KL(time=0) + KL(path). return zs, log_ratio
def get_data(batch_size, device): class OrnsteinUhlenbeckSDE(torch.nn.Module): sde_type = 'ito' noise_type = 'scalar' def __init__(self, mu, theta, sigma): super(OrnsteinUhlenbeckSDE, self).__init__() self.register_buffer('mu', torch.as_tensor(mu)) self.register_buffer('theta', torch.as_tensor(theta)) self.register_buffer('sigma', torch.as_tensor(sigma)) def f(self, t, y): return self.mu * t - self.theta * y def g(self, t, y): return self.sigma.expand(y.size(0), 1, 1) dataset_size = 8192 t_size = 64 ou_sde = OrnsteinUhlenbeckSDE(mu=0.02, theta=0.1, sigma=0.4).to(device) y0 = torch.rand(dataset_size, device=device).unsqueeze(-1) * 2 - 1 ts = torch.linspace(0, t_size - 1, t_size, device=device) ys = torchsde.sdeint(ou_sde, y0, ts, dt=1e-1) ################### # To demonstrate how to handle irregular data, then here we additionally drop some of the data (by setting it to # NaN.) ################### ys_num = ys.numel() to_drop = torch.randperm(ys_num)[:int(0.3 * ys_num)] ys.view(-1)[to_drop] = float('nan') ################### # Typically important to normalise data. Note that the data is normalised with respect to the statistics of the # initial data, _not_ the whole time series. This seems to help the learning process, presumably because if the # initial condition is wrong then it's pretty hard to learn the rest of the SDE correctly. ################### y0_flat = ys[0].view(-1) y0_not_nan = y0_flat.masked_select(~torch.isnan(y0_flat)) ys = (ys - y0_not_nan.mean()) / y0_not_nan.std() ################### # As discussed, time must be included as a channel for the discriminator. ################### ys = torch.cat([ ts.unsqueeze(0).unsqueeze(-1).expand(dataset_size, t_size, 1), ys.transpose(0, 1) ], dim=2) # shape (dataset_size=1000, t_size=100, 1 + data_size=3) ################### # Package up. ################### data_size = ys.size( -1 ) - 1 # How many channels the data has (not including time, hence the minus one). ys_coeffs = torchcde.linear_interpolation_coeffs(ys) # as per neural CDEs. dataset = torch.utils.data.TensorDataset(ys_coeffs) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) return ts, data_size, dataloader