def forward_importance_sampled(galaxy_vae, resid_image, recon_vars, was_on,
                                seq_tensor, use_importance_sample = True):

    assert len(was_on) == len(seq_tensor)
    assert resid_image.shape[0] == len(seq_tensor)
    assert recon_vars.shape[0] == len(seq_tensor)

    # get class weights
    class_weights = galaxy_vae.get_pixel_probs(resid_image, recon_vars)
    assert_diff(class_weights.sum(1), torch.Tensor([1.0]).to(device), tol = 1e-4)

    # get importance sampling weights
    if use_importance_sample:
        attn_offset = galaxy_vae.attn_offset
        prob_off = class_weights.detach()[:, -1].view(-1, 1)
        importance_weights = \
            get_importance_weights(resid_image.detach(), attn_offset, prob_off)
    else:
        importance_weights = class_weights.detach()

    # sample from importance weights
    a_sample = common_utils.sample_class_weights(importance_weights.detach())
    a_sample[was_on == 0.] = importance_weights.shape[-1] - 1

    recon_mean, recon_var, is_on, kl_z = \
        galaxy_vae.sample_conditional_a(\
            resid_image, recon_vars, a_sample)

    return recon_mean, recon_var, is_on, kl_z, importance_weights, \
                class_weights, a_sample
def reinforce_w_double_sample_baseline(\
            conditional_loss_fun, log_class_weights,
            class_weights_detached, seq_tensor, z_sample,
            epoch, data,
            grad_estimator_kwargs = None):
    # This is what we call REINFORCE+ in our paper,
    # where we use a second, independent sample from the discrete distribution
    # to use as a baseline

    assert len(z_sample) == log_class_weights.shape[0]

    # compute loss from those categories
    n_classes = log_class_weights.shape[1]
    one_hot_z_sample = get_one_hot_encoding_from_int(z_sample, n_classes)
    conditional_loss_fun_i = conditional_loss_fun(one_hot_z_sample)
    assert len(conditional_loss_fun_i) == log_class_weights.shape[0]

    # get log class_weights
    log_class_weights_i = log_class_weights[seq_tensor, z_sample]

    # get baseline
    z_sample2 = sample_class_weights(class_weights_detached)
    one_hot_z_sample2 = get_one_hot_encoding_from_int(z_sample2, n_classes)
    baseline = conditional_loss_fun(one_hot_z_sample2)

    return get_reinforce_grad_sample(conditional_loss_fun_i,
                                     log_class_weights_i,
                                     baseline) + conditional_loss_fun_i
def reinforce_wr(conditional_loss_fun,
                 log_class_weights,
                 class_weights_detached,
                 seq_tensor,
                 z_sample,
                 epoch,
                 data,
                 n_samples=1,
                 baseline_constant=None):
    # z_sample should be a vector of categories, but is ignored as this function samples itself
    # conditional_loss_fun is a function that takes in a one hot encoding
    # of z and returns the loss

    assert len(z_sample) == log_class_weights.shape[0]

    # Sample with replacement
    ind = torch.stack([
        sample_class_weights(class_weights_detached) for _ in range(n_samples)
    ], -1)

    log_p = log_class_weights.gather(-1, ind)
    n_classes = log_class_weights.shape[1]
    costs = torch.stack([
        conditional_loss_fun(get_one_hot_encoding_from_int(
            z_sample, n_classes)) for z_sample in ind.t()
    ], -1)

    if baseline_constant is None:
        adv = (costs - costs.mean(-1, keepdim=True)) * n_samples / (n_samples -
                                                                    1)
    else:
        adv = costs - baseline_constant

    # Add the costs in case there is a direct dependency on the parameters
    return (adv.detach() * log_p + costs).mean(-1)
def get_importance_sampled_loss(f_z, log_q,
                                importance_weights = None,
                                use_baseline = True):
    # class weights from the variational distribution
    assert (log_q.detach() < 0).all()
    class_weights = torch.exp(log_q.detach())

    # why is the tolerance so bad?
    assert_diff(class_weights.sum(1), torch.Tensor([1.0]).to(device), tol = 1e-4)

    seq_tensor = torch.LongTensor([i for i in range(class_weights.shape[0])])

    # sample from conditional distribution
    if importance_weights is not None:
        assert importance_weights.shape[0] == log_q.shape[0]
        assert importance_weights.shape[1] == log_q.shape[1]
        # why is the tolerance so bad?
        assert_diff(importance_weights.sum(1), torch.Tensor([1.0]).to(device), tol = 1e-4)

        # sample from importance weights
        z_sample = common_utils.sample_class_weights(importance_weights)

        # reweight accordingly
        importance_weighting = class_weights[seq_tensor, z_sample] / \
                                    importance_weights[seq_tensor, z_sample]
        assert len(importance_weighting) == len(z_sample)
    else:
        z_sample = common_utils.sample_class_weights(class_weights)
        importance_weighting = 1.0

    f_z_i_sample = f_z(z_sample)
    assert len(f_z_i_sample) == len(z_sample)
    log_q_i_sample = log_q[seq_tensor, z_sample]

    if use_baseline:
        z_sample2 = common_utils.sample_class_weights(class_weights)
        baseline = f_z(z_sample2).detach()
    else:
        baseline = 0.0

    reinforce_grad_sample = \
        common_utils.get_reinforce_grad_sample(f_z_i_sample, log_q_i_sample, \
                                                baseline) + f_z_i_sample
    assert len(reinforce_grad_sample) == len(z_sample)

    return (reinforce_grad_sample * importance_weighting).sum()
def get_baseline(conditional_loss_fun, log_class_weights, n_samples,
                 deterministic):

    with torch.no_grad():
        n_classes = log_class_weights.size(-1)
        if deterministic:
            # Compute baseline deterministically as normalized weighted topk elements
            weights, ind = log_class_weights.topk(n_samples, -1)
            weights = weights - weights.logsumexp(-1,
                                                  keepdim=True)  # normalize
            baseline = (torch.stack([
                conditional_loss_fun(
                    get_one_hot_encoding_from_int(ind[:, i], n_classes))
                for i in range(n_samples)
            ], -1) * weights.exp()).sum(-1)
        else:
            # Sample with replacement baseline and compute mean
            baseline = torch.stack([
                conditional_loss_fun(
                    get_one_hot_encoding_from_int(
                        sample_class_weights(log_class_weights.exp()),
                        n_classes)) for i in range(n_samples)
            ], -1).mean(-1)
    return baseline
示例#6
0
def get_raoblackwell_ps_loss(
        conditional_loss_fun,
        log_class_weights,
        topk,
        grad_estimator,
        grad_estimator_kwargs={'grad_estimator_kwargs': None},
        epoch=None,
        data=None):
    """
    Returns a pseudo_loss, such that the gradient obtained by calling
    pseudo_loss.backwards() is unbiased for the true loss

    Parameters
    ----------
    conditional_loss_fun : function
        A function that returns the loss conditional on an instance of the
        categorical random variable. It must take in a one-hot-encoding
        matrix (batchsize x n_categories) and return a vector of
        losses, one for each observation in the batch.
    log_class_weights : torch.Tensor
        A tensor of shape batchsize x n_categories of the log class weights
    topk : Integer
        The number of categories to sum over
    grad_estimator : function
        A function that returns the pseudo loss, that is, the loss which
        gives a gradient estimator when .backwards() is called.
        See baselines_lib for details.
    grad_estimator_kwargs : dict
        keyword arguments to gradient estimator
    epoch : int
        The epoch of the optimizer (for Gumbel-softmax, which has an annealing rate)
    data : torch.Tensor
        The data at which we evaluate the loss (for NVIl and RELAX, which have
        a data dependent baseline)

    Returns
    -------
    ps_loss :
        a value such that ps_loss.backward() returns an
        estimate of the gradient.
        In general, ps_loss might not equal the actual loss.
    """

    # class weights from the variational distribution
    assert np.all(log_class_weights.detach().cpu().numpy() <= 0)
    class_weights = torch.exp(log_class_weights.detach())

    # this is the indicator C_k
    concentrated_mask, topk_domain, seq_tensor = \
        get_concentrated_mask(class_weights, topk)
    concentrated_mask = concentrated_mask.float().detach()

    ############################
    # compute the summed term
    summed_term = 0.0

    for i in range(topk):
        # get categories to be summed
        summed_indx = topk_domain[:, i]

        # compute gradient estimate
        grad_summed = \
                grad_estimator(conditional_loss_fun, log_class_weights,
                                class_weights, seq_tensor, \
                                z_sample = summed_indx,
                                epoch = epoch,
                                data = data,
                                **grad_estimator_kwargs)

        # sum
        summed_weights = class_weights[seq_tensor, summed_indx].squeeze()
        summed_term = summed_term + \
                        (grad_summed * summed_weights).sum()

    ############################
    # compute sampled term
    sampled_weight = torch.sum(class_weights * (1 - concentrated_mask),
                               dim=1,
                               keepdim=True)

    if not (topk == class_weights.shape[1]):
        # if we didn't sum everything
        # we sample from the remaining terms

        # class weights conditioned on being in the diffuse set
        conditional_class_weights = (class_weights + 1e-12) * \
                    (1 - concentrated_mask)  / (sampled_weight + 1e-12)

        # sample from conditional distribution
        conditional_z_sample = sample_class_weights(conditional_class_weights)

        grad_sampled = grad_estimator(conditional_loss_fun,
                                      log_class_weights,
                                      class_weights,
                                      seq_tensor,
                                      z_sample=conditional_z_sample,
                                      epoch=epoch,
                                      data=data,
                                      **grad_estimator_kwargs)

    else:
        grad_sampled = 0.

    return (grad_sampled * sampled_weight.squeeze()).sum() + summed_term
def get_partial_marginal_loss(f_z, log_q, alpha, topk,
                                use_baseline = True,
                                use_term_one_baseline = True):

    # class weights from the variational distribution
    assert np.all(log_q.detach().cpu().numpy() <= 0)
    class_weights = torch.exp(log_q.detach())
    # assert np.all(np.abs(class_weights.cpu().sum(1).numpy() - 1.0) < 1e-6), \
    #             np.max(np.abs(class_weights.cpu().sum(1).numpy() - 1.0))

    # this is the indicator C_\alpha
    concentrated_mask, topk_domain, seq_tensor = \
        get_concentrated_mask(class_weights, alpha, topk)
    concentrated_mask = concentrated_mask.float().detach()

    # the summed term
    summed_term = 0.0

    for i in range(topk):
        summed_indx = topk_domain[:, i]
        f_z_i = f_z(summed_indx)
        assert len(f_z_i) == log_q.shape[0]

        log_q_i = log_q[seq_tensor, summed_indx]

        if (use_term_one_baseline) and (use_baseline):
            # print('using term 1 baseline')
            z_sample2 = common_utils.sample_class_weights(class_weights)
            baseline = f_z(z_sample2).detach()

        else:
            baseline = 0.0

        reinforce_grad_sample = \
                common_utils.get_reinforce_grad_sample(f_z_i, log_q_i, baseline)

        summed_term = summed_term + \
                        ((reinforce_grad_sample + f_z_i) * \
                        class_weights[seq_tensor, summed_indx].squeeze()).sum()

    # sampled term
    sampled_weight = torch.sum(class_weights * (1 - concentrated_mask), dim = 1, keepdim = True)

    if not(topk == class_weights.shape[1]):
        conditional_class_weights = \
            class_weights * (1 - concentrated_mask) / (sampled_weight)

        conditional_z_sample = common_utils.sample_class_weights(conditional_class_weights)
        # print(conditional_z_sample)

        # just for my own sanity ...
        assert np.all((1 - concentrated_mask)[seq_tensor, conditional_z_sample].cpu().numpy() == 1.), \
                    'sampled_weight {}'.format(sampled_weight)

        f_z_i_sample = f_z(conditional_z_sample)
        log_q_i_sample = log_q[seq_tensor, conditional_z_sample]

        if use_baseline:
            if not use_term_one_baseline:
                # print('using alt. covariate')
                # sample from the conditional distribution instead
                z_sample2 = common_utils.sample_class_weights(conditional_class_weights)
                baseline2 = f_z(z_sample2).detach()

            else:
                z_sample2 = common_utils.sample_class_weights(class_weights)
                baseline2 = f_z(z_sample2).detach()
        else:
            baseline2 = 0.0

        sampled_term = common_utils.get_reinforce_grad_sample(f_z_i_sample,
                                    log_q_i_sample, baseline2) + f_z_i_sample

    else:
        sampled_term = 0.

    return (sampled_term * sampled_weight.squeeze()).sum() + summed_term
def get_partial_marginal_loss(f_z,
                              log_q,
                              alpha,
                              topk,
                              use_baseline=True,
                              use_term_one_baseline=True):

    # class weights from the variational distribution
    assert np.all(log_q.detach().cpu().numpy() <= 0)
    class_weights = torch.exp(log_q.detach())
    # assert np.all(np.abs(class_weights.cpu().sum(1).numpy() - 1.0) < 1e-6), \
    #             np.max(np.abs(class_weights.cpu().sum(1).numpy() - 1.0))

    # this is the indicator C_\alpha
    concentrated_mask, topk_domain, seq_tensor = \
        get_concentrated_mask(class_weights, alpha, topk)
    concentrated_mask = concentrated_mask.float().detach()

    # the summed term
    summed_term = 0.0

    for i in range(topk):
        summed_indx = topk_domain[:, i]
        f_z_i = f_z(summed_indx)
        assert len(f_z_i) == log_q.shape[0]

        log_q_i = log_q[seq_tensor, summed_indx]

        if (use_term_one_baseline) and (use_baseline):
            # print('using term 1 baseline')
            z_sample2 = common_utils.sample_class_weights(class_weights)
            baseline = f_z(z_sample2).detach()

        else:
            baseline = 0.0

        reinforce_grad_sample = \
                common_utils.get_reinforce_grad_sample(f_z_i, log_q_i, baseline)

        summed_term = summed_term + \
                        ((reinforce_grad_sample + f_z_i) * \
                        class_weights[seq_tensor, summed_indx].squeeze()).sum()

    # sampled term
    sampled_weight = torch.sum(class_weights * (1 - concentrated_mask),
                               dim=1,
                               keepdim=True)

    if not (topk == class_weights.shape[1]):
        conditional_class_weights = \
            class_weights * (1 - concentrated_mask) / (sampled_weight)

        conditional_z_sample = common_utils.sample_class_weights(
            conditional_class_weights)
        # print(conditional_z_sample)

        # just for my own sanity ...
        assert np.all((1 - concentrated_mask)[seq_tensor, conditional_z_sample].cpu().numpy() == 1.), \
                    'sampled_weight {}'.format(sampled_weight)

        f_z_i_sample = f_z(conditional_z_sample)
        log_q_i_sample = log_q[seq_tensor, conditional_z_sample]

        if use_baseline:
            if not use_term_one_baseline:
                # print('using alt. covariate')
                # sample from the conditional distribution instead
                z_sample2 = common_utils.sample_class_weights(
                    conditional_class_weights)
                baseline2 = f_z(z_sample2).detach()

            else:
                z_sample2 = common_utils.sample_class_weights(class_weights)
                baseline2 = f_z(z_sample2).detach()
        else:
            baseline2 = 0.0

        sampled_term = common_utils.get_reinforce_grad_sample(
            f_z_i_sample, log_q_i_sample, baseline2) + f_z_i_sample

    else:
        sampled_term = 0.

    return (sampled_term * sampled_weight.squeeze()).sum() + summed_term