def test_dirichlet_multinomial_density(batch_shape, event_shape): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape)) max_count = 10 @funsor.function def dirichlet_multinomial(concentration: Reals[event_shape], total_count: Real, value: Reals[event_shape]) -> Real: return backend_dist.DirichletMultinomial(concentration, total_count).log_prob(value) check_funsor(dirichlet_multinomial, {'concentration': Reals[event_shape], 'total_count': Real, 'value': Reals[event_shape]}, Real) concentration = Tensor(ops.exp(randn(batch_shape + event_shape)), inputs) value_data = ops.astype(randint(0, max_count, size=batch_shape + event_shape), 'float32') total_count_data = value_data.sum(-1) + ops.astype(randint(0, max_count, size=batch_shape), 'float32') value = Tensor(value_data, inputs) total_count = Tensor(total_count_data, inputs) expected = dirichlet_multinomial(concentration, total_count, value) check_funsor(expected, inputs, Real) actual = dist.DirichletMultinomial(concentration, total_count, value) check_funsor(actual, inputs, Real) assert_close(actual, expected)
def test_multinomial_density(batch_shape, event_shape): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape)) max_count = 10 @funsor.function def multinomial(total_count: Real, probs: Reals[event_shape], value: Reals[event_shape]) -> Real: if get_backend() == "torch": total_count = total_count.max().item() return backend_dist.Multinomial(total_count, probs).log_prob(value) check_funsor(multinomial, {'total_count': Real, 'probs': Reals[event_shape], 'value': Reals[event_shape]}, Real) probs_data = rand(batch_shape + event_shape) probs_data = probs_data / probs_data.sum(-1)[..., None] probs = Tensor(probs_data, inputs) value_data = ops.astype(randint(0, max_count, size=batch_shape + event_shape), 'float') total_count_data = value_data.sum(-1) value = Tensor(value_data, inputs) total_count = Tensor(total_count_data, inputs) expected = multinomial(total_count, probs, value) check_funsor(expected, inputs, Real) actual = dist.Multinomial(total_count, probs, value) check_funsor(actual, inputs, Real) assert_close(actual, expected)
def test_dirichlet_multinomial_conjugate(batch_shape, size): max_count = 10 batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape)) full_shape = batch_shape + (size,) prior = Variable("prior", Reals[size]) concentration = Tensor(ops.exp(randn(full_shape)), inputs) value_data = ops.astype(randint(0, max_count, size=full_shape), 'float32') obs = Tensor(value_data, inputs) total_count_data = value_data.sum(-1) total_count = Tensor(total_count_data, inputs) latent = dist.Dirichlet(concentration, value=prior) conditional = dist.Multinomial(probs=prior, total_count=total_count) p = latent + conditional marginalized = p.reduce(ops.logaddexp, set(["value"])) assert isinstance(marginalized, dist.Dirichlet) reduced = p.reduce(ops.logaddexp, set(["prior"])) assert isinstance(reduced, dist.DirichletMultinomial) assert_close(reduced.concentration, concentration) assert_close(reduced.total_count, total_count) result = (p - reduced)(value=obs) assert isinstance(result, dist.Dirichlet) assert_close(result.concentration, concentration + obs) _assert_conjugate_density_ok(latent, conditional, obs)
def test_dirichlet_multinomial_conjugate_plate(batch_shape, size): max_count = 10 batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape)) full_shape = batch_shape + (size,) prior = Variable("prior", Reals[size]) concentration = Tensor(ops.exp(randn(full_shape)), inputs) value_data = ops.astype(randint(0, max_count, size=batch_shape + (7, size)), 'float32') obs_inputs = inputs.copy() obs_inputs['plate'] = Bint[7] obs = Tensor(value_data, obs_inputs) total_count_data = value_data.sum(-1) total_count = Tensor(total_count_data, obs_inputs) latent = dist.Dirichlet(concentration, value=prior) conditional = dist.Multinomial(probs=prior, total_count=total_count, value=obs) p = latent + conditional.reduce(ops.add, 'plate') reduced = p.reduce(ops.logaddexp, 'prior') assert isinstance(reduced, Tensor) _assert_conjugate_density_ok(latent, conditional, obs)