def testLogProbParam(self, dist_name, data): if (dist_name in self.logprob_param_blocklist) != FLAGS.blocklists_only: self.skipTest('Distribution currently broken.') params, batch_shape = data.draw( dhps.base_distribution_unconstrained_params(enable_vars=False, dist_name=dist_name)) constrained_params = dhps.constrain_params(params, dist_name) sampling_dist = data.draw( dhps.base_distributions(batch_shape=batch_shape, enable_vars=False, dist_name=dist_name, params=constrained_params)) sample = sampling_dist.sample(seed=random.PRNGKey(0)) def _log_prob(dist): return dist.log_prob(sample) for param_name, param, dist_func, func in self._param_func_generator( data, dist_name, params, batch_shape, _log_prob): del dist_func self._test_transformation(functools.partial(func, param_name), param, msg=param_name)
def testLogProbSample(self, dist_name, data): if (dist_name in self.logprob_sample_blocklist) != FLAGS.blocklists_only: self.skipTest('Distribution currently broken.') params, batch_shape = data.draw( dhps.base_distribution_unconstrained_params(enable_vars=False, dist_name=dist_name)) constrained_params = dhps.constrain_params(params, dist_name) dist = data.draw( dhps.base_distributions(batch_shape=batch_shape, enable_vars=False, dist_name=dist_name, params=constrained_params)) sample = dist.sample(seed=random.PRNGKey(0)) if np.issubdtype(sample.dtype, np.integer): self.skipTest( '{} has integer samples; no derivative.'.format(dist_name)) def _log_prob(sample): return dist.log_prob(sample) self._test_transformation(_log_prob, sample)
def testSample(self, dist_name, data): if (dist_name in self.sample_blocklist) != FLAGS.blocklists_only: self.skipTest('Distribution currently broken.') def _sample(dist): return dist.sample(seed=random.PRNGKey(0)) params_unconstrained, batch_shape = data.draw( dhps.base_distribution_unconstrained_params(enable_vars=False, dist_name=dist_name)) for (param_name, unconstrained_param, dist_func, func) in self._param_func_generator(data, dist_name, params_unconstrained, batch_shape, _sample): dist = dist_func(param_name, unconstrained_param) if (dist.reparameterization_type != reparameterization.FULLY_REPARAMETERIZED): # Skip distributions that don't support differentiable sampling. self.skipTest('{} is not reparameterized.'.format(dist_name)) self._test_transformation(functools.partial(func, param_name), unconstrained_param, msg=param_name)