def model_0(sequences, lengths, args, batch_size=None, include_prior=True): assert not torch._C._get_tracing_state() num_sequences, max_length, data_dim = sequences.shape with handlers.mask(mask=include_prior): # Our prior on transition probabilities will be: # stay in the same state with 90% probability; uniformly jump to another # state with 10% probability. probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1)) # We put a weak prior on the conditional probability of a tone sounding. # We know that on average about 4 of 88 tones are active, so we'll set a # rough weak prior of 10% of the notes being active at any one time. probs_y = pyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2)) # In this first model we'll sequentially iterate over sequences in a # minibatch; this will make it easy to reason about tensor shapes. tones_plate = pyro.plate("tones", data_dim, dim=-1) for i in pyro.plate("sequences", len(sequences), batch_size): length = lengths[i] sequence = sequences[i, :length] x = 0 for t in pyro.markov(range(length)): # On the next line, we'll overwrite the value of x with an updated # value. If we wanted to record all x values, we could instead # write x[t] = pyro.sample(...x[t-1]...). x = pyro.sample("x_{}_{}".format(i, t), dist.Categorical(probs_x[x]), infer={"enumerate": "parallel"}) with tones_plate: pyro.sample("y_{}_{}".format(i, t), dist.Bernoulli(probs_y[x.squeeze(-1)]), obs=sequence[t])
def model(): x_plate = pyro.plate("x_plate", 5, subsample_size=2 if subsampling else None, dim=-1) y_plate = pyro.plate("y_plate", 6, subsample_size=3 if subsampling else None, dim=-2) with pyro.plate("num_particles", 50, dim=-3): with x_plate: b = pyro.sample( "b", dist.Beta(torch.tensor(1.1), torch.tensor(1.1))) with y_plate: c = pyro.sample("c", dist.Bernoulli(0.5)) with x_plate, y_plate: d = pyro.sample("d", dist.Bernoulli(b)) # check shapes if enumerate_ == "parallel": assert b.shape == (50, 1, x_plate.subsample_size) assert c.shape == (2, 1, 1, 1) assert d.shape == (2, 1, 1, 1, 1) elif enumerate_ == "sequential": assert b.shape == (50, 1, x_plate.subsample_size) assert c.shape in ((), (1, 1, 1)) # both are valid assert d.shape in ((), (1, 1, 1)) # both are valid else: assert b.shape == (50, 1, x_plate.subsample_size) assert c.shape == (50, y_plate.subsample_size, 1) assert d.shape == (50, y_plate.subsample_size, x_plate.subsample_size)
def model_3(sequences, lengths, args, batch_size=None, include_prior=True): with ignore_jit_warnings(): num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences, ) assert lengths.max() <= max_length hidden_dim = int(args.hidden_dim**0.5) # split between w and x with handlers.mask(mask=include_prior): probs_w = pyro.sample( "probs_w", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1)) probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1)) probs_y = pyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim, data_dim]).to_event(3)) tones_plate = pyro.plate("tones", data_dim, dim=-1) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] w, x = 0, 0 for t in pyro.markov(range(max_length if args.jit else lengths.max())): with handlers.mask(mask=(t < lengths).unsqueeze(-1)): w = pyro.sample("w_{}".format(t), dist.Categorical(probs_w[w]), infer={"enumerate": "parallel"}) x = pyro.sample("x_{}".format(t), dist.Categorical(probs_x[x]), infer={"enumerate": "parallel"}) with tones_plate as tones: pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y[w, x, tones]), obs=sequences[batch, t])
def model_2(sequences, lengths, args, batch_size=None, include_prior=True): with ignore_jit_warnings(): num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences, ) assert lengths.max() <= max_length with handlers.mask(mask=include_prior): probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1)) probs_y = pyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([args.hidden_dim, 2, data_dim]).to_event(3)) tones_plate = pyro.plate("tones", data_dim, dim=-1) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] x, y = 0, 0 for t in pyro.markov(range(max_length if args.jit else lengths.max())): with handlers.mask(mask=(t < lengths).unsqueeze(-1)): x = pyro.sample("x_{}".format(t), dist.Categorical(probs_x[x]), infer={"enumerate": "parallel"}) # Note the broadcasting tricks here: to index probs_y on tensors x and y, # we also need a final tensor for the tones dimension. This is conveniently # provided by the plate associated with that dimension. with tones_plate as tones: y = pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y[x, y, tones]), obs=sequences[batch, t]).long()
def model_4(sequences, lengths, args, batch_size=None, include_prior=True): with ignore_jit_warnings(): num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences,) assert lengths.max() <= max_length hidden_dim = int(args.hidden_dim**0.5) # split between w and x with handlers.mask(mask=include_prior): probs_w = pyro.sample( "probs_w", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1) ) probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1) .expand_by([hidden_dim]) .to_event(2), ) probs_y = pyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim, data_dim]).to_event(3), ) tones_plate = pyro.plate("tones", data_dim, dim=-1) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] # Note the broadcasting tricks here: we declare a hidden torch.arange and # ensure that w and x are always tensors so we can unsqueeze them below, # thus ensuring that the x sample sites have correct distribution shape. w = x = torch.tensor(0, dtype=torch.long) for t in pyro.markov(range(max_length if args.jit else lengths.max())): with handlers.mask(mask=(t < lengths).unsqueeze(-1)): w = pyro.sample( "w_{}".format(t), dist.Categorical(probs_w[w]), infer={"enumerate": "parallel"}, ) x = pyro.sample( "x_{}".format(t), dist.Categorical(Vindex(probs_x)[w, x]), infer={"enumerate": "parallel"}, ) with tones_plate as tones: pyro.sample( "y_{}".format(t), dist.Bernoulli(probs_y[w, x, tones]), obs=sequences[batch, t], )
def model_1(sequences, lengths, args, batch_size=None, include_prior=True): # Sometimes it is safe to ignore jit warnings. Here we use the # pyro.util.ignore_jit_warnings context manager to silence warnings about # conversion to integer, since we know all three numbers will be the same # across all invocations to the model. with ignore_jit_warnings(): num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences,) assert lengths.max() <= max_length with handlers.mask(mask=include_prior): probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1), ) probs_y = pyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2), ) tones_plate = pyro.plate("tones", data_dim, dim=-1) # We subsample batch_size items out of num_sequences items. Note that since # we're using dim=-1 for the notes plate, we need to batch over a different # dimension, here dim=-2. with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] x = 0 # If we are not using the jit, then we can vary the program structure # each call by running for a dynamically determined number of time # steps, lengths.max(). However if we are using the jit, then we try to # keep a single program structure for all minibatches; the fixed # structure ends up being faster since each program structure would # need to trigger a new jit compile stage. for t in pyro.markov(range(max_length if args.jit else lengths.max())): with handlers.mask(mask=(t < lengths).unsqueeze(-1)): x = pyro.sample( "x_{}".format(t), dist.Categorical(probs_x[x]), infer={"enumerate": "parallel"}, ) with tones_plate: pyro.sample( "y_{}".format(t), dist.Bernoulli(probs_y[x.squeeze(-1)]), obs=sequences[batch, t], )
def model_7(sequences, lengths, args, batch_size=None, include_prior=True): with ignore_jit_warnings(): num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences,) assert lengths.max() <= max_length with handlers.mask(mask=include_prior): probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1), ) probs_y = pyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2), ) tones_plate = pyro.plate("tones", data_dim, dim=-1) # Note that since we're using dim=-2 for the time dimension, we need # to batch sequences over a different dimension, here dim=-3. with pyro.plate("sequences", num_sequences, batch_size, dim=-3) as batch: lengths = lengths[batch] batch = batch[:, None] x_prev = 0 # To vectorize time dimension we use pyro.vectorized_markov(name=...). # With the help of Vindex and additional unsqueezes we can ensure that # dimensions line up properly. for t in pyro.vectorized_markov( name="time", size=int(max_length if args.jit else lengths.max()), dim=-2 ): with handlers.mask(mask=(t < lengths.unsqueeze(-1)).unsqueeze(-1)): x_curr = pyro.sample( "x_{}".format(t), dist.Categorical(probs_x[x_prev]), infer={"enumerate": "parallel"}, ) with tones_plate: pyro.sample( "y_{}".format(t), dist.Bernoulli(probs_y[x_curr.squeeze(-1)]), obs=Vindex(sequences)[batch, t], )