def reinforce_w_double_sample_baseline(\ conditional_loss_fun, log_class_weights, class_weights_detached, seq_tensor, z_sample, epoch, data, grad_estimator_kwargs = None): # This is what we call REINFORCE+ in our paper, # where we use a second, independent sample from the discrete distribution # to use as a baseline assert len(z_sample) == log_class_weights.shape[0] # compute loss from those categories n_classes = log_class_weights.shape[1] one_hot_z_sample = get_one_hot_encoding_from_int(z_sample, n_classes) conditional_loss_fun_i = conditional_loss_fun(one_hot_z_sample) assert len(conditional_loss_fun_i) == log_class_weights.shape[0] # get log class_weights log_class_weights_i = log_class_weights[seq_tensor, z_sample] # get baseline z_sample2 = sample_class_weights(class_weights_detached) one_hot_z_sample2 = get_one_hot_encoding_from_int(z_sample2, n_classes) baseline = conditional_loss_fun(one_hot_z_sample2) return get_reinforce_grad_sample(conditional_loss_fun_i, log_class_weights_i, baseline) + conditional_loss_fun_i
def reinforce_w_double_sample_baseline(\ conditional_loss_fun, log_class_weights, class_weights_detached, seq_tensor, z_sample, epoch, data, grad_estimator_kwargs = None): assert len(z_sample) == log_class_weights.shape[0] # compute loss from those categories n_classes = log_class_weights.shape[1] one_hot_z_sample = get_one_hot_encoding_from_int(z_sample, n_classes) conditional_loss_fun_i = conditional_loss_fun(one_hot_z_sample) assert len(conditional_loss_fun_i) == log_class_weights.shape[0] # get log class_weights log_class_weights_i = log_class_weights[seq_tensor, z_sample] # get baseline z_sample2 = sample_class_weights(class_weights_detached) one_hot_z_sample2 = get_one_hot_encoding_from_int(z_sample2, n_classes) baseline = conditional_loss_fun(one_hot_z_sample2) return get_reinforce_grad_sample(conditional_loss_fun_i, log_class_weights_i, baseline) + conditional_loss_fun_i
def reinforce(conditional_loss_fun, log_class_weights, class_weights_detached, seq_tensor, z_sample, epoch, data, grad_estimator_kwargs=None): # z_sample should be a vector of categories # conditional_loss_fun is a function that takes in a one hot encoding # of z and returns the loss assert len(z_sample) == log_class_weights.shape[0] # compute loss from those categories n_classes = log_class_weights.shape[1] one_hot_z_sample = get_one_hot_encoding_from_int(z_sample, n_classes) conditional_loss_fun_i = conditional_loss_fun(one_hot_z_sample) assert len(conditional_loss_fun_i) == log_class_weights.shape[0] # get log class_weights log_class_weights_i = log_class_weights[seq_tensor, z_sample] return get_reinforce_grad_sample(conditional_loss_fun_i, log_class_weights_i, baseline = 0.0) + \ conditional_loss_fun_i
def get_full_loss(conditional_loss_fun, class_weights): """ Returns the loss averaged over the class weights. Parameters ---------- conditional_loss_fun : function Function that takes input a one-hot encoding of the discrete random variable z and outputs the loss conditional on z class_weights : torch.Tensor Array of class weights, with each row corresponding to a datapoint, each column corresponding to its weight Returns ------- full_loss : float The loss averaged over the class weights of the discrete random variable """ full_loss = 0.0 for i in range(class_weights.shape[1]): i_rep = (torch.ones(class_weights.shape[0]) * i).type(torch.LongTensor) one_hot_i = get_one_hot_encoding_from_int(i_rep, class_weights.shape[1]) conditional_loss = conditional_loss_fun(one_hot_i) assert len(conditional_loss) == class_weights.shape[0] full_loss = full_loss + class_weights[:, i] * conditional_loss return full_loss.sum()
def reinforce_wr(conditional_loss_fun, log_class_weights, class_weights_detached, seq_tensor, z_sample, epoch, data, n_samples=1, baseline_constant=None): # z_sample should be a vector of categories, but is ignored as this function samples itself # conditional_loss_fun is a function that takes in a one hot encoding # of z and returns the loss assert len(z_sample) == log_class_weights.shape[0] # Sample with replacement ind = torch.stack([ sample_class_weights(class_weights_detached) for _ in range(n_samples) ], -1) log_p = log_class_weights.gather(-1, ind) n_classes = log_class_weights.shape[1] costs = torch.stack([ conditional_loss_fun(get_one_hot_encoding_from_int( z_sample, n_classes)) for z_sample in ind.t() ], -1) if baseline_constant is None: adv = (costs - costs.mean(-1, keepdim=True)) * n_samples / (n_samples - 1) else: adv = costs - baseline_constant # Add the costs in case there is a direct dependency on the parameters return (adv.detach() * log_p + costs).mean(-1)
def get_class_label_cross_entropy(log_class_weights, labels): assert np.all(log_class_weights.detach().cpu().numpy() <= 0) assert log_class_weights.shape[0] == len(labels) n_classes = log_class_weights.shape[1] return torch.sum( -log_class_weights * \ get_one_hot_encoding_from_int(labels, n_classes), dim = 1)
def relax(conditional_loss_fun, class_weights, class_weights_detached, seq_tensor, z_sample, epoch, data, temperature=torch.Tensor([1.0]), eta=1., c_phi=lambda x: torch.Tensor([0.0])): # with the default c_phi value, this is just REBAR # RELAX adds a learned component c_phi log_class_weights = torch.log(class_weights) # sample gumbel gumbel_sample = log_class_weights + \ gumbel_softmax_lib.sample_gumbel(log_class_weights.size()) # get hard z _, z_sample = gumbel_sample.max(dim=-1) n_classes = log_class_weights.shape[1] z_one_hot = get_one_hot_encoding_from_int(z_sample, n_classes) temperature = torch.clamp(temperature, 0.01, 5.0) # get softmax z z_softmax = F.softmax(gumbel_sample / temperature[0], dim=-1) # conditional softmax z z_cond_softmax = \ gumbel_softmax_lib.gumbel_softmax_conditional_sample(\ log_class_weights, temperature[0], z_one_hot) # get log class_weights log_class_weights_i = log_class_weights[seq_tensor, z_sample] # reinforce term f_z_hard = conditional_loss_fun(z_one_hot.detach()) f_z_softmax = conditional_loss_fun(z_softmax) f_z_cond_softmax = conditional_loss_fun(z_cond_softmax) # baseline terms c_softmax = c_phi(z_softmax).squeeze() z_cond_softmax_detached = \ gumbel_softmax_lib.gumbel_softmax_conditional_sample(\ log_class_weights, temperature[0], z_one_hot, detach = True) c_cond_softmax = c_phi(z_cond_softmax_detached).squeeze() reinforce_term = \ (f_z_hard - eta * (f_z_cond_softmax - c_cond_softmax)).detach() * \ log_class_weights_i + \ log_class_weights_i * eta * c_cond_softmax # correction term correction_term = eta * (f_z_softmax - c_softmax) - \ eta * (f_z_cond_softmax - c_cond_softmax) return reinforce_term + correction_term + f_z_hard
def get_baseline(conditional_loss_fun, log_class_weights, n_samples, deterministic): with torch.no_grad(): n_classes = log_class_weights.size(-1) if deterministic: # Compute baseline deterministically as normalized weighted topk elements weights, ind = log_class_weights.topk(n_samples, -1) weights = weights - weights.logsumexp(-1, keepdim=True) # normalize baseline = (torch.stack([ conditional_loss_fun( get_one_hot_encoding_from_int(ind[:, i], n_classes)) for i in range(n_samples) ], -1) * weights.exp()).sum(-1) else: # Sample with replacement baseline and compute mean baseline = torch.stack([ conditional_loss_fun( get_one_hot_encoding_from_int( sample_class_weights(log_class_weights.exp()), n_classes)) for i in range(n_samples) ], -1).mean(-1) return baseline
def reinforce_unordered(conditional_loss_fun, log_class_weights, class_weights_detached, seq_tensor, z_sample, epoch, data, n_samples=1, baseline_separate=False, baseline_n_samples=1, baseline_deterministic=False, baseline_constant=None): # Sample without replacement using Gumbel top-k trick phi = log_class_weights.detach() g_phi = Gumbel(phi, torch.ones_like(phi)).rsample() _, ind = g_phi.topk(n_samples, -1) log_p = log_class_weights.gather(-1, ind) n_classes = log_class_weights.shape[1] costs = torch.stack([ conditional_loss_fun(get_one_hot_encoding_from_int( z_sample, n_classes)) for z_sample in ind.t() ], -1) with torch.no_grad(): # Don't compute gradients for advantage and ratio # log_R_s, log_R_ss = compute_log_R(log_p) log_R_s, log_R_ss = compute_log_R_O_nfac(log_p) if baseline_constant is not None: bl_vals = baseline_constant elif baseline_separate: bl_vals = get_baseline(conditional_loss_fun, log_class_weights, baseline_n_samples, baseline_deterministic) # Same bl for all samples, so add dimension bl_vals = bl_vals[:, None] elif log_p.size(-1) > 1: # Compute built in baseline bl_vals = ((log_p[:, None, :] + log_R_ss).exp() * costs[:, None, :]).sum(-1) else: bl_vals = 0. # No bl adv = costs - bl_vals # Also add the costs (with the unordered estimator) in case there is a direct dependency on the parameters loss = ((log_p + log_R_s).exp() * adv.detach() + (log_p + log_R_s).exp().detach() * costs).sum(-1) return loss
def stratified(conditional_loss_fun, log_class_weights, class_weights_detached, seq_tensor, z_sample, epoch, data, n_samples=1, systematic=False): # z_sample should be a vector of categories, but is ignored as this function samples itself # conditional_loss_fun is a function that takes in a one hot encoding # of z and returns the loss assert len(z_sample) == log_class_weights.shape[0] cum_p = class_weights_detached.cumsum(-1) cum_p = cum_p / cum_p[:, -1, None] # Normalize def stratified_sample(cum_p, i, k, u=None): assert i < k if u is None: u = torch.rand(cum_p.size(0)) us = ((i + u) / k) return (us >= cum_p).sum(-1) u = torch.rand( cum_p.size(0) ) if systematic else None # Common random numbers if systematic sampling # Sample with replacement ind = torch.stack( [stratified_sample(cum_p, i, n_samples, u) for i in range(n_samples)], -1) log_p = log_class_weights.gather(-1, ind) n_classes = log_class_weights.shape[1] costs = torch.stack([ conditional_loss_fun(get_one_hot_encoding_from_int( z_sample, n_classes)) for z_sample in ind.t() ], -1) adv = costs # Add the costs in case there is a direct dependency on the parameters return (adv.detach() * log_p + costs).mean(-1)
def get_rb_loss(self, image, grad_estimator, grad_estimator_kwargs={'grad_estimator_kwargs': None}, epoch=None, topk=0, n_samples=1, true_pixel_2d=None): if true_pixel_2d is None: log_class_weights = self.pixel_attention(image) class_weights = torch.exp(log_class_weights) else: class_weights = self._get_class_weights_from_pixel_2d( true_pixel_2d) log_class_weights = torch.log(class_weights) # kl term kl_pixel_probs = (class_weights * log_class_weights).sum() self.cache_id_conv_image(image) f_pixel = lambda i : self.get_loss_cond_pixel_1d(i, image, \ use_cached_image = True) avg_pm_loss = 0.0 # TODO: n_samples would be more elegant as an # argument to get_partial_marginal_loss for k in range(n_samples): pm_loss = rb_lib.get_raoblackwell_ps_loss(f_pixel, log_class_weights, topk, grad_estimator, grad_estimator_kwargs, epoch, data=image) avg_pm_loss += pm_loss / n_samples map_locations = torch.argmax(log_class_weights.detach(), dim=1) one_hot_map_locations = get_one_hot_encoding_from_int(map_locations, \ self.full_slen**2) map_cond_losses = f_pixel(one_hot_map_locations).sum() return avg_pm_loss + image.shape[0] * kl_pixel_probs, map_cond_losses
def nvil(conditional_loss_fun, log_class_weights, class_weights_detached, seq_tensor, z_sample, epoch, data, baseline_nn): assert len(z_sample) == log_class_weights.shape[0] # compute loss from those categories n_classes = log_class_weights.shape[1] one_hot_z_sample = get_one_hot_encoding_from_int(z_sample, n_classes) conditional_loss_fun_i = conditional_loss_fun(one_hot_z_sample) assert len(conditional_loss_fun_i) == log_class_weights.shape[0] # get log class_weights log_class_weights_i = log_class_weights[seq_tensor, z_sample] # get baseline baseline = baseline_nn(data).squeeze() return get_reinforce_grad_sample(conditional_loss_fun_i, log_class_weights_i, baseline = baseline) + \ conditional_loss_fun_i + \ (conditional_loss_fun_i.detach() - baseline)**2
def reinforce_w_double_sample_baseline(\ conditional_loss_fun, log_class_weights, class_weights_detached, seq_tensor, z_sample, epoch, data, baseline_reuse=False, baseline_n_samples=1, baseline_deterministic=False, grad_estimator_kwargs = None): # This is what we call REINFORCE+ in our paper, # where we use a second, independent sample from the discrete distribution # to use as a baseline assert len(z_sample) == log_class_weights.shape[0] # compute loss from those categories n_classes = log_class_weights.shape[1] one_hot_z_sample = get_one_hot_encoding_from_int(z_sample, n_classes) conditional_loss_fun_i = conditional_loss_fun(one_hot_z_sample) assert len(conditional_loss_fun_i) == log_class_weights.shape[0] # get log class_weights log_class_weights_i = log_class_weights[seq_tensor, z_sample] if baseline_reuse and data is rf_cache['prev_data']: # This way we use the same baseline for all topk and the sample baseline = rf_cache['prev_baseline'] else: baseline = get_baseline(conditional_loss_fun, log_class_weights, baseline_n_samples, baseline_deterministic) # get baseline # z_sample2 = sample_class_weights(class_weights_detached) # one_hot_z_sample2 = get_one_hot_encoding_from_int(z_sample2, n_classes) # baseline = conditional_loss_fun(one_hot_z_sample2) if baseline_reuse: rf_cache['prev_baseline'] = baseline rf_cache['prev_data'] = data return get_reinforce_grad_sample(conditional_loss_fun_i, log_class_weights_i, baseline) + conditional_loss_fun_i
def get_one_hot_encoding_from_label(self, label): return get_one_hot_encoding_from_int(label, self.n_classes)
def reinforce_sum_and_sample(conditional_loss_fun, log_class_weights, class_weights_detached, seq_tensor, z_sample, epoch, data, n_samples=1, baseline_separate=False, baseline_n_samples=1, baseline_deterministic=False, rao_blackwellize=False): # Sample without replacement using Gumbel top-k trick phi = log_class_weights.detach() g_phi = Gumbel(phi, torch.ones_like(phi)).rsample() _, ind = g_phi.topk(n_samples, -1) log_p = log_class_weights.gather(-1, ind) n_classes = log_class_weights.shape[1] costs = torch.stack([ conditional_loss_fun(get_one_hot_encoding_from_int( z_sample, n_classes)) for z_sample in ind.t() ], -1) with torch.no_grad(): # Don't compute gradients for advantage and ratio if baseline_separate: bl_vals = get_baseline(conditional_loss_fun, log_class_weights, baseline_n_samples, baseline_deterministic) # Same bl for all samples, so add dimension bl_vals = bl_vals[:, None] else: assert baseline_n_samples < n_samples bl_sampled_weight = log1mexp( log_p[:, :baseline_n_samples - 1].logsumexp(-1)).exp().detach() bl_vals = (log_p[:, :baseline_n_samples - 1].exp() * costs[:, :baseline_n_samples -1]).sum(-1)\ + bl_sampled_weight * costs[:, baseline_n_samples - 1] bl_vals = bl_vals[:, None] # We compute an 'exact' gradient if the sum of probabilities is roughly more than 1 - 1e-5 # in which case we can simply sum al the terms and the relative error will be < 1e-5 use_exact = log_p.logsumexp(-1) > -1e-5 not_use_exact = use_exact == 0 cost_exact = costs[use_exact] exact_loss = compute_summed_terms(log_p[use_exact], cost_exact, cost_exact - bl_vals[use_exact]) log_p_est = log_p[not_use_exact] costs_est = costs[not_use_exact] bl_vals_est = bl_vals[not_use_exact] if rao_blackwellize: ap = all_perms(torch.arange(n_samples, dtype=torch.long), device=log_p_est.device) log_p_ap = log_p_est[:, ap] bl_vals_ap = bl_vals_est.expand_as(costs_est)[:, ap] costs_ap = costs_est[:, ap] cond_losses = compute_sum_and_sample_loss(log_p_ap, costs_ap, bl_vals_ap) # Compute probabilities for permutations log_probs_perms = log_pl_rec(log_p_ap, -1) cond_log_probs_perms = log_probs_perms - log_probs_perms.logsumexp( -1, keepdim=True) losses = (cond_losses * cond_log_probs_perms.exp()).sum(-1) else: losses = compute_sum_and_sample_loss(log_p_est, costs_est, bl_vals_est) # If they are summed we can simply concatenate but for consistency it is best to place them in order all_losses = log_p.new_zeros(log_p.size(0)) all_losses[use_exact] = exact_loss all_losses[not_use_exact] = losses return all_losses