def iarange_model(subsample_size): loc = torch.zeros(20) scale = torch.ones(20) with pyro.iarange('iarange', 20, subsample_size) as batch: pyro.sample("x", dist.Normal(loc[batch], scale[batch])) result = list(batch.data) return result
def guide(subsample_size): mu = pyro.param("mu", lambda: Variable(torch.zeros(len(data)), requires_grad=True)) sigma = pyro.param("sigma", lambda: Variable(torch.ones(1), requires_grad=True)) with pyro.iarange("data", len(data), subsample_size) as ind: mu = mu[ind] sigma = sigma.expand(subsample_size) pyro.sample("z", dist.Normal(mu, sigma, reparameterized=reparameterized))
def guide(): mu_q = pyro.param("mu_q", Variable(self.analytic_mu_n.data + 0.334 * torch.ones(2), requires_grad=True)) log_sig_q = pyro.param("log_sig_q", Variable( self.analytic_log_sig_n.data - 0.29 * torch.ones(2), requires_grad=True)) mu_q_prime = pyro.param("mu_q_prime", Variable(torch.Tensor([-0.34, 0.52]), requires_grad=True)) kappa_q = pyro.param("kappa_q", Variable(torch.Tensor([0.74]), requires_grad=True)) log_sig_q_prime = pyro.param("log_sig_q_prime", Variable(-0.5 * torch.log(1.2 * self.lam0.data), requires_grad=True)) sig_q, sig_q_prime = torch.exp(log_sig_q), torch.exp(log_sig_q_prime) mu_latent_dist = dist.Normal(mu_q, sig_q, reparameterized=repa2) mu_latent = pyro.sample("mu_latent", mu_latent_dist, baseline=dict(use_decaying_avg_baseline=use_decaying_avg_baseline)) mu_latent_prime_dist = dist.Normal(kappa_q.expand_as(mu_latent) * mu_latent + mu_q_prime, sig_q_prime, reparameterized=repa1) pyro.sample("mu_latent_prime", mu_latent_prime_dist, baseline=dict(nn_baseline=mu_prime_baseline, nn_baseline_input=mu_latent, use_decaying_avg_baseline=use_decaying_avg_baseline)) return mu_latent
def guide(self, x): # register PyTorch module `encoder` with Pyro pyro.module("encoder", self.encoder) # use the encoder to get the parameters used to define q(z|x) z_mu, z_sigma = self.encoder.forward(x) # sample the latent code z pyro.sample("latent", dist.normal, z_mu, z_sigma)
def model(self): self.set_mode("model") Xu = self.get_param("Xu") u_loc = self.get_param("u_loc") u_scale_tril = self.get_param("u_scale_tril") M = Xu.shape[0] Kuu = self.kernel(Xu) + torch.eye(M, out=Xu.new_empty(M, M)) * self.jitter Luu = Kuu.potrf(upper=False) zero_loc = Xu.new_zeros(u_loc.shape) u_name = param_with_module_name(self.name, "u") if self.whiten: Id = torch.eye(M, out=Xu.new_empty(M, M)) pyro.sample(u_name, dist.MultivariateNormal(zero_loc, scale_tril=Id) .independent(zero_loc.dim() - 1)) else: pyro.sample(u_name, dist.MultivariateNormal(zero_loc, scale_tril=Luu) .independent(zero_loc.dim() - 1)) f_loc, f_var = conditional(self.X, Xu, self.kernel, u_loc, u_scale_tril, Luu, full_cov=False, whiten=self.whiten, jitter=self.jitter) f_loc = f_loc + self.mean_function(self.X) if self.y is None: return f_loc, f_var else: with poutine.scale(None, self.num_data / self.X.shape[0]): return self.likelihood(f_loc, f_var, self.y)
def guide(subsample): loc = pyro.param("loc", lambda: torch.zeros(len(data), requires_grad=True)) scale = pyro.param("scale", lambda: torch.tensor([1.0], requires_grad=True)) with pyro.iarange("particles", num_particles): with pyro.iarange("data", len(data), subsample_size, subsample) as ind: loc_ind = loc[ind].unsqueeze(-1).expand(-1, num_particles) pyro.sample("z", Normal(loc_ind, scale))
def guide(): p = pyro.param("p", torch.tensor(0.5, requires_grad=True)) outer_irange = pyro.irange("irange_0", 3, subsample_size) inner_irange = pyro.irange("irange_1", 3, subsample_size) for j in inner_irange: for i in outer_irange: pyro.sample("x_{}_{}".format(i, j), dist.Bernoulli(p))
def model(): p = torch.tensor(0.5) outer_irange = pyro.irange("irange_0", 3, subsample_size) inner_irange = pyro.irange("irange_1", 3, subsample_size) for i in outer_irange: for j in inner_irange: pyro.sample("x_{}_{}".format(i, j), dist.Bernoulli(p))
def guide(): pyro.module("mymodule", pt_guide) mu_q, tau_q = torch.exp(pt_guide.mu_q_log), torch.exp(pt_guide.tau_q_log) sigma = torch.pow(tau_q, -0.5) pyro.sample("mu_latent", dist.Normal(mu_q, sigma, reparameterized=reparameterized), baseline=dict(use_decaying_avg_baseline=True))
def sample_ws(name, width): alpha_w_q = pyro.param("log_alpha_w_q_%s" % name, lambda: rand_tensor((width), self.alpha_init, self.sigma_init)) mean_w_q = pyro.param("log_mean_w_q_%s" % name, lambda: rand_tensor((width), self.mean_init, self.sigma_init)) alpha_w_q, mean_w_q = self.softplus(alpha_w_q), self.softplus(mean_w_q) pyro.sample("w_%s" % name, Gamma(alpha_w_q, alpha_w_q / mean_w_q))
def model(self, data): loc = self.loc_0 lambda_prec = self.lambda_prec for i in range(1, self.chain_len + 1): loc = pyro.sample('loc_{}'.format(i), dist.Normal(loc=loc, scale=lambda_prec)) pyro.sample('obs', dist.Normal(loc, lambda_prec), obs=data)
def _register_param(self, param, mode="model"): """ Registers a parameter to Pyro. It can be seen as a wrapper for :func:`pyro.param` and :func:`pyro.sample` primitives. :param str param: Name of the parameter. :param str mode: Either "model" or "guide". """ if param in self._fixed_params: self._registered_params[param] = self._fixed_params[param] return prior = self._priors.get(param) if self.name is None: param_name = param else: param_name = param_with_module_name(self.name, param) if prior is None: constraint = self._constraints.get(param) default_value = getattr(self, param) if constraint is None: p = pyro.param(param_name, default_value) else: p = pyro.param(param_name, default_value, constraint=constraint) elif mode == "model": p = pyro.sample(param_name, prior) else: # prior != None and mode = "guide" MAP_param_name = param_name + "_MAP" # TODO: consider to init parameter from a prior call instead of mean MAP_param = pyro.param(MAP_param_name, prior.mean.detach()) p = pyro.sample(param_name, dist.Delta(MAP_param)) self._registered_params[param] = p
def guide(): q1 = pyro.param("q1", torch.tensor(pi1, requires_grad=True)) q2 = pyro.param("q2", torch.tensor(pi2, requires_grad=True)) with pyro.iarange("particles", num_particles): y = pyro.sample("y", dist.Bernoulli(q1).expand_by([num_particles]), infer={"enumerate": enumerate1}) if include_z: pyro.sample("z", dist.Normal(q2 * y + 0.10, 1.0))
def model(self): self.set_mode("model") f_loc = self.get_param("f_loc") f_scale_tril = self.get_param("f_scale_tril") N = self.X.shape[0] Kff = self.kernel(self.X) + (torch.eye(N, out=self.X.new_empty(N, N)) * self.jitter) Lff = Kff.potrf(upper=False) zero_loc = self.X.new_zeros(f_loc.shape) f_name = param_with_module_name(self.name, "f") if self.whiten: Id = torch.eye(N, out=self.X.new_empty(N, N)) pyro.sample(f_name, dist.MultivariateNormal(zero_loc, scale_tril=Id) .independent(zero_loc.dim() - 1)) f_scale_tril = Lff.matmul(f_scale_tril) else: pyro.sample(f_name, dist.MultivariateNormal(zero_loc, scale_tril=Lff) .independent(zero_loc.dim() - 1)) f_var = f_scale_tril.pow(2).sum(dim=-1) if self.whiten: f_loc = Lff.matmul(f_loc.unsqueeze(-1)).squeeze(-1) f_loc = f_loc + self.mean_function(self.X) if self.y is None: return f_loc, f_var else: return self.likelihood(f_loc, f_var, self.y)
def model(num_particles): with pyro.iarange("particles", num_particles): q3 = pyro.param("q3", torch.tensor(pi3, requires_grad=True)) q4 = pyro.param("q4", torch.tensor(0.5 * (pi1 + pi2), requires_grad=True)) z = pyro.sample("z", dist.Normal(q3, 1.0).expand_by([num_particles])) zz = torch.exp(z) / (1.0 + torch.exp(z)) pyro.sample("y", dist.Bernoulli(q4 * zz))
def guide(num_particles): q1 = pyro.param("q1", torch.tensor(pi1, requires_grad=True)) q2 = pyro.param("q2", torch.tensor(pi2, requires_grad=True)) with pyro.iarange("particles", num_particles): z = pyro.sample("z", dist.Normal(q2, 1.0).expand_by([num_particles])) zz = torch.exp(z) / (1.0 + torch.exp(z)) pyro.sample("y", dist.Bernoulli(q1 * zz))
def sample_zs(name, width): alpha_z_q = pyro.param("log_alpha_z_q_%s" % name, lambda: rand_tensor((x_size, width), self.alpha_init, self.sigma_init)) mean_z_q = pyro.param("log_mean_z_q_%s" % name, lambda: rand_tensor((x_size, width), self.mean_init, self.sigma_init)) alpha_z_q, mean_z_q = self.softplus(alpha_z_q), self.softplus(mean_z_q) pyro.sample("z_%s" % name, Gamma(alpha_z_q, alpha_z_q / mean_z_q).independent(1))
def guide(self, xs, ys=None): """ The guide corresponds to the following: q(y|x) = categorical(alpha(x)) # infer digit from an image q(z|x,y) = normal(mu(x,y),sigma(x,y)) # infer handwriting style from an image and the digit mu, sigma are given by a neural network `encoder_z` alpha is given by a neural network `encoder_y` :param xs: a batch of scaled vectors of pixels from an image :param ys: (optional) a batch of the class labels i.e. the digit corresponding to the image(s) :return: None """ # inform Pyro that the variables in the batch of xs, ys are conditionally independent with pyro.iarange("independent"): # if the class label (the digit) is not supervised, sample # (and score) the digit with the variational distribution # q(y|x) = categorical(alpha(x)) if ys is None: alpha = self.encoder_y.forward(xs) ys = pyro.sample("y", dist.categorical, alpha) # sample (and score) the latent handwriting-style with the variational # distribution q(z|x,y) = normal(mu(x,y),sigma(x,y)) mu, sigma = self.encoder_z.forward([xs, ys]) zs = pyro.sample("z", dist.normal, mu, sigma) # noqa: F841
def model(): with pyro.iarange("num_particles", 10, dim=-3): with pyro.iarange("components", 2, dim=-1): p = pyro.sample("p", dist.Beta(torch.tensor(1.1), torch.tensor(1.1))) assert p.shape == torch.Size((10, 1, 2)) with pyro.iarange("data", data.shape[0], dim=-2): pyro.sample("obs", dist.Bernoulli(p), obs=data)
def guide_step(self, t, n, prev, inputs): rnn_input = torch.cat((inputs['embed'], prev.z_where, prev.z_what, prev.z_pres), 1) h, c = self.rnn(rnn_input, (prev.h, prev.c)) z_pres_p, z_where_loc, z_where_scale = self.predict(h) # Compute baseline estimates for discrete choice z_pres. bl_value, bl_h, bl_c = self.baseline_step(prev, inputs) # Sample presence. z_pres = pyro.sample('z_pres_{}'.format(t), dist.Bernoulli(z_pres_p * prev.z_pres).independent(1), infer=dict(baseline=dict(baseline_value=bl_value.squeeze(-1)))) sample_mask = z_pres if self.use_masking else torch.tensor(1.0) z_where = pyro.sample('z_where_{}'.format(t), dist.Normal(z_where_loc + self.z_where_loc_prior, z_where_scale * self.z_where_scale_prior) .mask(sample_mask) .independent(1)) # Figure 2 of [1] shows x_att depending on z_where and h, # rather than z_where and x as here, but I think this is # correct. x_att = image_to_window(z_where, self.window_size, self.x_size, inputs['raw']) # Encode attention windows. z_what_loc, z_what_scale = self.encode(x_att) z_what = pyro.sample('z_what_{}'.format(t), dist.Normal(z_what_loc, z_what_scale) .mask(sample_mask) .independent(1)) return GuideState(h=h, c=c, bl_h=bl_h, bl_c=bl_c, z_pres=z_pres, z_where=z_where, z_what=z_what)
def model(): p2 = torch.tensor(torch.ones(2) / 2) p3 = torch.tensor(torch.ones(3) / 3) x2 = pyro.sample("x2", dist.OneHotCategorical(p2)) x3 = pyro.sample("x3", dist.OneHotCategorical(p3)) assert x2.shape == torch.Size([2]) + iarange_shape + p2.shape assert x3.shape == torch.Size([3, 1]) + iarange_shape + p3.shape
def bernoulli_normal_model(): bern_0 = pyro.sample('bern_0', dist.Bernoulli(torch.zeros(1) * 1e-2)) loc = torch.ones(1) if bern_0.item() else -torch.ones(1) normal_0 = torch.ones(1) pyro.sample('normal_0', dist.Normal(loc, torch.ones(1) * 1e-2), obs=normal_0) return [bern_0, normal_0]
def guide(): alpha_q_log = pyro.param("alpha_q_log", Variable(self.log_alpha_n.data + 0.17, requires_grad=True)) beta_q_log = pyro.param("beta_q_log", Variable(self.log_beta_n.data - 0.143, requires_grad=True)) alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log) pyro.sample("p_latent", dist.beta, alpha_q, beta_q) pyro.map_data("aaa", self.data, lambda i, x: None, batch_size=self.batch_size)
def guide(self, x): # register PyTorch module `encoder` with Pyro pyro.module("encoder", self.encoder) with pyro.iarange("data", x.size(0)): # use the encoder to get the parameters used to define q(z|x) z_loc, z_scale = self.encoder.forward(x) # sample the latent code z pyro.sample("latent", dist.Normal(z_loc, z_scale).independent(1))
def guide(): q = pyro.param("q") with pyro.iarange("particles", num_particles): pyro.sample("x", dist.Bernoulli(q).expand_by([num_particles]), infer={"enumerate": enumerate1}) for i in pyro.irange("irange", irange_dim): pyro.sample("y_{}".format(i), dist.Bernoulli(q).expand_by([num_particles]), infer={"enumerate": enumerate2})
def model_sample(self, batch_size=1): # sample the handwriting style from the constant prior distribution prior_mu = Variable(torch.zeros([batch_size, self.z_dim])) prior_sigma = Variable(torch.ones([batch_size, self.z_dim])) zs = pyro.sample("z", dist.normal, prior_mu, prior_sigma) mu = self.decoder.forward(zs) xs = pyro.sample("sample", dist.bernoulli, mu) return xs, mu
def irange_model(subsample_size): loc = torch.zeros(20) scale = torch.ones(20) result = [] for i in pyro.irange('irange', 20, subsample_size): pyro.sample("x_{}".format(i), dist.Normal(loc[i], scale[i])) result.append(i) return result
def model(data): idk = torch.tensor(4.0) # where is my neural network? idkb = torch.tensor(4.0) genelambdas = dist.Gamma(idka, idkb, batch_size = 19795) for celltype in range(data.size(0)): # this one's 56 right? with iarange('observe_{}'.format(celltype)): pyro.sample('indiv', dist.Poisson(genelambdas), obs=data[celltype])
def guide(): q = pyro.param("q") with pyro.iarange("particles", num_particles): pyro.sample("y", dist.Bernoulli(q).expand_by([num_particles]), infer={"enumerate": enumerate1}) with pyro.iarange("iarange", iarange_dim): pyro.sample("z", dist.Bernoulli(q).expand_by([iarange_dim, num_particles]), infer={"enumerate": enumerate2})
def obs_inner(i, _i, _x): for k in range(n_superfluous_top): pyro.sample("z_%d_%d" % (i, k), dist.Normal(ng_zeros(4 - i, 1), ng_ones(4 - i, 1), reparameterized=False)) pyro.observe("obs_%d" % i, dist.normal, _x, mu_latent, torch.pow(self.lam, -0.5)) for k in range(n_superfluous_top, n_superfluous_top + n_superfluous_bottom): pyro.sample("z_%d_%d" % (i, k), dist.Normal(ng_zeros(4 - i, 1), ng_ones(4 - i, 1), reparameterized=False))
def guide(): pyro.module("mymodule", pt_guide) mu_q, tau_q = torch.exp(pt_guide.mu_q_log), torch.exp(pt_guide.tau_q_log) sigma = torch.pow(tau_q, -0.5) pyro.sample("mu_latent", dist.Normal(mu_q, sigma, reparameterized=reparameterized))
def guide(x): x = torch.reshape(x, [320, 4096]) with pyro.plate("w_top_plate", 4000): #============ sample_ws alpha_w_q =\ pyro.param("log_alpha_w_q_top", alpha_init * torch.ones(4000) + sigma_init * torch.randn(4000)) mean_w_q =\ pyro.param("log_mean_w_q_top", mean_init * torch.ones(4000) + sigma_init * torch.randn(4000)) alpha_w_q = softplus(alpha_w_q) mean_w_q = softplus(mean_w_q) pyro.sample("w_top", Gamma(alpha_w_q, alpha_w_q / mean_w_q)) #============ sample_ws with pyro.plate("w_mid_plate", 600): #============ sample_ws alpha_w_q =\ pyro.param("log_alpha_w_q_mid", alpha_init * torch.ones(600) + sigma_init * torch.randn(600)) mean_w_q =\ pyro.param("log_mean_w_q_mid", mean_init * torch.ones(600) + sigma_init * torch.randn(600)) alpha_w_q = softplus(alpha_w_q) mean_w_q = softplus(mean_w_q) pyro.sample("w_mid", Gamma(alpha_w_q, alpha_w_q / mean_w_q)) #============ sample_ws with pyro.plate("w_bottom_plate", 61440): #============ sample_ws alpha_w_q =\ pyro.param("log_alpha_w_q_bottom", alpha_init * torch.ones(61440) + sigma_init * torch.randn(61440)) mean_w_q =\ pyro.param("log_mean_w_q_bottom", mean_init * torch.ones(61440) + sigma_init * torch.randn(61440)) alpha_w_q = softplus(alpha_w_q) mean_w_q = softplus(mean_w_q) pyro.sample("w_bottom", Gamma(alpha_w_q, alpha_w_q / mean_w_q)) #============ sample_ws with pyro.plate("data", 320): #============ sample_zs alpha_z_q =\ pyro.param("log_alpha_z_q_top", alpha_init * torch.ones(320, 100) + sigma_init * torch.randn(320, 100)) mean_z_q =\ pyro.param("log_mean_z_q_top", mean_init * torch.ones(320, 100) + sigma_init * torch.randn(320, 100)) alpha_z_q = softplus(alpha_z_q) mean_z_q = softplus(mean_z_q) pyro.sample("z_top", Gamma(alpha_z_q, alpha_z_q / mean_z_q).to_event(1)) #============ sample_zs #============ sample_zs alpha_z_q =\ pyro.param("log_alpha_z_q_mid", alpha_init * torch.ones(320, 40) + sigma_init * torch.randn(320, 40)) mean_z_q =\ pyro.param("log_mean_z_q_mid", mean_init * torch.ones(320, 40) + sigma_init * torch.randn(320, 40)) alpha_z_q = softplus(alpha_z_q) mean_z_q = softplus(mean_z_q) pyro.sample("z_mid", Gamma(alpha_z_q, alpha_z_q / mean_z_q).to_event(1)) #============ sample_zs #============ sample_zs alpha_z_q =\ pyro.param("log_alpha_z_q_bottom", alpha_init * torch.ones(320, 15) + sigma_init * torch.randn(320, 15)) mean_z_q =\ pyro.param("log_mean_z_q_bottom", mean_init * torch.ones(320, 15) + sigma_init * torch.randn(320, 15)) alpha_z_q = softplus(alpha_z_q) mean_z_q = softplus(mean_z_q) pyro.sample("z_bottom", Gamma(alpha_z_q, alpha_z_q / mean_z_q).to_event(1))
def model(self): # If the Kernel IS NOT random, we declare the kernel within the model # if (not self.random_kernel): # self.kernel = pydmn.kernels.RBF() # Covariance matrix of observed times entailed by our kernel # Kff = self.kernel(self.Y_time.reshape(-1,1)) # Kff.view(-1)[::self.T_net + 1] += self.jitter # add jitter to the diagonal # Lff = Kff.cholesky() # cholesky lower triangular Lff = self.Lff_ini ## Sampling system-wide connectivity and average weights ## with pyro.plate('gp_system_all', self.K_net*self.n_w ): # Mean function of the GPs gp_system_mean = pyro.sample( "gp_system_mean", dist.Normal( torch.zeros( (self.K_net*self.n_w) ), torch.tensor([0.1]) ) ) # Demeaned GPs gp_system_demean = pyro.sample( "gp_system_demean", dist.MultivariateNormal( torch.zeros( (self.K_net*self.n_w, self.T_net) ), scale_tril=Lff ) ) gp_system_mean = gp_system_mean.reshape(self.K_net,self.n_w) gp_system_demean = gp_system_demean.reshape(self.K_net, self.n_w, self.T_net) # Latent systemic evolution gp_system = gp_system_mean.expand(self.T_net, self.K_net, self.n_w).permute(1,2,0) + gp_system_demean ## Sampling latent coordinates ## if self.coord: with pyro.plate('gp_coord_all', self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w ): # Mean function of the GPs gp_coord_mean = pyro.sample( "gp_coord_mean", dist.Normal( torch.zeros( (self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w) ), torch.tensor([0.1]) ) ) # Demeaned GPs gp_coord_demean = pyro.sample( "gp_coord_demean", dist.MultivariateNormal( torch.zeros( (self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w, self.T_net) ), scale_tril=Lff ) ) gp_coord_mean = gp_coord_mean.reshape(self.V_net,self.H_dim,self.K_net,self.n_dir,self.n_w) gp_coord_demean = gp_coord_demean.reshape(self.V_net, self.H_dim, self.K_net, self.n_dir, self.n_w, self.T_net) # Latent coordinates gp_coord = gp_coord_mean.expand(self.T_net, self.V_net, self.H_dim, self.K_net, self.n_dir, self.n_w).permute(1,2,3,4,5,0) + gp_coord_demean ## Sampling Sociability and Popularity terms ## if self.socpop: with pyro.plate('gp_socpop_all', self.V_net*self.K_net*self.n_dir*self.n_w ): # Mean function of the GPs gp_socpop_mean = pyro.sample( "gp_socpop_mean", dist.Normal( torch.zeros( (self.V_net*self.K_net*self.n_dir*self.n_w) ), torch.tensor([0.1]) ) ) # Demeaned GPs gp_socpop_demean = pyro.sample( "gp_socpop_demean", dist.MultivariateNormal( torch.zeros( (self.V_net*self.K_net*self.n_dir*self.n_w, self.T_net) ), scale_tril=Lff ) ) gp_socpop_mean = gp_socpop_mean.reshape(self.V_net,self.K_net,self.n_dir,self.n_w) gp_socpop_demean = gp_socpop_demean.reshape(self.V_net, self.K_net, self.n_dir, self.n_w, self.T_net) # Latent coordinates gp_socpop = gp_socpop_mean.expand(self.T_net, self.V_net, self.K_net, self.n_dir, self.n_w).permute(1,2,3,4,0) + gp_socpop_demean ### Linear Predictor ### # Systemic component Y_linpred = gp_system.expand(self.V_net, self.V_net, self.K_net, self.n_w, self.T_net).permute(0,1,4,2,3) # Distance between agents if self.coord: Y_linpred = Y_linpred + torch.einsum('uhkwt,vhkwt->uvtkw', gp_coord[:,:,:,0,:,:], gp_coord[:,:,:,self.n_dir-1,:,:]) # Sociability and Popularity effects if self.socpop: gp_soc = gp_socpop[:,:,0,:,:].expand(self.V_net, self.V_net, self.K_net, self.n_w, self.T_net).transpose(0,1) gp_pop = gp_socpop[:,:,self.n_dir-1,:,:].expand(self.V_net, self.V_net, self.K_net, self.n_w,self.T_net) Y_linpred = Y_linpred + gp_soc.permute(0,1,4,2,3) + gp_pop.permute(0,1,4,2,3) ### Link propensity (probability of occur) ### Y_link_prob = torch.sigmoid(Y_linpred[:,:,:,:,0]) Y_link_prob_valid = Y_link_prob.flatten()[self.Y_valid_id.flatten()==1] with pyro.plate( "data", Y_link_prob_valid.shape[0]): pyro.sample( "obs", dist.Bernoulli(Y_link_prob_valid), obs=self.Y_link.flatten()[self.cond_Y_link] ) ### Link expected weight (weight given occurrence) ### if self.weighted: with pyro.plate( "sigma_k_ind", self.K_net): sigma_k = pyro.sample( 'sigma_k', dist.InverseGamma( self.sigma_k_prior_param[0].expand(self.K_net), self.sigma_k_prior_param[1].expand(self.K_net) ) ) Y_link_SDw = sigma_k.expand(self.V_net,self.V_net,self.T_net,self.K_net) Y_link_Ew = Y_linpred[:,:,:,:,1] # cond_Y_w: condition of being positive and valid weights (defined in set_data()) Y_link_Ew_valid = Y_link_Ew.flatten()[self.cond_Y_w] Y_link_SDw_valid = Y_link_SDw.flatten()[self.cond_Y_w] with pyro.plate( "data_w", Y_link_Ew_valid.shape[0] ): pyro.sample( "obs_w", dist.Normal(Y_link_Ew_valid,Y_link_SDw_valid), obs=self.Y.flatten()[self.cond_Y_w] )
def model(): p = pyro.param("p", Variable(torch.Tensor([0.05]))) ps = pyro.param("ps", Variable(torch.Tensor([0.1, 0.2, 0.3, 0.4]))) x = pyro.sample("x", dist.Bernoulli(p)) y = pyro.sample("y", dist.Categorical(ps, one_hot=False)) return dict(x=x, y=y)
def guide(): p = pyro.param( "p", Variable(torch.Tensor([0.0, 0.5, 1.0]), requires_grad=True)) pyro.sample("z", dist.Bernoulli(p))
def init_params(data): params = {} params["beta"] = init_vector("beta", dims=(2)) # vector params["sigma"] = pyro.sample("sigma", dist.Uniform(0., 1000.)) return params
def generate_number_grammar(input_symbols, grammar=None): zeroRule = False connectingWordsProb = 0.2 ZeroProb = 0.2 exceptionProb = 0.3 tenThousandWordProb = 0.3 input_symbol_options = copy.deepcopy(input_symbols) #random.shuffle(input_symbol_options) rules, intRules, input_symbol_options, oneWord = generateOneToTen( input_symbol_options) for base in [10000, 1000, 100, 10]: if base == 10 and tp == 'irregular': tp = 'irregular' else: if base in [100, 10]: tp = selectFromList(['regular', 'irregular'], f"regularity_{base}", obs=None) else: tp = 'regular' # IMPORTANT: if 100s are irregular 10s cant be regular ... probably fine # TODO teens and twenties? # maybe teens are irregular if 10s are? if tp == 'irregular': for i in range(1, 10): num = str(i * base) word, input_symbol_options = popFromList( input_symbol_options, name=f"irreg_word_{base}_{i}", obs=None) rules.append(Rule(word, num)) intRules.append([str(num), '->', word]) #french situation?? #do irregular teens half the time: if base == 10 and pyro.sample('irreg_teens', pyro.distributions.Bernoulli(0.5)): for i in range(11, 20): num = str(i) word, input_symbol_options = popFromList( input_symbol_options, name=f"teen_{i}", obs=None) rules.append(Rule(word, num)) intRules.append([str(num), '->', word]) intRules.append([ '>' + str(base), '->', '[x - x%' + str(base) + ']', '[x%' + str(base) + ']' ]) elif tp == 'regular': #we have a word for ten thousand sometimes if base == 10000: if not pyro.sample( f'ten_thousand_word', pyro.distributions.Bernoulli(tenThousandWordProb), obs=None): base = 1000000 baseWord, input_symbol_options = popFromList(input_symbol_options, name=f"word_{base}", obs=None) rule = Rule('x1 ' + baseWord + ' y1', '[x1]*' + str(base) + ' [y1]') rules.append(rule) intRules.append([ '>' + str(base), '->', '[x//' + str(base) + ']', baseWord, '[x%' + str(base) + ']' ]) if pyro.sample(f'one_exception_{base}', pyro.distributions.Bernoulli(exceptionProb), obs=None): #for now, only one exception oneException = True if pyro.sample(f'one_exception_change_{base}', pyro.distributions.Bernoulli(0.3), obs=None): #for now, only one exception oneExceptionWord, input_symbol_options = popFromList( input_symbol_options, name=f"oneWord_{base}", obs=None) exceptionRule = Rule(' '.join([oneExceptionWord, 'y1']), str(base) + '* 1' + ' [y1]') intRule = [ '//' + str(base) + '==' + str(1), '->', oneExceptionWord, '[x%' + str(base) + ']' ] else: exceptionRule = Rule(' '.join([baseWord, 'y1']), str(base) + '* 1' + ' [y1]') intRule = [ '//' + str(base) + '==' + str(1), '->', baseWord, '[x%' + str(base) + ']' ] rules.insert(-1, exceptionRule) intRules.insert(-1, intRule) explicitExceptionRule = Rule(baseWord, str(base)) else: explicitExceptionRule = Rule(' '.join([oneWord, baseWord]), str(base)) rules.insert(0, explicitExceptionRule) if pyro.sample(f'exception_{base}', pyro.distributions.Bernoulli(exceptionProb), obs=None): #for now, only one exception exceptionNum = selectFromList(list(range(2, 10)), name=f"exception_num_{base}", obs=None) exceptionWord, input_symbol_options = popFromList( input_symbol_options, name=f"exception_name_{base}", obs=None) exceptionRule = Rule( ' '.join([exceptionWord, baseWord, 'y1']), str(base) + '* ' + str(exceptionNum) + ' [y1]') intRule = [ '//' + str(base) + '==' + str(exceptionNum), '->', exceptionWord, baseWord, '[x%' + str(base) + ']' ] #TODO other direction #TODO rules.insert(-1, exceptionRule) intRules.insert(-1, intRule) else: assert False if pyro.sample(f'connecting_word', pyro.distributions.Bernoulli(connectingWordsProb), obs=None): connectingWord, input_symbol_options = popFromList( input_symbol_options, name=f"connecting_word_val", obs=None) rules.append(Rule('u1 ' + connectingWord + ' x1', '[u1] [x1]')) intRules.append(['>10', '->', '[x - x%10]', connectingWord, '[x%10]']) #assert False, "need to figure out where to put connectingWord" if pyro.sample(f'zero', pyro.distributions.Bernoulli(ZeroProb), obs=None): zeroWord, input_symbol_options = popFromList(input_symbol_options, name=f"zero_word", obs=None) rules.append(Rule(zeroWord, '0')) intRules.append(['0', '->', zeroWord]) zeroRule = True concatRule = Rule('u1 x1', '[u1] [x1]') rules.append(concatRule) #can shuffle those rules, but it doesn't matter, do that at example sampling time return NumberGrammar(rules, input_symbols), IntGrammar( intRules, zeroRule) #makeIntG(intRules, intG)
def init_vector(name, dims=None): return pyro.sample( name, dist.Normal(torch.zeros(dims), 0.2 * torch.ones(dims)).to_event(1))
def forward(self, x, y=None): y_pr = self.seq(x) if y != None: with pyro.plate("data", y.shape[0]): pyro.sample("obs", dist.Normal(y, self.target_std).to_event(1), obs=y_pr) return y_pr.detach()
def guide(self, data): encoder = pyro.module('encoder', self.vae_encoder) with pyro.plate('data', data.size(0)): z_mean, z_var = encoder.forward(data) pyro.sample('latent', Normal(z_mean, z_var.sqrt()).to_event(1))
def model(game, observer, action): if game.turn == observer: return None known_cards = get_known_cards(game, observer) if action: known_cards.update(action.cards) unknown_cards = list(set(game.unused_cards) - known_cards) idx_to_id = idx_2_id(observer) id_to_idx = id_2_idx(observer) num_cards_in_hand = { i: len(set(game.players[idx_to_id[i]].hand) - known_cards) for i in idx_to_id } probs = [] for card in unknown_cards: """ FIXME: prior distribution of cards? """ theta = [torch.tensor(global_card_dist[card][i]) for i in idx_to_id] player_probs = pyro.sample('{}_probs'.format(card), dist.Dirichlet(torch.stack(theta))) normalized_player_probs = player_probs #/ torch.sum(player_probs) probs.append(normalized_player_probs) probs = torch.stack(probs) hands = {i: list() for i in idx_to_id} card_probs = {i: list() for i in idx_to_id} for i, card in random.sample(tuple(enumerate(unknown_cards)), len(unknown_cards)): assigned = False while not assigned: player = torch.distributions.Categorical(probs=probs[i]).sample() #player = pyro.sample('{}_locs'.format(card), dist.Categorical(probs=probs[i])) if len(hands[int(player)]) < num_cards_in_hand[int(player)]: hands[int(player)].append(card) card_probs[int(player)].append(probs[i][int(player)]) assigned = True """ for i in idx_to_id: pyro.sample( '{}_card_assignment'.format(i), dist.Bernoulli(probs=torch.prod(torch.tensor(card_probs[i]))), obs=torch.tensor(1.) ) """ ai_player = GreedyPlayer(game, game.turn) ai_player.hand = hands[id_to_idx[game.turn]] if action: ai_player.hand += action.cards #print("===============card===============") #for card in ai_player.hand: # print(card) #print("==================================") actions, action_probs = tuple( [list(t) for t in zip(*ai_player.action_probs())]) #print(actions, action_probs) #print(actions, action_probs) #print([str(a) for a in actions]) #print(action) action_dist = dist.Categorical(probs=torch.tensor(action_probs)) pyro.sample('action', action_dist, obs=torch.tensor(actions.index(action))) return hands
def normal_product(loc, scale): z1 = pyro.sample("z1", pyro.distributions.Normal(loc, scale)) z2 = pyro.sample("z2", pyro.distributions.Normal(loc, scale)) y = z1 * z2 return y
def model_1(): a = pyro.sample("a", dist.Normal(0, 1)) pyro.sample("b", dist.Normal(a, 1), obs=torch.tensor(0.0))
def model(data): w = pyro.sample("w", dist.Normal(0, 1)) with pyro.plate("p", len(data)): x = pyro.sample("x", dist.Normal(0, 1)) y = pyro.sample("y", dist.Normal(0, 1)) pyro.sample("z", dist.Normal(w + x + y, 1), obs=data)
def model(): p = Variable(torch.Tensor([0.0, 0.5, 1.0])) pyro.sample("z", dist.Bernoulli(p))
def model(data): with pyro.plate("p", len(data)): x = pyro.sample("x", dist.Normal(0, 1)) y = pyro.sample("y", dist.Normal(0, 1)) pyro.sample("z", dist.Normal(x.sum(), y.sum().exp()), obs=data.sum())
def guide(self): # Posterior Covariance of the GP # if self.random_kernel: # self.kernel_param = pyro.param("kernel_param", 50*torch.ones((2,2)), constraint=constraints.positive) # pyro.sample( "kernel.lengthscale", dist.InverseGamma( self.kernel_param[0,0], self.kernel_param[0,1] ) ) # pyro.sample( "kernel.variance", dist.InverseGamma( self.kernel_param[1,0], self.kernel_param[1,1] ) ) # Sampling Systemic components # self.gp_system_mean_loc = pyro.param("gp_system_mean_loc", self.gp_system_mean_ini ) self.gp_system_mean_scale = pyro.param("gp_system_mean_scale", 0.1*torch.ones((self.K_net,self.n_w)), constraint=constraints.positive) self.gp_system_demean = pyro.param( f"gp_system_demean_loc", self.gp_system_demean_ini ) # Posterior Covariance of the GP self.gp_system_cov_tril = pyro.param( "gp_system_cov_tril", self.Lff_ini.expand(self.K_net,self.n_w,self.T_net,self.T_net), constraint=constraints.lower_cholesky ) with pyro.plate('gp_system_all', self.K_net*self.n_w ): # Posterior GP (mean function params) # pyro.sample( "gp_system_mean", dist.Normal( self.gp_system_mean_loc.reshape(self.K_net*self.n_w), self.gp_system_mean_scale.reshape(self.K_net*self.n_w) ) ) # Posterior GP (demeaned) # pyro.sample( f"gp_system_demean", dist.MultivariateNormal( self.gp_system_demean.reshape(self.K_net*self.n_w , self.T_net), scale_tril=self.gp_system_cov_tril.reshape(self.K_net*self.n_w , self.T_net, self.T_net) ) ) # Sampling coordinates # if self.coord: self.gp_coord_mean_loc = pyro.param("gp_coord_mean_loc", self.gp_coord_mean_ini ) self.gp_coord_mean_scale = pyro.param("gp_coord_mean_scale", 0.1*torch.ones((self.V_net,self.H_dim,self.K_net,self.n_dir,self.n_w)), constraint=constraints.positive) self.gp_coord_demean = pyro.param( f"gp_coord_demean_loc", self.gp_coord_demean_ini ) # Posterior Covariance of the GP self.gp_coord_cov_tril = pyro.param( "gp_coord_cov_tril", self.Lff_ini.expand(self.V_net,self.H_dim,self.K_net,self.n_dir,self.n_w,self.T_net,self.T_net), constraint=constraints.lower_cholesky ) with pyro.plate('gp_coord_all', self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w ): # Posterior GP (mean function params) # pyro.sample( "gp_coord_mean", dist.Normal( self.gp_coord_mean_loc.reshape(self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w), self.gp_coord_mean_scale.reshape(self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w) ) ) # Posterior GP (demeaned) # pyro.sample( f"gp_coord_demean", dist.MultivariateNormal( self.gp_coord_demean.reshape(self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w , self.T_net), scale_tril=self.gp_coord_cov_tril.reshape(self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w , self.T_net, self.T_net) ) ) # Sampling sociability and popularity # if self.socpop: self.gp_socpop_mean_loc = pyro.param("gp_socpop_mean_loc", self.gp_socpop_mean_ini ) self.gp_socpop_mean_scale = pyro.param("gp_socpop_mean_scale", 0.1*torch.ones((self.V_net,self.K_net,self.n_dir,self.n_w)), constraint=constraints.positive) self.gp_socpop_demean = pyro.param( f"gp_socpop_demean_loc", self.gp_socpop_demean_ini ) # Posterior Covariance of the GP self.gp_socpop_cov_tril = pyro.param( "gp_socpop_cov_tril", self.Lff_ini.expand(self.V_net,self.K_net,self.n_dir,self.n_w,self.T_net,self.T_net), constraint=constraints.lower_cholesky ) with pyro.plate('gp_socpop_all', self.V_net*self.K_net*self.n_dir*self.n_w ): # Posterior GP (mean function params) # pyro.sample( "gp_socpop_mean", dist.Normal( self.gp_socpop_mean_loc.reshape(self.V_net*self.K_net*self.n_dir*self.n_w), self.gp_socpop_mean_scale.reshape(self.V_net*self.K_net*self.n_dir*self.n_w) ) ) # Posterior GP (demeaned) # pyro.sample( f"gp_socpop_demean", dist.MultivariateNormal( self.gp_socpop_demean.reshape(self.V_net*self.K_net*self.n_dir*self.n_w , self.T_net), scale_tril=self.gp_socpop_cov_tril.reshape(self.V_net*self.K_net*self.n_dir*self.n_w , self.T_net, self.T_net) ) ) # pyro.sample( "kernel.variance", dist.InverseGamma( self.kernel_param[1,0], self.kernel_param[1,1] ) ) # Sampling variance of weights if self.weighted: self.sigma_k_post_loc = pyro.param("sigma_k_post_loc", torch.ones([1]), constraint=constraints.positive ) self.sigma_k_post_scale = pyro.param("sigma_k_post_scale", torch.ones([1]), constraint=constraints.positive ) with pyro.plate( "sigma_k_ind", self.K_net): sigma_k = pyro.sample( 'sigma_k', dist.InverseGamma( self.sigma_k_post_loc, self.sigma_k_post_scale ) )
def guide(): p = pyro.param("p", Variable(torch.ones(1), requires_grad=True)) pyro.sample("mu_q", dist.normal, ng_zeros(1), p) pyro.sample("mu_q_2", dist.normal, ng_zeros(1), p)
def wrapped_model(x_data, y_data): pyro.sample("prediction", Delta(model(x_data, y_data)))
def model_obs_dup(): pyro.sample("mu_q", dist.normal, ng_zeros(1), ng_ones(1)) pyro.observe("mu_q", dist.normal, ng_zeros(1), ng_ones(1), ng_zeros(1))
def model(): a = pyro.sample("a", dist.Dirichlet(torch.ones(3))) b = pyro.sample("b", dist.Categorical(a)) c = pyro.sample("c", dist.Normal(torch.zeros(3), 1).to_event(1)) d = pyro.sample("d", dist.Poisson(c[b].exp())) pyro.sample("e", dist.Normal(d, 1), obs=torch.ones(()))
def make_normal_normal(): mu_latent = pyro.sample("mu_latent", pyro.distributions.Normal(0, 1)) fn = lambda scale: normal_product(mu_latent, scale) return fn
def model(): pyro.sample("mu_q", dist.normal, ng_zeros(1), ng_ones(1))
def model(): a = pyro.sample("a", dist.Normal(0, 1)) pyro.factor("b", torch.tensor(0.0)) pyro.factor("c", a)
def model_dup(): pyro.param("mu_q", Variable(torch.ones(1), requires_grad=True)) pyro.sample("mu_q", dist.normal, ng_zeros(1), ng_ones(1))
def model_3(): with pyro.plate("p", 5): a = pyro.sample("a", dist.Normal(0, 1)) pyro.sample("b", dist.Normal(a.sum(), 1), obs=torch.tensor(0.0))
def model(): lambda_latent = pyro.sample("lambda_latent", dist.gamma, self.alpha0, self.beta0) pyro.observe("obs0", dist.exponential, self.data[0], lambda_latent) pyro.observe("obs1", dist.exponential, self.data[1], lambda_latent) return lambda_latent
def model_2(): a = pyro.sample("a", dist.Normal(0, 1)) b = pyro.sample("b", dist.LogNormal(0, 1)) c = pyro.sample("c", dist.Normal(a, b)) pyro.sample("d", dist.Normal(c, 1), obs=torch.tensor(0.0))
def init_params(data): params = {} params["sigma_a1"] = pyro.sample("sigma_a1", dist.HalfCauchy(2.5)) params["sigma_a2"] = pyro.sample("sigma_a2", dist.HalfCauchy(2.5)) params["sigma_y"] = pyro.sample("sigma_y", dist.HalfCauchy(2.5)) return params
def ice_cream_sales(): cloudy, temp = weather() expected_sales = 200. if cloudy == 'sunny' and temp > 80.0 else 50. ice_cream = pyro.sample('ice_cream', pyro.distributions.Normal(expected_sales, 10.0)) return ice_cream