Ejemplo n.º 1
0
    def forward(ctx, *args):
        assert len(args) >= 14, 'Internal error: all arguments required.'
        y0 = args[:-13]
        (sde, ts, flat_params, dt, bm, method, adjoint_method, adaptive, rtol,
         atol, dt_min, options, adjoint_options) = args[-13:]
        (
            ctx.sde, ctx.dt, ctx.bm, ctx.adjoint_method, ctx.adaptive,
            ctx.rtol, ctx.atol, ctx.dt_min, ctx.adjoint_options
        ) = sde, dt, bm, adjoint_method, adaptive, rtol, atol, dt_min, adjoint_options

        sde = base_sde.ForwardSDEIto(sde)
        with torch.no_grad():
            ans = sdeint.integrate(sde=sde,
                                   y0=y0,
                                   ts=ts,
                                   bm=bm,
                                   method=method,
                                   dt=dt,
                                   adaptive=adaptive,
                                   rtol=rtol,
                                   atol=atol,
                                   dt_min=dt_min,
                                   options=options)
        ctx.save_for_backward(ts, flat_params, *ans)
        return ans
Ejemplo n.º 2
0
    def forward(ctx, *args):
        assert len(args) >= 14, 'Internal error: all arguments required.'
        y0 = args[:-13]
        (sde, ts, flat_params, dt, bm, method, adjoint_method, adaptive, rtol,
         atol, dt_min, options, adjoint_options) = args[-13:]
        (
            ctx.sde, ctx.dt, ctx.bm, ctx.adjoint_method, ctx.adaptive,
            ctx.rtol, ctx.atol, ctx.dt_min, ctx.adjoint_options
        ) = sde, dt, bm, adjoint_method, adaptive, rtol, atol, dt_min, adjoint_options

        sde = base_sde.ForwardSDEIto(sde)
        with torch.no_grad():
            ans_and_logqp = sdeint.integrate(sde=sde,
                                             y0=y0,
                                             ts=ts,
                                             bm=bm,
                                             method=method,
                                             dt=dt,
                                             adaptive=adaptive,
                                             rtol=rtol,
                                             atol=atol,
                                             dt_min=dt_min,
                                             options=options,
                                             logqp=True)
            ans, logqp = ans_and_logqp[:len(y0)], ans_and_logqp[len(y0):]

        # Don't need to save `logqp`, since it is never used in the backward pass to compute gradients.
        ctx.save_for_backward(ts, flat_params, *ans)
        return ans + logqp
Ejemplo n.º 3
0
    def backward(ctx, *grad_outputs):
        ts, flat_params, *ans = ctx.saved_tensors
        sde, dt, bm, adjoint_method, adaptive, rtol, atol, dt_min, adjoint_options = (
            ctx.sde, ctx.dt, ctx.bm, ctx.adjoint_method, ctx.adaptive,
            ctx.rtol, ctx.atol, ctx.dt_min, ctx.adjoint_options)
        params = misc.make_seq_requires_grad(sde.parameters())
        n_tensors, n_params = len(ans), len(params)

        # TODO: Make use of adjoint_method.
        aug_bm = lambda t: tuple(-bmi for bmi in bm(-t))
        adjoint_sde, adjoint_method, adjoint_adaptive = _get_adjoint_params(
            sde=sde, params=params, adaptive=adaptive, logqp=True)

        T = ans[0].size(0)
        with torch.no_grad():
            adj_y = tuple(grad_outputs_[-1]
                          for grad_outputs_ in grad_outputs[:n_tensors])
            adj_l = tuple(grad_outputs_[-1]
                          for grad_outputs_ in grad_outputs[n_tensors:])
            adj_params = torch.zeros_like(flat_params)

            for i in range(T - 1, 0, -1):
                ans_i = tuple(ans_[i] for ans_ in ans)
                aug_y0 = (*ans_i, *adj_y, *adj_l, adj_params)

                aug_ans = sdeint.integrate(sde=adjoint_sde,
                                           y0=aug_y0,
                                           ts=torch.tensor(
                                               [-ts[i], -ts[i - 1]]).to(ts),
                                           bm=aug_bm,
                                           method=adjoint_method,
                                           dt=dt,
                                           adaptive=adjoint_adaptive,
                                           rtol=rtol,
                                           atol=atol,
                                           dt_min=dt_min,
                                           options=adjoint_options)

                adj_y = aug_ans[n_tensors:2 * n_tensors]
                adj_params = aug_ans[-1]

                adj_y = tuple(adj_y_[1] for adj_y_ in adj_y)
                adj_params = adj_params[1]

                adj_y = misc.seq_add(
                    adj_y,
                    tuple(grad_outputs_[i - 1]
                          for grad_outputs_ in grad_outputs[:n_tensors]))
                adj_l = tuple(grad_outputs_[i - 1]
                              for grad_outputs_ in grad_outputs[n_tensors:])

                del aug_y0, aug_ans

        return (*adj_y, None, None, adj_params, None, None, None, None, None,
                None, None, None, None, None)