def moments(model, args=None, **kwargs): args = model.args if args is None else args start = 0 stop = 1 threshold = 0.05 if not args.per_sample and not args.per_batch: log_iw = get_total_log_weight(model, args, args.valid_S) else: log_iw = model.elbo() partitions = args.K - 1 targets = np.linspace(0.0, 1.0, num=args.K + 1, endpoint=True) left = calc_exp(log_iw, start, all_sample_mean=not (args.per_sample)) right = calc_exp(log_iw, stop, all_sample_mean=not (args.per_sample)) left = torch.mean(left, axis=0, keepdims=True) if args.per_batch else left right = torch.mean(right, axis=0, keepdims=True) if args.per_batch else right moment_avg = right - left beta_result = [] for t in range(len(targets)): if targets[t] == 0.0 or targets[t] == 1.0: beta_result.append( targets[t] * (torch.ones_like(log_iw[:, 0]) if args.per_sample else 1)) # zero if targets[t]=0 else: target = targets[t] moment = left + target * moment_avg #for t in targets] start = torch.zeros_like( log_iw[:, 0]) if args.per_sample else torch.zeros_like(left) stop = torch.ones_like( log_iw[:, 0]) if args.per_sample else torch.ones_like(left) beta_result.append(_moment_binary_search(\ moment, log_iw, start = start, stop = stop, \ threshold=threshold, per_sample = args.per_sample)) if args.per_sample: #or args.per_batch: beta_result = torch.cat([b.unsqueeze(1) for b in beta_result], axis=1).unsqueeze(1) beta_result, _ = torch.sort(beta_result, -1) else: beta_result = torch.cuda.FloatTensor(beta_result) return beta_result
def _moment_binary_search(target, log_iw, start=0, stop=1, threshold=0.1, recursion=0, per_sample=False, min_beta=0.001): #recursion = 0, beta_guess = .5 * (stop + start) eta_guess = calc_exp(log_iw, beta_guess, all_sample_mean=not per_sample).squeeze() target = torch.ones_like(eta_guess) * (target.squeeze()) start_ = torch.where(eta_guess < target, beta_guess, start) stop_ = torch.where(eta_guess > target, beta_guess, stop) if torch.sum(torch.abs(eta_guess - target) > threshold).item() == 0: return beta_guess else: if recursion > 500: return beta_guess else: return _moment_binary_search(target, log_iw, start=start_, stop=stop_, recursion=recursion + 1, per_sample=per_sample)
def beta_gradient_descent(model, args, cpu = True, diffs = False): ''' perform manual gradient descent on beta - get_beta_derivative_no_var : returns dTVO / dbeta and resets beta tracking - safe_step clips updates recalculate = True is used for beta_batch_gradient - init_expectation = expectation before θ gradient descent update - expectation_diffs = expectation after θ gradient descent update ''' if args.schedule=='beta_batch_gradient': # used to update per_batch # (re-calculate expectations, variance post θ-update) log_weight = model.elbo() tvo_exps = calc_exp(log_weight, args, all_sample_mean=True) tvo_vars = calc_var_given_betas(log_weight, args, all_sample_mean=True) model.exp_meter.step(tvo_exps.data) model.var_meter.step(tvo_vars.data) #if model.exp_last is None or model.var_last is None: # model.exp_last = tvo_exps.data # model.var_last = tvo_vars.data #model.expectation_diffs = tvo_exps.data - model.init_expectation elif model.exp_last is None or isinstance(model.exp_meter.mean, int): # unchanged for first epoch after burn-in on beta_gradient_descent return args.partition # also resets expectation differences if diffs: beta_derivatives = get_beta_derivative_diffs(model, args) else: beta_derivatives = get_beta_derivative_single(model, args) model.reset_track_beta() # Gradient Descent STEP sliced_partition = args.partition.data[1:-1] sliced_partition = sliced_partition.cpu() if cpu else sliced_partition new_partition = safe_step(sliced_partition, args.beta_step_size * beta_derivatives, max_step = args.max_beta_step, adaptive=args.adaptive_beta_step) # pad 0 and 1 new_partition = torch.cat([torch.zeros_like(new_partition[0]).unsqueeze(0), new_partition, torch.ones_like(new_partition[0]).unsqueeze(0)]) print(args.partition) print("new partition ", new_partition) print("beta steps ", args.beta_step_size * beta_derivatives) return new_partition.cuda() if cpu else new_partition
def track_beta_grads(self, log_weight): if self.exp_last is None or self.var_last is None: if self.args.schedule == 'beta_gradient_descent': # initial calculation over entire dataset for stability log_weight = util.get_total_log_weight(self, self.args, self.args.S).data # else: 'beta_batch_gradient' should be over batch only tvo_exps = calc_exp(log_weight, self.args, all_sample_mean=True) tvo_vars = calc_var_given_betas(log_weight, self.args, all_sample_mean=True) self.exp_last = tvo_exps self.var_last = tvo_vars else: tvo_exps = calc_exp(log_weight, self.args, all_sample_mean=True) tvo_vars = calc_var_given_betas(log_weight, self.args, all_sample_mean=True) # beta gradient includes a telescoping sum, reduces to (Ε_βκ(t=T) - Ε_βκ(t=0)) = expectation_diffs self.exp_meter.step(tvo_exps.data) self.var_meter.step(tvo_vars.data)
def record_stats(self, loss=None, record_partition=False, epoch=None, batch_idx=None): ''' Records (across β) : expectation / variance / 3rd / 4th derivatives curvature, IWAE β estimator intermediate TVO integrals (WIP) Also used in BNN to record classification metrics ''' '''Possibility of different, standardized partition for evaluation? - may run with record_partition specified or overall arg''' # Always use validation sample size S = self.args.valid_S with torch.no_grad(): if self.args.record_partition is not None: partition = self.args.record_partition elif record_partition: partition = torch.linspace(0, 1.0, 101, device='cuda') else: partition = self.args.partition log_iw = self.elbo().unsqueeze(-1) if len(self.elbo().shape) < 3 else self.elbo() heated_log_weight = log_iw * partition snis = util.exponentiate_and_normalize(heated_log_weight, dim=1) # Leaving open possibility of addl calculations on batch dim (mean = False) tvo_expectations = util.calc_exp( log_iw, partition, snis=snis, all_sample_mean=True) tvo_vars = util.calc_var( log_iw, partition, snis=snis, all_sample_mean=True) tvo_thirds = util.calc_third( log_iw, partition, snis=snis, all_sample_mean=True) tvo_fourths = util.calc_fourth( log_iw, partition, snis=snis, all_sample_mean=True) curvature = tvo_thirds/(torch.pow(1+torch.pow(tvo_vars, 2), 1.5)) iwae_beta = torch.mean(torch.logsumexp( heated_log_weight, dim=1) - np.log(S), axis=0) # Using average meter # torch.mean(tvo_expectations, dim=0) self.record_results['tvo_exp'].step(tvo_expectations.cpu()) self.record_results['tvo_var'].step(tvo_vars.cpu()) self.record_results['tvo_third'].step(tvo_thirds.cpu()) self.record_results['tvo_fourth'].step(tvo_fourths.cpu()) # per sample curvature by beta (gets recorded as mean over batches) self.record_results['curvature'].step(curvature.cpu()) # [K] length vector of MC estimators of log Z_β self.record_results['iwae_beta'].step(iwae_beta.cpu())
def record_stats(self, eval=False, extra_record=False): # , data_loader): ''' Records (across β) : expectation / variance / 3rd / 4th derivatives curvature, IWAE β estimator intermediate TVO integrals (WIP) ''' '''Possibility of different, standardized partition for evaluation? - may run with record_partition specified or overall arg''' # Always use validation sample size S = self.args.valid_S # if self.args.record_partition is not None: # partition = self.args.record_partition if eval: # record_partition: partition = self.args.eval_partition \ if self.args.eval_partition is not None \ else torch.linspace(0, 1.0, 101, device='cuda') else: partition = self.args.partition log_iw = self.elbo().unsqueeze(-1) if len( self.elbo().shape) < 3 else self.elbo() log_iw = log_iw.detach() heated_log_weight = log_iw * partition snis = mlh.exponentiate_and_normalize(heated_log_weight, dim=1) # Leaving open possibility of addl calculations on batch dim (mean = False) tvo_expectations = util.calc_exp(log_iw, partition, snis=snis, all_sample_mean=True) tvo_vars = util.calc_var(log_iw, partition, snis=snis, all_sample_mean=True) # # Using average meter # torch.mean(tvo_expectations, dim=0) self.record_results['tvo_exp'].update(tvo_expectations.squeeze().cpu()) self.record_results['tvo_var'].update( tvo_vars.squeeze().cpu()) # torch.mean(tvo_vars, dim=0) if extra_record: tvo_thirds = util.calc_third(log_iw, partition, snis=snis, all_sample_mean=True) tvo_fourths = util.calc_fourth(log_iw, partition, snis=snis, all_sample_mean=True) curvature = tvo_thirds / (torch.pow(1 + torch.pow(tvo_vars, 2), 1.5)) iwae_beta = torch.mean(torch.logsumexp(heated_log_weight, dim=1) - np.log(S), axis=0) self.record_results['tvo_third'].update( tvo_thirds.squeeze().cpu()) # torch.mean(tvo_thirds, dim=0) self.record_results['tvo_fourth'].update(tvo_fourths.squeeze().cpu( )) # torch.mean(tvo_fourths, dim = 0) # per sample curvature by beta (gets recorded as mean over batches) self.record_results['curvature'].update(curvature.squeeze().cpu()) # [K] length vector of MC estimators of log Z_β self.record_results['iwae_beta'].update(iwae_beta.squeeze().cpu()) if eval: left_riemann = util._get_multiplier(partition, 'left') right_riemann = util._get_multiplier(partition, 'right') log_px_left = torch.sum(left_riemann * tvo_expectations) log_px_right = torch.sum(right_riemann * tvo_expectations) # self.record_results['betas']=partition # KL_Q = direct calculation via iwae_beta kl_q = partition * tvo_expectations - iwae_beta # KL_Q = KL_LR = integral of variances # beta * tvo_var * dbeta kl_lr = torch.cumsum(partition * tvo_vars * right_riemann, dim=0) kl_rl = torch.flip( torch.cumsum(torch.flip( (1 - partition) * tvo_vars * left_riemann, [0]), dim=0), [0]) #kl_rl_2 = torch.stack([torch.sum(tvo_vars[i:]*left_riemann[i:]) for i in range(tvo_vars.shape[0])]) self.record_results['direct_kl_lr'].update( kl_q.squeeze().cpu()) self.record_results['integral_kl_lr'].update( kl_lr.squeeze().cpu()) self.record_results['integral_kl_rl'].update( kl_rl.squeeze().cpu()) # FIND log p(x) by argmin abs(kl_lr - kl_rl) => KL[π_α || q] = KL[π_α || p] kl_diffs = torch.abs(kl_lr - kl_rl) min_val, min_ind = torch.min(kl_diffs, dim=0) log_px_jensen = tvo_expectations[min_ind] log_px_beta = partition[min_ind] # TVO intermediate UB / LB tvo_left = torch.cumsum(tvo_expectations * left_riemann, dim=0) tvo_right = torch.cumsum(tvo_expectations * right_riemann, dim=0) self.record_results['log_px_via_jensen'].update( log_px_jensen.squeeze().cpu()) self.record_results['log_px_beta'].update( log_px_beta.squeeze().cpu()) self.record_results['log_px_left_tvo'].update( log_px_left.squeeze().cpu()) self.record_results['log_px_right_tvo'].update( log_px_right.squeeze().cpu()) # all intermediate log Z_β via left/right integration (compare with iwae_beta) self.record_results['left_tvo'].update( tvo_left.squeeze().cpu()) self.record_results['right_tvo'].update( tvo_right.squeeze().cpu()) self.record_results['log_iw_var'].update( tvo_vars[..., 0].squeeze().cpu())