def model():
     with pyro.markov() as m:
         with pyro.markov():
             with m:  # error here
                 pyro.sample("x",
                             dist.Categorical(torch.ones(4)),
                             infer={"enumerate": "parallel"})
Пример #2
0
    def testing():

        with pyro.markov():
            v1 = pyro.to_data(
                Tensor(torch.ones(2), OrderedDict([(str(1), funsor.Bint[2])]),
                       'real'))
            print(1, v1.shape)  # shapes should alternate
            assert v1.shape == (2, )

            with pyro.markov():
                v2 = pyro.to_data(
                    Tensor(torch.ones(2),
                           OrderedDict([(str(2), funsor.Bint[2])]), 'real'))
                print(2, v2.shape)  # shapes should alternate
                assert v2.shape == (2, 1)

                with pyro.markov():
                    v3 = pyro.to_data(
                        Tensor(torch.ones(2),
                               OrderedDict([(str(3), funsor.Bint[2])]),
                               'real'))
                    print(3, v3.shape)  # shapes should alternate
                    assert v3.shape == (2, )

                    with pyro.markov():
                        v4 = pyro.to_data(
                            Tensor(torch.ones(2),
                                   OrderedDict([(str(4), funsor.Bint[2])]),
                                   'real'))
                        print(4, v4.shape)  # shapes should alternate

                        assert v4.shape == (2, 1)
Пример #3
0
def model_8(weeks_data, days_data, history, vectorized):
    x_dim, y_dim, w_dim, z_dim = 3, 2, 2, 3
    x_init = pyro.param("x_init",
                        lambda: torch.rand(x_dim),
                        constraint=constraints.simplex)
    x_trans = pyro.param("x_trans",
                         lambda: torch.rand((x_dim, x_dim)),
                         constraint=constraints.simplex)
    y_probs = pyro.param("y_probs",
                         lambda: torch.rand(x_dim, y_dim),
                         constraint=constraints.simplex)
    w_init = pyro.param("w_init",
                        lambda: torch.rand(w_dim),
                        constraint=constraints.simplex)
    w_trans = pyro.param("w_trans",
                         lambda: torch.rand((w_dim, w_dim)),
                         constraint=constraints.simplex)
    z_probs = pyro.param("z_probs",
                         lambda: torch.rand(w_dim, z_dim),
                         constraint=constraints.simplex)

    x_prev = None
    weeks_loop = (pyro.vectorized_markov(
        name="weeks", size=len(weeks_data), dim=-1, history=history)
                  if vectorized else pyro.markov(range(len(weeks_data)),
                                                 history=history))
    for i in weeks_loop:
        if isinstance(i, int) and i == 0:
            x_probs = x_init
        else:
            x_probs = Vindex(x_trans)[x_prev]

        x_curr = pyro.sample("x_{}".format(i), dist.Categorical(x_probs))
        pyro.sample(
            "y_{}".format(i),
            dist.Categorical(Vindex(y_probs)[x_curr]),
            obs=weeks_data[i],
        )
        x_prev = x_curr

    w_prev = None
    days_loop = (pyro.vectorized_markov(
        name="days", size=len(days_data), dim=-1, history=history)
                 if vectorized else pyro.markov(range(len(days_data)),
                                                history=history))
    for j in days_loop:
        if isinstance(j, int) and j == 0:
            w_probs = w_init
        else:
            w_probs = Vindex(w_trans)[w_prev]

        w_curr = pyro.sample("w_{}".format(j), dist.Categorical(w_probs))
        pyro.sample(
            "z_{}".format(j),
            dist.Categorical(Vindex(z_probs)[w_curr]),
            obs=days_data[j],
        )
        w_prev = w_curr
Пример #4
0
 def model():
     p = pyro.param("p", torch.ones(3, 3))
     x = pyro.sample("x", dist.Categorical(p[0]))
     y = x
     for i in pyro.markov(range(10)):
         y = pyro.sample("y_{}".format(i), dist.Categorical(p[y]))
         z = y
         for j in pyro.markov(range(10)):
             z = pyro.sample("z_{}_{}".format(i, j), dist.Categorical(p[z]))
Пример #5
0
 def model():
     p = pyro.param("p_leaf", torch.ones(2, 2, 2))
     x = defaultdict(lambda: torch.tensor(0))
     y_axis = pyro.markov(range(grid_size), keep=True)
     for i in pyro.markov(range(grid_size)):
         for j in y_axis:
             if use_vindex:
                 probs = Vindex(p)[x[i - 1, j], x[i, j - 1]]
             else:
                 ind = torch.arange(2, dtype=torch.long)
                 probs = p[x[i - 1, j].unsqueeze(-1),
                           x[i, j - 1].unsqueeze(-1), ind]
             x[i, j] = pyro.sample("x_{}_{}".format(i, j),
                                   dist.Categorical(probs))
Пример #6
0
    def model():
        p = pyro.param("p", torch.ones(3, 3))
        q = pyro.param("q", torch.tensor([0.5, 0.5]))
        plate_x = pyro.plate("plate_x",
                             4,
                             subsample_size=3 if subsampling else None,
                             dim=-1)
        plate_y = pyro.plate("plate_y",
                             5,
                             subsample_size=3 if subsampling else None,
                             dim=-1)
        plate_z = pyro.plate("plate_z",
                             6,
                             subsample_size=3 if subsampling else None,
                             dim=-2)

        a = pyro.sample("a", dist.Bernoulli(q[0])).long()
        w = 0
        for i in pyro.markov(range(4)):
            w = pyro.sample("w_{}".format(i), dist.Categorical(p[w]))

        with plate_x:
            b = pyro.sample("b", dist.Bernoulli(q[a])).long()
            x = 0
            for i in pyro.markov(range(4)):
                x = pyro.sample("x_{}".format(i), dist.Categorical(p[x]))

        with plate_y:
            c = pyro.sample("c", dist.Bernoulli(q[a])).long()
            y = 0
            for i in pyro.markov(range(4)):
                y = pyro.sample("y_{}".format(i), dist.Categorical(p[y]))

        with plate_z:
            d = pyro.sample("d", dist.Bernoulli(q[a])).long()
            z = 0
            for i in pyro.markov(range(4)):
                z = pyro.sample("z_{}".format(i), dist.Categorical(p[z]))

        with plate_x, plate_z:
            # this part is tricky: how do we know to preserve b's dimension?
            # also, how do we know how to make b and d have different dimensions?
            e = pyro.sample("e",
                            dist.Bernoulli(q[b if reuse_plate else a])).long()
            xz = 0
            for i in pyro.markov(range(4)):
                xz = pyro.sample("xz_{}".format(i), dist.Categorical(p[xz]))

        return a, b, c, d, e
Пример #7
0
def model_2(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences, )
        assert lengths.max() <= max_length
    with handlers.mask(mask=include_prior):
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1))
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim, 2,
                                        data_dim]).to_event(3))
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        x, y = 0, 0
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with handlers.mask(mask=(t < lengths).unsqueeze(-1)):
                x = pyro.sample("x_{}".format(t),
                                dist.Categorical(probs_x[x]),
                                infer={"enumerate": "parallel"})
                # Note the broadcasting tricks here: to index probs_y on tensors x and y,
                # we also need a final tensor for the tones dimension. This is conveniently
                # provided by the plate associated with that dimension.
                with tones_plate as tones:
                    y = pyro.sample("y_{}".format(t),
                                    dist.Bernoulli(probs_y[x, y, tones]),
                                    obs=sequences[batch, t]).long()
Пример #8
0
def model_5(data, history, vectorized):
    x_dim, y_dim = 3, 2
    x_init = pyro.param("x_init",
                        lambda: torch.rand(x_dim),
                        constraint=constraints.simplex)
    x_init_2 = pyro.param("x_init_2",
                          lambda: torch.rand(x_dim, x_dim),
                          constraint=constraints.simplex)
    x_trans = pyro.param(
        "x_trans",
        lambda: torch.rand((x_dim, x_dim, x_dim)),
        constraint=constraints.simplex,
    )
    y_probs = pyro.param("y_probs",
                         lambda: torch.rand(x_dim, y_dim),
                         constraint=constraints.simplex)

    x_prev = x_prev_2 = None
    markov_loop = (pyro.vectorized_markov(
        name="time", size=len(data), dim=-2, history=history) if vectorized
                   else pyro.markov(range(len(data)), history=history))
    for i in markov_loop:
        if isinstance(i, int) and i == 0:
            x_probs = x_init
        elif isinstance(i, int) and i == 1:
            x_probs = Vindex(x_init_2)[x_prev]
        else:
            x_probs = Vindex(x_trans)[x_prev_2, x_prev]

        x_curr = pyro.sample("x_{}".format(i), dist.Categorical(x_probs))
        with pyro.plate("tones", data.shape[-1], dim=-1):
            pyro.sample("y_{}".format(i),
                        dist.Categorical(Vindex(y_probs)[x_curr]),
                        obs=data[i])
        x_prev_2, x_prev = x_prev, x_curr
Пример #9
0
def model_5(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences, )
        assert lengths.max() <= max_length

    # Initialize a global module instance if needed.
    global tones_generator
    if tones_generator is None:
        tones_generator = TonesGenerator(args, data_dim)
    pyro.module("tones_generator", tones_generator)

    with handlers.mask(mask=include_prior):
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1))
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        x = 0
        y = torch.zeros(data_dim)
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with handlers.mask(mask=(t < lengths).unsqueeze(-1)):
                x = pyro.sample("x_{}".format(t),
                                dist.Categorical(probs_x[x]),
                                infer={"enumerate": "parallel"})
                # Note that since each tone depends on all tones at a previous time step
                # the tones at different time steps now need to live in separate plates.
                with pyro.plate("tones_{}".format(t), data_dim, dim=-1):
                    y = pyro.sample(
                        "y_{}".format(t),
                        dist.Bernoulli(logits=tones_generator(x, y)),
                        obs=sequences[batch, t])
Пример #10
0
    def model():
        p = torch.tensor([[0.2, 0.8], [0.1, 0.9]])

        xs = [0]
        for t in pyro.markov(range(100), history=history):
            xs.append(pyro.sample("x_{}".format(t), dist.Categorical(p[xs[-1]])))
        assert all(x.dim() <= history + 1 for x in xs[1:])
Пример #11
0
def model_0(sequences, lengths, args, batch_size=None, include_prior=True):
    assert not torch._C._get_tracing_state()
    num_sequences, max_length, data_dim = sequences.shape
    with handlers.mask(mask=include_prior):
        # Our prior on transition probabilities will be:
        # stay in the same state with 90% probability; uniformly jump to another
        # state with 10% probability.
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1))
        # We put a weak prior on the conditional probability of a tone sounding.
        # We know that on average about 4 of 88 tones are active, so we'll set a
        # rough weak prior of 10% of the notes being active at any one time.
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim,
                                        data_dim]).to_event(2))
    # In this first model we'll sequentially iterate over sequences in a
    # minibatch; this will make it easy to reason about tensor shapes.
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    for i in pyro.plate("sequences", len(sequences), batch_size):
        length = lengths[i]
        sequence = sequences[i, :length]
        x = 0
        for t in pyro.markov(range(length)):
            # On the next line, we'll overwrite the value of x with an updated
            # value. If we wanted to record all x values, we could instead
            # write x[t] = pyro.sample(...x[t-1]...).
            x = pyro.sample("x_{}_{}".format(i, t),
                            dist.Categorical(probs_x[x]),
                            infer={"enumerate": "parallel"})
            with tones_plate:
                pyro.sample("y_{}_{}".format(i, t),
                            dist.Bernoulli(probs_y[x.squeeze(-1)]),
                            obs=sequence[t])
Пример #12
0
def model_0(data, history, vectorized):
    x_dim = 3
    init = pyro.param("init",
                      lambda: torch.rand(x_dim),
                      constraint=constraints.simplex)
    trans = pyro.param("trans",
                       lambda: torch.rand((x_dim, x_dim)),
                       constraint=constraints.simplex)
    locs = pyro.param("locs", lambda: torch.rand(x_dim))

    with pyro.plate("sequences", data.shape[0], dim=-3) as sequences:
        sequences = sequences[:, None]
        x_prev = None
        markov_loop = \
            pyro.vectorized_markov(name="time", size=data.shape[1], dim=-2, history=history) if vectorized \
            else pyro.markov(range(data.shape[1]), history=history)
        for i in markov_loop:
            x_curr = pyro.sample(
                "x_{}".format(i),
                dist.Categorical(
                    init if isinstance(i, int) and i < 1 else trans[x_prev]))
            with pyro.plate("tones", data.shape[2], dim=-1):
                pyro.sample("y_{}".format(i),
                            dist.Normal(Vindex(locs)[..., x_curr], 1.),
                            obs=Vindex(data)[sequences, i])
            x_prev = x_curr
Пример #13
0
def model_3(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences, )
        assert lengths.max() <= max_length
    hidden_dim = int(args.hidden_dim**0.5)  # split between w and x
    with handlers.mask(mask=include_prior):
        probs_w = pyro.sample(
            "probs_w",
            dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1))
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1))
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim,
                                        data_dim]).to_event(3))
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        w, x = 0, 0
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with handlers.mask(mask=(t < lengths).unsqueeze(-1)):
                w = pyro.sample("w_{}".format(t),
                                dist.Categorical(probs_w[w]),
                                infer={"enumerate": "parallel"})
                x = pyro.sample("x_{}".format(t),
                                dist.Categorical(probs_x[x]),
                                infer={"enumerate": "parallel"})
                with tones_plate as tones:
                    pyro.sample("y_{}".format(t),
                                dist.Bernoulli(probs_y[w, x, tones]),
                                obs=sequences[batch, t])
Пример #14
0
def model_1(data, history, vectorized):
    x_dim = 3
    init = pyro.param("init",
                      lambda: torch.rand(x_dim),
                      constraint=constraints.simplex)
    trans = pyro.param("trans",
                       lambda: torch.rand((x_dim, x_dim)),
                       constraint=constraints.simplex)
    locs = pyro.param("locs", lambda: torch.rand(x_dim))

    x_prev = None
    markov_loop = (pyro.vectorized_markov(
        name="time", size=len(data), dim=-2, history=history) if vectorized
                   else pyro.markov(range(len(data)), history=history))
    for i in markov_loop:
        x_curr = pyro.sample(
            "x_{}".format(i),
            dist.Categorical(
                init if isinstance(i, int) and i < 1 else trans[x_prev]),
        )
        with pyro.plate("tones", data.shape[-1], dim=-1):
            pyro.sample(
                "y_{}".format(i),
                dist.Normal(Vindex(locs)[..., x_curr], 1.0),
                obs=data[i],
            )
        x_prev = x_curr
Пример #15
0
def model_6(data, history, vectorized):
    x_dim = 3
    x_init = pyro.param("x_init",
                        lambda: torch.rand(x_dim),
                        constraint=constraints.simplex)
    x_trans = pyro.param("x_trans",
                         lambda: torch.rand((len(data) - 1, x_dim, x_dim)),
                         constraint=constraints.simplex)
    locs = pyro.param("locs", lambda: torch.rand(x_dim))

    x_prev = None
    markov_loop = \
        pyro.vectorized_markov(name="time", size=len(data), dim=-2, history=history) if vectorized \
        else pyro.markov(range(len(data)), history=history)
    for i in markov_loop:
        if isinstance(i, int) and i < 1:
            x_probs = x_init
        elif isinstance(i, int):
            x_probs = x_trans[i - 1, x_prev]
        else:
            x_probs = Vindex(x_trans)[(i - 1)[:, None], x_prev]

        x_curr = pyro.sample("x_{}".format(i), dist.Categorical(x_probs))
        with pyro.plate("tones", data.shape[-1], dim=-1):
            pyro.sample("y_{}".format(i),
                        dist.Normal(Vindex(locs)[..., x_curr], 1.),
                        obs=data[i])
        x_prev = x_curr
Пример #16
0
def model_2(data, history, vectorized):
    x_dim, y_dim = 3, 2
    x_init = pyro.param("x_init",
                        lambda: torch.rand(x_dim),
                        constraint=constraints.simplex)
    x_trans = pyro.param("x_trans",
                         lambda: torch.rand((x_dim, x_dim)),
                         constraint=constraints.simplex)
    y_init = pyro.param("y_init",
                        lambda: torch.rand(x_dim, y_dim),
                        constraint=constraints.simplex)
    y_trans = pyro.param("y_trans",
                         lambda: torch.rand((x_dim, y_dim, y_dim)),
                         constraint=constraints.simplex)

    x_prev = y_prev = None
    markov_loop = \
        pyro.vectorized_markov(name="time", size=len(data), dim=-2, history=history) if vectorized \
        else pyro.markov(range(len(data)), history=history)
    for i in markov_loop:
        x_curr = pyro.sample(
            "x_{}".format(i),
            dist.Categorical(
                x_init if isinstance(i, int) and i < 1 else x_trans[x_prev]))
        with pyro.plate("tones", data.shape[-1], dim=-1):
            y_curr = pyro.sample(
                "y_{}".format(i),
                dist.Categorical(y_init[x_curr] if isinstance(i, int) and i < 1
                                 else Vindex(y_trans)[x_curr, y_prev]),
                obs=data[i])
        x_prev, y_prev = x_curr, y_curr
Пример #17
0
 def testing():
     for i in pyro.markov(range(12)):
         if i % 4 == 0:
             fv2 = pyro.to_funsor(torch.zeros(2), funsor.Real, dim_to_name={-1: 'a'})
             v2 = pyro.to_data(fv2)
             assert v2.shape == (2,)
             print('a', v2.shape)
             print('a', fv2.inputs)
Пример #18
0
    def model():
        p = torch.tensor([[0.2, 0.8], [0.1, 0.9]])

        xs = [0]
        for t in pyro.markov(range(10), history=history):
            xs.append(pyro.sample("x_{}".format(t), dist.Categorical(p[xs[-1]]),
                                  infer={"enumerate": ("sequential", "parallel")[t % 2]}))
        assert all(x.dim() <= history + 1 for x in xs[1:])
Пример #19
0
 def testing():
     for i in pyro.markov(range(12)):
         if i % 4 == 0:
             v2 = pyro.to_data(Tensor(torch.zeros(2), OrderedDict([('a', funsor.Bint[2])]), 'real'))
             fv2 = pyro.to_funsor(v2, funsor.Real)
             assert v2.shape == (2,)
             print('a', v2.shape)
             print('a', fv2.inputs)
Пример #20
0
 def model():
     p = pyro.param("p", 0.25 * torch.ones(2, 2))
     q = pyro.param("q", 0.25 * torch.ones(2))
     x_prev = torch.tensor(0)
     x_curr = torch.tensor(0)
     for t in pyro.markov(range(10), history=history):
         probs = p[x_prev, x_curr]
         x_prev, x_curr = x_curr, pyro.sample("x_{}".format(t), dist.Bernoulli(probs)).long()
         pyro.sample("y_{}".format(t), dist.Bernoulli(q[x_curr]),
                     obs=torch.tensor(0.))
Пример #21
0
    def model():
        p = torch.tensor([[0.2, 0.8], [0.1, 0.9]])

        xs = [0]
        c = pyro.markov(history=history)
        with contextlib.ExitStack() as stack:
            for t in range(100):
                stack.enter_context(c)
                xs.append(pyro.sample("x_{}".format(t), dist.Categorical(p[xs[-1]])))
            assert all(x.dim() <= history + 1 for x in xs[1:])
Пример #22
0
 def hmm(data, hidden_dim=10):
     transition = 0.3 / hidden_dim + 0.7 * torch.eye(hidden_dim)
     means = torch.arange(float(hidden_dim))
     states = [0]
     for t in pyro.markov(range(len(data))):
         states.append(
             pyro.sample("states_{}".format(t),
                         dist.Categorical(transition[states[-1]])))
         data[t] = pyro.sample("obs_{}".format(t),
                               dist.Normal(means[states[-1]], 1.0),
                               obs=data[t])
     return states, data
Пример #23
0
def model_6(sequences, lengths, args, batch_size=None, include_prior=False):
    num_sequences, max_length, data_dim = sequences.shape
    assert lengths.shape == (num_sequences, )
    assert lengths.max() <= max_length
    hidden_dim = args.hidden_dim

    if not args.raftery_parameterization:
        # Explicitly parameterize the full tensor of transition probabilities, which
        # has hidden_dim cubed entries.
        probs_x = pyro.param("probs_x",
                             torch.rand(hidden_dim, hidden_dim, hidden_dim),
                             constraint=constraints.simplex)
    else:
        # Use the more parsimonious "Raftery" parameterization of
        # the tensor of transition probabilities. See reference:
        # Raftery, A. E. A model for high-order markov chains.
        # Journal of the Royal Statistical Society. 1985.
        probs_x1 = pyro.param("probs_x1",
                              torch.rand(hidden_dim, hidden_dim),
                              constraint=constraints.simplex)
        probs_x2 = pyro.param("probs_x2",
                              torch.rand(hidden_dim, hidden_dim),
                              constraint=constraints.simplex)
        mix_lambda = pyro.param("mix_lambda",
                                torch.tensor(0.5),
                                constraint=constraints.unit_interval)
        # we use broadcasting to combine two tensors of shape (hidden_dim, hidden_dim) and
        # (hidden_dim, 1, hidden_dim) to obtain a tensor of shape (hidden_dim, hidden_dim, hidden_dim)
        probs_x = mix_lambda * probs_x1 + (1.0 -
                                           mix_lambda) * probs_x2.unsqueeze(-2)

    probs_y = pyro.param("probs_y",
                         torch.rand(hidden_dim, data_dim),
                         constraint=constraints.unit_interval)
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        x_curr, x_prev = torch.tensor(0), torch.tensor(0)
        # we need to pass the argument `history=2' to `pyro.markov()`
        # since our model is now 2-markov
        for t in pyro.markov(range(lengths.max()), history=2):
            with handlers.mask(mask=(t < lengths).unsqueeze(-1)):
                probs_x_t = Vindex(probs_x)[x_prev, x_curr]
                x_prev, x_curr = x_curr, pyro.sample(
                    "x_{}".format(t),
                    dist.Categorical(probs_x_t),
                    infer={"enumerate": "parallel"})
                with tones_plate:
                    probs_y_t = probs_y[x_curr.squeeze(-1)]
                    pyro.sample("y_{}".format(t),
                                dist.Bernoulli(probs_y_t),
                                obs=sequences[batch, t])
Пример #24
0
 def testing():
     for i in pyro.markov(range(5)):
         v1 = pyro.to_data(Tensor(torch.ones(2), OrderedDict([(str(i), funsor.Bint[2])]), 'real'))
         v2 = pyro.to_data(Tensor(torch.zeros(2), OrderedDict([('a', funsor.Bint[2])]), 'real'))
         fv1 = pyro.to_funsor(v1, funsor.Real)
         fv2 = pyro.to_funsor(v2, funsor.Real)
         print(i, v1.shape)  # shapes should alternate
         if i % 2 == 0:
             assert v1.shape == (2,)
         else:
             assert v1.shape == (2, 1, 1)
         assert v2.shape == (2, 1)
         print(i, fv1.inputs)
         print('a', v2.shape)  # shapes should stay the same
         print('a', fv2.inputs)
Пример #25
0
 def testing():
     for i in pyro.markov(range(5)):
         fv1 = pyro.to_funsor(torch.zeros(2), funsor.Real, dim_to_name={-1: str(i)})
         fv2 = pyro.to_funsor(torch.ones(2), funsor.Real, dim_to_name={-1: "a"})
         v1 = pyro.to_data(fv1)
         v2 = pyro.to_data(fv2)
         print(i, v1.shape)  # shapes should alternate
         if i % 2 == 0:
             assert v1.shape == (2,)
         else:
             assert v1.shape == (2, 1, 1)
         assert v2.shape == (2, 1)
         print(i, fv1.inputs)
         print('a', v2.shape)  # shapes should stay the same
         print('a', fv2.inputs)
Пример #26
0
    def model():
        p = pyro.param("p", torch.ones(3, 3))
        q = pyro.param("q", torch.ones(2))
        r = pyro.param("r", torch.ones(3, 2, 4))

        x = 0
        times = pyro.markov(range(100)) if markov else range(11)
        for t in times:
            x = pyro.sample("x_{}".format(t), dist.Categorical(p[x]))
            y = pyro.sample("y_{}".format(t), dist.Categorical(q))
            if use_vindex:
                probs = Vindex(r)[x, y]
            else:
                z_ind = torch.arange(4, dtype=torch.long)
                probs = r[x.unsqueeze(-1), y.unsqueeze(-1), z_ind]
            pyro.sample("z_{}".format(t), dist.Categorical(probs),
                        obs=torch.tensor(0.))
Пример #27
0
def model_4(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences,)
        assert lengths.max() <= max_length
    hidden_dim = int(args.hidden_dim**0.5)  # split between w and x
    with handlers.mask(mask=include_prior):
        probs_w = pyro.sample(
            "probs_w", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1)
        )
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1)
            .expand_by([hidden_dim])
            .to_event(2),
        )
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim, data_dim]).to_event(3),
        )
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        # Note the broadcasting tricks here: we declare a hidden torch.arange and
        # ensure that w and x are always tensors so we can unsqueeze them below,
        # thus ensuring that the x sample sites have correct distribution shape.
        w = x = torch.tensor(0, dtype=torch.long)
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with handlers.mask(mask=(t < lengths).unsqueeze(-1)):
                w = pyro.sample(
                    "w_{}".format(t),
                    dist.Categorical(probs_w[w]),
                    infer={"enumerate": "parallel"},
                )
                x = pyro.sample(
                    "x_{}".format(t),
                    dist.Categorical(Vindex(probs_x)[w, x]),
                    infer={"enumerate": "parallel"},
                )
                with tones_plate as tones:
                    pyro.sample(
                        "y_{}".format(t),
                        dist.Bernoulli(probs_y[w, x, tones]),
                        obs=sequences[batch, t],
                    )
Пример #28
0
def model_10(data, history, vectorized):
    init_probs = torch.tensor([0.5, 0.5])
    transition_probs = pyro.param("transition_probs",
                                  torch.tensor([[0.75, 0.25], [0.25, 0.75]]),
                                  constraint=constraints.simplex)
    emission_probs = pyro.param("emission_probs",
                                torch.tensor([[0.75, 0.25], [0.25, 0.75]]),
                                constraint=constraints.simplex)
    x = None
    markov_loop = \
        pyro.vectorized_markov(name="time", size=len(data), history=history) if vectorized \
        else pyro.markov(range(len(data)), history=history)
    for i in markov_loop:
        probs = init_probs if x is None else transition_probs[x]
        x = pyro.sample("x_{}".format(i), dist.Categorical(probs))
        pyro.sample("y_{}".format(i),
                    dist.Categorical(emission_probs[x]),
                    obs=data[i])
Пример #29
0
def model_4(data, history, vectorized):
    w_dim, x_dim, y_dim = 2, 3, 2
    w_init = pyro.param("w_init",
                        lambda: torch.rand(w_dim),
                        constraint=constraints.simplex)
    w_trans = pyro.param("w_trans",
                         lambda: torch.rand((w_dim, w_dim)),
                         constraint=constraints.simplex)
    x_init = pyro.param("x_init",
                        lambda: torch.rand(w_dim, x_dim),
                        constraint=constraints.simplex)
    x_trans = pyro.param(
        "x_trans",
        lambda: torch.rand((w_dim, x_dim, x_dim)),
        constraint=constraints.simplex,
    )
    y_probs = pyro.param(
        "y_probs",
        lambda: torch.rand(w_dim, x_dim, y_dim),
        constraint=constraints.simplex,
    )

    w_prev = x_prev = None
    markov_loop = (pyro.vectorized_markov(
        name="time", size=len(data), dim=-2, history=history) if vectorized
                   else pyro.markov(range(len(data)), history=history))
    for i in markov_loop:
        w_curr = pyro.sample(
            "w_{}".format(i),
            dist.Categorical(
                w_init if isinstance(i, int) and i < 1 else w_trans[w_prev]),
        )
        x_curr = pyro.sample(
            "x_{}".format(i),
            dist.Categorical(x_init[w_curr] if isinstance(i, int) and i < 1
                             else x_trans[w_curr, x_prev]),
        )
        with pyro.plate("tones", data.shape[-1], dim=-1):
            pyro.sample(
                "y_{}".format(i),
                dist.Categorical(Vindex(y_probs)[w_curr, x_curr]),
                obs=data[i],
            )
        x_prev, w_prev = x_curr, w_curr
Пример #30
0
def model_1(sequences, lengths, args, batch_size=None, include_prior=True):
    # Sometimes it is safe to ignore jit warnings. Here we use the
    # pyro.util.ignore_jit_warnings context manager to silence warnings about
    # conversion to integer, since we know all three numbers will be the same
    # across all invocations to the model.
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences,)
        assert lengths.max() <= max_length
    with handlers.mask(mask=include_prior):
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1),
        )
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2),
        )
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    # We subsample batch_size items out of num_sequences items. Note that since
    # we're using dim=-1 for the notes plate, we need to batch over a different
    # dimension, here dim=-2.
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        x = 0
        # If we are not using the jit, then we can vary the program structure
        # each call by running for a dynamically determined number of time
        # steps, lengths.max(). However if we are using the jit, then we try to
        # keep a single program structure for all minibatches; the fixed
        # structure ends up being faster since each program structure would
        # need to trigger a new jit compile stage.
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with handlers.mask(mask=(t < lengths).unsqueeze(-1)):
                x = pyro.sample(
                    "x_{}".format(t),
                    dist.Categorical(probs_x[x]),
                    infer={"enumerate": "parallel"},
                )
                with tones_plate:
                    pyro.sample(
                        "y_{}".format(t),
                        dist.Bernoulli(probs_y[x.squeeze(-1)]),
                        obs=sequences[batch, t],
                    )