def model(self): """ **Generative Model** """ # global parameters gain = pyro.sample("gain", dist.HalfNormal(self.priors["gain_std"])) init = pyro.sample( "init", dist.Dirichlet(torch.ones(self.Q, self.S + 1) / (self.S + 1)).to_event(1), ) init = expand_offtarget(init) trans = pyro.sample( "trans", dist.Dirichlet( torch.ones(self.Q, self.S + 1, self.S + 1) / (self.S + 1)).to_event(2), ) trans = expand_offtarget(trans) lamda = pyro.sample( "lamda", dist.Exponential(torch.full( (self.Q, ), self.priors["lamda_rate"])).to_event(1), ) proximity = pyro.sample( "proximity", dist.Exponential(self.priors["proximity_rate"])) size = torch.stack( ( torch.full_like(proximity, 2.0), (((self.data.P + 1) / (2 * proximity))**2 - 1), ), dim=-1, ) # spots spots = pyro.plate("spots", self.K) # aoi sites aois = pyro.plate( "aois", self.data.Nt, subsample=self.n, subsample_size=self.nbatch_size, dim=-3, ) # time frames frames = (pyro.vectorized_markov( name="frames", size=self.data.F, dim=-2) if self.vectorized else pyro.markov(range(self.data.F))) # color channels channels = pyro.plate( "channels", self.data.C, dim=-1, ) with channels as cdx, aois as ndx: ndx = ndx[:, None, None] mask = Vindex(self.data.mask)[ndx].to(self.device) with handlers.mask(mask=mask): # background mean and std background_mean = pyro.sample( "background_mean", dist.HalfNormal(self.priors["background_mean_std"]), ) background_std = pyro.sample( "background_std", dist.HalfNormal(self.priors["background_std_std"])) z_prev = None for fdx in frames: if self.vectorized: fsx, fdx = fdx fdx = torch.as_tensor(fdx) fdx = fdx.unsqueeze(-1) else: fsx = fdx # fetch data obs, target_locs, is_ontarget = self.data.fetch( ndx, fdx, cdx) # sample background intensity background = pyro.sample( f"background_f{fsx}", dist.Gamma( (background_mean / background_std)**2, background_mean / background_std**2, ), ) # sample hidden model state (1+S,) z_probs = (Vindex(init)[..., cdx, :, is_ontarget.long()] if z_prev is None else Vindex(trans)[..., cdx, z_prev, :, is_ontarget.long()]) z_curr = pyro.sample(f"z_f{fsx}", dist.Categorical(z_probs)) theta = pyro.sample( f"theta_f{fsx}", dist.Categorical( Vindex(probs_theta( self.K, self.device))[torch.clamp(z_curr, min=0, max=1)]), infer={"enumerate": "parallel"}, ) onehot_theta = one_hot(theta, num_classes=1 + self.K) ms, heights, widths, xs, ys = [], [], [], [], [] for kdx in spots: specific = onehot_theta[..., 1 + kdx] # spot presence m_probs = Vindex(probs_m(lamda, self.K))[..., cdx, theta, kdx] m = pyro.sample( f"m_k{kdx}_f{fsx}", dist.Categorical( torch.stack((1 - m_probs, m_probs), -1)), ) with handlers.mask(mask=m > 0): # sample spot variables height = pyro.sample( f"height_k{kdx}_f{fsx}", dist.HalfNormal(self.priors["height_std"]), ) width = pyro.sample( f"width_k{kdx}_f{fsx}", AffineBeta( 1.5, 2, self.priors["width_min"], self.priors["width_max"], ), ) x = pyro.sample( f"x_k{kdx}_f{fsx}", AffineBeta( 0, Vindex(size)[..., specific], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) y = pyro.sample( f"y_k{kdx}_f{fsx}", AffineBeta( 0, Vindex(size)[..., specific], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) # append ms.append(m) heights.append(height) widths.append(width) xs.append(x) ys.append(y) # observed data pyro.sample( f"data_f{fsx}", KSMOGN( torch.stack(heights, -1), torch.stack(widths, -1), torch.stack(xs, -1), torch.stack(ys, -1), target_locs, background, gain, self.data.offset.samples, self.data.offset.logits.to(self.dtype), self.data.P, torch.stack(torch.broadcast_tensors(*ms), -1), use_pykeops=self.use_pykeops, ), obs=obs, ) z_prev = z_curr
def model(self): r""" Generative Model """ # global parameters gain = pyro.sample("gain", dist.HalfNormal(self.priors["gain_std"])) alpha = pyro.sample( "alpha", dist.Dirichlet( torch.ones((self.Q, self.data.C)) + torch.eye(self.Q) * 9).to_event(1), ) pi = pyro.sample( "pi", dist.Dirichlet(torch.ones( (self.Q, self.S + 1)) / (self.S + 1)).to_event(1), ) pi = expand_offtarget(pi) lamda = pyro.sample( "lamda", dist.Exponential(torch.full( (self.Q, ), self.priors["lamda_rate"])).to_event(1), ) proximity = pyro.sample( "proximity", dist.Exponential(self.priors["proximity_rate"])) size = torch.stack( ( torch.full_like(proximity, 2.0), (((self.data.P + 1) / (2 * proximity))**2 - 1), ), dim=-1, ) # aoi sites aois = pyro.plate( "aois", self.data.Nt, subsample=self.n, subsample_size=self.nbatch_size, dim=-2, ) # time frames frames = pyro.plate( "frames", self.data.F, subsample=self.f, subsample_size=self.fbatch_size, dim=-1, ) with aois as ndx: ndx = ndx[:, None] mask = Vindex(self.data.mask)[ndx].to(self.device) with handlers.mask(mask=mask): # background mean and std background_mean = pyro.sample( "background_mean", dist.HalfNormal(self.priors["background_mean_std"]).expand( (self.data.C, )).to_event(1), ) background_std = pyro.sample( "background_std", dist.HalfNormal(self.priors["background_std_std"]).expand( (self.data.C, )).to_event(1), ) with frames as fdx: # fetch data obs, target_locs, is_ontarget = self.data.fetch( ndx.unsqueeze(-1), fdx.unsqueeze(-1), torch.arange(self.data.C)) # sample background intensity background = pyro.sample( "background", dist.Gamma( (background_mean / background_std)**2, background_mean / background_std**2, ).to_event(1), ) ms, heights, widths, xs, ys = [], [], [], [], [] is_ontarget = is_ontarget.squeeze(-1) for qdx in range(self.Q): # sample hidden model state (1+S,) z_probs = Vindex(pi)[..., qdx, :, is_ontarget.long()] z = pyro.sample( f"z_q{qdx}", dist.Categorical(z_probs), infer={"enumerate": "parallel"}, ) theta = pyro.sample( f"theta_q{qdx}", dist.Categorical( Vindex(probs_theta( self.K, self.device))[torch.clamp(z, min=0, max=1)]), infer={"enumerate": "parallel"}, ) onehot_theta = one_hot(theta, num_classes=1 + self.K) for kdx in range(self.K): specific = onehot_theta[..., 1 + kdx] # spot presence m = pyro.sample( f"m_k{kdx}_q{qdx}", dist.Bernoulli( Vindex(probs_m(lamda, self.K))[..., qdx, theta, kdx]), ) with handlers.mask(mask=m > 0): # sample spot variables height = pyro.sample( f"height_k{kdx}_q{qdx}", dist.HalfNormal(self.priors["height_std"]), ) width = pyro.sample( f"width_k{kdx}_q{qdx}", AffineBeta( 1.5, 2, self.priors["width_min"], self.priors["width_max"], ), ) x = pyro.sample( f"x_k{kdx}_q{qdx}", AffineBeta( 0, Vindex(size)[..., specific], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) y = pyro.sample( f"y_k{kdx}_q{qdx}", AffineBeta( 0, Vindex(size)[..., specific], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) # append ms.append(m) heights.append(height) widths.append(width) xs.append(x) ys.append(y) heights = torch.stack( [ torch.stack(heights[q * self.K:(1 + q) * self.K], -1) for q in range(self.Q) ], -2, ) widths = torch.stack( [ torch.stack(widths[q * self.K:(1 + q) * self.K], -1) for q in range(self.Q) ], -2, ) xs = torch.stack( [ torch.stack(xs[q * self.K:(1 + q) * self.K], -1) for q in range(self.Q) ], -2, ) ys = torch.stack( [ torch.stack(ys[q * self.Q:(1 + q) * self.K], -1) for q in range(self.Q) ], -2, ) ms = torch.broadcast_tensors(*ms) ms = torch.stack( [ torch.stack(ms[q * self.Q:(1 + q) * self.K], -1) for q in range(self.Q) ], -2, ) # observed data pyro.sample( "data", KSMOGN( heights, widths, xs, ys, target_locs, background, gain, self.data.offset.samples, self.data.offset.logits.to(self.dtype), self.data.P, ms, alpha, use_pykeops=self.use_pykeops, ), obs=obs, )
def model(self): r""" **Generative Model** Model parameters: +-----------------+-----------+-------------------------------------+ | Parameter | Shape | Description | +=================+===========+=====================================+ | |g| - :math:`g` | (1,) | camera gain | +-----------------+-----------+-------------------------------------+ | |sigma| - |prox|| (1,) | proximity | +-----------------+-----------+-------------------------------------+ | ``lamda`` - |ld|| (1,) | average rate of target-nonspecific | | | | binding | +-----------------+-----------+-------------------------------------+ | ``pi`` - |pi| | (1,) | average binding probability of | | | | target-specific binding | +-----------------+-----------+-------------------------------------+ | |bg| - |b| | (N, F) | background intensity | +-----------------+-----------+-------------------------------------+ | |z| - :math:`z` | (N, F) | target-specific spot presence | +-----------------+-----------+-------------------------------------+ | |t| - |theta| | (N, F) | target-specific spot index | +-----------------+-----------+-------------------------------------+ | |m| - :math:`m` | (K, N, F) | spot presence indicator | +-----------------+-----------+-------------------------------------+ | |h| - :math:`h` | (K, N, F) | spot intensity | +-----------------+-----------+-------------------------------------+ | |w| - :math:`w` | (K, N, F) | spot width | +-----------------+-----------+-------------------------------------+ | |x| - :math:`x` | (K, N, F) | spot position on x-axis | +-----------------+-----------+-------------------------------------+ | |y| - :math:`y` | (K, N, F) | spot position on y-axis | +-----------------+-----------+-------------------------------------+ | |D| - :math:`D` | |shape| | observed images | +-----------------+-----------+-------------------------------------+ .. |ps| replace:: :math:`p(\mathsf{specific})` .. |theta| replace:: :math:`\theta` .. |prox| replace:: :math:`\sigma^{xy}` .. |ld| replace:: :math:`\lambda` .. |b| replace:: :math:`b` .. |shape| replace:: (N, F, P, P) .. |sigma| replace:: ``proximity`` .. |bg| replace:: ``background`` .. |h| replace:: ``height`` .. |w| replace:: ``width`` .. |D| replace:: ``data`` .. |m| replace:: ``m`` .. |z| replace:: ``z`` .. |t| replace:: ``theta`` .. |x| replace:: ``x`` .. |y| replace:: ``y`` .. |pi| replace:: :math:`\pi` .. |g| replace:: ``gain`` Full joint distribution: .. math:: \begin{aligned} p(D, \phi) =~&p(g) p(\sigma^{xy}) p(\pi) p(\lambda) \prod_{\mathsf{AOI}} \left[ p(\mu^b) p(\sigma^b) \prod_{\mathsf{frame}} \left[ \vphantom{\prod_{F}} p(b | \mu^b, \sigma^b) p(z | \pi) p(\theta | z) \vphantom{\prod_{\substack{\mathsf{pixelX} \\ \mathsf{pixelY}}}} \cdot \right. \right. \\ &\prod_{\mathsf{spot}} \left[ \vphantom{\prod_{F}} p(m | \theta, \lambda) p(h) p(w) p(x | \sigma^{xy}, \theta) p(y | \sigma^{xy}, \theta) \right] \left. \left. \prod_{\substack{\mathsf{pixelX} \\ \mathsf{pixelY}}} \sum_{\delta} p(\delta) p(D | \mu^I, g, \delta) \right] \right] \end{aligned} :math:`z` and :math:`\theta` marginalized joint distribution: .. math:: \begin{aligned} \sum_{z, \theta} p(D, \phi) =~&p(g) p(\sigma^{xy}) p(\pi) p(\lambda) \prod_{\mathsf{AOI}} \left[ p(\mu^b) p(\sigma^b) \prod_{\mathsf{frame}} \left[ \vphantom{\prod_{F}} p(b | \mu^b, \sigma^b) \sum_{z} p(z | \pi) \sum_{\theta} p(\theta | z) \vphantom{\prod_{\substack{\mathsf{pixelX} \\ \mathsf{pixelY}}}} \cdot \right. \right. \\ &\prod_{\mathsf{spot}} \left[ \vphantom{\prod_{F}} p(m | \theta, \lambda) p(h) p(w) p(x | \sigma^{xy}, \theta) p(y | \sigma^{xy}, \theta) \right] \left. \left. \prod_{\substack{\mathsf{pixelX} \\ \mathsf{pixelY}}} \sum_{\delta} p(\delta) p(D | \mu^I, g, \delta) \right] \right] \end{aligned} """ # global parameters gain = pyro.sample("gain", dist.HalfNormal(self.gain_std)) pi = pyro.sample("pi", dist.Dirichlet(torch.ones(self.S + 1) / (self.S + 1))) pi = expand_offtarget(pi) lamda = pyro.sample("lamda", dist.Exponential(self.lamda_rate)) proximity = pyro.sample("proximity", dist.Exponential(self.proximity_rate)) size = torch.stack( ( torch.full_like(proximity, 2.0), (((self.data.P + 1) / (2 * proximity))**2 - 1), ), dim=-1, ) # spots spots = pyro.plate("spots", self.K) # aoi sites aois = pyro.plate( "aois", self.data.Nt, subsample=self.n, subsample_size=self.nbatch_size, dim=-2, ) # time frames frames = pyro.plate( "frames", self.data.F, subsample=self.f, subsample_size=self.fbatch_size, dim=-1, ) with aois as ndx: ndx = ndx[:, None] # background mean and std background_mean = pyro.sample( "background_mean", dist.HalfNormal(self.background_mean_std)) background_std = pyro.sample( "background_std", dist.HalfNormal(self.background_std_std)) with frames as fdx: # fetch data obs, target_locs, is_ontarget = self.data.fetch( ndx, fdx, self.cdx) # sample background intensity background = pyro.sample( "background", dist.Gamma( (background_mean / background_std)**2, background_mean / background_std**2, ), ) # sample hidden model state (1+S,) z = pyro.sample( "z", dist.Categorical(Vindex(pi)[..., :, is_ontarget.long()]), infer={"enumerate": "parallel"}, ) theta = pyro.sample( "theta", dist.Categorical( Vindex(probs_theta(self.K, self.device))[torch.clamp(z, min=0, max=1)]), infer={"enumerate": "parallel"}, ) onehot_theta = one_hot(theta, num_classes=1 + self.K) ms, heights, widths, xs, ys = [], [], [], [], [] for kdx in spots: specific = onehot_theta[..., 1 + kdx] # spot presence m = pyro.sample( f"m_{kdx}", dist.Bernoulli( Vindex(probs_m(lamda, self.K))[..., theta, kdx]), ) with handlers.mask(mask=m > 0): # sample spot variables height = pyro.sample( f"height_{kdx}", dist.HalfNormal(self.height_std), ) width = pyro.sample( f"width_{kdx}", AffineBeta( 1.5, 2, self.width_min, self.width_max, ), ) x = pyro.sample( f"x_{kdx}", AffineBeta( 0, Vindex(size)[..., specific], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) y = pyro.sample( f"y_{kdx}", AffineBeta( 0, Vindex(size)[..., specific], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) # append ms.append(m) heights.append(height) widths.append(width) xs.append(x) ys.append(y) # observed data pyro.sample( "data", KSMOGN( torch.stack(heights, -1), torch.stack(widths, -1), torch.stack(xs, -1), torch.stack(ys, -1), target_locs, background, gain, self.data.offset.samples, self.data.offset.logits.to(self.dtype), self.data.P, torch.stack(torch.broadcast_tensors(*ms), -1), self.use_pykeops, ), obs=obs, )