Example #1
0
    def backward(ctx, *grad_outputs):
        ts, flat_params, *ans = ctx.saved_tensors
        sde, dt, bm, adjoint_method, adaptive, rtol, atol, dt_min, adjoint_options = (
            ctx.sde, ctx.dt, ctx.bm, ctx.adjoint_method, ctx.adaptive,
            ctx.rtol, ctx.atol, ctx.dt_min, ctx.adjoint_options)
        params = misc.make_seq_requires_grad(sde.parameters())
        n_tensors, n_params = len(ans), len(params)

        # TODO: Make use of adjoint_method.
        aug_bm = lambda t: tuple(-bmi for bmi in bm(-t))
        adjoint_sde, adjoint_method, adjoint_adaptive = _get_adjoint_params(
            sde=sde, params=params, adaptive=adaptive, logqp=True)

        T = ans[0].size(0)
        with torch.no_grad():
            adj_y = tuple(grad_outputs_[-1]
                          for grad_outputs_ in grad_outputs[:n_tensors])
            adj_l = tuple(grad_outputs_[-1]
                          for grad_outputs_ in grad_outputs[n_tensors:])
            adj_params = torch.zeros_like(flat_params)

            for i in range(T - 1, 0, -1):
                ans_i = tuple(ans_[i] for ans_ in ans)
                aug_y0 = (*ans_i, *adj_y, *adj_l, adj_params)

                aug_ans = sdeint.integrate(sde=adjoint_sde,
                                           y0=aug_y0,
                                           ts=torch.tensor(
                                               [-ts[i], -ts[i - 1]]).to(ts),
                                           bm=aug_bm,
                                           method=adjoint_method,
                                           dt=dt,
                                           adaptive=adjoint_adaptive,
                                           rtol=rtol,
                                           atol=atol,
                                           dt_min=dt_min,
                                           options=adjoint_options)

                adj_y = aug_ans[n_tensors:2 * n_tensors]
                adj_params = aug_ans[-1]

                adj_y = tuple(adj_y_[1] for adj_y_ in adj_y)
                adj_params = adj_params[1]

                adj_y = misc.seq_add(
                    adj_y,
                    tuple(grad_outputs_[i - 1]
                          for grad_outputs_ in grad_outputs[:n_tensors]))
                adj_l = tuple(grad_outputs_[i - 1]
                              for grad_outputs_ in grad_outputs[n_tensors:])

                del aug_y0, aug_ans

        return (*adj_y, None, None, adj_params, None, None, None, None, None,
                None, None, None, None, None)
Example #2
0
    def f(self, t, y_aug):
        sde, params, n_tensors = self._base_sde, self.params, len(y_aug) // 3
        y, adj_y, adj_l = y_aug[:n_tensors], y_aug[
            n_tensors:2 * n_tensors], y_aug[2 * n_tensors:3 * n_tensors]
        vjp_l = tuple(torch.zeros_like(adj_l_) for adj_l_ in adj_l)

        with torch.enable_grad():
            y = tuple(y_.detach().requires_grad_(True) for y_ in y)
            adj_y = tuple(adj_y_.detach() for adj_y_ in adj_y)

            f_eval = sde.f(-t, y)
            f_eval = tuple(-f_eval_ for f_eval_ in f_eval)
            f_eval = misc.make_seq_requires_grad_y(f_eval, y)

            vjp_y_and_params = torch.autograd.grad(
                outputs=f_eval,
                inputs=y + params,
                grad_outputs=tuple(-adj_y_ for adj_y_ in adj_y),
                allow_unused=True,
                create_graph=True)
            vjp_y = vjp_y_and_params[:n_tensors]
            vjp_y = misc.convert_none_to_zeros(vjp_y, y)
            vjp_params = vjp_y_and_params[n_tensors:]
            vjp_params = misc.flatten_convert_none_to_zeros(vjp_params, params)

            # Vector field change due to log-ratio term, i.e. ||u||^2 / 2.
            g_eval = sde.g(-t, y)
            h_eval = sde.h(-t, y)

            ginv_eval = tuple(torch.pinverse(g_eval_) for g_eval_ in g_eval)
            u_eval = misc.seq_sub(f_eval, h_eval)
            u_eval = tuple(
                torch.bmm(ginv_eval_, u_eval_)
                for ginv_eval_, u_eval_ in zip(ginv_eval, u_eval))
            log_ratio_correction = tuple(.5 * torch.sum(u_eval_**2., dim=1)
                                         for u_eval_ in u_eval)
            log_ratio_correction = misc.make_seq_requires_grad_y(
                log_ratio_correction, y)
            corr_vjp_y_and_params = torch.autograd.grad(
                outputs=log_ratio_correction,
                inputs=y + params,
                grad_outputs=adj_l,
                allow_unused=True,
            )
            corr_vjp_y = corr_vjp_y_and_params[:n_tensors]
            corr_vjp_y = misc.convert_none_to_zeros(corr_vjp_y, y)
            corr_vjp_params = corr_vjp_y_and_params[n_tensors:]
            corr_vjp_params = misc.flatten_convert_none_to_zeros(
                corr_vjp_params, params)

            vjp_y = misc.seq_add(vjp_y, corr_vjp_y)
            vjp_params = vjp_params + corr_vjp_params

        return (*f_eval, *vjp_y, *vjp_l, vjp_params)
Example #3
0
    def f(self, t, y_aug):
        sde, params, n_tensors = self._base_sde, self.params, len(y_aug) // 2
        y, adj_y = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors]

        with torch.enable_grad():
            y = tuple(y_.detach().requires_grad_(True) for y_ in y)
            adj_y = tuple(adj_y_.detach() for adj_y_ in adj_y)

            g_eval = sde.g(-t, y)
            g_eval = misc.make_seq_requires_grad_y(g_eval, y)

            gdg = torch.autograd.grad(outputs=g_eval,
                                      inputs=y,
                                      grad_outputs=g_eval,
                                      allow_unused=True,
                                      create_graph=True)
            gdg = misc.convert_none_to_zeros(gdg, y)

            f_eval = sde.f(-t, y)

            f_eval_corrected = misc.seq_sub(
                gdg, f_eval)  # Stratonovich correction for reverse-time.
            f_eval_corrected = misc.make_seq_requires_grad_y(
                f_eval_corrected, y)

            vjp_y_and_params = torch.autograd.grad(
                outputs=f_eval_corrected,
                inputs=y + params,
                grad_outputs=tuple(-adj_y_ for adj_y_ in adj_y),
                allow_unused=True,
                create_graph=True)
            vjp_y = vjp_y_and_params[:n_tensors]
            vjp_y = misc.convert_none_to_zeros(vjp_y, y)
            vjp_params = vjp_y_and_params[n_tensors:]
            vjp_params = misc.flatten_convert_none_to_zeros(vjp_params, params)

            adj_times_dgdx = torch.autograd.grad(outputs=g_eval,
                                                 inputs=y,
                                                 grad_outputs=adj_y,
                                                 allow_unused=True,
                                                 create_graph=True)
            adj_times_dgdx = misc.convert_none_to_zeros(adj_times_dgdx, y)

            # This extra term is due to converting the *adjoint* Stratonovich backward SDE to Itô.
            extra_vjp_y_and_params = torch.autograd.grad(
                outputs=g_eval,
                inputs=y + params,
                grad_outputs=adj_times_dgdx,
                allow_unused=True,
            )
            extra_vjp_y = extra_vjp_y_and_params[:n_tensors]
            extra_vjp_y = misc.convert_none_to_zeros(extra_vjp_y, y)

            extra_vjp_params = extra_vjp_y_and_params[n_tensors:]
            extra_vjp_params = misc.flatten_convert_none_to_zeros(
                extra_vjp_params, params)

            vjp_y = misc.seq_add(vjp_y, extra_vjp_y)
            vjp_params = vjp_params + extra_vjp_params

        return (*f_eval_corrected, *vjp_y, vjp_params)