Beispiel #1
0
def moments(model, args=None, **kwargs):
    args = model.args if args is None else args
    start = 0
    stop = 1
    threshold = 0.05

    if not args.per_sample and not args.per_batch:
        log_iw = get_total_log_weight(model, args, args.valid_S)
    else:
        log_iw = model.elbo()

    partitions = args.K - 1
    targets = np.linspace(0.0, 1.0, num=args.K + 1, endpoint=True)

    left = calc_exp(log_iw, start, all_sample_mean=not (args.per_sample))
    right = calc_exp(log_iw, stop, all_sample_mean=not (args.per_sample))
    left = torch.mean(left, axis=0, keepdims=True) if args.per_batch else left
    right = torch.mean(right, axis=0,
                       keepdims=True) if args.per_batch else right
    moment_avg = right - left

    beta_result = []
    for t in range(len(targets)):
        if targets[t] == 0.0 or targets[t] == 1.0:
            beta_result.append(
                targets[t] * (torch.ones_like(log_iw[:, 0]) if args.per_sample
                              else 1))  # zero if targets[t]=0
        else:
            target = targets[t]
            moment = left + target * moment_avg  #for t in targets]

            start = torch.zeros_like(
                log_iw[:, 0]) if args.per_sample else torch.zeros_like(left)
            stop = torch.ones_like(
                log_iw[:, 0]) if args.per_sample else torch.ones_like(left)

            beta_result.append(_moment_binary_search(\
                    moment, log_iw, start = start, stop = stop, \
                        threshold=threshold, per_sample = args.per_sample))

    if args.per_sample:  #or args.per_batch:
        beta_result = torch.cat([b.unsqueeze(1) for b in beta_result],
                                axis=1).unsqueeze(1)
        beta_result, _ = torch.sort(beta_result, -1)
    else:
        beta_result = torch.cuda.FloatTensor(beta_result)

    return beta_result
Beispiel #2
0
def _moment_binary_search(target,
                          log_iw,
                          start=0,
                          stop=1,
                          threshold=0.1,
                          recursion=0,
                          per_sample=False,
                          min_beta=0.001):  #recursion = 0,
    beta_guess = .5 * (stop + start)
    eta_guess = calc_exp(log_iw, beta_guess,
                         all_sample_mean=not per_sample).squeeze()
    target = torch.ones_like(eta_guess) * (target.squeeze())
    start_ = torch.where(eta_guess < target, beta_guess, start)
    stop_ = torch.where(eta_guess > target, beta_guess, stop)

    if torch.sum(torch.abs(eta_guess - target) > threshold).item() == 0:
        return beta_guess
    else:
        if recursion > 500:
            return beta_guess
        else:
            return _moment_binary_search(target,
                                         log_iw,
                                         start=start_,
                                         stop=stop_,
                                         recursion=recursion + 1,
                                         per_sample=per_sample)
Beispiel #3
0
def beta_gradient_descent(model, args, cpu = True, diffs = False):
    '''
    perform manual gradient descent on beta
    - get_beta_derivative_no_var : returns dTVO / dbeta and resets beta tracking
    - safe_step clips updates

    recalculate = True is used for beta_batch_gradient
        - init_expectation = expectation before θ gradient descent update
        - expectation_diffs = expectation after θ gradient descent update
    '''

    if args.schedule=='beta_batch_gradient':
        # used to update per_batch
        # (re-calculate expectations, variance post θ-update)
        log_weight = model.elbo()
        tvo_exps = calc_exp(log_weight, args, all_sample_mean=True)
        tvo_vars = calc_var_given_betas(log_weight, args, all_sample_mean=True)

        model.exp_meter.step(tvo_exps.data)
        model.var_meter.step(tvo_vars.data)

        #if model.exp_last is None or model.var_last is None:
        #    model.exp_last = tvo_exps.data
        #    model.var_last = tvo_vars.data


    #model.expectation_diffs = tvo_exps.data - model.init_expectation
    elif model.exp_last is None or isinstance(model.exp_meter.mean, int):
        # unchanged for first epoch after burn-in on beta_gradient_descent
        return args.partition


    # also resets expectation differences
    if diffs:
        beta_derivatives = get_beta_derivative_diffs(model, args)
    else:
        beta_derivatives = get_beta_derivative_single(model, args)

    model.reset_track_beta()


    # Gradient Descent STEP
    sliced_partition = args.partition.data[1:-1]
    sliced_partition = sliced_partition.cpu() if cpu else sliced_partition

    new_partition = safe_step(sliced_partition, args.beta_step_size * beta_derivatives, max_step = args.max_beta_step, adaptive=args.adaptive_beta_step)

    # pad 0 and 1
    new_partition = torch.cat([torch.zeros_like(new_partition[0]).unsqueeze(0), new_partition,  torch.ones_like(new_partition[0]).unsqueeze(0)])

    print(args.partition)
    print("new partition ", new_partition)
    print("beta steps ", args.beta_step_size * beta_derivatives)


    return new_partition.cuda() if cpu else new_partition
Beispiel #4
0
    def track_beta_grads(self, log_weight):       
        if self.exp_last is None or self.var_last is None:            
            if self.args.schedule == 'beta_gradient_descent':
                #  initial calculation over entire dataset for stability
                log_weight = util.get_total_log_weight(self, self.args, self.args.S).data
            # else: 'beta_batch_gradient' should be over batch only
            tvo_exps = calc_exp(log_weight, self.args, all_sample_mean=True)
            tvo_vars = calc_var_given_betas(log_weight, self.args, all_sample_mean=True)

            self.exp_last = tvo_exps
            self.var_last = tvo_vars
            
        else:
            tvo_exps = calc_exp(log_weight, self.args, all_sample_mean=True)
            tvo_vars = calc_var_given_betas(log_weight, self.args, all_sample_mean=True)
        # beta gradient includes a telescoping sum, reduces to (Ε_βκ(t=T) - Ε_βκ(t=0)) = expectation_diffs
        
        self.exp_meter.step(tvo_exps.data)
        self.var_meter.step(tvo_vars.data)
Beispiel #5
0
    def record_stats(self, loss=None, record_partition=False, epoch=None, batch_idx=None):
        '''
            Records (across β) : expectation / variance / 3rd / 4th derivatives
                curvature, IWAE β estimator
                intermediate TVO integrals (WIP)
            Also used in BNN to record classification metrics
        '''

        '''Possibility of different, standardized partition for evaluation?
            - may run with record_partition specified or overall arg'''

        # Always use validation sample size
        S = self.args.valid_S

        with torch.no_grad():
            if self.args.record_partition is not None:
                partition = self.args.record_partition
            elif record_partition:
                partition = torch.linspace(0, 1.0, 101, device='cuda')
            else:
                partition = self.args.partition

            log_iw = self.elbo().unsqueeze(-1) if len(self.elbo().shape) < 3 else self.elbo()

            heated_log_weight = log_iw * partition

            snis = util.exponentiate_and_normalize(heated_log_weight, dim=1)

            # Leaving open possibility of addl calculations on batch dim (mean = False)
            tvo_expectations = util.calc_exp(
                log_iw, partition, snis=snis, all_sample_mean=True)
            tvo_vars = util.calc_var(
                log_iw, partition, snis=snis, all_sample_mean=True)
            tvo_thirds = util.calc_third(
                log_iw, partition, snis=snis, all_sample_mean=True)
            tvo_fourths = util.calc_fourth(
                log_iw, partition, snis=snis, all_sample_mean=True)

            curvature = tvo_thirds/(torch.pow(1+torch.pow(tvo_vars, 2), 1.5))
            iwae_beta = torch.mean(torch.logsumexp(
                heated_log_weight, dim=1) - np.log(S), axis=0)

            # Using average meter
            # torch.mean(tvo_expectations, dim=0)
            self.record_results['tvo_exp'].step(tvo_expectations.cpu())
            self.record_results['tvo_var'].step(tvo_vars.cpu())
            self.record_results['tvo_third'].step(tvo_thirds.cpu())
            self.record_results['tvo_fourth'].step(tvo_fourths.cpu())

            # per sample curvature by beta (gets recorded as mean over batches)
            self.record_results['curvature'].step(curvature.cpu())
            # [K] length vector of MC estimators of log Z_β
            self.record_results['iwae_beta'].step(iwae_beta.cpu())
Beispiel #6
0
    def record_stats(self, eval=False, extra_record=False):  # , data_loader):
        '''
            Records (across β) : expectation / variance / 3rd / 4th derivatives
                curvature, IWAE β estimator
                intermediate TVO integrals (WIP)
        '''
        '''Possibility of different, standardized partition for evaluation?
            - may run with record_partition specified or overall arg'''

        # Always use validation sample size
        S = self.args.valid_S

        # if self.args.record_partition is not None:
        #    partition = self.args.record_partition
        if eval:  # record_partition:
            partition = self.args.eval_partition \
                if self.args.eval_partition is not None \
                else torch.linspace(0, 1.0, 101, device='cuda')
        else:
            partition = self.args.partition

        log_iw = self.elbo().unsqueeze(-1) if len(
            self.elbo().shape) < 3 else self.elbo()

        log_iw = log_iw.detach()
        heated_log_weight = log_iw * partition

        snis = mlh.exponentiate_and_normalize(heated_log_weight, dim=1)

        # Leaving open possibility of addl calculations on batch dim (mean = False)
        tvo_expectations = util.calc_exp(log_iw,
                                         partition,
                                         snis=snis,
                                         all_sample_mean=True)
        tvo_vars = util.calc_var(log_iw,
                                 partition,
                                 snis=snis,
                                 all_sample_mean=True)

        # # Using average meter
        # torch.mean(tvo_expectations, dim=0)
        self.record_results['tvo_exp'].update(tvo_expectations.squeeze().cpu())
        self.record_results['tvo_var'].update(
            tvo_vars.squeeze().cpu())  # torch.mean(tvo_vars, dim=0)
        if extra_record:
            tvo_thirds = util.calc_third(log_iw,
                                         partition,
                                         snis=snis,
                                         all_sample_mean=True)
            tvo_fourths = util.calc_fourth(log_iw,
                                           partition,
                                           snis=snis,
                                           all_sample_mean=True)

            curvature = tvo_thirds / (torch.pow(1 + torch.pow(tvo_vars, 2),
                                                1.5))
            iwae_beta = torch.mean(torch.logsumexp(heated_log_weight, dim=1) -
                                   np.log(S),
                                   axis=0)

            self.record_results['tvo_third'].update(
                tvo_thirds.squeeze().cpu())  # torch.mean(tvo_thirds, dim=0)
            self.record_results['tvo_fourth'].update(tvo_fourths.squeeze().cpu(
            ))  # torch.mean(tvo_fourths, dim = 0)
            # per sample curvature by beta (gets recorded as mean over batches)
            self.record_results['curvature'].update(curvature.squeeze().cpu())
            # [K] length vector of MC estimators of log Z_β
            self.record_results['iwae_beta'].update(iwae_beta.squeeze().cpu())

            if eval:
                left_riemann = util._get_multiplier(partition, 'left')
                right_riemann = util._get_multiplier(partition, 'right')

                log_px_left = torch.sum(left_riemann * tvo_expectations)
                log_px_right = torch.sum(right_riemann * tvo_expectations)

                # self.record_results['betas']=partition

                # KL_Q = direct calculation via iwae_beta
                kl_q = partition * tvo_expectations - iwae_beta
                # KL_Q = KL_LR = integral of variances
                # beta * tvo_var * dbeta
                kl_lr = torch.cumsum(partition * tvo_vars * right_riemann,
                                     dim=0)
                kl_rl = torch.flip(
                    torch.cumsum(torch.flip(
                        (1 - partition) * tvo_vars * left_riemann, [0]),
                                 dim=0), [0])
                #kl_rl_2 = torch.stack([torch.sum(tvo_vars[i:]*left_riemann[i:]) for i in range(tvo_vars.shape[0])])

                self.record_results['direct_kl_lr'].update(
                    kl_q.squeeze().cpu())
                self.record_results['integral_kl_lr'].update(
                    kl_lr.squeeze().cpu())
                self.record_results['integral_kl_rl'].update(
                    kl_rl.squeeze().cpu())

                # FIND log p(x) by argmin abs(kl_lr - kl_rl) => KL[π_α || q] = KL[π_α || p]
                kl_diffs = torch.abs(kl_lr - kl_rl)
                min_val, min_ind = torch.min(kl_diffs, dim=0)
                log_px_jensen = tvo_expectations[min_ind]
                log_px_beta = partition[min_ind]

                # TVO intermediate UB / LB
                tvo_left = torch.cumsum(tvo_expectations * left_riemann, dim=0)
                tvo_right = torch.cumsum(tvo_expectations * right_riemann,
                                         dim=0)

                self.record_results['log_px_via_jensen'].update(
                    log_px_jensen.squeeze().cpu())
                self.record_results['log_px_beta'].update(
                    log_px_beta.squeeze().cpu())
                self.record_results['log_px_left_tvo'].update(
                    log_px_left.squeeze().cpu())
                self.record_results['log_px_right_tvo'].update(
                    log_px_right.squeeze().cpu())
                # all intermediate log Z_β via left/right integration (compare with iwae_beta)
                self.record_results['left_tvo'].update(
                    tvo_left.squeeze().cpu())
                self.record_results['right_tvo'].update(
                    tvo_right.squeeze().cpu())

                self.record_results['log_iw_var'].update(
                    tvo_vars[..., 0].squeeze().cpu())