def g_prod(self, t, y_aug, noise): sde, params, n_tensors = self._base_sde, self.params, len(y_aug) // 2 y, adj_y = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors] with torch.enable_grad(): y = tuple(y_.detach().requires_grad_(True) for y_ in y) adj_y = tuple(adj_y_.detach() for adj_y_ in adj_y) g_eval = tuple(-g_ for g_ in sde.g(-t, y)) g_eval = misc.make_seq_requires_grad_y(g_eval, y) vjp_y_and_params = torch.autograd.grad( outputs=g_eval, inputs=y + params, grad_outputs=tuple(-noise_ * adj_y_ for noise_, adj_y_ in zip(noise, adj_y)), allow_unused=True, ) vjp_y = vjp_y_and_params[:n_tensors] vjp_y = misc.convert_none_to_zeros(vjp_y, y) vjp_params = vjp_y_and_params[n_tensors:] vjp_params = misc.flatten_convert_none_to_zeros(vjp_params, params) g_prod_eval = misc.seq_mul(g_eval, noise) return (*g_prod_eval, *vjp_y, vjp_params)
def gdg_prod(self, t, y, v): with torch.enable_grad(): y = tuple(y_.detach().requires_grad_(True) if not y_.requires_grad else y_ for y_ in y) val = self._base_sde.g(t, y) val = misc.make_seq_requires_grad(val) vjp_val = torch.autograd.grad( outputs=val, inputs=y, grad_outputs=misc.seq_mul(val, v), create_graph=True, allow_unused=True) vjp_val = misc.convert_none_to_zeros(vjp_val, y) return vjp_val
def gdg_prod(self, t, y_aug, noise): sde, params, n_tensors = self._base_sde, self.params, len(y_aug) // 3 y, adj_y, adj_l = y_aug[:n_tensors], y_aug[ n_tensors:2 * n_tensors], y_aug[2 * n_tensors:3 * n_tensors] vjp_l = tuple(torch.zeros_like(adj_l_) for adj_l_ in adj_l) with torch.enable_grad(): y = tuple(y_.detach().requires_grad_(True) for y_ in y) adj_y = tuple(adj_y_.detach().requires_grad_(True) for adj_y_ in adj_y) g_eval = sde.g(-t, y) g_eval = misc.make_seq_requires_grad_y(g_eval, y) gdg_times_v = torch.autograd.grad( outputs=g_eval, inputs=y, grad_outputs=misc.seq_mul(g_eval, noise), allow_unused=True, create_graph=True, ) gdg_times_v = misc.convert_none_to_zeros(gdg_times_v, y) dgdy = torch.autograd.grad( outputs=g_eval, inputs=y, grad_outputs=tuple(torch.ones_like(y_) for y_ in y), allow_unused=True, create_graph=True, ) dgdy = misc.convert_none_to_zeros(dgdy, y) prod_partials_adj_y_and_params = torch.autograd.grad( outputs=g_eval, inputs=y + params, grad_outputs=misc.seq_mul(adj_y, noise, dgdy), allow_unused=True, create_graph=True, ) prod_partials_adj_y = prod_partials_adj_y_and_params[:n_tensors] prod_partials_adj_y = misc.convert_none_to_zeros( prod_partials_adj_y, y) prod_partials_params = prod_partials_adj_y_and_params[n_tensors:] prod_partials_params = misc.flatten_convert_none_to_zeros( prod_partials_params, params) gdg_v = torch.autograd.grad( outputs=g_eval, inputs=y, grad_outputs=tuple( p.detach() for p in misc.seq_mul(adj_y, noise, g_eval)), allow_unused=True, create_graph=True, ) gdg_v = misc.convert_none_to_zeros(gdg_v, y) gdg_v = misc.make_seq_requires_grad_y(gdg_v, y) gdg_v = tuple(gdg_v_.sum() for gdg_v_ in gdg_v) mixed_partials_adj_y_and_params = torch.autograd.grad( outputs=gdg_v, inputs=y + params, allow_unused=True, ) mixed_partials_adj_y = mixed_partials_adj_y_and_params[:n_tensors] mixed_partials_adj_y = misc.convert_none_to_zeros( mixed_partials_adj_y, y) mixed_partials_params = mixed_partials_adj_y_and_params[n_tensors:] mixed_partials_params = misc.flatten_convert_none_to_zeros( mixed_partials_params, params) return (*gdg_times_v, *misc.seq_sub(prod_partials_adj_y, mixed_partials_adj_y), *vjp_l, prod_partials_params - mixed_partials_params)
def g_prod(self, t, y, v): if self.noise_type == "diagonal": return misc.seq_mul(self._base_sde.g(t, y), v) elif self.noise_type == "scalar": return misc.seq_mul_bc(self._base_sde.g(t, y), v) return misc.seq_batch_mvp(ms=self._base_sde.g(t, y), vs=v)