def reinforce_unordered(conditional_loss_fun, log_class_weights, class_weights_detached, seq_tensor, z_sample, epoch, data, n_samples=1, baseline_separate=False, baseline_n_samples=1, baseline_deterministic=False, baseline_constant=None): # Sample without replacement using Gumbel top-k trick phi = log_class_weights.detach() g_phi = Gumbel(phi, torch.ones_like(phi)).rsample() _, ind = g_phi.topk(n_samples, -1) log_p = log_class_weights.gather(-1, ind) n_classes = log_class_weights.shape[1] costs = torch.stack([ conditional_loss_fun(get_one_hot_encoding_from_int( z_sample, n_classes)) for z_sample in ind.t() ], -1) with torch.no_grad(): # Don't compute gradients for advantage and ratio # log_R_s, log_R_ss = compute_log_R(log_p) log_R_s, log_R_ss = compute_log_R_O_nfac(log_p) if baseline_constant is not None: bl_vals = baseline_constant elif baseline_separate: bl_vals = get_baseline(conditional_loss_fun, log_class_weights, baseline_n_samples, baseline_deterministic) # Same bl for all samples, so add dimension bl_vals = bl_vals[:, None] elif log_p.size(-1) > 1: # Compute built in baseline bl_vals = ((log_p[:, None, :] + log_R_ss).exp() * costs[:, None, :]).sum(-1) else: bl_vals = 0. # No bl adv = costs - bl_vals # Also add the costs (with the unordered estimator) in case there is a direct dependency on the parameters loss = ((log_p + log_R_s).exp() * adv.detach() + (log_p + log_R_s).exp().detach() * costs).sum(-1) return loss
def reinforce_sum_and_sample(conditional_loss_fun, log_class_weights, class_weights_detached, seq_tensor, z_sample, epoch, data, n_samples=1, baseline_separate=False, baseline_n_samples=1, baseline_deterministic=False, rao_blackwellize=False): # Sample without replacement using Gumbel top-k trick phi = log_class_weights.detach() g_phi = Gumbel(phi, torch.ones_like(phi)).rsample() _, ind = g_phi.topk(n_samples, -1) log_p = log_class_weights.gather(-1, ind) n_classes = log_class_weights.shape[1] costs = torch.stack([ conditional_loss_fun(get_one_hot_encoding_from_int( z_sample, n_classes)) for z_sample in ind.t() ], -1) with torch.no_grad(): # Don't compute gradients for advantage and ratio if baseline_separate: bl_vals = get_baseline(conditional_loss_fun, log_class_weights, baseline_n_samples, baseline_deterministic) # Same bl for all samples, so add dimension bl_vals = bl_vals[:, None] else: assert baseline_n_samples < n_samples bl_sampled_weight = log1mexp( log_p[:, :baseline_n_samples - 1].logsumexp(-1)).exp().detach() bl_vals = (log_p[:, :baseline_n_samples - 1].exp() * costs[:, :baseline_n_samples -1]).sum(-1)\ + bl_sampled_weight * costs[:, baseline_n_samples - 1] bl_vals = bl_vals[:, None] # We compute an 'exact' gradient if the sum of probabilities is roughly more than 1 - 1e-5 # in which case we can simply sum al the terms and the relative error will be < 1e-5 use_exact = log_p.logsumexp(-1) > -1e-5 not_use_exact = use_exact == 0 cost_exact = costs[use_exact] exact_loss = compute_summed_terms(log_p[use_exact], cost_exact, cost_exact - bl_vals[use_exact]) log_p_est = log_p[not_use_exact] costs_est = costs[not_use_exact] bl_vals_est = bl_vals[not_use_exact] if rao_blackwellize: ap = all_perms(torch.arange(n_samples, dtype=torch.long), device=log_p_est.device) log_p_ap = log_p_est[:, ap] bl_vals_ap = bl_vals_est.expand_as(costs_est)[:, ap] costs_ap = costs_est[:, ap] cond_losses = compute_sum_and_sample_loss(log_p_ap, costs_ap, bl_vals_ap) # Compute probabilities for permutations log_probs_perms = log_pl_rec(log_p_ap, -1) cond_log_probs_perms = log_probs_perms - log_probs_perms.logsumexp( -1, keepdim=True) losses = (cond_losses * cond_log_probs_perms.exp()).sum(-1) else: losses = compute_sum_and_sample_loss(log_p_est, costs_est, bl_vals_est) # If they are summed we can simply concatenate but for consistency it is best to place them in order all_losses = log_p.new_zeros(log_p.size(0)) all_losses[use_exact] = exact_loss all_losses[not_use_exact] = losses return all_losses
def get_raoblackwell_ps_loss( conditional_loss_fun, log_class_weights, topk, sample_topk, grad_estimator, grad_estimator_kwargs={'grad_estimator_kwargs': None}, epoch=None, data=None): """ Returns a pseudo_loss, such that the gradient obtained by calling pseudo_loss.backwards() is unbiased for the true loss Parameters ---------- conditional_loss_fun : function A function that returns the loss conditional on an instance of the categorical random variable. It must take in a one-hot-encoding matrix (batchsize x n_categories) and return a vector of losses, one for each observation in the batch. log_class_weights : torch.Tensor A tensor of shape batchsize x n_categories of the log class weights topk : Integer The number of categories to sum over grad_estimator : function A function that returns the pseudo loss, that is, the loss which gives a gradient estimator when .backwards() is called. See baselines_lib for details. grad_estimator_kwargs : dict keyword arguments to gradient estimator epoch : int The epoch of the optimizer (for Gumbel-softmax, which has an annealing rate) data : torch.Tensor The data at which we evaluate the loss (for NVIl and RELAX, which have a data dependent baseline) Returns ------- ps_loss : a value such that ps_loss.backward() returns an estimate of the gradient. In general, ps_loss might not equal the actual loss. """ # class weights from the variational distribution assert np.all(log_class_weights.detach().cpu().numpy() <= 0) class_weights = torch.exp(log_class_weights.detach()) if sample_topk: # perturb the log_class_weights phi = log_class_weights.detach() g_phi = Gumbel(phi, torch.ones_like(phi)).rsample() _, ind = g_phi.topk(topk + 1, dim=-1) topk_domain = ind[..., :-1] concentrated_mask = torch.zeros_like(phi).scatter(-1, topk_domain, 1).detach() sample_ind = ind[..., -1] # Last sample we use as real sample seq_tensor = torch.arange(class_weights.size(0), dtype=torch.long, device=class_weights.device) else: # this is the indicator C_k concentrated_mask, topk_domain, seq_tensor = \ get_concentrated_mask(class_weights, topk) concentrated_mask = concentrated_mask.float().detach() ############################ # compute the summed term summed_term = 0.0 for i in range(topk): # get categories to be summed summed_indx = topk_domain[:, i] # compute gradient estimate grad_summed = \ grad_estimator(conditional_loss_fun, log_class_weights, class_weights, seq_tensor, \ z_sample = summed_indx, epoch = epoch, data = data, **grad_estimator_kwargs) # sum summed_weights = class_weights[seq_tensor, summed_indx].squeeze() summed_term = summed_term + \ (grad_summed * summed_weights).sum() ############################ # compute sampled term sampled_weight = torch.sum(class_weights * (1 - concentrated_mask), dim=1, keepdim=True) if not (topk == class_weights.shape[1]): # if we didn't sum everything # we sample from the remaining terms if not sample_topk: # class weights conditioned on being in the diffuse set conditional_class_weights = (class_weights + 1e-12) * \ (1 - concentrated_mask) / (sampled_weight + 1e-12) # sample from conditional distribution conditional_z_sample = sample_class_weights( conditional_class_weights) else: conditional_z_sample = sample_ind # We have already sampled it grad_sampled = grad_estimator(conditional_loss_fun, log_class_weights, class_weights, seq_tensor, z_sample=conditional_z_sample, epoch=epoch, data=data, **grad_estimator_kwargs) else: grad_sampled = 0. return (grad_sampled * sampled_weight.squeeze()).sum() + summed_term