def model(): locs = pyro.param("locs", torch.tensor([0.2, 0.3, 0.5])) p = torch.tensor([0.2, 0.3, 0.5]) with pyro.plate("plate", len(data), dim=-1): x = pyro.sample("x", dist.Categorical(p)) pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data)
def guide(): p = pyro.param("p", torch.tensor([0.5, 0.3, 0.2])) with pyro.plate("plate", len(data), dim=-1): pyro.sample("x", dist.Categorical(p))
def model(data): p = pyro.param("p", torch.tensor(0.5)) pyro.sample("x", dist.Bernoulli(p), obs=data)
def model(data): loc = pyro.param("loc", torch.tensor(0.0)) pyro.sample("x", dist.Normal(loc, 1.), obs=data)
def guide(): loc = pyro.param("loc", torch.tensor(0.)) y = pyro.sample("y", dist.Normal(loc, 1.)) pyro.sample("x", dist.Normal(y, 1.))
def test_elbo_enumerate_plate_7(backend): # Guide Model # a -----> b # | | # +-|--------|----------------+ # | V V | # | c -----> d -----> e N=2 | # +---------------------------+ # This tests a mixture of model and guide enumeration. with pyro_backend(backend): pyro.param("model_probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex) pyro.param("model_probs_b", torch.tensor([[0.6, 0.4], [0.4, 0.6]]), constraint=constraints.simplex) pyro.param("model_probs_c", torch.tensor([[0.75, 0.25], [0.55, 0.45]]), constraint=constraints.simplex) pyro.param("model_probs_d", torch.tensor([[[0.4, 0.6], [0.3, 0.7]], [[0.3, 0.7], [0.2, 0.8]]]), constraint=constraints.simplex) pyro.param("model_probs_e", torch.tensor([[0.75, 0.25], [0.55, 0.45]]), constraint=constraints.simplex) pyro.param("guide_probs_a", torch.tensor([0.35, 0.64]), constraint=constraints.simplex) pyro.param("guide_probs_c", torch.tensor([[0., 1.], [1., 0.]]), # deterministic constraint=constraints.simplex) def auto_model(data): probs_a = pyro.param("model_probs_a") probs_b = pyro.param("model_probs_b") probs_c = pyro.param("model_probs_c") probs_d = pyro.param("model_probs_d") probs_e = pyro.param("model_probs_e") a = pyro.sample("a", dist.Categorical(probs_a)) b = pyro.sample("b", dist.Categorical(probs_b[a]), infer={"enumerate": "parallel"}) with pyro.plate("data", 2, dim=-1): c = pyro.sample("c", dist.Categorical(probs_c[a])) d = pyro.sample("d", dist.Categorical(Vindex(probs_d)[b, c]), infer={"enumerate": "parallel"}) pyro.sample("obs", dist.Categorical(probs_e[d]), obs=data) def auto_guide(data): probs_a = pyro.param("guide_probs_a") probs_c = pyro.param("guide_probs_c") a = pyro.sample("a", dist.Categorical(probs_a), infer={"enumerate": "parallel"}) with pyro.plate("data", 2, dim=-1): pyro.sample("c", dist.Categorical(probs_c[a])) def hand_model(data): probs_a = pyro.param("model_probs_a") probs_b = pyro.param("model_probs_b") probs_c = pyro.param("model_probs_c") probs_d = pyro.param("model_probs_d") probs_e = pyro.param("model_probs_e") a = pyro.sample("a", dist.Categorical(probs_a)) b = pyro.sample("b", dist.Categorical(probs_b[a]), infer={"enumerate": "parallel"}) for i in range(2): c = pyro.sample("c_{}".format(i), dist.Categorical(probs_c[a])) d = pyro.sample("d_{}".format(i), dist.Categorical(Vindex(probs_d)[b, c]), infer={"enumerate": "parallel"}) pyro.sample("obs_{}".format(i), dist.Categorical(probs_e[d]), obs=data[i]) def hand_guide(data): probs_a = pyro.param("guide_probs_a") probs_c = pyro.param("guide_probs_c") a = pyro.sample("a", dist.Categorical(probs_a), infer={"enumerate": "parallel"}) for i in range(2): pyro.sample("c_{}".format(i), dist.Categorical(probs_c[a])) data = torch.tensor([0, 0]) elbo = infer.TraceEnum_ELBO(max_plate_nesting=1) elbo = elbo.differentiable_loss if backend == "pyro" else elbo auto_loss = elbo(auto_model, auto_guide, data) elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) elbo = elbo.differentiable_loss if backend == "pyro" else elbo hand_loss = elbo(hand_model, hand_guide, data) _check_loss_and_grads(hand_loss, auto_loss)
def _init_parameters(self): """ Parameters shared between different models. """ device = self.device data = self.data pyro.param( "proximity_loc", lambda: torch.tensor(0.5, device=device), constraint=constraints.interval( 0, (self.data.P + 1) / math.sqrt(12) - torch.finfo(self.dtype).eps, ), ) pyro.param( "proximity_size", lambda: torch.tensor(100, device=device), constraint=constraints.greater_than(2.0), ) pyro.param( "lamda_loc", lambda: torch.full((self.Q,), 0.5, device=device), constraint=constraints.positive, ) pyro.param( "lamda_beta", lambda: torch.full((self.Q,), 100, device=device), constraint=constraints.positive, ) pyro.param( "gain_loc", lambda: torch.tensor(5, device=device), constraint=constraints.positive, ) pyro.param( "gain_beta", lambda: torch.tensor(100, device=device), constraint=constraints.positive, ) pyro.param( "background_mean_loc", lambda: (data.median.to(device) - data.offset.mean).expand( data.Nt, 1, data.C ), constraint=constraints.positive, ) pyro.param( "background_std_loc", lambda: torch.ones(data.Nt, 1, data.C, device=device), constraint=constraints.positive, ) pyro.param( "b_loc", lambda: (data.median.to(device) - self.data.offset.mean).expand( data.Nt, data.F, data.C ), constraint=constraints.positive, ) pyro.param( "b_beta", lambda: torch.ones(data.Nt, data.F, data.C, device=device), constraint=constraints.positive, ) pyro.param( "h_loc", lambda: torch.full((self.K, data.Nt, data.F, self.Q), 2000, device=device), constraint=constraints.positive, ) pyro.param( "h_beta", lambda: torch.full((self.K, data.Nt, data.F, self.Q), 0.001, device=device), constraint=constraints.positive, ) pyro.param( "w_mean", lambda: torch.full((self.K, data.Nt, data.F, self.Q), 1.5, device=device), constraint=constraints.interval( 0.75 + torch.finfo(self.dtype).eps, 2.25 - torch.finfo(self.dtype).eps, ), ) pyro.param( "w_size", lambda: torch.full((self.K, data.Nt, data.F, self.Q), 100, device=device), constraint=constraints.greater_than(2.0), ) pyro.param( "x_mean", lambda: torch.zeros(self.K, data.Nt, data.F, self.Q, device=device), constraint=constraints.interval( -(data.P + 1) / 2 + torch.finfo(self.dtype).eps, (data.P + 1) / 2 - torch.finfo(self.dtype).eps, ), ) pyro.param( "y_mean", lambda: torch.zeros(self.K, data.Nt, data.F, self.Q, device=device), constraint=constraints.interval( -(data.P + 1) / 2 + torch.finfo(self.dtype).eps, (data.P + 1) / 2 - torch.finfo(self.dtype).eps, ), ) pyro.param( "size", lambda: torch.full((self.K, data.Nt, data.F, self.Q), 200, device=device), constraint=constraints.greater_than(2.0), )
def guide(): with pyro.plate("plate", len(data), dim=-1): p = pyro.param("p", torch.ones(len(data), 3) / 3, event_dim=1) pyro.sample("x", dist.Categorical(p)) return p
def model(z=None): p = pyro.param("p", torch.tensor([0.75, 0.25])) z = pyro.sample("z", dist.Categorical(p), obs=z) logger.info("z.shape = {}".format(z.shape)) with pyro.plate("data", 3), handlers.mask(mask=mask): pyro.sample("x", dist.Normal(z.type_as(data), 1.0), obs=data)
def guide(self): r""" **Variational Distribution** .. math:: \begin{aligned} q(\phi \setminus \{z, \theta\}) =~&q(g) q(\sigma^{xy}) q(\pi) q(\lambda) \cdot \\ &\prod_{\mathsf{AOI}} \left[ q(\mu^b) q(\sigma^b) \prod_{\mathsf{frame}} \left[ \vphantom{\prod_{F}} q(b) \prod_{\mathsf{spot}} q(m) q(h | m) q(w | m) q(x | m) q(y | m) \right] \right] \end{aligned} """ # global parameters pyro.sample( "gain", dist.Gamma( pyro.param("gain_loc") * pyro.param("gain_beta"), pyro.param("gain_beta"), ), ) pyro.sample( "pi", dist.Dirichlet(pyro.param("pi_mean") * pyro.param("pi_size")).to_event(1), ) pyro.sample( "lamda", dist.Gamma( pyro.param("lamda_loc") * pyro.param("lamda_beta"), pyro.param("lamda_beta"), ).to_event(1), ) pyro.sample( "proximity", AffineBeta( pyro.param("proximity_loc"), pyro.param("proximity_size"), 0, (self.data.P + 1) / math.sqrt(12), ), ) # spots spots = pyro.plate("spots", self.K) # aoi sites aois = pyro.plate( "aois", self.data.Nt, subsample=self.n, subsample_size=self.nbatch_size, dim=-3, ) # time frames frames = pyro.plate( "frames", self.data.F, subsample=self.f, subsample_size=self.fbatch_size, dim=-2, ) # color channels channels = pyro.plate( "channels", self.data.C, dim=-1, ) with channels as cdx, aois as ndx: ndx = ndx[:, None, None] mask = Vindex(self.data.mask)[ndx].to(self.device) with handlers.mask(mask=mask): pyro.sample( "background_mean", dist.Delta(Vindex(pyro.param("background_mean_loc"))[ndx, 0, cdx]), ) pyro.sample( "background_std", dist.Delta(Vindex(pyro.param("background_std_loc"))[ndx, 0, cdx]), ) with frames as fdx: fdx = fdx[:, None] # sample background intensity pyro.sample( "background", dist.Gamma( Vindex(pyro.param("b_loc"))[ndx, fdx, cdx] * Vindex(pyro.param("b_beta"))[ndx, fdx, cdx], Vindex(pyro.param("b_beta"))[ndx, fdx, cdx], ), ) for kdx in spots: # sample spot presence m m = pyro.sample( f"m_k{kdx}", dist.Bernoulli( Vindex(pyro.param("m_probs"))[kdx, ndx, fdx, cdx] ), infer={"enumerate": "parallel"}, ) with handlers.mask(mask=m > 0): # sample spot variables pyro.sample( f"height_k{kdx}", dist.Gamma( Vindex(pyro.param("h_loc"))[kdx, ndx, fdx, cdx] * Vindex(pyro.param("h_beta"))[kdx, ndx, fdx, cdx], Vindex(pyro.param("h_beta"))[kdx, ndx, fdx, cdx], ), ) pyro.sample( f"width_k{kdx}", AffineBeta( Vindex(pyro.param("w_mean"))[kdx, ndx, fdx, cdx], Vindex(pyro.param("w_size"))[kdx, ndx, fdx, cdx], self.priors["width_min"], self.priors["width_max"], ), ) pyro.sample( f"x_k{kdx}", AffineBeta( Vindex(pyro.param("x_mean"))[kdx, ndx, fdx, cdx], Vindex(pyro.param("size"))[kdx, ndx, fdx, cdx], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) pyro.sample( f"y_k{kdx}", AffineBeta( Vindex(pyro.param("y_mean"))[kdx, ndx, fdx, cdx], Vindex(pyro.param("size"))[kdx, ndx, fdx, cdx], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), )
def model(z2=None): p = pyro.param("p", torch.tensor([0.25, 0.75])) loc = pyro.param("loc", torch.tensor([-1.0, 1.0])) with pyro.plate("data", 2): z2 = pyro.sample("z2", dist.Categorical(p), obs=z2) pyro.sample("x2", dist.Normal(loc[z2], 1.0), obs=data)
def ttfb_model(data, control, Tmax): r""" Eq. 4 and Eq. 7 in:: @article{friedman2015multi, title={Multi-wavelength single-molecule fluorescence analysis of transcription mechanisms}, author={Friedman, Larry J and Gelles, Jeff}, journal={Methods}, volume={86}, pages={27--36}, year={2015}, publisher={Elsevier} } :param data: time prior to the first binding at the target location :param control: time prior to the first binding at the control location :param Tmax: entire observation interval """ ka = pyro.param( "ka", lambda: torch.full((data.shape[0], 1), 0.001), constraint=constraints.positive, ) kns = pyro.param( "kns", lambda: torch.full((data.shape[0], 1), 0.001), constraint=constraints.positive, ) Af = pyro.param( "Af", lambda: torch.full((data.shape[0], 1), 0.9), constraint=constraints.unit_interval, ) k = torch.stack([kns, ka + kns]) # on-target data mask = (data < Tmax) & (data > 0) tau = data.masked_fill(~mask, 1.0) with pyro.plate("bootstrap", data.shape[0], dim=-2) as bdx: with pyro.plate("N", data.shape[1], dim=-1): active = pyro.sample( "active", dist.Bernoulli(Af), infer={"enumerate": "parallel"} ) with handlers.mask(mask=(data == Tmax)): pyro.factor("Tmax", -Vindex(k)[active.long().squeeze(-1), bdx] * Tmax) # pyro.factor("Tmax", -k * Tmax) with handlers.mask(mask=mask): pyro.sample( "tau", dist.Exponential(Vindex(k)[active.long().squeeze(-1), bdx]), obs=tau, ) # pyro.sample("tau", dist.Exponential(k), obs=tau) # negative control data if control is not None: mask = (control < Tmax) & (control > 0) tauc = control.masked_fill(~mask, 1.0) with pyro.plate("bootstrapc", control.shape[0], dim=-2): with pyro.plate("Nc", control.shape[1], dim=-1): with handlers.mask(mask=(control == Tmax)): pyro.factor("Tmaxc", -kns * Tmax) with handlers.mask(mask=mask): pyro.sample("tauc", dist.Exponential(kns), obs=tauc)
def model(): loc = pyro.param("loc", torch.tensor(2.0)) scale = pyro.param("scale", torch.tensor(1.0)) x = pyro.sample("x", dist.Normal(loc, scale)) return x
def guide(): loc = pyro.param("loc", torch.tensor(0.)) scale = pyro.param("scale", torch.tensor(1.)) with pyro.plate("plate_outer", data.size(-1), dim=-1): pyro.sample("x", dist.Normal(loc, scale))
def m_probs(self) -> torch.Tensor: r""" Posterior spot presence probability :math:`q(m=1)`. """ return pyro.param("m_probs").data
def model(): locs = pyro.param("locs", torch.tensor([-1., 0., 1.])) with pyro.plate("plate", len(data), dim=-1): x = pyro.sample("x", dist.Categorical(torch.ones(3) / 3)) pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data)
def save_checkpoint(self, writer: SummaryWriter = None): """ Save checkpoint. :param writer: SummaryWriter object. """ # save only if no NaN values for k, v in pyro.get_param_store().items(): if torch.isnan(v).any() or torch.isinf(v).any(): raise ValueError( "Iteration #{}. Detected NaN values in {}".format(self.iter, k) ) # update convergence criteria parameters for name in self.conv_params: if name == "-ELBO": self._rolling["-ELBO"].append(self.iter_loss) elif pyro.param(name).ndim == 1: for i in range(len(pyro.param(name))): self._rolling[f"{name}_{i}"].append(pyro.param(name)[i].item()) else: self._rolling[name].append(pyro.param(name).item()) # check convergence status self.converged = False if len(self._rolling["-ELBO"]) == self._rolling["-ELBO"].maxlen: crit = all( torch.tensor(value).std() / torch.tensor(value)[-50:].std() < 1.05 for value in self._rolling.values() ) if crit: self.converged = True # save the model state torch.save( { "iter": self.iter, "params": pyro.get_param_store().get_state(), "optimizer": self.optim.get_state(), "rolling": dict(self._rolling), "convergence_status": self.converged, }, self.run_path / f"{self.name}-model.tpqr", ) # save global parameters for tensorboard writer.add_scalar("-ELBO", self.iter_loss, self.iter) for name, val in pyro.get_param_store().items(): if val.dim() == 0: writer.add_scalar(name, val.item(), self.iter) elif val.dim() == 1 and len(val) <= self.S + 1: scalars = {str(i): v.item() for i, v in enumerate(val)} writer.add_scalars(name, scalars, self.iter) if False and self.data.labels is not None: pred_labels = ( self.pspecific_map[self.data.is_ontarget].cpu().numpy().ravel() ) true_labels = self.data.labels["z"].ravel() metrics = {} with np.errstate(divide="ignore", invalid="ignore"): metrics["MCC"] = matthews_corrcoef(true_labels, pred_labels) metrics["Recall"] = recall_score(true_labels, pred_labels, zero_division=0) metrics["Precision"] = precision_score( true_labels, pred_labels, zero_division=0 ) neg, pos = {}, {} neg["TN"], neg["FP"], pos["FN"], pos["TP"] = confusion_matrix( true_labels, pred_labels, labels=(0, 1) ).ravel() writer.add_scalars("ACCURACY", metrics, self.iter) writer.add_scalars("NEGATIVES", neg, self.iter) writer.add_scalars("POSITIVES", pos, self.iter) logger.debug(f"Iteration #{self.iter}: Successful.")
def guide(): q = pyro.param("q", torch.randn(3).exp(), constraint=constraints.simplex) pyro.sample("x", dist.Categorical(q))
def model(data=None): loc = pyro.param("loc", torch.tensor(2.0)) scale = pyro.param("scale", torch.tensor(1.0)) with pyro.plate("data", 1000, dim=-1): x = pyro.sample("x", dist.Normal(loc, scale), obs=data) return x
def test_elbo_enumerate_plates_1(backend): # +-----------------+ # | a ----> b M=2 | # +-----------------+ # +-----------------+ # | c ----> d N=3 | # +-----------------+ # This tests two unrelated plates. # Each should remain uncontracted. with pyro_backend(backend): pyro.param("probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex) pyro.param("probs_b", torch.tensor([[0.6, 0.4], [0.4, 0.6]]), constraint=constraints.simplex) pyro.param("probs_c", torch.tensor([0.75, 0.25]), constraint=constraints.simplex) pyro.param("probs_d", torch.tensor([[0.4, 0.6], [0.3, 0.7]]), constraint=constraints.simplex) b_data = torch.tensor([0, 1]) d_data = torch.tensor([0, 0, 1]) def auto_model(): probs_a = pyro.param("probs_a") probs_b = pyro.param("probs_b") probs_c = pyro.param("probs_c") probs_d = pyro.param("probs_d") with pyro.plate("a_axis", 2, dim=-1): a = pyro.sample("a", dist.Categorical(probs_a), infer={"enumerate": "parallel"}) pyro.sample("b", dist.Categorical(probs_b[a]), obs=b_data) with pyro.plate("c_axis", 3, dim=-1): c = pyro.sample("c", dist.Categorical(probs_c), infer={"enumerate": "parallel"}) pyro.sample("d", dist.Categorical(probs_d[c]), obs=d_data) def hand_model(): probs_a = pyro.param("probs_a") probs_b = pyro.param("probs_b") probs_c = pyro.param("probs_c") probs_d = pyro.param("probs_d") for i in range(2): a = pyro.sample("a_{}".format(i), dist.Categorical(probs_a), infer={"enumerate": "parallel"}) pyro.sample("b_{}".format(i), dist.Categorical(probs_b[a]), obs=b_data[i]) for j in range(3): c = pyro.sample("c_{}".format(j), dist.Categorical(probs_c), infer={"enumerate": "parallel"}) pyro.sample("d_{}".format(j), dist.Categorical(probs_d[c]), obs=d_data[j]) def guide(): pass elbo = infer.TraceEnum_ELBO(max_plate_nesting=1) elbo = elbo.differentiable_loss if backend == "pyro" else elbo auto_loss = elbo(auto_model, guide) elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) elbo = elbo.differentiable_loss if backend == "pyro" else elbo hand_loss = elbo(hand_model, guide) _check_loss_and_grads(hand_loss, auto_loss)
def model(): locs = pyro.param("locs", torch.randn(3), constraint=constraints.real) scales = pyro.param("scales", torch.randn(3).exp(), constraint=constraints.positive) p = torch.tensor([0.5, 0.3, 0.2]) x = pyro.sample("x", dist.Categorical(p)) pyro.sample("obs", dist.Normal(locs[x], scales[x]), obs=data)