Esempio n. 1
0
def test_factor():
    def model():
        a = pyro.sample("a", dist.Normal(0, 1))
        pyro.factor("b", torch.tensor(0.0))
        pyro.factor("c", a)

    actual = get_dependencies(model)
    expected = {
        "prior_dependencies": {
            "a": {
                "a": set()
            },
            "b": {
                "b": set()
            },
            "c": {
                "c": set(),
                "a": set()
            },
        },
        "posterior_dependencies": {
            "a": {
                "a": set(),
                "c": set()
            },
        },
    }
    assert actual == expected
Esempio n. 2
0
def test_docstring_example_3():
    def model_3():
        with pyro.plate("p", 5):
            a = pyro.sample("a", dist.Normal(0, 1))
        pyro.sample("b", dist.Normal(a.sum(), 1), obs=torch.tensor(0.0))

    actual = get_dependencies(model_3)
    expected = {
        "prior_dependencies": {
            "a": {
                "a": set()
            },
            "b": {
                "a": set(),
                "b": set()
            },
        },
        "posterior_dependencies": {
            "a": {
                "a": {"p"},
                "b": set()
            },
        },
    }
    assert actual == expected
Esempio n. 3
0
def test_nested_plate_collider():
    # a a       b b
    #  a a     b b
    #    \\   //
    #      c c
    #       |
    #       d

    def model():
        plate_i = pyro.plate("i", 2, dim=-1)
        plate_j = pyro.plate("j", 3, dim=-2)
        plate_k = pyro.plate("k", 3, dim=-2)

        with plate_i:
            with plate_j:
                a = pyro.sample("a", dist.Normal(0, 1))
            with plate_k:
                b = pyro.sample("b", dist.Normal(0, 1))
            c = pyro.sample("c", dist.Normal(a.sum(0) + b.sum([0, 1]), 1))
        pyro.sample("d", dist.Normal(c.sum(), 1), obs=torch.zeros(()))

    actual = get_dependencies(model)
    _ = set()
    expected = {
        "prior_dependencies": {
            "a": {
                "a": _
            },
            "b": {
                "b": _
            },
            "c": {
                "c": _,
                "a": _,
                "b": _
            },
            "d": {
                "d": _,
                "c": _
            },
        },
        "posterior_dependencies": {
            "a": {
                "a": {"j"},
                "b": _,
                "c": _
            },
            "b": {
                "b": {"k"},
                "c": _
            },
            "c": {
                "c": {"i"},
                "d": _
            },
        },
    }
    assert actual == expected
Esempio n. 4
0
def test_plate_collider():
    #   x x    y y
    #     \\  //
    #      zzzz
    #
    # This results in posterior dependency structure:
    #
    #     x x y y z z z z
    #   x ?   ? ? ? ?
    #   x   ? ? ?     ? ?
    #   y     ?   ?   ?
    #   y       ?   ?   ?

    def model(data):
        i_plate = pyro.plate("i", data.shape[0], dim=-2)
        j_plate = pyro.plate("j", data.shape[1], dim=-1)

        with i_plate:
            x = pyro.sample("x", dist.Normal(0, 1))
        with j_plate:
            y = pyro.sample("y", dist.Normal(0, 1))
        with i_plate, j_plate:
            pyro.sample("z", dist.Normal(x, y.exp()), obs=data)

    data = torch.randn(3, 2)
    actual = get_dependencies(model, (data, ))
    _ = set()
    expected = {
        "prior_dependencies": {
            "x": {
                "x": _
            },
            "y": {
                "y": _
            },
            "z": {
                "x": _,
                "y": _,
                "z": _
            },
        },
        "posterior_dependencies": {
            "x": {
                "x": _,
                "y": _,
                "z": _
            },
            "y": {
                "y": _,
                "z": _
            },
        },
    }
    assert actual == expected
Esempio n. 5
0
def test_plate_coupling_3():
    #    x x x x
    #     // \\
    #   y y   z z
    #
    # This results in posterior dependency structure:
    #
    #     x x y y z
    #   x ? ? ? ? ?
    #   x ? ? ? ? ?
    #   y     ? ? ?
    #   y     ? ? ?

    def model(data):
        i_plate = pyro.plate("i", data.shape[0], dim=-2)
        j_plate = pyro.plate("j", data.shape[1], dim=-1)
        with i_plate, j_plate:
            x = pyro.sample("x", dist.Normal(0, 1))
        with i_plate:
            pyro.sample("y",
                        dist.Normal(x.sum(-1, True), 1),
                        obs=data.sum(-1, True))
        with j_plate:
            pyro.sample("z",
                        dist.Normal(x.sum(-2, True), 1),
                        obs=data.sum(-2, True))

    data = torch.randn(3, 2)
    actual = get_dependencies(model, (data, ))
    expected = {
        "prior_dependencies": {
            "x": {
                "x": set()
            },
            "y": {
                "y": set(),
                "x": set()
            },
            "z": {
                "z": set(),
                "x": set()
            },
        },
        "posterior_dependencies": {
            "x": {
                "x": {"i", "j"},
                "y": set(),
                "z": set()
            },
        },
    }
    assert actual == expected
Esempio n. 6
0
def test_discrete():
    def model():
        a = pyro.sample("a", dist.Dirichlet(torch.ones(3)))
        b = pyro.sample("b", dist.Categorical(a))
        c = pyro.sample("c", dist.Normal(torch.zeros(3), 1).to_event(1))
        d = pyro.sample("d", dist.Poisson(c[b].exp()))
        pyro.sample("e", dist.Normal(d, 1), obs=torch.ones(()))

    actual = get_dependencies(model)
    expected = {
        "prior_dependencies": {
            "a": {
                "a": set()
            },
            "b": {
                "a": set(),
                "b": set()
            },
            "c": {
                "c": set()
            },
            "d": {
                "b": set(),
                "c": set(),
                "d": set()
            },
            "e": {
                "d": set(),
                "e": set()
            },
        },
        "posterior_dependencies": {
            "a": {
                "a": set(),
                "b": set()
            },
            "b": {
                "b": set(),
                "c": set(),
                "d": set()
            },
            "c": {
                "c": set(),
                "d": set()
            },
            "d": {
                "d": set(),
                "e": set()
            },
        },
    }
    assert actual == expected
Esempio n. 7
0
def test_plate_coupling_2():
    #   x x
    #     \\   y y
    #      \\ //
    #        z
    #
    # This results in posterior dependency structure:
    #
    #     x x y y z
    #   x ? ? ? ? ?
    #   x ? ? ? ? ?
    #   y     ? ? ?
    #   y     ? ? ?

    def model(data):
        with pyro.plate("p", len(data)):
            x = pyro.sample("x", dist.Normal(0, 1))
            y = pyro.sample("y", dist.Normal(0, 1))
        pyro.sample("z", dist.Normal(x.sum(), y.sum().exp()), obs=data.sum())

    data = torch.randn(2)
    actual = get_dependencies(model, (data, ))
    expected = {
        "prior_dependencies": {
            "x": {
                "x": set()
            },
            "y": {
                "y": set()
            },
            "z": {
                "z": set(),
                "x": set(),
                "y": set()
            },
        },
        "posterior_dependencies": {
            "x": {
                "x": {"p"},
                "y": {"p"},
                "z": set()
            },
            "y": {
                "y": {"p"},
                "z": set()
            },
        },
    }
    assert actual == expected
Esempio n. 8
0
def test_docstring_example_2():
    def model_2():
        a = pyro.sample("a", dist.Normal(0, 1))
        b = pyro.sample("b", dist.LogNormal(0, 1))
        c = pyro.sample("c", dist.Normal(a, b))
        pyro.sample("d", dist.Normal(c, 1), obs=torch.tensor(0.0))

    actual = get_dependencies(model_2)
    expected = {
        "prior_dependencies": {
            "a": {
                "a": set()
            },
            "b": {
                "b": set()
            },
            "c": {
                "a": set(),
                "b": set(),
                "c": set()
            },
            "d": {
                "c": set(),
                "d": set()
            },
        },
        "posterior_dependencies": {
            "a": {
                "a": set(),
                "b": set(),
                "c": set()
            },
            "b": {
                "b": set(),
                "c": set()
            },
            "c": {
                "c": set(),
                "d": set()
            },
        },
    }
    assert actual == expected
Esempio n. 9
0
def test_plate_dependency():
    #   w                              w
    #     \  x1 x2      unroll    x1  / \  x2
    #      \  || y1 y2  =====>  y1 | /   \ | y2
    #       \ || //               \|/     \|/
    #        z1 z2                z1       z2
    #
    # This allows posterior dependency structure:
    #
    #     w x x y y z z
    #   w ? ? ? ? ? ? ?
    #   x   ?   ?   ?
    #   x     ?   ?   ?
    #   y       ?   ?
    #   y         ?   ?

    def model(data):
        w = pyro.sample("w", dist.Normal(0, 1))
        with pyro.plate("p", len(data)):
            x = pyro.sample("x", dist.Normal(0, 1))
            y = pyro.sample("y", dist.Normal(0, 1))
            pyro.sample("z", dist.Normal(w + x + y, 1), obs=data)

    data = torch.rand(2)
    actual = get_dependencies(model, (data, ))
    _ = set()
    expected = {
        "prior_dependencies": {
            "w": {
                "w": _
            },
            "x": {
                "x": _
            },
            "y": {
                "y": _
            },
            "z": {
                "w": _,
                "x": _,
                "y": _,
                "z": _
            },
        },
        "posterior_dependencies": {
            "w": {
                "w": _,
                "x": _,
                "y": _,
                "z": _
            },
            "x": {
                "x": _,
                "y": _,
                "z": _
            },
            "y": {
                "y": _,
                "z": _
            },
        },
    }
    assert actual == expected
Esempio n. 10
0
def test_discrete_obs():
    def model():
        a = pyro.sample("a", dist.Normal(0, 1))
        b = pyro.sample("b",
                        dist.Normal(a[..., None], torch.ones(3)).to_event(1))
        c = pyro.sample(
            "c",
            dist.MultivariateNormal(
                torch.zeros(3) + a[..., None], torch.eye(3)))
        with pyro.plate("i", 2):
            d = pyro.sample("d", dist.Dirichlet((b + c).exp()))
            pyro.sample("e",
                        dist.Categorical(logits=d),
                        obs=torch.tensor([0, 0]))
        return a, b, c, d

    actual = get_dependencies(model)
    expected = {
        "prior_dependencies": {
            "a": {
                "a": set()
            },
            "b": {
                "a": set(),
                "b": set()
            },
            "c": {
                "a": set(),
                "c": set()
            },
            "d": {
                "b": set(),
                "c": set(),
                "d": set()
            },
            "e": {
                "d": set(),
                "e": set()
            },
        },
        "posterior_dependencies": {
            "a": {
                "a": set(),
                "b": set(),
                "c": set()
            },
            "b": {
                "b": set(),
                "c": set(),
                "d": set()
            },
            "c": {
                "c": set(),
                "d": set()
            },
            "d": {
                "d": set(),
                "e": set()
            },
        },
    }
    assert actual == expected
Esempio n. 11
0
def test_get_dependencies(grad_enabled):
    def model(data):
        a = pyro.sample("a", dist.Normal(0, 1))
        b = pyro.sample("b", NonreparameterizedNormal(a, 0))
        c = pyro.sample("c", dist.Normal(b, 1))
        d = pyro.sample("d", dist.Normal(a, c.exp()))

        e = pyro.sample("e", dist.Normal(0, 1))
        f = pyro.sample("f", dist.Normal(0, 1))
        g = pyro.sample("g",
                        dist.Bernoulli(logits=e + f),
                        obs=torch.tensor(0.0))

        with pyro.plate("p", len(data)):
            d_ = d.detach()  # this results in a known failure
            h = pyro.sample("h", dist.Normal(c, d_.exp()))
            i = pyro.deterministic("i", h + 1)
            j = pyro.sample("j", dist.Delta(h + 1), obs=h + 1)
            k = pyro.sample("k", dist.Normal(a, j.exp()), obs=data)

        return [a, b, c, d, e, f, g, h, i, j, k]

    data = torch.randn(3)
    with torch.set_grad_enabled(grad_enabled):
        actual = get_dependencies(model, (data, ))
    _ = set()
    expected = {
        "prior_dependencies": {
            "a": {
                "a": _
            },
            "b": {
                "b": _,
                "a": _
            },
            "c": {
                "c": _,
                "b": _
            },
            "d": {
                "d": _,
                "c": _,
                "a": _
            },
            "e": {
                "e": _
            },
            "f": {
                "f": _
            },
            "g": {
                "g": _,
                "e": _,
                "f": _
            },
            "h": {
                "h": _,
                "c": _,
                "d": _
            },
            "k": {
                "k": _,
                "a": _,
                "h": _
            },
        },
        "posterior_dependencies": {
            "a": {
                "a": _,
                "b": _,
                "c": _,
                "d": _,
                "h": _,
                "k": _
            },
            "b": {
                "b": _,
                "c": _
            },
            "c": {
                "c": _,
                "d": _,
                "h": _
            },
            "d": {
                "d": _,
                "h": _
            },
            "e": {
                "e": _,
                "g": _,
                "f": _
            },
            "f": {
                "f": _,
                "g": _
            },
            "h": {
                "h": _,
                "k": _
            },
        },
    }
    assert actual == expected