def model_ternary(X, y, drop_last=True, initial_noise_var=1e-4): """ set up GP model for single target """ if drop_last: X = X[:, :-1] # ignore the last composition column sel = torch.isfinite(y) X, y = X[sel], y[sel] N, D = X.size() # set up ARD Matern 5/2 kernel # set an empirical mean function to the median value of observed data... kernel = gp.kernels.RBF(input_dim=2, variance=torch.tensor(1.), lengthscale=torch.tensor([1.0, 1.0])) # kernel = gp.kernels.Matern52(input_dim=2, variance=torch.tensor(1.), lengthscale=torch.tensor([1.0, 1.0])) model = gp.models.GPRegression(X, y, kernel, noise=torch.tensor(initial_noise_var), jitter=1e-8) # model.mean_function = lambda x: model.y.median() # set a weakly-informative lengthscale prior # e.g. half-normal(0, dx/3) dx = 1.0 model.kernel.set_prior("lengthscale", dist.HalfNormal(dx / 3)) model.kernel.set_prior("variance", dist.Gamma(2.0, 1 / 2.0)) # set a prior on the likelihood noise based on the variance of the observed data model.set_prior('noise', dist.HalfNormal(model.y.var() / 2)) return model
def model(self, home_team, away_team): sigma_a = pyro.sample("sigma_a", dist.HalfNormal(1.0)) sigma_b = pyro.sample("sigma_b", dist.HalfNormal(1.0)) mu_b = pyro.sample("mu_b", dist.Normal(0.0, 1.0)) rho_raw = pyro.sample("rho_raw", dist.Beta(2, 2)) rho = pyro.deterministic("rho", 2.0 * rho_raw - 1.0) log_gamma = pyro.sample("log_gamma", dist.Normal(0, 1)) with pyro.plate("teams", self.n_teams): abilities = pyro.sample( "abilities", dist.MultivariateNormal( torch.tensor([0.0, mu_b]), covariance_matrix=torch.tensor( [ [sigma_a ** 2.0, rho * sigma_a * sigma_b], [rho * sigma_a * sigma_b, sigma_b ** 2.0], ] ), ), ) log_a = abilities[:, 0] log_b = abilities[:, 1] home_inds = torch.tensor([self.team_to_index[team] for team in home_team]) away_inds = torch.tensor([self.team_to_index[team] for team in away_team]) home_rate = torch.exp(log_a[home_inds] + log_b[away_inds] + log_gamma) away_rate = torch.exp(log_a[away_inds] + log_b[home_inds]) pyro.sample("home_goals", dist.Poisson(home_rate)) pyro.sample("away_goals", dist.Poisson(away_rate))
def guide_3DA(data): # Hyperparameters a_psi_1 = pyro.param('a_psi_1', torch.tensor(-np.pi), constraint=constraints.greater_than(-3.15)) b_psi_1 = pyro.param('b_psi_1', torch.tensor(np.pi), constraint=constraints.less_than(3.15)) x_psi_1 = pyro.param('x_psi_1', torch.tensor(2.), constraint=constraints.positive) a_phi_2 = pyro.param('a_phi_2', torch.tensor(-np.pi), constraint=constraints.greater_than(-3.15)) b_phi_2 = pyro.param('b_phi_2', torch.tensor(np.pi), constraint=constraints.less_than(3.15)) x_phi_2 = pyro.param('x_phi_2', torch.tensor(2.), constraint=constraints.positive) a_psi_2 = pyro.param('a_psi_2', torch.tensor(-np.pi), constraint=constraints.greater_than(-3.15)) b_psi_2 = pyro.param('b_psi_2', torch.tensor(np.pi), constraint=constraints.less_than(3.15)) x_psi_2 = pyro.param('x_psi_2', torch.tensor(2.), constraint=constraints.positive) a_phi_3 = pyro.param('a_phi_3', torch.tensor(-np.pi), constraint=constraints.greater_than(-3.15)) b_phi_3 = pyro.param('b_phi_3', torch.tensor(np.pi), constraint=constraints.less_than(3.15)) x_phi_3 = pyro.param('x_phi_3', torch.tensor(2.), constraint=constraints.positive) # Sampling mu and kappa pyro.sample("mu_psi_1", dist.Uniform(a_psi_1, b_psi_1)) pyro.sample("inv_kappa_psi_1", dist.HalfNormal(x_psi_1)) pyro.sample("mu_phi_2", dist.Uniform(a_phi_2, b_phi_2)) pyro.sample("inv_kappa_phi_2", dist.HalfNormal(x_phi_2)) pyro.sample("mu_psi_2", dist.Uniform(a_psi_2, b_psi_2)) pyro.sample("inv_kappa_psi_2", dist.HalfNormal(x_psi_2)) pyro.sample("mu_phi_3", dist.Uniform(a_phi_3, b_phi_3)) pyro.sample("inv_kappa_phi_3", dist.HalfNormal(x_phi_3))
def model(self, X): _id = self._id N, D = X.shape global_shrinkage_prior_scale_init = self.param_init[f'global_shrinkage_prior_scale_init_{_id}'] cov_diag_prior_loc_init = self.param_init[f'cov_diag_prior_loc_init_{_id}'] cov_diag_prior_scale_init = self.param_init[f'cov_diag_prior_scale_init_{_id}'] global_shrinkage_prior_scale = pyro.param(f'global_shrinkage_scale_prior_{_id}', global_shrinkage_prior_scale_init, constraint=constraints.positive) tau = pyro.sample(f'global_shrinkage_{_id}', dist.HalfNormal(global_shrinkage_prior_scale)) b = pyro.sample('b', dist.InverseGamma(0.5,1./torch.ones(D)**2).to_event(1)) lambdasquared = pyro.sample(f'local_shrinkage_{_id}', dist.InverseGamma(0.5,1./b).to_event(1)) cov_diag_loc = pyro.param(f'cov_diag_prior_loc_{_id}', cov_diag_prior_loc_init) cov_diag_scale = pyro.param(f'cov_diag_prior_scale_{_id}', cov_diag_prior_scale_init, constraint=constraints.positive) cov_diag = pyro.sample(f'cov_diag_{_id}', dist.LogNormal(cov_diag_loc, cov_diag_scale).to_event(1)) #cov_diag = cov_diag*torch.ones(D) cov_diag = cov_diag + jitter lambdasquared = lambdasquared.squeeze() if lambdasquared.dim() == 1: # outer product cov_factor_scale = torch.ger(torch.sqrt(lambdasquared),tau.repeat((tau.dim()-1)*(1,)+(D,))) else: # batch outer product cov_factor_scale = torch.einsum('bp, br->bpr', torch.sqrt(lambdasquared),tau.repeat((tau.dim()-1)*(1,)+(D,))) cov_factor = pyro.sample(f'cov_factor_{_id}', dist.Normal(0., cov_factor_scale).to_event(2)) cov_factor = cov_factor.transpose(-2,-1) with pyro.plate(f'N_{_id}', size=N, subsample_size=self.batch_size, dim=-1) as ind: X = pyro.sample('obs', dist.LowRankMultivariateNormal(torch.zeros(D), cov_factor=cov_factor, cov_diag=cov_diag), obs=X.index_select(0, ind)) return X
def __init__(self, ode_op, ode_model): super(LNAGenModel, self).__init__() self._ode_op = ode_op self._ode_model = ode_model self.ode_params1 = PyroSample(dist.Beta(2, 1)) self.ode_params2 = PyroSample(dist.HalfNormal(1)) self.ode_params3 = PyroSample(dist.Beta(1, 2))
def model(self, data): max_var, data1, data2 = data ### 1. prior over mean M M = pyro.sample( "M", dist.StudentT(1, 0, 3).expand_by([data1.size(0), data1.size(1)]).to_event(2)) ### 2. Prior over variances for the normal distribution U = pyro.sample( "U", dist.HalfNormal(1).expand_by([data1.size(0)]).to_event(1)) U = U.reshape(data1.size(0), 1).repeat(1, 3).view( -1) #Triplicate the rows for the subsequent mean calculation ## 3. prior over translations T_i: Sample translations for each of the x,y,z coordinates T2 = pyro.sample("T2", dist.Normal(0, 1).expand_by([3]).to_event(1)) ## 4. prior over rotations R_i ri_vec = pyro.sample("ri_vec", dist.Uniform(0, 1).expand_by( [3]).to_event(1)) # Uniform distribution R = self.sample_R(ri_vec) M_T1 = M M_R2_T2 = M @ R + T2 # 5. Likelihood with pyro.plate("plate_univariate", data1.size(0) * data1.size(1), dim=-1): pyro.sample("X1", dist.StudentT(1, M_T1.view(-1), U), obs=data1.view(-1)) pyro.sample("X2", dist.StudentT(1, M_R2_T2.view(-1), U), obs=data2.view(-1))
def model_normal(X, y, column_names): # Define our intercept prior intercept_prior = dist.Normal(0.0, 1.0) linear_combination = pyro.sample(f"beta_intercept", intercept_prior) # Also define coefficient priors for i in range(X.shape[1]): coefficient_prior = dist.Normal(0.0, 1.0) beta_coef = pyro.sample(f"beta_{column_names[i]}", coefficient_prior) linear_combination = linear_combination + (X[:, i] * beta_coef) # Define a sigma value for the random error sigma = pyro.sample("sigma", dist.HalfNormal(scale=10.0)) # For a simple linear model, the expected mean is the linear combination of parameters mean = linear_combination with pyro.plate("data", y.shape[0]): # Assume our expected mean comes from a normal distribution with the mean which # depends on the linear combination, and a standard deviatin "sigma" outcome_dist = dist.Normal(mean, sigma) # Condition the expected mean on the observed target y observation = pyro.sample("obs", outcome_dist, obs=y)
def forward(self): def RP(weights, distances, d): return 1e4 * (weights * torch.pow(0.5, distances/(1e3 * d))).sum(-1) with pyro.plate(self.name +"_regions", 3): a = pyro.sample(self.name +"_a", dist.HalfNormal(12.)) with pyro.plate(self.name +"_upstream-downstream", 2): d = torch.exp(pyro.sample(self.name +'_logdistance', dist.Normal(np.e, 2.))) b = pyro.sample(self.name +"_b", dist.Normal(-10.,3.)) theta = pyro.sample(self.name +"_theta", dist.Gamma(2., 0.5)) psi = pyro.sample(self.name +"_dropout", dist.Beta(1., 10.)) with pyro.plate(self.name +"_data", self.N, subsample_size=64) as ind: expr_rate = a[0] * RP(self.upstream_weights.index_select(0, ind), self.upstream_distances, d[0])\ + a[1] * RP(self.downstream_weights.index_select(0, ind), self.downstream_distances, d[1]) \ + a[2] * 1e4 * self.promoter_weights.index_select(0, ind).sum(-1) \ + b mu = torch.multiply(self.read_depth.index_select(0, ind), torch.exp(expr_rate)) p = torch.minimum(mu / (mu + theta), torch.tensor([0.99999])) pyro.sample(self.name +'_obs', dist.ZeroInflatedNegativeBinomial(total_count=theta, probs=p, gate = psi), obs= self.gene_expr.index_select(0, ind))
def model(X, Y, hypers, jitter=1.0e-4): S, P, N = hypers['expected_sparsity'], X.size(1), X.size(0) sigma = pyro.sample("sigma", dist.HalfNormal(hypers['alpha3'])) phi = sigma * (S / math.sqrt(N)) / (P - S) eta1 = pyro.sample("eta1", dist.HalfCauchy(phi)) msq = pyro.sample("msq", dist.InverseGamma(hypers['alpha1'], hypers['beta1'])) xisq = pyro.sample("xisq", dist.InverseGamma(hypers['alpha2'], hypers['beta2'])) eta2 = eta1.pow(2.0) * xisq.sqrt() / msq lam = pyro.sample( "lambda", dist.HalfCauchy(torch.ones(P, device=X.device)).to_event(1)) kappa = msq.sqrt() * lam / (msq + (eta1 * lam).pow(2.0)).sqrt() kX = kappa * X # compute the kernel for the given hyperparameters k = kernel( kX, kX, eta1, eta2, hypers['c']) + (sigma**2 + jitter) * torch.eye(N, device=X.device) # observe the outputs Y pyro.sample("Y", dist.MultivariateNormal(torch.zeros(N, device=X.device), covariance_matrix=k), obs=Y)
def forward(self, data): scale = pyro.sample("scale", dist.HalfNormal(0.1)) sd = scale.view((-1, )).unsqueeze(1) param_shape = 6 states = self._ode_op.apply(self.ode_params.view((-1,param_shape)), \ (self._ode_model,)) for i in range(len(data)): pyro.sample("obs_{}".format(i), dist.Normal(loc = states[...,i,:], \ scale = sd).to_event(1), obs=data[i,:]) return states
def model_3DA(data): # Sampling mu and kappa mu_psi_1 = pyro.sample("mu_psi_1", dist.Uniform(-np.pi, np.pi)) inv_kappa_psi_1 = pyro.sample("inv_kappa_psi_1", dist.HalfNormal(1.)) kappa_psi_1 = 100 + 1/inv_kappa_psi_1 mu_phi_2 = pyro.sample("mu_phi_2", dist.Uniform(-np.pi, np.pi)) inv_kappa_phi_2 = pyro.sample("inv_kappa_phi_2", dist.HalfNormal(1.)) kappa_phi_2 = 100 + 1/inv_kappa_phi_2 mu_psi_2 = pyro.sample("mu_psi_2", dist.Uniform(-np.pi, np.pi)) inv_kappa_psi_2 = pyro.sample("inv_kappa_psi_2", dist.HalfNormal(1.)) kappa_psi_2 = 100 + 1/inv_kappa_psi_2 mu_phi_3 = pyro.sample("mu_phi_3", dist.Uniform(-np.pi, np.pi)) inv_kappa_phi_3 = pyro.sample("inv_kappa_phi_3", dist.HalfNormal(1.)) kappa_phi_3 = 100 + 1/inv_kappa_phi_3 # Looping over the observed data in an conditionally independant manner with pyro.plate('dihedral_angles'): pyro.sample("obs_psi_1", dist.VonMises(mu_psi_1, kappa_psi_1), obs=data[0,:,1]) pyro.sample("obs_phi_2", dist.VonMises(mu_phi_2, kappa_phi_2), obs=data[1,:,0]) pyro.sample("obs_psi_2", dist.VonMises(mu_psi_2, kappa_psi_2), obs=data[1,:,1]) pyro.sample("obs_phi_3", dist.VonMises(mu_phi_3, kappa_phi_3), obs=data[2,:,0])
def summary_prior(): plt.figure(figsize=(12, 2.5)) plt.subplot(131) D = dist.Beta(2.5, 2.5) # F_diag plt.hist([D.sample() for _ in range(50_000)], range=(-1, 2), bins=200, density=True, alpha=0.7) D = dist.Normal(0.0, 0.1) # F_rest plt.hist([D.sample() for _ in range(50_000)], range=(-1, 2), bins=200, density=True, alpha=0.7) plt.legend(['F_diag', 'F_rest']) plt.xlim(-1, 2) plt.subplot(132) D = dist.Normal(0.0, 0.1) # G plt.hist([D.sample() for _ in range(50_000)], range=(-0.75, 0.75), bins=200, density=True) plt.legend(['G']) plt.xlim(-0.75, 0.75) plt.subplot(133) D = dist.HalfNormal(1.0) # H plt.hist([D.sample() for _ in range(50_000)], range=(-1, 4), bins=200, density=True, alpha=0.7) D = dist.Normal(0.0, 0.2) # b plt.hist([D.sample() for _ in range(50_000)], range=(-1, 4), bins=200, density=True, alpha=0.7) plt.legend(['H', 'b']) plt.xlim(-1, 4) plt.show()
def model_gamma(X, y, column_names): pyro.enable_validation(True) min_value = torch.finfo(X.dtype).eps max_value = torch.finfo(X.dtype).max # We still need to calculate our linear combination intercept_prior = dist.Normal(0.0, 1.0) linear_combination = pyro.sample(f"beta_intercept", intercept_prior) #print("intercept", linear_combination) # Also define coefficient priors for i in range(X.shape[1]): coefficient_prior = dist.Normal(0.0, 1.0) beta_coef = pyro.sample(f"beta_{column_names[i]}", coefficient_prior) #print(column_names[i], beta_coef) linear_combination = linear_combination + (X[:, i] * beta_coef) # But now our mean will be e^{linear combination} mean = torch.exp(linear_combination).clamp(min=min_value, max=max_value) # We will also define a rate parameter rate = pyro.sample("rate", dist.HalfNormal(scale=10.0)).clamp(min=min_value) # Since mean = shape/rate, then the shape = mean * rate shape = (mean * rate) # Now that we have the shape and rate parameters for the # Gamma distribution, we can draw samples from it and condition # them on our observations with pyro.plate("data", y.shape[0]): outcome_dist = dist.Gamma(shape, rate) observation = pyro.sample("obs", outcome_dist, obs=y)
def forward(self, data): scale = pyro.sample("scale", dist.HalfNormal(0.001)) sd = scale.view((-1, )).unsqueeze(1) # print("sd: ", sd) p1 = self.ode_params1.view((-1, )) p2 = self.ode_params2.view((-1, )) ode_params = torch.stack([p1, p2], dim=1) simple_sim = self._ode_op.apply(ode_params, (self._ode_model, )) for i in range(len(data)): try: # TODO: Which distribution to use? # pyro.sample("obs_{}".format(i), dist.Exponential(simple_sim[..., i, 0]), obs=data[i]) pyro.sample("obs_{}".format(i), dist.Normal(loc=simple_sim[..., i, :], scale=sd).to_event(1), obs=data[i, :]) except: print(simple_sim[..., i, :]) print("ERROR (invalid parameter for Normal...!): ") return simple_sim
def model(self, home_team, away_team, gameweek): n_gameweeks = max(gameweek) + 1 gamma = pyro.sample("gamma", dist.LogNormal(0, 1)) mu_b = pyro.sample("mu_b", dist.Normal(0, 1)) with pyro.plate("teams", self.n_teams): log_a0 = pyro.sample("log_a0", dist.Normal(0, 1)) log_b0 = pyro.sample("log_b0", dist.Normal(mu_b, 1)) sigma_rw = pyro.sample("sigma_rw", dist.HalfNormal(0.1)) with pyro.plate("random_walk", n_gameweeks - 1): diffs_a = pyro.sample("diff_a", dist.Normal(0, sigma_rw)) diffs_b = pyro.sample("diff_b", dist.Normal(0, sigma_rw)) log_a0_t = log_a0 if log_a0.dim() == 2 else log_a0[None, :] diffs_a = torch.cat((log_a0_t, diffs_a), axis=0) log_a = torch.cumsum(diffs_a, axis=0) log_b0_t = log_b0 if log_b0.dim() == 2 else log_b0[None, :] diffs_b = torch.cat((log_b0_t, diffs_b), axis=0) log_b = torch.cumsum(diffs_b, axis=0) pyro.sample("log_a", dist.Delta(log_a), obs=log_a) pyro.sample("log_b", dist.Delta(log_b), obs=log_b) home_inds = torch.tensor( [self.team_to_index[team] for team in home_team]) away_inds = torch.tensor( [self.team_to_index[team] for team in away_team]) home_rate = torch.clamp( log_a[gameweek, home_inds] - log_b[gameweek, away_inds] + gamma, -7, 2) away_rate = torch.clamp( log_a[gameweek, away_inds] - log_b[gameweek, home_inds], -7, 2) pyro.sample("home_goals", dist.Poisson(torch.exp(home_rate))) pyro.sample("away_goals", dist.Poisson(torch.exp(away_rate)))
def update_posterior(model, x_new=None, y_new=None, lr=1e-3, num_steps=150, optimize_noise_variance=True): if x_new is not None and y_new is not None: if x_new.ndimension() == 1: x_new = x_new.unsqueeze(0) X = torch.cat([model.X, x_new]) # y = torch.cat([model.y, y_new.squeeze(1)]) y = torch.cat([model.y, y_new]) model.set_data(X, y) # update model noise prior based on variance of observed data model.set_prior('noise', dist.HalfNormal(model.y.var())) # reinitialize hyperparameters from prior p = model.kernel._priors model.kernel.variance = p['variance']() model.kernel.lengthscale = p['lengthscale']( model.kernel.lengthscale.size()) if optimize_noise_variance: model.noise = model._priors['noise']() optimizer = optim.Adam(model.parameters(), lr=lr) else: optimizer = optim.Adam([ param for name, param in model.named_parameters() if 'noise' not in name ], lr=lr) losses = gp.util.train(model, optimizer, num_steps=num_steps) return losses
def sds(one_sd_scale, two_sd_scale, three_sd_scale): one_sd = pyro.sample('one_sd', dst.HalfNormal(one_sd_scale)) two_sd = pyro.sample('two_sd', dst.HalfNormal(two_sd_scale)) three_sd = pyro.sample('three_sd', dst.HalfNormal(three_sd_scale)) return {'one': one_sd, 'two': two_sd, 'three': three_sd}
# Setting kappa parameters phi_2_inv_kappa_model = phi_2_scale phi_3_inv_kappa_model = phi_3_scale psi_1_inv_kappa_model = psi_1_scale psi_2_inv_kappa_model = psi_2_scale # Initialising distributions mu_phi_2_dist_model = dist.Uniform(phi_2_mu_start_model, phi_2_mu_end_model) mu_phi_3_dist_model = dist.Uniform(phi_3_mu_start_model, phi_3_mu_end_model) mu_psi_1_dist_model = dist.Uniform(psi_1_mu_start_model, psi_1_mu_end_model) mu_psi_2_dist_model = dist.Uniform(psi_2_mu_start_model, psi_2_mu_end_model) inv_kappa_phi_2_dist_model = dist.HalfNormal(phi_2_inv_kappa_model) inv_kappa_phi_3_dist_model = dist.HalfNormal(phi_3_inv_kappa_model) inv_kappa_psi_1_dist_model = dist.HalfNormal(psi_1_inv_kappa_model) inv_kappa_psi_2_dist_model = dist.HalfNormal(psi_2_inv_kappa_model) # ### Sampling from Von Mises (model) # In[24]: def Sampler_VM_model(n_samples, n_AA, mus, kappas, bias=0, used_angles=[True, True, True]): ''' n_samples -> number of samples. n_AA -> length of the peptides in AA.
def halfnormal(stdev, **kwargs): return sample(dist.HalfNormal(stdev), **kwargs)
def model(u_seq, z_seq, batch_size=None): # Move input and output to model device u_seq = list(map(lambda x: x.to(device), u_seq)) z_seq = list(map(lambda x: x.to(device), z_seq)) # Process sequence lengths lengths = torch.tensor(list(map(len, u_seq)), device=device) num_sequences, max_length = len(lengths), lengths.max() def pad_fn(x): return nn.functional.pad(x, (0, 0, 0, max_length - len(x))) u_seq = torch.stack(list(map(pad_fn, u_seq)), dim=0).unsqueeze(-1) z_seq = torch.stack(list(map(pad_fn, z_seq)), dim=0).unsqueeze(-1) plate_u1 = pyro.plate("plate_u1", udim, dim=-1) plate_x1 = pyro.plate("plate_x1", xdim, dim=-1) plate_x2 = pyro.plate("plate_x2", xdim, dim=-2) plate_z1 = pyro.plate("plate_z1", zdim, dim=-1) plate_z2 = pyro.plate("plate_z2", zdim, dim=-2) plate_seq = pyro.plate("plate_seq", num_sequences, batch_size, dim=-3) # Parameter priors # with poutine.mask(mask=True): # Noise variance Q = 0.1 * torch.ones(xdim, device=device).unsqueeze(-1) R = 0.1 * torch.ones(zdim, device=device).unsqueeze(-1) # State transition matrix with plate_x1: F_diag = pyro.sample( "F_diag", dist.Beta(torch.tensor(2.5, device=device), torch.tensor(2.5, device=device))) with plate_x2: F_rest = pyro.sample( "F_rest", dist.Normal(torch.tensor(0.0, device=device), torch.tensor(0.1, device=device))) mask_diag = torch.eye(xdim, device=device) F = F_diag * mask_diag + F_rest * (1 - mask_diag) # Input filter matrix with plate_x2, plate_u1: G = pyro.sample( "G", dist.Normal(torch.tensor(0.0, device=device), torch.tensor(0.1, device=device))) # Measurement matrix with plate_z2, plate_x1: H = pyro.sample("H", dist.HalfNormal(torch.tensor(1.0, device=device))) # Measurement bias with plate_z1: b = pyro.sample( "b", dist.Normal(torch.tensor(0.0, device=device), torch.tensor(0.2, device=device))).unsqueeze(-1) # We subsample batch_size items out of num_sequences items. with plate_seq as batch: lengths = lengths[batch] num_sequences = num_sequences if batch_size is None else batch_size x = torch.zeros((num_sequences, xdim, 1), device=device) for t in pyro.markov(range(max_length if jit else lengths.max())): with pyro.poutine.mask( mask=(t < lengths).unsqueeze(-1).unsqueeze(-1)): x = pyro.sample( f"x_{t}", dist.Normal(F @ x + G @ u_seq[batch, t], Q)) with plate_z2: pyro.sample(f"z_{t}", dist.Normal(H @ x + b, R), obs=z_seq[batch, t])
def HalfNormalFromInterval(high): """This assumes a 90% confidence interval starting at 0, i.e. right endpoint marks 90% on the CDF""" stdev = high / 1.645 return dist.HalfNormal(stdev)