def _dynamics(self, features): """ Compute dynamics parameters from time features. """ state_dim = self.args.state_dim gate_rate_dim = 2 * self.num_stations**2 init_loc = torch.zeros(state_dim) init_scale_tril = pyro.param( "init_scale", torch.full((state_dim, ), 10.), constraint=constraints.positive).diag_embed() init_dist = dist.MultivariateNormal(init_loc, scale_tril=init_scale_tril) trans_matrix = pyro.param("trans_matrix", 0.99 * torch.eye(state_dim)) trans_loc = torch.zeros(state_dim) trans_scale_tril = pyro.param( "trans_scale", 0.1 * torch.ones(state_dim), constraint=constraints.positive).diag_embed() trans_dist = dist.MultivariateNormal(trans_loc, scale_tril=trans_scale_tril) obs_matrix = pyro.param("obs_matrix", torch.randn(state_dim, gate_rate_dim)) obs_matrix.data /= obs_matrix.data.norm(dim=-1, keepdim=True) loc_scale = self.nn(features) loc, scale = loc_scale.reshape(loc_scale.shape[:-1] + (2, gate_rate_dim)).unbind(-2) scale = bounded_exp(scale, bound=10.) obs_dist = dist.Normal(loc, scale).to_event(1) return init_dist, trans_matrix, trans_dist, obs_matrix, obs_dist
def model(self): self.set_mode("model") N = self.X.size(0) Kff = self.kernel(self.X).contiguous() Kff.view(-1)[::N + 1] += self.jitter # add jitter to the diagonal Lff = Kff.cholesky() zero_loc = self.X.new_zeros(self.f_loc.shape) if self.whiten: identity = eye_like(self.X, N) pyro.sample( self._pyro_get_fullname("f"), dist.MultivariateNormal( zero_loc, scale_tril=identity).to_event(zero_loc.dim() - 1)) f_scale_tril = Lff.matmul(self.f_scale_tril) f_loc = Lff.matmul(self.f_loc.unsqueeze(-1)).squeeze(-1) else: pyro.sample( self._pyro_get_fullname("f"), dist.MultivariateNormal( zero_loc, scale_tril=Lff).to_event(zero_loc.dim() - 1)) f_scale_tril = self.f_scale_tril f_loc = self.f_loc f_loc = f_loc + self.mean_function(self.X) f_var = f_scale_tril.pow(2).sum(dim=-1) if self.y is None: return f_loc, f_var else: return self.likelihood(f_loc, f_var, self.y)
def model(data): # Global variables. weights = pyro.param( "weights", torch.FloatTensor([0.5]), constraint=constraints.unit_interval ) scales = pyro.param( "scales", torch.stack([torch.eye(2), torch.eye(2)]), constraint=constraints.positive ) locs = [ pyro.sample( "locs_{}".format(k), dist.MultivariateNormal(torch.zeros(2), 2 * torch.eye(2)) ) for k in range(K) ] with pyro.iarange("data", data.size(0), 4) as ind: # Local variables. assignment = pyro.sample( "assignment", dist.Bernoulli(torch.ones(len(data)) * weights) ).to(torch.int64) pyro.sample( "obs", dist.MultivariateNormal(locs[assignment], scales[assignment]), obs=data.index_select(ind) )
def model(self): # Global variables weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(self.n_comp))) with pyro.plate('components', self.n_comp): locs = pyro.sample('locs', dist.MultivariateNormal( torch.zeros(self.shape[1]), torch.eye(self.shape[1])) ) scale = pyro.sample('scale', dist.LogNormal(0., 2.)) lis = [] for i in range(self.n_comp): t = torch.eye(self.shape[1]) * scale[i] lis.append(t) f = torch.stack(lis) with pyro.plate('data', self.shape[0]): # Local variables. assignment = pyro.sample('assignment', dist.Categorical(weights)) pyro.sample('obs', dist.MultivariateNormal(locs[assignment], f[assignment]), obs=self.tensor_train)
def guide(self): self.set_mode("guide") if self.coord: for v in range(self.V_net): for h in range(self.H_dim): for sr_i in range(self.sr_dim): for lw_i in range(self.lw_dim): pyro.sample( f'f_coord_v{v}_h{h}_sr{sr_i}_lw{lw_i}', dist.MultivariateNormal( self.gp.coord.loc[v, h, sr_i, lw_i, :], scale_tril=self.gp.coord.cov_tril[ v, h, sr_i, lw_i, :, :]).to_event( self.gp.coord.loc[v, h, sr_i, lw_i, :].dim() - 1)) if self.socpop: for v in range(self.V_net): for sp_i in range(2): for lw_i in range(self.lw_dim): pyro.sample( f'f_socpop_v{v}_{["soc","pop"][sp_i]}_lw{lw_i}', dist.MultivariateNormal( self.gp.socpop.loc[v, sp_i, lw_i, :], scale_tril=self.gp.socpop.cov_tril[v, sp_i, lw_i, :, :] ).to_event(self.gp.socpop.loc[v, sp_i, lw_i, :].dim() - 1))
def model(self, seq): mu0 = torch.zeros(self.emb_dim).to(self.device) tri0 = self.tri0 # create this when initializing. (takes 4ms each time!) muV = pyro.sample("muV", dist.MultivariateNormal(loc=mu0, scale_tril=tri0)) with plate("item_loop", self.num_items): V = pyro.sample(f"V", dist.MultivariateNormal(muV, scale_tril=tri0)) # LIFT MODULE: prior = { 'linear.bias': dist.Normal(0, 1), 'V.weight': Deterministic_distr(V) } lifted_module = pyro.random_module("net", self, prior=prior) lifted_reg_model = lifted_module() lifted_reg_model.lstm.flatten_parameters() with pyro.plate("data", len(seq), subsample_size=self.batch_size) as ind: batch_seq = seq[ind, ] x = batch_seq[:, :-1] y = batch_seq[:, 1:] batch_mask = (y != 0).float() lprobs = lifted_reg_model(x) data = pyro.sample( "obs_x", dist.Categorical(logits=lprobs).mask(batch_mask).to_event(2), obs=y) return lifted_reg_model
def model(self): self.set_mode("model") M = self.Xu.size(0) Kuu = self.kernel(self.Xu).contiguous() Kuu.view(-1)[::M + 1] += self.jitter # add jitter to the diagonal Luu = Kuu.cholesky() zero_loc = self.Xu.new_zeros(self.u_loc.shape) if self.whiten: identity = eye_like(self.Xu, M) pyro.sample(self._pyro_get_fullname("u"), dist.MultivariateNormal(zero_loc, scale_tril=identity) .to_event(zero_loc.dim() - 1)) else: pyro.sample(self._pyro_get_fullname("u"), dist.MultivariateNormal(zero_loc, scale_tril=Luu) .to_event(zero_loc.dim() - 1)) f_loc, f_var = conditional(self.X, self.Xu, self.kernel, self.u_loc, self.u_scale_tril, Luu, full_cov=False, whiten=self.whiten, jitter=self.jitter) f_loc = f_loc + self.mean_function(self.X) if self.y is None: return f_loc, f_var else: # we would like to load likelihood's parameters outside poutine.scale context self.likelihood._load_pyro_samples() with poutine.scale(scale=self.num_data / self.X.size(0)): return self.likelihood(f_loc, f_var, self.y)
def model(self, seq): bias = dist.Normal(0,1) mu0 = torch.zeros(self.emb_dim).to(self.device) var0 = torch.diag(torch.ones(self.emb_dim).to(self.device)*2) muV = pyro.sample("muV", dist.MultivariateNormal(loc = mu0, covariance_matrix= var0)) with plate("item_loop", self.num_items): V = pyro.sample(f"V", dist.MultivariateNormal(muV, var0)) # LIFT MODULE: prior = {'linear.bias' : bias, 'V.weight' : Deterministic_distr(V)} lifted_module = pyro.random_module("net", self, prior= prior) lifted_reg_model = lifted_module() lifted_reg_model.lstm.flatten_parameters() with pyro.plate("data", len(seq), subsample_size = self.batch_size) as ind: batch_seq = seq[ind,] batch_mask = (batch_seq!=0).float() lprobs = lifted_reg_model(batch_seq) data = pyro.sample("obs_x", dist.Categorical(logits=lprobs).mask(batch_mask).to_event(2), obs = batch_seq) return lifted_reg_model
def guide(self): """Approximate posterior for the horseshoe prior. We assume posterior in the form of the multivariate normal distriburtion for the global mean and standard deviation and multivariate normal distribution for the parameters of each subject independently. """ nsub = self.runs # number of subjects npar = self.npar # number of parameters trns = biject_to(constraints.positive) m_hyp = param('m_hyp', zeros(2 * npar)) st_hyp = param('scale_tril_hyp', torch.eye(2 * npar), constraint=constraints.lower_cholesky) hyp = sample('hyp', dist.MultivariateNormal(m_hyp, scale_tril=st_hyp), infer={'is_auxiliary': True}) unc_mu = hyp[..., :npar] unc_tau = hyp[..., npar:] c_tau = trns(unc_tau) ld_tau = trns.inv.log_abs_det_jacobian(c_tau, unc_tau) ld_tau = sum_rightmost(ld_tau, ld_tau.dim() - c_tau.dim() + 1) sample("mu", dist.Delta(unc_mu, event_dim=1)) sample("tau", dist.Delta(c_tau, log_density=ld_tau, event_dim=1)) m_locs = param('m_locs', zeros(nsub, npar)) st_locs = param('scale_tril_locs', torch.eye(npar).repeat(nsub, 1, 1), constraint=constraints.lower_cholesky) with plate('runs', nsub): sample("locs", dist.MultivariateNormal(m_locs, scale_tril=st_locs))
def model(data): _, dim = data.shape weights = pyro.sample('weights', dist.Dirichlet(torch.ones(K))) with pyro.plate('c1', K): mus = pyro.sample( 'mus', dist.MultivariateNormal(torch.zeros(dim), torch.diag(torch.ones(dim) * 10.0))) assert (mus.size() == (K, dim)) with pyro.plate('dim', dim): with pyro.plate('c2', K): lambdas = pyro.sample('lambdas', dist.LogNormal(0, 2)) assert (lambdas.size() == (K, dim)) scales = [] for k in range(K): scales.append(torch.diag(lambdas[k])) scales = torch.stack(scales, dim=0) assert ((K, dim, dim) == scales.size()) with pyro.plate('data', len(data)): assignments = pyro.sample('assignments', dist.Categorical(weights)) pyro.sample('obs', dist.MultivariateNormal(mus[assignments], scales[assignments]), obs=data)
def model(X, Y, U, V): phi = 1 if is_cuda: d_i = pyro.sample( "d_i", dist.MultivariateNormal( torch.zeros(E).cuda(), phi * torch.eye(E).cuda())).cuda() else: d_i = pyro.sample( "d_i", dist.MultivariateNormal(torch.zeros(E), phi * torch.eye(E))) with pyro.plate('observations', len(X)): if is_cuda: logit = torch.sum(torch.bmm( U[X[:, 0]].view(U[X[:, 0]].shape[0], 1, E), (V[X[:, 1]] + d_i).view((V[X[:, 1]] + d_i).shape[0], E, 1)).cuda(), axis=1).cuda() else: logit = torch.sum(torch.bmm( U[X[:, 0]].view(U[X[:, 0]].shape[0], 1, E), (V[X[:, 1]] + d_i).view((V[X[:, 1]] + d_i).shape[0], E, 1)), axis=1) target = pyro.sample('obs', dist.Bernoulli(logits=logit), obs=Y) if is_cuda: target = target.cuda()
def model(self, X=None, y=None): N = X.shape[0] D = X.shape[1] pyro.module("MDN", self) pi, loc, Sigma_tril = self.mdn(X) locT = torch.transpose(loc, 0, 1) Sigma_trilT = torch.transpose(Sigma_tril, 0, 1) assert pi.shape == (N, self.K) assert locT.shape == (self.K, N, D) assert Sigma_trilT.shape == (self.K, N, D, D) with pyro.plate("data", N): assignment = pyro.sample("assignment", dist.Categorical(pi)) if len(assignment.shape) == 1: _mu = torch.gather(locT, 0, assignment.view(1, -1, 1))[0] _scale_tril = torch.gather(Sigma_trilT, 0, assignment.view(1, -1, 1, 1))[0] sample = pyro.sample('obs', dist.MultivariateNormal( _mu, scale_tril=_scale_tril), obs=y) else: _mu = locT[assignment][:, 0] _scale_tril = Sigma_trilT[assignment][:, 0] sample = pyro.sample('obs', dist.MultivariateNormal( _mu, scale_tril=_scale_tril), obs=y) return pi, loc, Sigma_tril, sample
def my_local_guide(x, y, alt_av, alt_ids): if diagonal_alpha: alpha_loc = pyro.param( 'alpha_loc', torch.randn(len(non_mix_params), device=x.device)) alpha_scale = pyro.param( 'alpha_scale', 1 * torch.ones(len(non_mix_params), device=x.device), constraint=constraints.positive) alpha = pyro.sample("alpha", dist.Normal(alpha_loc, alpha_scale).to_event(1)) else: alpha_loc = pyro.param( 'alpha_loc', torch.randn(len(non_mix_params), device=x.device)) alpha_scale = pyro.param( "alpha_scale", torch.tril(1 * torch.eye(len(non_mix_params), device=x.device)), constraint=constraints.lower_cholesky) alpha = pyro.sample( "alpha", dist.MultivariateNormal(alpha_loc, scale_tril=alpha_scale)) if diagonal_beta_mu: beta_mu_loc = pyro.param('beta_mu_loc', torch.randn(len(mix_params), device=x.device)) beta_mu_scale = pyro.param( 'beta_mu_scale', 1 * torch.ones(len(mix_params), device=x.device), constraint=constraints.positive) beta_mu = pyro.sample( "beta_mu", dist.Normal(beta_mu_loc, beta_mu_scale).to_event(1)) else: beta_mu_loc = pyro.param('beta_mu_loc', torch.randn(len(mix_params), device=x.device)) beta_mu_scale = pyro.param( "beta_mu_scale", torch.tril(1 * torch.eye(len(mix_params), device=x.device)), constraint=constraints.lower_cholesky) beta_mu = pyro.sample( "beta_mu", dist.MultivariateNormal(beta_mu_loc, scale_tril=beta_mu_scale)) # Use an amortized guide for local variables. pyro.module("predictor", predictor) one_hot = torch.zeros(num_resp, T, num_alternatives, device=x.device, dtype=torch.float) one_hot = one_hot.scatter(2, y.unsqueeze(2).long(), 1) inference_data = torch.cat([one_hot, x, alt_av_cuda.float()], dim=-1) beta_loc = predictor.forward(inference_data.flatten(1, 2).unsqueeze(1)) beta_scale = pyro.param( 'beta_resp_scale', torch.tril(1. * torch.eye(len(mix_params), device=x.device)), constraint=constraints.lower_cholesky) pyro.sample( "beta_resp", dist.MultivariateNormal(beta_loc, scale_tril=beta_scale).to_event(1))
def normal_inv_gamma_family_guide(design, obs_sd, w_sizes, mf=False): """Normal inverse Gamma family guide. If `obs_sd` is known, this is a multivariate Normal family with separate parameters for each batch. `w` is sampled from a Gaussian with mean `mw_param` and covariance matrix derived from `obs_sd * lambda_param` and the two parameters `mw_param` and `lambda_param` are learned. If `obs_sd=None`, this is a four-parameter family. The observation precision `tau` is sampled from a Gamma distribution with parameters `alpha`, `beta` (separate for each batch). We let `obs_sd = 1./torch.sqrt(tau)` and then proceed as above. :param torch.Tensor design: a tensor with last two dimensions `n` and `p` corresponding to observations and features respectively. :param torch.Tensor obs_sd: observation standard deviation, or `None` to use inverse Gamma :param OrderedDict w_sizes: map from variable names to torch.Size """ # design is size batch x n x p # tau is size batch tau_shape = design.shape[:-2] with ExitStack() as stack: for plate in iter_plates_to_shape(tau_shape): stack.enter_context(plate) if obs_sd is None: # First, sample tau (observation precision) alpha = softplus( pyro.param("invsoftplus_alpha", 20.0 * torch.ones(tau_shape)) ) beta = softplus( pyro.param("invsoftplus_beta", 20.0 * torch.ones(tau_shape)) ) # Global variable tau_prior = dist.Gamma(alpha, beta) tau = pyro.sample("tau", tau_prior) obs_sd = 1.0 / torch.sqrt(tau) # response will be shape batch x n obs_sd = obs_sd.expand(tau_shape).unsqueeze(-1) for name, size in w_sizes.items(): w_shape = tau_shape + size # Set up mu and lambda mw_param = pyro.param("{}_guide_mean".format(name), torch.zeros(w_shape)) scale_tril = pyro.param( "{}_guide_scale_tril".format(name), torch.eye(*size).expand(tau_shape + size + size), constraint=constraints.lower_cholesky, ) # guide distributions for w if mf: w_dist = dist.MultivariateNormal(mw_param, scale_tril=scale_tril) else: w_dist = dist.MultivariateNormal( mw_param, scale_tril=obs_sd.unsqueeze(-1) * scale_tril ) pyro.sample(name, w_dist)
def model(loc, cov): x = pyro.param("x", torch.randn(2)) y = pyro.param("y", torch.randn(3, 2)) z = pyro.param("z", torch.randn(4, 2).abs(), constraint=constraints.greater_than(-1)) pyro.sample("obs_x", dist.MultivariateNormal(loc, cov), obs=x) with pyro.plate("y_plate", 3): pyro.sample("obs_y", dist.MultivariateNormal(loc, cov), obs=y) with pyro.plate("z_plate", 4): pyro.sample("obs_z", dist.MultivariateNormal(loc, cov), obs=z)
def model(N): with pyro.plate("x_plate", N): z1 = pyro.sample( "z1", dist.MultivariateNormal(torch.zeros(2), torch.eye(2))) z2 = pyro.sample( "z2", dist.MultivariateNormal(torch.zeros(2), torch.eye(2))) return pyro.sample("x", dist.MultivariateNormal(z1 + z2, torch.eye(2)))
def my_local_guide(x, y, alt_av, alt_ids): if diagonal_alpha: alpha_loc = pyro.param( 'alpha_loc', torch.randn(len(non_mix_params), device=x.device)) alpha_scale = pyro.param( 'alpha_scale', 1. * torch.ones(len(non_mix_params), device=x.device), constraint=constraints.positive) alpha = pyro.sample("alpha", dist.Normal(alpha_loc, alpha_scale).to_event(1)) else: alpha_loc = pyro.param( 'alpha_loc', torch.randn(len(non_mix_params), device=x.device)) alpha_scale = pyro.param( "alpha_scale", torch.tril(1. * torch.eye(len(non_mix_params), device=x.device)), constraint=constraints.lower_cholesky) alpha = pyro.sample( "alpha", dist.MultivariateNormal(alpha_loc, scale_tril=alpha_scale)) if diagonal_beta_mu: beta_mu_loc = pyro.param('beta_mu_loc', torch.randn(len(mix_params), device=x.device)) beta_mu_scale = pyro.param( 'beta_mu_scale', 1. * torch.ones(len(mix_params), device=x.device), constraint=constraints.positive) beta_mu = pyro.sample( "beta_mu", dist.Normal(beta_mu_loc, beta_mu_scale).to_event(1)) else: beta_mu_loc = pyro.param('beta_mu_loc', torch.randn(len(mix_params), device=x.device)) beta_mu_scale = pyro.param( "beta_mu_scale", torch.tril(1. * torch.eye(len(mix_params), device=x.device)), constraint=constraints.lower_cholesky) beta_mu = pyro.sample( "beta_mu", dist.MultivariateNormal(beta_mu_loc, scale_tril=beta_mu_scale)) beta_loc = pyro.param( 'beta_resp_loc', torch.randn(num_resp, len(mix_params), device=x.device)) beta_scale = pyro.param( 'beta_resp_scale', torch.tril( 1. * torch.eye(len(mix_params), len(mix_params), device=x.device)), constraint=constraints.lower_cholesky) pyro.sample( "beta_resp", dist.MultivariateNormal(beta_loc, scale_tril=beta_scale).to_event(1))
def step(self, state, datum=None): state["z"] = pyro.sample( "z_{}".format(self.t), dist.MultivariateNormal(state["z"], scale_tril=trans_dist.scale_tril)) datum = pyro.sample( "obs_{}".format(self.t), dist.MultivariateNormal(state["z"], scale_tril=obs_dist.scale_tril), obs=datum) self.t += 1 return datum
def get_replicated_data(data, mu, cov, pi): data_rep = [] for i in range(len(data)): cluster = pyro.sample('category', dist.Categorical(torch.tensor(pi))) idx = cluster.item() sample = pyro.sample("obs", dist.MultivariateNormal(mu[idx], cov[idx])) while sample[0] < min(data[:, 0]) or sample[1] < min(data[:, 1]): # Only sample valid points sample = pyro.sample("obs", dist.MultivariateNormal(mu[idx], cov[idx])) data_rep.append(sample.tolist()) data_rep = torch.tensor(data_rep) return data_rep
def guide(self): self.set_mode("guide") self._load_pyro_samples() pyro.sample( self._pyro_get_fullname("f"), dist.MultivariateNormal( self.f_loc, scale_tril=self.f_scale_tril).to_event(self.f_loc.dim() - 1)) pyro.sample( self._pyro_get_fullname("g"), dist.MultivariateNormal( self.g_loc, scale_tril=self.g_scale_tril).to_event(self.g_loc.dim() - 1))
def get_samples(num_samples=100): # underlying parameters mu1 = torch.tensor([0., 5.]) sig1 = torch.tensor([[2., 0.], [0., 3.]]) mu2 = torch.tensor([5., 0.]) sig2 = torch.tensor([[4., 0.], [0., 1.]]) # generate samples dist1 = dist.MultivariateNormal(mu1, sig1) samples1 = [pyro.sample("samples1", dist1) for _ in range(num_samples)] dist2 = dist.MultivariateNormal(mu2, sig2) samples2 = [pyro.sample("samples2", dist2) for _ in range(num_samples)] return torch.cat((torch.stack(samples1), torch.stack(samples2)))
def test_kl_independent_normal_mvn(batch_shape, size): loc = torch.randn(batch_shape + (size, )) scale = torch.randn(batch_shape + (size, )).exp() p1 = dist.Normal(loc, scale).to_event(1) p2 = dist.MultivariateNormal(loc, scale_tril=scale.diag_embed()) loc = torch.randn(batch_shape + (size, )) cov = torch.randn(batch_shape + (size, size)) cov = cov @ cov.transpose(-1, -2) + 0.01 * torch.eye(size) q = dist.MultivariateNormal(loc, covariance_matrix=cov) actual = kl_divergence(p1, q) expected = kl_divergence(p2, q) assert_close(actual, expected)
def guide(self): """Approximate posterior for the Dirichlet process prior. """ nsub = self.runs # number of subjects npar = self.npar # number of parameters kmax = self.kmax # maximum number of components gaa = param("ga_a", ones(1), constraint=constraints.positive) gar = param("ga_r", .1 * ones(1), constraint=constraints.positive) sample('alpha', dist.Gamma(gaa, gaa / gar)) gba = param("gb_beta_a", ones(kmax - 1), constraint=constraints.positive) gbb = param("gb_beta_b", ones(kmax - 1), constraint=constraints.positive) beta = sample("beta", dist.Beta(gba, gbb).to_event(1)) with plate('classes', kmax, dim=-2): m_mu = param('m_mu', zeros(kmax, npar)) st_mu = param('scale_tril_mu', torch.eye(npar).repeat(kmax, 1, 1), constraint=constraints.lower_cholesky) mu = sample("mu", dist.MultivariateNormal(m_mu, scale_tril=st_mu)) m_tau = param('m_hyp', zeros(kmax, npar)) st_tau = param('scale_tril_hyp', torch.eye(npar).repeat(kmax, 1, 1), constraint=constraints.lower_cholesky) mn = dist.MultivariateNormal(m_tau, scale_tril=st_tau) sample( "tau", dist.TransformedDistribution(mn, [dist.transforms.ExpTransform()])) m_locs = param('m_locs', zeros(nsub, npar)) st_locs = param('scale_tril_locs', torch.eye(npar).repeat(nsub, 1, 1), constraint=constraints.lower_cholesky) class_probs = param('class_probs', ones(nsub, kmax) / kmax, constraint=constraints.simplex) with plate('subjects', nsub, dim=-2): sample('class', dist.Categorical(class_probs), infer={"enumerate": "parallel"}) sample("locs", dist.MultivariateNormal(m_locs, scale_tril=st_locs))
def guide_horseshoe_plus(self): npar = self.npars # number of parameters nsub = self.runs # number of subjects trns = biject_to(constraints.positive) m_hyp = param('m_hyp', zeros(2*npar)) st_hyp = param('scale_tril_hyp', torch.eye(2*npar), constraint=constraints.lower_cholesky) hyp = sample('hyp', dist.MultivariateNormal(m_hyp, scale_tril=st_hyp), infer={'is_auxiliary': True}) unc_mu = hyp[:npar] unc_sigma = hyp[npar:] c_sigma = trns(unc_sigma) ld_sigma = trns.inv.log_abs_det_jacobian(c_sigma, unc_sigma) ld_sigma = sum_rightmost(ld_sigma, ld_sigma.dim() - c_sigma.dim() + 1) mu_g = sample("mu_g", dist.Delta(unc_mu, event_dim=1)) sigma_g = sample("sigma_g", dist.Delta(c_sigma, log_density=ld_sigma, event_dim=1)) m_tmp = param('m_tmp', zeros(nsub, 2*npar)) st_tmp = param('s_tmp', torch.eye(2*npar).repeat(nsub, 1, 1), constraint=constraints.lower_cholesky) with plate('subjects', nsub): tmp = sample('tmp', dist.MultivariateNormal(m_tmp, scale_tril=st_tmp), infer={'is_auxiliary': True}) unc_locs = tmp[..., :npar] unc_scale = tmp[..., npar:] c_scale = trns(unc_scale) ld_scale = trns.inv.log_abs_det_jacobian(c_scale, unc_scale) ld_scale = sum_rightmost(ld_scale, ld_scale.dim() - c_scale.dim() + 1) x = sample("x", dist.Delta(unc_locs, event_dim=1)) sigma_x = sample("sigma_x", dist.Delta(c_scale, log_density=ld_scale, event_dim=1)) return {'mu_g': mu_g, 'sigma_g': sigma_g, 'sigma_x': sigma_x, 'x': x}
def guide(self, data): sample_size = self.sample_size subsample_size = self.subsample_size pyro.module('encoder', self.encoder) if self.x_feature == 1: with pyro.plate("data", sample_size, subsample_size=subsample_size, dim=-2) as idx: data_ = data[idx] data_nan = torch.isnan(data_) if data_nan.any(): data_ = torch.where(data_nan, torch.full_like(data_, -1), data_) x_local, x_scale = self.encoder.forward(data_) pyro.sample('x', dist.Normal(x_local, x_scale)) else: transform = LowerCholeskyTransform() with pyro.plate("data", sample_size, subsample_size=subsample_size) as idx: data_ = data[idx] data_nan = torch.isnan(data_) if data_nan.any(): data_ = torch.where(data_nan, torch.full_like(data_, -1), data_) x_local, x_scale = self.encoder.forward(data_) pyro.sample( 'x', dist.MultivariateNormal(x_local, scale_tril=transform(x_scale)))
def __init__(self, mdisc_log_local=0, mdisc_log_scale=0.5, mdiff_local=0.5, mdiff_scale=1, x_feature=2, x_local=None, x_cov=None, D=1, *args, **kwargs): super().__init__(*args, **kwargs) mdisc = torch.FloatTensor(self.item_size).log_normal_( mdisc_log_local, mdisc_log_scale) mdiff = torch.FloatTensor(self.item_size).normal_( mdiff_local, mdiff_scale) self.a = self.gen_a(self.item_size, mdisc, x_feature) b = -mdiff * mdisc self.b = b.view(1, -1) self.x_feature = x_feature if x_local is None: x_local = torch.zeros((x_feature, )) if x_cov is None: x_cov = torch.eye(x_feature) self.x = dist.MultivariateNormal(x_local, x_cov).sample( (self.sample_size, )) self.D = D
def test_masked_mixture_multivariate(sample_shape, batch_shape): event_shape = torch.Size((8,)) component0 = dist.MultivariateNormal( torch.zeros(event_shape), torch.eye(event_shape[0]) ) component1 = dist.Uniform( torch.zeros(event_shape), torch.ones(event_shape) ).to_event(1) if batch_shape: component0 = component0.expand_by(batch_shape) component1 = component1.expand_by(batch_shape) mask = torch.empty(batch_shape).bernoulli_(0.5).bool() d = dist.MaskedMixture(mask, component0, component1) assert d.batch_shape == batch_shape assert d.event_shape == event_shape assert d.sample().shape == batch_shape + event_shape assert d.mean.shape == batch_shape + event_shape assert d.variance.shape == batch_shape + event_shape x = d.sample(sample_shape) assert x.shape == sample_shape + batch_shape + event_shape log_prob = d.log_prob(x) assert log_prob.shape == sample_shape + batch_shape assert not torch_isnan(log_prob) log_prob_0 = component0.log_prob(x) log_prob_1 = component1.log_prob(x) mask = mask.expand(sample_shape + batch_shape) assert_equal(log_prob[mask], log_prob_1[mask]) assert_equal(log_prob[~mask], log_prob_0[~mask])
def model(cov): w = pyro.sample("w", dist.Normal(0, 1000).expand([2]).to_event(1)) x = pyro.sample("x", dist.Normal(0, 1000).expand([1]).to_event(1)) y = pyro.sample("y", dist.Normal(0, 1000).expand([1]).to_event(1)) z = pyro.sample("z", dist.Normal(0, 1000).expand([1]).to_event(1)) wxyz = torch.cat([w, x, y, z]) pyro.sample("obs", dist.MultivariateNormal(torch.zeros(5), cov), obs=wxyz)
def model(self, bows, embeddings, article_ids): pyro.module("topic_recognition_net", self.topic_recognition_net) with pyro.plate("articles", bows.shape[0]): # instead of a Dirichlet prior, we use a log-normal distribution prop_mu = bows.new_zeros((bows.shape[0], self.nav_topics)) prop_sigma = bows.new_ones((bows.shape[0], self.nav_topics)) props = pyro.sample( "theta", dist.LogNormal(prop_mu, prop_sigma).to_event(1)) topics_mu, topics_sigma = self.topic_recognition_net(props) for batch_article_id, article_id in enumerate(article_ids): nav_embeddings = torch.tensor( embeddings[self.article_navs[article_id]], dtype=torch.float32).to(device) for article_nav_id in pyro.plate( "navs_{}".format(article_id), len(self.article_navs[article_id])): pyro.sample("nav_{}_{}".format(article_id, article_nav_id), dist.MultivariateNormal( topics_mu[batch_article_id], scale_tril=torch.diag( topics_sigma[batch_article_id])), obs=nav_embeddings[article_nav_id])
def program_arbitrary(nn_model, p_tgt, std=0.05): ''' a probabilistic model for enforcing p = p_tgt sample u_i ~ No(p_i, std) then the posterior is p( z | u_i=p_tgt[i] ) ''' if nn_model.device == 'cuda': typ = torch.cuda.FloatTensor elif nn_model.device == 'cpu': typ = torch.FloatTensor torch.set_default_tensor_type(typ) std = torch.tensor(std).float() latent_dim = nn_model.latent_dim loc = torch.zeros(latent_dim) cov = torch.eye(latent_dim) z = pyro.sample('z', dist.MultivariateNormal(loc, cov)) prob = nn_model.predict_from_latent(z) N = len(prob) us = [] for i in range(N): us.append( pyro.sample('u_%i' % i, dist.Normal(prob[i], std), obs=torch.tensor(p_tgt[i])))