def model(y): d = y.shape[1] N = y.shape[0] options = dict(dtype=y.dtype, device=y.device) # Vector of variances for each of the d variables theta = pyro.sample("theta", dist.HalfCauchy(torch.ones(d, **options))) # Lower cholesky factor of a correlation matrix eta = torch.ones(1, **options) # Implies a uniform distribution over correlation matrices L_omega = pyro.sample("L_omega", dist.LKJCorrCholesky(d, eta)) # Lower cholesky factor of the covariance matrix L_Omega = torch.mm(torch.diag(theta.sqrt()), L_omega) # For inference with SVI, one might prefer to use torch.bmm(theta.sqrt().diag_embed(), L_omega) # Vector of expectations mu = torch.zeros(d, **options) with pyro.plate("observations", N): obs = pyro.sample("obs", dist.MultivariateNormal(mu, scale_tril=L_Omega), obs=y) return obs
def model(): return pyro.sample('x', dist.LKJCorrCholesky(2, torch.tensor(1.)))
def model(self, zero_data, covariates): period = 24 * 7 duration, dim = zero_data.shape[-2:] assert dim == 2 # Data is bivariate: (arrivals, departures). # Sample global parameters. noise_scale = pyro.sample( "noise_scale", dist.LogNormal(torch.full((dim, ), -3.), 1.).to_event(1)) assert noise_scale.shape[-1:] == (dim, ) trans_timescale = pyro.sample( "trans_timescale", dist.LogNormal(torch.zeros(dim), 1).to_event(1)) assert trans_timescale.shape[-1:] == (dim, ) trans_loc = pyro.sample("trans_loc", dist.Cauchy(0, 1 / period)) trans_loc = trans_loc.unsqueeze(-1).expand(trans_loc.shape + (dim, )) assert trans_loc.shape[-1:] == (dim, ) trans_scale = pyro.sample( "trans_scale", dist.LogNormal(torch.zeros(dim), 0.1).to_event(1)) trans_corr = pyro.sample("trans_corr", dist.LKJCorrCholesky(dim, torch.ones(()))) trans_scale_tril = trans_scale.unsqueeze(-1) * trans_corr assert trans_scale_tril.shape[-2:] == (dim, dim) obs_scale = pyro.sample( "obs_scale", dist.LogNormal(torch.zeros(dim), 0.1).to_event(1)) obs_corr = pyro.sample("obs_corr", dist.LKJCorrCholesky(dim, torch.ones(()))) obs_scale_tril = obs_scale.unsqueeze(-1) * obs_corr assert obs_scale_tril.shape[-2:] == (dim, dim) # Note the initial seasonality should be sampled in a plate with the # same dim as the time_plate, dim=-1. That way we can repeat the dim # below using periodic_repeat(). with pyro.plate("season_plate", period, dim=-1): season_init = pyro.sample( "season_init", dist.Normal(torch.zeros(dim), 1).to_event(1)) assert season_init.shape[-2:] == (period, dim) # Sample independent noise at each time step. with self.time_plate: season_noise = pyro.sample("season_noise", dist.Normal(0, noise_scale).to_event(1)) assert season_noise.shape[-2:] == (duration, dim) # Construct a prediction. This prediction has an exactly repeated # seasonal part plus slow seasonal drift. We use two deterministic, # linear functions to transform our diagonal Normal noise to nontrivial # samples from a Gaussian process. prediction = (periodic_repeat(season_init, duration, dim=-2) + periodic_cumsum(season_noise, period, dim=-2)) assert prediction.shape[-2:] == (duration, dim) # Construct a joint noise model. This model is a GaussianHMM, whose # .rsample() and .log_prob() methods are parallelized over time; this # this entire model is parallelized over time. init_dist = dist.Normal(torch.zeros(dim), 100).to_event(1) trans_mat = trans_timescale.neg().exp().diag_embed() trans_dist = dist.MultivariateNormal(trans_loc, scale_tril=trans_scale_tril) obs_mat = torch.eye(dim) obs_dist = dist.MultivariateNormal(torch.zeros(dim), scale_tril=obs_scale_tril) noise_model = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration) assert noise_model.event_shape == (duration, dim) # The final statement registers our noise model and prediction. self.predict(noise_model, prediction)
def model(x, y, alt_av, alt_ids_cuda): # global parameters in the model if diagonal_alpha: alpha_mu = pyro.sample( "alpha", dist.Normal(torch.zeros(len(non_mix_params), device=x.device), 1).to_event(1)) else: alpha_mu = pyro.sample( "alpha", dist.MultivariateNormal( torch.zeros(len(non_mix_params), device=x.device), scale_tril=torch.tril( 1 * torch.eye(len(non_mix_params), device=x.device)))) if diagonal_beta_mu: beta_mu = pyro.sample( "beta_mu", dist.Normal(torch.zeros(len(mix_params), device=x.device), 1.).to_event(1)) else: beta_mu = pyro.sample( "beta_mu", dist.MultivariateNormal( torch.zeros(len(mix_params), device=x.device), scale_tril=torch.tril( 1 * torch.eye(len(mix_params), device=x.device)))) # Vector of variances for each of the d variables theta = pyro.sample( "theta", dist.HalfCauchy( 10. * torch.ones(len(mix_params), device=x.device)).to_event(1)) # Lower cholesky factor of a correlation matrix eta = 1. * torch.ones( 1, device=x.device ) # Implies a uniform distribution over correlation matrices L_omega = pyro.sample("L_omega", dist.LKJCorrCholesky(len(mix_params), eta)) # Lower cholesky factor of the covariance matrix L_Omega = torch.mm(torch.diag(theta.sqrt()), L_omega) # local parameters in the model random_params = pyro.sample( "beta_resp", dist.MultivariateNormal(beta_mu.repeat(num_resp, 1), scale_tril=L_Omega).to_event(1)) # vector of respondent parameters: global + local (respondent) params_resp = torch.cat([alpha_mu.repeat(num_resp, 1), random_params], dim=-1) # vector of betas of MXL (may repeat the same learnable parameter multiple times; random + fixed effects) beta_resp = torch.cat([ params_resp[:, beta_to_params_map[i]] for i in range(num_alternatives) ], dim=-1) with pyro.plate("locals", len(x), subsample_size=BATCH_SIZE) as ind: with pyro.plate("data_resp", T): # compute utilities for each alternative utilities = torch.scatter_add( zeros_vec[:, ind, :], 2, alt_ids_cuda[ind, :, :].transpose(0, 1), torch.mul(x[ind, :, :].transpose(0, 1), beta_resp[ind, :])) # adjust utility for unavailable alternatives utilities += alt_av[ind, :, :].transpose(0, 1) # likelihood pyro.sample("obs", dist.Categorical(logits=utilities), obs=y[ind, :].transpose(0, 1))
def model(self, data): ''' Define the parameters ''' n_ind = data['n_ind'] n_trt = data['n_trt'] n_tms = data['n_tms'] n_mrk = data['n_mrk'] n_prs = n_trt * n_tms * n_mrk plt_ind = pyro.plate('individuals', n_ind, dim=-3) plt_trt = pyro.plate('treatments', n_trt, dim=-2) plt_tms = pyro.plate('times', n_tms, dim=-1) pars = {} # covariance factors with plt_tms: # learning dt time step sizes # if k(t1,t2) is independent of time, can instead learn scales and variances for RBF kernels that use data['time_vals'] pars['dt0'] = pyro.sample('dt0', dist.Normal(0, 1)) pars['dt1'] = pyro.sample('dt1', dist.Normal(0, 1)) pars['theta_trt0'] = pyro.sample('theta_trt0', dist.HalfCauchy(torch.ones(n_trt))) pars['theta_mrk0'] = pyro.sample('theta_mrk0', dist.HalfCauchy(torch.ones(n_mrk))) pars['theta_trt1'] = pyro.sample('theta_trt1', dist.HalfCauchy(torch.ones(n_trt))) pars['L_omega_trt1'] = pyro.sample( 'L_omega_trt1', dist.LKJCorrCholesky(n_trt, torch.ones(1))) pars['theta_mrk1'] = pyro.sample('theta_mrk1', dist.HalfCauchy(torch.ones(n_mrk))) pars['L_omega_mrk1'] = pyro.sample( 'L_omega_mrk1', dist.LKJCorrCholesky(n_mrk, torch.ones(1))) times0 = fun.pad(torch.cumsum(pars['dt0'].exp().log1p(), 0), (1, 0), value=0)[:-1].unsqueeze(1) times1 = fun.pad(torch.cumsum(pars['dt1'].exp().log1p(), 0), (1, 0), value=0)[:-1].unsqueeze(1) cov_t0 = (-torch.cdist(times0, times0)).exp() cov_t1 = (-torch.cdist(times1, times1)).exp() cov_i0 = pars['theta_trt0'].diag() L_Omega_trt = torch.mm(torch.diag(pars['theta_trt1'].sqrt()), pars['L_omega_trt1']) cov_i1 = L_Omega_trt.mm(L_Omega_trt.t()) cov_m0 = pars['theta_mrk0'].diag() L_Omega_mrk = torch.mm(torch.diag(pars['theta_mrk1'].sqrt()), pars['L_omega_mrk1']) cov_m1 = L_Omega_mrk.mm(L_Omega_mrk.t()) # kronecker product of the factors cov_itm0 = torch.einsum('ij,tu,mn->itmjun', [cov_i0, cov_t0, cov_m0]).view(n_prs, n_prs) cov_itm1 = torch.einsum('ij,tu,mn->itmjun', [cov_i1, cov_t1, cov_m1]).view(n_prs, n_prs) # global and individual level params of each marker, treatment, and time point pars['glb'] = pyro.sample( 'glb', dist.MultivariateNormal(torch.zeros(n_prs), cov_itm0)) with plt_ind: pars['ind'] = pyro.sample( 'ind', dist.MultivariateNormal(torch.zeros(n_prs), cov_itm1)) # observation noise, time series bias and scale pars['noise_scale'] = pyro.sample('noise_scale', dist.HalfCauchy(torch.ones(n_mrk))) pars['t0_scale'] = pyro.sample('t0_scale', dist.HalfCauchy(torch.ones(n_mrk))) with plt_ind: pars['t0'] = pyro.sample( 't0', dist.MultivariateNormal(torch.zeros(n_mrk), pars['t0_scale'].diag())) with plt_trt, plt_tms: pars['noise'] = pyro.sample( 'noise', dist.MultivariateNormal(torch.zeros(n_mrk), pars['noise_scale'].diag())) # likelihood of the data distr = self.get_distr(data, pars) pyro.sample('obs', distr, obs=data['Y'])