def model_4(sequences, lengths, args, batch_size=None, include_prior=True): with ignore_jit_warnings(): num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences, ) assert lengths.max() <= max_length hidden_dim = int(args.hidden_dim**0.5) # split between w and x with poutine.mask(mask=include_prior): probs_w = pyro.sample( "probs_w", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1)) probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).expand_by( [hidden_dim]).to_event(2)) probs_y = pyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim, data_dim]).to_event(3)) tones_plate = pyro.plate("tones", data_dim, dim=-1) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] # Note the broadcasting tricks here: we declare a hidden torch.arange and # ensure that w and x are always tensors so we can unsqueeze them below, # thus ensuring that the x sample sites have correct distribution shape. w = x = torch.tensor(0, dtype=torch.long) for t in pyro.markov(range(max_length if args.jit else lengths.max())): with poutine.mask(mask=(t < lengths).unsqueeze(-1)): w = pyro.sample("w_{}".format(t), dist.Categorical(probs_w[w]), infer={"enumerate": "parallel"}) x = pyro.sample("x_{}".format(t), dist.Categorical(Vindex(probs_x)[w, x]), infer={"enumerate": "parallel"}) with tones_plate as tones: pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y[w, x, tones]), obs=sequences[batch, t])
def model_1(data, history, vectorized): x_dim = 3 init = pyro.param("init", lambda: torch.rand(x_dim), constraint=constraints.simplex) trans = pyro.param("trans", lambda: torch.rand((x_dim, x_dim)), constraint=constraints.simplex) locs = pyro.param("locs", lambda: torch.rand(x_dim)) x_prev = None markov_loop = \ pyro.vectorized_markov(name="time", size=len(data), dim=-2, history=history) if vectorized \ else pyro.markov(range(len(data)), history=history) for i in markov_loop: x_curr = pyro.sample( "x_{}".format(i), dist.Categorical( init if isinstance(i, int) and i < 1 else trans[x_prev])) with pyro.plate("tones", data.shape[-1], dim=-1): pyro.sample("y_{}".format(i), dist.Normal(Vindex(locs)[..., x_curr], 1.), obs=data[i]) x_prev = x_curr
def model_7(data, history, vectorized): w_dim, x_dim, y_dim = 2, 3, 2 w_init = pyro.param("w_init", lambda: torch.rand(w_dim), constraint=constraints.simplex) w_trans = pyro.param("w_trans", lambda: torch.rand((x_dim, w_dim)), constraint=constraints.simplex) x_init = pyro.param("x_init", lambda: torch.rand(x_dim), constraint=constraints.simplex) x_trans = pyro.param("x_trans", lambda: torch.rand((w_dim, x_dim)), constraint=constraints.simplex) y_probs = pyro.param("y_probs", lambda: torch.rand(w_dim, x_dim, y_dim), constraint=constraints.simplex) w_prev = x_prev = None markov_loop = \ pyro.vectorized_markov(name="time", size=len(data), dim=-2, history=history) if vectorized \ else pyro.markov(range(len(data)), history=history) for i in markov_loop: w_curr = pyro.sample( "w_{}".format(i), dist.Categorical( w_init if isinstance(i, int) and i < 1 else w_trans[x_prev])) x_curr = pyro.sample( "x_{}".format(i), dist.Categorical( x_init if isinstance(i, int) and i < 1 else x_trans[w_prev])) with pyro.plate("tones", data.shape[-1], dim=-1): pyro.sample("y_{}".format(i), dist.Categorical(Vindex(y_probs)[w_curr, x_curr]), obs=data[i]) x_prev, w_prev = x_curr, w_curr
def torus_dbn(phis=None, psis=None, lengths=None, num_sequences=None, num_states=55, prior_conc=0.1, prior_loc=0.0, prior_length_shape=100., prior_length_rate=100., prior_kappa_min=10., prior_kappa_max=1000.): # From https://pyro.ai/examples/hmm.html with ignore_jit_warnings(): if lengths is not None: assert num_sequences is None num_sequences = int(lengths.shape[0]) else: assert num_sequences is not None transition_probs = pyro.sample( 'transition_probs', dist.Dirichlet( torch.ones(num_states, num_states, dtype=torch.float) * num_states).to_event(1)) length_shape = pyro.sample('length_shape', dist.HalfCauchy(prior_length_shape)) length_rate = pyro.sample('length_rate', dist.HalfCauchy(prior_length_rate)) phi_locs = pyro.sample( 'phi_locs', dist.VonMises( torch.ones(num_states, dtype=torch.float) * prior_loc, torch.ones(num_states, dtype=torch.float) * prior_conc).to_event(1)) phi_kappas = pyro.sample( 'phi_kappas', dist.Uniform( torch.ones(num_states, dtype=torch.float) * prior_kappa_min, torch.ones(num_states, dtype=torch.float) * prior_kappa_max).to_event(1)) psi_locs = pyro.sample( 'psi_locs', dist.VonMises( torch.ones(num_states, dtype=torch.float) * prior_loc, torch.ones(num_states, dtype=torch.float) * prior_conc).to_event(1)) psi_kappas = pyro.sample( 'psi_kappas', dist.Uniform( torch.ones(num_states, dtype=torch.float) * prior_kappa_min, torch.ones(num_states, dtype=torch.float) * prior_kappa_max).to_event(1)) element_plate = pyro.plate('elements', 1, dim=-1) with pyro.plate('sequences', num_sequences, dim=-2) as batch: if lengths is not None: lengths = lengths[batch] obs_length = lengths.float().unsqueeze(-1) else: obs_length = None state = 0 sam_lengths = pyro.sample('length', dist.TransformedDistribution( dist.GammaPoisson( length_shape, length_rate), AffineTransform(0., 1.)), obs=obs_length) if lengths is None: lengths = sam_lengths.squeeze(-1).long() for t in pyro.markov(range(lengths.max())): with poutine.mask(mask=(t < lengths).unsqueeze(-1)): state = pyro.sample(f'state_{t}', dist.Categorical(transition_probs[state]), infer={'enumerate': 'parallel'}) if phis is not None: obs_phi = Vindex(phis)[batch, t].unsqueeze(-1) else: obs_phi = None if psis is not None: obs_psi = Vindex(psis)[batch, t].unsqueeze(-1) else: obs_psi = None with element_plate: pyro.sample(f'phi_{t}', dist.VonMises(phi_locs[state], phi_kappas[state]), obs=obs_phi) pyro.sample(f'psi_{t}', dist.VonMises(psi_locs[state], psi_kappas[state]), obs=obs_psi)
def model_generic(config): """Hierarchical mixed-effects hidden markov model""" MISSING = config["MISSING"] N_v = config["sizes"]["random"] N_state = config["sizes"]["state"] # initialize group-level random effect parameterss if config["group"]["random"] == "discrete": probs_e_g = pyro.param("probs_e_group", lambda: torch.randn((N_v, )).abs(), constraint=constraints.simplex) theta_g = pyro.param("theta_group", lambda: torch.randn( (N_v, N_state**2))) elif config["group"]["random"] == "continuous": loc_g = torch.zeros((N_state**2, )) scale_g = torch.ones((N_state**2, )) # initialize individual-level random effect parameters N_c = config["sizes"]["group"] if config["individual"]["random"] == "discrete": probs_e_i = pyro.param("probs_e_individual", lambda: torch.randn(( N_c, N_v, )).abs(), constraint=constraints.simplex) theta_i = pyro.param("theta_individual", lambda: torch.randn( (N_c, N_v, N_state**2))) elif config["individual"]["random"] == "continuous": loc_i = torch.zeros(( N_c, N_state**2, )) scale_i = torch.ones(( N_c, N_state**2, )) # initialize likelihood parameters # observation 1: step size (step ~ Gamma) step_zi_param = pyro.param("step_zi_param", lambda: torch.ones( (N_state, 2))) step_concentration = pyro.param("step_param_concentration", lambda: torch.randn((N_state, )).abs(), constraint=constraints.positive) step_rate = pyro.param("step_param_rate", lambda: torch.randn((N_state, )).abs(), constraint=constraints.positive) # observation 2: step angle (angle ~ VonMises) angle_concentration = pyro.param("angle_param_concentration", lambda: torch.randn((N_state, )).abs(), constraint=constraints.positive) angle_loc = pyro.param("angle_param_loc", lambda: torch.randn( (N_state, )).abs()) # observation 3: dive activity (omega ~ Beta) omega_zi_param = pyro.param("omega_zi_param", lambda: torch.ones( (N_state, 2))) omega_concentration0 = pyro.param("omega_param_concentration0", lambda: torch.randn((N_state, )).abs(), constraint=constraints.positive) omega_concentration1 = pyro.param("omega_param_concentration1", lambda: torch.randn((N_state, )).abs(), constraint=constraints.positive) # initialize gamma to uniform gamma = torch.zeros((N_state**2, )) N_c = config["sizes"]["group"] with pyro.plate("group", N_c, dim=-1): # group-level random effects if config["group"]["random"] == "discrete": # group-level discrete effect e_g = pyro.sample("e_g", dist.Categorical(probs_e_g)) eps_g = Vindex(theta_g)[..., e_g, :] elif config["group"]["random"] == "continuous": eps_g = pyro.sample( "eps_g", dist.Normal(loc_g, scale_g).to_event(1), ) # infer={"num_samples": 10}) else: eps_g = 0. # add group-level random effect to gamma gamma = gamma + eps_g N_s = config["sizes"]["individual"] with pyro.plate( "individual", N_s, dim=-2), poutine.mask(mask=config["individual"]["mask"]): # individual-level random effects if config["individual"]["random"] == "discrete": # individual-level discrete effect e_i = pyro.sample("e_i", dist.Categorical(probs_e_i)) eps_i = Vindex(theta_i)[..., e_i, :] # assert eps_i.shape[-3:] == (1, N_c, N_state ** 2) and eps_i.shape[0] == N_v elif config["individual"]["random"] == "continuous": eps_i = pyro.sample( "eps_i", dist.Normal(loc_i, scale_i).to_event(1), ) # infer={"num_samples": 10}) else: eps_i = 0. # add individual-level random effect to gamma gamma = gamma + eps_i y = torch.tensor(0).long() N_t = config["sizes"]["timesteps"] for t in pyro.markov(range(N_t)): with poutine.mask(mask=config["timestep"]["mask"][..., t]): gamma_t = gamma # per-timestep variable # finally, reshape gamma as batch of transition matrices gamma_t = gamma_t.reshape( tuple(gamma_t.shape[:-1]) + (N_state, N_state)) # we've accounted for all effects, now actually compute gamma_y gamma_y = Vindex(gamma_t)[..., y, :] y = pyro.sample("y_{}".format(t), dist.Categorical(logits=gamma_y)) # observation 1: step size step_dist = dist.Gamma( concentration=Vindex(step_concentration)[..., y], rate=Vindex(step_rate)[..., y]) # zero-inflation with MaskedMixture step_zi = Vindex(step_zi_param)[..., y, :] step_zi_mask = pyro.sample( "step_zi_{}".format(t), dist.Categorical(logits=step_zi), obs=(config["observations"]["step"][..., t] == MISSING)) step_zi_zero_dist = dist.Delta(v=torch.tensor(MISSING)) step_zi_dist = dist.MaskedMixture(step_zi_mask, step_dist, step_zi_zero_dist) pyro.sample("step_{}".format(t), step_zi_dist, obs=config["observations"]["step"][..., t]) # observation 2: step angle angle_dist = dist.VonMises( concentration=Vindex(angle_concentration)[..., y], loc=Vindex(angle_loc)[..., y]) pyro.sample("angle_{}".format(t), angle_dist, obs=config["observations"]["angle"][..., t]) # observation 3: dive activity omega_dist = dist.Beta( concentration0=Vindex(omega_concentration0)[..., y], concentration1=Vindex(omega_concentration1)[..., y]) # zero-inflation with MaskedMixture omega_zi = Vindex(omega_zi_param)[..., y, :] omega_zi_mask = pyro.sample( "omega_zi_{}".format(t), dist.Categorical(logits=omega_zi), obs=(config["observations"]["omega"][..., t] == MISSING)) omega_zi_zero_dist = dist.Delta(v=torch.tensor(MISSING)) omega_zi_dist = dist.MaskedMixture(omega_zi_mask, omega_dist, omega_zi_zero_dist) pyro.sample("omega_{}".format(t), omega_zi_dist, obs=config["observations"]["omega"][..., t])
def model(self): r""" Generative Model """ # global parameters gain = pyro.sample("gain", dist.HalfNormal(self.priors["gain_std"])) alpha = pyro.sample( "alpha", dist.Dirichlet( torch.ones((self.Q, self.data.C)) + torch.eye(self.Q) * 9).to_event(1), ) pi = pyro.sample( "pi", dist.Dirichlet(torch.ones( (self.Q, self.S + 1)) / (self.S + 1)).to_event(1), ) pi = expand_offtarget(pi) lamda = pyro.sample( "lamda", dist.Exponential(torch.full( (self.Q, ), self.priors["lamda_rate"])).to_event(1), ) proximity = pyro.sample( "proximity", dist.Exponential(self.priors["proximity_rate"])) size = torch.stack( ( torch.full_like(proximity, 2.0), (((self.data.P + 1) / (2 * proximity))**2 - 1), ), dim=-1, ) # aoi sites aois = pyro.plate( "aois", self.data.Nt, subsample=self.n, subsample_size=self.nbatch_size, dim=-2, ) # time frames frames = pyro.plate( "frames", self.data.F, subsample=self.f, subsample_size=self.fbatch_size, dim=-1, ) with aois as ndx: ndx = ndx[:, None] mask = Vindex(self.data.mask)[ndx].to(self.device) with handlers.mask(mask=mask): # background mean and std background_mean = pyro.sample( "background_mean", dist.HalfNormal(self.priors["background_mean_std"]).expand( (self.data.C, )).to_event(1), ) background_std = pyro.sample( "background_std", dist.HalfNormal(self.priors["background_std_std"]).expand( (self.data.C, )).to_event(1), ) with frames as fdx: # fetch data obs, target_locs, is_ontarget = self.data.fetch( ndx.unsqueeze(-1), fdx.unsqueeze(-1), torch.arange(self.data.C)) # sample background intensity background = pyro.sample( "background", dist.Gamma( (background_mean / background_std)**2, background_mean / background_std**2, ).to_event(1), ) ms, heights, widths, xs, ys = [], [], [], [], [] is_ontarget = is_ontarget.squeeze(-1) for qdx in range(self.Q): # sample hidden model state (1+S,) z_probs = Vindex(pi)[..., qdx, :, is_ontarget.long()] z = pyro.sample( f"z_q{qdx}", dist.Categorical(z_probs), infer={"enumerate": "parallel"}, ) theta = pyro.sample( f"theta_q{qdx}", dist.Categorical( Vindex(probs_theta( self.K, self.device))[torch.clamp(z, min=0, max=1)]), infer={"enumerate": "parallel"}, ) onehot_theta = one_hot(theta, num_classes=1 + self.K) for kdx in range(self.K): specific = onehot_theta[..., 1 + kdx] # spot presence m = pyro.sample( f"m_k{kdx}_q{qdx}", dist.Bernoulli( Vindex(probs_m(lamda, self.K))[..., qdx, theta, kdx]), ) with handlers.mask(mask=m > 0): # sample spot variables height = pyro.sample( f"height_k{kdx}_q{qdx}", dist.HalfNormal(self.priors["height_std"]), ) width = pyro.sample( f"width_k{kdx}_q{qdx}", AffineBeta( 1.5, 2, self.priors["width_min"], self.priors["width_max"], ), ) x = pyro.sample( f"x_k{kdx}_q{qdx}", AffineBeta( 0, Vindex(size)[..., specific], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) y = pyro.sample( f"y_k{kdx}_q{qdx}", AffineBeta( 0, Vindex(size)[..., specific], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) # append ms.append(m) heights.append(height) widths.append(width) xs.append(x) ys.append(y) heights = torch.stack( [ torch.stack(heights[q * self.K:(1 + q) * self.K], -1) for q in range(self.Q) ], -2, ) widths = torch.stack( [ torch.stack(widths[q * self.K:(1 + q) * self.K], -1) for q in range(self.Q) ], -2, ) xs = torch.stack( [ torch.stack(xs[q * self.K:(1 + q) * self.K], -1) for q in range(self.Q) ], -2, ) ys = torch.stack( [ torch.stack(ys[q * self.Q:(1 + q) * self.K], -1) for q in range(self.Q) ], -2, ) ms = torch.broadcast_tensors(*ms) ms = torch.stack( [ torch.stack(ms[q * self.Q:(1 + q) * self.K], -1) for q in range(self.Q) ], -2, ) # observed data pyro.sample( "data", KSMOGN( heights, widths, xs, ys, target_locs, background, gain, self.data.offset.samples, self.data.offset.logits.to(self.dtype), self.data.P, ms, alpha, use_pykeops=self.use_pykeops, ), obs=obs, )
def guide(self): r""" Variational Distribution """ # global parameters pyro.sample( "gain", dist.Gamma( pyro.param("gain_loc") * pyro.param("gain_beta"), pyro.param("gain_beta"), ), ) pyro.sample( "alpha", dist.Dirichlet( pyro.param("alpha_mean") * pyro.param("alpha_size")).to_event(1), ) pyro.sample( "pi", dist.Dirichlet(pyro.param("pi_mean") * pyro.param("pi_size")).to_event(1), ) pyro.sample( "lamda", dist.Gamma( pyro.param("lamda_loc") * pyro.param("lamda_beta"), pyro.param("lamda_beta"), ).to_event(1), ) pyro.sample( "proximity", AffineBeta( pyro.param("proximity_loc"), pyro.param("proximity_size"), 0, (self.data.P + 1) / math.sqrt(12), ), ) # aoi sites aois = pyro.plate( "aois", self.data.Nt, subsample=self.n, subsample_size=self.nbatch_size, dim=-2, ) # time frames frames = pyro.plate( "frames", self.data.F, subsample=self.f, subsample_size=self.fbatch_size, dim=-1, ) with aois as ndx: ndx = ndx[:, None] mask = Vindex(self.data.mask)[ndx].to(self.device) with handlers.mask(mask=mask): pyro.sample( "background_mean", dist.Delta( Vindex( pyro.param("background_mean_loc"))[ndx, 0]).to_event(1), ) pyro.sample( "background_std", dist.Delta( Vindex( pyro.param("background_std_loc"))[ndx, 0]).to_event(1), ) with frames as fdx: # sample background intensity pyro.sample( "background", dist.Gamma( Vindex(pyro.param("b_loc"))[ndx, fdx] * Vindex(pyro.param("b_beta"))[ndx, fdx], Vindex(pyro.param("b_beta"))[ndx, fdx], ).to_event(1), ) for qdx in range(self.Q): for kdx in range(self.K): # sample spot presence m m = pyro.sample( f"m_k{kdx}_q{qdx}", dist.Bernoulli( Vindex(pyro.param("m_probs"))[kdx, ndx, fdx, qdx]), infer={"enumerate": "parallel"}, ) with handlers.mask(mask=m > 0): # sample spot variables pyro.sample( f"height_k{kdx}_q{qdx}", dist.Gamma( Vindex(pyro.param("h_loc"))[kdx, ndx, fdx, qdx] * Vindex(pyro.param("h_beta"))[kdx, ndx, fdx, qdx], Vindex(pyro.param("h_beta"))[kdx, ndx, fdx, qdx], ), ) pyro.sample( f"width_k{kdx}_q{qdx}", AffineBeta( Vindex(pyro.param("w_mean"))[kdx, ndx, fdx, qdx], Vindex(pyro.param("w_size"))[kdx, ndx, fdx, qdx], self.priors["width_min"], self.priors["width_max"], ), ) pyro.sample( f"x_k{kdx}_q{qdx}", AffineBeta( Vindex(pyro.param("x_mean"))[kdx, ndx, fdx, qdx], Vindex(pyro.param("size"))[kdx, ndx, fdx, qdx], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) pyro.sample( f"y_k{kdx}_q{qdx}", AffineBeta( Vindex(pyro.param("y_mean"))[kdx, ndx, fdx, qdx], Vindex(pyro.param("size"))[kdx, ndx, fdx, qdx], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), )
def guide(self): """ **Variational Distribution** """ # global parameters pyro.sample( "gain", dist.Gamma( pyro.param("gain_loc") * pyro.param("gain_beta"), pyro.param("gain_beta"), ), ) pyro.sample( "init", dist.Dirichlet(pyro.param("init_mean") * pyro.param("init_size"))) pyro.sample( "trans", dist.Dirichlet( pyro.param("trans_mean") * pyro.param("trans_size")).to_event(1), ) pyro.sample( "lamda", dist.Gamma( pyro.param("lamda_loc") * pyro.param("lamda_beta"), pyro.param("lamda_beta"), ), ) pyro.sample( "proximity", AffineBeta( pyro.param("proximity_loc"), pyro.param("proximity_size"), 0, (self.data.P + 1) / math.sqrt(12), ), ) # spots spots = pyro.plate("spots", self.K) # aoi sites aois = pyro.plate( "aois", self.data.Nt, subsample=self.n, subsample_size=self.nbatch_size, dim=-2, ) # time frames frames = (pyro.vectorized_markov( name="frames", size=self.data.F, dim=-1) if self.vectorized else pyro.markov(range(self.data.F))) with aois as ndx: ndx = ndx[:, None] pyro.sample( "background_mean", dist.Delta(Vindex(pyro.param("background_mean_loc"))[ndx, 0]), ) pyro.sample( "background_std", dist.Delta(Vindex(pyro.param("background_std_loc"))[ndx, 0]), ) z_prev = None for fdx in frames: if self.vectorized: fsx, fdx = fdx else: fsx = fdx # sample background intensity pyro.sample( f"background_{fsx}", dist.Gamma( Vindex(pyro.param("b_loc"))[ndx, fdx] * Vindex(pyro.param("b_beta"))[ndx, fdx], Vindex(pyro.param("b_beta"))[ndx, fdx], ), ) # sample hidden model state z_probs = (Vindex(pyro.param("z_trans"))[ndx, fdx, 0] if isinstance(fdx, int) and fdx < 1 else Vindex( pyro.param("z_trans"))[ndx, fdx, z_prev]) z_curr = pyro.sample( f"z_{fsx}", dist.Categorical(z_probs), infer={"enumerate": "parallel"}, ) for kdx in spots: # spot presence m_probs = Vindex(pyro.param("m_probs"))[z_curr, kdx, ndx, fdx] m = pyro.sample( f"m_{kdx}_{fsx}", dist.Categorical( torch.stack((1 - m_probs, m_probs), -1)), infer={"enumerate": "parallel"}, ) with handlers.mask(mask=m > 0): # sample spot variables pyro.sample( f"height_{kdx}_{fsx}", dist.Gamma( Vindex(pyro.param("h_loc"))[kdx, ndx, fdx] * Vindex(pyro.param("h_beta"))[kdx, ndx, fdx], Vindex(pyro.param("h_beta"))[kdx, ndx, fdx], ), ) pyro.sample( f"width_{kdx}_{fsx}", AffineBeta( Vindex(pyro.param("w_mean"))[kdx, ndx, fdx], Vindex(pyro.param("w_size"))[kdx, ndx, fdx], 0.75, 2.25, ), ) pyro.sample( f"x_{kdx}_{fsx}", AffineBeta( Vindex(pyro.param("x_mean"))[kdx, ndx, fdx], Vindex(pyro.param("size"))[kdx, ndx, fdx], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) pyro.sample( f"y_{kdx}_{fsx}", AffineBeta( Vindex(pyro.param("y_mean"))[kdx, ndx, fdx], Vindex(pyro.param("size"))[kdx, ndx, fdx], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) z_prev = z_curr
def model(self): """ **Generative Model** """ # global parameters gain = pyro.sample("gain", dist.HalfNormal(self.priors["gain_std"])) init = pyro.sample( "init", dist.Dirichlet(torch.ones(self.Q, self.S + 1) / (self.S + 1)).to_event(1), ) init = expand_offtarget(init) trans = pyro.sample( "trans", dist.Dirichlet( torch.ones(self.Q, self.S + 1, self.S + 1) / (self.S + 1)).to_event(2), ) trans = expand_offtarget(trans) lamda = pyro.sample( "lamda", dist.Exponential(torch.full( (self.Q, ), self.priors["lamda_rate"])).to_event(1), ) proximity = pyro.sample( "proximity", dist.Exponential(self.priors["proximity_rate"])) size = torch.stack( ( torch.full_like(proximity, 2.0), (((self.data.P + 1) / (2 * proximity))**2 - 1), ), dim=-1, ) # spots spots = pyro.plate("spots", self.K) # aoi sites aois = pyro.plate( "aois", self.data.Nt, subsample=self.n, subsample_size=self.nbatch_size, dim=-3, ) # time frames frames = (pyro.vectorized_markov( name="frames", size=self.data.F, dim=-2) if self.vectorized else pyro.markov(range(self.data.F))) # color channels channels = pyro.plate( "channels", self.data.C, dim=-1, ) with channels as cdx, aois as ndx: ndx = ndx[:, None, None] mask = Vindex(self.data.mask)[ndx].to(self.device) with handlers.mask(mask=mask): # background mean and std background_mean = pyro.sample( "background_mean", dist.HalfNormal(self.priors["background_mean_std"]), ) background_std = pyro.sample( "background_std", dist.HalfNormal(self.priors["background_std_std"])) z_prev = None for fdx in frames: if self.vectorized: fsx, fdx = fdx fdx = torch.as_tensor(fdx) fdx = fdx.unsqueeze(-1) else: fsx = fdx # fetch data obs, target_locs, is_ontarget = self.data.fetch( ndx, fdx, cdx) # sample background intensity background = pyro.sample( f"background_f{fsx}", dist.Gamma( (background_mean / background_std)**2, background_mean / background_std**2, ), ) # sample hidden model state (1+S,) z_probs = (Vindex(init)[..., cdx, :, is_ontarget.long()] if z_prev is None else Vindex(trans)[..., cdx, z_prev, :, is_ontarget.long()]) z_curr = pyro.sample(f"z_f{fsx}", dist.Categorical(z_probs)) theta = pyro.sample( f"theta_f{fsx}", dist.Categorical( Vindex(probs_theta( self.K, self.device))[torch.clamp(z_curr, min=0, max=1)]), infer={"enumerate": "parallel"}, ) onehot_theta = one_hot(theta, num_classes=1 + self.K) ms, heights, widths, xs, ys = [], [], [], [], [] for kdx in spots: specific = onehot_theta[..., 1 + kdx] # spot presence m_probs = Vindex(probs_m(lamda, self.K))[..., cdx, theta, kdx] m = pyro.sample( f"m_k{kdx}_f{fsx}", dist.Categorical( torch.stack((1 - m_probs, m_probs), -1)), ) with handlers.mask(mask=m > 0): # sample spot variables height = pyro.sample( f"height_k{kdx}_f{fsx}", dist.HalfNormal(self.priors["height_std"]), ) width = pyro.sample( f"width_k{kdx}_f{fsx}", AffineBeta( 1.5, 2, self.priors["width_min"], self.priors["width_max"], ), ) x = pyro.sample( f"x_k{kdx}_f{fsx}", AffineBeta( 0, Vindex(size)[..., specific], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) y = pyro.sample( f"y_k{kdx}_f{fsx}", AffineBeta( 0, Vindex(size)[..., specific], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) # append ms.append(m) heights.append(height) widths.append(width) xs.append(x) ys.append(y) # observed data pyro.sample( f"data_f{fsx}", KSMOGN( torch.stack(heights, -1), torch.stack(widths, -1), torch.stack(xs, -1), torch.stack(ys, -1), target_locs, background, gain, self.data.offset.samples, self.data.offset.logits.to(self.dtype), self.data.P, torch.stack(torch.broadcast_tensors(*ms), -1), use_pykeops=self.use_pykeops, ), obs=obs, ) z_prev = z_curr
def guide(self): """ **Variational Distribution** """ # global parameters pyro.sample( "gain", dist.Gamma( pyro.param("gain_loc") * pyro.param("gain_beta"), pyro.param("gain_beta"), ), ) pyro.sample( "init", dist.Dirichlet(pyro.param("init_mean") * pyro.param("init_size")).to_event(1), ) pyro.sample( "trans", dist.Dirichlet( pyro.param("trans_mean") * pyro.param("trans_size")).to_event(2), ) pyro.sample( "lamda", dist.Gamma( pyro.param("lamda_loc") * pyro.param("lamda_beta"), pyro.param("lamda_beta"), ).to_event(1), ) pyro.sample( "proximity", AffineBeta( pyro.param("proximity_loc"), pyro.param("proximity_size"), 0, (self.data.P + 1) / math.sqrt(12), ), ) # spots spots = pyro.plate("spots", self.K) # aoi sites aois = pyro.plate( "aois", self.data.Nt, subsample=self.n, subsample_size=self.nbatch_size, dim=-3, ) # time frames frames = (pyro.vectorized_markov( name="frames", size=self.data.F, dim=-2) if self.vectorized else pyro.markov(range(self.data.F))) # color channels channels = pyro.plate( "channels", self.data.C, dim=-1, ) with channels as cdx, aois as ndx: ndx = ndx[:, None, None] mask = Vindex(self.data.mask)[ndx].to(self.device) with handlers.mask(mask=mask): pyro.sample( "background_mean", dist.Delta( Vindex(pyro.param("background_mean_loc"))[ndx, 0, cdx]), ) pyro.sample( "background_std", dist.Delta( Vindex(pyro.param("background_std_loc"))[ndx, 0, cdx]), ) z_prev = None for fdx in frames: if self.vectorized: fsx, fdx = fdx fdx = torch.as_tensor(fdx) fdx = fdx.unsqueeze(-1) else: fsx = fdx # sample background intensity pyro.sample( f"background_f{fsx}", dist.Gamma( Vindex(pyro.param("b_loc"))[ndx, fdx, cdx] * Vindex(pyro.param("b_beta"))[ndx, fdx, cdx], Vindex(pyro.param("b_beta"))[ndx, fdx, cdx], ), ) # sample hidden model state z_probs = (Vindex(pyro.param("z_trans"))[ndx, fdx, cdx, 0] if z_prev is None else Vindex( pyro.param("z_trans"))[ndx, fdx, cdx, z_prev]) z_curr = pyro.sample( f"z_f{fsx}", dist.Categorical(z_probs), infer={"enumerate": "parallel"}, ) for kdx in spots: # spot presence m_probs = Vindex(pyro.param("m_probs"))[z_curr, kdx, ndx, fdx, cdx] m = pyro.sample( f"m_k{kdx}_f{fsx}", dist.Categorical( torch.stack((1 - m_probs, m_probs), -1)), infer={"enumerate": "parallel"}, ) with handlers.mask(mask=m > 0): # sample spot variables pyro.sample( f"height_k{kdx}_f{fsx}", dist.Gamma( Vindex(pyro.param("h_loc"))[kdx, ndx, fdx, cdx] * Vindex(pyro.param("h_beta"))[kdx, ndx, fdx, cdx], Vindex(pyro.param("h_beta"))[kdx, ndx, fdx, cdx], ), ) pyro.sample( f"width_k{kdx}_f{fsx}", AffineBeta( Vindex(pyro.param("w_mean"))[kdx, ndx, fdx, cdx], Vindex(pyro.param("w_size"))[kdx, ndx, fdx, cdx], self.priors["width_min"], self.priors["width_max"], ), ) pyro.sample( f"x_k{kdx}_f{fsx}", AffineBeta( Vindex(pyro.param("x_mean"))[kdx, ndx, fdx, cdx], Vindex(pyro.param("size"))[kdx, ndx, fdx, cdx], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) pyro.sample( f"y_k{kdx}_f{fsx}", AffineBeta( Vindex(pyro.param("y_mean"))[kdx, ndx, fdx, cdx], Vindex(pyro.param("size"))[kdx, ndx, fdx, cdx], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) z_prev = z_curr
def model(self): r""" **Generative Model** Model parameters: +-----------------+-----------+-------------------------------------+ | Parameter | Shape | Description | +=================+===========+=====================================+ | |g| - :math:`g` | (1,) | camera gain | +-----------------+-----------+-------------------------------------+ | |sigma| - |prox|| (1,) | proximity | +-----------------+-----------+-------------------------------------+ | ``lamda`` - |ld|| (1,) | average rate of target-nonspecific | | | | binding | +-----------------+-----------+-------------------------------------+ | ``pi`` - |pi| | (1,) | average binding probability of | | | | target-specific binding | +-----------------+-----------+-------------------------------------+ | |bg| - |b| | (N, F) | background intensity | +-----------------+-----------+-------------------------------------+ | |z| - :math:`z` | (N, F) | target-specific spot presence | +-----------------+-----------+-------------------------------------+ | |t| - |theta| | (N, F) | target-specific spot index | +-----------------+-----------+-------------------------------------+ | |m| - :math:`m` | (K, N, F) | spot presence indicator | +-----------------+-----------+-------------------------------------+ | |h| - :math:`h` | (K, N, F) | spot intensity | +-----------------+-----------+-------------------------------------+ | |w| - :math:`w` | (K, N, F) | spot width | +-----------------+-----------+-------------------------------------+ | |x| - :math:`x` | (K, N, F) | spot position on x-axis | +-----------------+-----------+-------------------------------------+ | |y| - :math:`y` | (K, N, F) | spot position on y-axis | +-----------------+-----------+-------------------------------------+ | |D| - :math:`D` | |shape| | observed images | +-----------------+-----------+-------------------------------------+ .. |ps| replace:: :math:`p(\mathsf{specific})` .. |theta| replace:: :math:`\theta` .. |prox| replace:: :math:`\sigma^{xy}` .. |ld| replace:: :math:`\lambda` .. |b| replace:: :math:`b` .. |shape| replace:: (N, F, P, P) .. |sigma| replace:: ``proximity`` .. |bg| replace:: ``background`` .. |h| replace:: ``height`` .. |w| replace:: ``width`` .. |D| replace:: ``data`` .. |m| replace:: ``m`` .. |z| replace:: ``z`` .. |t| replace:: ``theta`` .. |x| replace:: ``x`` .. |y| replace:: ``y`` .. |pi| replace:: :math:`\pi` .. |g| replace:: ``gain`` Full joint distribution: .. math:: \begin{aligned} p(D, \phi) =~&p(g) p(\sigma^{xy}) p(\pi) p(\lambda) \prod_{\mathsf{AOI}} \left[ p(\mu^b) p(\sigma^b) \prod_{\mathsf{frame}} \left[ \vphantom{\prod_{F}} p(b | \mu^b, \sigma^b) p(z | \pi) p(\theta | z) \vphantom{\prod_{\substack{\mathsf{pixelX} \\ \mathsf{pixelY}}}} \cdot \right. \right. \\ &\prod_{\mathsf{spot}} \left[ \vphantom{\prod_{F}} p(m | \theta, \lambda) p(h) p(w) p(x | \sigma^{xy}, \theta) p(y | \sigma^{xy}, \theta) \right] \left. \left. \prod_{\substack{\mathsf{pixelX} \\ \mathsf{pixelY}}} \sum_{\delta} p(\delta) p(D | \mu^I, g, \delta) \right] \right] \end{aligned} :math:`z` and :math:`\theta` marginalized joint distribution: .. math:: \begin{aligned} \sum_{z, \theta} p(D, \phi) =~&p(g) p(\sigma^{xy}) p(\pi) p(\lambda) \prod_{\mathsf{AOI}} \left[ p(\mu^b) p(\sigma^b) \prod_{\mathsf{frame}} \left[ \vphantom{\prod_{F}} p(b | \mu^b, \sigma^b) \sum_{z} p(z | \pi) \sum_{\theta} p(\theta | z) \vphantom{\prod_{\substack{\mathsf{pixelX} \\ \mathsf{pixelY}}}} \cdot \right. \right. \\ &\prod_{\mathsf{spot}} \left[ \vphantom{\prod_{F}} p(m | \theta, \lambda) p(h) p(w) p(x | \sigma^{xy}, \theta) p(y | \sigma^{xy}, \theta) \right] \left. \left. \prod_{\substack{\mathsf{pixelX} \\ \mathsf{pixelY}}} \sum_{\delta} p(\delta) p(D | \mu^I, g, \delta) \right] \right] \end{aligned} """ # global parameters gain = pyro.sample("gain", dist.HalfNormal(self.gain_std)) pi = pyro.sample("pi", dist.Dirichlet(torch.ones(self.S + 1) / (self.S + 1))) pi = expand_offtarget(pi) lamda = pyro.sample("lamda", dist.Exponential(self.lamda_rate)) proximity = pyro.sample("proximity", dist.Exponential(self.proximity_rate)) size = torch.stack( ( torch.full_like(proximity, 2.0), (((self.data.P + 1) / (2 * proximity))**2 - 1), ), dim=-1, ) # spots spots = pyro.plate("spots", self.K) # aoi sites aois = pyro.plate( "aois", self.data.Nt, subsample=self.n, subsample_size=self.nbatch_size, dim=-2, ) # time frames frames = pyro.plate( "frames", self.data.F, subsample=self.f, subsample_size=self.fbatch_size, dim=-1, ) with aois as ndx: ndx = ndx[:, None] # background mean and std background_mean = pyro.sample( "background_mean", dist.HalfNormal(self.background_mean_std)) background_std = pyro.sample( "background_std", dist.HalfNormal(self.background_std_std)) with frames as fdx: # fetch data obs, target_locs, is_ontarget = self.data.fetch( ndx, fdx, self.cdx) # sample background intensity background = pyro.sample( "background", dist.Gamma( (background_mean / background_std)**2, background_mean / background_std**2, ), ) # sample hidden model state (1+S,) z = pyro.sample( "z", dist.Categorical(Vindex(pi)[..., :, is_ontarget.long()]), infer={"enumerate": "parallel"}, ) theta = pyro.sample( "theta", dist.Categorical( Vindex(probs_theta(self.K, self.device))[torch.clamp(z, min=0, max=1)]), infer={"enumerate": "parallel"}, ) onehot_theta = one_hot(theta, num_classes=1 + self.K) ms, heights, widths, xs, ys = [], [], [], [], [] for kdx in spots: specific = onehot_theta[..., 1 + kdx] # spot presence m = pyro.sample( f"m_{kdx}", dist.Bernoulli( Vindex(probs_m(lamda, self.K))[..., theta, kdx]), ) with handlers.mask(mask=m > 0): # sample spot variables height = pyro.sample( f"height_{kdx}", dist.HalfNormal(self.height_std), ) width = pyro.sample( f"width_{kdx}", AffineBeta( 1.5, 2, self.width_min, self.width_max, ), ) x = pyro.sample( f"x_{kdx}", AffineBeta( 0, Vindex(size)[..., specific], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) y = pyro.sample( f"y_{kdx}", AffineBeta( 0, Vindex(size)[..., specific], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) # append ms.append(m) heights.append(height) widths.append(width) xs.append(x) ys.append(y) # observed data pyro.sample( "data", KSMOGN( torch.stack(heights, -1), torch.stack(widths, -1), torch.stack(xs, -1), torch.stack(ys, -1), target_locs, background, gain, self.data.offset.samples, self.data.offset.logits.to(self.dtype), self.data.P, torch.stack(torch.broadcast_tensors(*ms), -1), self.use_pykeops, ), obs=obs, )
def guide(self): r""" **Variational Distribution** .. math:: \begin{aligned} q(\phi \setminus \{z, \theta\}) =~&q(g) q(\sigma^{xy}) q(\pi) q(\lambda) \cdot \\ &\prod_{\mathsf{AOI}} \left[ q(\mu^b) q(\sigma^b) \prod_{\mathsf{frame}} \left[ \vphantom{\prod_{F}} q(b) \prod_{\mathsf{spot}} q(m) q(h | m) q(w | m) q(x | m) q(y | m) \right] \right] \end{aligned} """ # global parameters pyro.sample( "gain", dist.Gamma( pyro.param("gain_loc") * pyro.param("gain_beta"), pyro.param("gain_beta"), ), ) pyro.sample( "pi", dist.Dirichlet(pyro.param("pi_mean") * pyro.param("pi_size"))) pyro.sample( "lamda", dist.Gamma( pyro.param("lamda_loc") * pyro.param("lamda_beta"), pyro.param("lamda_beta"), ), ) pyro.sample( "proximity", AffineBeta( pyro.param("proximity_loc"), pyro.param("proximity_size"), 0, (self.data.P + 1) / math.sqrt(12), ), ) # spots spots = pyro.plate("spots", self.K) # aoi sites aois = pyro.plate( "aois", self.data.Nt, subsample=self.n, subsample_size=self.nbatch_size, dim=-2, ) # time frames frames = pyro.plate( "frames", self.data.F, subsample=self.f, subsample_size=self.fbatch_size, dim=-1, ) with aois as ndx: ndx = ndx[:, None] pyro.sample( "background_mean", dist.Delta(Vindex(pyro.param("background_mean_loc"))[ndx, 0]), ) pyro.sample( "background_std", dist.Delta(Vindex(pyro.param("background_std_loc"))[ndx, 0]), ) with frames as fdx: # sample background intensity pyro.sample( "background", dist.Gamma( Vindex(pyro.param("b_loc"))[ndx, fdx] * Vindex(pyro.param("b_beta"))[ndx, fdx], Vindex(pyro.param("b_beta"))[ndx, fdx], ), ) for kdx in spots: # sample spot presence m m = pyro.sample( f"m_{kdx}", dist.Bernoulli( Vindex(pyro.param("m_probs"))[kdx, ndx, fdx]), infer={"enumerate": "parallel"}, ) with handlers.mask(mask=m > 0): # sample spot variables pyro.sample( f"height_{kdx}", dist.Gamma( Vindex(pyro.param("h_loc"))[kdx, ndx, fdx] * Vindex(pyro.param("h_beta"))[kdx, ndx, fdx], Vindex(pyro.param("h_beta"))[kdx, ndx, fdx], ), ) pyro.sample( f"width_{kdx}", AffineBeta( Vindex(pyro.param("w_mean"))[kdx, ndx, fdx], Vindex(pyro.param("w_size"))[kdx, ndx, fdx], 0.75, 2.25, ), ) pyro.sample( f"x_{kdx}", AffineBeta( Vindex(pyro.param("x_mean"))[kdx, ndx, fdx], Vindex(pyro.param("size"))[kdx, ndx, fdx], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), ) pyro.sample( f"y_{kdx}", AffineBeta( Vindex(pyro.param("y_mean"))[kdx, ndx, fdx], Vindex(pyro.param("size"))[kdx, ndx, fdx], -(self.data.P + 1) / 2, (self.data.P + 1) / 2, ), )
def m_probs(self) -> torch.Tensor: r""" Posterior spot presence probability :math:`q(m=1, z=z_\mathsf{MAP})`. """ return Vindex(torch.permute(pyro.param("m_probs").data, (1, 2, 3, 0)))[..., self.z_map.long()]
def fetch(self, ndx, fdx, cdx): return ( Vindex(self.images)[ndx, fdx, cdx].to(self.device), Vindex(self.xy)[ndx, fdx, cdx].to(self.device), Vindex(self.is_ontarget)[ndx].to(self.device), )
def save_stats(model, path, CI=0.95, save_matlab=False): # global parameters global_params = model._global_params summary = pd.DataFrame( index=global_params, columns=["Mean", f"{int(100*CI)}% LL", f"{int(100*CI)}% UL"], ) # local parameters local_params = [ "height", "width", "x", "y", "background", ] ci_stats = defaultdict(partial(defaultdict, list)) num_samples = 10000 for param in global_params: if param == "gain": fn = dist.Gamma( pyro.param("gain_loc") * pyro.param("gain_beta"), pyro.param("gain_beta"), ) elif param == "pi": fn = dist.Dirichlet(pyro.param("pi_mean") * pyro.param("pi_size")) elif param == "lamda": fn = dist.Gamma( pyro.param("lamda_loc") * pyro.param("lamda_beta"), pyro.param("lamda_beta"), ) elif param == "proximity": fn = AffineBeta( pyro.param("proximity_loc"), pyro.param("proximity_size"), 0, (model.data.P + 1) / math.sqrt(12), ) elif param == "trans": fn = dist.Dirichlet( pyro.param("trans_mean") * pyro.param("trans_size") ).to_event(1) else: raise NotImplementedError samples = fn.sample((num_samples,)).data.squeeze() ci_stats[param] = {} LL, UL = hpdi( samples, CI, dim=0, ) ci_stats[param]["LL"] = LL.cpu() ci_stats[param]["UL"] = UL.cpu() ci_stats[param]["Mean"] = fn.mean.data.squeeze().cpu() # calculate Keq if param == "pi": ci_stats["Keq"] = {} LL, UL = hpdi(samples[:, 1] / (1 - samples[:, 1]), CI, dim=0) ci_stats["Keq"]["LL"] = LL.cpu() ci_stats["Keq"]["UL"] = UL.cpu() ci_stats["Keq"]["Mean"] = (samples[:, 1] / (1 - samples[:, 1])).mean().cpu() # this does not need to be very accurate num_samples = 1000 for param in local_params: LL, UL, Mean = [], [], [] for ndx in torch.split(torch.arange(model.data.Nt), model.nbatch_size): ndx = ndx[:, None] kdx = torch.arange(model.K)[:, None, None] ll, ul, mean = [], [], [] for fdx in torch.split(torch.arange(model.data.F), model.fbatch_size): if param == "background": fn = dist.Gamma( Vindex(pyro.param("b_loc"))[ndx, fdx] * Vindex(pyro.param("b_beta"))[ndx, fdx], Vindex(pyro.param("b_beta"))[ndx, fdx], ) elif param == "height": fn = dist.Gamma( Vindex(pyro.param("h_loc"))[kdx, ndx, fdx] * Vindex(pyro.param("h_beta"))[kdx, ndx, fdx], Vindex(pyro.param("h_beta"))[kdx, ndx, fdx], ) elif param == "width": fn = AffineBeta( Vindex(pyro.param("w_mean"))[kdx, ndx, fdx], Vindex(pyro.param("w_size"))[kdx, ndx, fdx], 0.75, 2.25, ) elif param == "x": fn = AffineBeta( Vindex(pyro.param("x_mean"))[kdx, ndx, fdx], Vindex(pyro.param("size"))[kdx, ndx, fdx], -(model.data.P + 1) / 2, (model.data.P + 1) / 2, ) elif param == "y": fn = AffineBeta( Vindex(pyro.param("y_mean"))[kdx, ndx, fdx], Vindex(pyro.param("size"))[kdx, ndx, fdx], -(model.data.P + 1) / 2, (model.data.P + 1) / 2, ) else: raise NotImplementedError samples = fn.sample((num_samples,)).data l, u = hpdi( samples, CI, dim=0, ) m = fn.mean.data ll.append(l) ul.append(u) mean.append(m) else: LL.append(torch.cat(ll, -1)) UL.append(torch.cat(ul, -1)) Mean.append(torch.cat(mean, -1)) else: ci_stats[param]["LL"] = torch.cat(LL, -2).cpu() ci_stats[param]["UL"] = torch.cat(UL, -2).cpu() ci_stats[param]["Mean"] = torch.cat(Mean, -2).cpu() for param in global_params: if param == "pi": summary.loc[param, "Mean"] = ci_stats[param]["Mean"][1].item() summary.loc[param, "95% LL"] = ci_stats[param]["LL"][1].item() summary.loc[param, "95% UL"] = ci_stats[param]["UL"][1].item() # Keq summary.loc["Keq", "Mean"] = ci_stats["Keq"]["Mean"].item() summary.loc["Keq", "95% LL"] = ci_stats["Keq"]["LL"].item() summary.loc["Keq", "95% UL"] = ci_stats["Keq"]["UL"].item() elif param == "trans": summary.loc["kon", "Mean"] = ci_stats[param]["Mean"][0, 1].item() summary.loc["kon", "95% LL"] = ci_stats[param]["LL"][0, 1].item() summary.loc["kon", "95% UL"] = ci_stats[param]["UL"][0, 1].item() summary.loc["koff", "Mean"] = ci_stats[param]["Mean"][1, 0].item() summary.loc["koff", "95% LL"] = ci_stats[param]["LL"][1, 0].item() summary.loc["koff", "95% UL"] = ci_stats[param]["UL"][1, 0].item() else: summary.loc[param, "Mean"] = ci_stats[param]["Mean"].item() summary.loc[param, "95% LL"] = ci_stats[param]["LL"].item() summary.loc[param, "95% UL"] = ci_stats[param]["UL"].item() ci_stats["m_probs"] = model.m_probs.data.cpu() ci_stats["theta_probs"] = model.theta_probs.data.cpu() ci_stats["z_probs"] = model.z_probs.data.cpu() ci_stats["z_map"] = model.z_map.data.cpu() # timestamps if model.data.time1 is not None: ci_stats["time1"] = model.data.time1 if model.data.ttb is not None: ci_stats["ttb"] = model.data.ttb model.params = ci_stats # snr summary.loc["SNR", "Mean"] = ( snr( model.data.images[:, :, model.cdx], ci_stats["width"]["Mean"], ci_stats["x"]["Mean"], ci_stats["y"]["Mean"], model.data.xy[:, :, model.cdx], ci_stats["background"]["Mean"], ci_stats["gain"]["Mean"], model.data.offset.mean, model.data.offset.var, model.data.P, model.theta_probs, ) .mean() .item() ) # classification statistics if model.data.labels is not None: pred_labels = model.z_map[model.data.is_ontarget].cpu().numpy().ravel() true_labels = model.data.labels["z"][: model.data.N, :, model.cdx].ravel() with np.errstate(divide="ignore", invalid="ignore"): summary.loc["MCC", "Mean"] = matthews_corrcoef(true_labels, pred_labels) summary.loc["Recall", "Mean"] = recall_score( true_labels, pred_labels, zero_division=0 ) summary.loc["Precision", "Mean"] = precision_score( true_labels, pred_labels, zero_division=0 ) ( summary.loc["TN", "Mean"], summary.loc["FP", "Mean"], summary.loc["FN", "Mean"], summary.loc["TP", "Mean"], ) = confusion_matrix(true_labels, pred_labels, labels=(0, 1)).ravel() mask = torch.from_numpy(model.data.labels["z"][: model.data.N, :, model.cdx]) samples = torch.masked_select(model.z_probs[model.data.is_ontarget].cpu(), mask) if len(samples): z_ll, z_ul = hpdi(samples, CI) summary.loc["p(specific)", "Mean"] = quantile(samples, 0.5).item() summary.loc["p(specific)", "95% LL"] = z_ll.item() summary.loc["p(specific)", "95% UL"] = z_ul.item() else: summary.loc["p(specific)", "Mean"] = 0.0 summary.loc["p(specific)", "95% LL"] = 0.0 summary.loc["p(specific)", "95% UL"] = 0.0 model.summary = summary if path is not None: path = Path(path) torch.save(ci_stats, path / f"{model.full_name}-params.tpqr") if save_matlab: from scipy.io import savemat for param, field in ci_stats.items(): if param in ( "m_probs", "theta_probs", "z_probs", "z_map", "time1", "ttb", ): ci_stats[param] = field.numpy() continue for stat, value in field.items(): ci_stats[param][stat] = value.cpu().numpy() savemat(path / f"{model.full_name}-params.mat", ci_stats) summary.to_csv( path / f"{model.full_name}-summary.csv", )
def model_6(sequences, lengths, args, batch_size=None, include_prior=False): num_sequences, max_length, data_dim = sequences.shape assert lengths.shape == (num_sequences, ) assert lengths.max() <= max_length hidden_dim = args.hidden_dim if not args.raftery_parameterization: # Explicitly parameterize the full tensor of transition probabilities, which # has hidden_dim cubed entries. probs_x = pyro.param( "probs_x", torch.rand(hidden_dim, hidden_dim, hidden_dim), constraint=constraints.simplex, ) else: # Use the more parsimonious "Raftery" parameterization of # the tensor of transition probabilities. See reference: # Raftery, A. E. A model for high-order markov chains. # Journal of the Royal Statistical Society. 1985. probs_x1 = pyro.param( "probs_x1", torch.rand(hidden_dim, hidden_dim), constraint=constraints.simplex, ) probs_x2 = pyro.param( "probs_x2", torch.rand(hidden_dim, hidden_dim), constraint=constraints.simplex, ) mix_lambda = pyro.param("mix_lambda", torch.tensor(0.5), constraint=constraints.unit_interval) # we use broadcasting to combine two tensors of shape (hidden_dim, hidden_dim) and # (hidden_dim, 1, hidden_dim) to obtain a tensor of shape (hidden_dim, hidden_dim, hidden_dim) probs_x = mix_lambda * probs_x1 + (1.0 - mix_lambda) * probs_x2.unsqueeze(-2) probs_y = pyro.param( "probs_y", torch.rand(hidden_dim, data_dim), constraint=constraints.unit_interval, ) tones_plate = pyro.plate("tones", data_dim, dim=-1) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] x_curr, x_prev = torch.tensor(0), torch.tensor(0) # we need to pass the argument `history=2' to `pyro.markov()` # since our model is now 2-markov for t in pyro.markov(range(lengths.max()), history=2): with poutine.mask(mask=(t < lengths).unsqueeze(-1)): probs_x_t = Vindex(probs_x)[x_prev, x_curr] x_prev, x_curr = x_curr, pyro.sample( "x_{}".format(t), dist.Categorical(probs_x_t), infer={"enumerate": "parallel"}, ) with tones_plate: probs_y_t = probs_y[x_curr.squeeze(-1)] pyro.sample( "y_{}".format(t), dist.Bernoulli(probs_y_t), obs=sequences[batch, t], )
def ttfb_model(data, control, Tmax): r""" Eq. 4 and Eq. 7 in:: @article{friedman2015multi, title={Multi-wavelength single-molecule fluorescence analysis of transcription mechanisms}, author={Friedman, Larry J and Gelles, Jeff}, journal={Methods}, volume={86}, pages={27--36}, year={2015}, publisher={Elsevier} } :param data: time prior to the first binding at the target location :param control: time prior to the first binding at the control location :param Tmax: entire observation interval """ ka = pyro.param( "ka", lambda: torch.full((data.shape[0], 1), 0.001), constraint=constraints.positive, ) kns = pyro.param( "kns", lambda: torch.full((data.shape[0], 1), 0.001), constraint=constraints.positive, ) Af = pyro.param( "Af", lambda: torch.full((data.shape[0], 1), 0.9), constraint=constraints.unit_interval, ) k = torch.stack([kns, ka + kns]) # on-target data mask = (data < Tmax) & (data > 0) tau = data.masked_fill(~mask, 1.0) with pyro.plate("bootstrap", data.shape[0], dim=-2) as bdx: with pyro.plate("N", data.shape[1], dim=-1): active = pyro.sample( "active", dist.Bernoulli(Af), infer={"enumerate": "parallel"} ) with handlers.mask(mask=(data == Tmax)): pyro.factor("Tmax", -Vindex(k)[active.long().squeeze(-1), bdx] * Tmax) # pyro.factor("Tmax", -k * Tmax) with handlers.mask(mask=mask): pyro.sample( "tau", dist.Exponential(Vindex(k)[active.long().squeeze(-1), bdx]), obs=tau, ) # pyro.sample("tau", dist.Exponential(k), obs=tau) # negative control data if control is not None: mask = (control < Tmax) & (control > 0) tauc = control.masked_fill(~mask, 1.0) with pyro.plate("bootstrapc", control.shape[0], dim=-2): with pyro.plate("Nc", control.shape[1], dim=-1): with handlers.mask(mask=(control == Tmax)): pyro.factor("Tmaxc", -kns * Tmax) with handlers.mask(mask=mask): pyro.sample("tauc", dist.Exponential(kns), obs=tauc)