def testLogProbParam(self, dist_name, data): if (dist_name in self.logprob_param_blocklist) != FLAGS.blocklists_only: self.skipTest('Distribution currently broken.') params, batch_shape = data.draw( dhps.base_distribution_unconstrained_params(enable_vars=False, dist_name=dist_name)) constrained_params = dhps.constrain_params(params, dist_name) sampling_dist = data.draw( dhps.base_distributions(batch_shape=batch_shape, enable_vars=False, dist_name=dist_name, params=constrained_params)) sample = sampling_dist.sample(seed=random.PRNGKey(0)) def _log_prob(dist): return dist.log_prob(sample) for param_name, param, dist_func, func in self._param_func_generator( data, dist_name, params, batch_shape, _log_prob): del dist_func self._test_transformation(functools.partial(func, param_name), param, msg=param_name)
def testLogProbSample(self, dist_name, data): if (dist_name in self.logprob_sample_blocklist) != FLAGS.blocklists_only: self.skipTest('Distribution currently broken.') params, batch_shape = data.draw( dhps.base_distribution_unconstrained_params(enable_vars=False, dist_name=dist_name)) constrained_params = dhps.constrain_params(params, dist_name) dist = data.draw( dhps.base_distributions(batch_shape=batch_shape, enable_vars=False, dist_name=dist_name, params=constrained_params)) sample = dist.sample(seed=random.PRNGKey(0)) if np.issubdtype(sample.dtype, np.integer): self.skipTest( '{} has integer samples; no derivative.'.format(dist_name)) def _log_prob(sample): return dist.log_prob(sample) self._test_transformation(_log_prob, sample)
def _make_distribution(self, dist_name, params, batch_shape, override_params=None): override_params = override_params or {} all_params = dict(params) for param_name, override_param in override_params.items(): all_params[param_name] = override_param all_params = dhps.constrain_params(all_params, dist_name) all_params = dhps.modify_params(all_params, dist_name, validate_args=False) return dhps.base_distributions( enable_vars=False, dist_name=dist_name, params=all_params, batch_shape=batch_shape, validate_args=False)