def model(): with pyro.markov() as m: with pyro.markov(): with m: # error here pyro.sample("x", dist.Categorical(torch.ones(4)), infer={"enumerate": "parallel"})
def testing(): with pyro.markov(): v1 = pyro.to_data( Tensor(torch.ones(2), OrderedDict([(str(1), funsor.Bint[2])]), 'real')) print(1, v1.shape) # shapes should alternate assert v1.shape == (2, ) with pyro.markov(): v2 = pyro.to_data( Tensor(torch.ones(2), OrderedDict([(str(2), funsor.Bint[2])]), 'real')) print(2, v2.shape) # shapes should alternate assert v2.shape == (2, 1) with pyro.markov(): v3 = pyro.to_data( Tensor(torch.ones(2), OrderedDict([(str(3), funsor.Bint[2])]), 'real')) print(3, v3.shape) # shapes should alternate assert v3.shape == (2, ) with pyro.markov(): v4 = pyro.to_data( Tensor(torch.ones(2), OrderedDict([(str(4), funsor.Bint[2])]), 'real')) print(4, v4.shape) # shapes should alternate assert v4.shape == (2, 1)
def model_8(weeks_data, days_data, history, vectorized): x_dim, y_dim, w_dim, z_dim = 3, 2, 2, 3 x_init = pyro.param("x_init", lambda: torch.rand(x_dim), constraint=constraints.simplex) x_trans = pyro.param("x_trans", lambda: torch.rand((x_dim, x_dim)), constraint=constraints.simplex) y_probs = pyro.param("y_probs", lambda: torch.rand(x_dim, y_dim), constraint=constraints.simplex) w_init = pyro.param("w_init", lambda: torch.rand(w_dim), constraint=constraints.simplex) w_trans = pyro.param("w_trans", lambda: torch.rand((w_dim, w_dim)), constraint=constraints.simplex) z_probs = pyro.param("z_probs", lambda: torch.rand(w_dim, z_dim), constraint=constraints.simplex) x_prev = None weeks_loop = (pyro.vectorized_markov( name="weeks", size=len(weeks_data), dim=-1, history=history) if vectorized else pyro.markov(range(len(weeks_data)), history=history)) for i in weeks_loop: if isinstance(i, int) and i == 0: x_probs = x_init else: x_probs = Vindex(x_trans)[x_prev] x_curr = pyro.sample("x_{}".format(i), dist.Categorical(x_probs)) pyro.sample( "y_{}".format(i), dist.Categorical(Vindex(y_probs)[x_curr]), obs=weeks_data[i], ) x_prev = x_curr w_prev = None days_loop = (pyro.vectorized_markov( name="days", size=len(days_data), dim=-1, history=history) if vectorized else pyro.markov(range(len(days_data)), history=history)) for j in days_loop: if isinstance(j, int) and j == 0: w_probs = w_init else: w_probs = Vindex(w_trans)[w_prev] w_curr = pyro.sample("w_{}".format(j), dist.Categorical(w_probs)) pyro.sample( "z_{}".format(j), dist.Categorical(Vindex(z_probs)[w_curr]), obs=days_data[j], ) w_prev = w_curr
def model(): p = pyro.param("p", torch.ones(3, 3)) x = pyro.sample("x", dist.Categorical(p[0])) y = x for i in pyro.markov(range(10)): y = pyro.sample("y_{}".format(i), dist.Categorical(p[y])) z = y for j in pyro.markov(range(10)): z = pyro.sample("z_{}_{}".format(i, j), dist.Categorical(p[z]))
def model(): p = pyro.param("p_leaf", torch.ones(2, 2, 2)) x = defaultdict(lambda: torch.tensor(0)) y_axis = pyro.markov(range(grid_size), keep=True) for i in pyro.markov(range(grid_size)): for j in y_axis: if use_vindex: probs = Vindex(p)[x[i - 1, j], x[i, j - 1]] else: ind = torch.arange(2, dtype=torch.long) probs = p[x[i - 1, j].unsqueeze(-1), x[i, j - 1].unsqueeze(-1), ind] x[i, j] = pyro.sample("x_{}_{}".format(i, j), dist.Categorical(probs))
def model(): p = pyro.param("p", torch.ones(3, 3)) q = pyro.param("q", torch.tensor([0.5, 0.5])) plate_x = pyro.plate("plate_x", 4, subsample_size=3 if subsampling else None, dim=-1) plate_y = pyro.plate("plate_y", 5, subsample_size=3 if subsampling else None, dim=-1) plate_z = pyro.plate("plate_z", 6, subsample_size=3 if subsampling else None, dim=-2) a = pyro.sample("a", dist.Bernoulli(q[0])).long() w = 0 for i in pyro.markov(range(4)): w = pyro.sample("w_{}".format(i), dist.Categorical(p[w])) with plate_x: b = pyro.sample("b", dist.Bernoulli(q[a])).long() x = 0 for i in pyro.markov(range(4)): x = pyro.sample("x_{}".format(i), dist.Categorical(p[x])) with plate_y: c = pyro.sample("c", dist.Bernoulli(q[a])).long() y = 0 for i in pyro.markov(range(4)): y = pyro.sample("y_{}".format(i), dist.Categorical(p[y])) with plate_z: d = pyro.sample("d", dist.Bernoulli(q[a])).long() z = 0 for i in pyro.markov(range(4)): z = pyro.sample("z_{}".format(i), dist.Categorical(p[z])) with plate_x, plate_z: # this part is tricky: how do we know to preserve b's dimension? # also, how do we know how to make b and d have different dimensions? e = pyro.sample("e", dist.Bernoulli(q[b if reuse_plate else a])).long() xz = 0 for i in pyro.markov(range(4)): xz = pyro.sample("xz_{}".format(i), dist.Categorical(p[xz])) return a, b, c, d, e
def model_2(sequences, lengths, args, batch_size=None, include_prior=True): with ignore_jit_warnings(): num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences, ) assert lengths.max() <= max_length with handlers.mask(mask=include_prior): probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1)) probs_y = pyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([args.hidden_dim, 2, data_dim]).to_event(3)) tones_plate = pyro.plate("tones", data_dim, dim=-1) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] x, y = 0, 0 for t in pyro.markov(range(max_length if args.jit else lengths.max())): with handlers.mask(mask=(t < lengths).unsqueeze(-1)): x = pyro.sample("x_{}".format(t), dist.Categorical(probs_x[x]), infer={"enumerate": "parallel"}) # Note the broadcasting tricks here: to index probs_y on tensors x and y, # we also need a final tensor for the tones dimension. This is conveniently # provided by the plate associated with that dimension. with tones_plate as tones: y = pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y[x, y, tones]), obs=sequences[batch, t]).long()
def model_5(data, history, vectorized): x_dim, y_dim = 3, 2 x_init = pyro.param("x_init", lambda: torch.rand(x_dim), constraint=constraints.simplex) x_init_2 = pyro.param("x_init_2", lambda: torch.rand(x_dim, x_dim), constraint=constraints.simplex) x_trans = pyro.param( "x_trans", lambda: torch.rand((x_dim, x_dim, x_dim)), constraint=constraints.simplex, ) y_probs = pyro.param("y_probs", lambda: torch.rand(x_dim, y_dim), constraint=constraints.simplex) x_prev = x_prev_2 = None markov_loop = (pyro.vectorized_markov( name="time", size=len(data), dim=-2, history=history) if vectorized else pyro.markov(range(len(data)), history=history)) for i in markov_loop: if isinstance(i, int) and i == 0: x_probs = x_init elif isinstance(i, int) and i == 1: x_probs = Vindex(x_init_2)[x_prev] else: x_probs = Vindex(x_trans)[x_prev_2, x_prev] x_curr = pyro.sample("x_{}".format(i), dist.Categorical(x_probs)) with pyro.plate("tones", data.shape[-1], dim=-1): pyro.sample("y_{}".format(i), dist.Categorical(Vindex(y_probs)[x_curr]), obs=data[i]) x_prev_2, x_prev = x_prev, x_curr
def model_5(sequences, lengths, args, batch_size=None, include_prior=True): with ignore_jit_warnings(): num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences, ) assert lengths.max() <= max_length # Initialize a global module instance if needed. global tones_generator if tones_generator is None: tones_generator = TonesGenerator(args, data_dim) pyro.module("tones_generator", tones_generator) with handlers.mask(mask=include_prior): probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1)) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] x = 0 y = torch.zeros(data_dim) for t in pyro.markov(range(max_length if args.jit else lengths.max())): with handlers.mask(mask=(t < lengths).unsqueeze(-1)): x = pyro.sample("x_{}".format(t), dist.Categorical(probs_x[x]), infer={"enumerate": "parallel"}) # Note that since each tone depends on all tones at a previous time step # the tones at different time steps now need to live in separate plates. with pyro.plate("tones_{}".format(t), data_dim, dim=-1): y = pyro.sample( "y_{}".format(t), dist.Bernoulli(logits=tones_generator(x, y)), obs=sequences[batch, t])
def model(): p = torch.tensor([[0.2, 0.8], [0.1, 0.9]]) xs = [0] for t in pyro.markov(range(100), history=history): xs.append(pyro.sample("x_{}".format(t), dist.Categorical(p[xs[-1]]))) assert all(x.dim() <= history + 1 for x in xs[1:])
def model_0(sequences, lengths, args, batch_size=None, include_prior=True): assert not torch._C._get_tracing_state() num_sequences, max_length, data_dim = sequences.shape with handlers.mask(mask=include_prior): # Our prior on transition probabilities will be: # stay in the same state with 90% probability; uniformly jump to another # state with 10% probability. probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1)) # We put a weak prior on the conditional probability of a tone sounding. # We know that on average about 4 of 88 tones are active, so we'll set a # rough weak prior of 10% of the notes being active at any one time. probs_y = pyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2)) # In this first model we'll sequentially iterate over sequences in a # minibatch; this will make it easy to reason about tensor shapes. tones_plate = pyro.plate("tones", data_dim, dim=-1) for i in pyro.plate("sequences", len(sequences), batch_size): length = lengths[i] sequence = sequences[i, :length] x = 0 for t in pyro.markov(range(length)): # On the next line, we'll overwrite the value of x with an updated # value. If we wanted to record all x values, we could instead # write x[t] = pyro.sample(...x[t-1]...). x = pyro.sample("x_{}_{}".format(i, t), dist.Categorical(probs_x[x]), infer={"enumerate": "parallel"}) with tones_plate: pyro.sample("y_{}_{}".format(i, t), dist.Bernoulli(probs_y[x.squeeze(-1)]), obs=sequence[t])
def model_0(data, history, vectorized): x_dim = 3 init = pyro.param("init", lambda: torch.rand(x_dim), constraint=constraints.simplex) trans = pyro.param("trans", lambda: torch.rand((x_dim, x_dim)), constraint=constraints.simplex) locs = pyro.param("locs", lambda: torch.rand(x_dim)) with pyro.plate("sequences", data.shape[0], dim=-3) as sequences: sequences = sequences[:, None] x_prev = None markov_loop = \ pyro.vectorized_markov(name="time", size=data.shape[1], dim=-2, history=history) if vectorized \ else pyro.markov(range(data.shape[1]), history=history) for i in markov_loop: x_curr = pyro.sample( "x_{}".format(i), dist.Categorical( init if isinstance(i, int) and i < 1 else trans[x_prev])) with pyro.plate("tones", data.shape[2], dim=-1): pyro.sample("y_{}".format(i), dist.Normal(Vindex(locs)[..., x_curr], 1.), obs=Vindex(data)[sequences, i]) x_prev = x_curr
def model_3(sequences, lengths, args, batch_size=None, include_prior=True): with ignore_jit_warnings(): num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences, ) assert lengths.max() <= max_length hidden_dim = int(args.hidden_dim**0.5) # split between w and x with handlers.mask(mask=include_prior): probs_w = pyro.sample( "probs_w", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1)) probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1)) probs_y = pyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim, data_dim]).to_event(3)) tones_plate = pyro.plate("tones", data_dim, dim=-1) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] w, x = 0, 0 for t in pyro.markov(range(max_length if args.jit else lengths.max())): with handlers.mask(mask=(t < lengths).unsqueeze(-1)): w = pyro.sample("w_{}".format(t), dist.Categorical(probs_w[w]), infer={"enumerate": "parallel"}) x = pyro.sample("x_{}".format(t), dist.Categorical(probs_x[x]), infer={"enumerate": "parallel"}) with tones_plate as tones: pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y[w, x, tones]), obs=sequences[batch, t])
def model_1(data, history, vectorized): x_dim = 3 init = pyro.param("init", lambda: torch.rand(x_dim), constraint=constraints.simplex) trans = pyro.param("trans", lambda: torch.rand((x_dim, x_dim)), constraint=constraints.simplex) locs = pyro.param("locs", lambda: torch.rand(x_dim)) x_prev = None markov_loop = (pyro.vectorized_markov( name="time", size=len(data), dim=-2, history=history) if vectorized else pyro.markov(range(len(data)), history=history)) for i in markov_loop: x_curr = pyro.sample( "x_{}".format(i), dist.Categorical( init if isinstance(i, int) and i < 1 else trans[x_prev]), ) with pyro.plate("tones", data.shape[-1], dim=-1): pyro.sample( "y_{}".format(i), dist.Normal(Vindex(locs)[..., x_curr], 1.0), obs=data[i], ) x_prev = x_curr
def model_6(data, history, vectorized): x_dim = 3 x_init = pyro.param("x_init", lambda: torch.rand(x_dim), constraint=constraints.simplex) x_trans = pyro.param("x_trans", lambda: torch.rand((len(data) - 1, x_dim, x_dim)), constraint=constraints.simplex) locs = pyro.param("locs", lambda: torch.rand(x_dim)) x_prev = None markov_loop = \ pyro.vectorized_markov(name="time", size=len(data), dim=-2, history=history) if vectorized \ else pyro.markov(range(len(data)), history=history) for i in markov_loop: if isinstance(i, int) and i < 1: x_probs = x_init elif isinstance(i, int): x_probs = x_trans[i - 1, x_prev] else: x_probs = Vindex(x_trans)[(i - 1)[:, None], x_prev] x_curr = pyro.sample("x_{}".format(i), dist.Categorical(x_probs)) with pyro.plate("tones", data.shape[-1], dim=-1): pyro.sample("y_{}".format(i), dist.Normal(Vindex(locs)[..., x_curr], 1.), obs=data[i]) x_prev = x_curr
def model_2(data, history, vectorized): x_dim, y_dim = 3, 2 x_init = pyro.param("x_init", lambda: torch.rand(x_dim), constraint=constraints.simplex) x_trans = pyro.param("x_trans", lambda: torch.rand((x_dim, x_dim)), constraint=constraints.simplex) y_init = pyro.param("y_init", lambda: torch.rand(x_dim, y_dim), constraint=constraints.simplex) y_trans = pyro.param("y_trans", lambda: torch.rand((x_dim, y_dim, y_dim)), constraint=constraints.simplex) x_prev = y_prev = None markov_loop = \ pyro.vectorized_markov(name="time", size=len(data), dim=-2, history=history) if vectorized \ else pyro.markov(range(len(data)), history=history) for i in markov_loop: x_curr = pyro.sample( "x_{}".format(i), dist.Categorical( x_init if isinstance(i, int) and i < 1 else x_trans[x_prev])) with pyro.plate("tones", data.shape[-1], dim=-1): y_curr = pyro.sample( "y_{}".format(i), dist.Categorical(y_init[x_curr] if isinstance(i, int) and i < 1 else Vindex(y_trans)[x_curr, y_prev]), obs=data[i]) x_prev, y_prev = x_curr, y_curr
def testing(): for i in pyro.markov(range(12)): if i % 4 == 0: fv2 = pyro.to_funsor(torch.zeros(2), funsor.Real, dim_to_name={-1: 'a'}) v2 = pyro.to_data(fv2) assert v2.shape == (2,) print('a', v2.shape) print('a', fv2.inputs)
def model(): p = torch.tensor([[0.2, 0.8], [0.1, 0.9]]) xs = [0] for t in pyro.markov(range(10), history=history): xs.append(pyro.sample("x_{}".format(t), dist.Categorical(p[xs[-1]]), infer={"enumerate": ("sequential", "parallel")[t % 2]})) assert all(x.dim() <= history + 1 for x in xs[1:])
def testing(): for i in pyro.markov(range(12)): if i % 4 == 0: v2 = pyro.to_data(Tensor(torch.zeros(2), OrderedDict([('a', funsor.Bint[2])]), 'real')) fv2 = pyro.to_funsor(v2, funsor.Real) assert v2.shape == (2,) print('a', v2.shape) print('a', fv2.inputs)
def model(): p = pyro.param("p", 0.25 * torch.ones(2, 2)) q = pyro.param("q", 0.25 * torch.ones(2)) x_prev = torch.tensor(0) x_curr = torch.tensor(0) for t in pyro.markov(range(10), history=history): probs = p[x_prev, x_curr] x_prev, x_curr = x_curr, pyro.sample("x_{}".format(t), dist.Bernoulli(probs)).long() pyro.sample("y_{}".format(t), dist.Bernoulli(q[x_curr]), obs=torch.tensor(0.))
def model(): p = torch.tensor([[0.2, 0.8], [0.1, 0.9]]) xs = [0] c = pyro.markov(history=history) with contextlib.ExitStack() as stack: for t in range(100): stack.enter_context(c) xs.append(pyro.sample("x_{}".format(t), dist.Categorical(p[xs[-1]]))) assert all(x.dim() <= history + 1 for x in xs[1:])
def hmm(data, hidden_dim=10): transition = 0.3 / hidden_dim + 0.7 * torch.eye(hidden_dim) means = torch.arange(float(hidden_dim)) states = [0] for t in pyro.markov(range(len(data))): states.append( pyro.sample("states_{}".format(t), dist.Categorical(transition[states[-1]]))) data[t] = pyro.sample("obs_{}".format(t), dist.Normal(means[states[-1]], 1.0), obs=data[t]) return states, data
def model_6(sequences, lengths, args, batch_size=None, include_prior=False): num_sequences, max_length, data_dim = sequences.shape assert lengths.shape == (num_sequences, ) assert lengths.max() <= max_length hidden_dim = args.hidden_dim if not args.raftery_parameterization: # Explicitly parameterize the full tensor of transition probabilities, which # has hidden_dim cubed entries. probs_x = pyro.param("probs_x", torch.rand(hidden_dim, hidden_dim, hidden_dim), constraint=constraints.simplex) else: # Use the more parsimonious "Raftery" parameterization of # the tensor of transition probabilities. See reference: # Raftery, A. E. A model for high-order markov chains. # Journal of the Royal Statistical Society. 1985. probs_x1 = pyro.param("probs_x1", torch.rand(hidden_dim, hidden_dim), constraint=constraints.simplex) probs_x2 = pyro.param("probs_x2", torch.rand(hidden_dim, hidden_dim), constraint=constraints.simplex) mix_lambda = pyro.param("mix_lambda", torch.tensor(0.5), constraint=constraints.unit_interval) # we use broadcasting to combine two tensors of shape (hidden_dim, hidden_dim) and # (hidden_dim, 1, hidden_dim) to obtain a tensor of shape (hidden_dim, hidden_dim, hidden_dim) probs_x = mix_lambda * probs_x1 + (1.0 - mix_lambda) * probs_x2.unsqueeze(-2) probs_y = pyro.param("probs_y", torch.rand(hidden_dim, data_dim), constraint=constraints.unit_interval) tones_plate = pyro.plate("tones", data_dim, dim=-1) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] x_curr, x_prev = torch.tensor(0), torch.tensor(0) # we need to pass the argument `history=2' to `pyro.markov()` # since our model is now 2-markov for t in pyro.markov(range(lengths.max()), history=2): with handlers.mask(mask=(t < lengths).unsqueeze(-1)): probs_x_t = Vindex(probs_x)[x_prev, x_curr] x_prev, x_curr = x_curr, pyro.sample( "x_{}".format(t), dist.Categorical(probs_x_t), infer={"enumerate": "parallel"}) with tones_plate: probs_y_t = probs_y[x_curr.squeeze(-1)] pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y_t), obs=sequences[batch, t])
def testing(): for i in pyro.markov(range(5)): v1 = pyro.to_data(Tensor(torch.ones(2), OrderedDict([(str(i), funsor.Bint[2])]), 'real')) v2 = pyro.to_data(Tensor(torch.zeros(2), OrderedDict([('a', funsor.Bint[2])]), 'real')) fv1 = pyro.to_funsor(v1, funsor.Real) fv2 = pyro.to_funsor(v2, funsor.Real) print(i, v1.shape) # shapes should alternate if i % 2 == 0: assert v1.shape == (2,) else: assert v1.shape == (2, 1, 1) assert v2.shape == (2, 1) print(i, fv1.inputs) print('a', v2.shape) # shapes should stay the same print('a', fv2.inputs)
def testing(): for i in pyro.markov(range(5)): fv1 = pyro.to_funsor(torch.zeros(2), funsor.Real, dim_to_name={-1: str(i)}) fv2 = pyro.to_funsor(torch.ones(2), funsor.Real, dim_to_name={-1: "a"}) v1 = pyro.to_data(fv1) v2 = pyro.to_data(fv2) print(i, v1.shape) # shapes should alternate if i % 2 == 0: assert v1.shape == (2,) else: assert v1.shape == (2, 1, 1) assert v2.shape == (2, 1) print(i, fv1.inputs) print('a', v2.shape) # shapes should stay the same print('a', fv2.inputs)
def model(): p = pyro.param("p", torch.ones(3, 3)) q = pyro.param("q", torch.ones(2)) r = pyro.param("r", torch.ones(3, 2, 4)) x = 0 times = pyro.markov(range(100)) if markov else range(11) for t in times: x = pyro.sample("x_{}".format(t), dist.Categorical(p[x])) y = pyro.sample("y_{}".format(t), dist.Categorical(q)) if use_vindex: probs = Vindex(r)[x, y] else: z_ind = torch.arange(4, dtype=torch.long) probs = r[x.unsqueeze(-1), y.unsqueeze(-1), z_ind] pyro.sample("z_{}".format(t), dist.Categorical(probs), obs=torch.tensor(0.))
def model_4(sequences, lengths, args, batch_size=None, include_prior=True): with ignore_jit_warnings(): num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences,) assert lengths.max() <= max_length hidden_dim = int(args.hidden_dim**0.5) # split between w and x with handlers.mask(mask=include_prior): probs_w = pyro.sample( "probs_w", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1) ) probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1) .expand_by([hidden_dim]) .to_event(2), ) probs_y = pyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim, data_dim]).to_event(3), ) tones_plate = pyro.plate("tones", data_dim, dim=-1) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] # Note the broadcasting tricks here: we declare a hidden torch.arange and # ensure that w and x are always tensors so we can unsqueeze them below, # thus ensuring that the x sample sites have correct distribution shape. w = x = torch.tensor(0, dtype=torch.long) for t in pyro.markov(range(max_length if args.jit else lengths.max())): with handlers.mask(mask=(t < lengths).unsqueeze(-1)): w = pyro.sample( "w_{}".format(t), dist.Categorical(probs_w[w]), infer={"enumerate": "parallel"}, ) x = pyro.sample( "x_{}".format(t), dist.Categorical(Vindex(probs_x)[w, x]), infer={"enumerate": "parallel"}, ) with tones_plate as tones: pyro.sample( "y_{}".format(t), dist.Bernoulli(probs_y[w, x, tones]), obs=sequences[batch, t], )
def model_10(data, history, vectorized): init_probs = torch.tensor([0.5, 0.5]) transition_probs = pyro.param("transition_probs", torch.tensor([[0.75, 0.25], [0.25, 0.75]]), constraint=constraints.simplex) emission_probs = pyro.param("emission_probs", torch.tensor([[0.75, 0.25], [0.25, 0.75]]), constraint=constraints.simplex) x = None markov_loop = \ pyro.vectorized_markov(name="time", size=len(data), history=history) if vectorized \ else pyro.markov(range(len(data)), history=history) for i in markov_loop: probs = init_probs if x is None else transition_probs[x] x = pyro.sample("x_{}".format(i), dist.Categorical(probs)) pyro.sample("y_{}".format(i), dist.Categorical(emission_probs[x]), obs=data[i])
def model_4(data, history, vectorized): w_dim, x_dim, y_dim = 2, 3, 2 w_init = pyro.param("w_init", lambda: torch.rand(w_dim), constraint=constraints.simplex) w_trans = pyro.param("w_trans", lambda: torch.rand((w_dim, w_dim)), constraint=constraints.simplex) x_init = pyro.param("x_init", lambda: torch.rand(w_dim, x_dim), constraint=constraints.simplex) x_trans = pyro.param( "x_trans", lambda: torch.rand((w_dim, x_dim, x_dim)), constraint=constraints.simplex, ) y_probs = pyro.param( "y_probs", lambda: torch.rand(w_dim, x_dim, y_dim), constraint=constraints.simplex, ) w_prev = x_prev = None markov_loop = (pyro.vectorized_markov( name="time", size=len(data), dim=-2, history=history) if vectorized else pyro.markov(range(len(data)), history=history)) for i in markov_loop: w_curr = pyro.sample( "w_{}".format(i), dist.Categorical( w_init if isinstance(i, int) and i < 1 else w_trans[w_prev]), ) x_curr = pyro.sample( "x_{}".format(i), dist.Categorical(x_init[w_curr] if isinstance(i, int) and i < 1 else x_trans[w_curr, x_prev]), ) with pyro.plate("tones", data.shape[-1], dim=-1): pyro.sample( "y_{}".format(i), dist.Categorical(Vindex(y_probs)[w_curr, x_curr]), obs=data[i], ) x_prev, w_prev = x_curr, w_curr
def model_1(sequences, lengths, args, batch_size=None, include_prior=True): # Sometimes it is safe to ignore jit warnings. Here we use the # pyro.util.ignore_jit_warnings context manager to silence warnings about # conversion to integer, since we know all three numbers will be the same # across all invocations to the model. with ignore_jit_warnings(): num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences,) assert lengths.max() <= max_length with handlers.mask(mask=include_prior): probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1), ) probs_y = pyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2), ) tones_plate = pyro.plate("tones", data_dim, dim=-1) # We subsample batch_size items out of num_sequences items. Note that since # we're using dim=-1 for the notes plate, we need to batch over a different # dimension, here dim=-2. with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] x = 0 # If we are not using the jit, then we can vary the program structure # each call by running for a dynamically determined number of time # steps, lengths.max(). However if we are using the jit, then we try to # keep a single program structure for all minibatches; the fixed # structure ends up being faster since each program structure would # need to trigger a new jit compile stage. for t in pyro.markov(range(max_length if args.jit else lengths.max())): with handlers.mask(mask=(t < lengths).unsqueeze(-1)): x = pyro.sample( "x_{}".format(t), dist.Categorical(probs_x[x]), infer={"enumerate": "parallel"}, ) with tones_plate: pyro.sample( "y_{}".format(t), dist.Bernoulli(probs_y[x.squeeze(-1)]), obs=sequences[batch, t], )