def f(self, t, y_aug): sde, params, n_tensors = self._base_sde, self.params, len(y_aug) // 2 y, adj_y = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors] with torch.enable_grad(): y = [y_.detach().requires_grad_(True) for y_ in y] adj_y = [adj_y_.detach() for adj_y_ in adj_y] g_eval = sde.g(-t, y) g_eval = misc.make_seq_requires_grad(g_eval) gdg = misc.grad(outputs=g_eval, inputs=y, grad_outputs=g_eval, allow_unused=True, create_graph=True) gdg = misc.convert_none_to_zeros(gdg, y) f_eval = sde.f(-t, y) f_eval_corrected = misc.seq_sub( gdg, f_eval) # Stratonovich correction for reverse-time. f_eval_corrected = misc.make_seq_requires_grad(f_eval_corrected) vjp_y_and_params = misc.grad( outputs=f_eval_corrected, inputs=y + params, grad_outputs=[-adj_y_ for adj_y_ in adj_y], allow_unused=True, create_graph=True) vjp_y = vjp_y_and_params[:n_tensors] vjp_y = misc.convert_none_to_zeros(vjp_y, y) vjp_params = vjp_y_and_params[n_tensors:] vjp_params = misc.flatten_convert_none_to_zeros(vjp_params, params) adj_times_dgdx = misc.grad(outputs=g_eval, inputs=y, grad_outputs=adj_y, allow_unused=True, create_graph=True) adj_times_dgdx = misc.convert_none_to_zeros(adj_times_dgdx, y) # This extra term is due to converting the *adjoint* Stratonovich backward SDE to Itô. extra_vjp_y_and_params = misc.grad( outputs=g_eval, inputs=y + params, grad_outputs=adj_times_dgdx, allow_unused=True, ) extra_vjp_y = extra_vjp_y_and_params[:n_tensors] extra_vjp_y = misc.convert_none_to_zeros(extra_vjp_y, y) extra_vjp_params = extra_vjp_y_and_params[n_tensors:] extra_vjp_params = misc.flatten_convert_none_to_zeros( extra_vjp_params, params) vjp_y = misc.seq_add(vjp_y, extra_vjp_y) vjp_params = vjp_params + extra_vjp_params return (*f_eval_corrected, *vjp_y, vjp_params)
def f(self, t, y_aug): sde, params, n_tensors = self._base_sde, self.params, len(y_aug) // 3 y, adj_y, adj_l = y_aug[:n_tensors], y_aug[ n_tensors:2 * n_tensors], y_aug[2 * n_tensors:3 * n_tensors] vjp_l = [torch.zeros_like(adj_l_) for adj_l_ in adj_l] with torch.enable_grad(): y = [y_.detach().requires_grad_(True) for y_ in y] adj_y = [adj_y_.detach() for adj_y_ in adj_y] f_eval = sde.f(-t, y) f_eval = [-f_eval_ for f_eval_ in f_eval] f_eval = misc.make_seq_requires_grad(f_eval) vjp_y_and_params = misc.grad( outputs=f_eval, inputs=y + params, grad_outputs=[-adj_y_ for adj_y_ in adj_y], allow_unused=True, create_graph=True) vjp_y = vjp_y_and_params[:n_tensors] vjp_y = misc.convert_none_to_zeros(vjp_y, y) vjp_params = vjp_y_and_params[n_tensors:] vjp_params = misc.flatten_convert_none_to_zeros(vjp_params, params) # Vector field change due to log-ratio term, i.e. ||u||^2 / 2. g_eval = sde.g(-t, y) h_eval = sde.h(-t, y) g_inv_eval = [torch.pinverse(g_eval_) for g_eval_ in g_eval] u_eval = misc.seq_sub(f_eval, h_eval) u_eval = [ torch.bmm(g_inv_eval_, u_eval_) for g_inv_eval_, u_eval_ in zip(g_inv_eval, u_eval) ] log_ratio_correction = [ .5 * torch.sum(u_eval_**2., dim=1) for u_eval_ in u_eval ] log_ratio_correction = misc.make_seq_requires_grad( log_ratio_correction) corr_vjp_y_and_params = misc.grad( outputs=log_ratio_correction, inputs=y + params, grad_outputs=adj_l, allow_unused=True, ) corr_vjp_y = corr_vjp_y_and_params[:n_tensors] corr_vjp_y = misc.convert_none_to_zeros(corr_vjp_y, y) corr_vjp_params = corr_vjp_y_and_params[n_tensors:] corr_vjp_params = misc.flatten_convert_none_to_zeros( corr_vjp_params, params) vjp_y = misc.seq_add(vjp_y, corr_vjp_y) vjp_params = vjp_params + corr_vjp_params return (*f_eval, *vjp_y, *vjp_l, vjp_params)
def g_prod(self, t, y_aug, noise): sde, params, n_tensors = self._base_sde, self.params, len(y_aug) // 2 y, adj_y = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors] with torch.enable_grad(): y = [y_.detach().requires_grad_(True) for y_ in y] adj_y = [adj_y_.detach() for adj_y_ in adj_y] g_eval = [-g_ for g_ in sde.g(-t, y)] g_eval = misc.make_seq_requires_grad(g_eval) vjp_y_and_params = misc.grad( outputs=g_eval, inputs=y + params, grad_outputs=[ -noise_ * adj_y_ for noise_, adj_y_ in zip(noise, adj_y) ], allow_unused=True, ) vjp_y = vjp_y_and_params[:n_tensors] vjp_y = misc.convert_none_to_zeros(vjp_y, y) vjp_params = vjp_y_and_params[n_tensors:] vjp_params = misc.flatten_convert_none_to_zeros(vjp_params, params) g_prod_eval = misc.seq_mul(g_eval, noise) return (*g_prod_eval, *vjp_y, vjp_params)
def f(self, t, y_aug): sde, params, n_tensors = self._base_sde, self.params, len(y_aug) // 2 y, adj_y = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors] with torch.enable_grad(): y = [y_.detach().requires_grad_(True) for y_ in y] adj_y = [adj_y_.detach() for adj_y_ in adj_y] f_eval = sde.f(-t, y) f_eval = [-f_eval_ for f_eval_ in f_eval] f_eval = misc.make_seq_requires_grad(f_eval) vjp_y_and_params = misc.grad( outputs=f_eval, inputs=y + params, grad_outputs=[-adj_y_ for adj_y_ in adj_y], allow_unused=True, create_graph=True) vjp_y = vjp_y_and_params[:n_tensors] vjp_y = misc.convert_none_to_zeros(vjp_y, y) vjp_params = vjp_y_and_params[n_tensors:] vjp_params = misc.flatten_convert_none_to_zeros(vjp_params, params) return (*f_eval, *vjp_y, vjp_params)
def gdg_prod(self, t, y_aug, noise): sde, params, n_tensors = self._base_sde, self.params, len(y_aug) // 3 y, adj_y, adj_l = y_aug[:n_tensors], y_aug[ n_tensors:2 * n_tensors], y_aug[2 * n_tensors:3 * n_tensors] vjp_l = [torch.zeros_like(adj_l_) for adj_l_ in adj_l] with torch.enable_grad(): y = [y_.detach().requires_grad_(True) for y_ in y] adj_y = [adj_y_.detach().requires_grad_(True) for adj_y_ in adj_y] g_eval = sde.g(-t, y) g_eval = misc.make_seq_requires_grad(g_eval) gdg_times_v = misc.grad( outputs=g_eval, inputs=y, grad_outputs=misc.seq_mul(g_eval, noise), allow_unused=True, create_graph=True, ) gdg_times_v = misc.convert_none_to_zeros(gdg_times_v, y) dgdy = misc.grad( outputs=g_eval, inputs=y, grad_outputs=[torch.ones_like(y_) for y_ in y], allow_unused=True, create_graph=True, ) dgdy = misc.convert_none_to_zeros(dgdy, y) prod_partials_adj_y_and_params = misc.grad( outputs=g_eval, inputs=y + params, grad_outputs=misc.seq_mul(adj_y, noise, dgdy), allow_unused=True, create_graph=True, ) prod_partials_adj_y = prod_partials_adj_y_and_params[:n_tensors] prod_partials_adj_y = misc.convert_none_to_zeros( prod_partials_adj_y, y) prod_partials_params = prod_partials_adj_y_and_params[n_tensors:] prod_partials_params = misc.flatten_convert_none_to_zeros( prod_partials_params, params) gdg_v = misc.grad( outputs=g_eval, inputs=y, grad_outputs=[ p.detach() for p in misc.seq_mul(adj_y, noise, g_eval) ], allow_unused=True, create_graph=True, ) gdg_v = misc.convert_none_to_zeros(gdg_v, y) gdg_v = misc.make_seq_requires_grad(gdg_v) gdg_v = [gdg_v_.sum() for gdg_v_ in gdg_v] mixed_partials_adj_y_and_params = misc.grad( outputs=gdg_v, inputs=y + params, allow_unused=True, ) mixed_partials_adj_y = mixed_partials_adj_y_and_params[:n_tensors] mixed_partials_adj_y = misc.convert_none_to_zeros( mixed_partials_adj_y, y) mixed_partials_params = mixed_partials_adj_y_and_params[n_tensors:] mixed_partials_params = misc.flatten_convert_none_to_zeros( mixed_partials_params, params) return (*gdg_times_v, *misc.seq_sub(prod_partials_adj_y, mixed_partials_adj_y), *vjp_l, prod_partials_params - mixed_partials_params)
def f(self, t, y_aug): sde, params, n_tensors = self._base_sde, self.params, len(y_aug) // 3 y, adj_y, adj_l = y_aug[:n_tensors], y_aug[ n_tensors:2 * n_tensors], y_aug[2 * n_tensors:3 * n_tensors] vjp_l = [torch.zeros_like(adj_l_) for adj_l_ in adj_l] with torch.enable_grad(): y = [y_.detach().requires_grad_(True) for y_ in y] adj_y = [adj_y_.detach() for adj_y_ in adj_y] g_eval = sde.g(-t, y) g_eval = misc.make_seq_requires_grad(g_eval) gdg = misc.grad( outputs=g_eval, inputs=y, grad_outputs=g_eval, allow_unused=True, create_graph=True, ) gdg = misc.convert_none_to_zeros(gdg, y) f_eval = sde.f(-t, y) f_eval_corrected = misc.seq_sub(gdg, f_eval) f_eval_corrected = misc.make_seq_requires_grad(f_eval_corrected) vjp_y_and_params = misc.grad( outputs=f_eval_corrected, inputs=y + params, grad_outputs=[-adj_y_ for adj_y_ in adj_y], allow_unused=True, create_graph=True) vjp_y = vjp_y_and_params[:n_tensors] vjp_y = misc.convert_none_to_zeros(vjp_y, y) vjp_params = vjp_y_and_params[n_tensors:] vjp_params = misc.flatten_convert_none_to_zeros(vjp_params, params) adj_times_dgdx = misc.grad(outputs=g_eval, inputs=y, grad_outputs=adj_y, allow_unused=True, create_graph=True) adj_times_dgdx = misc.convert_none_to_zeros(adj_times_dgdx, y) extra_vjp_y_and_params = misc.grad( outputs=g_eval, inputs=y + params, grad_outputs=adj_times_dgdx, allow_unused=True, create_graph=True, ) extra_vjp_y = extra_vjp_y_and_params[:n_tensors] extra_vjp_y = misc.convert_none_to_zeros(extra_vjp_y, y) extra_vjp_params = extra_vjp_y_and_params[n_tensors:] extra_vjp_params = misc.flatten_convert_none_to_zeros( extra_vjp_params, params) # Vector field change due to log-ratio term, i.e. ||u||^2 / 2. h_eval = sde.h(-t, y) u_eval = misc.seq_sub_div(f_eval, h_eval, g_eval) log_ratio_correction = [ .5 * torch.sum(u_eval_**2., dim=1) for u_eval_ in u_eval ] log_ratio_correction = misc.make_seq_requires_grad( log_ratio_correction) corr_vjp_y_and_params = misc.grad( outputs=log_ratio_correction, inputs=y + params, grad_outputs=adj_l, allow_unused=True, ) corr_vjp_y = corr_vjp_y_and_params[:n_tensors] corr_vjp_y = misc.convert_none_to_zeros(corr_vjp_y, y) corr_vjp_params = corr_vjp_y_and_params[n_tensors:] corr_vjp_params = misc.flatten_convert_none_to_zeros( corr_vjp_params, params) vjp_y = misc.seq_add(vjp_y, extra_vjp_y, corr_vjp_y) vjp_params = vjp_params + extra_vjp_params + corr_vjp_params return (*f_eval_corrected, *vjp_y, *vjp_l, vjp_params)