Exemple #1
0
def model_3(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences, )
        assert lengths.max() <= max_length
    hidden_dim = int(args.hidden_dim**0.5)  # split between w and x
    with poutine.mask(mask=include_prior):
        probs_w = pyro.sample(
            "probs_w",
            dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1))
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1))
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim,
                                        data_dim]).to_event(3))
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        w, x = 0, 0
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                w = pyro.sample("w_{}".format(t),
                                dist.Categorical(probs_w[w]),
                                infer={"enumerate": "parallel"})
                x = pyro.sample("x_{}".format(t),
                                dist.Categorical(probs_x[x]),
                                infer={"enumerate": "parallel"})
                with tones_plate as tones:
                    pyro.sample("y_{}".format(t),
                                dist.Bernoulli(probs_y[w, x, tones]),
                                obs=sequences[batch, t])
Exemple #2
0
def model_2(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences, )
        assert lengths.max() <= max_length
    with poutine.mask(mask=include_prior):
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1))
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim, 2,
                                        data_dim]).to_event(3))
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        x, y = 0, 0
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                x = pyro.sample("x_{}".format(t),
                                dist.Categorical(probs_x[x]),
                                infer={"enumerate": "parallel"})
                # Note the broadcasting tricks here: to index probs_y on tensors x and y,
                # we also need a final tensor for the tones dimension. This is conveniently
                # provided by the plate associated with that dimension.
                with tones_plate as tones:
                    y = pyro.sample("y_{}".format(t),
                                    dist.Bernoulli(probs_y[x, y, tones]),
                                    obs=sequences[batch, t]).long()
Exemple #3
0
def model_5(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences, )
        assert lengths.max() <= max_length

    # Initialize a global module instance if needed.
    global tones_generator
    if tones_generator is None:
        tones_generator = TonesGenerator(args, data_dim)
    pyro.module("tones_generator", tones_generator)

    with poutine.mask(mask=include_prior):
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1))
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        x = 0
        y = torch.zeros(data_dim)
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                x = pyro.sample("x_{}".format(t),
                                dist.Categorical(probs_x[x]),
                                infer={"enumerate": "parallel"})
                # Note that since each tone depends on all tones at a previous time step
                # the tones at different time steps now need to live in separate plates.
                with pyro.plate("tones_{}".format(t), data_dim, dim=-1):
                    y = pyro.sample(
                        "y_{}".format(t),
                        dist.Bernoulli(logits=tones_generator(x, y)),
                        obs=sequences[batch, t])
Exemple #4
0
def model_4(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences,)
        assert lengths.max() <= max_length
    hidden_dim = int(args.hidden_dim ** 0.5)  # split between w and x
    with poutine.mask(mask=include_prior):
        probs_w = pyro.sample("probs_w",
                              dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1)
                                  .to_event(1))
        probs_x = pyro.sample("probs_x",
                              dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1)
                                  .expand_by([hidden_dim])
                                  .to_event(2))
        probs_y = pyro.sample("probs_y",
                              dist.Beta(0.1, 0.9)
                                  .expand([hidden_dim, hidden_dim, data_dim])
                                  .to_event(3))
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        # Note the broadcasting tricks here: we declare a hidden torch.arange and
        # ensure that w and x are always tensors so we can unsqueeze them below,
        # thus ensuring that the x sample sites have correct distribution shape.
        w = x = torch.tensor(0, dtype=torch.long)
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                w = pyro.sample("w_{}".format(t), dist.Categorical(probs_w[w]),
                                infer={"enumerate": "parallel"})
                x = pyro.sample("x_{}".format(t),
                                dist.Categorical(Vindex(probs_x)[w, x]),
                                infer={"enumerate": "parallel"})
                with tones_plate as tones:
                    pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y[w, x, tones]),
                                obs=sequences[batch, t])
Exemple #5
0
def test_arg_kwarg_error():
    def model():
        pyro.param("p", torch.zeros(1, requires_grad=True))
        pyro.sample("a",
                    Bernoulli(torch.tensor([0.5])),
                    infer={"enumerate": "parallel"})
        pyro.sample("b", Bernoulli(torch.tensor([0.5])))

    with pytest.raises(ValueError, match="not callable"):
        with poutine.mask(False):
            model()

    with poutine.mask(mask=False):
        model()
Exemple #6
0
def test_get_mask():
    assert get_mask() is None

    with poutine.mask(mask=True):
        assert get_mask() is True
    with poutine.mask(mask=False):
        assert get_mask() is False

    with pyro.plate("i", 2, dim=-1):
        mask1 = torch.tensor([False, True, True])
        mask2 = torch.tensor([True, True, False])
        with poutine.mask(mask=mask1):
            assert_equal(get_mask(), mask1)
            with poutine.mask(mask=mask2):
                assert_equal(get_mask(), mask1 & mask2)
Exemple #7
0
def model(transition_alphas, emission_alphas, lengths,
          sequences=None, batch_size=None):
    # From https://pyro.ai/examples/hmm.html
    with ignore_jit_warnings():
        if sequences is not None:
            num_sequences, max_length, data_dim = map(int, sequences.shape)
            assert lengths.shape == (num_sequences,)
            assert lengths.max() <= max_length
        else:
            data_dim = emission_alphas.size(1)
            num_sequences = int(lengths.shape[0])
            max_length = int(lengths.max())
    transition_probs = pyro.sample('transition_probs',
                                   dist.Dirichlet(transition_alphas).to_event(1))
    emission_probs = pyro.sample('emission_probs',
                                 dist.Dirichlet(emission_alphas).to_event(2))
    element_plate = pyro.plate('elements', data_dim, dim=-1)
    with pyro.plate('sequences', num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        state = 0
        for t in pyro.markov(range(max_length)):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                state = pyro.sample(f'state_{t}', dist.Categorical(transition_probs[state]),
                                    infer={'enumerate': 'parallel'})
                obs_element = Vindex(sequences)[batch, t] if sequences is not None else None
                with element_plate:
                    element = pyro.sample(f'element_{t}',
                                          dist.Categorical(emission_probs[state.squeeze(-1)]),
                                          obs=obs_element)
Exemple #8
0
 def model(num_particles=1, z=None):
     p = pyro.param("p", torch.tensor(0.25))
     with pyro.plate("num_particles", num_particles, dim=-2):
         z = pyro.sample("z", dist.Bernoulli(p), obs=z)
         logger.info("z.shape = {}".format(z.shape))
         with pyro.plate("data", 3), poutine.mask(mask=mask):
             pyro.sample("x", dist.Normal(z, 1.), obs=data)
Exemple #9
0
def model_1(capture_history, sex):
    N, T = capture_history.shape
    phi = pyro.sample("phi", dist.Uniform(0.0, 1.0))  # survival probability
    rho = pyro.sample("rho", dist.Uniform(0.0, 1.0))  # recapture probability

    with pyro.plate("animals", N, dim=-1):
        z = torch.ones(N)
        # we use this mask to eliminate extraneous log probabilities
        # that arise for a given individual before its first capture.
        first_capture_mask = torch.zeros(N).bool()
        for t in pyro.markov(range(T)):
            with poutine.mask(mask=first_capture_mask):
                mu_z_t = first_capture_mask.float() * phi * z + (
                    1 - first_capture_mask.float())
                # we use parallel enumeration to exactly sum out
                # the discrete states z_t.
                z = pyro.sample(
                    "z_{}".format(t),
                    dist.Bernoulli(mu_z_t),
                    infer={"enumerate": "parallel"},
                )
                mu_y_t = rho * z
                pyro.sample("y_{}".format(t),
                            dist.Bernoulli(mu_y_t),
                            obs=capture_history[:, t])
            first_capture_mask |= capture_history[:, t].bool()
Exemple #10
0
def model_2(capture_history, sex):
    N, T = capture_history.shape
    rho = pyro.sample("rho", dist.Uniform(0.0, 1.0))  # recapture probability

    z = torch.ones(N)
    first_capture_mask = torch.zeros(N).bool()
    # we create the plate once, outside of the loop over t
    animals_plate = pyro.plate("animals", N, dim=-1)
    for t in pyro.markov(range(T)):
        # note that phi_t needs to be outside the plate, since
        # phi_t is shared across all N individuals
        phi_t = pyro.sample("phi_{}".format(t), dist.Uniform(0.0, 1.0)) if t > 0 \
                else 1.0
        with animals_plate, poutine.mask(mask=first_capture_mask):
            mu_z_t = first_capture_mask.float() * phi_t * z + (
                1 - first_capture_mask.float())
            # we use parallel enumeration to exactly sum out
            # the discrete states z_t.
            z = pyro.sample("z_{}".format(t),
                            dist.Bernoulli(mu_z_t),
                            infer={"enumerate": "parallel"})
            mu_y_t = rho * z
            pyro.sample("y_{}".format(t),
                        dist.Bernoulli(mu_y_t),
                        obs=capture_history[:, t])
        first_capture_mask |= capture_history[:, t].bool()
Exemple #11
0
def model_3(capture_history, sex):
    def logit(p):
        return torch.log(p) - torch.log1p(-p)

    N, T = capture_history.shape
    phi_mean = pyro.sample("phi_mean",
                           dist.Uniform(0.0, 1.0))  # mean survival probability
    phi_logit_mean = logit(phi_mean)
    # controls temporal variability of survival probability
    phi_sigma = pyro.sample("phi_sigma", dist.Uniform(0.0, 10.0))
    rho = pyro.sample("rho", dist.Uniform(0.0, 1.0))  # recapture probability

    z = torch.ones(N)
    first_capture_mask = torch.zeros(N).bool()
    # we create the plate once, outside of the loop over t
    animals_plate = pyro.plate("animals", N, dim=-1)
    for t in pyro.markov(range(T)):
        phi_logit_t = pyro.sample("phi_logit_{}".format(t),
                                  dist.Normal(phi_logit_mean, phi_sigma)) if t > 0 \
                      else torch.tensor(0.0)
        phi_t = torch.sigmoid(phi_logit_t)
        with animals_plate, poutine.mask(mask=first_capture_mask):
            mu_z_t = first_capture_mask.float() * phi_t * z + (
                1 - first_capture_mask.float())
            # we use parallel enumeration to exactly sum out
            # the discrete states z_t.
            z = pyro.sample("z_{}".format(t),
                            dist.Bernoulli(mu_z_t),
                            infer={"enumerate": "parallel"})
            mu_y_t = rho * z
            pyro.sample("y_{}".format(t),
                        dist.Bernoulli(mu_y_t),
                        obs=capture_history[:, t])
        first_capture_mask |= capture_history[:, t].bool()
Exemple #12
0
def model_4(capture_history, sex):
    N, T = capture_history.shape
    # survival probabilities for males/females
    phi_male = pyro.sample("phi_male", dist.Uniform(0.0, 1.0))
    phi_female = pyro.sample("phi_female", dist.Uniform(0.0, 1.0))
    # we construct a N-dimensional vector that contains the appropriate
    # phi for each individual given its sex (female = 0, male = 1)
    phi = sex * phi_male + (1.0 - sex) * phi_female
    rho = pyro.sample("rho", dist.Uniform(0.0, 1.0))  # recapture probability

    with pyro.plate("animals", N, dim=-1):
        z = torch.ones(N)
        # we use this mask to eliminate extraneous log probabilities
        # that arise for a given individual before its first capture.
        first_capture_mask = torch.zeros(N).bool()
        for t in pyro.markov(range(T)):
            with poutine.mask(mask=first_capture_mask):
                mu_z_t = first_capture_mask.float() * phi * z + (
                    1 - first_capture_mask.float())
                # we use parallel enumeration to exactly sum out
                # the discrete states z_t.
                z = pyro.sample("z_{}".format(t),
                                dist.Bernoulli(mu_z_t),
                                infer={"enumerate": "parallel"})
                mu_y_t = rho * z
                pyro.sample("y_{}".format(t),
                            dist.Bernoulli(mu_y_t),
                            obs=capture_history[:, t])
            first_capture_mask |= capture_history[:, t].bool()
Exemple #13
0
def model_7(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences, )
        assert lengths.max() <= max_length

    # Initialize a global module instance if needed.
    global tones_generator
    if tones_generator is None:
        tones_generator = TonesGenerator(args, data_dim)
    pyro.module("tones_generator", tones_generator)

    with poutine.mask(mask=include_prior):
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1),
        )
    with pyro.plate("sequences", num_sequences, batch_size, dim=-1) as batch:
        lengths = lengths[batch]
        y = sequences[batch] if args.jit else sequences[batch, :lengths.max()]
        x = torch.arange(args.hidden_dim)
        t = torch.arange(y.size(1))
        init_logits = torch.full((args.hidden_dim, ), -float("inf"))
        init_logits[0] = 0
        trans_logits = probs_x.log()
        with ignore_jit_warnings():
            obs_dist = dist.Bernoulli(
                logits=tones_generator(x, y.unsqueeze(-2))).to_event(1)
            obs_dist = obs_dist.mask((t < lengths.unsqueeze(-1)).unsqueeze(-1))
            hmm_dist = dist.DiscreteHMM(init_logits, trans_logits, obs_dist)
        pyro.sample("y", hmm_dist, obs=y)
Exemple #14
0
def model_0(sequences, lengths, args, batch_size=None, include_prior=True):
    assert not torch._C._get_tracing_state()
    num_sequences, max_length, data_dim = sequences.shape
    with poutine.mask(mask=include_prior):
        # Our prior on transition probabilities will be:
        # stay in the same state with 90% probability; uniformly jump to another
        # state with 10% probability.
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1))
        # We put a weak prior on the conditional probability of a tone sounding.
        # We know that on average about 4 of 88 tones are active, so we'll set a
        # rough weak prior of 10% of the notes being active at any one time.
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim,
                                        data_dim]).to_event(2))
    # In this first model we'll sequentially iterate over sequences in a
    # minibatch; this will make it easy to reason about tensor shapes.
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    for i in pyro.plate("sequences", len(sequences), batch_size):
        length = lengths[i]
        sequence = sequences[i, :length]
        x = 0
        for t in pyro.markov(range(length)):
            # On the next line, we'll overwrite the value of x with an updated
            # value. If we wanted to record all x values, we could instead
            # write x[t] = pyro.sample(...x[t-1]...).
            x = pyro.sample("x_{}_{}".format(i, t),
                            dist.Categorical(probs_x[x]),
                            infer={"enumerate": "parallel"})
            with tones_plate:
                pyro.sample("y_{}_{}".format(i, t),
                            dist.Bernoulli(probs_y[x.squeeze(-1)]),
                            obs=sequence[t])
Exemple #15
0
def model(detections, args):
    noise_scale = pyro.param('noise_scale')
    objects = pyro.param('objects_loc').squeeze(-1)
    num_detections, = detections.shape
    max_num_objects, = objects.shape

    # Existence part.
    p_exists = args.expected_num_objects / max_num_objects
    with pyro.plate('objects_plate', max_num_objects):
        exists = pyro.sample('exists', dist.Bernoulli(p_exists))
        with poutine.mask(mask=exists.bool()):
            pyro.sample('objects', dist.Normal(0., 1.), obs=objects)

    # Assignment part.
    p_fake = args.num_fake_detections / num_detections
    with pyro.plate('detections_plate', num_detections):
        assign_probs = torch.empty(max_num_objects + 1)
        assign_probs[:-1] = (1 - p_fake) / max_num_objects
        assign_probs[-1] = p_fake
        assign = pyro.sample('assign', dist.Categorical(logits=assign_probs))
        is_fake = (assign == assign.shape[-1] - 1)
        objects_plus_bogus = torch.zeros(max_num_objects + 1)
        objects_plus_bogus[:max_num_objects] = objects
        real_dist = dist.Normal(objects_plus_bogus[assign], noise_scale)
        fake_dist = dist.Normal(0., 1.)
        pyro.sample('detections',
                    dist.MaskedMixture(is_fake, real_dist, fake_dist),
                    obs=detections)
Exemple #16
0
def model_5(capture_history, sex):
    N, T = capture_history.shape

    # phi_beta controls the survival probability differential
    # for males versus females (in logit space)
    phi_beta = pyro.sample("phi_beta", dist.Normal(0.0, 10.0))
    phi_beta = sex * phi_beta
    rho = pyro.sample("rho", dist.Uniform(0.0, 1.0))  # recapture probability

    z = torch.ones(N)
    first_capture_mask = torch.zeros(N).bool()
    # we create the plate once, outside of the loop over t
    animals_plate = pyro.plate("animals", N, dim=-1)
    for t in pyro.markov(range(T)):
        phi_gamma_t = pyro.sample("phi_gamma_{}".format(t), dist.Normal(0.0, 10.0)) if t > 0 \
                      else 0.0
        phi_t = torch.sigmoid(phi_beta + phi_gamma_t)
        with animals_plate, poutine.mask(mask=first_capture_mask):
            mu_z_t = first_capture_mask.float() * phi_t * z + (
                1 - first_capture_mask.float())
            # we use parallel enumeration to exactly sum out
            # the discrete states z_t.
            z = pyro.sample("z_{}".format(t),
                            dist.Bernoulli(mu_z_t),
                            infer={"enumerate": "parallel"})
            mu_y_t = rho * z
            pyro.sample("y_{}".format(t),
                        dist.Bernoulli(mu_y_t),
                        obs=capture_history[:, t])
        first_capture_mask |= capture_history[:, t].bool()
Exemple #17
0
def model_1(sequences, lengths, args, batch_size=None, include_prior=True):
    # Sometimes it is safe to ignore jit warnings. Here we use the
    # pyro.util.ignore_jit_warnings context manager to silence warnings about
    # conversion to integer, since we know all three numbers will be the same
    # across all invocations to the model.
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences, )
        assert lengths.max() <= max_length
    with poutine.mask(mask=include_prior):
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1),
        )
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim,
                                        data_dim]).to_event(2),
        )
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    # We subsample batch_size items out of num_sequences items. Note that since
    # we're using dim=-1 for the notes plate, we need to batch over a different
    # dimension, here dim=-2.
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        x = 0
        # If we are not using the jit, then we can vary the program structure
        # each call by running for a dynamically determined number of time
        # steps, lengths.max(). However if we are using the jit, then we try to
        # keep a single program structure for all minibatches; the fixed
        # structure ends up being faster since each program structure would
        # need to trigger a new jit compile stage.
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                x = pyro.sample(
                    "x_{}".format(t),
                    dist.Categorical(probs_x[x]),
                    infer={"enumerate": "parallel"},
                )
                with tones_plate:
                    pyro.sample(
                        "y_{}".format(t),
                        dist.Bernoulli(probs_y[x.squeeze(-1)]),
                        obs=sequences[batch, t],
                    )
Exemple #18
0
    def median(self, *args, **kwargs) -> Dict[str, torch.Tensor]:
        """
        Returns the posterior median value of each latent variable.

        :return: A dict mapping sample site name to median tensor.
        :rtype: dict
        """
        with torch.no_grad(), poutine.mask(mask=False):
            aux_values = self._sample_aux_values(temperature=0.0)
            values, _ = self._transform_values(aux_values)
        return values
Exemple #19
0
def model_6(sequences, lengths, args, batch_size=None, include_prior=False):
    num_sequences, max_length, data_dim = sequences.shape
    assert lengths.shape == (num_sequences, )
    assert lengths.max() <= max_length
    hidden_dim = args.hidden_dim
    hidden = torch.arange(hidden_dim, dtype=torch.long)

    if not args.raftery_parameterization:
        # Explicitly parameterize the full tensor of transition probabilities, which
        # has hidden_dim cubed entries.
        probs_x = pyro.param("probs_x",
                             torch.rand(hidden_dim, hidden_dim, hidden_dim),
                             constraint=constraints.simplex)
    else:
        # Use the more parsimonious "Raftery" parameterization of
        # the tensor of transition probabilities. See reference:
        # Raftery, A. E. A model for high-order markov chains.
        # Journal of the Royal Statistical Society. 1985.
        probs_x1 = pyro.param("probs_x1",
                              torch.rand(hidden_dim, hidden_dim),
                              constraint=constraints.simplex)
        probs_x2 = pyro.param("probs_x2",
                              torch.rand(hidden_dim, hidden_dim),
                              constraint=constraints.simplex)
        mix_lambda = pyro.param("mix_lambda",
                                torch.tensor(0.5),
                                constraint=constraints.unit_interval)
        # we use broadcasting to combine two tensors of shape (hidden_dim, hidden_dim) and
        # (hidden_dim, 1, hidden_dim) to obtain a tensor of shape (hidden_dim, hidden_dim, hidden_dim)
        probs_x = mix_lambda * probs_x1 + (1.0 -
                                           mix_lambda) * probs_x2.unsqueeze(-2)

    probs_y = pyro.param("probs_y",
                         torch.rand(hidden_dim, data_dim),
                         constraint=constraints.unit_interval)
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        x_curr, x_prev = torch.tensor(0), torch.tensor(0)
        # we need to pass the argument `history=2' to `pyro.markov()`
        # since our model is now 2-markov
        for t in pyro.markov(range(lengths.max()), history=2):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                probs_x_t = probs_x[x_prev.unsqueeze(-1),
                                    x_curr.unsqueeze(-1), hidden]
                x_prev, x_curr = x_curr, pyro.sample(
                    "x_{}".format(t),
                    dist.Categorical(probs_x_t),
                    infer={"enumerate": "parallel"})
                with tones_plate:
                    probs_y_t = probs_y[x_curr.squeeze(-1)]
                    pyro.sample("y_{}".format(t),
                                dist.Bernoulli(probs_y_t),
                                obs=sequences[batch, t])
Exemple #20
0
def _masked_observe(name, fn, obs, obs_mask, *args, **kwargs):
    # Split into two auxiliary sample sites.
    with poutine.mask(mask=obs_mask):
        observed = sample(f"{name}_observed", fn, *args, **kwargs, obs=obs)
    with poutine.mask(mask=~obs_mask):
        unobserved = sample(f"{name}_unobserved", fn, *args, **kwargs)

    # Interleave observed and unobserved events.
    shape = obs_mask.shape + (1, ) * fn.event_dim
    batch_mask = obs_mask.reshape(shape)
    try:
        value = torch.where(batch_mask, observed, unobserved)
    except RuntimeError as e:
        if "must match the size of tensor" in str(e):
            shape = torch.broadcast_shapes(observed.shape, unobserved.shape)
            batch_shape = shape[:len(shape) - fn.event_dim]
            raise ValueError(
                f"Invalid obs_mask shape {tuple(obs_mask.shape)}; should be "
                f"broadcastable to batch_shape = {tuple(batch_shape)}") from e
        raise
    return deterministic(name, value)
Exemple #21
0
def guide_generic(config):
    """generic mean-field guide for continuous random effects"""
    N_state = config["sizes"]["state"]

    if config["group"]["random"] == "continuous":
        loc_g = pyro.param("loc_group", lambda: torch.zeros((N_state**2, )))
        scale_g = pyro.param(
            "scale_group",
            lambda: torch.ones((N_state**2, )),
            constraint=constraints.positive,
        )

    # initialize individual-level random effect parameters
    N_c = config["sizes"]["group"]
    if config["individual"]["random"] == "continuous":
        loc_i = pyro.param(
            "loc_individual",
            lambda: torch.zeros((
                N_c,
                N_state**2,
            )),
        )
        scale_i = pyro.param(
            "scale_individual",
            lambda: torch.ones((
                N_c,
                N_state**2,
            )),
            constraint=constraints.positive,
        )

    N_c = config["sizes"]["group"]
    with pyro.plate("group", N_c, dim=-1):

        if config["group"]["random"] == "continuous":
            pyro.sample(
                "eps_g",
                dist.Normal(loc_g, scale_g).to_event(1),
            )  # infer={"num_samples": 10})

        N_s = config["sizes"]["individual"]
        with pyro.plate(
                "individual", N_s,
                dim=-2), poutine.mask(mask=config["individual"]["mask"]):

            # individual-level random effects
            if config["individual"]["random"] == "continuous":
                pyro.sample(
                    "eps_i",
                    dist.Normal(loc_i, scale_i).to_event(1),
                )  # infer={"num_samples": 10})
Exemple #22
0
    def transform_samples(self, aux_samples, save_params=None):
        """
        Given latent samples from the warped posterior (with a possible batch dimension),
        return a `dict` of samples from the latent sites in the model.

        :param dict aux_samples: Dict site name to tensor value for each latent
            auxiliary site (or if ``save_params`` is specifiec, then for only
            those latent auxiliary sites needed to compute requested params).
        :param list save_params: An optional list of site names to save. This
            is useful in models with large nuisance variables. Defaults to
            None, saving all params.
        :return: a `dict` of samples keyed by latent sites in the model.
        :rtype: dict
        """
        with poutine.condition(data=aux_samples), poutine.mask(mask=False):
            deltas = self.guide.get_deltas(save_params)
        return {name: delta.v for name, delta in deltas.items()}
Exemple #23
0
    def _forward_pyro_mean_field(self, features, trip_counts):
        total_hours = len(features)
        observed_hours, num_origins, num_destins = trip_counts.shape
        assert observed_hours <= total_hours
        assert num_origins == self.num_stations
        assert num_destins == self.num_stations
        time_plate = pyro.plate("time", observed_hours, dim=-3)
        origins_plate = pyro.plate("origins", num_origins, dim=-2)
        destins_plate = pyro.plate("destins", num_destins, dim=-1)
        init_dist, trans_matrix, trans_dist, obs_matrix, obs_dist = \
            self._dynamics(features[:observed_hours])

        # This is a parallelizable crf representation of the HMM.
        # We first pull random variables from the guide, masking all factors.
        with poutine.mask(mask=False):
            shape = (1 + observed_hours, self.args.state_dim)  # includes init
            state = pyro.sample("state",
                                dist.Normal(0, 1).expand(shape).to_event(2))

            shape = (observed_hours, 2 * num_origins * num_destins)
            gate_rate = pyro.sample(
                "gate_rate",
                dist.Normal(0, 1).expand(shape).to_event(2))

        # We then declare CRF factors.
        pyro.sample("init", init_dist, obs=state[0])
        pyro.sample("trans",
                    trans_dist.expand((observed_hours, )).to_event(1),
                    obs=state[..., 1:, :] - state[..., :-1, :] @ trans_matrix)
        pyro.sample("obs",
                    obs_dist.expand((observed_hours, )).to_event(1),
                    obs=gate_rate - state[..., 1:, :] @ obs_matrix)
        gate, rate = self._unpack_gate_rate(gate_rate, event_dim=2)
        with time_plate, origins_plate, destins_plate:
            pyro.sample("trip_count",
                        dist.ZeroInflatedPoisson(gate, rate),
                        obs=trip_counts)

        # The second half of the model forecasts forward.
        if total_hours > observed_hours:
            return self._forward_pyro_forecast(features,
                                               trip_counts,
                                               origins_plate,
                                               destins_plate,
                                               state=state[..., -1, :])
def guide(sequences):
    theta = pyro.param("theta", torch.ones(16))
    alpha = pyro.param("alpha", torch.rand(1))
    beta = pyro.param("beta", torch.rand(1))
    p = pyro.param("p", torch.rand(1))
    q = pyro.param("q", torch.rand(1))
    w = p * torch.eye(16) + q
    with poutine.mask(mask=False):
        probs_x = pyro.sample("probs_x", Dirichlet(w).to_event(1))
        probs_y = pyro.sample("probs_y",
                              Beta(alpha, beta).expand([16, 51]).to_event(2))

    for i in pyro.plate("sequences", len(sequences), 8):
        length = lengths[i]
        sequence = sequences[i, :length]
        x = 0
        for t in pyro.markov(range(length)):
            x = pyro.sample("x_{}_{}".format(i, t), Categorical(probs_x[x]))
 def model(self,
           lengths=None,
           sequences=None,
           expected_string_length: int = 5):
     with ignore_jit_warnings():
         assert sequences is None or lengths is not None
         assert lengths is None or lengths.max(
         ) <= self.smct.max_chain_length
         assert sequences is None or (
             0 <= sequences.min()
             and sequences.max() < self.smct.alphabet_size)
     binom_prob = pyro.sample(
         'binom_prob',
         dist.Beta(
             min(1, expected_string_length),
             min(1, self.smct.max_chain_length - expected_string_length)))
     lengths_size = 1 if sequences is None else sequences.size(0)
     with pyro.plate('lengths_plate', size=lengths_size, dim=-1):
         lengths = pyro.sample(
             'lengths',
             dist.Binomial(self.smct.max_chain_length, binom_prob),
             obs=(lengths.float() if lengths is not None else lengths))
     if lengths.dim() == 0:
         lengths = lengths.unsqueeze(-1)
     sequence_size = 1 if sequences is None else sequences.size(0)
     with pyro.plate('sequences_plate', size=sequence_size,
                     dim=-2) as batch:
         lengths = lengths[batch]
         prev = ()
         for t in pyro.markov(range(self.smct.max_chain_length),
                              history=self.smct.order):
             if len(prev) > self.smct.order:
                 prev = prev[1:]
             probs_t = pyro.sample(
                 f'probs_{t}',
                 dist.Dirichlet(
                     self.smct.get_pseudocounts(prev).unsqueeze(-2)))
             x_t = None if sequences is None else sequences[batch, t]
             with poutine.mask(
                     mask=(t < lengths).unsqueeze(-1).unsqueeze(-1)):
                 x_t = pyro.sample(f'x_{t}',
                                   dist.Categorical(probs=probs_t),
                                   obs=x_t)
             prev = (*prev, x_t)
def model(sequences):
    with poutine.mask(mask=False):
        probs_x = pyro.sample("probs_x",
                              Dirichlet(0.9 * torch.eye(16) + 0.1).to_event(1))
        probs_y = pyro.sample("probs_y",
                              Beta(0.1, 0.9).expand([16, 51]).to_event(2))
    tones_plate = pyro.plate("tones", 51, dim=-1)
    for i in pyro.plate("sequences", len(sequences)):
        length = lengths[i]
        sequence = sequences[i, :length]
        x = 0
        for t in pyro.markov(range(length)):
            x = pyro.sample("x_{}_{}".format(i, t),
                            Categorical(probs_x[x]),
                            infer={"enumerate": "parallel"})
            with tones_plate:
                pyro.sample("y_{}_{}".format(i, t),
                            Bernoulli(probs_y[x.squeeze(-1)]),
                            obs=sequence[t])
Exemple #27
0
def test_get_mask_optimization():
    def model():
        x = pyro.sample("x", dist.Normal(0, 1))
        pyro.sample("y", dist.Normal(x, 1), obs=torch.tensor(0.0))
        called.add("model-always")
        if poutine.get_mask() is not False:
            called.add("model-sometimes")
            pyro.factor("f", x + 1)

    def guide():
        x = pyro.sample("x", dist.Normal(0, 1))
        called.add("guide-always")
        if poutine.get_mask() is not False:
            called.add("guide-sometimes")
            pyro.factor("g", 2 - x)

    called = set()
    trace = poutine.trace(guide).get_trace()
    poutine.replay(model, trace)()
    assert "model-always" in called
    assert "guide-always" in called
    assert "model-sometimes" in called
    assert "guide-sometimes" in called

    called = set()
    with poutine.mask(mask=False):
        trace = poutine.trace(guide).get_trace()
        poutine.replay(model, trace)()
    assert "model-always" in called
    assert "guide-always" in called
    assert "model-sometimes" not in called
    assert "guide-sometimes" not in called

    called = set()
    Predictive(model, guide=guide, num_samples=2, parallel=True)()
    assert "model-always" in called
    assert "guide-always" in called
    assert "model-sometimes" not in called
    assert "guide-sometimes" not in called
Exemple #28
0
def model_generic(config):
    """Hierarchical mixed-effects hidden markov model"""

    MISSING = config["MISSING"]
    N_v = config["sizes"]["random"]
    N_state = config["sizes"]["state"]

    # initialize group-level random effect parameterss
    if config["group"]["random"] == "discrete":
        probs_e_g = pyro.param("probs_e_group",
                               lambda: torch.randn((N_v, )).abs(),
                               constraint=constraints.simplex)
        theta_g = pyro.param("theta_group", lambda: torch.randn(
            (N_v, N_state**2)))
    elif config["group"]["random"] == "continuous":
        loc_g = torch.zeros((N_state**2, ))
        scale_g = torch.ones((N_state**2, ))

    # initialize individual-level random effect parameters
    N_c = config["sizes"]["group"]
    if config["individual"]["random"] == "discrete":
        probs_e_i = pyro.param("probs_e_individual",
                               lambda: torch.randn((
                                   N_c,
                                   N_v,
                               )).abs(),
                               constraint=constraints.simplex)
        theta_i = pyro.param("theta_individual", lambda: torch.randn(
            (N_c, N_v, N_state**2)))
    elif config["individual"]["random"] == "continuous":
        loc_i = torch.zeros((
            N_c,
            N_state**2,
        ))
        scale_i = torch.ones((
            N_c,
            N_state**2,
        ))

    # initialize likelihood parameters
    # observation 1: step size (step ~ Gamma)
    step_zi_param = pyro.param("step_zi_param", lambda: torch.ones(
        (N_state, 2)))
    step_concentration = pyro.param("step_param_concentration",
                                    lambda: torch.randn((N_state, )).abs(),
                                    constraint=constraints.positive)
    step_rate = pyro.param("step_param_rate",
                           lambda: torch.randn((N_state, )).abs(),
                           constraint=constraints.positive)

    # observation 2: step angle (angle ~ VonMises)
    angle_concentration = pyro.param("angle_param_concentration",
                                     lambda: torch.randn((N_state, )).abs(),
                                     constraint=constraints.positive)
    angle_loc = pyro.param("angle_param_loc", lambda: torch.randn(
        (N_state, )).abs())

    # observation 3: dive activity (omega ~ Beta)
    omega_zi_param = pyro.param("omega_zi_param", lambda: torch.ones(
        (N_state, 2)))
    omega_concentration0 = pyro.param("omega_param_concentration0",
                                      lambda: torch.randn((N_state, )).abs(),
                                      constraint=constraints.positive)
    omega_concentration1 = pyro.param("omega_param_concentration1",
                                      lambda: torch.randn((N_state, )).abs(),
                                      constraint=constraints.positive)

    # initialize gamma to uniform
    gamma = torch.zeros((N_state**2, ))

    N_c = config["sizes"]["group"]
    with pyro.plate("group", N_c, dim=-1):

        # group-level random effects
        if config["group"]["random"] == "discrete":
            # group-level discrete effect
            e_g = pyro.sample("e_g", dist.Categorical(probs_e_g))
            eps_g = Vindex(theta_g)[..., e_g, :]
        elif config["group"]["random"] == "continuous":
            eps_g = pyro.sample(
                "eps_g",
                dist.Normal(loc_g, scale_g).to_event(1),
            )  # infer={"num_samples": 10})
        else:
            eps_g = 0.

        # add group-level random effect to gamma
        gamma = gamma + eps_g

        N_s = config["sizes"]["individual"]
        with pyro.plate(
                "individual", N_s,
                dim=-2), poutine.mask(mask=config["individual"]["mask"]):

            # individual-level random effects
            if config["individual"]["random"] == "discrete":
                # individual-level discrete effect
                e_i = pyro.sample("e_i", dist.Categorical(probs_e_i))
                eps_i = Vindex(theta_i)[..., e_i, :]
                # assert eps_i.shape[-3:] == (1, N_c, N_state ** 2) and eps_i.shape[0] == N_v
            elif config["individual"]["random"] == "continuous":
                eps_i = pyro.sample(
                    "eps_i",
                    dist.Normal(loc_i, scale_i).to_event(1),
                )  # infer={"num_samples": 10})
            else:
                eps_i = 0.

            # add individual-level random effect to gamma
            gamma = gamma + eps_i

            y = torch.tensor(0).long()

            N_t = config["sizes"]["timesteps"]
            for t in pyro.markov(range(N_t)):
                with poutine.mask(mask=config["timestep"]["mask"][..., t]):
                    gamma_t = gamma  # per-timestep variable

                    # finally, reshape gamma as batch of transition matrices
                    gamma_t = gamma_t.reshape(
                        tuple(gamma_t.shape[:-1]) + (N_state, N_state))

                    # we've accounted for all effects, now actually compute gamma_y
                    gamma_y = Vindex(gamma_t)[..., y, :]
                    y = pyro.sample("y_{}".format(t),
                                    dist.Categorical(logits=gamma_y))

                    # observation 1: step size
                    step_dist = dist.Gamma(
                        concentration=Vindex(step_concentration)[..., y],
                        rate=Vindex(step_rate)[..., y])

                    # zero-inflation with MaskedMixture
                    step_zi = Vindex(step_zi_param)[..., y, :]
                    step_zi_mask = pyro.sample(
                        "step_zi_{}".format(t),
                        dist.Categorical(logits=step_zi),
                        obs=(config["observations"]["step"][...,
                                                            t] == MISSING))
                    step_zi_zero_dist = dist.Delta(v=torch.tensor(MISSING))
                    step_zi_dist = dist.MaskedMixture(step_zi_mask, step_dist,
                                                      step_zi_zero_dist)

                    pyro.sample("step_{}".format(t),
                                step_zi_dist,
                                obs=config["observations"]["step"][..., t])

                    # observation 2: step angle
                    angle_dist = dist.VonMises(
                        concentration=Vindex(angle_concentration)[..., y],
                        loc=Vindex(angle_loc)[..., y])
                    pyro.sample("angle_{}".format(t),
                                angle_dist,
                                obs=config["observations"]["angle"][..., t])

                    # observation 3: dive activity
                    omega_dist = dist.Beta(
                        concentration0=Vindex(omega_concentration0)[..., y],
                        concentration1=Vindex(omega_concentration1)[..., y])

                    # zero-inflation with MaskedMixture
                    omega_zi = Vindex(omega_zi_param)[..., y, :]
                    omega_zi_mask = pyro.sample(
                        "omega_zi_{}".format(t),
                        dist.Categorical(logits=omega_zi),
                        obs=(config["observations"]["omega"][...,
                                                             t] == MISSING))

                    omega_zi_zero_dist = dist.Delta(v=torch.tensor(MISSING))
                    omega_zi_dist = dist.MaskedMixture(omega_zi_mask,
                                                       omega_dist,
                                                       omega_zi_zero_dist)

                    pyro.sample("omega_{}".format(t),
                                omega_zi_dist,
                                obs=config["observations"]["omega"][..., t])
Exemple #29
0
def torus_dbn(phis=None,
              psis=None,
              lengths=None,
              num_sequences=None,
              num_states=55,
              prior_conc=0.1,
              prior_loc=0.0,
              prior_length_shape=100.,
              prior_length_rate=100.,
              prior_kappa_min=10.,
              prior_kappa_max=1000.):
    # From https://pyro.ai/examples/hmm.html
    with ignore_jit_warnings():
        if lengths is not None:
            assert num_sequences is None
            num_sequences = int(lengths.shape[0])
        else:
            assert num_sequences is not None
    transition_probs = pyro.sample(
        'transition_probs',
        dist.Dirichlet(
            torch.ones(num_states, num_states, dtype=torch.float) *
            num_states).to_event(1))
    length_shape = pyro.sample('length_shape',
                               dist.HalfCauchy(prior_length_shape))
    length_rate = pyro.sample('length_rate',
                              dist.HalfCauchy(prior_length_rate))
    phi_locs = pyro.sample(
        'phi_locs',
        dist.VonMises(
            torch.ones(num_states, dtype=torch.float) * prior_loc,
            torch.ones(num_states, dtype=torch.float) *
            prior_conc).to_event(1))
    phi_kappas = pyro.sample(
        'phi_kappas',
        dist.Uniform(
            torch.ones(num_states, dtype=torch.float) * prior_kappa_min,
            torch.ones(num_states, dtype=torch.float) *
            prior_kappa_max).to_event(1))
    psi_locs = pyro.sample(
        'psi_locs',
        dist.VonMises(
            torch.ones(num_states, dtype=torch.float) * prior_loc,
            torch.ones(num_states, dtype=torch.float) *
            prior_conc).to_event(1))
    psi_kappas = pyro.sample(
        'psi_kappas',
        dist.Uniform(
            torch.ones(num_states, dtype=torch.float) * prior_kappa_min,
            torch.ones(num_states, dtype=torch.float) *
            prior_kappa_max).to_event(1))
    element_plate = pyro.plate('elements', 1, dim=-1)
    with pyro.plate('sequences', num_sequences, dim=-2) as batch:
        if lengths is not None:
            lengths = lengths[batch]
            obs_length = lengths.float().unsqueeze(-1)
        else:
            obs_length = None
        state = 0
        sam_lengths = pyro.sample('length',
                                  dist.TransformedDistribution(
                                      dist.GammaPoisson(
                                          length_shape, length_rate),
                                      AffineTransform(0., 1.)),
                                  obs=obs_length)
        if lengths is None:
            lengths = sam_lengths.squeeze(-1).long()
        for t in pyro.markov(range(lengths.max())):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                state = pyro.sample(f'state_{t}',
                                    dist.Categorical(transition_probs[state]),
                                    infer={'enumerate': 'parallel'})
                if phis is not None:
                    obs_phi = Vindex(phis)[batch, t].unsqueeze(-1)
                else:
                    obs_phi = None
                if psis is not None:
                    obs_psi = Vindex(psis)[batch, t].unsqueeze(-1)
                else:
                    obs_psi = None
                with element_plate:
                    pyro.sample(f'phi_{t}',
                                dist.VonMises(phi_locs[state],
                                              phi_kappas[state]),
                                obs=obs_phi)
                    pyro.sample(f'psi_{t}',
                                dist.VonMises(psi_locs[state],
                                              psi_kappas[state]),
                                obs=obs_psi)
Exemple #30
0
def _predictive(model,
                posterior_samples,
                num_samples,
                return_sites=(),
                return_trace=False,
                parallel=False,
                model_args=(),
                model_kwargs={}):
    model = torch.no_grad()(poutine.mask(model, mask=False))
    max_plate_nesting = _guess_max_plate_nesting(model, model_args,
                                                 model_kwargs)
    vectorize = pyro.plate("_num_predictive_samples",
                           num_samples,
                           dim=-max_plate_nesting - 1)
    model_trace = prune_subsample_sites(
        poutine.trace(model).get_trace(*model_args, **model_kwargs))
    reshaped_samples = {}

    for name, sample in posterior_samples.items():
        sample_shape = sample.shape[1:]
        sample = sample.reshape((num_samples, ) + (1, ) *
                                (max_plate_nesting - len(sample_shape)) +
                                sample_shape)
        reshaped_samples[name] = sample

    if return_trace:
        trace = poutine.trace(poutine.condition(vectorize(model), reshaped_samples))\
            .get_trace(*model_args, **model_kwargs)
        return trace

    return_site_shapes = {}
    for site in model_trace.stochastic_nodes + model_trace.observation_nodes:
        append_ndim = max_plate_nesting - len(
            model_trace.nodes[site]["fn"].batch_shape)
        site_shape = (num_samples, ) + (
            1, ) * append_ndim + model_trace.nodes[site]['value'].shape
        # non-empty return-sites
        if return_sites:
            if site in return_sites:
                return_site_shapes[site] = site_shape
        # special case (for guides): include all sites
        elif return_sites is None:
            return_site_shapes[site] = site_shape
        # default case: return sites = ()
        # include all sites not in posterior samples
        elif site not in posterior_samples:
            return_site_shapes[site] = site_shape

    # handle _RETURN site
    if return_sites is not None and '_RETURN' in return_sites:
        value = model_trace.nodes['_RETURN']['value']
        shape = (num_samples, ) + value.shape if torch.is_tensor(
            value) else None
        return_site_shapes['_RETURN'] = shape

    if not parallel:
        return _predictive_sequential(model,
                                      posterior_samples,
                                      model_args,
                                      model_kwargs,
                                      num_samples,
                                      return_site_shapes,
                                      return_trace=False)

    trace = poutine.trace(poutine.condition(vectorize(model), reshaped_samples))\
        .get_trace(*model_args, **model_kwargs)
    predictions = {}
    for site, shape in return_site_shapes.items():
        value = trace.nodes[site]['value']
        if site == '_RETURN' and shape is None:
            predictions[site] = value
            continue
        if value.numel() < reduce((lambda x, y: x * y), shape):
            predictions[site] = value.expand(shape)
        else:
            predictions[site] = value.reshape(shape)

    return predictions