예제 #1
0
    def f(self, t, y_aug):
        sde, params, n_tensors = self._base_sde, self.params, len(y_aug) // 2
        y, adj_y = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors]

        with torch.enable_grad():
            y = [y_.detach().requires_grad_(True) for y_ in y]
            adj_y = [adj_y_.detach() for adj_y_ in adj_y]

            g_eval = sde.g(-t, y)
            g_eval = misc.make_seq_requires_grad(g_eval)

            gdg = misc.grad(outputs=g_eval,
                            inputs=y,
                            grad_outputs=g_eval,
                            allow_unused=True,
                            create_graph=True)
            gdg = misc.convert_none_to_zeros(gdg, y)

            f_eval = sde.f(-t, y)

            f_eval_corrected = misc.seq_sub(
                gdg, f_eval)  # Stratonovich correction for reverse-time.
            f_eval_corrected = misc.make_seq_requires_grad(f_eval_corrected)

            vjp_y_and_params = misc.grad(
                outputs=f_eval_corrected,
                inputs=y + params,
                grad_outputs=[-adj_y_ for adj_y_ in adj_y],
                allow_unused=True,
                create_graph=True)
            vjp_y = vjp_y_and_params[:n_tensors]
            vjp_y = misc.convert_none_to_zeros(vjp_y, y)
            vjp_params = vjp_y_and_params[n_tensors:]
            vjp_params = misc.flatten_convert_none_to_zeros(vjp_params, params)

            adj_times_dgdx = misc.grad(outputs=g_eval,
                                       inputs=y,
                                       grad_outputs=adj_y,
                                       allow_unused=True,
                                       create_graph=True)
            adj_times_dgdx = misc.convert_none_to_zeros(adj_times_dgdx, y)

            # This extra term is due to converting the *adjoint* Stratonovich backward SDE to Itô.
            extra_vjp_y_and_params = misc.grad(
                outputs=g_eval,
                inputs=y + params,
                grad_outputs=adj_times_dgdx,
                allow_unused=True,
            )
            extra_vjp_y = extra_vjp_y_and_params[:n_tensors]
            extra_vjp_y = misc.convert_none_to_zeros(extra_vjp_y, y)

            extra_vjp_params = extra_vjp_y_and_params[n_tensors:]
            extra_vjp_params = misc.flatten_convert_none_to_zeros(
                extra_vjp_params, params)

            vjp_y = misc.seq_add(vjp_y, extra_vjp_y)
            vjp_params = vjp_params + extra_vjp_params

        return (*f_eval_corrected, *vjp_y, vjp_params)
예제 #2
0
    def f(self, t, y_aug):
        sde, params, n_tensors = self._base_sde, self.params, len(y_aug) // 3
        y, adj_y, adj_l = y_aug[:n_tensors], y_aug[
            n_tensors:2 * n_tensors], y_aug[2 * n_tensors:3 * n_tensors]
        vjp_l = [torch.zeros_like(adj_l_) for adj_l_ in adj_l]

        with torch.enable_grad():
            y = [y_.detach().requires_grad_(True) for y_ in y]
            adj_y = [adj_y_.detach() for adj_y_ in adj_y]

            f_eval = sde.f(-t, y)
            f_eval = [-f_eval_ for f_eval_ in f_eval]
            f_eval = misc.make_seq_requires_grad(f_eval)

            vjp_y_and_params = misc.grad(
                outputs=f_eval,
                inputs=y + params,
                grad_outputs=[-adj_y_ for adj_y_ in adj_y],
                allow_unused=True,
                create_graph=True)
            vjp_y = vjp_y_and_params[:n_tensors]
            vjp_y = misc.convert_none_to_zeros(vjp_y, y)
            vjp_params = vjp_y_and_params[n_tensors:]
            vjp_params = misc.flatten_convert_none_to_zeros(vjp_params, params)

            # Vector field change due to log-ratio term, i.e. ||u||^2 / 2.
            g_eval = sde.g(-t, y)
            h_eval = sde.h(-t, y)

            g_inv_eval = [torch.pinverse(g_eval_) for g_eval_ in g_eval]
            u_eval = misc.seq_sub(f_eval, h_eval)
            u_eval = [
                torch.bmm(g_inv_eval_, u_eval_)
                for g_inv_eval_, u_eval_ in zip(g_inv_eval, u_eval)
            ]
            log_ratio_correction = [
                .5 * torch.sum(u_eval_**2., dim=1) for u_eval_ in u_eval
            ]
            log_ratio_correction = misc.make_seq_requires_grad(
                log_ratio_correction)
            corr_vjp_y_and_params = misc.grad(
                outputs=log_ratio_correction,
                inputs=y + params,
                grad_outputs=adj_l,
                allow_unused=True,
            )
            corr_vjp_y = corr_vjp_y_and_params[:n_tensors]
            corr_vjp_y = misc.convert_none_to_zeros(corr_vjp_y, y)
            corr_vjp_params = corr_vjp_y_and_params[n_tensors:]
            corr_vjp_params = misc.flatten_convert_none_to_zeros(
                corr_vjp_params, params)

            vjp_y = misc.seq_add(vjp_y, corr_vjp_y)
            vjp_params = vjp_params + corr_vjp_params

        return (*f_eval, *vjp_y, *vjp_l, vjp_params)
예제 #3
0
    def g_prod(self, t, y_aug, noise):
        sde, params, n_tensors = self._base_sde, self.params, len(y_aug) // 2
        y, adj_y = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors]

        with torch.enable_grad():
            y = [y_.detach().requires_grad_(True) for y_ in y]
            adj_y = [adj_y_.detach() for adj_y_ in adj_y]

            g_eval = [-g_ for g_ in sde.g(-t, y)]
            g_eval = misc.make_seq_requires_grad(g_eval)
            vjp_y_and_params = misc.grad(
                outputs=g_eval,
                inputs=y + params,
                grad_outputs=[
                    -noise_ * adj_y_ for noise_, adj_y_ in zip(noise, adj_y)
                ],
                allow_unused=True,
            )
            vjp_y = vjp_y_and_params[:n_tensors]
            vjp_y = misc.convert_none_to_zeros(vjp_y, y)

            vjp_params = vjp_y_and_params[n_tensors:]
            vjp_params = misc.flatten_convert_none_to_zeros(vjp_params, params)
            g_prod_eval = misc.seq_mul(g_eval, noise)

        return (*g_prod_eval, *vjp_y, vjp_params)
예제 #4
0
    def f(self, t, y_aug):
        sde, params, n_tensors = self._base_sde, self.params, len(y_aug) // 2
        y, adj_y = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors]

        with torch.enable_grad():
            y = [y_.detach().requires_grad_(True) for y_ in y]
            adj_y = [adj_y_.detach() for adj_y_ in adj_y]

            f_eval = sde.f(-t, y)
            f_eval = [-f_eval_ for f_eval_ in f_eval]
            f_eval = misc.make_seq_requires_grad(f_eval)

            vjp_y_and_params = misc.grad(
                outputs=f_eval,
                inputs=y + params,
                grad_outputs=[-adj_y_ for adj_y_ in adj_y],
                allow_unused=True,
                create_graph=True)
            vjp_y = vjp_y_and_params[:n_tensors]
            vjp_y = misc.convert_none_to_zeros(vjp_y, y)
            vjp_params = vjp_y_and_params[n_tensors:]
            vjp_params = misc.flatten_convert_none_to_zeros(vjp_params, params)

        return (*f_eval, *vjp_y, vjp_params)
예제 #5
0
    def gdg_prod(self, t, y_aug, noise):
        sde, params, n_tensors = self._base_sde, self.params, len(y_aug) // 3
        y, adj_y, adj_l = y_aug[:n_tensors], y_aug[
            n_tensors:2 * n_tensors], y_aug[2 * n_tensors:3 * n_tensors]
        vjp_l = [torch.zeros_like(adj_l_) for adj_l_ in adj_l]

        with torch.enable_grad():
            y = [y_.detach().requires_grad_(True) for y_ in y]
            adj_y = [adj_y_.detach().requires_grad_(True) for adj_y_ in adj_y]

            g_eval = sde.g(-t, y)
            g_eval = misc.make_seq_requires_grad(g_eval)

            gdg_times_v = misc.grad(
                outputs=g_eval,
                inputs=y,
                grad_outputs=misc.seq_mul(g_eval, noise),
                allow_unused=True,
                create_graph=True,
            )
            gdg_times_v = misc.convert_none_to_zeros(gdg_times_v, y)

            dgdy = misc.grad(
                outputs=g_eval,
                inputs=y,
                grad_outputs=[torch.ones_like(y_) for y_ in y],
                allow_unused=True,
                create_graph=True,
            )
            dgdy = misc.convert_none_to_zeros(dgdy, y)

            prod_partials_adj_y_and_params = misc.grad(
                outputs=g_eval,
                inputs=y + params,
                grad_outputs=misc.seq_mul(adj_y, noise, dgdy),
                allow_unused=True,
                create_graph=True,
            )
            prod_partials_adj_y = prod_partials_adj_y_and_params[:n_tensors]
            prod_partials_adj_y = misc.convert_none_to_zeros(
                prod_partials_adj_y, y)
            prod_partials_params = prod_partials_adj_y_and_params[n_tensors:]
            prod_partials_params = misc.flatten_convert_none_to_zeros(
                prod_partials_params, params)

            gdg_v = misc.grad(
                outputs=g_eval,
                inputs=y,
                grad_outputs=[
                    p.detach() for p in misc.seq_mul(adj_y, noise, g_eval)
                ],
                allow_unused=True,
                create_graph=True,
            )
            gdg_v = misc.convert_none_to_zeros(gdg_v, y)
            gdg_v = misc.make_seq_requires_grad(gdg_v)

            gdg_v = [gdg_v_.sum() for gdg_v_ in gdg_v]
            mixed_partials_adj_y_and_params = misc.grad(
                outputs=gdg_v,
                inputs=y + params,
                allow_unused=True,
            )
            mixed_partials_adj_y = mixed_partials_adj_y_and_params[:n_tensors]
            mixed_partials_adj_y = misc.convert_none_to_zeros(
                mixed_partials_adj_y, y)
            mixed_partials_params = mixed_partials_adj_y_and_params[n_tensors:]
            mixed_partials_params = misc.flatten_convert_none_to_zeros(
                mixed_partials_params, params)

        return (*gdg_times_v,
                *misc.seq_sub(prod_partials_adj_y, mixed_partials_adj_y),
                *vjp_l, prod_partials_params - mixed_partials_params)
예제 #6
0
    def f(self, t, y_aug):
        sde, params, n_tensors = self._base_sde, self.params, len(y_aug) // 3
        y, adj_y, adj_l = y_aug[:n_tensors], y_aug[
            n_tensors:2 * n_tensors], y_aug[2 * n_tensors:3 * n_tensors]
        vjp_l = [torch.zeros_like(adj_l_) for adj_l_ in adj_l]

        with torch.enable_grad():
            y = [y_.detach().requires_grad_(True) for y_ in y]
            adj_y = [adj_y_.detach() for adj_y_ in adj_y]

            g_eval = sde.g(-t, y)
            g_eval = misc.make_seq_requires_grad(g_eval)

            gdg = misc.grad(
                outputs=g_eval,
                inputs=y,
                grad_outputs=g_eval,
                allow_unused=True,
                create_graph=True,
            )
            gdg = misc.convert_none_to_zeros(gdg, y)

            f_eval = sde.f(-t, y)
            f_eval_corrected = misc.seq_sub(gdg, f_eval)
            f_eval_corrected = misc.make_seq_requires_grad(f_eval_corrected)

            vjp_y_and_params = misc.grad(
                outputs=f_eval_corrected,
                inputs=y + params,
                grad_outputs=[-adj_y_ for adj_y_ in adj_y],
                allow_unused=True,
                create_graph=True)
            vjp_y = vjp_y_and_params[:n_tensors]
            vjp_y = misc.convert_none_to_zeros(vjp_y, y)
            vjp_params = vjp_y_and_params[n_tensors:]
            vjp_params = misc.flatten_convert_none_to_zeros(vjp_params, params)

            adj_times_dgdx = misc.grad(outputs=g_eval,
                                       inputs=y,
                                       grad_outputs=adj_y,
                                       allow_unused=True,
                                       create_graph=True)
            adj_times_dgdx = misc.convert_none_to_zeros(adj_times_dgdx, y)

            extra_vjp_y_and_params = misc.grad(
                outputs=g_eval,
                inputs=y + params,
                grad_outputs=adj_times_dgdx,
                allow_unused=True,
                create_graph=True,
            )
            extra_vjp_y = extra_vjp_y_and_params[:n_tensors]
            extra_vjp_y = misc.convert_none_to_zeros(extra_vjp_y, y)
            extra_vjp_params = extra_vjp_y_and_params[n_tensors:]
            extra_vjp_params = misc.flatten_convert_none_to_zeros(
                extra_vjp_params, params)

            # Vector field change due to log-ratio term, i.e. ||u||^2 / 2.
            h_eval = sde.h(-t, y)
            u_eval = misc.seq_sub_div(f_eval, h_eval, g_eval)
            log_ratio_correction = [
                .5 * torch.sum(u_eval_**2., dim=1) for u_eval_ in u_eval
            ]

            log_ratio_correction = misc.make_seq_requires_grad(
                log_ratio_correction)
            corr_vjp_y_and_params = misc.grad(
                outputs=log_ratio_correction,
                inputs=y + params,
                grad_outputs=adj_l,
                allow_unused=True,
            )
            corr_vjp_y = corr_vjp_y_and_params[:n_tensors]
            corr_vjp_y = misc.convert_none_to_zeros(corr_vjp_y, y)
            corr_vjp_params = corr_vjp_y_and_params[n_tensors:]
            corr_vjp_params = misc.flatten_convert_none_to_zeros(
                corr_vjp_params, params)

            vjp_y = misc.seq_add(vjp_y, extra_vjp_y, corr_vjp_y)
            vjp_params = vjp_params + extra_vjp_params + corr_vjp_params

        return (*f_eval_corrected, *vjp_y, *vjp_l, vjp_params)