Ejemplo n.º 1
0
def test_dirichlet_multinomial_density(batch_shape, event_shape):
    batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
    inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape))
    max_count = 10

    @funsor.function
    def dirichlet_multinomial(concentration: Reals[event_shape], total_count: Real,
                              value: Reals[event_shape]) -> Real:
        return backend_dist.DirichletMultinomial(concentration, total_count).log_prob(value)

    check_funsor(dirichlet_multinomial, {'concentration': Reals[event_shape],
                                         'total_count': Real,
                                         'value': Reals[event_shape]},
                 Real)

    concentration = Tensor(ops.exp(randn(batch_shape + event_shape)), inputs)
    value_data = ops.astype(randint(0, max_count, size=batch_shape + event_shape), 'float32')
    total_count_data = value_data.sum(-1) + ops.astype(randint(0, max_count, size=batch_shape), 'float32')
    value = Tensor(value_data, inputs)
    total_count = Tensor(total_count_data, inputs)
    expected = dirichlet_multinomial(concentration, total_count, value)
    check_funsor(expected, inputs, Real)
    actual = dist.DirichletMultinomial(concentration, total_count, value)
    check_funsor(actual, inputs, Real)
    assert_close(actual, expected)
Ejemplo n.º 2
0
def test_multinomial_density(batch_shape, event_shape):
    batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
    inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape))
    max_count = 10

    @funsor.function
    def multinomial(total_count: Real, probs: Reals[event_shape], value: Reals[event_shape]) -> Real:
        if get_backend() == "torch":
            total_count = total_count.max().item()
        return backend_dist.Multinomial(total_count, probs).log_prob(value)

    check_funsor(multinomial, {'total_count': Real, 'probs': Reals[event_shape], 'value': Reals[event_shape]},
                 Real)

    probs_data = rand(batch_shape + event_shape)
    probs_data = probs_data / probs_data.sum(-1)[..., None]
    probs = Tensor(probs_data, inputs)
    value_data = ops.astype(randint(0, max_count, size=batch_shape + event_shape), 'float')
    total_count_data = value_data.sum(-1)
    value = Tensor(value_data, inputs)
    total_count = Tensor(total_count_data, inputs)
    expected = multinomial(total_count, probs, value)
    check_funsor(expected, inputs, Real)
    actual = dist.Multinomial(total_count, probs, value)
    check_funsor(actual, inputs, Real)
    assert_close(actual, expected)
Ejemplo n.º 3
0
def test_dirichlet_multinomial_conjugate(batch_shape, size):
    max_count = 10
    batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
    inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape))
    full_shape = batch_shape + (size,)
    prior = Variable("prior", Reals[size])
    concentration = Tensor(ops.exp(randn(full_shape)), inputs)
    value_data = ops.astype(randint(0, max_count, size=full_shape), 'float32')
    obs = Tensor(value_data, inputs)
    total_count_data = value_data.sum(-1)
    total_count = Tensor(total_count_data, inputs)
    latent = dist.Dirichlet(concentration, value=prior)
    conditional = dist.Multinomial(probs=prior, total_count=total_count)
    p = latent + conditional
    marginalized = p.reduce(ops.logaddexp, set(["value"]))
    assert isinstance(marginalized, dist.Dirichlet)
    reduced = p.reduce(ops.logaddexp, set(["prior"]))
    assert isinstance(reduced, dist.DirichletMultinomial)
    assert_close(reduced.concentration, concentration)
    assert_close(reduced.total_count, total_count)
    result = (p - reduced)(value=obs)
    assert isinstance(result, dist.Dirichlet)
    assert_close(result.concentration, concentration + obs)

    _assert_conjugate_density_ok(latent, conditional, obs)
Ejemplo n.º 4
0
def test_dirichlet_multinomial_conjugate_plate(batch_shape, size):
    max_count = 10
    batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
    inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape))
    full_shape = batch_shape + (size,)
    prior = Variable("prior", Reals[size])
    concentration = Tensor(ops.exp(randn(full_shape)), inputs)
    value_data = ops.astype(randint(0, max_count, size=batch_shape + (7, size)), 'float32')
    obs_inputs = inputs.copy()
    obs_inputs['plate'] = Bint[7]
    obs = Tensor(value_data, obs_inputs)
    total_count_data = value_data.sum(-1)
    total_count = Tensor(total_count_data, obs_inputs)
    latent = dist.Dirichlet(concentration, value=prior)
    conditional = dist.Multinomial(probs=prior, total_count=total_count, value=obs)
    p = latent + conditional.reduce(ops.add, 'plate')
    reduced = p.reduce(ops.logaddexp, 'prior')
    assert isinstance(reduced, Tensor)

    _assert_conjugate_density_ok(latent, conditional, obs)