Ejemplo n.º 1
0
def test_mc_plate_gaussian():
    log_measure = Gaussian(torch.tensor([0.]), torch.tensor([[1.]]),
                           (('loc', reals()),)) + torch.tensor(-0.9189)
    integrand = Gaussian(torch.randn((100, 1)) + 3., torch.ones((100, 1, 1)),
                         (('data', bint(100)), ('loc', reals())))

    res = Integrate(log_measure.sample(frozenset({'loc'})), integrand, frozenset({'loc'}))
    res = res.reduce(ops.mul, frozenset({'data'}))
    assert not torch.isinf(res).any()
Ejemplo n.º 2
0
def test_mc_plate_gaussian():
    log_measure = Gaussian(numeric_array([0.]), numeric_array([[1.]]),
                           (('loc', Real),)) + numeric_array(-0.9189)
    integrand = Gaussian(randn((100, 1)) + 3., ones((100, 1, 1)),
                         (('data', Bint[100]), ('loc', Real)))

    rng_key = None if get_backend() != 'jax' else np.array([0, 0], dtype=np.uint32)
    res = Integrate(log_measure.sample('loc', rng_key=rng_key), integrand, 'loc')
    res = res.reduce(ops.mul, 'data')
    assert not ((res == float('inf')) | (res == float('-inf'))).any()