def compute_stochastic_elbo(a, b, nu, omega, x, y, a_0, b_0, mu_0): """ Return a monte-carlo estimate of the ELBO, using a single sample from Q(sigma^-2, beta) a, b are the Gamma 'shape' and 'rate' parameters for the variational posterior over *precision*: q(tau) = q(sigma^-2) nu_k, omega_k are Normal 'mean' and 'precision' parameters for the variational posterior over weights: q(beta_k) x is an n by k matrix, where each row contains the regression inputs [1, x, x^2, x^3] y is an n by 1 values a_0, b_0 the parameters for the Gamma prior over precision P(tau) = P(sigma^-2) mu_0 is the mean of the Gamma prior on weights beta """ # Define mean field variational distribution over (beta, tau). Q_beta = Normal(nu, omega**-0.5) Q_tau = Gamma(a, b) # Sample from variational distribution: (tau, beta) ~ Q # Use rsample to make sure that the result is differentiable. tau = Q_tau.rsample() sigma = tau**-0.5 beta = Q_beta.rsample() # Create a single sample monte-carlo estimate of ELBO. P_tau = Gamma(a_0, b_0) P_beta = Normal(mu_0, sigma) P_y = Normal((beta[None, :]*x).sum(dim=1, keepdim=True), sigma) kl_tau = Q_tau.log_prob(tau) - P_tau.log_prob(tau) kl_beta = Q_beta.log_prob(beta).sum() - P_beta.log_prob(beta).sum() log_likelihood = P_y.log_prob(y).sum() elbo = log_likelihood - kl_tau - kl_beta return elbo
def log_likelihood(self, x_norm, y_norm): mean, var, shape, rate, mixture_var = self(x_norm) norm_dist = Normal(mean, torch.sqrt(var)) gamma_dist = Gamma(shape, rate) y = y_norm * self.y_std + self.y_mean + 10**(-4) only_normal_bool = (torch.abs(1 - mixture_var) < 10**(-4)).type( torch.float) only_gamma_bool = (mixture_var < 10**(-4)).type(torch.float) normal_component = norm_dist.log_prob(y_norm) + torch.log(mixture_var) gamma_component = gamma_dist.log_prob(y) + torch.log(1 - mixture_var) combined_tensor = torch.stack((normal_component, gamma_component), dim=0) output = torch.logsumexp(combined_tensor, dim=0) if mixture_var < 0.9: logging.debug("Mixture var: {}".format(float(mixture_var.mean()))) logging.debug("NLLs: {:.3f}, {:.3f}".format( -float(norm_dist.log_prob(y_norm).mean()), -float(gamma_dist.log_prob(y).mean()), )) logging.debug("Combined NLL: {:.3f} or {:.3f}".format( -float(output.mean()), -float(old_output))) return output.mean()
def gamma_ll(target_vals, v): """ Evaluate gamma-bernoulli mixture likelihood Parameters: ---------- v: torch.Tensor(batch,86,channels) parameters from model [rho, alpha, beta] target_vals: torch.Tensor(batch,86) target vals to eval at """ # Reshape target_vals = target_vals.reshape(-1) v = v.reshape(-1, 3) # Deal with cases where data is missing for a station v = v[~torch.isnan(target_vals), :] target_vals = target_vals[~torch.isnan(target_vals)] # Make r mask r, target_vals = make_r_mask(target_vals) gamma = Gamma(concentration=v[:, 1], rate=v[:, 2]) logp = gamma.log_prob(target_vals) total = r * (torch.log(v[:, 0]) + logp) + (1 - r) * torch.log(1 - v[:, 0]) return torch.mean(total)
def test_loss(self, test_data): """ outputs the losses the test data """ x, y = test_data[:] if not test_data.x_normalised: constant_x = self.x_std == 0 x = (x - self.x_mean) / self.x_std x[:, constant_x] = 0 self.train(False) shape, rate = self(x) y = y.squeeze() shape = shape.squeeze() rate = rate.squeeze() gamma_dist = Gamma(shape, rate) test_nll = -gamma_dist.log_prob(y + 10**(-8)).mean() test_rmse = (((y - gamma_dist.mean)**2).mean())**0.5 calibration_arr = self.calibration_test(y.detach().numpy(), shape.detach().numpy(), rate.detach().numpy()) return float(test_nll), float(test_rmse), calibration_arr
def gamma_logpdf(inputs, loc, scale, reduction=None): """Gamma log-density. Args: inputs (tensor): Inputs. mean (tensor): Mean. sigma (tensor): Standard deviation. reduction (str, optional): Reduction. Defaults to no reduction. Possible values are "sum", "mean", and "batched_mean". Returns: tensor: Log-density. """ dist = Gamma(concentration=loc, rate=scale) logp = dist.log_prob(inputs) if not reduction: return logp elif reduction == 'sum': return torch.sum(logp) elif reduction == 'mean': return torch.mean(logp) elif reduction == 'batched_mean': return torch.mean(torch.sum(logp, 1)) else: raise RuntimeError(f'Unknown reduction "{reduction}".')
def fit(self, train_data): self.train(True) self.x_mean, self.x_std = train_data.normalise_x() data_generator = data.DataLoader(train_data, batch_size=self.batch_size) optimiser = torch.optim.Adam(self.parameters(), lr=self.lr) for _ in torch.arange(self.n_epochs): for i, sample in enumerate(data_generator): x, y = sample shape, rate = self(x) gamma_dist = Gamma(shape, rate) optimiser.zero_grad() loss = -gamma_dist.log_prob(y.squeeze() + 10 ** (-8)).mean() loss.backward() optimiser.step()
def tbi_func(x, v): """ Evaluate gamma-GP-Bernoulli mixture likelihood Parameters: ---------- v: torch.Tensor(batch*86, channels) parameters from model x: torch.Tensor(batch*86) target vals to eval at """ # Gamma distribution g = Gamma(concentration=v[:, 2], rate=v[:, 3]) gamma = torch.exp(torch.clamp(g.log_prob(x), min=-1e5, max=1e5)) # Weight term weight_term = (1 / 2) + (1 / np.pi) * torch.atan((x - v[:, 5]) / v[:, 6]) # GP distribution gp = (1 / v[:, 4]) * (1 + (v[:, 1] * x / v[:, 4]))**((-1 / v[:, 1]) - 1) # total tbi = gamma * (1 - weight_term) + gp * weight_term return torch.clamp(tbi, min=1e-5)
class MLLGP(): def __init__(self, model_gp, likelihood_gp, hyperpriors: dict) -> None: self.model_gp = model_gp self.likelihood_gp = likelihood_gp self.hyperpriors = hyperpriors a_beta = self.hyperpriors["lengthscales"].kwds["a"] b_beta = self.hyperpriors["lengthscales"].kwds["b"] self.Beta_tmp = Beta(concentration1=a_beta, concentration0=b_beta) a_gg = self.hyperpriors["outputscale"].kwds["a"] b_gg = self.hyperpriors["outputscale"].kwds["scale"] self.Gamma_tmp = Gamma(concentration=a_gg, rate=1. / b_gg) def log_marginal(self, lengthscales, outputscale) -> float: """ """ # print("lengthscales.shape:",lengthscales.shape) # print("outputscale.shape:",outputscale.shape) if lengthscales.dim() == 3 or outputscale.dim() == 3: Nels = lengthscales.shape[0] loss_vec = torch.zeros(Nels) for k in range(Nels): loss_vec[k] = self.log_marginal(lengthscales[k, 0, :], outputscale[k, 0, :]) return loss_vec assert lengthscales.dim() <= 1 and outputscale.dim() <= 1 assert not torch.any(torch.isnan(lengthscales)) and not torch.any( torch.isinf(lengthscales)), "lengthscales is inf or NaN" assert not torch.isnan(outputscale) and not torch.isinf( outputscale), "outputscale is inf or NaN" # Update hyperparameters: self.model_gp.covar_module.outputscale = outputscale self.model_gp.covar_module.base_kernel.lengthscale = lengthscales # self.model_gp.display_hyperparameters() # Get the log prob of the marginal distribution: function_dist = self.model_gp(self.model_gp.train_inputs[0]) output = self.likelihood_gp(function_dist) loss_val = output.log_prob(self.model_gp.train_targets).view(1) # if self.debug == True: # pdb.set_trace() loss_lengthscales_hyperprior = torch.sum( self.Beta_tmp.log_prob(lengthscales)).view(1) loss_outputscale_hyperprior = self.Gamma_tmp.log_prob(outputscale) # loss_lengthscales_hyperprior = sum(self.hyperpriors["lengthscales"].logpdf(lengthscales)) # loss_outputscale_hyperprior = self.hyperpriors["outputscale"].logpdf(outputscale).item() loss_val += loss_lengthscales_hyperprior + loss_outputscale_hyperprior try: assert not torch.any(torch.isnan(loss_val)) and not torch.any( torch.isinf(loss_val)), "loss_val is Inf or NaN" except: # debug TODO DEBUG logger.info("loss_val: {0:s}".format(str(loss_val))) logger.info("loss_lengthscales_hyperprior: {0:s}".format( str(loss_lengthscales_hyperprior))) logger.info("loss_outputscale_hyperprior: {0:s}".format( str(loss_outputscale_hyperprior))) return loss_val def __call__(self, pars_in): # Slice only last dimension: https://pytorch.org/docs/stable/tensors.html#torch.Tensor.narrow lengthscales = pars_in.narrow( dim=-1, start=self.model_gp.idx_hyperpars["lengthscales"][0], length=len(self.model_gp.idx_hyperpars["lengthscales"])) outputscale = pars_in.narrow( dim=-1, start=self.model_gp.idx_hyperpars["outputscale"][0], length=len(self.model_gp.idx_hyperpars["outputscale"])) return -self.log_marginal( lengthscales, outputscale) # Use minus (-) when minizing the marginal likelihood
class BayesianNN: def __init__(self, X_train, y_train, batch_size, num_particles, hidden_dim): self.gamma_prior = Gamma(torch.tensor(1., device=device), torch.tensor(1 / 0.1, device=device)) self.lambda_prior = Gamma(torch.tensor(1., device=device), torch.tensor(1 / 0.1, device=device)) self.X_train = X_train self.y_train = y_train self.batch_size = batch_size self.num_particles = num_particles self.n_features = X_train.shape[1] self.hidden_dim = hidden_dim def forward(self, inputs, theta): # Unpack theta w1 = theta[:, 0:self.n_features * self.hidden_dim].reshape( -1, self.n_features, self.hidden_dim) b1 = theta[:, self.n_features * self.hidden_dim:(self.n_features + 1) * self.hidden_dim].unsqueeze(1) w2 = theta[:, (self.n_features + 1) * self.hidden_dim:(self.n_features + 2) * self.hidden_dim].unsqueeze(2) b2 = theta[:, -3].reshape(-1, 1, 1) # log_gamma, log_lambda = theta[-2], theta[-1] # num_particles times of forward inputs = inputs.unsqueeze(0).repeat(self.num_particles, 1, 1) inter = F.relu(torch.bmm(inputs, w1) + b1) out = torch.bmm(inter, w2) + b2 out = out.squeeze() return out def log_prob(self, theta): model_gamma = torch.exp(theta[:, -2]) model_lambda = torch.exp(theta[:, -1]) model_w = theta[:, :-2] # w_prior should be decided based on current lambda (not sure) w_prior = Normal( 0, torch.sqrt(torch.ones_like(model_lambda) / model_lambda)) random_idx = random.sample([i for i in range(self.X_train.shape[0])], self.batch_size) X_batch = self.X_train[random_idx] y_batch = self.y_train[random_idx] outputs = self.forward(X_batch, theta) # [num_particles, batch_size] model_gamma_repeat = model_gamma.unsqueeze(1).repeat( 1, self.batch_size) y_batch_repeat = y_batch.unsqueeze(0).repeat(self.num_particles, 1) distribution = Normal( outputs, torch.sqrt( torch.ones_like(model_gamma_repeat) / model_gamma_repeat)) log_p_data = distribution.log_prob(y_batch_repeat).sum(dim=1) log_p0 = w_prior.log_prob( model_w.t()).sum(dim=0) + self.gamma_prior.log_prob( model_gamma) + self.lambda_prior.log_prob(model_lambda) log_p = log_p0 + log_p_data * (self.X_train.shape[0] / self.batch_size ) # (8) in paper return log_p
class MAMLParticles(MetaNetwork): """ Object that contains all the particles. """ def __init__(self, feature_extractor_params, lr_chaser=0.001, lr_leader=None, n_epochs_chaser=1, n_epochs_predict=0, s_epochs_leader=1, m_particles=2, kernel_function='rbf', n_samples=10, a_likelihood=2., b_likelihood=.2, a_prior=2., b_prior=.2, use_mse=False): """ Initialises the object. Parameters ---------- feature_extractor_params: dict Parameters for the feature extractor. lr_chaser: float Learning rate for the chaser lr_leader: float Learning rate for the leader n_epochs_chaser: int Number of steps to be performed by the chaser. s_epochs_leader: int Number of steps to be performed by the leader. m_particles: Number of particles. kernel_function: str, {'rbf', 'quadratic'} The kernel function to use. use_mse: bool Whether to use MSE loss or Chaser loss. """ super(MAMLParticles, self).__init__() self.kernel_function = kernel_function self.n_epochs_chaser = n_epochs_chaser self.s_epochs_leader = s_epochs_leader self.n_epochs_predict = n_epochs_predict if lr_leader is None: lr_leader = lr_chaser / 10 self.lr = { 'chaser': lr_chaser, 'leader': lr_leader, } self.m_particles = m_particles self.n_samples = n_samples self.feature_extractor = FeaturesExtractorFactory()( **feature_extractor_params) self.fe_output_dim = self.feature_extractor.output_dim self.gamma_likelihood = Gamma(a_likelihood, b_likelihood) self.gamma_prior = Gamma(a_prior, b_prior) # The particles only implement the last (linear) layer. # The first two columns are the kappas (likelihood then prior) self.particles = nn.Parameter( torch.cat(( self.gamma_likelihood.sample((m_particles, 1)), self.gamma_prior.sample((m_particles, 1)), nn.init.kaiming_uniform( torch.empty((m_particles, self.fe_output_dim + 1))), ), dim=1)) self.loss = 0 self.use_mse = use_mse @property def return_var(self): return True def kernel(self, weights): """ Computes the cross-particle kernel. Given the stacked parameter vectors of the particles, outputs the kernel (be it RBF or quadratic). Parameters ---------- weights: torch.Tensor B * M * M * (D + 1) tensor. Expanded versions of the weights. Returns ------- kernel: torch.Tensor B * M * M tensor representing the cross-particle kernel. """ def rbf_kernel(pv): """ Computes the RBF kernel for a set of parameter vectors. Parameters ---------- pv: torch.Tensor Stack of flatten parameters for each particle. Returns ------- kernel: m x m torch.Tensor A m x m torch tensor representing the kernel. """ x = pv - pv.transpose(1, 2) x = -x.norm(2, dim=3).pow(2) / 2 x = x.exp() return x def quadratic_kernel(pv): """ Computes the RBF kernel for a set of parameter vectors. Parameters ---------- pv: torch.Tensor Stack of flatten parameters for each particle. Returns ------- kernel: m x m torch.Tensor A m x m torch tensor representing the kernel. """ x = pv - pv.transpose(1, 2) x = -x.norm(2, dim=3).pow(2) x = 1 / x return x kernel_functions = {'rbf': rbf_kernel, 'quadratic': quadratic_kernel} kernel = kernel_functions[self.kernel_function] return kernel(weights) @staticmethod def compute_predictions(features, parameters): """ Parameters ---------- features: torch.Tensor B * N * D tensor representing the features. parameters: torch.Tensor B * M * (D + 3) tensor representing the M particles (including the bias-feature trick and two kappa vectors). Returns ------- predictions: torch.Tensor B * M * N tensor, representing the predictions. """ # Obtains the weights weights = parameters[..., 2:] # Implements the bias-feature trick features = torch.cat((features, torch.ones_like(features[..., :1])), dim=2) predictions = torch.bmm(weights, features.transpose(1, 2)) return predictions def compute_mean_std(self, features, parameters): """ Parameters ---------- features: torch.Tensor B * N * D tensor representing the features. parameters: torch.Tensor B * M * (D + 3) tensor representing the M particles (including the bias-feature trick and two kappa vectors). Returns ------- predictions: torch.Tensor B * M * N tensor, representing the predictions. """ # Obtains the kappas (B * M) kappa_likelihood = parameters[..., 0] # Computes the predictions (B * M * N) predictions = self.compute_predictions(features, parameters) # Transposes the predictions to B * N * M predictions = predictions.transpose(1, 2) # Computes the mean mean = predictions.mean(dim=2) # Adds the variability variability = torch.randn( (*predictions.size(), self.n_samples)).to(mean.device) variability = variability / kappa_likelihood.unsqueeze(1).unsqueeze( 3).pow(.5) predictions = predictions.unsqueeze(3) + variability # Reshapes the predictions to B * N * (M x S), where S is the number of samples predictions = predictions.view(*predictions.shape[:2], -1) # mean = predictions.mean(dim=2) std = predictions.std(dim=2) return mean, std def posterior(self, predictions, targets, mask, weights, kappa_likelihood, kappa_prior): r""" Computes the posterior of the configuration. Parameters ---------- predictions: torch.Tensor B * M * N tensor representing the prediction made by the network. targets: torch.Tensor B * N * 1 tensor representing the targets. mask: torch.Tensor B * N mask of the examples (some tasks have less than N examples). weights: torch.Tensor B * M * (D + 1) tensor representing the weights, including the bias-feature trick kappa_likelihood: torch.Tensor: B * M tensor representing $\kappa_{likelihood}$. kappa_prior: torch.Tensor: B * M tensor representing $\kappa_{prior}$. Returns ------- objective: torch.Tensor B * M tensor, representing the posterior of each particle, for each batch. """ # Computing the log-likelihood log_likelihood = log_pdf(predictions - targets.transpose(1, 2), kappa_likelihood) # B * M * N log_likelihood = log_likelihood * mask.unsqueeze( 1) # Keep only the actual examples log_likelihood = log_likelihood.sum(dim=2) # We enforce a Gaussian prior on the weights log_prior = log_pdf(weights[..., :-1], kappa_prior).sum(dim=2) # Gamma prior on the kappas log_prior_kappa = self.gamma_likelihood.log_prob(kappa_likelihood) log_prior_kappa = log_prior_kappa + self.gamma_prior.log_prob( kappa_prior) objective = log_likelihood + log_prior + log_prior_kappa return objective def svgd(self, features, targets, mask, parameters, update_type='chaser'): r""" Performs the Stein Variational Gradient Update on the particles. For each particle, the update is given by :math:`\theta_{t+1} \gets \theta_t + \varepsilon_t \phi(\theta_t)` where: .. math:: \phi(\theta_t) = \frac{1}{M} \sum_{m=1}^M \left[ k(\theta_t^{(m)}, \theta_t) \nabla_{\theta_t^{(m)}} \log p(\theta_t^{(m)}) + \nabla_{\theta_t^{(m)}} k(\theta_t^{(m)}, \theta_t) \right] Parameters ---------- features: torch.Tensor B * N * D tensor. The precomputed features associated with the dataset. targets: torch.Tensor B * N * 1 tensor. The targets associated to the features. Useful to compute the posterior. mask: torch.Tensor B * N mask of the examples (some tasks have less than N examples). parameters: torch.Tensor B * M * (D + 3) tensor containing the full parameters, already expanded along a batch dimension. update_type: str, 'chaser' or 'leader' Defines which learning rate to use. """ # Expands the parameters : B * M * (D + 3) -> B * M * M * (D + 3) expanded_parameters = parameters.unsqueeze(1) expanded_parameters = expanded_parameters.expand( (parameters.size(0), self.m_particles, *parameters.shape[1:])) # Splits the different parameters kappa_likelihood = parameters[..., 0] kappa_prior = parameters[..., 1] weights = parameters[..., 2:] expanded_weights = expanded_parameters[..., 2:] # weights is B * M * (D + 1), features is B * N * D # predictions is B * M * N predictions = self.compute_predictions(features, parameters) # B * M * M kernel = self.kernel(expanded_weights) # B * M objectives = self.posterior( predictions=predictions, targets=targets, mask=mask, weights=weights, kappa_likelihood=kappa_likelihood, kappa_prior=kappa_prior, ) # Computes the gradients for the objective (B * M * (D + 3)) objective_grads = autograd.grad(objectives.sum(), parameters, create_graph=True)[0] # Computes the gradients for the kernel, using the expanded parameters (B * M * M * (D + 3)) kernel_grads = autograd.grad(kernel.sum(), expanded_parameters, create_graph=True)[0] # Computes the update # The matmul term multiplies batches of matrices that are B * M * M and B * M * (D + 3) update = torch.matmul( kernel, objective_grads) / self.m_particles + kernel_grads.mean( dim=2) # Performs the update new_parameters = parameters + self.lr[update_type] * update # We need to make sure that the kappas remain in the right range for numerical stability new_parameters = torch.cat([ torch.clamp(new_parameters[..., :2], min=1e-8), new_parameters[..., 2:] ], dim=2) return new_parameters def forward(self, episodes, train=None, test=None, query=None, trim_ends=True): """ Performs a forward and backward pass on a single episode. To keep memory load low, the backward pass is done simultaneously. Parameters ---------- episodes: list A batch of meta-learning episodes. train: dataset The train dataset. test: dataset The test dataset. query: dataset The query dataset. trim_ends: bool Whether to trim the results. Returns ------- results: list(tuple) A list of tuples containing the mean and standard deviation computed by the network for each episodes. query_results: list(tuple) A list of tuples containing the mean and standard deviation computed by the network for each episodes of the query set. """ if episodes is not None: train, test = pack_episodes(episodes, return_ys_test=True, return_query=False) x_test, y_test, len_test, mask_test = test query = None else: assert (train is not None) and (test is not None) x_test, len_test, mask_test = test # x is B * N * D dimensional, y is B * N * 1 dimensional x_train, y_train, len_train, mask_train = train b, n, d = x_train.size() train_features = self.feature_extractor(x_train.reshape( -1, d)).reshape(b, -1, self.fe_output_dim) test_features = self.feature_extractor(x_test.reshape(-1, d)).reshape( b, -1, self.fe_output_dim) # Expands the parameters along the batch dimension : M * (D + 3) -> B * M * (D + 3) parameters = self.particles.unsqueeze(0).expand( (b, *self.particles.size())) with autograd.enable_grad(): # Initialise the chaser as a new tensor chaser = parameters + 0. for i in range(self.n_epochs_chaser): chaser = self.svgd(train_features, y_train, mask_train, parameters=chaser, update_type='chaser') if self.training and not self.use_mse: full_features = torch.cat((train_features, test_features), dim=1) y_full = torch.cat((y_train, y_test), dim=1) mask_full = torch.cat((mask_train, mask_test), dim=1) leader = chaser + 0. for i in range(self.s_epochs_leader): leader = self.svgd(full_features, y_full, mask_full, parameters=leader, update_type='leader') # Added stability self.loss = (leader.detach() - chaser)[..., 2:].pow(2).sum() / b with autograd.enable_grad(): for i in range(self.n_epochs_predict): chaser = self.svgd(train_features, y_train, mask_train, parameters=chaser, update_type='chaser') # Computes the mean and standard deviation mean, std = self.compute_mean_std(test_features, chaser) # Unsqueezes the results to keep the same shape as the targets mean = mean.unsqueeze(2) std = std.unsqueeze(2) # Re-organises the results in the episodic form mean = [m[:n] for m, n in zip(mean, len_test)] std = [s[:n] for s, n in zip(std, len_test)] results = [(m[:n], s[:n]) for m, s, n in zip(mean, std, len_test) ] if trim_ends else (mean, std) if query is None: return results x_query, _, len_query, mask_query = query query_features = self.feature_extractor(x_query.reshape( -1, d)).reshape(b, -1, self.fe_output_dim) mean, std = self.compute_mean_std(query_features, chaser) # Unsqueezes the results to keep the same shape as the targets mean = mean.unsqueeze(2) std = std.unsqueeze(2) query_results = [ (m[:n], s[:n]) for m, s, n in zip(mean, std, len_test) ] if trim_ends else (mean, std) return results, query_results