Beispiel #1
0
def test_normal(shape, dim, smooth):
    loc = torch.empty(shape).uniform_(-1., 1.).requires_grad_()
    scale = torch.empty(shape).uniform_(0.5, 1.5).requires_grad_()

    def model():
        with pyro.plate_stack("plates", shape[:dim]):
            with pyro.plate("particles", 10000):
                pyro.sample(
                    "x",
                    dist.Normal(loc, scale).expand(shape).to_event(-dim))

    value = poutine.trace(model).get_trace().nodes["x"]["value"]
    expected_probe = get_moments(value)

    rep = DiscreteCosineReparam(dim=dim, smooth=smooth)
    reparam_model = poutine.reparam(model, {"x": rep})
    trace = poutine.trace(reparam_model).get_trace()
    assert isinstance(trace.nodes["x_dct"]["fn"], dist.TransformedDistribution)
    assert isinstance(trace.nodes["x"]["fn"], dist.Delta)
    value = trace.nodes["x"]["value"]
    actual_probe = get_moments(value)
    assert_close(actual_probe, expected_probe, atol=0.1)

    for actual_m, expected_m in zip(actual_probe[:10], expected_probe[:10]):
        expected_grads = grad(expected_m.sum(), [loc, scale],
                              retain_graph=True)
        actual_grads = grad(actual_m.sum(), [loc, scale], retain_graph=True)
        assert_close(actual_grads[0], expected_grads[0], atol=0.05)
        assert_close(actual_grads[1], expected_grads[1], atol=0.05)
Beispiel #2
0
def test_init(shape, dim, smooth):
    loc = torch.empty(shape).uniform_(-1.0, 1.0).requires_grad_()
    scale = torch.empty(shape).uniform_(0.5, 1.5).requires_grad_()

    def model():
        with pyro.plate_stack("plates", shape[:dim]):
            return pyro.sample("x", dist.Normal(loc, scale).to_event(-dim))

    check_init_reparam(model, DiscreteCosineReparam(dim=dim, smooth=smooth))
Beispiel #3
0
def time_reparam_haar(msg):
    """
    EXPERIMENTAL Configures ``poutine.reparam()`` to use a ``DiscreteCosineReparam`` for
    all sites inside the ``time`` plate.
    """
    if msg["is_observed"]:
        return
    for frame in msg["cond_indep_stack"]:
        if frame.name == "time":
            dim = frame.dim - msg["fn"].event_dim
            return DiscreteCosineReparam(dim=dim, experimental_allow_batch=True)
Beispiel #4
0
def test_uniform(shape, dim, smooth):
    def model():
        with pyro.plate_stack("plates", shape[:dim]):
            with pyro.plate("particles", 10000):
                pyro.sample("x",
                            dist.Uniform(0, 1).expand(shape).to_event(-dim))

    value = poutine.trace(model).get_trace().nodes["x"]["value"]
    expected_probe = get_moments(value)

    reparam_model = poutine.reparam(
        model, {"x": DiscreteCosineReparam(dim=dim, smooth=smooth)})
    trace = poutine.trace(reparam_model).get_trace()
    assert isinstance(trace.nodes["x_dct"]["fn"], dist.TransformedDistribution)
    assert isinstance(trace.nodes["x"]["fn"], dist.Delta)
    value = trace.nodes["x"]["value"]
    actual_probe = get_moments(value)
    assert_close(actual_probe, expected_probe, atol=0.1)