def model(self, X): _id = self._id N, D = X.shape global_shrinkage_prior_scale_init = self.param_init[f'global_shrinkage_prior_scale_init_{_id}'] cov_diag_prior_loc_init = self.param_init[f'cov_diag_prior_loc_init_{_id}'] cov_diag_prior_scale_init = self.param_init[f'cov_diag_prior_scale_init_{_id}'] global_shrinkage_prior_scale = pyro.param(f'global_shrinkage_scale_prior_{_id}', global_shrinkage_prior_scale_init, constraint=constraints.positive) tau = pyro.sample(f'global_shrinkage_{_id}', dist.HalfNormal(global_shrinkage_prior_scale)) b = pyro.sample('b', dist.InverseGamma(0.5,1./torch.ones(D)**2).to_event(1)) lambdasquared = pyro.sample(f'local_shrinkage_{_id}', dist.InverseGamma(0.5,1./b).to_event(1)) cov_diag_loc = pyro.param(f'cov_diag_prior_loc_{_id}', cov_diag_prior_loc_init) cov_diag_scale = pyro.param(f'cov_diag_prior_scale_{_id}', cov_diag_prior_scale_init, constraint=constraints.positive) cov_diag = pyro.sample(f'cov_diag_{_id}', dist.LogNormal(cov_diag_loc, cov_diag_scale).to_event(1)) #cov_diag = cov_diag*torch.ones(D) cov_diag = cov_diag + jitter lambdasquared = lambdasquared.squeeze() if lambdasquared.dim() == 1: # outer product cov_factor_scale = torch.ger(torch.sqrt(lambdasquared),tau.repeat((tau.dim()-1)*(1,)+(D,))) else: # batch outer product cov_factor_scale = torch.einsum('bp, br->bpr', torch.sqrt(lambdasquared),tau.repeat((tau.dim()-1)*(1,)+(D,))) cov_factor = pyro.sample(f'cov_factor_{_id}', dist.Normal(0., cov_factor_scale).to_event(2)) cov_factor = cov_factor.transpose(-2,-1) with pyro.plate(f'N_{_id}', size=N, subsample_size=self.batch_size, dim=-1) as ind: X = pyro.sample('obs', dist.LowRankMultivariateNormal(torch.zeros(D), cov_factor=cov_factor, cov_diag=cov_diag), obs=X.index_select(0, ind)) return X
def __init__( self, edgelist, H_dim=3, random_kernel=False, jitter=1e-3 ): super(dmn_toy, self).__init__() Y, [all_nodes, Y_time, all_layers] = pydmn.util.edgelist_to_tensor( edgelist = edgelist ) self.all_layers = all_layers self.set_data(Y=Y, Y_time=Y_time, H_dim=H_dim ) self.weighted = False self.directed = False self.coord = True self.socpop = False self.random_kernel = random_kernel self.jitter = jitter # self.kernel = pydmn.kernels.RBF( random_param=self.random_kernel ) ### Variational Parameters ### self.gp_mean_param = torch.ones((self.V_net,self.H_dim,2)) self.gp_coord_demean = torch.zeros((self.V_net,self.H_dim,self.T_net)) self.gp_cov_tril = torch.eye(self.T_net).expand((self.V_net,self.H_dim,self.T_net,self.T_net)) if self.random_kernel: self.kernel_param = torch.ones((2,2)) # If the Kernel IS random, we use PyroSample if self.random_kernel: self.kernel = pydmn.kernels.RBF() self.kernel.lengthscale = pyro.nn.PyroSample( dist.InverseGamma(torch.tensor([4.]),torch.tensor([30.])) ) self.kernel.variance = pyro.nn.PyroSample( dist.InverseGamma(torch.tensor([11.]),torch.tensor([10.])) )
def model(X, Y, hypers, jitter=1.0e-4): S, P, N = hypers['expected_sparsity'], X.size(1), X.size(0) sigma = pyro.sample("sigma", dist.HalfNormal(hypers['alpha3'])) phi = sigma * (S / math.sqrt(N)) / (P - S) eta1 = pyro.sample("eta1", dist.HalfCauchy(phi)) msq = pyro.sample("msq", dist.InverseGamma(hypers['alpha1'], hypers['beta1'])) xisq = pyro.sample("xisq", dist.InverseGamma(hypers['alpha2'], hypers['beta2'])) eta2 = eta1.pow(2.0) * xisq.sqrt() / msq lam = pyro.sample( "lambda", dist.HalfCauchy(torch.ones(P, device=X.device)).to_event(1)) kappa = msq.sqrt() * lam / (msq + (eta1 * lam).pow(2.0)).sqrt() kX = kappa * X # compute the kernel for the given hyperparameters k = kernel( kX, kX, eta1, eta2, hypers['c']) + (sigma**2 + jitter) * torch.eye(N, device=X.device) # observe the outputs Y pyro.sample("Y", dist.MultivariateNormal(torch.zeros(N, device=X.device), covariance_matrix=k), obs=Y)
def model(self, data, batch_idx): # The beta's are conditionally independent. with pyro.plate("beta_plate", self.T - 1): beta = pyro.sample("beta", dist.Beta(1, self.alpha)) # The components are conditionally independent. # to_event(1) indicates the second dimension of the tensor are dependent. with pyro.plate("mu_plate", self.T): mu_sd = pyro.sample( "musd", dist.InverseGamma(torch.ones_like(self.sd_q1), torch.ones_like(self.sd_q2)).to_event(1)) mu_c = pyro.sample( "mu", dist.Normal(self.mu_c, mu_sd * torch.ones_like(self.mu_c)).to_event(1)) # The data is conditionally independent. with pyro.plate("data", size=self.num_obs, subsample=batch_idx): ys = pyro.sample("cat", dist.Categorical(mix_weights(beta)), infer={"enumerate": "parallel"}) pyro.sample("obs", dist.Normal(mu_c[ys], mu_sd[ys]).to_event(1), obs=data[batch_idx])
def guide(self, X): _id = self._id N, D = X.shape global_shrinkage_loc_init = self.param_init[f'global_shrinkage_loc_init_{_id}'] global_shrinkage_scale_init = self.param_init[f'global_shrinkage_scale_init_{_id}'] local_shrinkage_loc_init = self.param_init[f'local_shrinkage_loc_init_{_id}'] local_shrinkage_scale_init = self.param_init[f'local_shrinkage_scale_init_{_id}'] cov_diag_loc_init = self.param_init[f'cov_diag_loc_init_{_id}'] cov_diag_scale_init = self.param_init[f'cov_diag_scale_init_{_id}'] cov_factor_loc_init = self.param_init[f'cov_factor_loc_init_{_id}'] global_shrinkage_loc = pyro.param(f'global_shrinkage_loc_{_id}', global_shrinkage_loc_init) global_shrinkage_scale = pyro.param(f'global_shrinkage_scale_{_id}', global_shrinkage_scale_init, constraint=constraints.positive) tau = pyro.sample(f'global_shrinkage_{_id}', dist.LogNormal(global_shrinkage_loc,global_shrinkage_scale)) b = pyro.sample('b', dist.InverseGamma(0.5,1./torch.ones(D)**2).to_event(1)) local_shrinkage_loc = pyro.param(f'local_shrinkage_loc_{_id}', local_shrinkage_loc_init) local_shrinkage_scale = pyro.param(f'local_shrinkage_scale_{_id}', local_shrinkage_scale_init, constraint=constraints.positive) lambdasquared = pyro.sample(f'local_shrinkage_{_id}', dist.LogNormal(local_shrinkage_loc,local_shrinkage_scale).to_event(1)) cov_diag_loc = pyro.param(f'cov_diag_loc_{_id}', cov_diag_loc_init) cov_diag_scale = pyro.param(f'cov_diag_scale_{_id}', cov_diag_scale_init, constraint=constraints.positive) cov_diag = pyro.sample(f'cov_diag_{_id}', dist.LogNormal(cov_diag_loc, cov_diag_scale).to_event(1)) cov_diag = cov_diag*torch.ones(D) lambdasquared = lambdasquared.squeeze() if lambdasquared.dim() == 1: cov_factor_scale = torch.ger(torch.sqrt(lambdasquared),tau.repeat((tau.dim()-1)*(1,)+(D,))) else: cov_factor_scale = torch.einsum('bp, br->bpr', torch.sqrt(lambdasquared),tau.repeat((tau.dim()-1)*(1,)+(D,))) cov_factor_loc = pyro.param(f'cov_factor_loc_{_id}', cov_factor_loc_init) cov_factor = pyro.sample(f'cov_factor_{_id}', dist.Normal(cov_factor_loc, cov_factor_scale).to_event(2)) return tau, lambdasquared, cov_factor, cov_diag
def model(self, X, y): input_size = X.shape[0] input_dim = X.shape[1] # initialise w_mu_init = torch.zeros(input_dim) w_sigma_init = torch.eye(input_dim) self.sigma_concentration = pyro.param('sigma_concentration', torch.tensor(10), constraint=constraints.positive) self.sigma_rate = pyro.param('sigma_rate', torch.tensor(50), constraint=constraints.positive) # priors self.bias = pyro.sample('bias', dist.Normal(0, 10)) self.weights = pyro.sample('weights', dist.MultivariateNormal(w_mu_init, w_sigma_init)) self.sigma = pyro.sample('sigma', dist.InverseGamma(concentration=self.sigma_concentration, rate=self.sigma_rate)) # expected and obs prior_mean = self.bias + X @ self.weights with pyro.plate('data', input_size): return pyro.sample('obs', dist.Normal(prior_mean, self.sigma), obs=y)
def __init__(self, name=''): """ Parameters ---------- resp_model : callable that generates a tuple (surface_area, layers) """ super().__init__(name=name) self.obs_scale = PyroSample(dist.InverseGamma(torch.tensor(2.), 0.5))
def guide(self): # Ppsterior Covariance of the GP if self.random_kernel: self.kernel_param = pyro.param("kernel_param", torch.ones((2,2)), constraint=constraints.positive) pyro.sample( "kernel.lengthscale", dist.InverseGamma( self.kernel_param[0,0], self.kernel_param[0,1] ) ) pyro.sample( "kernel.variance", dist.InverseGamma( self.kernel_param[1,0], self.kernel_param[1,1] ) ) # Posterior GP (mean function params) self.gp_mean_loc = pyro.param("gp_mean_loc", torch.zeros((self.V_net,self.H_dim))) self.gp_mean_scale = pyro.param("gp_mean_scale", torch.ones((self.V_net,self.H_dim)), constraint=constraints.positive) # Posterior GP (demeaned) self.gp_coord_demean = pyro.param( f"gp_coord_demean_loc", torch.zeros((self.V_net,self.H_dim,self.T_net)) ) # Posterior Covariance of the GP self.gp_cov_tril = pyro.param( f"gp_cov_tril", torch.eye(self.T_net).expand(self.V_net,self.H_dim,self.T_net,self.T_net), constraint=constraints.lower_cholesky ) with pyro.plate('gp_coord_all', self.V_net*self.H_dim ): pyro.sample( "gp_mean", dist.Normal( self.gp_mean_loc.reshape(self.V_net*self.H_dim), self.gp_mean_scale.reshape(self.V_net*self.H_dim) ) ) pyro.sample( f"gp_coord_demean", dist.MultivariateNormal( self.gp_coord_demean.reshape(self.V_net * self.H_dim, self.T_net), scale_tril=self.gp_cov_tril.reshape(self.V_net * self.H_dim, self.T_net, self.T_net) ) )
def pyrocov_model_relaxed(dataset): # Tensor shapes are commented at the end of some lines. features = dataset["features"] local_time = dataset["local_time"][..., None] # [T, P, 1] T, P, _ = local_time.shape S, F = features.shape weekly_strains = dataset["weekly_strains"] assert weekly_strains.shape == (T, P, S) # Sample global random variables. coef_scale = pyro.sample("coef_scale", dist.InverseGamma(5e3, 1e2))[..., None] rate_loc_scale = pyro.sample("rate_loc_scale", dist.LogNormal(-4, 2))[..., None] rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2))[..., None] init_loc_scale = pyro.sample("init_loc_scale", dist.LogNormal(0, 2))[..., None] init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2))[..., None] # Assume relative growth rate depends strongly on mutations and weakly on place. coef_loc = torch.zeros(F) coef = pyro.sample("coef", dist.Logistic(coef_loc, coef_scale).to_event(1)) # [F] rate_loc = pyro.sample( "rate_loc", dist.Normal(0.01 * coef @ features.T, rate_loc_scale).to_event(1), ) # [S] # Assume initial infections depend strongly on strain and place. init_loc = pyro.sample( "init_loc", dist.Normal(torch.zeros(S), init_loc_scale).to_event(1) ) # [S] with pyro.plate("place", P, dim=-1): rate = pyro.sample( "rate", dist.Normal(rate_loc, rate_scale).to_event(1) ) # [P, S] init = pyro.sample( "init", dist.Normal(init_loc, init_scale).to_event(1) ) # [P, S] # Finally observe counts. with pyro.plate("time", T, dim=-2): logits = init + rate * local_time # [T, P, S] pyro.sample( "obs", dist.Multinomial(logits=logits, validate_args=False), obs=weekly_strains, )
def pyrocov_model_plated(dataset): # Tensor shapes are commented at the end of some lines. features = dataset["features"] local_time = dataset["local_time"][..., None] # [T, P, 1] T, P, _ = local_time.shape S, F = features.shape weekly_strains = dataset["weekly_strains"] # [T, P, S] assert weekly_strains.shape == (T, P, S) feature_plate = pyro.plate("feature", F, dim=-1) strain_plate = pyro.plate("strain", S, dim=-1) place_plate = pyro.plate("place", P, dim=-2) time_plate = pyro.plate("time", T, dim=-3) # Sample global random variables. coef_scale = pyro.sample("coef_scale", dist.InverseGamma(5e3, 1e2)) rate_loc_scale = pyro.sample("rate_loc_scale", dist.LogNormal(-4, 2)) rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2)) init_loc_scale = pyro.sample("init_loc_scale", dist.LogNormal(0, 2)) init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2)) with feature_plate: coef = pyro.sample("coef", dist.Logistic(0, coef_scale)) # [F] rate_loc_loc = 0.01 * coef @ features.T with strain_plate: rate_loc = pyro.sample( "rate_loc", dist.Normal(rate_loc_loc, rate_loc_scale) ) # [S] init_loc = pyro.sample("init_loc", dist.Normal(0, init_loc_scale)) # [S] with place_plate, strain_plate: rate = pyro.sample("rate", dist.Normal(rate_loc, rate_scale)) # [P, S] init = pyro.sample("init", dist.Normal(init_loc, init_scale)) # [P, S] # Finally observe counts. with time_plate, place_plate: logits = (init + rate * local_time)[..., None, :] # [T, P, 1, S] pyro.sample( "obs", dist.Multinomial(logits=logits, validate_args=False), obs=weekly_strains[..., None, :], )
def guide(self, data, batch_idx): # Define all the variational parameters alpha_q = pyro.param('alpha_q', self.alpha_q, constraint=constraints.positive) rho = pyro.param('rho', self.rho, constraint=constraints.simplex) mu_q = pyro.param("mu_q", self.mu_c) sd_q1 = pyro.param('sd_q1', self.sd_q1) sd_q2 = pyro.param('sd_q2', self.sd_q2) with pyro.plate("beta_plate", self.T - 1): f_beta = pyro.sample("beta", dist.Beta(torch.ones(self.T - 1), alpha_q)) with pyro.plate("mu_plate", self.T): mu_sd = pyro.sample("musd", dist.InverseGamma(sd_q1, sd_q2).to_event(1)) mu_c = pyro.sample("mu", dist.Normal(mu_q, mu_sd).to_event(1)) with pyro.plate("data", size=self.num_obs, subsample=batch_idx): f_cat = pyro.sample("cat", dist.Categorical(rho[batch_idx]))
import torch as th import pyro from pyro import distributions pyro.set_rng_seed(1) # ---------- # # GIVEN DATA # # ---------- # my_mean = 180.0 prioalpha = 38.0 priobeta = 1110.0 my_invgamma = distributions.InverseGamma(prioalpha, priobeta) my_sample = np.array( [ 183.0, 173.0, 181.0, 170.0, 176.0, 180.0, 187.0, 176.0, 171.0, 190.0, 184.0, 173.0,
def linear(xes, yes): slope = pyro.sample("slope", dist.Normal(5, 10)) intercept = pyro.sample("intercept", dist.Normal(0, 10)) var = pyro.sample("var", dist.InverseGamma(3, 0.1)) x = slope * xes return slope
def guide(self): # Posterior Covariance of the GP # if self.random_kernel: # self.kernel_param = pyro.param("kernel_param", 50*torch.ones((2,2)), constraint=constraints.positive) # pyro.sample( "kernel.lengthscale", dist.InverseGamma( self.kernel_param[0,0], self.kernel_param[0,1] ) ) # pyro.sample( "kernel.variance", dist.InverseGamma( self.kernel_param[1,0], self.kernel_param[1,1] ) ) # Sampling Systemic components # self.gp_system_mean_loc = pyro.param("gp_system_mean_loc", self.gp_system_mean_ini ) self.gp_system_mean_scale = pyro.param("gp_system_mean_scale", 0.1*torch.ones((self.K_net,self.n_w)), constraint=constraints.positive) self.gp_system_demean = pyro.param( f"gp_system_demean_loc", self.gp_system_demean_ini ) # Posterior Covariance of the GP self.gp_system_cov_tril = pyro.param( "gp_system_cov_tril", self.Lff_ini.expand(self.K_net,self.n_w,self.T_net,self.T_net), constraint=constraints.lower_cholesky ) with pyro.plate('gp_system_all', self.K_net*self.n_w ): # Posterior GP (mean function params) # pyro.sample( "gp_system_mean", dist.Normal( self.gp_system_mean_loc.reshape(self.K_net*self.n_w), self.gp_system_mean_scale.reshape(self.K_net*self.n_w) ) ) # Posterior GP (demeaned) # pyro.sample( f"gp_system_demean", dist.MultivariateNormal( self.gp_system_demean.reshape(self.K_net*self.n_w , self.T_net), scale_tril=self.gp_system_cov_tril.reshape(self.K_net*self.n_w , self.T_net, self.T_net) ) ) # Sampling coordinates # if self.coord: self.gp_coord_mean_loc = pyro.param("gp_coord_mean_loc", self.gp_coord_mean_ini ) self.gp_coord_mean_scale = pyro.param("gp_coord_mean_scale", 0.1*torch.ones((self.V_net,self.H_dim,self.K_net,self.n_dir,self.n_w)), constraint=constraints.positive) self.gp_coord_demean = pyro.param( f"gp_coord_demean_loc", self.gp_coord_demean_ini ) # Posterior Covariance of the GP self.gp_coord_cov_tril = pyro.param( "gp_coord_cov_tril", self.Lff_ini.expand(self.V_net,self.H_dim,self.K_net,self.n_dir,self.n_w,self.T_net,self.T_net), constraint=constraints.lower_cholesky ) with pyro.plate('gp_coord_all', self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w ): # Posterior GP (mean function params) # pyro.sample( "gp_coord_mean", dist.Normal( self.gp_coord_mean_loc.reshape(self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w), self.gp_coord_mean_scale.reshape(self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w) ) ) # Posterior GP (demeaned) # pyro.sample( f"gp_coord_demean", dist.MultivariateNormal( self.gp_coord_demean.reshape(self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w , self.T_net), scale_tril=self.gp_coord_cov_tril.reshape(self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w , self.T_net, self.T_net) ) ) # Sampling sociability and popularity # if self.socpop: self.gp_socpop_mean_loc = pyro.param("gp_socpop_mean_loc", self.gp_socpop_mean_ini ) self.gp_socpop_mean_scale = pyro.param("gp_socpop_mean_scale", 0.1*torch.ones((self.V_net,self.K_net,self.n_dir,self.n_w)), constraint=constraints.positive) self.gp_socpop_demean = pyro.param( f"gp_socpop_demean_loc", self.gp_socpop_demean_ini ) # Posterior Covariance of the GP self.gp_socpop_cov_tril = pyro.param( "gp_socpop_cov_tril", self.Lff_ini.expand(self.V_net,self.K_net,self.n_dir,self.n_w,self.T_net,self.T_net), constraint=constraints.lower_cholesky ) with pyro.plate('gp_socpop_all', self.V_net*self.K_net*self.n_dir*self.n_w ): # Posterior GP (mean function params) # pyro.sample( "gp_socpop_mean", dist.Normal( self.gp_socpop_mean_loc.reshape(self.V_net*self.K_net*self.n_dir*self.n_w), self.gp_socpop_mean_scale.reshape(self.V_net*self.K_net*self.n_dir*self.n_w) ) ) # Posterior GP (demeaned) # pyro.sample( f"gp_socpop_demean", dist.MultivariateNormal( self.gp_socpop_demean.reshape(self.V_net*self.K_net*self.n_dir*self.n_w , self.T_net), scale_tril=self.gp_socpop_cov_tril.reshape(self.V_net*self.K_net*self.n_dir*self.n_w , self.T_net, self.T_net) ) ) # pyro.sample( "kernel.variance", dist.InverseGamma( self.kernel_param[1,0], self.kernel_param[1,1] ) ) # Sampling variance of weights if self.weighted: self.sigma_k_post_loc = pyro.param("sigma_k_post_loc", torch.ones([1]), constraint=constraints.positive ) self.sigma_k_post_scale = pyro.param("sigma_k_post_scale", torch.ones([1]), constraint=constraints.positive ) with pyro.plate( "sigma_k_ind", self.K_net): sigma_k = pyro.sample( 'sigma_k', dist.InverseGamma( self.sigma_k_post_loc, self.sigma_k_post_scale ) )
def model(self): # If the Kernel IS NOT random, we declare the kernel within the model # if (not self.random_kernel): # self.kernel = pydmn.kernels.RBF() # Covariance matrix of observed times entailed by our kernel # Kff = self.kernel(self.Y_time.reshape(-1,1)) # Kff.view(-1)[::self.T_net + 1] += self.jitter # add jitter to the diagonal # Lff = Kff.cholesky() # cholesky lower triangular Lff = self.Lff_ini ## Sampling system-wide connectivity and average weights ## with pyro.plate('gp_system_all', self.K_net*self.n_w ): # Mean function of the GPs gp_system_mean = pyro.sample( "gp_system_mean", dist.Normal( torch.zeros( (self.K_net*self.n_w) ), torch.tensor([0.1]) ) ) # Demeaned GPs gp_system_demean = pyro.sample( "gp_system_demean", dist.MultivariateNormal( torch.zeros( (self.K_net*self.n_w, self.T_net) ), scale_tril=Lff ) ) gp_system_mean = gp_system_mean.reshape(self.K_net,self.n_w) gp_system_demean = gp_system_demean.reshape(self.K_net, self.n_w, self.T_net) # Latent systemic evolution gp_system = gp_system_mean.expand(self.T_net, self.K_net, self.n_w).permute(1,2,0) + gp_system_demean ## Sampling latent coordinates ## if self.coord: with pyro.plate('gp_coord_all', self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w ): # Mean function of the GPs gp_coord_mean = pyro.sample( "gp_coord_mean", dist.Normal( torch.zeros( (self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w) ), torch.tensor([0.1]) ) ) # Demeaned GPs gp_coord_demean = pyro.sample( "gp_coord_demean", dist.MultivariateNormal( torch.zeros( (self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w, self.T_net) ), scale_tril=Lff ) ) gp_coord_mean = gp_coord_mean.reshape(self.V_net,self.H_dim,self.K_net,self.n_dir,self.n_w) gp_coord_demean = gp_coord_demean.reshape(self.V_net, self.H_dim, self.K_net, self.n_dir, self.n_w, self.T_net) # Latent coordinates gp_coord = gp_coord_mean.expand(self.T_net, self.V_net, self.H_dim, self.K_net, self.n_dir, self.n_w).permute(1,2,3,4,5,0) + gp_coord_demean ## Sampling Sociability and Popularity terms ## if self.socpop: with pyro.plate('gp_socpop_all', self.V_net*self.K_net*self.n_dir*self.n_w ): # Mean function of the GPs gp_socpop_mean = pyro.sample( "gp_socpop_mean", dist.Normal( torch.zeros( (self.V_net*self.K_net*self.n_dir*self.n_w) ), torch.tensor([0.1]) ) ) # Demeaned GPs gp_socpop_demean = pyro.sample( "gp_socpop_demean", dist.MultivariateNormal( torch.zeros( (self.V_net*self.K_net*self.n_dir*self.n_w, self.T_net) ), scale_tril=Lff ) ) gp_socpop_mean = gp_socpop_mean.reshape(self.V_net,self.K_net,self.n_dir,self.n_w) gp_socpop_demean = gp_socpop_demean.reshape(self.V_net, self.K_net, self.n_dir, self.n_w, self.T_net) # Latent coordinates gp_socpop = gp_socpop_mean.expand(self.T_net, self.V_net, self.K_net, self.n_dir, self.n_w).permute(1,2,3,4,0) + gp_socpop_demean ### Linear Predictor ### # Systemic component Y_linpred = gp_system.expand(self.V_net, self.V_net, self.K_net, self.n_w, self.T_net).permute(0,1,4,2,3) # Distance between agents if self.coord: Y_linpred = Y_linpred + torch.einsum('uhkwt,vhkwt->uvtkw', gp_coord[:,:,:,0,:,:], gp_coord[:,:,:,self.n_dir-1,:,:]) # Sociability and Popularity effects if self.socpop: gp_soc = gp_socpop[:,:,0,:,:].expand(self.V_net, self.V_net, self.K_net, self.n_w, self.T_net).transpose(0,1) gp_pop = gp_socpop[:,:,self.n_dir-1,:,:].expand(self.V_net, self.V_net, self.K_net, self.n_w,self.T_net) Y_linpred = Y_linpred + gp_soc.permute(0,1,4,2,3) + gp_pop.permute(0,1,4,2,3) ### Link propensity (probability of occur) ### Y_link_prob = torch.sigmoid(Y_linpred[:,:,:,:,0]) Y_link_prob_valid = Y_link_prob.flatten()[self.Y_valid_id.flatten()==1] with pyro.plate( "data", Y_link_prob_valid.shape[0]): pyro.sample( "obs", dist.Bernoulli(Y_link_prob_valid), obs=self.Y_link.flatten()[self.cond_Y_link] ) ### Link expected weight (weight given occurrence) ### if self.weighted: with pyro.plate( "sigma_k_ind", self.K_net): sigma_k = pyro.sample( 'sigma_k', dist.InverseGamma( self.sigma_k_prior_param[0].expand(self.K_net), self.sigma_k_prior_param[1].expand(self.K_net) ) ) Y_link_SDw = sigma_k.expand(self.V_net,self.V_net,self.T_net,self.K_net) Y_link_Ew = Y_linpred[:,:,:,:,1] # cond_Y_w: condition of being positive and valid weights (defined in set_data()) Y_link_Ew_valid = Y_link_Ew.flatten()[self.cond_Y_w] Y_link_SDw_valid = Y_link_SDw.flatten()[self.cond_Y_w] with pyro.plate( "data_w", Y_link_Ew_valid.shape[0] ): pyro.sample( "obs_w", dist.Normal(Y_link_Ew_valid,Y_link_SDw_valid), obs=self.Y.flatten()[self.cond_Y_w] )