Esempio n. 1
0
    def model(self, X):
        _id = self._id
        N, D = X.shape
        global_shrinkage_prior_scale_init = self.param_init[f'global_shrinkage_prior_scale_init_{_id}']
        cov_diag_prior_loc_init = self.param_init[f'cov_diag_prior_loc_init_{_id}']
        cov_diag_prior_scale_init = self.param_init[f'cov_diag_prior_scale_init_{_id}']


        global_shrinkage_prior_scale = pyro.param(f'global_shrinkage_scale_prior_{_id}', global_shrinkage_prior_scale_init, constraint=constraints.positive)
        tau = pyro.sample(f'global_shrinkage_{_id}', dist.HalfNormal(global_shrinkage_prior_scale))
        
        b = pyro.sample('b', dist.InverseGamma(0.5,1./torch.ones(D)**2).to_event(1))
        lambdasquared = pyro.sample(f'local_shrinkage_{_id}', dist.InverseGamma(0.5,1./b).to_event(1))
        
        cov_diag_loc = pyro.param(f'cov_diag_prior_loc_{_id}', cov_diag_prior_loc_init)
        cov_diag_scale = pyro.param(f'cov_diag_prior_scale_{_id}', cov_diag_prior_scale_init, constraint=constraints.positive)
        cov_diag = pyro.sample(f'cov_diag_{_id}', dist.LogNormal(cov_diag_loc, cov_diag_scale).to_event(1))
        #cov_diag = cov_diag*torch.ones(D)
        cov_diag = cov_diag + jitter
        
        lambdasquared = lambdasquared.squeeze()
        if lambdasquared.dim() == 1:
            # outer product
            cov_factor_scale = torch.ger(torch.sqrt(lambdasquared),tau.repeat((tau.dim()-1)*(1,)+(D,)))
        else:
            # batch outer product
            cov_factor_scale = torch.einsum('bp, br->bpr', torch.sqrt(lambdasquared),tau.repeat((tau.dim()-1)*(1,)+(D,)))
        cov_factor = pyro.sample(f'cov_factor_{_id}', dist.Normal(0., cov_factor_scale).to_event(2))
        cov_factor = cov_factor.transpose(-2,-1)
        with pyro.plate(f'N_{_id}', size=N, subsample_size=self.batch_size, dim=-1) as ind:
            X = pyro.sample('obs', dist.LowRankMultivariateNormal(torch.zeros(D), cov_factor=cov_factor, cov_diag=cov_diag), obs=X.index_select(0, ind))
        return X
Esempio n. 2
0
    def __init__( self, edgelist, H_dim=3, random_kernel=False, jitter=1e-3 ):

        super(dmn_toy, self).__init__()

        Y, [all_nodes, Y_time, all_layers] = pydmn.util.edgelist_to_tensor( edgelist = edgelist )
        self.all_layers = all_layers

        self.set_data(Y=Y, Y_time=Y_time, H_dim=H_dim )

        self.weighted = False
        self.directed = False
        self.coord = True
        self.socpop = False

        self.random_kernel = random_kernel
        self.jitter = jitter
        # self.kernel = pydmn.kernels.RBF( random_param=self.random_kernel )

        ### Variational Parameters ###
        self.gp_mean_param = torch.ones((self.V_net,self.H_dim,2))
        self.gp_coord_demean = torch.zeros((self.V_net,self.H_dim,self.T_net))
        self.gp_cov_tril = torch.eye(self.T_net).expand((self.V_net,self.H_dim,self.T_net,self.T_net))

        if self.random_kernel:
            self.kernel_param = torch.ones((2,2))

        # If the Kernel IS random, we use PyroSample
        if self.random_kernel:
            self.kernel = pydmn.kernels.RBF()
            self.kernel.lengthscale = pyro.nn.PyroSample( dist.InverseGamma(torch.tensor([4.]),torch.tensor([30.])) )
            self.kernel.variance = pyro.nn.PyroSample( dist.InverseGamma(torch.tensor([11.]),torch.tensor([10.])) )
Esempio n. 3
0
def model(X, Y, hypers, jitter=1.0e-4):
    S, P, N = hypers['expected_sparsity'], X.size(1), X.size(0)

    sigma = pyro.sample("sigma", dist.HalfNormal(hypers['alpha3']))
    phi = sigma * (S / math.sqrt(N)) / (P - S)
    eta1 = pyro.sample("eta1", dist.HalfCauchy(phi))

    msq = pyro.sample("msq",
                      dist.InverseGamma(hypers['alpha1'], hypers['beta1']))
    xisq = pyro.sample("xisq",
                       dist.InverseGamma(hypers['alpha2'], hypers['beta2']))

    eta2 = eta1.pow(2.0) * xisq.sqrt() / msq

    lam = pyro.sample(
        "lambda",
        dist.HalfCauchy(torch.ones(P, device=X.device)).to_event(1))
    kappa = msq.sqrt() * lam / (msq + (eta1 * lam).pow(2.0)).sqrt()
    kX = kappa * X

    # compute the kernel for the given hyperparameters
    k = kernel(
        kX, kX, eta1, eta2,
        hypers['c']) + (sigma**2 + jitter) * torch.eye(N, device=X.device)

    # observe the outputs Y
    pyro.sample("Y",
                dist.MultivariateNormal(torch.zeros(N, device=X.device),
                                        covariance_matrix=k),
                obs=Y)
Esempio n. 4
0
    def model(self, data, batch_idx):
        # The beta's are conditionally independent.
        with pyro.plate("beta_plate", self.T - 1):
            beta = pyro.sample("beta", dist.Beta(1, self.alpha))

        # The components are conditionally independent.
        # to_event(1) indicates the second dimension of the tensor are dependent.
        with pyro.plate("mu_plate", self.T):
            mu_sd = pyro.sample(
                "musd",
                dist.InverseGamma(torch.ones_like(self.sd_q1),
                                  torch.ones_like(self.sd_q2)).to_event(1))
            mu_c = pyro.sample(
                "mu",
                dist.Normal(self.mu_c,
                            mu_sd * torch.ones_like(self.mu_c)).to_event(1))

        # The data is conditionally independent.
        with pyro.plate("data", size=self.num_obs, subsample=batch_idx):
            ys = pyro.sample("cat",
                             dist.Categorical(mix_weights(beta)),
                             infer={"enumerate": "parallel"})
            pyro.sample("obs",
                        dist.Normal(mu_c[ys], mu_sd[ys]).to_event(1),
                        obs=data[batch_idx])
Esempio n. 5
0
    def guide(self, X):
        _id = self._id
        N, D = X.shape
        global_shrinkage_loc_init = self.param_init[f'global_shrinkage_loc_init_{_id}']
        global_shrinkage_scale_init = self.param_init[f'global_shrinkage_scale_init_{_id}']
        local_shrinkage_loc_init = self.param_init[f'local_shrinkage_loc_init_{_id}']
        local_shrinkage_scale_init = self.param_init[f'local_shrinkage_scale_init_{_id}']
        cov_diag_loc_init = self.param_init[f'cov_diag_loc_init_{_id}']
        cov_diag_scale_init = self.param_init[f'cov_diag_scale_init_{_id}']
        cov_factor_loc_init = self.param_init[f'cov_factor_loc_init_{_id}']

        global_shrinkage_loc = pyro.param(f'global_shrinkage_loc_{_id}', global_shrinkage_loc_init)
        global_shrinkage_scale = pyro.param(f'global_shrinkage_scale_{_id}', global_shrinkage_scale_init, constraint=constraints.positive)
        tau = pyro.sample(f'global_shrinkage_{_id}', dist.LogNormal(global_shrinkage_loc,global_shrinkage_scale))

        b = pyro.sample('b', dist.InverseGamma(0.5,1./torch.ones(D)**2).to_event(1))
        local_shrinkage_loc = pyro.param(f'local_shrinkage_loc_{_id}', local_shrinkage_loc_init)
        local_shrinkage_scale = pyro.param(f'local_shrinkage_scale_{_id}', local_shrinkage_scale_init, constraint=constraints.positive)
        lambdasquared = pyro.sample(f'local_shrinkage_{_id}', dist.LogNormal(local_shrinkage_loc,local_shrinkage_scale).to_event(1))

        cov_diag_loc = pyro.param(f'cov_diag_loc_{_id}', cov_diag_loc_init)
        cov_diag_scale = pyro.param(f'cov_diag_scale_{_id}', cov_diag_scale_init, constraint=constraints.positive)
        cov_diag = pyro.sample(f'cov_diag_{_id}', dist.LogNormal(cov_diag_loc, cov_diag_scale).to_event(1))
        cov_diag = cov_diag*torch.ones(D)

        lambdasquared = lambdasquared.squeeze()
        if lambdasquared.dim() == 1:
            cov_factor_scale = torch.ger(torch.sqrt(lambdasquared),tau.repeat((tau.dim()-1)*(1,)+(D,)))
        else:
            cov_factor_scale = torch.einsum('bp, br->bpr', torch.sqrt(lambdasquared),tau.repeat((tau.dim()-1)*(1,)+(D,)))
        cov_factor_loc = pyro.param(f'cov_factor_loc_{_id}', cov_factor_loc_init)
        cov_factor = pyro.sample(f'cov_factor_{_id}', dist.Normal(cov_factor_loc, cov_factor_scale).to_event(2))
        return tau, lambdasquared, cov_factor, cov_diag
Esempio n. 6
0
    def model(self, X, y):
        input_size = X.shape[0]
        input_dim = X.shape[1]


        # initialise
        w_mu_init = torch.zeros(input_dim)
        w_sigma_init = torch.eye(input_dim)
        self.sigma_concentration = pyro.param('sigma_concentration',
                                              torch.tensor(10),
                                              constraint=constraints.positive)
        self.sigma_rate = pyro.param('sigma_rate',
                                     torch.tensor(50),
                                     constraint=constraints.positive)

        # priors
        self.bias = pyro.sample('bias', dist.Normal(0, 10))
        self.weights = pyro.sample('weights', dist.MultivariateNormal(w_mu_init, w_sigma_init))

        self.sigma = pyro.sample('sigma',
                                 dist.InverseGamma(concentration=self.sigma_concentration,
                                                   rate=self.sigma_rate))

        # expected and obs
        prior_mean = self.bias + X @ self.weights
        with pyro.plate('data', input_size):
            return pyro.sample('obs', dist.Normal(prior_mean, self.sigma), obs=y)
Esempio n. 7
0
 def __init__(self, name=''):
     """
     Parameters
     ----------
     resp_model : callable that generates a tuple (surface_area, layers)
     """
     super().__init__(name=name)
     self.obs_scale = PyroSample(dist.InverseGamma(torch.tensor(2.), 0.5))
Esempio n. 8
0
    def guide(self):

        # Ppsterior Covariance of the GP
        if self.random_kernel:
            self.kernel_param = pyro.param("kernel_param", torch.ones((2,2)), constraint=constraints.positive)
            pyro.sample( "kernel.lengthscale", dist.InverseGamma( self.kernel_param[0,0], self.kernel_param[0,1] ) )
            pyro.sample( "kernel.variance", dist.InverseGamma( self.kernel_param[1,0], self.kernel_param[1,1] ) )

        # Posterior GP (mean function params)
        self.gp_mean_loc = pyro.param("gp_mean_loc", torch.zeros((self.V_net,self.H_dim)))
        self.gp_mean_scale = pyro.param("gp_mean_scale", torch.ones((self.V_net,self.H_dim)), constraint=constraints.positive)
        # Posterior GP (demeaned)
        self.gp_coord_demean = pyro.param( f"gp_coord_demean_loc", torch.zeros((self.V_net,self.H_dim,self.T_net)) )
        # Posterior Covariance of the GP
        self.gp_cov_tril = pyro.param( f"gp_cov_tril", torch.eye(self.T_net).expand(self.V_net,self.H_dim,self.T_net,self.T_net),
                                                        constraint=constraints.lower_cholesky )
        with pyro.plate('gp_coord_all', self.V_net*self.H_dim ):
            pyro.sample( "gp_mean", dist.Normal( self.gp_mean_loc.reshape(self.V_net*self.H_dim), self.gp_mean_scale.reshape(self.V_net*self.H_dim) ) )
            pyro.sample( f"gp_coord_demean",
                                    dist.MultivariateNormal( self.gp_coord_demean.reshape(self.V_net * self.H_dim, self.T_net),
                                                            scale_tril=self.gp_cov_tril.reshape(self.V_net * self.H_dim, self.T_net, self.T_net) ) )
Esempio n. 9
0
def pyrocov_model_relaxed(dataset):
    # Tensor shapes are commented at the end of some lines.
    features = dataset["features"]
    local_time = dataset["local_time"][..., None]  # [T, P, 1]
    T, P, _ = local_time.shape
    S, F = features.shape
    weekly_strains = dataset["weekly_strains"]
    assert weekly_strains.shape == (T, P, S)

    # Sample global random variables.
    coef_scale = pyro.sample("coef_scale", dist.InverseGamma(5e3, 1e2))[..., None]
    rate_loc_scale = pyro.sample("rate_loc_scale", dist.LogNormal(-4, 2))[..., None]
    rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2))[..., None]
    init_loc_scale = pyro.sample("init_loc_scale", dist.LogNormal(0, 2))[..., None]
    init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2))[..., None]

    # Assume relative growth rate depends strongly on mutations and weakly on place.
    coef_loc = torch.zeros(F)
    coef = pyro.sample("coef", dist.Logistic(coef_loc, coef_scale).to_event(1))  # [F]
    rate_loc = pyro.sample(
        "rate_loc",
        dist.Normal(0.01 * coef @ features.T, rate_loc_scale).to_event(1),
    )  # [S]

    # Assume initial infections depend strongly on strain and place.
    init_loc = pyro.sample(
        "init_loc", dist.Normal(torch.zeros(S), init_loc_scale).to_event(1)
    )  # [S]
    with pyro.plate("place", P, dim=-1):
        rate = pyro.sample(
            "rate", dist.Normal(rate_loc, rate_scale).to_event(1)
        )  # [P, S]
        init = pyro.sample(
            "init", dist.Normal(init_loc, init_scale).to_event(1)
        )  # [P, S]

        # Finally observe counts.
        with pyro.plate("time", T, dim=-2):
            logits = init + rate * local_time  # [T, P, S]
            pyro.sample(
                "obs",
                dist.Multinomial(logits=logits, validate_args=False),
                obs=weekly_strains,
            )
Esempio n. 10
0
def pyrocov_model_plated(dataset):
    # Tensor shapes are commented at the end of some lines.
    features = dataset["features"]
    local_time = dataset["local_time"][..., None]  # [T, P, 1]
    T, P, _ = local_time.shape
    S, F = features.shape
    weekly_strains = dataset["weekly_strains"]  # [T, P, S]
    assert weekly_strains.shape == (T, P, S)
    feature_plate = pyro.plate("feature", F, dim=-1)
    strain_plate = pyro.plate("strain", S, dim=-1)
    place_plate = pyro.plate("place", P, dim=-2)
    time_plate = pyro.plate("time", T, dim=-3)

    # Sample global random variables.
    coef_scale = pyro.sample("coef_scale", dist.InverseGamma(5e3, 1e2))
    rate_loc_scale = pyro.sample("rate_loc_scale", dist.LogNormal(-4, 2))
    rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2))
    init_loc_scale = pyro.sample("init_loc_scale", dist.LogNormal(0, 2))
    init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2))

    with feature_plate:
        coef = pyro.sample("coef", dist.Logistic(0, coef_scale))  # [F]
    rate_loc_loc = 0.01 * coef @ features.T
    with strain_plate:
        rate_loc = pyro.sample(
            "rate_loc", dist.Normal(rate_loc_loc, rate_loc_scale)
        )  # [S]
        init_loc = pyro.sample("init_loc", dist.Normal(0, init_loc_scale))  # [S]
    with place_plate, strain_plate:
        rate = pyro.sample("rate", dist.Normal(rate_loc, rate_scale))  # [P, S]
        init = pyro.sample("init", dist.Normal(init_loc, init_scale))  # [P, S]

    # Finally observe counts.
    with time_plate, place_plate:
        logits = (init + rate * local_time)[..., None, :]  # [T, P, 1, S]
        pyro.sample(
            "obs",
            dist.Multinomial(logits=logits, validate_args=False),
            obs=weekly_strains[..., None, :],
        )
Esempio n. 11
0
    def guide(self, data, batch_idx):

        # Define all the variational parameters
        alpha_q = pyro.param('alpha_q',
                             self.alpha_q,
                             constraint=constraints.positive)
        rho = pyro.param('rho', self.rho, constraint=constraints.simplex)
        mu_q = pyro.param("mu_q", self.mu_c)
        sd_q1 = pyro.param('sd_q1', self.sd_q1)
        sd_q2 = pyro.param('sd_q2', self.sd_q2)

        with pyro.plate("beta_plate", self.T - 1):
            f_beta = pyro.sample("beta",
                                 dist.Beta(torch.ones(self.T - 1), alpha_q))

        with pyro.plate("mu_plate", self.T):
            mu_sd = pyro.sample("musd",
                                dist.InverseGamma(sd_q1, sd_q2).to_event(1))
            mu_c = pyro.sample("mu", dist.Normal(mu_q, mu_sd).to_event(1))

        with pyro.plate("data", size=self.num_obs, subsample=batch_idx):
            f_cat = pyro.sample("cat", dist.Categorical(rho[batch_idx]))
import torch as th
import pyro
from pyro import distributions

pyro.set_rng_seed(1)

# ---------- #
# GIVEN DATA #
# ---------- #
my_mean = 180.0

prioalpha = 38.0
priobeta = 1110.0

my_invgamma = distributions.InverseGamma(prioalpha, priobeta)

my_sample = np.array(
    [
        183.0,
        173.0,
        181.0,
        170.0,
        176.0,
        180.0,
        187.0,
        176.0,
        171.0,
        190.0,
        184.0,
        173.0,
Esempio n. 13
0
def linear(xes, yes):
    slope = pyro.sample("slope", dist.Normal(5, 10))
    intercept = pyro.sample("intercept", dist.Normal(0, 10))
    var = pyro.sample("var", dist.InverseGamma(3, 0.1))
    x = slope * xes
    return slope
Esempio n. 14
0
    def guide(self):

        # Posterior Covariance of the GP
        # if self.random_kernel:
        #     self.kernel_param = pyro.param("kernel_param", 50*torch.ones((2,2)), constraint=constraints.positive)
        #     pyro.sample( "kernel.lengthscale", dist.InverseGamma( self.kernel_param[0,0], self.kernel_param[0,1] ) )
        #     pyro.sample( "kernel.variance", dist.InverseGamma( self.kernel_param[1,0], self.kernel_param[1,1] ) )

        # Sampling Systemic components #
        self.gp_system_mean_loc = pyro.param("gp_system_mean_loc", self.gp_system_mean_ini )
        self.gp_system_mean_scale = pyro.param("gp_system_mean_scale", 0.1*torch.ones((self.K_net,self.n_w)), constraint=constraints.positive)
        self.gp_system_demean = pyro.param( f"gp_system_demean_loc", self.gp_system_demean_ini )
        # Posterior Covariance of the GP
        self.gp_system_cov_tril = pyro.param( "gp_system_cov_tril", self.Lff_ini.expand(self.K_net,self.n_w,self.T_net,self.T_net),
                                        constraint=constraints.lower_cholesky )
        with pyro.plate('gp_system_all', self.K_net*self.n_w ):
            # Posterior GP (mean function params) #
            pyro.sample( "gp_system_mean", dist.Normal( self.gp_system_mean_loc.reshape(self.K_net*self.n_w),
                                                self.gp_system_mean_scale.reshape(self.K_net*self.n_w) ) )
            # Posterior GP (demeaned) #
            pyro.sample( f"gp_system_demean",
                                    dist.MultivariateNormal( self.gp_system_demean.reshape(self.K_net*self.n_w , self.T_net),
                                                            scale_tril=self.gp_system_cov_tril.reshape(self.K_net*self.n_w , self.T_net, self.T_net) ) )

        # Sampling coordinates #
        if self.coord:
            self.gp_coord_mean_loc = pyro.param("gp_coord_mean_loc", self.gp_coord_mean_ini )
            self.gp_coord_mean_scale = pyro.param("gp_coord_mean_scale", 0.1*torch.ones((self.V_net,self.H_dim,self.K_net,self.n_dir,self.n_w)), constraint=constraints.positive)
            self.gp_coord_demean = pyro.param( f"gp_coord_demean_loc", self.gp_coord_demean_ini )
            # Posterior Covariance of the GP
            self.gp_coord_cov_tril = pyro.param( "gp_coord_cov_tril", self.Lff_ini.expand(self.V_net,self.H_dim,self.K_net,self.n_dir,self.n_w,self.T_net,self.T_net),
                                            constraint=constraints.lower_cholesky )
            with pyro.plate('gp_coord_all', self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w ):
                # Posterior GP (mean function params) #
                pyro.sample( "gp_coord_mean", dist.Normal( self.gp_coord_mean_loc.reshape(self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w),
                                                    self.gp_coord_mean_scale.reshape(self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w) ) )
                # Posterior GP (demeaned) #
                pyro.sample( f"gp_coord_demean",
                                        dist.MultivariateNormal( self.gp_coord_demean.reshape(self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w , self.T_net),
                                                                scale_tril=self.gp_coord_cov_tril.reshape(self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w , self.T_net, self.T_net) ) )

        # Sampling sociability and popularity #
        if self.socpop:
            self.gp_socpop_mean_loc = pyro.param("gp_socpop_mean_loc", self.gp_socpop_mean_ini )
            self.gp_socpop_mean_scale = pyro.param("gp_socpop_mean_scale", 0.1*torch.ones((self.V_net,self.K_net,self.n_dir,self.n_w)), constraint=constraints.positive)
            self.gp_socpop_demean = pyro.param( f"gp_socpop_demean_loc", self.gp_socpop_demean_ini )
            # Posterior Covariance of the GP
            self.gp_socpop_cov_tril = pyro.param( "gp_socpop_cov_tril", self.Lff_ini.expand(self.V_net,self.K_net,self.n_dir,self.n_w,self.T_net,self.T_net),
                                            constraint=constraints.lower_cholesky )
            with pyro.plate('gp_socpop_all', self.V_net*self.K_net*self.n_dir*self.n_w ):
                # Posterior GP (mean function params) #
                pyro.sample( "gp_socpop_mean", dist.Normal( self.gp_socpop_mean_loc.reshape(self.V_net*self.K_net*self.n_dir*self.n_w),
                                                    self.gp_socpop_mean_scale.reshape(self.V_net*self.K_net*self.n_dir*self.n_w) ) )
                # Posterior GP (demeaned) #
                pyro.sample( f"gp_socpop_demean",
                                        dist.MultivariateNormal( self.gp_socpop_demean.reshape(self.V_net*self.K_net*self.n_dir*self.n_w , self.T_net),
                                                                scale_tril=self.gp_socpop_cov_tril.reshape(self.V_net*self.K_net*self.n_dir*self.n_w , self.T_net, self.T_net) ) )

        # pyro.sample( "kernel.variance", dist.InverseGamma( self.kernel_param[1,0], self.kernel_param[1,1] ) )

        # Sampling variance of weights
        if self.weighted:
            self.sigma_k_post_loc = pyro.param("sigma_k_post_loc", torch.ones([1]), constraint=constraints.positive )
            self.sigma_k_post_scale = pyro.param("sigma_k_post_scale", torch.ones([1]), constraint=constraints.positive )
            with pyro.plate( "sigma_k_ind", self.K_net):
                sigma_k = pyro.sample( 'sigma_k', dist.InverseGamma( self.sigma_k_post_loc, self.sigma_k_post_scale ) )
Esempio n. 15
0
    def model(self):

        # If the Kernel IS NOT random, we declare the kernel within the model
        # if (not self.random_kernel):
        #     self.kernel = pydmn.kernels.RBF()

        # Covariance matrix of observed times entailed by our kernel
        # Kff = self.kernel(self.Y_time.reshape(-1,1))
        # Kff.view(-1)[::self.T_net + 1] += self.jitter  # add jitter to the diagonal
        # Lff = Kff.cholesky() # cholesky lower triangular
        Lff = self.Lff_ini

        ## Sampling system-wide connectivity and average weights ##
        with pyro.plate('gp_system_all', self.K_net*self.n_w ):
            # Mean function of the GPs
            gp_system_mean = pyro.sample( "gp_system_mean",
                                    dist.Normal( torch.zeros( (self.K_net*self.n_w) ),
                                                torch.tensor([0.1]) ) )
            # Demeaned GPs
            gp_system_demean = pyro.sample( "gp_system_demean",
                                            dist.MultivariateNormal( torch.zeros( (self.K_net*self.n_w, self.T_net) ),
                                                                        scale_tril=Lff ) )
        gp_system_mean = gp_system_mean.reshape(self.K_net,self.n_w)
        gp_system_demean = gp_system_demean.reshape(self.K_net, self.n_w, self.T_net)
        # Latent systemic evolution
        gp_system = gp_system_mean.expand(self.T_net, self.K_net, self.n_w).permute(1,2,0) + gp_system_demean

        ## Sampling latent coordinates ##
        if self.coord:
            with pyro.plate('gp_coord_all', self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w ):
                # Mean function of the GPs
                gp_coord_mean = pyro.sample( "gp_coord_mean",
                                        dist.Normal( torch.zeros( (self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w) ),
                                                    torch.tensor([0.1]) ) )
                # Demeaned GPs
                gp_coord_demean = pyro.sample( "gp_coord_demean",
                                                dist.MultivariateNormal( torch.zeros( (self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w, self.T_net) ),
                                                                            scale_tril=Lff ) )

            gp_coord_mean = gp_coord_mean.reshape(self.V_net,self.H_dim,self.K_net,self.n_dir,self.n_w)
            gp_coord_demean = gp_coord_demean.reshape(self.V_net, self.H_dim, self.K_net, self.n_dir, self.n_w, self.T_net)
            # Latent coordinates
            gp_coord = gp_coord_mean.expand(self.T_net, self.V_net, self.H_dim, self.K_net, self.n_dir, self.n_w).permute(1,2,3,4,5,0) + gp_coord_demean

        ## Sampling Sociability and Popularity terms ##
        if self.socpop:
            with pyro.plate('gp_socpop_all', self.V_net*self.K_net*self.n_dir*self.n_w ):
                # Mean function of the GPs
                gp_socpop_mean = pyro.sample( "gp_socpop_mean",
                                        dist.Normal( torch.zeros( (self.V_net*self.K_net*self.n_dir*self.n_w) ),
                                                    torch.tensor([0.1]) ) )
                # Demeaned GPs
                gp_socpop_demean = pyro.sample( "gp_socpop_demean",
                                                dist.MultivariateNormal( torch.zeros( (self.V_net*self.K_net*self.n_dir*self.n_w, self.T_net) ),
                                                                            scale_tril=Lff ) )

            gp_socpop_mean = gp_socpop_mean.reshape(self.V_net,self.K_net,self.n_dir,self.n_w)
            gp_socpop_demean = gp_socpop_demean.reshape(self.V_net, self.K_net, self.n_dir, self.n_w, self.T_net)
            # Latent coordinates
            gp_socpop = gp_socpop_mean.expand(self.T_net, self.V_net, self.K_net, self.n_dir, self.n_w).permute(1,2,3,4,0) + gp_socpop_demean

        ### Linear Predictor ###
        # Systemic component
        Y_linpred = gp_system.expand(self.V_net, self.V_net, self.K_net, self.n_w, self.T_net).permute(0,1,4,2,3)
        # Distance between agents
        if self.coord:
            Y_linpred = Y_linpred + torch.einsum('uhkwt,vhkwt->uvtkw', gp_coord[:,:,:,0,:,:], gp_coord[:,:,:,self.n_dir-1,:,:])
        # Sociability and Popularity effects
        if self.socpop:
            gp_soc = gp_socpop[:,:,0,:,:].expand(self.V_net, self.V_net, self.K_net, self.n_w, self.T_net).transpose(0,1)
            gp_pop = gp_socpop[:,:,self.n_dir-1,:,:].expand(self.V_net, self.V_net, self.K_net, self.n_w,self.T_net)
            Y_linpred = Y_linpred + gp_soc.permute(0,1,4,2,3) + gp_pop.permute(0,1,4,2,3)

        ### Link propensity (probability of occur) ###
        Y_link_prob = torch.sigmoid(Y_linpred[:,:,:,:,0])
        Y_link_prob_valid = Y_link_prob.flatten()[self.Y_valid_id.flatten()==1]

        with pyro.plate( "data", Y_link_prob_valid.shape[0]):
            pyro.sample( "obs", dist.Bernoulli(Y_link_prob_valid), obs=self.Y_link.flatten()[self.cond_Y_link] )

        ### Link expected weight (weight given occurrence) ###
        if self.weighted:
            with pyro.plate( "sigma_k_ind", self.K_net):
                sigma_k = pyro.sample( 'sigma_k', dist.InverseGamma( self.sigma_k_prior_param[0].expand(self.K_net), self.sigma_k_prior_param[1].expand(self.K_net) ) )
            Y_link_SDw = sigma_k.expand(self.V_net,self.V_net,self.T_net,self.K_net)

            Y_link_Ew = Y_linpred[:,:,:,:,1]
            # cond_Y_w: condition of being positive and valid weights (defined in set_data())
            Y_link_Ew_valid = Y_link_Ew.flatten()[self.cond_Y_w]
            Y_link_SDw_valid = Y_link_SDw.flatten()[self.cond_Y_w]
            with pyro.plate( "data_w", Y_link_Ew_valid.shape[0] ):
                pyro.sample( "obs_w", dist.Normal(Y_link_Ew_valid,Y_link_SDw_valid), obs=self.Y.flatten()[self.cond_Y_w] )