def model_8(weeks_data, days_data, history, vectorized): x_dim, y_dim, w_dim, z_dim = 3, 2, 2, 3 x_init = pyro.param("x_init", lambda: torch.rand(x_dim), constraint=constraints.simplex) x_trans = pyro.param("x_trans", lambda: torch.rand((x_dim, x_dim)), constraint=constraints.simplex) y_probs = pyro.param("y_probs", lambda: torch.rand(x_dim, y_dim), constraint=constraints.simplex) w_init = pyro.param("w_init", lambda: torch.rand(w_dim), constraint=constraints.simplex) w_trans = pyro.param("w_trans", lambda: torch.rand((w_dim, w_dim)), constraint=constraints.simplex) z_probs = pyro.param("z_probs", lambda: torch.rand(w_dim, z_dim), constraint=constraints.simplex) x_prev = None weeks_loop = (pyro.vectorized_markov( name="weeks", size=len(weeks_data), dim=-1, history=history) if vectorized else pyro.markov(range(len(weeks_data)), history=history)) for i in weeks_loop: if isinstance(i, int) and i == 0: x_probs = x_init else: x_probs = Vindex(x_trans)[x_prev] x_curr = pyro.sample("x_{}".format(i), dist.Categorical(x_probs)) pyro.sample( "y_{}".format(i), dist.Categorical(Vindex(y_probs)[x_curr]), obs=weeks_data[i], ) x_prev = x_curr w_prev = None days_loop = (pyro.vectorized_markov( name="days", size=len(days_data), dim=-1, history=history) if vectorized else pyro.markov(range(len(days_data)), history=history)) for j in days_loop: if isinstance(j, int) and j == 0: w_probs = w_init else: w_probs = Vindex(w_trans)[w_prev] w_curr = pyro.sample("w_{}".format(j), dist.Categorical(w_probs)) pyro.sample( "z_{}".format(j), dist.Categorical(Vindex(z_probs)[w_curr]), obs=days_data[j], ) w_prev = w_curr
def model_2(data, history, vectorized): x_dim, y_dim = 3, 2 x_init = pyro.param("x_init", lambda: torch.rand(x_dim), constraint=constraints.simplex) x_trans = pyro.param("x_trans", lambda: torch.rand((x_dim, x_dim)), constraint=constraints.simplex) y_init = pyro.param("y_init", lambda: torch.rand(x_dim, y_dim), constraint=constraints.simplex) y_trans = pyro.param("y_trans", lambda: torch.rand((x_dim, y_dim, y_dim)), constraint=constraints.simplex) x_prev = y_prev = None markov_loop = \ pyro.vectorized_markov(name="time", size=len(data), dim=-2, history=history) if vectorized \ else pyro.markov(range(len(data)), history=history) for i in markov_loop: x_curr = pyro.sample( "x_{}".format(i), dist.Categorical( x_init if isinstance(i, int) and i < 1 else x_trans[x_prev])) with pyro.plate("tones", data.shape[-1], dim=-1): y_curr = pyro.sample( "y_{}".format(i), dist.Categorical(y_init[x_curr] if isinstance(i, int) and i < 1 else Vindex(y_trans)[x_curr, y_prev]), obs=data[i]) x_prev, y_prev = x_curr, y_curr
def model_0(data, history, vectorized): x_dim = 3 init = pyro.param("init", lambda: torch.rand(x_dim), constraint=constraints.simplex) trans = pyro.param("trans", lambda: torch.rand((x_dim, x_dim)), constraint=constraints.simplex) locs = pyro.param("locs", lambda: torch.rand(x_dim)) with pyro.plate("sequences", data.shape[0], dim=-3) as sequences: sequences = sequences[:, None] x_prev = None markov_loop = \ pyro.vectorized_markov(name="time", size=data.shape[1], dim=-2, history=history) if vectorized \ else pyro.markov(range(data.shape[1]), history=history) for i in markov_loop: x_curr = pyro.sample( "x_{}".format(i), dist.Categorical( init if isinstance(i, int) and i < 1 else trans[x_prev])) with pyro.plate("tones", data.shape[2], dim=-1): pyro.sample("y_{}".format(i), dist.Normal(Vindex(locs)[..., x_curr], 1.), obs=Vindex(data)[sequences, i]) x_prev = x_curr
def model_6(data, history, vectorized): x_dim = 3 x_init = pyro.param("x_init", lambda: torch.rand(x_dim), constraint=constraints.simplex) x_trans = pyro.param("x_trans", lambda: torch.rand((len(data) - 1, x_dim, x_dim)), constraint=constraints.simplex) locs = pyro.param("locs", lambda: torch.rand(x_dim)) x_prev = None markov_loop = \ pyro.vectorized_markov(name="time", size=len(data), dim=-2, history=history) if vectorized \ else pyro.markov(range(len(data)), history=history) for i in markov_loop: if isinstance(i, int) and i < 1: x_probs = x_init elif isinstance(i, int): x_probs = x_trans[i - 1, x_prev] else: x_probs = Vindex(x_trans)[(i - 1)[:, None], x_prev] x_curr = pyro.sample("x_{}".format(i), dist.Categorical(x_probs)) with pyro.plate("tones", data.shape[-1], dim=-1): pyro.sample("y_{}".format(i), dist.Normal(Vindex(locs)[..., x_curr], 1.), obs=data[i]) x_prev = x_curr
def model_1(data, history, vectorized): x_dim = 3 init = pyro.param("init", lambda: torch.rand(x_dim), constraint=constraints.simplex) trans = pyro.param("trans", lambda: torch.rand((x_dim, x_dim)), constraint=constraints.simplex) locs = pyro.param("locs", lambda: torch.rand(x_dim)) x_prev = None markov_loop = (pyro.vectorized_markov( name="time", size=len(data), dim=-2, history=history) if vectorized else pyro.markov(range(len(data)), history=history)) for i in markov_loop: x_curr = pyro.sample( "x_{}".format(i), dist.Categorical( init if isinstance(i, int) and i < 1 else trans[x_prev]), ) with pyro.plate("tones", data.shape[-1], dim=-1): pyro.sample( "y_{}".format(i), dist.Normal(Vindex(locs)[..., x_curr], 1.0), obs=data[i], ) x_prev = x_curr
def model_5(data, history, vectorized): x_dim, y_dim = 3, 2 x_init = pyro.param("x_init", lambda: torch.rand(x_dim), constraint=constraints.simplex) x_init_2 = pyro.param("x_init_2", lambda: torch.rand(x_dim, x_dim), constraint=constraints.simplex) x_trans = pyro.param( "x_trans", lambda: torch.rand((x_dim, x_dim, x_dim)), constraint=constraints.simplex, ) y_probs = pyro.param("y_probs", lambda: torch.rand(x_dim, y_dim), constraint=constraints.simplex) x_prev = x_prev_2 = None markov_loop = (pyro.vectorized_markov( name="time", size=len(data), dim=-2, history=history) if vectorized else pyro.markov(range(len(data)), history=history)) for i in markov_loop: if isinstance(i, int) and i == 0: x_probs = x_init elif isinstance(i, int) and i == 1: x_probs = Vindex(x_init_2)[x_prev] else: x_probs = Vindex(x_trans)[x_prev_2, x_prev] x_curr = pyro.sample("x_{}".format(i), dist.Categorical(x_probs)) with pyro.plate("tones", data.shape[-1], dim=-1): pyro.sample("y_{}".format(i), dist.Categorical(Vindex(y_probs)[x_curr]), obs=data[i]) x_prev_2, x_prev = x_prev, x_curr
def model_10(data, history, vectorized): init_probs = torch.tensor([0.5, 0.5]) transition_probs = pyro.param("transition_probs", torch.tensor([[0.75, 0.25], [0.25, 0.75]]), constraint=constraints.simplex) emission_probs = pyro.param("emission_probs", torch.tensor([[0.75, 0.25], [0.25, 0.75]]), constraint=constraints.simplex) x = None markov_loop = \ pyro.vectorized_markov(name="time", size=len(data), history=history) if vectorized \ else pyro.markov(range(len(data)), history=history) for i in markov_loop: probs = init_probs if x is None else transition_probs[x] x = pyro.sample("x_{}".format(i), dist.Categorical(probs)) pyro.sample("y_{}".format(i), dist.Categorical(emission_probs[x]), obs=data[i])
def model_4(data, history, vectorized): w_dim, x_dim, y_dim = 2, 3, 2 w_init = pyro.param("w_init", lambda: torch.rand(w_dim), constraint=constraints.simplex) w_trans = pyro.param("w_trans", lambda: torch.rand((w_dim, w_dim)), constraint=constraints.simplex) x_init = pyro.param("x_init", lambda: torch.rand(w_dim, x_dim), constraint=constraints.simplex) x_trans = pyro.param( "x_trans", lambda: torch.rand((w_dim, x_dim, x_dim)), constraint=constraints.simplex, ) y_probs = pyro.param( "y_probs", lambda: torch.rand(w_dim, x_dim, y_dim), constraint=constraints.simplex, ) w_prev = x_prev = None markov_loop = (pyro.vectorized_markov( name="time", size=len(data), dim=-2, history=history) if vectorized else pyro.markov(range(len(data)), history=history)) for i in markov_loop: w_curr = pyro.sample( "w_{}".format(i), dist.Categorical( w_init if isinstance(i, int) and i < 1 else w_trans[w_prev]), ) x_curr = pyro.sample( "x_{}".format(i), dist.Categorical(x_init[w_curr] if isinstance(i, int) and i < 1 else x_trans[w_curr, x_prev]), ) with pyro.plate("tones", data.shape[-1], dim=-1): pyro.sample( "y_{}".format(i), dist.Categorical(Vindex(y_probs)[w_curr, x_curr]), obs=data[i], ) x_prev, w_prev = x_curr, w_curr
def model_7(sequences, lengths, args, batch_size=None, include_prior=True): with ignore_jit_warnings(): num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences,) assert lengths.max() <= max_length with handlers.mask(mask=include_prior): probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1), ) probs_y = pyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2), ) tones_plate = pyro.plate("tones", data_dim, dim=-1) # Note that since we're using dim=-2 for the time dimension, we need # to batch sequences over a different dimension, here dim=-3. with pyro.plate("sequences", num_sequences, batch_size, dim=-3) as batch: lengths = lengths[batch] batch = batch[:, None] x_prev = 0 # To vectorize time dimension we use pyro.vectorized_markov(name=...). # With the help of Vindex and additional unsqueezes we can ensure that # dimensions line up properly. for t in pyro.vectorized_markov( name="time", size=int(max_length if args.jit else lengths.max()), dim=-2 ): with handlers.mask(mask=(t < lengths.unsqueeze(-1)).unsqueeze(-1)): x_curr = pyro.sample( "x_{}".format(t), dist.Categorical(probs_x[x_prev]), infer={"enumerate": "parallel"}, ) with tones_plate: pyro.sample( "y_{}".format(t), dist.Bernoulli(probs_y[x_curr.squeeze(-1)]), obs=Vindex(sequences)[batch, t], )
def model(self): """ **Generative Model** """ # global parameters gain = pyro.sample("gain", dist.HalfNormal(self.priors["gain_std"])) init = pyro.sample( "init", dist.Dirichlet(torch.ones(self.Q, self.S + 1) / (self.S + 1)).to_event(1), ) init = expand_offtarget(init) trans = pyro.sample( "trans", dist.Dirichlet( torch.ones(self.Q, self.S + 1, self.S + 1) / (self.S + 1)).to_event(2), ) trans = expand_offtarget(trans) lamda = pyro.sample( "lamda", dist.Exponential(torch.full( (self.Q, ), self.priors["lamda_rate"])).to_event(1), ) proximity = pyro.sample( "proximity", dist.Exponential(self.priors["proximity_rate"])) size = torch.stack( ( torch.full_like(proximity, 2.0), (((self.data.P + 1) / (2 * proximity))**2 - 1), ), dim=-1, ) # spots spots = pyro.plate("spots", self.K) # aoi sites aois = pyro.plate( "aois", self.data.Nt, subsample=self.n, subsample_size=self.nbatch_size, dim=-3, ) # time frames frames = (pyro.vectorized_markov( name="frames", size=self.data.F, dim=-2) if self.vectorized else pyro.markov(range(self.data.F))) # color channels channels = pyro.plate( "channels", self.data.C, dim=-1, ) with channels as cdx, aois as ndx: ndx = ndx[:, None, None] mask = Vindex(self.data.mask)[ndx].to(self.device) with handlers.mask(mask=mask): # background mean and std background_mean = pyro.sample( "background_mean", dist.HalfNormal(self.priors["background_mean_std"]), ) background_std = pyro.sample( "background_std", dist.HalfNormal(self.priors["background_std_std"])) z_prev = None for fdx in frames: if self.vectorized: fsx, fdx = fdx fdx = torch.as_tensor(fdx) fdx = fdx.unsqueeze(-1) else: fsx = fdx # fetch data obs, target_locs, is_ontarget = self.data.fetch( ndx, fdx, cdx) # sample background intensity background = pyro.sample( f"background_f{fsx}", dist.Gamma( (background_mean / background_std)**2, background_mean / background_std**2, ), ) # sample hidden model state (1+S,) z_probs = (Vindex(init)[..., cdx, :, is_ontarget.long()] if z_prev is None else Vindex(trans)[..., cdx, z_prev, :, is_ontarget.long()]) z_curr = pyro.sample(f"z_f{fsx}", dist.Categorical(z_probs)) theta = pyro.sample( f"theta_f{fsx}", dist.Categorical( Vindex(probs_theta( self.K, self.device))[torch.clamp(z_curr, min=0, max=1)]), infer={"enumerate": "parallel"}, ) onehot_theta = one_hot(theta, num_classes=1 + self.K) ms, heights, widths, xs, ys = [], [], [], [], [] for kdx in spots: specific = onehot_theta[..., 1 + kdx] # spot presence m_probs = Vindex(probs_m(lamda, self.K))[..., cdx, theta, kdx] m = pyro.sample( f"m_k{kdx}_f{fsx}", dist.Categorical( torch.stack((1 - m_probs, m_probs), -1)), ) with handlers.mask(mask=m > 0): # sample spot variables height = pyro.sample( f"height_k{kdx}_f{fsx}", dist.HalfNormal(self.priors["height_std"]), ) width = pyro.sample( f"width_k{kdx}_f{fsx}", AffineBeta( 1.5, 2, self.priors["width_min"], self.priors["width_max"], ), ) x = pyro.sample( f"x_k{kdx}_f{fsx}", AffineBeta( 0, Vindex(size)[..., specific], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) y = pyro.sample( f"y_k{kdx}_f{fsx}", AffineBeta( 0, Vindex(size)[..., specific], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) # append ms.append(m) heights.append(height) widths.append(width) xs.append(x) ys.append(y) # observed data pyro.sample( f"data_f{fsx}", KSMOGN( torch.stack(heights, -1), torch.stack(widths, -1), torch.stack(xs, -1), torch.stack(ys, -1), target_locs, background, gain, self.data.offset.samples, self.data.offset.logits.to(self.dtype), self.data.P, torch.stack(torch.broadcast_tensors(*ms), -1), use_pykeops=self.use_pykeops, ), obs=obs, ) z_prev = z_curr
def guide(self): """ **Variational Distribution** """ # global parameters pyro.sample( "gain", dist.Gamma( pyro.param("gain_loc") * pyro.param("gain_beta"), pyro.param("gain_beta"), ), ) pyro.sample( "init", dist.Dirichlet(pyro.param("init_mean") * pyro.param("init_size")).to_event(1), ) pyro.sample( "trans", dist.Dirichlet( pyro.param("trans_mean") * pyro.param("trans_size")).to_event(2), ) pyro.sample( "lamda", dist.Gamma( pyro.param("lamda_loc") * pyro.param("lamda_beta"), pyro.param("lamda_beta"), ).to_event(1), ) pyro.sample( "proximity", AffineBeta( pyro.param("proximity_loc"), pyro.param("proximity_size"), 0, (self.data.P + 1) / math.sqrt(12), ), ) # spots spots = pyro.plate("spots", self.K) # aoi sites aois = pyro.plate( "aois", self.data.Nt, subsample=self.n, subsample_size=self.nbatch_size, dim=-3, ) # time frames frames = (pyro.vectorized_markov( name="frames", size=self.data.F, dim=-2) if self.vectorized else pyro.markov(range(self.data.F))) # color channels channels = pyro.plate( "channels", self.data.C, dim=-1, ) with channels as cdx, aois as ndx: ndx = ndx[:, None, None] mask = Vindex(self.data.mask)[ndx].to(self.device) with handlers.mask(mask=mask): pyro.sample( "background_mean", dist.Delta( Vindex(pyro.param("background_mean_loc"))[ndx, 0, cdx]), ) pyro.sample( "background_std", dist.Delta( Vindex(pyro.param("background_std_loc"))[ndx, 0, cdx]), ) z_prev = None for fdx in frames: if self.vectorized: fsx, fdx = fdx fdx = torch.as_tensor(fdx) fdx = fdx.unsqueeze(-1) else: fsx = fdx # sample background intensity pyro.sample( f"background_f{fsx}", dist.Gamma( Vindex(pyro.param("b_loc"))[ndx, fdx, cdx] * Vindex(pyro.param("b_beta"))[ndx, fdx, cdx], Vindex(pyro.param("b_beta"))[ndx, fdx, cdx], ), ) # sample hidden model state z_probs = (Vindex(pyro.param("z_trans"))[ndx, fdx, cdx, 0] if z_prev is None else Vindex( pyro.param("z_trans"))[ndx, fdx, cdx, z_prev]) z_curr = pyro.sample( f"z_f{fsx}", dist.Categorical(z_probs), infer={"enumerate": "parallel"}, ) for kdx in spots: # spot presence m_probs = Vindex(pyro.param("m_probs"))[z_curr, kdx, ndx, fdx, cdx] m = pyro.sample( f"m_k{kdx}_f{fsx}", dist.Categorical( torch.stack((1 - m_probs, m_probs), -1)), infer={"enumerate": "parallel"}, ) with handlers.mask(mask=m > 0): # sample spot variables pyro.sample( f"height_k{kdx}_f{fsx}", dist.Gamma( Vindex(pyro.param("h_loc"))[kdx, ndx, fdx, cdx] * Vindex(pyro.param("h_beta"))[kdx, ndx, fdx, cdx], Vindex(pyro.param("h_beta"))[kdx, ndx, fdx, cdx], ), ) pyro.sample( f"width_k{kdx}_f{fsx}", AffineBeta( Vindex(pyro.param("w_mean"))[kdx, ndx, fdx, cdx], Vindex(pyro.param("w_size"))[kdx, ndx, fdx, cdx], self.priors["width_min"], self.priors["width_max"], ), ) pyro.sample( f"x_k{kdx}_f{fsx}", AffineBeta( Vindex(pyro.param("x_mean"))[kdx, ndx, fdx, cdx], Vindex(pyro.param("size"))[kdx, ndx, fdx, cdx], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) pyro.sample( f"y_k{kdx}_f{fsx}", AffineBeta( Vindex(pyro.param("y_mean"))[kdx, ndx, fdx, cdx], Vindex(pyro.param("size"))[kdx, ndx, fdx, cdx], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) z_prev = z_curr
def guide(self): """ **Variational Distribution** """ # global parameters pyro.sample( "gain", dist.Gamma( pyro.param("gain_loc") * pyro.param("gain_beta"), pyro.param("gain_beta"), ), ) pyro.sample( "init", dist.Dirichlet(pyro.param("init_mean") * pyro.param("init_size"))) pyro.sample( "trans", dist.Dirichlet( pyro.param("trans_mean") * pyro.param("trans_size")).to_event(1), ) pyro.sample( "lamda", dist.Gamma( pyro.param("lamda_loc") * pyro.param("lamda_beta"), pyro.param("lamda_beta"), ), ) pyro.sample( "proximity", AffineBeta( pyro.param("proximity_loc"), pyro.param("proximity_size"), 0, (self.data.P + 1) / math.sqrt(12), ), ) # spots spots = pyro.plate("spots", self.K) # aoi sites aois = pyro.plate( "aois", self.data.Nt, subsample=self.n, subsample_size=self.nbatch_size, dim=-2, ) # time frames frames = (pyro.vectorized_markov( name="frames", size=self.data.F, dim=-1) if self.vectorized else pyro.markov(range(self.data.F))) with aois as ndx: ndx = ndx[:, None] pyro.sample( "background_mean", dist.Delta(Vindex(pyro.param("background_mean_loc"))[ndx, 0]), ) pyro.sample( "background_std", dist.Delta(Vindex(pyro.param("background_std_loc"))[ndx, 0]), ) z_prev = None for fdx in frames: if self.vectorized: fsx, fdx = fdx else: fsx = fdx # sample background intensity pyro.sample( f"background_{fsx}", dist.Gamma( Vindex(pyro.param("b_loc"))[ndx, fdx] * Vindex(pyro.param("b_beta"))[ndx, fdx], Vindex(pyro.param("b_beta"))[ndx, fdx], ), ) # sample hidden model state z_probs = (Vindex(pyro.param("z_trans"))[ndx, fdx, 0] if isinstance(fdx, int) and fdx < 1 else Vindex( pyro.param("z_trans"))[ndx, fdx, z_prev]) z_curr = pyro.sample( f"z_{fsx}", dist.Categorical(z_probs), infer={"enumerate": "parallel"}, ) for kdx in spots: # spot presence m_probs = Vindex(pyro.param("m_probs"))[z_curr, kdx, ndx, fdx] m = pyro.sample( f"m_{kdx}_{fsx}", dist.Categorical( torch.stack((1 - m_probs, m_probs), -1)), infer={"enumerate": "parallel"}, ) with handlers.mask(mask=m > 0): # sample spot variables pyro.sample( f"height_{kdx}_{fsx}", dist.Gamma( Vindex(pyro.param("h_loc"))[kdx, ndx, fdx] * Vindex(pyro.param("h_beta"))[kdx, ndx, fdx], Vindex(pyro.param("h_beta"))[kdx, ndx, fdx], ), ) pyro.sample( f"width_{kdx}_{fsx}", AffineBeta( Vindex(pyro.param("w_mean"))[kdx, ndx, fdx], Vindex(pyro.param("w_size"))[kdx, ndx, fdx], 0.75, 2.25, ), ) pyro.sample( f"x_{kdx}_{fsx}", AffineBeta( Vindex(pyro.param("x_mean"))[kdx, ndx, fdx], Vindex(pyro.param("size"))[kdx, ndx, fdx], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) pyro.sample( f"y_{kdx}_{fsx}", AffineBeta( Vindex(pyro.param("y_mean"))[kdx, ndx, fdx], Vindex(pyro.param("size"))[kdx, ndx, fdx], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) z_prev = z_curr