コード例 #1
0
ファイル: brownian.py プロジェクト: xanderyin/torchsde
def _time_sdeint_adjoint(sde, y0, ts, bm):
    now = time.perf_counter()
    sde.zero_grad()
    y0 = y0.clone().requires_grad_(True)
    ys = torchsde.sdeint_adjoint(sde, y0, ts, bm, method='euler')
    ys.sum().backward()
    return time.perf_counter() - now
コード例 #2
0
ファイル: neuralsde.py プロジェクト: ucalyptus/torchdyn
def sdeint(sde,
           x0,
           s_span,
           bm=None,
           logqp=False,
           method='srk',
           adaptive=False,
           rtol=1e-6,
           atol=1e-5,
           dt=1e-2,
           dt_min=1e-4,
           options=None,
           names=None):
    """
    s_span -> ts:   ts (Tensor or sequence of float): Query times in non-descending order.
                    The state at the first time of `ts` should be `y0`.
    x0 -> y0:       y0 (sequence of Tensor): Tensors for initial state.
    """

    return sdeint_adjoint(sde=sde,
                          y0=x0,
                          ts=s_span,
                          bm=bm,
                          logqp=logqp,
                          method=method,
                          dt=dt,
                          adaptive=adaptive,
                          rtol=rtol,
                          atol=atol,
                          dt_min=dt_min,
                          options=options,
                          names=names)
コード例 #3
0
    def forward(self, xs, ts, noise_std, adjoint=False, method="euler"):
        # Contextualization is only needed for posterior inference.
        ctx = self.encoder(torch.flip(xs, dims=(0, )))
        ctx = torch.flip(ctx, dims=(0, ))
        self.contextualize((ts, ctx))

        if adjoint:
            # Must use the argument `adjoint_params`, since `ctx` is not part of the input to `f`, `g`, and `h`.
            adjoint_params = ((ctx, ) + tuple(self.f_net.parameters()) +
                              tuple(self.g_nets.parameters()) +
                              tuple(self.h_net.parameters()))
            _xs, log_ratio = torchsde.sdeint_adjoint(
                self,
                xs[0],
                ts,
                adjoint_params=adjoint_params,
                dt=1e-2,
                logqp=True,
                method=method)
        else:
            _xs, log_ratio = torchsde.sdeint(self,
                                             xs[0],
                                             ts,
                                             dt=1e-2,
                                             logqp=True,
                                             method=method)

        xs_dist = Normal(loc=_xs, scale=noise_std)
        log_pxs = xs_dist.log_prob(xs).sum(dim=(0, 2)).mean()
        log_ratio = log_ratio.sum(dim=0).mean()
        return log_pxs, log_ratio
コード例 #4
0
ファイル: sde_gan.py プロジェクト: shi27feng/torchsde
    def forward(self, ts, batch_size):
        # ts has shape (t_size,) and corresponds to the points we want to evaluate the SDE at.

        ###################
        # Actually solve the SDE.
        ###################
        init_noise = torch.randn(batch_size,
                                 self._initial_noise_size,
                                 device=ts.device)
        x0 = self._initial(init_noise)

        ###################
        # We use the reversible Heun method to get accurate gradients whilst using the adjoint method.
        ###################
        xs = torchsde.sdeint_adjoint(
            self._func,
            x0,
            ts,
            method='reversible_heun',
            dt=1.0,
            adjoint_method='adjoint_reversible_heun',
        )
        xs = xs.transpose(0, 1)
        ys = self._readout(xs)

        ###################
        # Normalise the data to the form that the discriminator expects, in particular including time as a channel.
        ###################
        ts = ts.unsqueeze(0).unsqueeze(-1).expand(batch_size, ts.size(0), 1)
        return torchcde.linear_interpolation_coeffs(torch.cat([ts, ys], dim=2))
コード例 #5
0
 def func(inputs, modules):
     y0, sde = inputs[0], modules[0]
     ys = torchsde.sdeint_adjoint(sde,
                                  y0,
                                  ts,
                                  bm,
                                  dt=dt,
                                  method=method,
                                  adaptive=adaptive)
     return (ys[-1]**2).sum(dim=1).mean(dim=0)
コード例 #6
0
ファイル: neuralde.py プロジェクト: willw625731/torchdyn
 def _adjoint(self, x):
     out = torchsde.sdeint_adjoint(self.defunc,
                                   x,
                                   self.s_span,
                                   rtol=self.rtol,
                                   atol=self.atol,
                                   adaptive=self.adaptive,
                                   method=self.solver,
                                   dt=self.ds)[-1]
     return out
コード例 #7
0
 def func(x):
     ys_and_logqp = sdeint_adjoint(sde,
                                   x,
                                   ts,
                                   bm,
                                   logqp=True,
                                   method=method,
                                   dt=dt,
                                   adaptive=adaptive)
     ys, logqp = ys_and_logqp
     # Just another arbitrarily chosen function with two outputs.
     return torch.stack([(ys**2.).sum(), (logqp / 3.).sum()], dim=0)
コード例 #8
0
def _test_forward(sde, bm, method, adaptive=False, rtol=1e-6, atol=1e-5):
    sde.zero_grad()
    ys, log_ratio = sdeint_adjoint(sde,
                                   y0,
                                   ts,
                                   bm,
                                   logqp=True,
                                   method=method,
                                   dt=dt,
                                   adaptive=adaptive,
                                   rtol=rtol,
                                   atol=atol)
    loss = ys.sum(0).mean(0).sum(0) + log_ratio.sum(0).mean(0)
    loss.backward()
コード例 #9
0
ファイル: test_adjoint.py プロジェクト: wtwong316/torchsde
def test_basic(problem, method, adaptive):
    d = 10
    batch_size = 128
    ts = torch.tensor([0.0, 0.5], device=device)
    dt = 1e-3
    y0 = torch.zeros(batch_size, d).to(device).fill_(0.1)

    problem = problem(d).to(device)

    num_before = _count_differentiable_params(problem)

    problem.zero_grad()
    _, yt = torchsde.sdeint_adjoint(problem, y0, ts, method=method, dt=dt, adaptive=adaptive)
    loss = yt.sum(dim=1).mean(dim=0)
    loss.backward()

    num_after = _count_differentiable_params(problem)
    assert num_before == num_after
コード例 #10
0
ファイル: test_adjoint.py プロジェクト: stjordanis/torchsde
    def _test_basic(self, problem, method, adaptive, rtol=1e-6, atol=1e-5):
        if method == 'euler' and adaptive:
            return

        nbefore = _count_differentiable_params(problem)

        problem.zero_grad()
        _, yt = sdeint_adjoint(problem,
                               y0,
                               ts,
                               method=method,
                               dt=dt,
                               adaptive=adaptive,
                               rtol=rtol,
                               atol=atol)
        loss = yt.sum(dim=1).mean(dim=0)
        loss.backward()

        nafter = _count_differentiable_params(problem)
        self.assertEqual(nbefore, nafter)
コード例 #11
0
ファイル: test_adjoint.py プロジェクト: stjordanis/torchsde
    def _test_gradient(self, problem, method, adaptive, rtol=1e-6, atol=1e-5):
        if method == 'euler' and adaptive:
            return

        bm = BrownianPath(t0=t0, w0=w0)
        with torch.no_grad():
            grad_outputs = torch.ones(batch_size, d).to(device)
            alt_grad = problem.analytical_grad(y0, t1, grad_outputs, bm)

        problem.zero_grad()
        _, yt = sdeint_adjoint(problem,
                               y0,
                               ts,
                               bm=bm,
                               method=method,
                               dt=dt,
                               adaptive=adaptive,
                               rtol=rtol,
                               atol=atol)
        loss = yt.sum(dim=1).mean(dim=0)
        loss.backward()
        adj_grad = torch.cat(tuple(p.grad for p in problem.parameters()))
        self.tensorAssertAllClose(alt_grad, adj_grad)
コード例 #12
0
ファイル: latent_sde.py プロジェクト: stjordanis/torchsde
    def forward(self, ts, batch_size, eps=None):
        eps = torch.randn(batch_size, 1).to(
            self.qy0_std) if eps is None else eps
        y0 = self.qy0_mean + eps * self.qy0_std

        qy0 = Normal(loc=self.qy0_mean, scale=self.qy0_std)
        py0 = Normal(loc=self.py0_mean, scale=self.py0_std)
        logqp0 = kl_divergence(qy0, py0).sum(1).mean(0)  # KL(time=0).

        # `trapezoidal_approx` is for SRK. Disabling it gives better performance.
        if args.adjoint:
            zs, logqp = sdeint_adjoint(self,
                                       y0,
                                       ts,
                                       logqp=True,
                                       method=args.method,
                                       dt=args.dt,
                                       adaptive=args.adaptive,
                                       rtol=args.rtol,
                                       atol=args.atol,
                                       options={'trapezoidal_approx': False})
        else:
            zs, logqp = sdeint(self,
                               y0,
                               ts,
                               logqp=True,
                               method=args.method,
                               dt=args.dt,
                               adaptive=args.adaptive,
                               rtol=args.rtol,
                               atol=args.atol,
                               options={'trapezoidal_approx': False})
        logqp = logqp.sum(0).mean(0)
        log_ratio = logqp0 + logqp  # KL(time=0) + KL(path).

        return zs, log_ratio
コード例 #13
0
def test_against_sdeint(sde_cls, sde_type, method, options, dt, rtol, atol,
                        len_ts):
    # Skipping below, since method not supported for corresponding noise types.
    if sde_cls.noise_type == NOISE_TYPES.general and method in (
            METHODS.milstein, METHODS.srk):
        return

    d = 3
    m = {
        NOISE_TYPES.scalar: 1,
        NOISE_TYPES.diagonal: d,
        NOISE_TYPES.general: 2,
        NOISE_TYPES.additive: 2
    }[sde_cls.noise_type]
    batch_size = 4
    ts = torch.linspace(0.0, 1.0, len_ts, device=device, dtype=torch.float64)
    t0 = ts[0]
    t1 = ts[-1]
    y0 = torch.full((batch_size, d),
                    0.1,
                    device=device,
                    dtype=torch.float64,
                    requires_grad=True)
    sde = sde_cls(d=d, m=m, sde_type=sde_type).to(device, torch.float64)

    if method == METHODS.srk:
        levy_area_approximation = LEVY_AREA_APPROXIMATIONS.space_time
    else:
        levy_area_approximation = LEVY_AREA_APPROXIMATIONS.none
    bm = torchsde.BrownianInterval(
        t0=t0,
        t1=t1,
        size=(batch_size, m),
        dtype=torch.float64,
        device=device,
        levy_area_approximation=levy_area_approximation)

    if method == METHODS.reversible_heun:
        adjoint_method = METHODS.adjoint_reversible_heun
        adjoint_options = options
    else:
        adjoint_method = None
        adjoint_options = None

    ys_true = torchsde.sdeint(sde,
                              y0,
                              ts,
                              dt=dt,
                              method=method,
                              bm=bm,
                              options=options)
    grad = torch.randn_like(ys_true)
    ys_true.backward(grad)

    true_grad = torch.cat([y0.grad.view(-1)] +
                          [param.grad.view(-1) for param in sde.parameters()])
    y0.grad.zero_()
    for param in sde.parameters():
        param.grad.zero_()

    ys_test = torchsde.sdeint_adjoint(sde,
                                      y0,
                                      ts,
                                      dt=dt,
                                      method=method,
                                      bm=bm,
                                      adjoint_method=adjoint_method,
                                      options=options,
                                      adjoint_options=adjoint_options)
    ys_test.backward(grad)
    test_grad = torch.cat([y0.grad.view(-1)] +
                          [param.grad.view(-1) for param in sde.parameters()])

    torch.testing.assert_allclose(ys_true, ys_test)
    torch.testing.assert_allclose(true_grad, test_grad, rtol=rtol, atol=atol)