def reinforce_unordered(conditional_loss_fun,
                        log_class_weights,
                        class_weights_detached,
                        seq_tensor,
                        z_sample,
                        epoch,
                        data,
                        n_samples=1,
                        baseline_separate=False,
                        baseline_n_samples=1,
                        baseline_deterministic=False,
                        baseline_constant=None):

    # Sample without replacement using Gumbel top-k trick
    phi = log_class_weights.detach()
    g_phi = Gumbel(phi, torch.ones_like(phi)).rsample()
    _, ind = g_phi.topk(n_samples, -1)

    log_p = log_class_weights.gather(-1, ind)
    n_classes = log_class_weights.shape[1]
    costs = torch.stack([
        conditional_loss_fun(get_one_hot_encoding_from_int(
            z_sample, n_classes)) for z_sample in ind.t()
    ], -1)

    with torch.no_grad():  # Don't compute gradients for advantage and ratio
        # log_R_s, log_R_ss = compute_log_R(log_p)
        log_R_s, log_R_ss = compute_log_R_O_nfac(log_p)

        if baseline_constant is not None:
            bl_vals = baseline_constant
        elif baseline_separate:
            bl_vals = get_baseline(conditional_loss_fun, log_class_weights,
                                   baseline_n_samples, baseline_deterministic)
            # Same bl for all samples, so add dimension
            bl_vals = bl_vals[:, None]
        elif log_p.size(-1) > 1:
            # Compute built in baseline
            bl_vals = ((log_p[:, None, :] + log_R_ss).exp() *
                       costs[:, None, :]).sum(-1)
        else:
            bl_vals = 0.  # No bl
        adv = costs - bl_vals
    # Also add the costs (with the unordered estimator) in case there is a direct dependency on the parameters
    loss = ((log_p + log_R_s).exp() * adv.detach() +
            (log_p + log_R_s).exp().detach() * costs).sum(-1)
    return loss
def reinforce_sum_and_sample(conditional_loss_fun,
                             log_class_weights,
                             class_weights_detached,
                             seq_tensor,
                             z_sample,
                             epoch,
                             data,
                             n_samples=1,
                             baseline_separate=False,
                             baseline_n_samples=1,
                             baseline_deterministic=False,
                             rao_blackwellize=False):

    # Sample without replacement using Gumbel top-k trick
    phi = log_class_weights.detach()
    g_phi = Gumbel(phi, torch.ones_like(phi)).rsample()

    _, ind = g_phi.topk(n_samples, -1)

    log_p = log_class_weights.gather(-1, ind)
    n_classes = log_class_weights.shape[1]
    costs = torch.stack([
        conditional_loss_fun(get_one_hot_encoding_from_int(
            z_sample, n_classes)) for z_sample in ind.t()
    ], -1)

    with torch.no_grad():  # Don't compute gradients for advantage and ratio
        if baseline_separate:
            bl_vals = get_baseline(conditional_loss_fun, log_class_weights,
                                   baseline_n_samples, baseline_deterministic)
            # Same bl for all samples, so add dimension
            bl_vals = bl_vals[:, None]
        else:
            assert baseline_n_samples < n_samples
            bl_sampled_weight = log1mexp(
                log_p[:, :baseline_n_samples -
                      1].logsumexp(-1)).exp().detach()
            bl_vals = (log_p[:, :baseline_n_samples - 1].exp() * costs[:, :baseline_n_samples -1]).sum(-1)\
                      + bl_sampled_weight * costs[:, baseline_n_samples - 1]
            bl_vals = bl_vals[:, None]

    # We compute an 'exact' gradient if the sum of probabilities is roughly more than 1 - 1e-5
    # in which case we can simply sum al the terms and the relative error will be < 1e-5
    use_exact = log_p.logsumexp(-1) > -1e-5
    not_use_exact = use_exact == 0

    cost_exact = costs[use_exact]
    exact_loss = compute_summed_terms(log_p[use_exact], cost_exact,
                                      cost_exact - bl_vals[use_exact])

    log_p_est = log_p[not_use_exact]
    costs_est = costs[not_use_exact]
    bl_vals_est = bl_vals[not_use_exact]

    if rao_blackwellize:
        ap = all_perms(torch.arange(n_samples, dtype=torch.long),
                       device=log_p_est.device)
        log_p_ap = log_p_est[:, ap]
        bl_vals_ap = bl_vals_est.expand_as(costs_est)[:, ap]
        costs_ap = costs_est[:, ap]
        cond_losses = compute_sum_and_sample_loss(log_p_ap, costs_ap,
                                                  bl_vals_ap)

        # Compute probabilities for permutations
        log_probs_perms = log_pl_rec(log_p_ap, -1)
        cond_log_probs_perms = log_probs_perms - log_probs_perms.logsumexp(
            -1, keepdim=True)
        losses = (cond_losses * cond_log_probs_perms.exp()).sum(-1)
    else:
        losses = compute_sum_and_sample_loss(log_p_est, costs_est, bl_vals_est)

    # If they are summed we can simply concatenate but for consistency it is best to place them in order
    all_losses = log_p.new_zeros(log_p.size(0))
    all_losses[use_exact] = exact_loss
    all_losses[not_use_exact] = losses
    return all_losses
Example #3
0
def get_raoblackwell_ps_loss(
        conditional_loss_fun,
        log_class_weights,
        topk,
        sample_topk,
        grad_estimator,
        grad_estimator_kwargs={'grad_estimator_kwargs': None},
        epoch=None,
        data=None):
    """
    Returns a pseudo_loss, such that the gradient obtained by calling
    pseudo_loss.backwards() is unbiased for the true loss

    Parameters
    ----------
    conditional_loss_fun : function
        A function that returns the loss conditional on an instance of the
        categorical random variable. It must take in a one-hot-encoding
        matrix (batchsize x n_categories) and return a vector of
        losses, one for each observation in the batch.
    log_class_weights : torch.Tensor
        A tensor of shape batchsize x n_categories of the log class weights
    topk : Integer
        The number of categories to sum over
    grad_estimator : function
        A function that returns the pseudo loss, that is, the loss which
        gives a gradient estimator when .backwards() is called.
        See baselines_lib for details.
    grad_estimator_kwargs : dict
        keyword arguments to gradient estimator
    epoch : int
        The epoch of the optimizer (for Gumbel-softmax, which has an annealing rate)
    data : torch.Tensor
        The data at which we evaluate the loss (for NVIl and RELAX, which have
        a data dependent baseline)

    Returns
    -------
    ps_loss :
        a value such that ps_loss.backward() returns an
        estimate of the gradient.
        In general, ps_loss might not equal the actual loss.
    """

    # class weights from the variational distribution
    assert np.all(log_class_weights.detach().cpu().numpy() <= 0)
    class_weights = torch.exp(log_class_weights.detach())

    if sample_topk:
        # perturb the log_class_weights
        phi = log_class_weights.detach()
        g_phi = Gumbel(phi, torch.ones_like(phi)).rsample()

        _, ind = g_phi.topk(topk + 1, dim=-1)

        topk_domain = ind[..., :-1]
        concentrated_mask = torch.zeros_like(phi).scatter(-1, topk_domain,
                                                          1).detach()
        sample_ind = ind[..., -1]  # Last sample we use as real sample
        seq_tensor = torch.arange(class_weights.size(0),
                                  dtype=torch.long,
                                  device=class_weights.device)
    else:
        # this is the indicator C_k
        concentrated_mask, topk_domain, seq_tensor = \
            get_concentrated_mask(class_weights, topk)
        concentrated_mask = concentrated_mask.float().detach()

    ############################
    # compute the summed term
    summed_term = 0.0

    for i in range(topk):
        # get categories to be summed
        summed_indx = topk_domain[:, i]

        # compute gradient estimate
        grad_summed = \
                grad_estimator(conditional_loss_fun, log_class_weights,
                                class_weights, seq_tensor, \
                                z_sample = summed_indx,
                                epoch = epoch,
                                data = data,
                                **grad_estimator_kwargs)

        # sum
        summed_weights = class_weights[seq_tensor, summed_indx].squeeze()
        summed_term = summed_term + \
                        (grad_summed * summed_weights).sum()

    ############################
    # compute sampled term
    sampled_weight = torch.sum(class_weights * (1 - concentrated_mask),
                               dim=1,
                               keepdim=True)

    if not (topk == class_weights.shape[1]):
        # if we didn't sum everything
        # we sample from the remaining terms

        if not sample_topk:
            # class weights conditioned on being in the diffuse set
            conditional_class_weights = (class_weights + 1e-12) * \
                        (1 - concentrated_mask)  / (sampled_weight + 1e-12)

            # sample from conditional distribution
            conditional_z_sample = sample_class_weights(
                conditional_class_weights)
        else:
            conditional_z_sample = sample_ind  # We have already sampled it

        grad_sampled = grad_estimator(conditional_loss_fun,
                                      log_class_weights,
                                      class_weights,
                                      seq_tensor,
                                      z_sample=conditional_z_sample,
                                      epoch=epoch,
                                      data=data,
                                      **grad_estimator_kwargs)

    else:
        grad_sampled = 0.

    return (grad_sampled * sampled_weight.squeeze()).sum() + summed_term