示例#1
0
    def sample_and_weight(self, y, x):
        mean, std = self._model.hidden.mean_scale(x)

        std = std.clone()
        x_copy = x.copy(values=mean)

        # TODO: Would optimally build density utilizing the mean and scale from above
        x_dist = self._model.hidden.build_density(x)

        for _ in range(self._n_steps):
            mean.requires_grad_(True)

            y_dist = self._model.build_density(x_copy)
            logl = y_dist.log_prob(y) + x_dist.log_prob(mean)
            g = grad(logl,
                     mean,
                     grad_outputs=torch.ones_like(logl),
                     create_graph=self._use_second_order)[-1]

            ones_like_g = torch.ones_like(g)
            step = self._alpha * ones_like_g
            if self._use_second_order:
                neg_inv_hess = -grad(g, mean,
                                     grad_outputs=ones_like_g)[-1].pow(-1.0)

                # TODO: There is a better approach in Dahlin, find it
                mask = neg_inv_hess > 0.0

                step[mask] = neg_inv_hess[mask]
                std[mask] = neg_inv_hess[mask].sqrt()

                g.detach_()

            mean.detach_()
            mean += step * g

        kernel = Normal(mean, std, validate_args=False)
        if not self._is1d:
            kernel = kernel.to_event(1)

        x_result = x_copy.copy(values=kernel.sample)

        return x_result, self._weight_with_kernel(y, x_dist, x_result, kernel)