def test_normal(shape, dim, smooth): loc = torch.empty(shape).uniform_(-1., 1.).requires_grad_() scale = torch.empty(shape).uniform_(0.5, 1.5).requires_grad_() def model(): with pyro.plate_stack("plates", shape[:dim]): with pyro.plate("particles", 10000): pyro.sample( "x", dist.Normal(loc, scale).expand(shape).to_event(-dim)) value = poutine.trace(model).get_trace().nodes["x"]["value"] expected_probe = get_moments(value) rep = DiscreteCosineReparam(dim=dim, smooth=smooth) reparam_model = poutine.reparam(model, {"x": rep}) trace = poutine.trace(reparam_model).get_trace() assert isinstance(trace.nodes["x_dct"]["fn"], dist.TransformedDistribution) assert isinstance(trace.nodes["x"]["fn"], dist.Delta) value = trace.nodes["x"]["value"] actual_probe = get_moments(value) assert_close(actual_probe, expected_probe, atol=0.1) for actual_m, expected_m in zip(actual_probe[:10], expected_probe[:10]): expected_grads = grad(expected_m.sum(), [loc, scale], retain_graph=True) actual_grads = grad(actual_m.sum(), [loc, scale], retain_graph=True) assert_close(actual_grads[0], expected_grads[0], atol=0.05) assert_close(actual_grads[1], expected_grads[1], atol=0.05)
def test_init(shape, dim, smooth): loc = torch.empty(shape).uniform_(-1.0, 1.0).requires_grad_() scale = torch.empty(shape).uniform_(0.5, 1.5).requires_grad_() def model(): with pyro.plate_stack("plates", shape[:dim]): return pyro.sample("x", dist.Normal(loc, scale).to_event(-dim)) check_init_reparam(model, DiscreteCosineReparam(dim=dim, smooth=smooth))
def time_reparam_haar(msg): """ EXPERIMENTAL Configures ``poutine.reparam()`` to use a ``DiscreteCosineReparam`` for all sites inside the ``time`` plate. """ if msg["is_observed"]: return for frame in msg["cond_indep_stack"]: if frame.name == "time": dim = frame.dim - msg["fn"].event_dim return DiscreteCosineReparam(dim=dim, experimental_allow_batch=True)
def test_uniform(shape, dim, smooth): def model(): with pyro.plate_stack("plates", shape[:dim]): with pyro.plate("particles", 10000): pyro.sample("x", dist.Uniform(0, 1).expand(shape).to_event(-dim)) value = poutine.trace(model).get_trace().nodes["x"]["value"] expected_probe = get_moments(value) reparam_model = poutine.reparam( model, {"x": DiscreteCosineReparam(dim=dim, smooth=smooth)}) trace = poutine.trace(reparam_model).get_trace() assert isinstance(trace.nodes["x_dct"]["fn"], dist.TransformedDistribution) assert isinstance(trace.nodes["x"]["fn"], dist.Delta) value = trace.nodes["x"]["value"] actual_probe = get_moments(value) assert_close(actual_probe, expected_probe, atol=0.1)