def test_mc_plate_gaussian(): log_measure = Gaussian(torch.tensor([0.]), torch.tensor([[1.]]), (('loc', reals()),)) + torch.tensor(-0.9189) integrand = Gaussian(torch.randn((100, 1)) + 3., torch.ones((100, 1, 1)), (('data', bint(100)), ('loc', reals()))) res = Integrate(log_measure.sample(frozenset({'loc'})), integrand, frozenset({'loc'})) res = res.reduce(ops.mul, frozenset({'data'})) assert not torch.isinf(res).any()
def test_mc_plate_gaussian(): log_measure = Gaussian(numeric_array([0.]), numeric_array([[1.]]), (('loc', Real),)) + numeric_array(-0.9189) integrand = Gaussian(randn((100, 1)) + 3., ones((100, 1, 1)), (('data', Bint[100]), ('loc', Real))) rng_key = None if get_backend() != 'jax' else np.array([0, 0], dtype=np.uint32) res = Integrate(log_measure.sample('loc', rng_key=rng_key), integrand, 'loc') res = res.reduce(ops.mul, 'data') assert not ((res == float('inf')) | (res == float('-inf'))).any()