def model_3(sequences, lengths, args, batch_size=None, include_prior=True): with ignore_jit_warnings(): num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences, ) assert lengths.max() <= max_length hidden_dim = int(args.hidden_dim**0.5) # split between w and x with poutine.mask(mask=include_prior): probs_w = pyro.sample( "probs_w", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1)) probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1)) probs_y = pyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim, data_dim]).to_event(3)) tones_plate = pyro.plate("tones", data_dim, dim=-1) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] w, x = 0, 0 for t in pyro.markov(range(max_length if args.jit else lengths.max())): with poutine.mask(mask=(t < lengths).unsqueeze(-1)): w = pyro.sample("w_{}".format(t), dist.Categorical(probs_w[w]), infer={"enumerate": "parallel"}) x = pyro.sample("x_{}".format(t), dist.Categorical(probs_x[x]), infer={"enumerate": "parallel"}) with tones_plate as tones: pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y[w, x, tones]), obs=sequences[batch, t])
def model_2(sequences, lengths, args, batch_size=None, include_prior=True): with ignore_jit_warnings(): num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences, ) assert lengths.max() <= max_length with poutine.mask(mask=include_prior): probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1)) probs_y = pyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([args.hidden_dim, 2, data_dim]).to_event(3)) tones_plate = pyro.plate("tones", data_dim, dim=-1) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] x, y = 0, 0 for t in pyro.markov(range(max_length if args.jit else lengths.max())): with poutine.mask(mask=(t < lengths).unsqueeze(-1)): x = pyro.sample("x_{}".format(t), dist.Categorical(probs_x[x]), infer={"enumerate": "parallel"}) # Note the broadcasting tricks here: to index probs_y on tensors x and y, # we also need a final tensor for the tones dimension. This is conveniently # provided by the plate associated with that dimension. with tones_plate as tones: y = pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y[x, y, tones]), obs=sequences[batch, t]).long()
def model_5(sequences, lengths, args, batch_size=None, include_prior=True): with ignore_jit_warnings(): num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences, ) assert lengths.max() <= max_length # Initialize a global module instance if needed. global tones_generator if tones_generator is None: tones_generator = TonesGenerator(args, data_dim) pyro.module("tones_generator", tones_generator) with poutine.mask(mask=include_prior): probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1)) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] x = 0 y = torch.zeros(data_dim) for t in pyro.markov(range(max_length if args.jit else lengths.max())): with poutine.mask(mask=(t < lengths).unsqueeze(-1)): x = pyro.sample("x_{}".format(t), dist.Categorical(probs_x[x]), infer={"enumerate": "parallel"}) # Note that since each tone depends on all tones at a previous time step # the tones at different time steps now need to live in separate plates. with pyro.plate("tones_{}".format(t), data_dim, dim=-1): y = pyro.sample( "y_{}".format(t), dist.Bernoulli(logits=tones_generator(x, y)), obs=sequences[batch, t])
def model_4(sequences, lengths, args, batch_size=None, include_prior=True): with ignore_jit_warnings(): num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences,) assert lengths.max() <= max_length hidden_dim = int(args.hidden_dim ** 0.5) # split between w and x with poutine.mask(mask=include_prior): probs_w = pyro.sample("probs_w", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1) .to_event(1)) probs_x = pyro.sample("probs_x", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1) .expand_by([hidden_dim]) .to_event(2)) probs_y = pyro.sample("probs_y", dist.Beta(0.1, 0.9) .expand([hidden_dim, hidden_dim, data_dim]) .to_event(3)) tones_plate = pyro.plate("tones", data_dim, dim=-1) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] # Note the broadcasting tricks here: we declare a hidden torch.arange and # ensure that w and x are always tensors so we can unsqueeze them below, # thus ensuring that the x sample sites have correct distribution shape. w = x = torch.tensor(0, dtype=torch.long) for t in pyro.markov(range(max_length if args.jit else lengths.max())): with poutine.mask(mask=(t < lengths).unsqueeze(-1)): w = pyro.sample("w_{}".format(t), dist.Categorical(probs_w[w]), infer={"enumerate": "parallel"}) x = pyro.sample("x_{}".format(t), dist.Categorical(Vindex(probs_x)[w, x]), infer={"enumerate": "parallel"}) with tones_plate as tones: pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y[w, x, tones]), obs=sequences[batch, t])
def test_arg_kwarg_error(): def model(): pyro.param("p", torch.zeros(1, requires_grad=True)) pyro.sample("a", Bernoulli(torch.tensor([0.5])), infer={"enumerate": "parallel"}) pyro.sample("b", Bernoulli(torch.tensor([0.5]))) with pytest.raises(ValueError, match="not callable"): with poutine.mask(False): model() with poutine.mask(mask=False): model()
def test_get_mask(): assert get_mask() is None with poutine.mask(mask=True): assert get_mask() is True with poutine.mask(mask=False): assert get_mask() is False with pyro.plate("i", 2, dim=-1): mask1 = torch.tensor([False, True, True]) mask2 = torch.tensor([True, True, False]) with poutine.mask(mask=mask1): assert_equal(get_mask(), mask1) with poutine.mask(mask=mask2): assert_equal(get_mask(), mask1 & mask2)
def model(transition_alphas, emission_alphas, lengths, sequences=None, batch_size=None): # From https://pyro.ai/examples/hmm.html with ignore_jit_warnings(): if sequences is not None: num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences,) assert lengths.max() <= max_length else: data_dim = emission_alphas.size(1) num_sequences = int(lengths.shape[0]) max_length = int(lengths.max()) transition_probs = pyro.sample('transition_probs', dist.Dirichlet(transition_alphas).to_event(1)) emission_probs = pyro.sample('emission_probs', dist.Dirichlet(emission_alphas).to_event(2)) element_plate = pyro.plate('elements', data_dim, dim=-1) with pyro.plate('sequences', num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] state = 0 for t in pyro.markov(range(max_length)): with poutine.mask(mask=(t < lengths).unsqueeze(-1)): state = pyro.sample(f'state_{t}', dist.Categorical(transition_probs[state]), infer={'enumerate': 'parallel'}) obs_element = Vindex(sequences)[batch, t] if sequences is not None else None with element_plate: element = pyro.sample(f'element_{t}', dist.Categorical(emission_probs[state.squeeze(-1)]), obs=obs_element)
def model(num_particles=1, z=None): p = pyro.param("p", torch.tensor(0.25)) with pyro.plate("num_particles", num_particles, dim=-2): z = pyro.sample("z", dist.Bernoulli(p), obs=z) logger.info("z.shape = {}".format(z.shape)) with pyro.plate("data", 3), poutine.mask(mask=mask): pyro.sample("x", dist.Normal(z, 1.), obs=data)
def model_1(capture_history, sex): N, T = capture_history.shape phi = pyro.sample("phi", dist.Uniform(0.0, 1.0)) # survival probability rho = pyro.sample("rho", dist.Uniform(0.0, 1.0)) # recapture probability with pyro.plate("animals", N, dim=-1): z = torch.ones(N) # we use this mask to eliminate extraneous log probabilities # that arise for a given individual before its first capture. first_capture_mask = torch.zeros(N).bool() for t in pyro.markov(range(T)): with poutine.mask(mask=first_capture_mask): mu_z_t = first_capture_mask.float() * phi * z + ( 1 - first_capture_mask.float()) # we use parallel enumeration to exactly sum out # the discrete states z_t. z = pyro.sample( "z_{}".format(t), dist.Bernoulli(mu_z_t), infer={"enumerate": "parallel"}, ) mu_y_t = rho * z pyro.sample("y_{}".format(t), dist.Bernoulli(mu_y_t), obs=capture_history[:, t]) first_capture_mask |= capture_history[:, t].bool()
def model_2(capture_history, sex): N, T = capture_history.shape rho = pyro.sample("rho", dist.Uniform(0.0, 1.0)) # recapture probability z = torch.ones(N) first_capture_mask = torch.zeros(N).bool() # we create the plate once, outside of the loop over t animals_plate = pyro.plate("animals", N, dim=-1) for t in pyro.markov(range(T)): # note that phi_t needs to be outside the plate, since # phi_t is shared across all N individuals phi_t = pyro.sample("phi_{}".format(t), dist.Uniform(0.0, 1.0)) if t > 0 \ else 1.0 with animals_plate, poutine.mask(mask=first_capture_mask): mu_z_t = first_capture_mask.float() * phi_t * z + ( 1 - first_capture_mask.float()) # we use parallel enumeration to exactly sum out # the discrete states z_t. z = pyro.sample("z_{}".format(t), dist.Bernoulli(mu_z_t), infer={"enumerate": "parallel"}) mu_y_t = rho * z pyro.sample("y_{}".format(t), dist.Bernoulli(mu_y_t), obs=capture_history[:, t]) first_capture_mask |= capture_history[:, t].bool()
def model_3(capture_history, sex): def logit(p): return torch.log(p) - torch.log1p(-p) N, T = capture_history.shape phi_mean = pyro.sample("phi_mean", dist.Uniform(0.0, 1.0)) # mean survival probability phi_logit_mean = logit(phi_mean) # controls temporal variability of survival probability phi_sigma = pyro.sample("phi_sigma", dist.Uniform(0.0, 10.0)) rho = pyro.sample("rho", dist.Uniform(0.0, 1.0)) # recapture probability z = torch.ones(N) first_capture_mask = torch.zeros(N).bool() # we create the plate once, outside of the loop over t animals_plate = pyro.plate("animals", N, dim=-1) for t in pyro.markov(range(T)): phi_logit_t = pyro.sample("phi_logit_{}".format(t), dist.Normal(phi_logit_mean, phi_sigma)) if t > 0 \ else torch.tensor(0.0) phi_t = torch.sigmoid(phi_logit_t) with animals_plate, poutine.mask(mask=first_capture_mask): mu_z_t = first_capture_mask.float() * phi_t * z + ( 1 - first_capture_mask.float()) # we use parallel enumeration to exactly sum out # the discrete states z_t. z = pyro.sample("z_{}".format(t), dist.Bernoulli(mu_z_t), infer={"enumerate": "parallel"}) mu_y_t = rho * z pyro.sample("y_{}".format(t), dist.Bernoulli(mu_y_t), obs=capture_history[:, t]) first_capture_mask |= capture_history[:, t].bool()
def model_4(capture_history, sex): N, T = capture_history.shape # survival probabilities for males/females phi_male = pyro.sample("phi_male", dist.Uniform(0.0, 1.0)) phi_female = pyro.sample("phi_female", dist.Uniform(0.0, 1.0)) # we construct a N-dimensional vector that contains the appropriate # phi for each individual given its sex (female = 0, male = 1) phi = sex * phi_male + (1.0 - sex) * phi_female rho = pyro.sample("rho", dist.Uniform(0.0, 1.0)) # recapture probability with pyro.plate("animals", N, dim=-1): z = torch.ones(N) # we use this mask to eliminate extraneous log probabilities # that arise for a given individual before its first capture. first_capture_mask = torch.zeros(N).bool() for t in pyro.markov(range(T)): with poutine.mask(mask=first_capture_mask): mu_z_t = first_capture_mask.float() * phi * z + ( 1 - first_capture_mask.float()) # we use parallel enumeration to exactly sum out # the discrete states z_t. z = pyro.sample("z_{}".format(t), dist.Bernoulli(mu_z_t), infer={"enumerate": "parallel"}) mu_y_t = rho * z pyro.sample("y_{}".format(t), dist.Bernoulli(mu_y_t), obs=capture_history[:, t]) first_capture_mask |= capture_history[:, t].bool()
def model_7(sequences, lengths, args, batch_size=None, include_prior=True): with ignore_jit_warnings(): num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences, ) assert lengths.max() <= max_length # Initialize a global module instance if needed. global tones_generator if tones_generator is None: tones_generator = TonesGenerator(args, data_dim) pyro.module("tones_generator", tones_generator) with poutine.mask(mask=include_prior): probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1), ) with pyro.plate("sequences", num_sequences, batch_size, dim=-1) as batch: lengths = lengths[batch] y = sequences[batch] if args.jit else sequences[batch, :lengths.max()] x = torch.arange(args.hidden_dim) t = torch.arange(y.size(1)) init_logits = torch.full((args.hidden_dim, ), -float("inf")) init_logits[0] = 0 trans_logits = probs_x.log() with ignore_jit_warnings(): obs_dist = dist.Bernoulli( logits=tones_generator(x, y.unsqueeze(-2))).to_event(1) obs_dist = obs_dist.mask((t < lengths.unsqueeze(-1)).unsqueeze(-1)) hmm_dist = dist.DiscreteHMM(init_logits, trans_logits, obs_dist) pyro.sample("y", hmm_dist, obs=y)
def model_0(sequences, lengths, args, batch_size=None, include_prior=True): assert not torch._C._get_tracing_state() num_sequences, max_length, data_dim = sequences.shape with poutine.mask(mask=include_prior): # Our prior on transition probabilities will be: # stay in the same state with 90% probability; uniformly jump to another # state with 10% probability. probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1)) # We put a weak prior on the conditional probability of a tone sounding. # We know that on average about 4 of 88 tones are active, so we'll set a # rough weak prior of 10% of the notes being active at any one time. probs_y = pyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2)) # In this first model we'll sequentially iterate over sequences in a # minibatch; this will make it easy to reason about tensor shapes. tones_plate = pyro.plate("tones", data_dim, dim=-1) for i in pyro.plate("sequences", len(sequences), batch_size): length = lengths[i] sequence = sequences[i, :length] x = 0 for t in pyro.markov(range(length)): # On the next line, we'll overwrite the value of x with an updated # value. If we wanted to record all x values, we could instead # write x[t] = pyro.sample(...x[t-1]...). x = pyro.sample("x_{}_{}".format(i, t), dist.Categorical(probs_x[x]), infer={"enumerate": "parallel"}) with tones_plate: pyro.sample("y_{}_{}".format(i, t), dist.Bernoulli(probs_y[x.squeeze(-1)]), obs=sequence[t])
def model(detections, args): noise_scale = pyro.param('noise_scale') objects = pyro.param('objects_loc').squeeze(-1) num_detections, = detections.shape max_num_objects, = objects.shape # Existence part. p_exists = args.expected_num_objects / max_num_objects with pyro.plate('objects_plate', max_num_objects): exists = pyro.sample('exists', dist.Bernoulli(p_exists)) with poutine.mask(mask=exists.bool()): pyro.sample('objects', dist.Normal(0., 1.), obs=objects) # Assignment part. p_fake = args.num_fake_detections / num_detections with pyro.plate('detections_plate', num_detections): assign_probs = torch.empty(max_num_objects + 1) assign_probs[:-1] = (1 - p_fake) / max_num_objects assign_probs[-1] = p_fake assign = pyro.sample('assign', dist.Categorical(logits=assign_probs)) is_fake = (assign == assign.shape[-1] - 1) objects_plus_bogus = torch.zeros(max_num_objects + 1) objects_plus_bogus[:max_num_objects] = objects real_dist = dist.Normal(objects_plus_bogus[assign], noise_scale) fake_dist = dist.Normal(0., 1.) pyro.sample('detections', dist.MaskedMixture(is_fake, real_dist, fake_dist), obs=detections)
def model_5(capture_history, sex): N, T = capture_history.shape # phi_beta controls the survival probability differential # for males versus females (in logit space) phi_beta = pyro.sample("phi_beta", dist.Normal(0.0, 10.0)) phi_beta = sex * phi_beta rho = pyro.sample("rho", dist.Uniform(0.0, 1.0)) # recapture probability z = torch.ones(N) first_capture_mask = torch.zeros(N).bool() # we create the plate once, outside of the loop over t animals_plate = pyro.plate("animals", N, dim=-1) for t in pyro.markov(range(T)): phi_gamma_t = pyro.sample("phi_gamma_{}".format(t), dist.Normal(0.0, 10.0)) if t > 0 \ else 0.0 phi_t = torch.sigmoid(phi_beta + phi_gamma_t) with animals_plate, poutine.mask(mask=first_capture_mask): mu_z_t = first_capture_mask.float() * phi_t * z + ( 1 - first_capture_mask.float()) # we use parallel enumeration to exactly sum out # the discrete states z_t. z = pyro.sample("z_{}".format(t), dist.Bernoulli(mu_z_t), infer={"enumerate": "parallel"}) mu_y_t = rho * z pyro.sample("y_{}".format(t), dist.Bernoulli(mu_y_t), obs=capture_history[:, t]) first_capture_mask |= capture_history[:, t].bool()
def model_1(sequences, lengths, args, batch_size=None, include_prior=True): # Sometimes it is safe to ignore jit warnings. Here we use the # pyro.util.ignore_jit_warnings context manager to silence warnings about # conversion to integer, since we know all three numbers will be the same # across all invocations to the model. with ignore_jit_warnings(): num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences, ) assert lengths.max() <= max_length with poutine.mask(mask=include_prior): probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1), ) probs_y = pyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2), ) tones_plate = pyro.plate("tones", data_dim, dim=-1) # We subsample batch_size items out of num_sequences items. Note that since # we're using dim=-1 for the notes plate, we need to batch over a different # dimension, here dim=-2. with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] x = 0 # If we are not using the jit, then we can vary the program structure # each call by running for a dynamically determined number of time # steps, lengths.max(). However if we are using the jit, then we try to # keep a single program structure for all minibatches; the fixed # structure ends up being faster since each program structure would # need to trigger a new jit compile stage. for t in pyro.markov(range(max_length if args.jit else lengths.max())): with poutine.mask(mask=(t < lengths).unsqueeze(-1)): x = pyro.sample( "x_{}".format(t), dist.Categorical(probs_x[x]), infer={"enumerate": "parallel"}, ) with tones_plate: pyro.sample( "y_{}".format(t), dist.Bernoulli(probs_y[x.squeeze(-1)]), obs=sequences[batch, t], )
def median(self, *args, **kwargs) -> Dict[str, torch.Tensor]: """ Returns the posterior median value of each latent variable. :return: A dict mapping sample site name to median tensor. :rtype: dict """ with torch.no_grad(), poutine.mask(mask=False): aux_values = self._sample_aux_values(temperature=0.0) values, _ = self._transform_values(aux_values) return values
def model_6(sequences, lengths, args, batch_size=None, include_prior=False): num_sequences, max_length, data_dim = sequences.shape assert lengths.shape == (num_sequences, ) assert lengths.max() <= max_length hidden_dim = args.hidden_dim hidden = torch.arange(hidden_dim, dtype=torch.long) if not args.raftery_parameterization: # Explicitly parameterize the full tensor of transition probabilities, which # has hidden_dim cubed entries. probs_x = pyro.param("probs_x", torch.rand(hidden_dim, hidden_dim, hidden_dim), constraint=constraints.simplex) else: # Use the more parsimonious "Raftery" parameterization of # the tensor of transition probabilities. See reference: # Raftery, A. E. A model for high-order markov chains. # Journal of the Royal Statistical Society. 1985. probs_x1 = pyro.param("probs_x1", torch.rand(hidden_dim, hidden_dim), constraint=constraints.simplex) probs_x2 = pyro.param("probs_x2", torch.rand(hidden_dim, hidden_dim), constraint=constraints.simplex) mix_lambda = pyro.param("mix_lambda", torch.tensor(0.5), constraint=constraints.unit_interval) # we use broadcasting to combine two tensors of shape (hidden_dim, hidden_dim) and # (hidden_dim, 1, hidden_dim) to obtain a tensor of shape (hidden_dim, hidden_dim, hidden_dim) probs_x = mix_lambda * probs_x1 + (1.0 - mix_lambda) * probs_x2.unsqueeze(-2) probs_y = pyro.param("probs_y", torch.rand(hidden_dim, data_dim), constraint=constraints.unit_interval) tones_plate = pyro.plate("tones", data_dim, dim=-1) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] x_curr, x_prev = torch.tensor(0), torch.tensor(0) # we need to pass the argument `history=2' to `pyro.markov()` # since our model is now 2-markov for t in pyro.markov(range(lengths.max()), history=2): with poutine.mask(mask=(t < lengths).unsqueeze(-1)): probs_x_t = probs_x[x_prev.unsqueeze(-1), x_curr.unsqueeze(-1), hidden] x_prev, x_curr = x_curr, pyro.sample( "x_{}".format(t), dist.Categorical(probs_x_t), infer={"enumerate": "parallel"}) with tones_plate: probs_y_t = probs_y[x_curr.squeeze(-1)] pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y_t), obs=sequences[batch, t])
def _masked_observe(name, fn, obs, obs_mask, *args, **kwargs): # Split into two auxiliary sample sites. with poutine.mask(mask=obs_mask): observed = sample(f"{name}_observed", fn, *args, **kwargs, obs=obs) with poutine.mask(mask=~obs_mask): unobserved = sample(f"{name}_unobserved", fn, *args, **kwargs) # Interleave observed and unobserved events. shape = obs_mask.shape + (1, ) * fn.event_dim batch_mask = obs_mask.reshape(shape) try: value = torch.where(batch_mask, observed, unobserved) except RuntimeError as e: if "must match the size of tensor" in str(e): shape = torch.broadcast_shapes(observed.shape, unobserved.shape) batch_shape = shape[:len(shape) - fn.event_dim] raise ValueError( f"Invalid obs_mask shape {tuple(obs_mask.shape)}; should be " f"broadcastable to batch_shape = {tuple(batch_shape)}") from e raise return deterministic(name, value)
def guide_generic(config): """generic mean-field guide for continuous random effects""" N_state = config["sizes"]["state"] if config["group"]["random"] == "continuous": loc_g = pyro.param("loc_group", lambda: torch.zeros((N_state**2, ))) scale_g = pyro.param( "scale_group", lambda: torch.ones((N_state**2, )), constraint=constraints.positive, ) # initialize individual-level random effect parameters N_c = config["sizes"]["group"] if config["individual"]["random"] == "continuous": loc_i = pyro.param( "loc_individual", lambda: torch.zeros(( N_c, N_state**2, )), ) scale_i = pyro.param( "scale_individual", lambda: torch.ones(( N_c, N_state**2, )), constraint=constraints.positive, ) N_c = config["sizes"]["group"] with pyro.plate("group", N_c, dim=-1): if config["group"]["random"] == "continuous": pyro.sample( "eps_g", dist.Normal(loc_g, scale_g).to_event(1), ) # infer={"num_samples": 10}) N_s = config["sizes"]["individual"] with pyro.plate( "individual", N_s, dim=-2), poutine.mask(mask=config["individual"]["mask"]): # individual-level random effects if config["individual"]["random"] == "continuous": pyro.sample( "eps_i", dist.Normal(loc_i, scale_i).to_event(1), ) # infer={"num_samples": 10})
def transform_samples(self, aux_samples, save_params=None): """ Given latent samples from the warped posterior (with a possible batch dimension), return a `dict` of samples from the latent sites in the model. :param dict aux_samples: Dict site name to tensor value for each latent auxiliary site (or if ``save_params`` is specifiec, then for only those latent auxiliary sites needed to compute requested params). :param list save_params: An optional list of site names to save. This is useful in models with large nuisance variables. Defaults to None, saving all params. :return: a `dict` of samples keyed by latent sites in the model. :rtype: dict """ with poutine.condition(data=aux_samples), poutine.mask(mask=False): deltas = self.guide.get_deltas(save_params) return {name: delta.v for name, delta in deltas.items()}
def _forward_pyro_mean_field(self, features, trip_counts): total_hours = len(features) observed_hours, num_origins, num_destins = trip_counts.shape assert observed_hours <= total_hours assert num_origins == self.num_stations assert num_destins == self.num_stations time_plate = pyro.plate("time", observed_hours, dim=-3) origins_plate = pyro.plate("origins", num_origins, dim=-2) destins_plate = pyro.plate("destins", num_destins, dim=-1) init_dist, trans_matrix, trans_dist, obs_matrix, obs_dist = \ self._dynamics(features[:observed_hours]) # This is a parallelizable crf representation of the HMM. # We first pull random variables from the guide, masking all factors. with poutine.mask(mask=False): shape = (1 + observed_hours, self.args.state_dim) # includes init state = pyro.sample("state", dist.Normal(0, 1).expand(shape).to_event(2)) shape = (observed_hours, 2 * num_origins * num_destins) gate_rate = pyro.sample( "gate_rate", dist.Normal(0, 1).expand(shape).to_event(2)) # We then declare CRF factors. pyro.sample("init", init_dist, obs=state[0]) pyro.sample("trans", trans_dist.expand((observed_hours, )).to_event(1), obs=state[..., 1:, :] - state[..., :-1, :] @ trans_matrix) pyro.sample("obs", obs_dist.expand((observed_hours, )).to_event(1), obs=gate_rate - state[..., 1:, :] @ obs_matrix) gate, rate = self._unpack_gate_rate(gate_rate, event_dim=2) with time_plate, origins_plate, destins_plate: pyro.sample("trip_count", dist.ZeroInflatedPoisson(gate, rate), obs=trip_counts) # The second half of the model forecasts forward. if total_hours > observed_hours: return self._forward_pyro_forecast(features, trip_counts, origins_plate, destins_plate, state=state[..., -1, :])
def guide(sequences): theta = pyro.param("theta", torch.ones(16)) alpha = pyro.param("alpha", torch.rand(1)) beta = pyro.param("beta", torch.rand(1)) p = pyro.param("p", torch.rand(1)) q = pyro.param("q", torch.rand(1)) w = p * torch.eye(16) + q with poutine.mask(mask=False): probs_x = pyro.sample("probs_x", Dirichlet(w).to_event(1)) probs_y = pyro.sample("probs_y", Beta(alpha, beta).expand([16, 51]).to_event(2)) for i in pyro.plate("sequences", len(sequences), 8): length = lengths[i] sequence = sequences[i, :length] x = 0 for t in pyro.markov(range(length)): x = pyro.sample("x_{}_{}".format(i, t), Categorical(probs_x[x]))
def model(self, lengths=None, sequences=None, expected_string_length: int = 5): with ignore_jit_warnings(): assert sequences is None or lengths is not None assert lengths is None or lengths.max( ) <= self.smct.max_chain_length assert sequences is None or ( 0 <= sequences.min() and sequences.max() < self.smct.alphabet_size) binom_prob = pyro.sample( 'binom_prob', dist.Beta( min(1, expected_string_length), min(1, self.smct.max_chain_length - expected_string_length))) lengths_size = 1 if sequences is None else sequences.size(0) with pyro.plate('lengths_plate', size=lengths_size, dim=-1): lengths = pyro.sample( 'lengths', dist.Binomial(self.smct.max_chain_length, binom_prob), obs=(lengths.float() if lengths is not None else lengths)) if lengths.dim() == 0: lengths = lengths.unsqueeze(-1) sequence_size = 1 if sequences is None else sequences.size(0) with pyro.plate('sequences_plate', size=sequence_size, dim=-2) as batch: lengths = lengths[batch] prev = () for t in pyro.markov(range(self.smct.max_chain_length), history=self.smct.order): if len(prev) > self.smct.order: prev = prev[1:] probs_t = pyro.sample( f'probs_{t}', dist.Dirichlet( self.smct.get_pseudocounts(prev).unsqueeze(-2))) x_t = None if sequences is None else sequences[batch, t] with poutine.mask( mask=(t < lengths).unsqueeze(-1).unsqueeze(-1)): x_t = pyro.sample(f'x_{t}', dist.Categorical(probs=probs_t), obs=x_t) prev = (*prev, x_t)
def model(sequences): with poutine.mask(mask=False): probs_x = pyro.sample("probs_x", Dirichlet(0.9 * torch.eye(16) + 0.1).to_event(1)) probs_y = pyro.sample("probs_y", Beta(0.1, 0.9).expand([16, 51]).to_event(2)) tones_plate = pyro.plate("tones", 51, dim=-1) for i in pyro.plate("sequences", len(sequences)): length = lengths[i] sequence = sequences[i, :length] x = 0 for t in pyro.markov(range(length)): x = pyro.sample("x_{}_{}".format(i, t), Categorical(probs_x[x]), infer={"enumerate": "parallel"}) with tones_plate: pyro.sample("y_{}_{}".format(i, t), Bernoulli(probs_y[x.squeeze(-1)]), obs=sequence[t])
def test_get_mask_optimization(): def model(): x = pyro.sample("x", dist.Normal(0, 1)) pyro.sample("y", dist.Normal(x, 1), obs=torch.tensor(0.0)) called.add("model-always") if poutine.get_mask() is not False: called.add("model-sometimes") pyro.factor("f", x + 1) def guide(): x = pyro.sample("x", dist.Normal(0, 1)) called.add("guide-always") if poutine.get_mask() is not False: called.add("guide-sometimes") pyro.factor("g", 2 - x) called = set() trace = poutine.trace(guide).get_trace() poutine.replay(model, trace)() assert "model-always" in called assert "guide-always" in called assert "model-sometimes" in called assert "guide-sometimes" in called called = set() with poutine.mask(mask=False): trace = poutine.trace(guide).get_trace() poutine.replay(model, trace)() assert "model-always" in called assert "guide-always" in called assert "model-sometimes" not in called assert "guide-sometimes" not in called called = set() Predictive(model, guide=guide, num_samples=2, parallel=True)() assert "model-always" in called assert "guide-always" in called assert "model-sometimes" not in called assert "guide-sometimes" not in called
def model_generic(config): """Hierarchical mixed-effects hidden markov model""" MISSING = config["MISSING"] N_v = config["sizes"]["random"] N_state = config["sizes"]["state"] # initialize group-level random effect parameterss if config["group"]["random"] == "discrete": probs_e_g = pyro.param("probs_e_group", lambda: torch.randn((N_v, )).abs(), constraint=constraints.simplex) theta_g = pyro.param("theta_group", lambda: torch.randn( (N_v, N_state**2))) elif config["group"]["random"] == "continuous": loc_g = torch.zeros((N_state**2, )) scale_g = torch.ones((N_state**2, )) # initialize individual-level random effect parameters N_c = config["sizes"]["group"] if config["individual"]["random"] == "discrete": probs_e_i = pyro.param("probs_e_individual", lambda: torch.randn(( N_c, N_v, )).abs(), constraint=constraints.simplex) theta_i = pyro.param("theta_individual", lambda: torch.randn( (N_c, N_v, N_state**2))) elif config["individual"]["random"] == "continuous": loc_i = torch.zeros(( N_c, N_state**2, )) scale_i = torch.ones(( N_c, N_state**2, )) # initialize likelihood parameters # observation 1: step size (step ~ Gamma) step_zi_param = pyro.param("step_zi_param", lambda: torch.ones( (N_state, 2))) step_concentration = pyro.param("step_param_concentration", lambda: torch.randn((N_state, )).abs(), constraint=constraints.positive) step_rate = pyro.param("step_param_rate", lambda: torch.randn((N_state, )).abs(), constraint=constraints.positive) # observation 2: step angle (angle ~ VonMises) angle_concentration = pyro.param("angle_param_concentration", lambda: torch.randn((N_state, )).abs(), constraint=constraints.positive) angle_loc = pyro.param("angle_param_loc", lambda: torch.randn( (N_state, )).abs()) # observation 3: dive activity (omega ~ Beta) omega_zi_param = pyro.param("omega_zi_param", lambda: torch.ones( (N_state, 2))) omega_concentration0 = pyro.param("omega_param_concentration0", lambda: torch.randn((N_state, )).abs(), constraint=constraints.positive) omega_concentration1 = pyro.param("omega_param_concentration1", lambda: torch.randn((N_state, )).abs(), constraint=constraints.positive) # initialize gamma to uniform gamma = torch.zeros((N_state**2, )) N_c = config["sizes"]["group"] with pyro.plate("group", N_c, dim=-1): # group-level random effects if config["group"]["random"] == "discrete": # group-level discrete effect e_g = pyro.sample("e_g", dist.Categorical(probs_e_g)) eps_g = Vindex(theta_g)[..., e_g, :] elif config["group"]["random"] == "continuous": eps_g = pyro.sample( "eps_g", dist.Normal(loc_g, scale_g).to_event(1), ) # infer={"num_samples": 10}) else: eps_g = 0. # add group-level random effect to gamma gamma = gamma + eps_g N_s = config["sizes"]["individual"] with pyro.plate( "individual", N_s, dim=-2), poutine.mask(mask=config["individual"]["mask"]): # individual-level random effects if config["individual"]["random"] == "discrete": # individual-level discrete effect e_i = pyro.sample("e_i", dist.Categorical(probs_e_i)) eps_i = Vindex(theta_i)[..., e_i, :] # assert eps_i.shape[-3:] == (1, N_c, N_state ** 2) and eps_i.shape[0] == N_v elif config["individual"]["random"] == "continuous": eps_i = pyro.sample( "eps_i", dist.Normal(loc_i, scale_i).to_event(1), ) # infer={"num_samples": 10}) else: eps_i = 0. # add individual-level random effect to gamma gamma = gamma + eps_i y = torch.tensor(0).long() N_t = config["sizes"]["timesteps"] for t in pyro.markov(range(N_t)): with poutine.mask(mask=config["timestep"]["mask"][..., t]): gamma_t = gamma # per-timestep variable # finally, reshape gamma as batch of transition matrices gamma_t = gamma_t.reshape( tuple(gamma_t.shape[:-1]) + (N_state, N_state)) # we've accounted for all effects, now actually compute gamma_y gamma_y = Vindex(gamma_t)[..., y, :] y = pyro.sample("y_{}".format(t), dist.Categorical(logits=gamma_y)) # observation 1: step size step_dist = dist.Gamma( concentration=Vindex(step_concentration)[..., y], rate=Vindex(step_rate)[..., y]) # zero-inflation with MaskedMixture step_zi = Vindex(step_zi_param)[..., y, :] step_zi_mask = pyro.sample( "step_zi_{}".format(t), dist.Categorical(logits=step_zi), obs=(config["observations"]["step"][..., t] == MISSING)) step_zi_zero_dist = dist.Delta(v=torch.tensor(MISSING)) step_zi_dist = dist.MaskedMixture(step_zi_mask, step_dist, step_zi_zero_dist) pyro.sample("step_{}".format(t), step_zi_dist, obs=config["observations"]["step"][..., t]) # observation 2: step angle angle_dist = dist.VonMises( concentration=Vindex(angle_concentration)[..., y], loc=Vindex(angle_loc)[..., y]) pyro.sample("angle_{}".format(t), angle_dist, obs=config["observations"]["angle"][..., t]) # observation 3: dive activity omega_dist = dist.Beta( concentration0=Vindex(omega_concentration0)[..., y], concentration1=Vindex(omega_concentration1)[..., y]) # zero-inflation with MaskedMixture omega_zi = Vindex(omega_zi_param)[..., y, :] omega_zi_mask = pyro.sample( "omega_zi_{}".format(t), dist.Categorical(logits=omega_zi), obs=(config["observations"]["omega"][..., t] == MISSING)) omega_zi_zero_dist = dist.Delta(v=torch.tensor(MISSING)) omega_zi_dist = dist.MaskedMixture(omega_zi_mask, omega_dist, omega_zi_zero_dist) pyro.sample("omega_{}".format(t), omega_zi_dist, obs=config["observations"]["omega"][..., t])
def torus_dbn(phis=None, psis=None, lengths=None, num_sequences=None, num_states=55, prior_conc=0.1, prior_loc=0.0, prior_length_shape=100., prior_length_rate=100., prior_kappa_min=10., prior_kappa_max=1000.): # From https://pyro.ai/examples/hmm.html with ignore_jit_warnings(): if lengths is not None: assert num_sequences is None num_sequences = int(lengths.shape[0]) else: assert num_sequences is not None transition_probs = pyro.sample( 'transition_probs', dist.Dirichlet( torch.ones(num_states, num_states, dtype=torch.float) * num_states).to_event(1)) length_shape = pyro.sample('length_shape', dist.HalfCauchy(prior_length_shape)) length_rate = pyro.sample('length_rate', dist.HalfCauchy(prior_length_rate)) phi_locs = pyro.sample( 'phi_locs', dist.VonMises( torch.ones(num_states, dtype=torch.float) * prior_loc, torch.ones(num_states, dtype=torch.float) * prior_conc).to_event(1)) phi_kappas = pyro.sample( 'phi_kappas', dist.Uniform( torch.ones(num_states, dtype=torch.float) * prior_kappa_min, torch.ones(num_states, dtype=torch.float) * prior_kappa_max).to_event(1)) psi_locs = pyro.sample( 'psi_locs', dist.VonMises( torch.ones(num_states, dtype=torch.float) * prior_loc, torch.ones(num_states, dtype=torch.float) * prior_conc).to_event(1)) psi_kappas = pyro.sample( 'psi_kappas', dist.Uniform( torch.ones(num_states, dtype=torch.float) * prior_kappa_min, torch.ones(num_states, dtype=torch.float) * prior_kappa_max).to_event(1)) element_plate = pyro.plate('elements', 1, dim=-1) with pyro.plate('sequences', num_sequences, dim=-2) as batch: if lengths is not None: lengths = lengths[batch] obs_length = lengths.float().unsqueeze(-1) else: obs_length = None state = 0 sam_lengths = pyro.sample('length', dist.TransformedDistribution( dist.GammaPoisson( length_shape, length_rate), AffineTransform(0., 1.)), obs=obs_length) if lengths is None: lengths = sam_lengths.squeeze(-1).long() for t in pyro.markov(range(lengths.max())): with poutine.mask(mask=(t < lengths).unsqueeze(-1)): state = pyro.sample(f'state_{t}', dist.Categorical(transition_probs[state]), infer={'enumerate': 'parallel'}) if phis is not None: obs_phi = Vindex(phis)[batch, t].unsqueeze(-1) else: obs_phi = None if psis is not None: obs_psi = Vindex(psis)[batch, t].unsqueeze(-1) else: obs_psi = None with element_plate: pyro.sample(f'phi_{t}', dist.VonMises(phi_locs[state], phi_kappas[state]), obs=obs_phi) pyro.sample(f'psi_{t}', dist.VonMises(psi_locs[state], psi_kappas[state]), obs=obs_psi)
def _predictive(model, posterior_samples, num_samples, return_sites=(), return_trace=False, parallel=False, model_args=(), model_kwargs={}): model = torch.no_grad()(poutine.mask(model, mask=False)) max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs) vectorize = pyro.plate("_num_predictive_samples", num_samples, dim=-max_plate_nesting - 1) model_trace = prune_subsample_sites( poutine.trace(model).get_trace(*model_args, **model_kwargs)) reshaped_samples = {} for name, sample in posterior_samples.items(): sample_shape = sample.shape[1:] sample = sample.reshape((num_samples, ) + (1, ) * (max_plate_nesting - len(sample_shape)) + sample_shape) reshaped_samples[name] = sample if return_trace: trace = poutine.trace(poutine.condition(vectorize(model), reshaped_samples))\ .get_trace(*model_args, **model_kwargs) return trace return_site_shapes = {} for site in model_trace.stochastic_nodes + model_trace.observation_nodes: append_ndim = max_plate_nesting - len( model_trace.nodes[site]["fn"].batch_shape) site_shape = (num_samples, ) + ( 1, ) * append_ndim + model_trace.nodes[site]['value'].shape # non-empty return-sites if return_sites: if site in return_sites: return_site_shapes[site] = site_shape # special case (for guides): include all sites elif return_sites is None: return_site_shapes[site] = site_shape # default case: return sites = () # include all sites not in posterior samples elif site not in posterior_samples: return_site_shapes[site] = site_shape # handle _RETURN site if return_sites is not None and '_RETURN' in return_sites: value = model_trace.nodes['_RETURN']['value'] shape = (num_samples, ) + value.shape if torch.is_tensor( value) else None return_site_shapes['_RETURN'] = shape if not parallel: return _predictive_sequential(model, posterior_samples, model_args, model_kwargs, num_samples, return_site_shapes, return_trace=False) trace = poutine.trace(poutine.condition(vectorize(model), reshaped_samples))\ .get_trace(*model_args, **model_kwargs) predictions = {} for site, shape in return_site_shapes.items(): value = trace.nodes[site]['value'] if site == '_RETURN' and shape is None: predictions[site] = value continue if value.numel() < reduce((lambda x, y: x * y), shape): predictions[site] = value.expand(shape) else: predictions[site] = value.reshape(shape) return predictions