예제 #1
0
def test_reversibility(sde_cls):
    batch_size = 32
    state_size = 4
    t_size = 20
    dt = 0.1

    brownian_size = {
        NOISE_TYPES.scalar: 1,
        NOISE_TYPES.diagonal: state_size,
        NOISE_TYPES.general: 2,
        NOISE_TYPES.additive: 2
    }[sde_cls.noise_type]

    class MinusSDE(torch.nn.Module):
        def __init__(self, sde):
            self.noise_type = sde.noise_type
            self.sde_type = sde.sde_type
            self.f = lambda t, y: -sde.f(-t, y)
            self.g = lambda t, y: -sde.g(-t, y)

    sde = sde_cls(d=state_size, m=brownian_size, sde_type='stratonovich')
    minus_sde = MinusSDE(sde)
    y0 = torch.full((batch_size, state_size), 0.1)
    ts = torch.linspace(0, (t_size - 1) * dt, t_size)
    bm = torchsde.BrownianInterval(t0=ts[0], t1=ts[-1], size=(batch_size, brownian_size))
    ys, (f, g, z) = torchsde.sdeint(sde, y0, ts, bm=bm, method='reversible_heun', dt=dt, extra=True)
    backward_ts = -ts.flip(0)
    backward_ys = torchsde.sdeint(minus_sde, ys[-1], backward_ts, bm=torchsde.ReverseBrownian(bm),
                                  method='reversible_heun', dt=dt, extra_solver_state=(-f, -g, z))
    backward_ys = backward_ys.flip(0)

    torch.testing.assert_allclose(ys, backward_ys, rtol=1e-6, atol=1e-6)
예제 #2
0
def _test_sdeint(sde, bm, method, adaptive, logqp, device, should_fail, options):
    y0 = torch.ones(batch_size, d, device=device)
    ts = torch.linspace(t0, t1, steps=T, device=device)
    if adaptive and method == 'euler' and sde.noise_type != 'additive':
        ctx = pytest.warns(UserWarning)
    else:
        ctx = _nullcontext()

    # Using `f` as drift.
    with torch.no_grad():
        try:
            with ctx:
                ans = torchsde.sdeint(sde, y0, ts, bm, method=method, dt=dt, adaptive=adaptive, logqp=logqp,
                                      options=options)
        except ValueError:
            if should_fail:
                return
            raise
        else:
            if should_fail:
                pytest.fail("Expected an error; did not get one.")
    if logqp:
        ans, log_ratio = ans
        assert log_ratio.shape == (T - 1, batch_size)
    assert ans.shape == (T, batch_size, d)

    # Using `h` as drift.
    with torch.no_grad():
        with ctx:
            ans = torchsde.sdeint(sde, y0, ts, bm, method=method, dt=dt, adaptive=adaptive, names={'drift': 'h'},
                                  logqp=logqp, options=options)
    if logqp:
        ans, log_ratio = ans
        assert log_ratio.shape == (T - 1, batch_size)
    assert ans.shape == (T, batch_size, d)
예제 #3
0
def inspect_strong_order():
    batch_size, d, m = 4096, 5, 5
    t0, t1 = ts = torch.tensor([0., 5.]).to(device)
    dts = tuple(2**-i for i in range(1, 9))
    y0 = torch.ones(batch_size, d).to(device)
    sde = Ex3Additive(d=d).to(device)

    euler_mses_ = []
    srk_mses_ = []

    with torch.no_grad():
        bm = BrownianPath(t0=t0, w0=torch.zeros(
            batch_size,
            m).to(device))  # It's important to have the correct size!!!

        for dt in tqdm.tqdm(dts):
            # Only take end value.
            _, ys_euler = sdeint(sde,
                                 y0=y0,
                                 ts=ts,
                                 dt=dt,
                                 bm=bm,
                                 method='euler')
            _, ys_srk = sdeint(sde,
                               y0=y0,
                               ts=ts,
                               dt=dt,
                               bm=bm,
                               method='srk',
                               options={'trapezoidal_approx': False})
            _, ys_analytical = sde.analytical_sample(y0=y0, ts=ts, bm=bm)

            euler_mse = compute_mse(ys_euler, ys_analytical)
            srk_mse = compute_mse(ys_srk, ys_analytical)

            euler_mse_, srk_mse_ = to_numpy(euler_mse, srk_mse)

            euler_mses_.append(euler_mse_)
            srk_mses_.append(srk_mse_)
    del euler_mse_, srk_mse_

    # Divide the log-error by 2, since textbook strong orders are represented so.
    log = lambda x: np.log(np.array(x))
    euler_slope, _, _, _, _ = stats.linregress(log(dts), log(euler_mses_) / 2)
    srk_slope, _, _, _, _ = stats.linregress(log(dts), log(srk_mses_) / 2)

    plt.figure()
    plt.plot(dts, euler_mses_, label=f'euler(k={euler_slope:.4f})')
    plt.plot(dts, srk_mses_, label=f'srk(k={srk_slope:.4f})')
    plt.xscale('log')
    plt.yscale('log')
    plt.legend()

    img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_additive')
    makedirs_if_not_found(img_dir)
    plt.savefig(os.path.join(img_dir, 'rate'))
    plt.close()
예제 #4
0
    def _test_sdeint(self, sde, bm, adaptive, method, dt):
        # Using `f` as drift.
        with torch.no_grad():
            ans = sdeint(sde, y0, ts, bm, method=method, dt=dt, adaptive=adaptive)
        self.assertEqual(ans.shape, (T, batch_size, d))

        # Using `h` as drift.
        with torch.no_grad():
            ans = sdeint(sde, y0, ts, bm, method=method, dt=dt, adaptive=adaptive, names={'drift': 'h'})
        self.assertEqual(ans.shape, (T, batch_size, d))
예제 #5
0
    def test_rename_methods(self):
        # Test renaming works with a subset of names when `logqp=False`.
        sde = basic_sde.CustomNamesSDE().to(device)
        ans = sdeint(sde, y0, ts, dt=dt, names={'drift': 'forward'})
        self.assertEqual(ans.shape, (T, batch_size, d))

        # Test renaming works with a subset of names when `logqp=True`.
        sde = basic_sde.CustomNamesSDELogqp().to(device)
        ans = sdeint(sde, y0, ts, dt=dt, names={'drift': 'forward', 'prior_drift': 'w'}, logqp=True)
        self.assertEqual(ans[0].shape, (T, batch_size, d))
        self.assertEqual(ans[1].shape, (T - 1, batch_size))
예제 #6
0
    def test_srk_determinism(self):
        # srk for additive.
        sde = basic_sde.AdditiveSDE(d=d, m=m).to(device)
        ys1 = sdeint(sde, y0, ts, bm=bm_general, adaptive=False, method='srk', dt=dt)
        ys2 = sdeint(sde, y0, ts, bm=bm_general, adaptive=False, method='srk', dt=dt)
        self.tensorAssertAllClose(ys1, ys2)

        # srk for diagonal.
        sde = basic_sde.BasicSDE1(d=d).to(device)
        ys1 = sdeint(sde, y0, ts, bm=bm_diagonal, adaptive=False, method='srk', dt=dt)
        ys2 = sdeint(sde, y0, ts, bm=bm_diagonal, adaptive=False, method='srk', dt=dt)
        self.tensorAssertAllClose(ys1, ys2)
예제 #7
0
def inspect_rate():
    batch_size, d, m = 4096, 5, 5
    t0, t1 = ts = torch.tensor([0., 5.]).to(device)
    dts = tuple(2**-i for i in range(1, 10))
    y0 = torch.ones(batch_size, d).to(device)
    sde = Ex3Additive(d=d).to(device)

    euler_mses_ = []
    srk_mses_ = []

    with torch.no_grad():
        bm = BrownianPath(t0=t0, w0=torch.zeros(
            batch_size,
            m).to(device))  # It's important to have the correct size!!!

        for dt in dts:
            # Only take end value.
            _, ys_euler = sdeint(sde,
                                 y0=y0,
                                 ts=ts,
                                 dt=dt,
                                 bm=bm,
                                 method='euler')
            _, ys_srk = sdeint(sde,
                               y0=y0,
                               ts=ts,
                               dt=dt,
                               bm=bm,
                               method='srk',
                               options={'trapezoidal_approx': False})
            _, ys_analytical = sde.analytical_sample(y0=y0, ts=ts, bm=bm)

            euler_mse = compute_mse(ys_euler, ys_analytical)
            srk_mse = compute_mse(ys_srk, ys_analytical)

            euler_mse_, srk_mse_ = to_numpy(euler_mse, srk_mse)

            euler_mses_.append(euler_mse_)
            srk_mses_.append(srk_mse_)

    plt.figure()
    plt.plot(dts, euler_mses_, label='euler')
    plt.plot(dts, srk_mses_, label='srk')
    plt.xscale('log')
    plt.yscale('log')
    plt.legend()

    img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_additive')
    makedirs_if_not_found(img_dir)
    plt.savefig(os.path.join(img_dir, 'rate'))
    plt.close()
예제 #8
0
 def sample(self, x0, ts, noise_std, normalize):
     """Sample data for training. Store data normalization constants if necessary."""
     xs = torchsde.sdeint(self, x0, ts)
     if normalize:
         mean, std = torch.mean(xs, dim=(0, 1)), torch.std(xs, dim=(0, 1))
         xs.sub_(mean).div_(std).add_(torch.randn_like(xs) * noise_std)
     return xs
예제 #9
0
    def forward(self, xs, ts, noise_std, adjoint=False, method="euler"):
        # Contextualization is only needed for posterior inference.
        ctx = self.encoder(torch.flip(xs, dims=(0, )))
        ctx = torch.flip(ctx, dims=(0, ))
        self.contextualize((ts, ctx))

        if adjoint:
            # Must use the argument `adjoint_params`, since `ctx` is not part of the input to `f`, `g`, and `h`.
            adjoint_params = ((ctx, ) + tuple(self.f_net.parameters()) +
                              tuple(self.g_nets.parameters()) +
                              tuple(self.h_net.parameters()))
            _xs, log_ratio = torchsde.sdeint_adjoint(
                self,
                xs[0],
                ts,
                adjoint_params=adjoint_params,
                dt=1e-2,
                logqp=True,
                method=method)
        else:
            _xs, log_ratio = torchsde.sdeint(self,
                                             xs[0],
                                             ts,
                                             dt=1e-2,
                                             logqp=True,
                                             method=method)

        xs_dist = Normal(loc=_xs, scale=noise_std)
        log_pxs = xs_dist.log_prob(xs).sum(dim=(0, 2)).mean()
        log_ratio = log_ratio.sum(dim=0).mean()
        return log_pxs, log_ratio
예제 #10
0
 def test_sdeint_tuplesde(self):
     y0_ = (y0,)  # Make tuple input.
     sde = basic_sde.TupleSDE(d=d).to(device)
     bm = lambda t: (bm_diagonal(t),)
     with torch.no_grad():
         ans = sdeint(sde, y0_, ts, bm, method='euler', dt=dt)
         self.assertTrue(isinstance(ans, tuple))
예제 #11
0
def time_func():
    ys = sdeint(geometric_bm,
                y0,
                ts,
                adaptive=False,
                dt=ts[1],
                options={'trapezoidal_approx': True})
예제 #12
0
 def sample(self, x0, ts, bm=None):
     return torchsde.sdeint(self,
                            x0,
                            ts,
                            names={'drift': 'h'},
                            dt=1e-3,
                            bm=bm)
예제 #13
0
def _time_sdeint_bp(sde, y0, ts, bm):
    now = time.perf_counter()
    sde.zero_grad()
    y0 = y0.clone().requires_grad_(True)
    ys = torchsde.sdeint(sde, y0, ts, bm, method='euler')
    ys.sum().backward()
    return time.perf_counter() - now
예제 #14
0
def test_rename_methods(device):
    """Test renaming works with a subset of names."""
    sde = problems.CustomNamesSDE().to(device)
    y0 = torch.ones(batch_size, d, device=device)
    ts = torch.linspace(t0, t1, steps=T, device=device)
    ans = torchsde.sdeint(sde, y0, ts, dt=dt, names={'drift': 'forward'})
    assert ans.shape == (T, batch_size, d)
예제 #15
0
def test_rename_methods_logqp(device):
    """Test renaming works with a subset of names when `logqp=True`."""
    sde = problems.CustomNamesSDELogqp().to(device)
    y0 = torch.ones(batch_size, d, device=device)
    ts = torch.linspace(t0, t1, steps=T, device=device)
    ans = torchsde.sdeint(sde, y0, ts, dt=dt, names={'drift': 'forward', 'prior_drift': 'w'}, logqp=True)
    assert ans[0].shape == (T, batch_size, d)
    assert ans[1].shape == (T - 1, batch_size)
예제 #16
0
def inspect_rate():
    batch_size, d = 4096, 10
    t0, t1 = ts = torch.tensor([0., 5.]).to(device)
    dts = tuple(2 ** -i for i in range(1, 10))
    y0 = torch.ones(batch_size, d).to(device)
    sde = Ex2(d=d).to(device)

    euler_mses_ = []
    milstein_mses_ = []
    srk_mses_ = []

    with torch.no_grad():
        bm = BrownianPath(t0=t0, w0=torch.zeros(batch_size, 1).to(device))

        for dt in dts:
            # Only take end value.
            _, ys_euler = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='euler')
            _, ys_milstein = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='milstein')
            _, ys_srk = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='srk')
            _, ys_analytical = sde.analytical_sample(y0=y0, ts=ts, bm=bm)

            euler_mse = compute_mse(ys_euler, ys_analytical)
            milstein_mse = compute_mse(ys_milstein, ys_analytical)
            srk_mse = compute_mse(ys_srk, ys_analytical)

            euler_mse_, milstein_mse_, srk_mse_ = to_numpy(euler_mse, milstein_mse, srk_mse)

            euler_mses_.append(euler_mse_)
            milstein_mses_.append(milstein_mse_)
            srk_mses_.append(srk_mse_)

    plt.figure()
    plt.plot(dts, euler_mses_, label='euler')
    plt.plot(dts, milstein_mses_, label='milstein')
    plt.plot(dts, srk_mses_, label='srk')
    plt.xscale('log')
    plt.yscale('log')
    plt.legend()

    img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_scalar')
    makedirs_if_not_found(img_dir)
    plt.savefig(os.path.join(img_dir, 'rate'))
    plt.close()
예제 #17
0
 def trajectory(self, x: torch.Tensor, s_span: torch.Tensor):
     x = self._prep_sdeint(x)
     sol = torchsde.sdeint(self.defunc,
                           x,
                           s_span,
                           rtol=self.rtol,
                           atol=self.atol,
                           method=self.solver,
                           dt=self.ds)
     return sol
예제 #18
0
 def _autograd(self, x):
     self.defunc.intloss, self.defunc.sensitivity = self.intloss, self.sensitivity
     return torchsde.sdeint(self.defunc,
                            x,
                            self.s_span,
                            rtol=self.rtol,
                            atol=self.atol,
                            adaptive=self.adaptive,
                            method=self.solver,
                            dt=self.ds)[-1]
예제 #19
0
 def sample_p(self, ts, batch_size, eps=None, bm=None):
     eps = torch.randn(batch_size, 1).to(
         self.py0_mean) if eps is None else eps
     y0 = self.py0_mean + eps * self.py0_std
     return sdeint(self,
                   y0,
                   ts,
                   bm=bm,
                   method='srk',
                   dt=args.dt,
                   names={'drift': 'h'})
예제 #20
0
 def sample_q(self, vis_span, n_sim, eps=None, bm=None, dt=0.01):
     """
     :param vis_span:
     :param n_sim:
     :param eps:
     :param bm:
     :param dt:
     """
     eps = torch.randn(n_sim, 1).to(self.qy0_mean) if eps is None else eps
     y0 = self.qy0_mean + eps.to(self.device) * self.qy0_std
     return torchsde.sdeint(self.defunc, y0, vis_span, bm=bm, method='srk', dt=dt)
예제 #21
0
def inspect_samples():
    batch_size, d, m = 32, 1, 5
    steps = 10

    ts = torch.linspace(0., 5., steps=steps).to(device)
    t0 = ts[0]
    dt = 3e-1
    y0 = torch.ones(batch_size, d).to(device)
    sde = AdditiveSDE(d=d, m=m).to(device)

    with torch.no_grad():
        bm = BrownianPath(t0=t0, w0=torch.zeros(batch_size, m).to(device))
        ys_em = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='euler')
        ys_srk = sdeint(sde,
                        y0=y0,
                        ts=ts,
                        dt=dt,
                        bm=bm,
                        method='srk',
                        options={'trapezoidal_approx': False})
        ys_true = sdeint(sde, y0=y0, ts=ts, dt=1e-3, bm=bm, method='euler')

        ys_em = ys_em.squeeze().t()
        ys_srk = ys_srk.squeeze().t()
        ys_true = ys_true.squeeze().t()

        ts_, ys_em_, ys_srk_, ys_true_ = to_numpy(ts, ys_em, ys_srk, ys_true)

    # Visualize sample path.
    img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_additive')
    makedirs_if_not_found(img_dir)

    for i, (ys_em_i, ys_srk_i,
            ys_true_i) in enumerate(zip(ys_em_, ys_srk_, ys_true_)):
        plt.figure()
        plt.plot(ts_, ys_em_i, label='em')
        plt.plot(ts_, ys_srk_i, label='srk')
        plt.plot(ts_, ys_true_i, label='true')
        plt.legend()
        plt.savefig(os.path.join(img_dir, f'{i}'))
        plt.close()
예제 #22
0
def inspect_samples(y0: Tensor,
                    ts: Vector,
                    dt: Scalar,
                    sde: BaseSDE,
                    bm: BaseBrownian,
                    img_dir: str,
                    methods: Tuple[str, ...],
                    options: Optional[Tuple] = None,
                    vis_dim=0,
                    dt_true: Optional[float] = 2**-14,
                    labels: Optional[Tuple[str, ...]] = None):
    if options is None:
        options = (None, ) * len(methods)
    if labels is None:
        labels = methods

    sde = copy.deepcopy(sde).requires_grad_(False)

    solns = [
        sdeint(sde, y0, ts, bm, method=method, dt=dt, options=options_)
        for method, options_ in zip(methods, options)
    ]

    method_for_true = 'euler' if sde.sde_type == SDE_TYPES.ito else 'midpoint'
    true = sdeint(sde, y0, ts, bm, method=method_for_true, dt=dt_true)

    labels += ('true', )
    solns += [true]

    # (T, batch_size, d) -> (T, batch_size) -> (batch_size, T).
    solns = [soln[..., vis_dim].t() for soln in solns]

    for i, samples in enumerate(zip(*solns)):
        utils.swiss_knife_plotter(img_path=os.path.join(img_dir, f'{i}'),
                                  plots=[{
                                      'x': ts,
                                      'y': sample,
                                      'label': label,
                                      'marker': 'x'
                                  } for sample, label in zip(samples, labels)])
예제 #23
0
def inspect_sample():
    batch_size, d = 32, 1
    steps = 100

    ts = torch.linspace(0., 5., steps=steps).to(device)
    t0 = ts[0]
    dt = 1e-1
    y0 = torch.ones(batch_size, d).to(device)
    sde = Ex2(d=d).to(device)
    sde.noise_type = "scalar"

    with torch.no_grad():
        bm = BrownianPath(t0=t0, w0=torch.zeros_like(y0))
        ys_euler = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='euler')
        ys_milstein = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='milstein')
        ys_srk = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='srk', options={'trapezoidal_approx': False})
        ys_analytical = sde.analytical_sample(y0=y0, ts=ts, bm=bm)

        ys_euler = ys_euler.squeeze().t()
        ys_milstein = ys_milstein.squeeze().t()
        ys_srk = ys_srk.squeeze().t()
        ys_analytical = ys_analytical.squeeze().t()

        ts_, ys_euler_, ys_milstein_, ys_srk_, ys_analytical_ = to_numpy(
            ts, ys_euler, ys_milstein, ys_srk, ys_analytical)

    # Visualize sample path.
    img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_scalar')
    makedirs_if_not_found(img_dir)

    for i, (ys_euler_i, ys_milstein_i, ys_srk_i, ys_analytical_i) in enumerate(
            zip(ys_euler_, ys_milstein_, ys_srk_, ys_analytical_)):
        plt.figure()
        plt.plot(ts_, ys_euler_i, label='euler')
        plt.plot(ts_, ys_milstein_i, label='milstein')
        plt.plot(ts_, ys_srk_i, label='srk')
        plt.plot(ts_, ys_analytical_i, label='analytical')
        plt.legend()
        plt.savefig(os.path.join(img_dir, f'{i}'))
        plt.close()
예제 #24
0
    def sde_sample(self,
                   batch_size=64,
                   tau=1.,
                   t=None,
                   y=None,
                   dt=1e-2,
                   tweedie_correction=True):
        self.module.eval()

        t = torch.tensor([-self.t1, -self.t0],
                         device=self.device) if t is None else t
        y = self.sample_t1_marginal(batch_size, tau) if y is None else y

        ys = torchsde.sdeint(self, y.flatten(start_dim=1), t, dt=dt)
        ys = ys.view(len(t), *y.size())
        if tweedie_correction:
            ys[-1] = self.tweedie_correction(self.t0, ys[-1], dt)
        return ys
예제 #25
0
def test_specialised_functions(sde_type, method):
    vector = torch.randn(m)
    fg = problems.FGSDE(sde_type, vector)
    f_and_g = problems.FAndGSDE(sde_type, vector)
    g_prod = problems.GProdSDE(sde_type, vector)
    f_and_g_prod = problems.FAndGProdSDE(sde_type, vector)
    f_and_g_with_g_prod1 = problems.FAndGGProdSDE1(sde_type, vector)
    f_and_g_with_g_prod2 = problems.FAndGGProdSDE2(sde_type, vector)

    y0 = torch.randn(batch_size, d)

    outs = []
    for sde in (fg, f_and_g, g_prod, f_and_g_prod, f_and_g_with_g_prod1, f_and_g_with_g_prod2):
        bm = torchsde.BrownianInterval(t0, t1, (batch_size, m), entropy=45678)
        outs.append(torchsde.sdeint(sde, y0, [t0, t1], dt=dt, bm=bm)[1])
    for o in outs[1:]:
        # Equality of floating points, because we expect them to do everything exactly the same.
        assert o.shape == outs[0].shape
        assert (o == outs[0]).all()
예제 #26
0
    def sample_q(self, vis_span, n_sim, eps=None, bm=None, dt=0.01):
        """[summary]

        Args:
            vis_span ([type]): [description]
            n_sim ([type]): [description]
            eps ([type], optional): [description]. Defaults to None.
            bm ([type], optional): [description]. Defaults to None.
            dt (float, optional): [description]. Defaults to 0.01.

        Returns:
            [type]: [description]
        """
        eps = torch.randn(n_sim, 1).to(self.qy0_mean) if eps is None else eps
        y0 = self.qy0_mean + eps.to(self.device) * self.qy0_std
        return torchsde.sdeint(self.defunc,
                               y0,
                               vis_span,
                               bm=bm,
                               method='srk',
                               dt=dt)
예제 #27
0
    def forward(self, ts, batch_size):
        # ts has shape (t_size,) and corresponds to the points we want to evaluate the SDE at.

        ###################
        # Actually solve the SDE.
        ###################
        init_noise = torch.randn(batch_size,
                                 self._initial_noise_size,
                                 device=ts.device)
        x0 = self._initial(init_noise)
        xs = torchsde.sdeint(self._func, x0, ts, method='midpoint',
                             dt=1.0)  # shape (t_size, batch_size, hidden_size)
        xs = xs.transpose(0, 1)  # switch t_size and batch_size
        ys = self._readout(xs)

        ###################
        # Normalise the data to the form that the discriminator expects, in particular including time as a channel.
        ###################
        t_size = ts.size(0)
        ts = ts.unsqueeze(0).unsqueeze(-1).expand(batch_size, t_size, 1)
        return torchcde.linear_interpolation_coeffs(torch.cat([ts, ys], dim=2))
예제 #28
0
    def forward(self, eps: torch.Tensor, s_span=None):
        """[summary]

        Args:
            eps (torch.Tensor): [description]
            s_span ([type], optional): [description]. Defaults to None.

        Returns:
            [type]: [description]
        """

        eps = eps.to(self.qy0_std)
        x0 = self.qy0_mean + eps * self.qy0_std

        qy0 = Normal(loc=self.qy0_mean, scale=self.qy0_std)
        py0 = Normal(loc=self.py0_mean, scale=self.py0_std)
        logqp0 = kl_divergence(qy0, py0).sum(1).mean(0)  # KL(time=0).

        if s_span is not None:
            s_span_ext = s_span
        else:
            s_span_ext = self.s_span.cpu()

        zs, logqp = torchsde.sdeint(sde=self.defunc,
                                    x0=x0,
                                    s_span=s_span_ext,
                                    rtol=self.rtol,
                                    atol=self.atol,
                                    logqp=True,
                                    options=self.options,
                                    adaptive=self.adaptive,
                                    method=self.solver)

        logqp = logqp.sum(0).mean(0)
        log_ratio = logqp0 + logqp  # KL(time=0) + KL(path).

        return zs, log_ratio
예제 #29
0
    def forward(self, ts, batch_size, eps=None):
        eps = torch.randn(batch_size, 1).to(
            self.qy0_std) if eps is None else eps
        y0 = self.qy0_mean + eps * self.qy0_std

        qy0 = Normal(loc=self.qy0_mean, scale=self.qy0_std)
        py0 = Normal(loc=self.py0_mean, scale=self.py0_std)
        logqp0 = kl_divergence(qy0, py0).sum(1).mean(0)  # KL(time=0).

        # `trapezoidal_approx` is for SRK. Disabling it gives better performance.
        if args.adjoint:
            zs, logqp = sdeint_adjoint(self,
                                       y0,
                                       ts,
                                       logqp=True,
                                       method=args.method,
                                       dt=args.dt,
                                       adaptive=args.adaptive,
                                       rtol=args.rtol,
                                       atol=args.atol,
                                       options={'trapezoidal_approx': False})
        else:
            zs, logqp = sdeint(self,
                               y0,
                               ts,
                               logqp=True,
                               method=args.method,
                               dt=args.dt,
                               adaptive=args.adaptive,
                               rtol=args.rtol,
                               atol=args.atol,
                               options={'trapezoidal_approx': False})
        logqp = logqp.sum(0).mean(0)
        log_ratio = logqp0 + logqp  # KL(time=0) + KL(path).

        return zs, log_ratio
예제 #30
0
def get_data(batch_size, device):
    class OrnsteinUhlenbeckSDE(torch.nn.Module):
        sde_type = 'ito'
        noise_type = 'scalar'

        def __init__(self, mu, theta, sigma):
            super(OrnsteinUhlenbeckSDE, self).__init__()
            self.register_buffer('mu', torch.as_tensor(mu))
            self.register_buffer('theta', torch.as_tensor(theta))
            self.register_buffer('sigma', torch.as_tensor(sigma))

        def f(self, t, y):
            return self.mu * t - self.theta * y

        def g(self, t, y):
            return self.sigma.expand(y.size(0), 1, 1)

    dataset_size = 8192
    t_size = 64

    ou_sde = OrnsteinUhlenbeckSDE(mu=0.02, theta=0.1, sigma=0.4).to(device)
    y0 = torch.rand(dataset_size, device=device).unsqueeze(-1) * 2 - 1
    ts = torch.linspace(0, t_size - 1, t_size, device=device)
    ys = torchsde.sdeint(ou_sde, y0, ts, dt=1e-1)

    ###################
    # To demonstrate how to handle irregular data, then here we additionally drop some of the data (by setting it to
    # NaN.)
    ###################
    ys_num = ys.numel()
    to_drop = torch.randperm(ys_num)[:int(0.3 * ys_num)]
    ys.view(-1)[to_drop] = float('nan')

    ###################
    # Typically important to normalise data. Note that the data is normalised with respect to the statistics of the
    # initial data, _not_ the whole time series. This seems to help the learning process, presumably because if the
    # initial condition is wrong then it's pretty hard to learn the rest of the SDE correctly.
    ###################
    y0_flat = ys[0].view(-1)
    y0_not_nan = y0_flat.masked_select(~torch.isnan(y0_flat))
    ys = (ys - y0_not_nan.mean()) / y0_not_nan.std()

    ###################
    # As discussed, time must be included as a channel for the discriminator.
    ###################
    ys = torch.cat([
        ts.unsqueeze(0).unsqueeze(-1).expand(dataset_size, t_size, 1),
        ys.transpose(0, 1)
    ],
                   dim=2)
    # shape (dataset_size=1000, t_size=100, 1 + data_size=3)

    ###################
    # Package up.
    ###################
    data_size = ys.size(
        -1
    ) - 1  # How many channels the data has (not including time, hence the minus one).
    ys_coeffs = torchcde.linear_interpolation_coeffs(ys)  # as per neural CDEs.
    dataset = torch.utils.data.TensorDataset(ys_coeffs)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             shuffle=True)

    return ts, data_size, dataloader