def neighbors(x, partition_samples, edge_mat_samples, n_vertices, uniquely=False): """ :param x: 1D Tensor :param partition_samples: :param edge_mat_samples: :param n_vertices: :param uniquely: :return: """ nbds = x.new_empty((0, x.numel())) for i in range(len(partition_samples)): grouped_x = group_input(x.unsqueeze(0), partition_samples[i], n_vertices).squeeze(0) grouped_nbd = _cartesian_neighbors(grouped_x, edge_mat_samples[i]) nbd = ungroup_input(grouped_nbd, partition_samples[i], n_vertices) added_ind = [] if uniquely: for j in range(nbd.size(0)): if not torch.any(torch.all(nbds == nbd[j], dim=1)): added_ind.append(j) if len(added_ind) > 0: nbds = torch.cat([nbds, nbd[added_ind]]) else: nbds = torch.cat([nbds, nbd]) return nbds
def inference_sampling(input_data, output_data, n_vertices, hyper_samples, log_beta_samples, partition_samples, freq_samples, basis_samples): """ :param input_data: :param output_data: :param n_vertices: :param hyper_samples: :param log_beta_samples: :param partition_samples: :param freq_samples: :param basis_samples: :return: """ inference_samples = [] for s in range(len(hyper_samples)): grouped_log_beta = torch.stack([torch.sum(log_beta_samples[s][subset]) for subset in partition_samples[s]]) kernel = DiffusionKernel(grouped_log_beta=grouped_log_beta, fourier_freq_list=freq_samples[s], fourier_basis_list=basis_samples[s]) model = GPRegression(kernel=kernel) model.vec_to_param(hyper_samples[s]) grouped_input_data = group_input(input_data=input_data, sorted_partition=partition_samples[s], n_vertices=n_vertices) inference = Inference((grouped_input_data, output_data), model=model) inference_samples.append(inference) return inference_samples
def acquisition_expectation(x, inference_samples, partition_samples, n_vertices, acquisition_func, reference=None): if x.dim() == 1: x = x.unsqueeze(0) acquisition_sample_list = [] for s in range(len(inference_samples)): hyper = inference_samples[s].model.param_to_vec() grouped_x = group_input(x, sorted_partition=partition_samples[s], n_vertices=n_vertices) pred_dist = inference_samples[s].predict(grouped_x, hyper=hyper, verbose=False) pred_mean_sample = pred_dist[0].detach() pred_var_sample = pred_dist[1].detach() acquisition_sample_list.append( acquisition_func(pred_mean_sample[:, 0], pred_var_sample[:, 0], reference=reference)) return torch.stack(acquisition_sample_list, 1).sum(1, keepdim=True)
def acquisition_expectation(x, inference_samples, partition_samples, n_vertices, acquisition_func=expected_improvement, reference=None): """ Using posterior samples, the acquisition function is also averaged over posterior samples :param x: 1d or 2d tensor :param inference_samples: inference method for each posterior sample :param partition_samples: :param n_vertices: :param acquisition_func: :param reference: :return: """ if x.dim() == 1: x = x.unsqueeze(0) acquisition_sample_list = [] for s in range(len(inference_samples)): hyper = inference_samples[s].model.param_to_vec() grouped_x = group_input(x, sorted_partition=partition_samples[s], n_vertices=n_vertices) pred_dist = inference_samples[s].predict(grouped_x, hyper=hyper, verbose=False) pred_mean_sample = pred_dist[0].detach() pred_var_sample = pred_dist[1].detach() acquisition_sample_list.append(acquisition_func(pred_mean_sample[:, 0], pred_var_sample[:, 0], reference=reference)) return torch.stack(acquisition_sample_list, 1).sum(1, keepdim=True)
def slice_edgeweight(model, input_data, output_data, n_vertices, log_beta, sorted_partition, fourier_freq_list, fourier_basis_list, ind): """ Slice sampling the edgeweight(exp('log_beta')) at 'ind' in 'log_beta' vector Note that model.kernel members (fourier_freq_list, fourier_basis_list) are updated. :param model: :param input_data: :param output_data: :param n_vertices: 1d np.array :param log_beta: :param sorted_partition: Partition of {0, ..., K-1}, list of subsets(list) :param fourier_freq_list: :param fourier_basis_list: :param ind: :return: """ updated_subset_ind = [(ind in subset) for subset in sorted_partition].index(True) updated_subset = sorted_partition[updated_subset_ind] log_beta_rest = torch.sum(log_beta[updated_subset]) - log_beta[ind] grouped_log_beta = torch.stack( [torch.sum(log_beta[subset]) for subset in sorted_partition]) model.kernel.grouped_log_beta = grouped_log_beta model.kernel.fourier_freq_list = fourier_freq_list model.kernel.fourier_basis_list = fourier_basis_list grouped_input_data = group_input(input_data=input_data, sorted_partition=sorted_partition, n_vertices=n_vertices) inference = Inference(train_data=(grouped_input_data, output_data), model=model) def logp(log_beta_i): """ Note that model.kernel members (fourier_freq_list, fourier_basis_list) are updated. :param log_beta_i: numeric(float) :return: numeric(float) """ log_prior = log_prior_edgeweight(log_beta_i) if np.isinf(log_prior): return log_prior model.kernel.grouped_log_beta[ updated_subset_ind] = log_beta_rest + log_beta_i log_likelihood = float(-inference.negative_log_likelihood( hyper=model.param_to_vec())) return log_prior + log_likelihood x0 = float(log_beta[ind]) x1 = univariate_slice_sampling(logp, x0) log_beta[ind] = x1 model.kernel.grouped_log_beta[updated_subset_ind] = log_beta_rest + x1 return log_beta
def slice_hyper(model, input_data, output_data, n_vertices, sorted_partition): """ :param model: :param input_data: :param output_data: :return: """ grouped_input_data = group_input(input_data=input_data, sorted_partition=sorted_partition, n_vertices=n_vertices) inference = Inference(train_data=(grouped_input_data, output_data), model=model) # Randomly shuffling order can be considered, here the order is in const_mean, kernel_amp, noise_var slice_constmean(inference) slice_kernelamp(inference) slice_noisevar(inference)
def prediction_statistic(x, inference_samples, partition_samples, n_vertices): if x.dim() == 1: x = x.unsqueeze(0) mean_sample_list = [] std_sample_list = [] var_sample_list = [] for s in range(len(inference_samples)): grouped_x = group_input(input_data=x, sorted_partition=partition_samples[s], n_vertices=n_vertices) pred_dist = inference_samples[s].predict(grouped_x) pred_mean_sample = pred_dist[0] pred_var_sample = pred_dist[1] pred_std_sample = pred_var_sample ** 0.5 mean_sample_list.append(pred_mean_sample.data) std_sample_list.append(pred_std_sample.data) var_sample_list.append(pred_var_sample.data) return torch.cat(mean_sample_list, 1).mean(1, keepdim=True),\ torch.cat(std_sample_list, 1).mean(1, keepdim=True),\ torch.cat(var_sample_list, 1).mean(1, keepdim=True)
def function(self, variables): combo = torch.tensor(variables).view(1, -1) if repr(combo.view(-1)) in self.seen: return INF, -INF means = [] stds = [] for s in range(len(self.inference_samples)): hyper = self.inference_samples[s].model.param_to_vec() grouped_x = group_input(combo, sorted_partition=self.partition_samples[s], n_vertices=self.n_vertices) pred_dist = self.inference_samples[s].predict(grouped_x, hyper=hyper, verbose=False) pred_mean_sample = pred_dist[0].detach() pred_var_sample = pred_dist[1].detach() means.append(pred_mean_sample[:, 0]) stds.append(torch.sqrt(pred_var_sample[:, 0])) mean = torch.mean(torch.stack(means)).item() std = torch.mean(torch.stack(stds)).item() return mean, std