예제 #1
0
def model_0(sequences, lengths, args, batch_size=None, include_prior=True):
    assert not torch._C._get_tracing_state()
    num_sequences, max_length, data_dim = sequences.shape
    with handlers.mask(mask=include_prior):
        # Our prior on transition probabilities will be:
        # stay in the same state with 90% probability; uniformly jump to another
        # state with 10% probability.
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1))
        # We put a weak prior on the conditional probability of a tone sounding.
        # We know that on average about 4 of 88 tones are active, so we'll set a
        # rough weak prior of 10% of the notes being active at any one time.
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim,
                                        data_dim]).to_event(2))
    # In this first model we'll sequentially iterate over sequences in a
    # minibatch; this will make it easy to reason about tensor shapes.
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    for i in pyro.plate("sequences", len(sequences), batch_size):
        length = lengths[i]
        sequence = sequences[i, :length]
        x = 0
        for t in pyro.markov(range(length)):
            # On the next line, we'll overwrite the value of x with an updated
            # value. If we wanted to record all x values, we could instead
            # write x[t] = pyro.sample(...x[t-1]...).
            x = pyro.sample("x_{}_{}".format(i, t),
                            dist.Categorical(probs_x[x]),
                            infer={"enumerate": "parallel"})
            with tones_plate:
                pyro.sample("y_{}_{}".format(i, t),
                            dist.Bernoulli(probs_y[x.squeeze(-1)]),
                            obs=sequence[t])
예제 #2
0
    def model():
        x_plate = pyro.plate("x_plate",
                             5,
                             subsample_size=2 if subsampling else None,
                             dim=-1)
        y_plate = pyro.plate("y_plate",
                             6,
                             subsample_size=3 if subsampling else None,
                             dim=-2)
        with pyro.plate("num_particles", 50, dim=-3):
            with x_plate:
                b = pyro.sample(
                    "b", dist.Beta(torch.tensor(1.1), torch.tensor(1.1)))
            with y_plate:
                c = pyro.sample("c", dist.Bernoulli(0.5))
            with x_plate, y_plate:
                d = pyro.sample("d", dist.Bernoulli(b))

        # check shapes
        if enumerate_ == "parallel":
            assert b.shape == (50, 1, x_plate.subsample_size)
            assert c.shape == (2, 1, 1, 1)
            assert d.shape == (2, 1, 1, 1, 1)
        elif enumerate_ == "sequential":
            assert b.shape == (50, 1, x_plate.subsample_size)
            assert c.shape in ((), (1, 1, 1))  # both are valid
            assert d.shape in ((), (1, 1, 1))  # both are valid
        else:
            assert b.shape == (50, 1, x_plate.subsample_size)
            assert c.shape == (50, y_plate.subsample_size, 1)
            assert d.shape == (50, y_plate.subsample_size,
                               x_plate.subsample_size)
예제 #3
0
def model_3(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences, )
        assert lengths.max() <= max_length
    hidden_dim = int(args.hidden_dim**0.5)  # split between w and x
    with handlers.mask(mask=include_prior):
        probs_w = pyro.sample(
            "probs_w",
            dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1))
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1))
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim,
                                        data_dim]).to_event(3))
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        w, x = 0, 0
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with handlers.mask(mask=(t < lengths).unsqueeze(-1)):
                w = pyro.sample("w_{}".format(t),
                                dist.Categorical(probs_w[w]),
                                infer={"enumerate": "parallel"})
                x = pyro.sample("x_{}".format(t),
                                dist.Categorical(probs_x[x]),
                                infer={"enumerate": "parallel"})
                with tones_plate as tones:
                    pyro.sample("y_{}".format(t),
                                dist.Bernoulli(probs_y[w, x, tones]),
                                obs=sequences[batch, t])
예제 #4
0
def model_2(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences, )
        assert lengths.max() <= max_length
    with handlers.mask(mask=include_prior):
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1))
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim, 2,
                                        data_dim]).to_event(3))
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        x, y = 0, 0
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with handlers.mask(mask=(t < lengths).unsqueeze(-1)):
                x = pyro.sample("x_{}".format(t),
                                dist.Categorical(probs_x[x]),
                                infer={"enumerate": "parallel"})
                # Note the broadcasting tricks here: to index probs_y on tensors x and y,
                # we also need a final tensor for the tones dimension. This is conveniently
                # provided by the plate associated with that dimension.
                with tones_plate as tones:
                    y = pyro.sample("y_{}".format(t),
                                    dist.Bernoulli(probs_y[x, y, tones]),
                                    obs=sequences[batch, t]).long()
예제 #5
0
파일: hmm.py 프로젝트: pyro-ppl/pyro
def model_4(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences,)
        assert lengths.max() <= max_length
    hidden_dim = int(args.hidden_dim**0.5)  # split between w and x
    with handlers.mask(mask=include_prior):
        probs_w = pyro.sample(
            "probs_w", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1)
        )
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1)
            .expand_by([hidden_dim])
            .to_event(2),
        )
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim, data_dim]).to_event(3),
        )
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        # Note the broadcasting tricks here: we declare a hidden torch.arange and
        # ensure that w and x are always tensors so we can unsqueeze them below,
        # thus ensuring that the x sample sites have correct distribution shape.
        w = x = torch.tensor(0, dtype=torch.long)
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with handlers.mask(mask=(t < lengths).unsqueeze(-1)):
                w = pyro.sample(
                    "w_{}".format(t),
                    dist.Categorical(probs_w[w]),
                    infer={"enumerate": "parallel"},
                )
                x = pyro.sample(
                    "x_{}".format(t),
                    dist.Categorical(Vindex(probs_x)[w, x]),
                    infer={"enumerate": "parallel"},
                )
                with tones_plate as tones:
                    pyro.sample(
                        "y_{}".format(t),
                        dist.Bernoulli(probs_y[w, x, tones]),
                        obs=sequences[batch, t],
                    )
예제 #6
0
파일: hmm.py 프로젝트: pyro-ppl/pyro
def model_1(sequences, lengths, args, batch_size=None, include_prior=True):
    # Sometimes it is safe to ignore jit warnings. Here we use the
    # pyro.util.ignore_jit_warnings context manager to silence warnings about
    # conversion to integer, since we know all three numbers will be the same
    # across all invocations to the model.
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences,)
        assert lengths.max() <= max_length
    with handlers.mask(mask=include_prior):
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1),
        )
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2),
        )
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    # We subsample batch_size items out of num_sequences items. Note that since
    # we're using dim=-1 for the notes plate, we need to batch over a different
    # dimension, here dim=-2.
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        x = 0
        # If we are not using the jit, then we can vary the program structure
        # each call by running for a dynamically determined number of time
        # steps, lengths.max(). However if we are using the jit, then we try to
        # keep a single program structure for all minibatches; the fixed
        # structure ends up being faster since each program structure would
        # need to trigger a new jit compile stage.
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with handlers.mask(mask=(t < lengths).unsqueeze(-1)):
                x = pyro.sample(
                    "x_{}".format(t),
                    dist.Categorical(probs_x[x]),
                    infer={"enumerate": "parallel"},
                )
                with tones_plate:
                    pyro.sample(
                        "y_{}".format(t),
                        dist.Bernoulli(probs_y[x.squeeze(-1)]),
                        obs=sequences[batch, t],
                    )
예제 #7
0
파일: hmm.py 프로젝트: pyro-ppl/pyro
def model_7(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences,)
        assert lengths.max() <= max_length
    with handlers.mask(mask=include_prior):
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1),
        )
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2),
        )
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    # Note that since we're using dim=-2 for the time dimension, we need
    # to batch sequences over a different dimension, here dim=-3.
    with pyro.plate("sequences", num_sequences, batch_size, dim=-3) as batch:
        lengths = lengths[batch]
        batch = batch[:, None]
        x_prev = 0
        # To vectorize time dimension we use pyro.vectorized_markov(name=...).
        # With the help of Vindex and additional unsqueezes we can ensure that
        # dimensions line up properly.
        for t in pyro.vectorized_markov(
            name="time", size=int(max_length if args.jit else lengths.max()), dim=-2
        ):
            with handlers.mask(mask=(t < lengths.unsqueeze(-1)).unsqueeze(-1)):
                x_curr = pyro.sample(
                    "x_{}".format(t),
                    dist.Categorical(probs_x[x_prev]),
                    infer={"enumerate": "parallel"},
                )
                with tones_plate:
                    pyro.sample(
                        "y_{}".format(t),
                        dist.Bernoulli(probs_y[x_curr.squeeze(-1)]),
                        obs=Vindex(sequences)[batch, t],
                    )