def calc_var(log_weight, args, snis=None, all_sample_mean=True): """ Args: log_weight : [batch, samples, *] args : either args object or partition [batch, 1, K partitions] all_sample_mean : returns mean over samples if True snis : optionally feed weights to avoid recomputation Returns: Variance across importance samples at each beta (2nd derivative of logZβ) """ log_weight = log_weight.unsqueeze(-1) if len( log_weight.shape) < 3 else log_weight if snis is None: try: # args.partition or partition tensor directly partition = args.partition except: partition = args beta_iw = log_weight * partition snis = exponentiate_and_normalize( beta_iw, dim=1) else: pass exp_ = torch.sum(snis * log_weight, dim=1) exp2 = torch.sum(snis * torch.pow(log_weight, 2), dim=1) to_return = exp2 - torch.pow(exp_, 2) # VM: May have to switch to_return to E[(X-EX)(X-EX)] form, had numerical issues in the past assert not torch.isnan(to_return).any(), "Nan in calc_var() - switch to E[(X-EX)(X-EX)] form for numerical stability" return torch.mean(to_return, dim=0) if all_sample_mean else to_return
def calc_fourth(log_weight, args, snis=None, all_sample_mean=True): """ Args: log_weight : [batch, samples, *] args : either args object or partition [batch, 1, K partitions] all_sample_mean : returns mean over samples if True snis : optionally feed weights to avoid recomputation Returns: Fourth derivative of logZβ at each beta """ log_weight = log_weight.unsqueeze(-1) if len( log_weight.shape) < 3 else log_weight if snis is None: try: # args.partition or partition tensor directly partition = args.partition except: partition = args beta_iw = log_weight * partition snis = exponentiate_and_normalize( beta_iw, dim=1) else: pass exp = torch.sum(snis * log_weight, dim=1) exp2 = torch.sum(snis * torch.pow(log_weight, 2), dim=1) exp3 = torch.sum(snis * torch.pow(log_weight, 3), dim=1) exp4 = torch.sum(snis * torch.pow(log_weight, 4), dim=1) to_return = exp4 - 6*torch.pow(exp, 4) + 12*exp2 * \ torch.pow(exp, 2) - 3*torch.pow(exp2, 2) - 4*exp*exp3 return torch.mean(to_return, dim=0) if all_sample_mean else to_return
def calc_exp(log_weight, args, all_sample_mean=True, snis=None): """ Args: log_weight : [batch, samples, *] args : either args object or partition [batch, 1, K partitions] all_sample_mean : True averages over batch TO DO : replace for cleaner integration into code (pulled directly from Rob's) """ log_weight = log_weight.unsqueeze(-1) if len( log_weight.shape) < 3 else log_weight if snis is None: try: # args.partition or partition tensor directly partition = args.partition except: partition = args beta_iw = log_weight * partition snis = exponentiate_and_normalize( beta_iw, dim=1) else: pass exp = snis * log_weight exp = torch.sum(exp, dim=1) return torch.mean(exp, dim=0) if all_sample_mean else exp
def compute_tvo_loss(log_weight, log_p, log_q, args): """Args: log_weight: tensor of shape [batch_size, num_particles] log_p: tensor of shape [batch_size, num_particles] log_q: tensor of shape [batch_size, num_particles] partition: partition of [0, 1]; tensor of shape [num_partitions + 1] where partition[0] is zero and partition[-1] is one; see https://en.wikipedia.org/wiki/Partition_of_an_interval num_particles: int integration: left, right or trapz Returns: loss: scalar that we call .backward() on and step the optimizer. elbo: average elbo over data """ partition = args.partition num_particles = args.S integration = args.integration log_weight = log_weight.unsqueeze(-1) if len( log_weight.shape) < 3 else log_weight heated_log_weight = log_weight * partition heated_normalized_weight = exponentiate_and_normalize( heated_log_weight, dim=1) log_p = log_p.unsqueeze(-1) log_q = log_q.unsqueeze(-1) thermo_logp = partition * log_p + \ (1 - partition) * log_q wf = heated_normalized_weight * log_weight w_detached = heated_normalized_weight.detach() if num_particles == 1: correction = 1 else: correction = num_particles / (num_particles - 1) cov = correction * torch.sum( w_detached * (log_weight - torch.sum(wf, dim=1, keepdim=True)).detach() * (thermo_logp - torch.sum(thermo_logp * w_detached, dim=1, keepdim=True)), dim=1) multiplier = _get_multiplier(partition, integration) loss = -torch.mean(torch.sum( multiplier * (cov + torch.sum( w_detached * log_weight, dim=1)), dim=1)) return loss
def get_total_beta_exp(model, args, betas, S): with torch.no_grad(): log_weight = [] for obs in args.train_data_loader: model.set_internals(obs, S) elbo = model.elbo() heated = elbo.unsqueeze(-1) * betas snis = exponentiate_and_normalize(heated, dim=1) exp = torch.mean(torch.sum(snis*elbo.unsqueeze(-1), dim=1), dim=0) log_weight.append(exp.unsqueeze(0)) log_weight = torch.cat(log_weight, dim=0) return torch.mean(log_weight, dim=0)
def get_tvo_components(log_weight, log_p, log_q, args, heated_normalized_weight=None): partition = args.partition num_particles = args.S integration = args.integration # feed heated_normalized_weight when doing importance resampling (due to uniform expectations at selected indices) if heated_normalized_weight is None: heated_log_weight = log_weight * partition heated_normalized_weight = exponentiate_and_normalize( heated_log_weight, dim=1) log_p = log_p.unsqueeze(-1) if len(log_p.shape) < 3 else log_p log_q = log_q.unsqueeze(-1) if len(log_q.shape) < 3 else log_q thermo_logp = partition * log_p + \ (1 - partition) * log_q snis_logw = heated_normalized_weight * log_weight snis_detach = heated_normalized_weight.detach() return snis_logw, thermo_logp, snis_detach
def calc_var(log_weight, args, snis=None, all_sample_mean=True, betas=None): """ Args: log_weight : [batch, samples, *] args : args object *** Note: only args.partition is used (this can be replaced by betas arg) all_sample_mean : returns mean over samples if True betas : β points at which to evaluate variance (shape: [K β points] or [batch or 1, 1, K β points]) Returns: Variance across importance samples at each beta (2nd derivative of log Z_β wrt β) """ from ml_helpers import exponentiate_and_normalize log_weight = log_weight.unsqueeze(-1) if len( log_weight.shape) < 3 else log_weight if betas is None: try: partition = args.partition except: partition = args else: partition = betas beta_iw = log_weight * partition snis = exponentiate_and_normalize( beta_iw, dim=1) snis_detach = snis.detach() # expected log weights under π_β exp_pi_beta_log_weight = torch.sum( snis_detach*log_weight, dim=1, keepdim=True) variance = torch.sum( snis_detach*torch.pow(log_weight - exp_pi_beta_log_weight, 2), dim=1) # variance is [batch, # β]. set all_sample_mean = True to average over batch dimension return torch.mean(variance, dim=0) if all_sample_mean else variance
def calc_exp(log_weight, args, all_sample_mean=True, snis=None): """ Args: log_weight : [batch, samples, *] args : either args object or partition [batch, 1, K partitions] all_sample_mean : True averages over batch TO DO : replace for cleaner integration into code (pulled directly from Rob's) """ log_weight = log_weight.unsqueeze(-1) if len( log_weight.shape) < 3 else log_weight if snis is None: try: # args.partition or partition tensor directly partition = args.partition except: partition = args # to account for per_sample = True if not isinstance(partition, float) and len(partition.shape) > 0 and partition.shape[0] == log_weight.shape[0]: while len(partition.shape) < len(log_weight.shape): partition = partition.unsqueeze(1) beta_iw = log_weight * partition snis = exponentiate_and_normalize( beta_iw, dim=1) else: pass exp = snis * log_weight exp = torch.sum(exp, dim=1) # if not isinstance(partition,float): # import IPython # IPython.embed() return torch.mean(exp, dim=0) if all_sample_mean else exp
def compute_wake_phi_loss(log_weight, log_q): """Returns: loss: scalar that we call .backward() on and step the optimizer. """ normalized_weight = exponentiate_and_normalize(log_weight, dim=1) return torch.mean(-torch.sum(normalized_weight.detach() * log_q, dim=1))
def record_stats(self, eval=False, extra_record=False): # , data_loader): ''' Records (across β) : expectation / variance / 3rd / 4th derivatives curvature, IWAE β estimator intermediate TVO integrals (WIP) ''' '''Possibility of different, standardized partition for evaluation? - may run with record_partition specified or overall arg''' # Always use validation sample size S = self.args.valid_S # if self.args.record_partition is not None: # partition = self.args.record_partition if eval: # record_partition: partition = self.args.eval_partition \ if self.args.eval_partition is not None \ else torch.linspace(0, 1.0, 101, device='cuda') else: partition = self.args.partition log_iw = self.elbo().unsqueeze(-1) if len( self.elbo().shape) < 3 else self.elbo() log_iw = log_iw.detach() heated_log_weight = log_iw * partition snis = mlh.exponentiate_and_normalize(heated_log_weight, dim=1) # Leaving open possibility of addl calculations on batch dim (mean = False) tvo_expectations = util.calc_exp(log_iw, partition, snis=snis, all_sample_mean=True) tvo_vars = util.calc_var(log_iw, partition, snis=snis, all_sample_mean=True) # # Using average meter # torch.mean(tvo_expectations, dim=0) self.record_results['tvo_exp'].update(tvo_expectations.squeeze().cpu()) self.record_results['tvo_var'].update( tvo_vars.squeeze().cpu()) # torch.mean(tvo_vars, dim=0) if extra_record: tvo_thirds = util.calc_third(log_iw, partition, snis=snis, all_sample_mean=True) tvo_fourths = util.calc_fourth(log_iw, partition, snis=snis, all_sample_mean=True) curvature = tvo_thirds / (torch.pow(1 + torch.pow(tvo_vars, 2), 1.5)) iwae_beta = torch.mean(torch.logsumexp(heated_log_weight, dim=1) - np.log(S), axis=0) self.record_results['tvo_third'].update( tvo_thirds.squeeze().cpu()) # torch.mean(tvo_thirds, dim=0) self.record_results['tvo_fourth'].update(tvo_fourths.squeeze().cpu( )) # torch.mean(tvo_fourths, dim = 0) # per sample curvature by beta (gets recorded as mean over batches) self.record_results['curvature'].update(curvature.squeeze().cpu()) # [K] length vector of MC estimators of log Z_β self.record_results['iwae_beta'].update(iwae_beta.squeeze().cpu()) if eval: left_riemann = util._get_multiplier(partition, 'left') right_riemann = util._get_multiplier(partition, 'right') log_px_left = torch.sum(left_riemann * tvo_expectations) log_px_right = torch.sum(right_riemann * tvo_expectations) # self.record_results['betas']=partition # KL_Q = direct calculation via iwae_beta kl_q = partition * tvo_expectations - iwae_beta # KL_Q = KL_LR = integral of variances # beta * tvo_var * dbeta kl_lr = torch.cumsum(partition * tvo_vars * right_riemann, dim=0) kl_rl = torch.flip( torch.cumsum(torch.flip( (1 - partition) * tvo_vars * left_riemann, [0]), dim=0), [0]) #kl_rl_2 = torch.stack([torch.sum(tvo_vars[i:]*left_riemann[i:]) for i in range(tvo_vars.shape[0])]) self.record_results['direct_kl_lr'].update( kl_q.squeeze().cpu()) self.record_results['integral_kl_lr'].update( kl_lr.squeeze().cpu()) self.record_results['integral_kl_rl'].update( kl_rl.squeeze().cpu()) # FIND log p(x) by argmin abs(kl_lr - kl_rl) => KL[π_α || q] = KL[π_α || p] kl_diffs = torch.abs(kl_lr - kl_rl) min_val, min_ind = torch.min(kl_diffs, dim=0) log_px_jensen = tvo_expectations[min_ind] log_px_beta = partition[min_ind] # TVO intermediate UB / LB tvo_left = torch.cumsum(tvo_expectations * left_riemann, dim=0) tvo_right = torch.cumsum(tvo_expectations * right_riemann, dim=0) self.record_results['log_px_via_jensen'].update( log_px_jensen.squeeze().cpu()) self.record_results['log_px_beta'].update( log_px_beta.squeeze().cpu()) self.record_results['log_px_left_tvo'].update( log_px_left.squeeze().cpu()) self.record_results['log_px_right_tvo'].update( log_px_right.squeeze().cpu()) # all intermediate log Z_β via left/right integration (compare with iwae_beta) self.record_results['left_tvo'].update( tvo_left.squeeze().cpu()) self.record_results['right_tvo'].update( tvo_right.squeeze().cpu()) self.record_results['log_iw_var'].update( tvo_vars[..., 0].squeeze().cpu())