def forward_importance_sampled(galaxy_vae, resid_image, recon_vars, was_on, seq_tensor, use_importance_sample = True): assert len(was_on) == len(seq_tensor) assert resid_image.shape[0] == len(seq_tensor) assert recon_vars.shape[0] == len(seq_tensor) # get class weights class_weights = galaxy_vae.get_pixel_probs(resid_image, recon_vars) assert_diff(class_weights.sum(1), torch.Tensor([1.0]).to(device), tol = 1e-4) # get importance sampling weights if use_importance_sample: attn_offset = galaxy_vae.attn_offset prob_off = class_weights.detach()[:, -1].view(-1, 1) importance_weights = \ get_importance_weights(resid_image.detach(), attn_offset, prob_off) else: importance_weights = class_weights.detach() # sample from importance weights a_sample = common_utils.sample_class_weights(importance_weights.detach()) a_sample[was_on == 0.] = importance_weights.shape[-1] - 1 recon_mean, recon_var, is_on, kl_z = \ galaxy_vae.sample_conditional_a(\ resid_image, recon_vars, a_sample) return recon_mean, recon_var, is_on, kl_z, importance_weights, \ class_weights, a_sample
def reinforce_w_double_sample_baseline(\ conditional_loss_fun, log_class_weights, class_weights_detached, seq_tensor, z_sample, epoch, data, grad_estimator_kwargs = None): # This is what we call REINFORCE+ in our paper, # where we use a second, independent sample from the discrete distribution # to use as a baseline assert len(z_sample) == log_class_weights.shape[0] # compute loss from those categories n_classes = log_class_weights.shape[1] one_hot_z_sample = get_one_hot_encoding_from_int(z_sample, n_classes) conditional_loss_fun_i = conditional_loss_fun(one_hot_z_sample) assert len(conditional_loss_fun_i) == log_class_weights.shape[0] # get log class_weights log_class_weights_i = log_class_weights[seq_tensor, z_sample] # get baseline z_sample2 = sample_class_weights(class_weights_detached) one_hot_z_sample2 = get_one_hot_encoding_from_int(z_sample2, n_classes) baseline = conditional_loss_fun(one_hot_z_sample2) return get_reinforce_grad_sample(conditional_loss_fun_i, log_class_weights_i, baseline) + conditional_loss_fun_i
def reinforce_wr(conditional_loss_fun, log_class_weights, class_weights_detached, seq_tensor, z_sample, epoch, data, n_samples=1, baseline_constant=None): # z_sample should be a vector of categories, but is ignored as this function samples itself # conditional_loss_fun is a function that takes in a one hot encoding # of z and returns the loss assert len(z_sample) == log_class_weights.shape[0] # Sample with replacement ind = torch.stack([ sample_class_weights(class_weights_detached) for _ in range(n_samples) ], -1) log_p = log_class_weights.gather(-1, ind) n_classes = log_class_weights.shape[1] costs = torch.stack([ conditional_loss_fun(get_one_hot_encoding_from_int( z_sample, n_classes)) for z_sample in ind.t() ], -1) if baseline_constant is None: adv = (costs - costs.mean(-1, keepdim=True)) * n_samples / (n_samples - 1) else: adv = costs - baseline_constant # Add the costs in case there is a direct dependency on the parameters return (adv.detach() * log_p + costs).mean(-1)
def get_importance_sampled_loss(f_z, log_q, importance_weights = None, use_baseline = True): # class weights from the variational distribution assert (log_q.detach() < 0).all() class_weights = torch.exp(log_q.detach()) # why is the tolerance so bad? assert_diff(class_weights.sum(1), torch.Tensor([1.0]).to(device), tol = 1e-4) seq_tensor = torch.LongTensor([i for i in range(class_weights.shape[0])]) # sample from conditional distribution if importance_weights is not None: assert importance_weights.shape[0] == log_q.shape[0] assert importance_weights.shape[1] == log_q.shape[1] # why is the tolerance so bad? assert_diff(importance_weights.sum(1), torch.Tensor([1.0]).to(device), tol = 1e-4) # sample from importance weights z_sample = common_utils.sample_class_weights(importance_weights) # reweight accordingly importance_weighting = class_weights[seq_tensor, z_sample] / \ importance_weights[seq_tensor, z_sample] assert len(importance_weighting) == len(z_sample) else: z_sample = common_utils.sample_class_weights(class_weights) importance_weighting = 1.0 f_z_i_sample = f_z(z_sample) assert len(f_z_i_sample) == len(z_sample) log_q_i_sample = log_q[seq_tensor, z_sample] if use_baseline: z_sample2 = common_utils.sample_class_weights(class_weights) baseline = f_z(z_sample2).detach() else: baseline = 0.0 reinforce_grad_sample = \ common_utils.get_reinforce_grad_sample(f_z_i_sample, log_q_i_sample, \ baseline) + f_z_i_sample assert len(reinforce_grad_sample) == len(z_sample) return (reinforce_grad_sample * importance_weighting).sum()
def get_baseline(conditional_loss_fun, log_class_weights, n_samples, deterministic): with torch.no_grad(): n_classes = log_class_weights.size(-1) if deterministic: # Compute baseline deterministically as normalized weighted topk elements weights, ind = log_class_weights.topk(n_samples, -1) weights = weights - weights.logsumexp(-1, keepdim=True) # normalize baseline = (torch.stack([ conditional_loss_fun( get_one_hot_encoding_from_int(ind[:, i], n_classes)) for i in range(n_samples) ], -1) * weights.exp()).sum(-1) else: # Sample with replacement baseline and compute mean baseline = torch.stack([ conditional_loss_fun( get_one_hot_encoding_from_int( sample_class_weights(log_class_weights.exp()), n_classes)) for i in range(n_samples) ], -1).mean(-1) return baseline
def get_raoblackwell_ps_loss( conditional_loss_fun, log_class_weights, topk, grad_estimator, grad_estimator_kwargs={'grad_estimator_kwargs': None}, epoch=None, data=None): """ Returns a pseudo_loss, such that the gradient obtained by calling pseudo_loss.backwards() is unbiased for the true loss Parameters ---------- conditional_loss_fun : function A function that returns the loss conditional on an instance of the categorical random variable. It must take in a one-hot-encoding matrix (batchsize x n_categories) and return a vector of losses, one for each observation in the batch. log_class_weights : torch.Tensor A tensor of shape batchsize x n_categories of the log class weights topk : Integer The number of categories to sum over grad_estimator : function A function that returns the pseudo loss, that is, the loss which gives a gradient estimator when .backwards() is called. See baselines_lib for details. grad_estimator_kwargs : dict keyword arguments to gradient estimator epoch : int The epoch of the optimizer (for Gumbel-softmax, which has an annealing rate) data : torch.Tensor The data at which we evaluate the loss (for NVIl and RELAX, which have a data dependent baseline) Returns ------- ps_loss : a value such that ps_loss.backward() returns an estimate of the gradient. In general, ps_loss might not equal the actual loss. """ # class weights from the variational distribution assert np.all(log_class_weights.detach().cpu().numpy() <= 0) class_weights = torch.exp(log_class_weights.detach()) # this is the indicator C_k concentrated_mask, topk_domain, seq_tensor = \ get_concentrated_mask(class_weights, topk) concentrated_mask = concentrated_mask.float().detach() ############################ # compute the summed term summed_term = 0.0 for i in range(topk): # get categories to be summed summed_indx = topk_domain[:, i] # compute gradient estimate grad_summed = \ grad_estimator(conditional_loss_fun, log_class_weights, class_weights, seq_tensor, \ z_sample = summed_indx, epoch = epoch, data = data, **grad_estimator_kwargs) # sum summed_weights = class_weights[seq_tensor, summed_indx].squeeze() summed_term = summed_term + \ (grad_summed * summed_weights).sum() ############################ # compute sampled term sampled_weight = torch.sum(class_weights * (1 - concentrated_mask), dim=1, keepdim=True) if not (topk == class_weights.shape[1]): # if we didn't sum everything # we sample from the remaining terms # class weights conditioned on being in the diffuse set conditional_class_weights = (class_weights + 1e-12) * \ (1 - concentrated_mask) / (sampled_weight + 1e-12) # sample from conditional distribution conditional_z_sample = sample_class_weights(conditional_class_weights) grad_sampled = grad_estimator(conditional_loss_fun, log_class_weights, class_weights, seq_tensor, z_sample=conditional_z_sample, epoch=epoch, data=data, **grad_estimator_kwargs) else: grad_sampled = 0. return (grad_sampled * sampled_weight.squeeze()).sum() + summed_term
def get_partial_marginal_loss(f_z, log_q, alpha, topk, use_baseline = True, use_term_one_baseline = True): # class weights from the variational distribution assert np.all(log_q.detach().cpu().numpy() <= 0) class_weights = torch.exp(log_q.detach()) # assert np.all(np.abs(class_weights.cpu().sum(1).numpy() - 1.0) < 1e-6), \ # np.max(np.abs(class_weights.cpu().sum(1).numpy() - 1.0)) # this is the indicator C_\alpha concentrated_mask, topk_domain, seq_tensor = \ get_concentrated_mask(class_weights, alpha, topk) concentrated_mask = concentrated_mask.float().detach() # the summed term summed_term = 0.0 for i in range(topk): summed_indx = topk_domain[:, i] f_z_i = f_z(summed_indx) assert len(f_z_i) == log_q.shape[0] log_q_i = log_q[seq_tensor, summed_indx] if (use_term_one_baseline) and (use_baseline): # print('using term 1 baseline') z_sample2 = common_utils.sample_class_weights(class_weights) baseline = f_z(z_sample2).detach() else: baseline = 0.0 reinforce_grad_sample = \ common_utils.get_reinforce_grad_sample(f_z_i, log_q_i, baseline) summed_term = summed_term + \ ((reinforce_grad_sample + f_z_i) * \ class_weights[seq_tensor, summed_indx].squeeze()).sum() # sampled term sampled_weight = torch.sum(class_weights * (1 - concentrated_mask), dim = 1, keepdim = True) if not(topk == class_weights.shape[1]): conditional_class_weights = \ class_weights * (1 - concentrated_mask) / (sampled_weight) conditional_z_sample = common_utils.sample_class_weights(conditional_class_weights) # print(conditional_z_sample) # just for my own sanity ... assert np.all((1 - concentrated_mask)[seq_tensor, conditional_z_sample].cpu().numpy() == 1.), \ 'sampled_weight {}'.format(sampled_weight) f_z_i_sample = f_z(conditional_z_sample) log_q_i_sample = log_q[seq_tensor, conditional_z_sample] if use_baseline: if not use_term_one_baseline: # print('using alt. covariate') # sample from the conditional distribution instead z_sample2 = common_utils.sample_class_weights(conditional_class_weights) baseline2 = f_z(z_sample2).detach() else: z_sample2 = common_utils.sample_class_weights(class_weights) baseline2 = f_z(z_sample2).detach() else: baseline2 = 0.0 sampled_term = common_utils.get_reinforce_grad_sample(f_z_i_sample, log_q_i_sample, baseline2) + f_z_i_sample else: sampled_term = 0. return (sampled_term * sampled_weight.squeeze()).sum() + summed_term
def get_partial_marginal_loss(f_z, log_q, alpha, topk, use_baseline=True, use_term_one_baseline=True): # class weights from the variational distribution assert np.all(log_q.detach().cpu().numpy() <= 0) class_weights = torch.exp(log_q.detach()) # assert np.all(np.abs(class_weights.cpu().sum(1).numpy() - 1.0) < 1e-6), \ # np.max(np.abs(class_weights.cpu().sum(1).numpy() - 1.0)) # this is the indicator C_\alpha concentrated_mask, topk_domain, seq_tensor = \ get_concentrated_mask(class_weights, alpha, topk) concentrated_mask = concentrated_mask.float().detach() # the summed term summed_term = 0.0 for i in range(topk): summed_indx = topk_domain[:, i] f_z_i = f_z(summed_indx) assert len(f_z_i) == log_q.shape[0] log_q_i = log_q[seq_tensor, summed_indx] if (use_term_one_baseline) and (use_baseline): # print('using term 1 baseline') z_sample2 = common_utils.sample_class_weights(class_weights) baseline = f_z(z_sample2).detach() else: baseline = 0.0 reinforce_grad_sample = \ common_utils.get_reinforce_grad_sample(f_z_i, log_q_i, baseline) summed_term = summed_term + \ ((reinforce_grad_sample + f_z_i) * \ class_weights[seq_tensor, summed_indx].squeeze()).sum() # sampled term sampled_weight = torch.sum(class_weights * (1 - concentrated_mask), dim=1, keepdim=True) if not (topk == class_weights.shape[1]): conditional_class_weights = \ class_weights * (1 - concentrated_mask) / (sampled_weight) conditional_z_sample = common_utils.sample_class_weights( conditional_class_weights) # print(conditional_z_sample) # just for my own sanity ... assert np.all((1 - concentrated_mask)[seq_tensor, conditional_z_sample].cpu().numpy() == 1.), \ 'sampled_weight {}'.format(sampled_weight) f_z_i_sample = f_z(conditional_z_sample) log_q_i_sample = log_q[seq_tensor, conditional_z_sample] if use_baseline: if not use_term_one_baseline: # print('using alt. covariate') # sample from the conditional distribution instead z_sample2 = common_utils.sample_class_weights( conditional_class_weights) baseline2 = f_z(z_sample2).detach() else: z_sample2 = common_utils.sample_class_weights(class_weights) baseline2 = f_z(z_sample2).detach() else: baseline2 = 0.0 sampled_term = common_utils.get_reinforce_grad_sample( f_z_i_sample, log_q_i_sample, baseline2) + f_z_i_sample else: sampled_term = 0. return (sampled_term * sampled_weight.squeeze()).sum() + summed_term