def full_logll(self, param_list, mean_list, var_list, cov_mat_root_list): cov_mat_root = torch.cat(cov_mat_root_list, dim=1) mean_vector = flatten(mean_list) var_vector = flatten(var_list) param_vector = flatten(param_list) return self.compute_ll_for_block(param_vector, mean_vector, var_vector, cov_mat_root)
def sample(self, scale=1.0, cov=False, seed=None, fullrank=True): if seed is not None: torch.manual_seed(seed) scale_sqrt = scale**0.5 mean_list = [] sq_mean_list = [] if cov: cov_mat_sqrt_list = [] for (module, name) in self.params: mean = module.__getattr__('%s_mean' % name) sq_mean = module.__getattr__('%s_sq_mean' % name) if cov: cov_mat_sqrt = module.__getattr__('%s_cov_mat_sqrt' % name) cov_mat_sqrt_list.append(cov_mat_sqrt) mean_list.append(mean.cpu()) sq_mean_list.append(sq_mean.cpu()) mean = flatten(mean_list) sq_mean = flatten(sq_mean_list) # draw diagonal variance sample var = torch.clamp(sq_mean - mean**2, self.var_clamp) var_sample = var.sqrt() * torch.randn_like(var, requires_grad=False) # if covariance draw low rank sample if cov: cov_mat_sqrt = torch.cat(cov_mat_sqrt_list, dim=1) cov_sample = cov_mat_sqrt.t().matmul( cov_mat_sqrt.new_empty((cov_mat_sqrt.size(0), ), requires_grad=False).normal_()) cov_sample /= ((self.max_num_models - 1)**0.5) rand_sample = var_sample + cov_sample else: rand_sample = var_sample # update sample with mean and scale sample = mean + scale_sqrt * rand_sample sample = sample.unsqueeze(0) # unflatten new sample like the mean sample samples_list = unflatten_like(sample, mean_list) for (module, name), sample in zip(self.params, samples_list): module.register_parameter(name, nn.Parameter(sample.cuda()))
def compute_ll_for_block(self, vec, mean, var, cov_mat_root): vec = flatten(vec) mean = flatten(mean) var = flatten(var) cov_mat_lt = RootLazyTensor(cov_mat_root.t()) var_lt = DiagLazyTensor(var + 1e-6) covar_lt = AddedDiagLazyTensor(var_lt, cov_mat_lt) qdist = MultivariateNormal(mean, covar_lt) with gpytorch.settings.num_trace_samples( 1) and gpytorch.settings.max_cg_iterations(25): return qdist.log_prob(vec)
def block_logdet(self, var, cov_mat_root): var = flatten(var) cov_mat_lt = RootLazyTensor(cov_mat_root.t()) var_lt = DiagLazyTensor(var + 1e-6) covar_lt = AddedDiagLazyTensor(var_lt, cov_mat_lt) return covar_lt.log_det()
def compute_logdet(self, block=False): _, var_list, covar_mat_root_list = self.generate_mean_var_covar() if block: full_logdet = 0 for (var, cov_mat_root) in zip(var_list, covar_mat_root_list): block_logdet = self.block_logdet(var, cov_mat_root) full_logdet += block_logdet else: var_vector = flatten(var_list) cov_mat_root = torch.cat(covar_mat_root_list, dim=1) full_logdet = self.block_logdet(var_vector, cov_mat_root) return full_logdet