def sample_and_weight(self, y, x): mean, std = self._model.hidden.mean_scale(x) std = std.clone() x_copy = x.copy(values=mean) # TODO: Would optimally build density utilizing the mean and scale from above x_dist = self._model.hidden.build_density(x) for _ in range(self._n_steps): mean.requires_grad_(True) y_dist = self._model.build_density(x_copy) logl = y_dist.log_prob(y) + x_dist.log_prob(mean) g = grad(logl, mean, grad_outputs=torch.ones_like(logl), create_graph=self._use_second_order)[-1] ones_like_g = torch.ones_like(g) step = self._alpha * ones_like_g if self._use_second_order: neg_inv_hess = -grad(g, mean, grad_outputs=ones_like_g)[-1].pow(-1.0) # TODO: There is a better approach in Dahlin, find it mask = neg_inv_hess > 0.0 step[mask] = neg_inv_hess[mask] std[mask] = neg_inv_hess[mask].sqrt() g.detach_() mean.detach_() mean += step * g kernel = Normal(mean, std, validate_args=False) if not self._is1d: kernel = kernel.to_event(1) x_result = x_copy.copy(values=kernel.sample) return x_result, self._weight_with_kernel(y, x_dist, x_result, kernel)