def test_beta_sample_grad(self):
     self._set_rng_seed()
     num_samples = 20
     for alpha, beta in product([1e-2, 1e0, 1e2], [1e-2, 1e0, 1e2]):
         alphas = Variable(torch.Tensor([alpha] * num_samples),
                           requires_grad=True)
         betas = Variable(torch.Tensor([beta] * num_samples))
         x = Beta(alphas, betas).rsample()
         x.sum().backward()
         x, ind = x.data.sort()
         x = x.numpy()
         actual_grad = alphas.grad.data[ind].numpy()
         # Compare with expected gradient dx/dalpha along constant cdf(x,alpha,beta).
         cdf = scipy.stats.beta.cdf
         pdf = scipy.stats.beta.pdf
         eps = 0.02 * alpha / (1.0 + np.sqrt(alpha))
         cdf_alpha = (cdf(x, alpha + eps, beta) -
                      cdf(x, alpha - eps, beta)) / (2 * eps)
         cdf_x = pdf(x, alpha, beta)
         expected_grad = -cdf_alpha / cdf_x
         rel_error = np.abs(actual_grad - expected_grad) / (expected_grad +
                                                            1e-100)
         self.assertLess(
             np.max(rel_error), 0.01, '\n'.join([
                 'Bad gradients for Beta({}, {})'.format(alpha, beta),
                 'x {}'.format(x), 'expected {}'.format(expected_grad),
                 'actual {}'.format(actual_grad),
                 'rel error {}'.format(rel_error),
                 'max error {}'.format(rel_error.max())
             ]))
Beispiel #2
0
def log_prob_density(x, dist_args, args):
    if args.stat_policy == "Gaussian":
        log_prob_density = -(x - dist_args[0]).pow(2) / (2 * dist_args[1].pow(2)) \
                         - 0.5 * math.log(2 * math.pi)
    elif args.stat_policy == "Beta":
        log_prob_density = Beta(dist_args[0], dist_args[1]).log_prob(x)
    return log_prob_density.sum(1, keepdim=True)
Beispiel #3
0
    def log_prior(self, reals, params, idx):
        lp = 0.0
        for key in reals:
            if key == 'v':
                if self.use_stick_break:
                    tmp = Beta(params['alpha'], 1).log_prob(params['v'])
                else:
                    tmp = Beta(params['alpha'] / self.K,
                               1).log_prob(params['v'])
                tmp += self.mp['v'].logabsdetJ(reals['v'], params['v'])
                lp += tmp.sum()
            elif key != 'y':
                tmp = self.priors[key].log_prob(params[key])
                tmp += self.mp[key].logabsdetJ(reals[key], params[key])
                lp += tmp.sum()

        return lp