def linear_model_formula(self, y, design, target_labels): if self.use_softplus: mu = {l: rmv(self.softplus(self.regressor[l]), y) for l in target_labels} else: mu = {l: rmv(self.regressor[l], y) for l in target_labels} scale_tril = {l: rtril(self.scale_tril[l]) for l in target_labels} return mu, scale_tril
def linear_model_formula(self, y, design, target_labels): tikhonov_diag = torch.diag(self.softplus(self.tikhonov_diag)) xtx = torch.matmul(design.transpose(-1, -2), design) + tikhonov_diag xtxi = rinverse(xtx, sym=True) mu = rmv(xtxi, rmv(design.transpose(-1, -2), y)) # Extract sub-indices mu = tensor_to_dict(self.w_sizes, mu, subset=target_labels) scale_tril = {l: rtril(self.scale_tril[l]) for l in target_labels} return mu, scale_tril
def dv_critic(design, trace, observation_labels, target_labels): y_dict = {l: trace.nodes[l]["value"] for l in observation_labels} theta_dict = {l: trace.nodes[l]["value"] for l in target_labels} x = torch.cat(list(theta_dict.values()) + list(y_dict.values()), dim=-1) B = pyro.param("B", torch.zeros(5, 5)) return rvv(x, rmv(B, x))
def posterior_guide(y_dict, design, observation_labels, target_labels): y = torch.cat(list(y_dict.values()), dim=-1) A = pyro.param("A", torch.zeros(2, 3)) scale_tril = pyro.param("scale_tril", torch.tensor([[1., 0.], [0., 1.5]]), constraint=torch.distributions.constraints.lower_cholesky) mu = rmv(A, y) pyro.sample("w", dist.MultivariateNormal(mu, scale_tril=scale_tril))
def likelihood_guide(theta_dict, design, observation_labels, target_labels): theta = torch.cat(list(theta_dict.values()), dim=-1) centre = rmv(design, theta) # Need to avoid name collision here mu = pyro.param("mu_l", torch.zeros(3)) scale_tril = pyro.param("scale_tril_l", torch.eye(3), constraint=torch.distributions.constraints.lower_cholesky) pyro.sample("y", dist.MultivariateNormal(centre + mu, scale_tril=scale_tril))
def true_model(design): w1 = torch.tensor([-1., 1.]) w2 = torch.tensor([-.5, .5, -.5, .5, -.5, 2., -2., 2., -2., 0.]) w = torch.cat([w1, w2], dim=-1) k = torch.tensor(.1) response_mean = rmv(design, w) base_dist = dist.Normal(response_mean, torch.tensor(1.)).to_event(1) k = k.expand(response_mean.shape) transforms = [AffineTransform(loc=0., scale=k), SigmoidTransform()] response_dist = dist.TransformedDistribution(base_dist, transforms) return pyro.sample("y", response_dist)
def get_params(self, y_dict, design, target_labels): y = torch.cat(list(y_dict.values()), dim=-1) coefficient_labels = [label for label in target_labels if label != self.tau_label] mu, scale_tril = self.linear_model_formula(y, design, coefficient_labels) mu_vec = torch.cat(list(mu.values()), dim=-1) yty = rvv(y, y) ytxmu = rvv(y, rmv(design, mu_vec)) beta = self.b0 + .5*(yty - ytxmu) return mu, scale_tril, self.alpha, beta
def model(design): batch_shape = design.shape[:-2] k_shape = batch_shape + (sigmoid_design.shape[-1],) k = pyro.sample( sigmoid_label, dist.Gamma( sigmoid_alpha.expand(k_shape), sigmoid_beta.expand(k_shape) ).to_event(1), ) k_assigned = rmv(sigmoid_design, k) return bayesian_linear_model( design, w_means=OrderedDict([(coef1_label, coef1_mean), (coef2_label, coef2_mean)]), w_sqrtlambdas={ coef1_label: 1.0 / (observation_sd * coef1_sd), coef2_label: 1.0 / (observation_sd * coef2_sd), }, obs_sd=observation_sd, response="sigmoid", response_label=observation_label, k=k_assigned, )
def bayesian_linear_model(design, w_means={}, w_sqrtlambdas={}, re_group_sizes={}, re_alphas={}, re_betas={}, obs_sd=None, alpha_0=None, beta_0=None, response="normal", response_label="y", k=None): """ A pyro model for Bayesian linear regression. If :param:`response` is `"normal"` this corresponds to a linear regression model :math:`Y = Xw + \\epsilon` with `\\epsilon`` i.i.d. zero-mean Gaussian. The observation standard deviation (:param:`obs_sd`) may be known or unknown. If unknown, it is assumed to follow an inverse Gamma distribution with parameters :param:`alpha_0` and :param:`beta_0`. If the response type is `"bernoulli"` we instead have :math:`Y \\sim Bernoulli(p)` with :math:`logit(p) = Xw` Given parameter groups in :param:`w_means` and :param:`w_sqrtlambda`, the fixed effects regression coefficient is taken to be Gaussian with mean `w_mean` and standard deviation given by :math:`\\sigma / \\sqrt{\\lambda}` corresponding to the normal inverse Gamma family. The random effects coefficient is constructed as follows. For each random effect group, standard deviations for that group are sampled from a normal inverse Gamma distribution. For each group, a random effect coefficient is then sampled from a zero mean Gaussian with those standard deviations. :param torch.Tensor design: a tensor with last two dimensions `n` and `p` corresponding to observations and features respectively. :param OrderedDict w_means: map from variable names to tensors of fixed effect means. :param OrderedDict w_sqrtlambdas: map from variable names to tensors of square root :math:`\\lambda` values for fixed effects. :param OrderedDict re_group_sizes: map from variable names to int representing the group size :param OrderedDict re_alphas: map from variable names to `torch.Tensor`, the tensor consists of Gamma dist :math:`\\alpha` values :param OrderedDict re_betas: map from variable names to `torch.Tensor`, the tensor consists of Gamma dist :math:`\\beta` values :param torch.Tensor obs_sd: the observation standard deviation (if assumed known). This is still relevant in the case of Bernoulli observations when coefficeints are sampled using `w_sqrtlambdas`. :param torch.Tensor alpha_0: Gamma :math:`\\alpha` parameter for unknown observation covariance. :param torch.Tensor beta_0: Gamma :math:`\\beta` parameter for unknown observation covariance. :param str response: Emission distribution. May be `"normal"` or `"bernoulli"`. :param str response_label: Variable label for response. :param torch.Tensor k: Only used for a sigmoid response. The slope of the sigmoid transformation. """ # design is size batch x n x p # tau is size batch batch_shape = design.shape[:-2] with ExitStack() as stack: for plate in iter_plates_to_shape(batch_shape): stack.enter_context(plate) if obs_sd is None: # First, sample tau (observation precision) tau_prior = dist.Gamma(alpha_0.unsqueeze(-1), beta_0.unsqueeze(-1)).to_event(1) tau = pyro.sample("tau", tau_prior) obs_sd = 1. / torch.sqrt(tau) elif alpha_0 is not None or beta_0 is not None: warnings.warn("Values of `alpha_0` and `beta_0` unused becased" "`obs_sd` was specified already.") obs_sd = obs_sd.expand(batch_shape + (1, )) # Build the regression coefficient w = [] # Allow different names for different coefficient groups # Process fixed effects for name, w_sqrtlambda in w_sqrtlambdas.items(): w_mean = w_means[name] # Place a normal prior on the regression coefficient w_prior = dist.Normal(w_mean, obs_sd / w_sqrtlambda).to_event(1) w.append(pyro.sample(name, w_prior)) # Process random effects for name, group_size in re_group_sizes.items(): # Sample `G` once for this group alpha, beta = re_alphas[name], re_betas[name] G_prior = dist.Gamma(alpha, beta).to_event(1) G = 1. / torch.sqrt(pyro.sample("G_" + name, G_prior)) # Repeat `G` for each group repeat_shape = tuple(1 for _ in batch_shape) + (group_size, ) u_prior = dist.Normal(torch.tensor(0.), G.repeat(repeat_shape)).to_event(1) w.append(pyro.sample(name, u_prior)) # Regression coefficient `w` is batch x p w = broadcast_cat(w) # Run the regressor forward conditioned on inputs prediction_mean = rmv(design, w) if response == "normal": # y is an n-vector: hence use .to_event(1) return pyro.sample( response_label, dist.Normal(prediction_mean, obs_sd).to_event(1)) elif response == "bernoulli": return pyro.sample( response_label, dist.Bernoulli(logits=prediction_mean).to_event(1)) elif response == "sigmoid": base_dist = dist.Normal(prediction_mean, obs_sd).to_event(1) # You can add loc via the linear model itself k = k.expand(prediction_mean.shape) transforms = [ AffineTransform(loc=torch.tensor(0.), scale=k), SigmoidTransform() ] response_dist = dist.TransformedDistribution(base_dist, transforms) return pyro.sample(response_label, response_dist) else: raise ValueError( "Unknown response distribution: '{}'".format(response))
def test_rmv(A, b): assert_equal(rmv(A, b), A.mv(b), prec=1e-8) batched_A = lexpand(A, 5, 4) batched_b = lexpand(b, 5, 4) expected_Ab = lexpand(A.mv(b), 5, 4) assert_equal(rmv(batched_A, batched_b), expected_Ab, prec=1e-8)