def reinforce_w_double_sample_baseline(\
            conditional_loss_fun, log_class_weights,
            class_weights_detached, seq_tensor, z_sample,
            epoch, data,
            grad_estimator_kwargs = None):
    # This is what we call REINFORCE+ in our paper,
    # where we use a second, independent sample from the discrete distribution
    # to use as a baseline

    assert len(z_sample) == log_class_weights.shape[0]

    # compute loss from those categories
    n_classes = log_class_weights.shape[1]
    one_hot_z_sample = get_one_hot_encoding_from_int(z_sample, n_classes)
    conditional_loss_fun_i = conditional_loss_fun(one_hot_z_sample)
    assert len(conditional_loss_fun_i) == log_class_weights.shape[0]

    # get log class_weights
    log_class_weights_i = log_class_weights[seq_tensor, z_sample]

    # get baseline
    z_sample2 = sample_class_weights(class_weights_detached)
    one_hot_z_sample2 = get_one_hot_encoding_from_int(z_sample2, n_classes)
    baseline = conditional_loss_fun(one_hot_z_sample2)

    return get_reinforce_grad_sample(conditional_loss_fun_i,
                                     log_class_weights_i,
                                     baseline) + conditional_loss_fun_i
Пример #2
0
def reinforce_w_double_sample_baseline(\
            conditional_loss_fun, log_class_weights,
            class_weights_detached, seq_tensor, z_sample,
            epoch, data,
            grad_estimator_kwargs = None):

    assert len(z_sample) == log_class_weights.shape[0]

    # compute loss from those categories
    n_classes = log_class_weights.shape[1]
    one_hot_z_sample = get_one_hot_encoding_from_int(z_sample, n_classes)
    conditional_loss_fun_i = conditional_loss_fun(one_hot_z_sample)
    assert len(conditional_loss_fun_i) == log_class_weights.shape[0]

    # get log class_weights
    log_class_weights_i = log_class_weights[seq_tensor, z_sample]

    # get baseline
    z_sample2 = sample_class_weights(class_weights_detached)
    one_hot_z_sample2 = get_one_hot_encoding_from_int(z_sample2, n_classes)
    baseline = conditional_loss_fun(one_hot_z_sample2)

    return get_reinforce_grad_sample(conditional_loss_fun_i,
                                     log_class_weights_i,
                                     baseline) + conditional_loss_fun_i
Пример #3
0
def reinforce(conditional_loss_fun,
              log_class_weights,
              class_weights_detached,
              seq_tensor,
              z_sample,
              epoch,
              data,
              grad_estimator_kwargs=None):
    # z_sample should be a vector of categories
    # conditional_loss_fun is a function that takes in a one hot encoding
    # of z and returns the loss

    assert len(z_sample) == log_class_weights.shape[0]

    # compute loss from those categories
    n_classes = log_class_weights.shape[1]
    one_hot_z_sample = get_one_hot_encoding_from_int(z_sample, n_classes)
    conditional_loss_fun_i = conditional_loss_fun(one_hot_z_sample)
    assert len(conditional_loss_fun_i) == log_class_weights.shape[0]

    # get log class_weights
    log_class_weights_i = log_class_weights[seq_tensor, z_sample]

    return get_reinforce_grad_sample(conditional_loss_fun_i,
                    log_class_weights_i, baseline = 0.0) + \
                        conditional_loss_fun_i
Пример #4
0
def get_full_loss(conditional_loss_fun, class_weights):
    """
    Returns the loss averaged over the class weights.

    Parameters
    ----------
    conditional_loss_fun : function
        Function that takes input a one-hot encoding of the discrete random
        variable z and outputs the loss conditional on z
    class_weights : torch.Tensor
        Array of class weights, with each row corresponding to a datapoint,
        each column corresponding to its weight
    Returns
    -------
    full_loss : float
        The loss averaged over the class weights of the discrete random variable
    """

    full_loss = 0.0

    for i in range(class_weights.shape[1]):

        i_rep = (torch.ones(class_weights.shape[0]) * i).type(torch.LongTensor)
        one_hot_i = get_one_hot_encoding_from_int(i_rep,
                                                  class_weights.shape[1])

        conditional_loss = conditional_loss_fun(one_hot_i)
        assert len(conditional_loss) == class_weights.shape[0]

        full_loss = full_loss + class_weights[:, i] * conditional_loss

    return full_loss.sum()
def reinforce_wr(conditional_loss_fun,
                 log_class_weights,
                 class_weights_detached,
                 seq_tensor,
                 z_sample,
                 epoch,
                 data,
                 n_samples=1,
                 baseline_constant=None):
    # z_sample should be a vector of categories, but is ignored as this function samples itself
    # conditional_loss_fun is a function that takes in a one hot encoding
    # of z and returns the loss

    assert len(z_sample) == log_class_weights.shape[0]

    # Sample with replacement
    ind = torch.stack([
        sample_class_weights(class_weights_detached) for _ in range(n_samples)
    ], -1)

    log_p = log_class_weights.gather(-1, ind)
    n_classes = log_class_weights.shape[1]
    costs = torch.stack([
        conditional_loss_fun(get_one_hot_encoding_from_int(
            z_sample, n_classes)) for z_sample in ind.t()
    ], -1)

    if baseline_constant is None:
        adv = (costs - costs.mean(-1, keepdim=True)) * n_samples / (n_samples -
                                                                    1)
    else:
        adv = costs - baseline_constant

    # Add the costs in case there is a direct dependency on the parameters
    return (adv.detach() * log_p + costs).mean(-1)
Пример #6
0
def get_class_label_cross_entropy(log_class_weights, labels):
    assert np.all(log_class_weights.detach().cpu().numpy() <= 0)
    assert log_class_weights.shape[0] == len(labels)
    n_classes = log_class_weights.shape[1]

    return torch.sum(
        -log_class_weights * \
        get_one_hot_encoding_from_int(labels, n_classes), dim = 1)
Пример #7
0
def relax(conditional_loss_fun,
          class_weights,
          class_weights_detached,
          seq_tensor,
          z_sample,
          epoch,
          data,
          temperature=torch.Tensor([1.0]),
          eta=1.,
          c_phi=lambda x: torch.Tensor([0.0])):
    # with the default c_phi value, this is just REBAR
    # RELAX adds a learned component c_phi

    log_class_weights = torch.log(class_weights)
    # sample gumbel
    gumbel_sample = log_class_weights + \
        gumbel_softmax_lib.sample_gumbel(log_class_weights.size())

    # get hard z
    _, z_sample = gumbel_sample.max(dim=-1)
    n_classes = log_class_weights.shape[1]
    z_one_hot = get_one_hot_encoding_from_int(z_sample, n_classes)
    temperature = torch.clamp(temperature, 0.01, 5.0)

    # get softmax z
    z_softmax = F.softmax(gumbel_sample / temperature[0], dim=-1)

    # conditional softmax z
    z_cond_softmax = \
        gumbel_softmax_lib.gumbel_softmax_conditional_sample(\
            log_class_weights, temperature[0], z_one_hot)

    # get log class_weights
    log_class_weights_i = log_class_weights[seq_tensor, z_sample]

    # reinforce term
    f_z_hard = conditional_loss_fun(z_one_hot.detach())
    f_z_softmax = conditional_loss_fun(z_softmax)
    f_z_cond_softmax = conditional_loss_fun(z_cond_softmax)

    # baseline terms
    c_softmax = c_phi(z_softmax).squeeze()
    z_cond_softmax_detached = \
        gumbel_softmax_lib.gumbel_softmax_conditional_sample(\
            log_class_weights, temperature[0], z_one_hot, detach = True)
    c_cond_softmax = c_phi(z_cond_softmax_detached).squeeze()

    reinforce_term = \
        (f_z_hard - eta * (f_z_cond_softmax - c_cond_softmax)).detach() * \
                        log_class_weights_i + \
                        log_class_weights_i * eta * c_cond_softmax

    # correction term
    correction_term = eta * (f_z_softmax - c_softmax) - \
                        eta * (f_z_cond_softmax - c_cond_softmax)

    return reinforce_term + correction_term + f_z_hard
def get_baseline(conditional_loss_fun, log_class_weights, n_samples,
                 deterministic):

    with torch.no_grad():
        n_classes = log_class_weights.size(-1)
        if deterministic:
            # Compute baseline deterministically as normalized weighted topk elements
            weights, ind = log_class_weights.topk(n_samples, -1)
            weights = weights - weights.logsumexp(-1,
                                                  keepdim=True)  # normalize
            baseline = (torch.stack([
                conditional_loss_fun(
                    get_one_hot_encoding_from_int(ind[:, i], n_classes))
                for i in range(n_samples)
            ], -1) * weights.exp()).sum(-1)
        else:
            # Sample with replacement baseline and compute mean
            baseline = torch.stack([
                conditional_loss_fun(
                    get_one_hot_encoding_from_int(
                        sample_class_weights(log_class_weights.exp()),
                        n_classes)) for i in range(n_samples)
            ], -1).mean(-1)
    return baseline
def reinforce_unordered(conditional_loss_fun,
                        log_class_weights,
                        class_weights_detached,
                        seq_tensor,
                        z_sample,
                        epoch,
                        data,
                        n_samples=1,
                        baseline_separate=False,
                        baseline_n_samples=1,
                        baseline_deterministic=False,
                        baseline_constant=None):

    # Sample without replacement using Gumbel top-k trick
    phi = log_class_weights.detach()
    g_phi = Gumbel(phi, torch.ones_like(phi)).rsample()
    _, ind = g_phi.topk(n_samples, -1)

    log_p = log_class_weights.gather(-1, ind)
    n_classes = log_class_weights.shape[1]
    costs = torch.stack([
        conditional_loss_fun(get_one_hot_encoding_from_int(
            z_sample, n_classes)) for z_sample in ind.t()
    ], -1)

    with torch.no_grad():  # Don't compute gradients for advantage and ratio
        # log_R_s, log_R_ss = compute_log_R(log_p)
        log_R_s, log_R_ss = compute_log_R_O_nfac(log_p)

        if baseline_constant is not None:
            bl_vals = baseline_constant
        elif baseline_separate:
            bl_vals = get_baseline(conditional_loss_fun, log_class_weights,
                                   baseline_n_samples, baseline_deterministic)
            # Same bl for all samples, so add dimension
            bl_vals = bl_vals[:, None]
        elif log_p.size(-1) > 1:
            # Compute built in baseline
            bl_vals = ((log_p[:, None, :] + log_R_ss).exp() *
                       costs[:, None, :]).sum(-1)
        else:
            bl_vals = 0.  # No bl
        adv = costs - bl_vals
    # Also add the costs (with the unordered estimator) in case there is a direct dependency on the parameters
    loss = ((log_p + log_R_s).exp() * adv.detach() +
            (log_p + log_R_s).exp().detach() * costs).sum(-1)
    return loss
def stratified(conditional_loss_fun,
               log_class_weights,
               class_weights_detached,
               seq_tensor,
               z_sample,
               epoch,
               data,
               n_samples=1,
               systematic=False):
    # z_sample should be a vector of categories, but is ignored as this function samples itself
    # conditional_loss_fun is a function that takes in a one hot encoding
    # of z and returns the loss

    assert len(z_sample) == log_class_weights.shape[0]

    cum_p = class_weights_detached.cumsum(-1)
    cum_p = cum_p / cum_p[:, -1, None]  # Normalize

    def stratified_sample(cum_p, i, k, u=None):
        assert i < k
        if u is None:
            u = torch.rand(cum_p.size(0))
        us = ((i + u) / k)
        return (us >= cum_p).sum(-1)

    u = torch.rand(
        cum_p.size(0)
    ) if systematic else None  # Common random numbers if systematic sampling

    # Sample with replacement
    ind = torch.stack(
        [stratified_sample(cum_p, i, n_samples, u) for i in range(n_samples)],
        -1)

    log_p = log_class_weights.gather(-1, ind)
    n_classes = log_class_weights.shape[1]
    costs = torch.stack([
        conditional_loss_fun(get_one_hot_encoding_from_int(
            z_sample, n_classes)) for z_sample in ind.t()
    ], -1)

    adv = costs

    # Add the costs in case there is a direct dependency on the parameters
    return (adv.detach() * log_p + costs).mean(-1)
    def get_rb_loss(self,
                    image,
                    grad_estimator,
                    grad_estimator_kwargs={'grad_estimator_kwargs': None},
                    epoch=None,
                    topk=0,
                    n_samples=1,
                    true_pixel_2d=None):

        if true_pixel_2d is None:
            log_class_weights = self.pixel_attention(image)
            class_weights = torch.exp(log_class_weights)
        else:
            class_weights = self._get_class_weights_from_pixel_2d(
                true_pixel_2d)
            log_class_weights = torch.log(class_weights)

        # kl term
        kl_pixel_probs = (class_weights * log_class_weights).sum()

        self.cache_id_conv_image(image)
        f_pixel = lambda i : self.get_loss_cond_pixel_1d(i, image, \
                            use_cached_image = True)

        avg_pm_loss = 0.0
        # TODO: n_samples would be more elegant as an
        # argument to get_partial_marginal_loss
        for k in range(n_samples):
            pm_loss = rb_lib.get_raoblackwell_ps_loss(f_pixel,
                                                      log_class_weights,
                                                      topk,
                                                      grad_estimator,
                                                      grad_estimator_kwargs,
                                                      epoch,
                                                      data=image)

            avg_pm_loss += pm_loss / n_samples

        map_locations = torch.argmax(log_class_weights.detach(), dim=1)
        one_hot_map_locations = get_one_hot_encoding_from_int(map_locations, \
                                                            self.full_slen**2)
        map_cond_losses = f_pixel(one_hot_map_locations).sum()

        return avg_pm_loss + image.shape[0] * kl_pixel_probs, map_cond_losses
Пример #12
0
def nvil(conditional_loss_fun, log_class_weights, class_weights_detached,
         seq_tensor, z_sample, epoch, data, baseline_nn):

    assert len(z_sample) == log_class_weights.shape[0]

    # compute loss from those categories
    n_classes = log_class_weights.shape[1]
    one_hot_z_sample = get_one_hot_encoding_from_int(z_sample, n_classes)
    conditional_loss_fun_i = conditional_loss_fun(one_hot_z_sample)
    assert len(conditional_loss_fun_i) == log_class_weights.shape[0]

    # get log class_weights
    log_class_weights_i = log_class_weights[seq_tensor, z_sample]

    # get baseline
    baseline = baseline_nn(data).squeeze()

    return get_reinforce_grad_sample(conditional_loss_fun_i,
                    log_class_weights_i, baseline = baseline) + \
                        conditional_loss_fun_i + \
                        (conditional_loss_fun_i.detach() - baseline)**2
def reinforce_w_double_sample_baseline(\
            conditional_loss_fun, log_class_weights,
            class_weights_detached, seq_tensor, z_sample,
            epoch, data,
           baseline_reuse=False,
           baseline_n_samples=1,
           baseline_deterministic=False,
            grad_estimator_kwargs = None):
    # This is what we call REINFORCE+ in our paper,
    # where we use a second, independent sample from the discrete distribution
    # to use as a baseline

    assert len(z_sample) == log_class_weights.shape[0]

    # compute loss from those categories
    n_classes = log_class_weights.shape[1]
    one_hot_z_sample = get_one_hot_encoding_from_int(z_sample, n_classes)
    conditional_loss_fun_i = conditional_loss_fun(one_hot_z_sample)
    assert len(conditional_loss_fun_i) == log_class_weights.shape[0]

    # get log class_weights
    log_class_weights_i = log_class_weights[seq_tensor, z_sample]

    if baseline_reuse and data is rf_cache['prev_data']:
        # This way we use the same baseline for all topk and the sample
        baseline = rf_cache['prev_baseline']
    else:
        baseline = get_baseline(conditional_loss_fun, log_class_weights,
                                baseline_n_samples, baseline_deterministic)
        # get baseline
        # z_sample2 = sample_class_weights(class_weights_detached)
        # one_hot_z_sample2 = get_one_hot_encoding_from_int(z_sample2, n_classes)
        # baseline = conditional_loss_fun(one_hot_z_sample2)
        if baseline_reuse:
            rf_cache['prev_baseline'] = baseline
            rf_cache['prev_data'] = data

    return get_reinforce_grad_sample(conditional_loss_fun_i,
                                     log_class_weights_i,
                                     baseline) + conditional_loss_fun_i
 def get_one_hot_encoding_from_label(self, label):
     return get_one_hot_encoding_from_int(label, self.n_classes)
def reinforce_sum_and_sample(conditional_loss_fun,
                             log_class_weights,
                             class_weights_detached,
                             seq_tensor,
                             z_sample,
                             epoch,
                             data,
                             n_samples=1,
                             baseline_separate=False,
                             baseline_n_samples=1,
                             baseline_deterministic=False,
                             rao_blackwellize=False):

    # Sample without replacement using Gumbel top-k trick
    phi = log_class_weights.detach()
    g_phi = Gumbel(phi, torch.ones_like(phi)).rsample()

    _, ind = g_phi.topk(n_samples, -1)

    log_p = log_class_weights.gather(-1, ind)
    n_classes = log_class_weights.shape[1]
    costs = torch.stack([
        conditional_loss_fun(get_one_hot_encoding_from_int(
            z_sample, n_classes)) for z_sample in ind.t()
    ], -1)

    with torch.no_grad():  # Don't compute gradients for advantage and ratio
        if baseline_separate:
            bl_vals = get_baseline(conditional_loss_fun, log_class_weights,
                                   baseline_n_samples, baseline_deterministic)
            # Same bl for all samples, so add dimension
            bl_vals = bl_vals[:, None]
        else:
            assert baseline_n_samples < n_samples
            bl_sampled_weight = log1mexp(
                log_p[:, :baseline_n_samples -
                      1].logsumexp(-1)).exp().detach()
            bl_vals = (log_p[:, :baseline_n_samples - 1].exp() * costs[:, :baseline_n_samples -1]).sum(-1)\
                      + bl_sampled_weight * costs[:, baseline_n_samples - 1]
            bl_vals = bl_vals[:, None]

    # We compute an 'exact' gradient if the sum of probabilities is roughly more than 1 - 1e-5
    # in which case we can simply sum al the terms and the relative error will be < 1e-5
    use_exact = log_p.logsumexp(-1) > -1e-5
    not_use_exact = use_exact == 0

    cost_exact = costs[use_exact]
    exact_loss = compute_summed_terms(log_p[use_exact], cost_exact,
                                      cost_exact - bl_vals[use_exact])

    log_p_est = log_p[not_use_exact]
    costs_est = costs[not_use_exact]
    bl_vals_est = bl_vals[not_use_exact]

    if rao_blackwellize:
        ap = all_perms(torch.arange(n_samples, dtype=torch.long),
                       device=log_p_est.device)
        log_p_ap = log_p_est[:, ap]
        bl_vals_ap = bl_vals_est.expand_as(costs_est)[:, ap]
        costs_ap = costs_est[:, ap]
        cond_losses = compute_sum_and_sample_loss(log_p_ap, costs_ap,
                                                  bl_vals_ap)

        # Compute probabilities for permutations
        log_probs_perms = log_pl_rec(log_p_ap, -1)
        cond_log_probs_perms = log_probs_perms - log_probs_perms.logsumexp(
            -1, keepdim=True)
        losses = (cond_losses * cond_log_probs_perms.exp()).sum(-1)
    else:
        losses = compute_sum_and_sample_loss(log_p_est, costs_est, bl_vals_est)

    # If they are summed we can simply concatenate but for consistency it is best to place them in order
    all_losses = log_p.new_zeros(log_p.size(0))
    all_losses[use_exact] = exact_loss
    all_losses[not_use_exact] = losses
    return all_losses