示例#1
0
def calc_var(log_weight,  args, snis=None, all_sample_mean=True):
    """
    Args:
        log_weight : [batch, samples, *]
        args : either args object or partition [batch, 1, K partitions]
        all_sample_mean : returns mean over samples if True
        snis : optionally feed weights to avoid recomputation
    Returns:
        Variance across importance samples at each beta (2nd derivative of logZβ)
    """
    log_weight = log_weight.unsqueeze(-1) if len(
        log_weight.shape) < 3 else log_weight

    if snis is None:
        try:  # args.partition or partition tensor directly
            partition = args.partition
        except:
            partition = args
        beta_iw = log_weight * partition
        snis = exponentiate_and_normalize(
            beta_iw, dim=1)
    else:
        pass

    exp_ = torch.sum(snis * log_weight, dim=1)
    exp2 = torch.sum(snis * torch.pow(log_weight, 2), dim=1)

    to_return = exp2 - torch.pow(exp_, 2)

    # VM: May have to switch to_return to E[(X-EX)(X-EX)] form, had numerical issues in the past
    assert not torch.isnan(to_return).any(), "Nan in calc_var() - switch to E[(X-EX)(X-EX)] form for numerical stability"

    return torch.mean(to_return, dim=0) if all_sample_mean else to_return
示例#2
0
def calc_fourth(log_weight, args, snis=None, all_sample_mean=True):
    """
    Args:
        log_weight : [batch, samples, *]
        args : either args object or partition [batch, 1, K partitions]
        all_sample_mean : returns mean over samples if True
        snis : optionally feed weights to avoid recomputation
    Returns:
        Fourth derivative of logZβ at each beta
    """
    log_weight = log_weight.unsqueeze(-1) if len(
        log_weight.shape) < 3 else log_weight

    if snis is None:
        try:  # args.partition or partition tensor directly
            partition = args.partition
        except:
            partition = args
        beta_iw = log_weight * partition
        snis = exponentiate_and_normalize(
            beta_iw, dim=1)
    else:
        pass

    exp = torch.sum(snis * log_weight, dim=1)
    exp2 = torch.sum(snis * torch.pow(log_weight, 2), dim=1)
    exp3 = torch.sum(snis * torch.pow(log_weight, 3), dim=1)
    exp4 = torch.sum(snis * torch.pow(log_weight, 4), dim=1)

    to_return = exp4 - 6*torch.pow(exp, 4) + 12*exp2 * \
        torch.pow(exp, 2) - 3*torch.pow(exp2, 2) - 4*exp*exp3
    return torch.mean(to_return, dim=0) if all_sample_mean else to_return
示例#3
0
def calc_exp(log_weight, args, all_sample_mean=True, snis=None):
    """
    Args:
        log_weight : [batch, samples, *]
        args : either args object or partition [batch, 1, K partitions]
        all_sample_mean : True averages over batch

    TO DO : replace for cleaner integration into code (pulled directly from Rob's)
    """
    log_weight = log_weight.unsqueeze(-1) if len(
        log_weight.shape) < 3 else log_weight
    if snis is None:
        try:  # args.partition or partition tensor directly
            partition = args.partition
        except:
            partition = args
        beta_iw = log_weight * partition
        snis = exponentiate_and_normalize(
            beta_iw, dim=1)
    else:
        pass

    exp = snis * log_weight
    exp = torch.sum(exp, dim=1)
    return torch.mean(exp, dim=0) if all_sample_mean else exp
示例#4
0
def compute_tvo_loss(log_weight, log_p, log_q, args):
    """Args:
        log_weight: tensor of shape [batch_size, num_particles]
        log_p: tensor of shape [batch_size, num_particles]
        log_q: tensor of shape [batch_size, num_particles]
        partition: partition of [0, 1];
            tensor of shape [num_partitions + 1] where partition[0] is zero and
            partition[-1] is one;
            see https://en.wikipedia.org/wiki/Partition_of_an_interval
        num_particles: int
        integration: left, right or trapz

    Returns:
        loss: scalar that we call .backward() on and step the optimizer.
        elbo: average elbo over data

    """
    partition = args.partition
    num_particles = args.S
    integration = args.integration

    log_weight = log_weight.unsqueeze(-1) if len(
        log_weight.shape) < 3 else log_weight

    heated_log_weight = log_weight * partition
    heated_normalized_weight = exponentiate_and_normalize(
        heated_log_weight, dim=1)

    log_p = log_p.unsqueeze(-1)
    log_q = log_q.unsqueeze(-1)

    thermo_logp = partition * log_p + \
        (1 - partition) * log_q

    wf = heated_normalized_weight * log_weight
    w_detached = heated_normalized_weight.detach()

    if num_particles == 1:
        correction = 1
    else:
        correction = num_particles / (num_particles - 1)

    cov = correction * torch.sum(
        w_detached * (log_weight - torch.sum(wf, dim=1, keepdim=True)).detach() *
        (thermo_logp - torch.sum(thermo_logp * w_detached, dim=1, keepdim=True)),
        dim=1)

    multiplier = _get_multiplier(partition, integration)

    loss = -torch.mean(torch.sum(
        multiplier * (cov + torch.sum(
            w_detached * log_weight, dim=1)),
        dim=1))

    return loss
示例#5
0
def get_total_beta_exp(model, args, betas, S):
    with torch.no_grad():
        log_weight = []
        for obs in args.train_data_loader:
            model.set_internals(obs, S)
            elbo = model.elbo()

            heated = elbo.unsqueeze(-1) * betas
            snis = exponentiate_and_normalize(heated, dim=1)
            exp = torch.mean(torch.sum(snis*elbo.unsqueeze(-1), dim=1), dim=0)
            log_weight.append(exp.unsqueeze(0))
        log_weight = torch.cat(log_weight, dim=0)

    return torch.mean(log_weight, dim=0)
示例#6
0
def get_tvo_components(log_weight, log_p, log_q, args, heated_normalized_weight=None):

    partition = args.partition
    num_particles = args.S
    integration = args.integration

    # feed heated_normalized_weight when doing importance resampling (due to uniform expectations at selected indices)
    if heated_normalized_weight is None:

        heated_log_weight = log_weight * partition
        heated_normalized_weight = exponentiate_and_normalize(
            heated_log_weight, dim=1)

    log_p = log_p.unsqueeze(-1) if len(log_p.shape) < 3 else log_p
    log_q = log_q.unsqueeze(-1) if len(log_q.shape) < 3 else log_q
    thermo_logp = partition * log_p + \
        (1 - partition) * log_q

    snis_logw = heated_normalized_weight * log_weight
    snis_detach = heated_normalized_weight.detach()

    return snis_logw, thermo_logp, snis_detach
示例#7
0
def calc_var(log_weight, args, snis=None, all_sample_mean=True, betas=None):
    """
    Args:
        log_weight : [batch, samples, *]
        args : args object
                *** Note: only args.partition is used (this can be replaced by betas arg)
        all_sample_mean : returns mean over samples if True
        betas : β points at which to evaluate variance (shape: [K β points] or  [batch or 1, 1, K β points])
    Returns:
        Variance across importance samples at each beta (2nd derivative of log Z_β wrt β)
    """
    from ml_helpers import exponentiate_and_normalize

    log_weight = log_weight.unsqueeze(-1) if len(
        log_weight.shape) < 3 else log_weight

    if betas is None:
        try:
            partition = args.partition
        except:
            partition = args
    else:
        partition = betas

    beta_iw = log_weight * partition
    snis = exponentiate_and_normalize(
        beta_iw, dim=1)
    snis_detach = snis.detach()

    # expected log weights under π_β
    exp_pi_beta_log_weight = torch.sum(
        snis_detach*log_weight, dim=1, keepdim=True)

    variance = torch.sum(
        snis_detach*torch.pow(log_weight - exp_pi_beta_log_weight, 2), dim=1)

    # variance is [batch, # β].  set all_sample_mean = True to average over batch dimension
    return torch.mean(variance, dim=0) if all_sample_mean else variance
示例#8
0
def calc_exp(log_weight, args, all_sample_mean=True, snis=None):
    """
    Args:
        log_weight : [batch, samples, *]
        args : either args object or partition [batch, 1, K partitions]
        all_sample_mean : True averages over batch

    TO DO : replace for cleaner integration into code (pulled directly from Rob's)
    """

    log_weight = log_weight.unsqueeze(-1) if len(
        log_weight.shape) < 3 else log_weight
    if snis is None:
        try:  # args.partition or partition tensor directly
            partition = args.partition
        except:
            partition = args

        # to account for per_sample = True

        if not isinstance(partition, float) and len(partition.shape) > 0 and partition.shape[0] == log_weight.shape[0]:
            while len(partition.shape) < len(log_weight.shape):
                partition = partition.unsqueeze(1)

        beta_iw = log_weight * partition
        snis = exponentiate_and_normalize(
            beta_iw, dim=1)
    else:
        pass

    exp = snis * log_weight
    exp = torch.sum(exp, dim=1)
    # if not isinstance(partition,float):
    #     import IPython
    #     IPython.embed()
    return torch.mean(exp, dim=0) if all_sample_mean else exp
示例#9
0
def compute_wake_phi_loss(log_weight, log_q):
    """Returns:
        loss: scalar that we call .backward() on and step the optimizer.
    """
    normalized_weight = exponentiate_and_normalize(log_weight, dim=1)
    return torch.mean(-torch.sum(normalized_weight.detach() * log_q, dim=1))
示例#10
0
    def record_stats(self, eval=False, extra_record=False):  # , data_loader):
        '''
            Records (across β) : expectation / variance / 3rd / 4th derivatives
                curvature, IWAE β estimator
                intermediate TVO integrals (WIP)
        '''
        '''Possibility of different, standardized partition for evaluation?
            - may run with record_partition specified or overall arg'''

        # Always use validation sample size
        S = self.args.valid_S

        # if self.args.record_partition is not None:
        #    partition = self.args.record_partition
        if eval:  # record_partition:
            partition = self.args.eval_partition \
                if self.args.eval_partition is not None \
                else torch.linspace(0, 1.0, 101, device='cuda')
        else:
            partition = self.args.partition

        log_iw = self.elbo().unsqueeze(-1) if len(
            self.elbo().shape) < 3 else self.elbo()

        log_iw = log_iw.detach()
        heated_log_weight = log_iw * partition

        snis = mlh.exponentiate_and_normalize(heated_log_weight, dim=1)

        # Leaving open possibility of addl calculations on batch dim (mean = False)
        tvo_expectations = util.calc_exp(log_iw,
                                         partition,
                                         snis=snis,
                                         all_sample_mean=True)
        tvo_vars = util.calc_var(log_iw,
                                 partition,
                                 snis=snis,
                                 all_sample_mean=True)

        # # Using average meter
        # torch.mean(tvo_expectations, dim=0)
        self.record_results['tvo_exp'].update(tvo_expectations.squeeze().cpu())
        self.record_results['tvo_var'].update(
            tvo_vars.squeeze().cpu())  # torch.mean(tvo_vars, dim=0)
        if extra_record:
            tvo_thirds = util.calc_third(log_iw,
                                         partition,
                                         snis=snis,
                                         all_sample_mean=True)
            tvo_fourths = util.calc_fourth(log_iw,
                                           partition,
                                           snis=snis,
                                           all_sample_mean=True)

            curvature = tvo_thirds / (torch.pow(1 + torch.pow(tvo_vars, 2),
                                                1.5))
            iwae_beta = torch.mean(torch.logsumexp(heated_log_weight, dim=1) -
                                   np.log(S),
                                   axis=0)

            self.record_results['tvo_third'].update(
                tvo_thirds.squeeze().cpu())  # torch.mean(tvo_thirds, dim=0)
            self.record_results['tvo_fourth'].update(tvo_fourths.squeeze().cpu(
            ))  # torch.mean(tvo_fourths, dim = 0)
            # per sample curvature by beta (gets recorded as mean over batches)
            self.record_results['curvature'].update(curvature.squeeze().cpu())
            # [K] length vector of MC estimators of log Z_β
            self.record_results['iwae_beta'].update(iwae_beta.squeeze().cpu())

            if eval:
                left_riemann = util._get_multiplier(partition, 'left')
                right_riemann = util._get_multiplier(partition, 'right')

                log_px_left = torch.sum(left_riemann * tvo_expectations)
                log_px_right = torch.sum(right_riemann * tvo_expectations)

                # self.record_results['betas']=partition

                # KL_Q = direct calculation via iwae_beta
                kl_q = partition * tvo_expectations - iwae_beta
                # KL_Q = KL_LR = integral of variances
                # beta * tvo_var * dbeta
                kl_lr = torch.cumsum(partition * tvo_vars * right_riemann,
                                     dim=0)
                kl_rl = torch.flip(
                    torch.cumsum(torch.flip(
                        (1 - partition) * tvo_vars * left_riemann, [0]),
                                 dim=0), [0])
                #kl_rl_2 = torch.stack([torch.sum(tvo_vars[i:]*left_riemann[i:]) for i in range(tvo_vars.shape[0])])

                self.record_results['direct_kl_lr'].update(
                    kl_q.squeeze().cpu())
                self.record_results['integral_kl_lr'].update(
                    kl_lr.squeeze().cpu())
                self.record_results['integral_kl_rl'].update(
                    kl_rl.squeeze().cpu())

                # FIND log p(x) by argmin abs(kl_lr - kl_rl) => KL[π_α || q] = KL[π_α || p]
                kl_diffs = torch.abs(kl_lr - kl_rl)
                min_val, min_ind = torch.min(kl_diffs, dim=0)
                log_px_jensen = tvo_expectations[min_ind]
                log_px_beta = partition[min_ind]

                # TVO intermediate UB / LB
                tvo_left = torch.cumsum(tvo_expectations * left_riemann, dim=0)
                tvo_right = torch.cumsum(tvo_expectations * right_riemann,
                                         dim=0)

                self.record_results['log_px_via_jensen'].update(
                    log_px_jensen.squeeze().cpu())
                self.record_results['log_px_beta'].update(
                    log_px_beta.squeeze().cpu())
                self.record_results['log_px_left_tvo'].update(
                    log_px_left.squeeze().cpu())
                self.record_results['log_px_right_tvo'].update(
                    log_px_right.squeeze().cpu())
                # all intermediate log Z_β via left/right integration (compare with iwae_beta)
                self.record_results['left_tvo'].update(
                    tvo_left.squeeze().cpu())
                self.record_results['right_tvo'].update(
                    tvo_right.squeeze().cpu())

                self.record_results['log_iw_var'].update(
                    tvo_vars[..., 0].squeeze().cpu())