예제 #1
0
def test_nested():
    shape = (5, 6)

    @poutine.reparam(config={
        "x": HaarReparam(dim=-1),
        "x_haar": HaarReparam(dim=-2)
    })
    def model():
        pyro.sample("x", dist.Normal(torch.zeros(shape), 1).to_event(2))

    # Try without initialization, e.g. in AutoGuide._setup_prototype().
    trace = poutine.trace(model).get_trace()
    assert {"x", "x_haar", "x_haar_haar"}.issubset(trace.nodes)
    assert trace.nodes["x"]["is_observed"]
    assert trace.nodes["x_haar"]["is_observed"]
    assert not trace.nodes["x_haar_haar"]["is_observed"]
    assert trace.nodes["x"]["value"].shape == shape

    # Try conditioning on x_haar_haar, e.g. in Predictive.
    x = torch.randn(shape)
    x_haar = HaarTransform(dim=-1)(x)
    x_haar_haar = HaarTransform(dim=-2)(x_haar)
    with poutine.condition(data={"x_haar_haar": x_haar_haar}):
        trace = poutine.trace(model).get_trace()
        assert {"x", "x_haar", "x_haar_haar"}.issubset(trace.nodes)
        assert trace.nodes["x"]["is_observed"]
        assert trace.nodes["x_haar"]["is_observed"]
        assert trace.nodes["x_haar_haar"]["is_observed"]
        assert_close(trace.nodes["x"]["value"], x)
        assert_close(trace.nodes["x_haar"]["value"], x_haar)
        assert_close(trace.nodes["x_haar_haar"]["value"], x_haar_haar)

    # Try with custom initialization.
    # This is required for autoguides and MCMC.
    with InitMessenger(init_to_value(values={"x": x})):
        trace = poutine.trace(model).get_trace()
        assert {"x", "x_haar", "x_haar_haar"}.issubset(trace.nodes)
        assert trace.nodes["x"]["is_observed"]
        assert trace.nodes["x_haar"]["is_observed"]
        assert not trace.nodes["x_haar_haar"]["is_observed"]
        assert_close(trace.nodes["x"]["value"], x)

    # Try conditioning on x.
    x = torch.randn(shape)
    with poutine.condition(data={"x": x}):
        trace = poutine.trace(model).get_trace()
        assert {"x", "x_haar", "x_haar_haar"}.issubset(trace.nodes)
        assert trace.nodes["x"]["is_observed"]
        assert trace.nodes["x_haar"]["is_observed"]
        # TODO Decide whether it is worth fixing this failing assertion.
        # See https://github.com/pyro-ppl/pyro/issues/2878
        # assert trace.nodes["x_haar_haar"]["is_observed"]
        assert_close(trace.nodes["x"]["value"], x)
예제 #2
0
파일: test_haar.py 프로젝트: zeta1999/pyro
def test_normal(shape, dim, flip):
    loc = torch.empty(shape).uniform_(-1., 1.).requires_grad_()
    scale = torch.empty(shape).uniform_(0.5, 1.5).requires_grad_()

    def model():
        with pyro.plate_stack("plates", shape[:dim]):
            with pyro.plate("particles", 10000):
                pyro.sample(
                    "x",
                    dist.Normal(loc, scale).expand(shape).to_event(-dim))

    value = poutine.trace(model).get_trace().nodes["x"]["value"]
    expected_probe = get_moments(value)

    rep = HaarReparam(dim=dim, flip=flip)
    reparam_model = poutine.reparam(model, {"x": rep})
    trace = poutine.trace(reparam_model).get_trace()
    assert isinstance(trace.nodes["x_haar"]["fn"],
                      dist.TransformedDistribution)
    assert isinstance(trace.nodes["x"]["fn"], dist.Delta)
    value = trace.nodes["x"]["value"]
    actual_probe = get_moments(value)
    assert_close(actual_probe, expected_probe, atol=0.1)

    for actual_m, expected_m in zip(actual_probe[:10], expected_probe[:10]):
        expected_grads = grad(expected_m.sum(), [loc, scale],
                              retain_graph=True)
        actual_grads = grad(actual_m.sum(), [loc, scale], retain_graph=True)
        assert_close(actual_grads[0], expected_grads[0], atol=0.05)
        assert_close(actual_grads[1], expected_grads[1], atol=0.05)
예제 #3
0
def test_init(shape, dim, flip):
    loc = torch.empty(shape).uniform_(-1.0, 1.0).requires_grad_()
    scale = torch.empty(shape).uniform_(0.5, 1.5).requires_grad_()

    def model():
        with pyro.plate_stack("plates", shape[:dim]):
            return pyro.sample("x", dist.Normal(loc, scale).to_event(-dim))

    check_init_reparam(model, HaarReparam(dim=dim, flip=flip))
예제 #4
0
def time_reparam_dct(msg):
    """
    EXPERIMENTAL Configures ``poutine.reparam()`` to use a ``HaarReparam`` for
    all sites inside the ``time`` plate.
    """
    if msg["is_observed"]:
        return
    for frame in msg["cond_indep_stack"]:
        if frame.name == "time":
            dim = frame.dim - msg["fn"].event_dim
            return HaarReparam(dim=dim, experimental_allow_batch=True)
예제 #5
0
    def reparam(self, model):
        """
        Wrap a model with ``poutine.reparam``.
        """
        # Transform to Haar coordinates.
        config = {}
        for name, dim in self.dims.items():
            config[name] = HaarReparam(dim=dim, flip=True)
        model = poutine.reparam(model, config)

        if self.split:
            # Split into low- and high-frequency parts.
            splits = [self.split, self.duration - self.split]
            config = {}
            for name, dim in self.dims.items():
                config[name + "_haar"] = SplitReparam(splits, dim=dim)
            model = poutine.reparam(model, config)

        return model
예제 #6
0
파일: test_haar.py 프로젝트: zeta1999/pyro
def test_uniform(shape, dim, flip):
    def model():
        with pyro.plate_stack("plates", shape[:dim]):
            with pyro.plate("particles", 10000):
                pyro.sample("x",
                            dist.Uniform(0, 1).expand(shape).to_event(-dim))

    value = poutine.trace(model).get_trace().nodes["x"]["value"]
    expected_probe = get_moments(value)

    reparam_model = poutine.reparam(model,
                                    {"x": HaarReparam(dim=dim, flip=flip)})
    trace = poutine.trace(reparam_model).get_trace()
    assert isinstance(trace.nodes["x_haar"]["fn"],
                      dist.TransformedDistribution)
    assert isinstance(trace.nodes["x"]["fn"], dist.Delta)
    value = trace.nodes["x"]["value"]
    actual_probe = get_moments(value)
    assert_close(actual_probe, expected_probe, atol=0.1)