def _max_margin(self, mu_q, sigma_q, pos_mu_p, pos_sigma_p, neg_mu_p, neg_sigma_p, mask): """ Computes a sum over context words margin(hinge loss). :param pos_context_words: a tensor with true context words ids [batch_size x window_size] :param neg_context_words: a tensor with negative context words ids [batch_size x window_size] :param num_contexts: batch_size (how many context words for each center id - necessary for masking padding) :return: tensor [batch_size x 1] """ batch_size, num_context_ids, embed_dim = pos_mu_p.size() mu_q_tiled = mu_q.unsqueeze(1).repeat(1, num_context_ids, 1) sigma_q_tiled = sigma_q.unsqueeze(1).repeat(1, num_context_ids, 1) mu_q_flat = mu_q_tiled.view(batch_size * num_context_ids, -1) sigma_q_flat = sigma_q_tiled.view(batch_size * num_context_ids, -1) pos_mu_p_flat = pos_mu_p.view(batch_size * num_context_ids, -1) pos_sigma_p_flat = pos_sigma_p.view(batch_size * num_context_ids, -1) neg_mu_p_flat = neg_mu_p.view(batch_size * num_context_ids, -1) neg_sigma_p_flat = neg_sigma_p.view(batch_size * num_context_ids, -1) kl_pos = compute_kl(mu_q_flat, sigma_q_flat, pos_mu_p_flat, pos_sigma_p_flat, device=self.device).view(batch_size, -1) kl_neg = compute_kl(mu_q_flat, sigma_q_flat, neg_mu_p_flat, neg_sigma_p_flat, device=self.device).view(batch_size, -1) hinge_loss = (kl_pos - kl_neg + self.margin).clamp_min_(0) hinge_loss.masked_fill_(mask, 0) return hinge_loss.sum(1)
def forward(self, center_ids, center_metadata_ids, context_ids, context_metadata_ids, neg_ids, neg_metadata_ids, num_contexts, num_metadata_samples=None): """ :param center_ids: batch_size :param center_metadata_ids: batch_size :param context_ids: batch_size, 2 * context_window :param context_metadata_ids: batch_size, 2 * context_window, metadata_samples :param neg_ids: batch_size, 2 * context_window :param neg_metadata_ids: batch_size, 2 * context_window, metadata_samples :param num_contexts: batch_size (how many context words for each center id - necessary for masking padding) :return: cost components: KL-Divergence (q(z|w,c) || p(z|w)) and max margin (reconstruction error) """ # Mask padded context ids device = 'cuda' if torch.cuda.is_available() else 'cpu' batch_size, num_context_ids = context_ids.size() mask_size = torch.Size([batch_size, num_context_ids]) mask = mask_2D(mask_size, num_contexts).to(device) m_samples = context_metadata_ids.size()[-1] assert m_samples == num_metadata_samples # Compute center words mu_center_q, sigma_center_q, _ = self.encoder(center_ids, center_metadata_ids, context_ids, mask) mu_center_tiled_q = mu_center_q.unsqueeze(1).repeat(1, num_context_ids * m_samples, 1) sigma_center_tiled_q = sigma_center_q.unsqueeze(1).repeat(1, num_context_ids * m_samples, 1) mu_center_flat_q = mu_center_tiled_q.view(batch_size * num_context_ids * m_samples, -1) sigma_center_flat_q = sigma_center_tiled_q.view(batch_size * num_context_ids * m_samples, -1) # Compute decoded representations of (w, d), E(c), E(n) mu_center, sigma_center = self.decoder(center_ids, center_metadata_ids) mu_pos, sigma_pos = self._compute_marginal(context_ids, context_metadata_ids) mu_neg, sigma_neg = self._compute_marginal(neg_ids, neg_metadata_ids) # Flatten positive context mu_pos_flat = mu_pos.view(batch_size * num_context_ids * m_samples, -1) sigma_pos_flat = sigma_pos.view(batch_size * num_context_ids * m_samples, -1) # Flatten negative context mu_neg_flat = mu_neg.view(batch_size * num_context_ids * m_samples, -1) sigma_neg_flat = sigma_neg.view(batch_size * num_context_ids * m_samples, -1) # Compute KL-divergence between center words and negative and reshape kl_pos_flat = compute_kl(mu_center_flat_q, sigma_center_flat_q, mu_pos_flat, sigma_pos_flat) kl_neg_flat = compute_kl(mu_center_flat_q, sigma_center_flat_q, mu_neg_flat, sigma_neg_flat) kl_pos = kl_pos_flat.view(batch_size, num_context_ids, m_samples).mean(-1) kl_neg = kl_neg_flat.view(batch_size, num_context_ids, m_samples).mean(-1) hinge_loss = (kl_pos - kl_neg + 1.0).clamp_min_(0) hinge_loss.masked_fill_(mask, 0) hinge_loss = hinge_loss.sum(1) recon_loss = compute_kl(mu_center_q, sigma_center_q, mu_center, sigma_center).squeeze(-1) return hinge_loss.mean(), recon_loss.mean()
def forward(self, sf_ids, context_ids, lf_ids, target_lf_ids, lf_token_ct, num_outputs): """ :param sf_ids: batch_size :param context_ids: batch_size, num_context_ids :param lf_ids: batch_size, max_output_size, max_lf_len :param lf_token_ct: batch_size, max_output_size - normalizer for lf_ids :param target_lf_ids: batch_size :param num_outputs: batch_size :return: """ batch_size, num_context_ids = context_ids.size() max_output_size = lf_ids.size()[1] output_dim = torch.Size([batch_size, max_output_size]) output_mask = mask_2D(output_dim, num_outputs) # First thing is to pass the SF with the context to the encoder mask = torch.BoolTensor(torch.Size([batch_size, num_context_ids])) mask.fill_(0) sf_mu, sf_sigma = self.encoder(sf_ids, context_ids, mask) # Next is to get prior representations for each LF in lf_ids lf_mu, lf_sigma = self._compute_priors(lf_ids) # Summarize LFs normalizer = lf_token_ct.unsqueeze(-1).clamp_min(1.0) lf_mu_sum, lf_sigma_sum = lf_mu.sum(-2) / normalizer, lf_sigma.sum( -2) / normalizer # Tile SFs across each LF and flatten both SFs and LFs sf_mu_flat = sf_mu.unsqueeze(1).repeat(1, max_output_size, 1).view( batch_size * max_output_size, -1) sf_sigma_flat = sf_sigma.unsqueeze(1).repeat( 1, max_output_size, 1).view(batch_size * max_output_size, -1) lf_mu_flat = lf_mu_sum.view(batch_size * max_output_size, -1) lf_sigma_flat = lf_sigma_sum.view(batch_size * max_output_size, -1) kl = compute_kl(sf_mu_flat, sf_sigma_flat, lf_mu_flat, lf_sigma_flat).view(batch_size, max_output_size) # min_kl, max_kl = kl.min(-1)[0], kl.max(-1)[0] # normalizer = (max_kl - min_kl).unsqueeze(-1) # numerator = max_kl.unsqueeze(-1).repeat(1, kl.size()[-1]) - kl # score = numerator # / normalizer score = -kl score.masked_fill_(output_mask, float('-inf')) return score, target_lf_ids
def forward(self, token_ids, sec_ids, cat_ids, context_ids, neg_context_ids, num_contexts): """ :param center_ids: batch_size :param context_ids: batch_size, 2 * context_window :param neg_context_ids: batch_size, 2 * context_window :param num_contexts: batch_size (how many context words for each center id - necessary for masking padding) :return: cost components: KL-Divergence (q(z|w,c) || p(z|w)) and max margin (reconstruction error) """ # Mask padded context ids batch_size, num_context_ids = context_ids.size() mask_size = torch.Size([batch_size, num_context_ids]) mask = mask_2D(mask_size, num_contexts).to(self.device) center_ids = token_ids if self.multi_bsg: center_id_candidates = torch.cat([ token_ids.unsqueeze(0), sec_ids.unsqueeze(0), cat_ids.unsqueeze(0) ]) input_sample = torch.multinomial(self.input_weights, batch_size, replacement=True).to(self.device) center_ids = center_id_candidates.gather( 0, input_sample.unsqueeze(0)).squeeze(0) mu_q, sigma_q = self.encoder(center_ids, context_ids, mask, token_mask_p=self.mask_p) mu_p, sigma_p = self._compute_priors(token_ids) pos_mu_p, pos_sigma_p = self._compute_priors(context_ids) neg_mu_p, neg_sigma_p = self._compute_priors(neg_context_ids) kl = compute_kl(mu_q, sigma_q, mu_p, sigma_p).mean() max_margin = self._max_margin(mu_q, sigma_q, pos_mu_p, pos_sigma_p, neg_mu_p, neg_sigma_p, mask).mean() return kl, max_margin
def forward(self, center_ids, context_ids, neg_context_ids, num_contexts): """ :param center_ids: batch_size :param context_ids: batch_size, 2 * context_window :param neg_context_ids: batch_size, 2 * context_window :param num_contexts: batch_size (how many context words for each center id - necessary for masking padding) :return: cost components: KL-Divergence (q(z|w,c) || p(z|w)) and max margin (reconstruction error) """ # Mask padded context ids batch_size, num_context_ids = context_ids.size() mask_size = torch.Size([batch_size, num_context_ids]) mask = mask_2D(mask_size, num_contexts).to(self.device) mu_q, sigma_q = self.encoder(center_ids, context_ids, mask) mu_p, sigma_p = self._compute_priors(center_ids) pos_mu_p, pos_sigma_p = self._compute_priors(context_ids) neg_mu_p, neg_sigma_p = self._compute_priors(neg_context_ids) kl = compute_kl(mu_q, sigma_q, mu_p, sigma_p).mean() max_margin = self._max_margin(mu_q, sigma_q, pos_mu_p, pos_sigma_p, neg_mu_p, neg_sigma_p, mask).mean() return kl, max_margin
def forward(self, sf_ids, section_ids, category_ids, context_ids, lf_ids, target_lf_ids, lf_token_ct, lf_metadata_ids, num_outputs, num_contexts): """ :param sf_ids: LongTensor of batch_size :param context_ids: LongTensor of batch_size x 2 * context_window :param lf_ids: LongTensor of batch_size x max_output_size x max_lf_len :param lf_token_ct: batch_size, max_output_size - normalizer for lf_ids :param target_lf_ids: LongTensor of batch_size representing which index in lf_ids lies the target LF :param num_outputs: list representing the number of target LFs for each row in batch. :return: """ batch_size, max_output_size, _ = lf_ids.size() # Next is to get prior representations for each LF in lf_ids lf_mu, lf_sigma = self._compute_priors(lf_ids) # Summarize LFs normalizer = lf_token_ct.unsqueeze(-1).clamp_min(1.0) lf_mu_sum, lf_sigma_sum = lf_mu.sum(-2) / normalizer, lf_sigma.sum( -2) / normalizer combined_mu = [] combined_sigma = [] # Encode SFs in context sf_mu_tokens, sf_sigma_tokens = self.encode_context( sf_ids, context_ids, num_contexts) combined_mu.append(sf_mu_tokens) combined_sigma.append(sf_sigma_tokens) # For MBSGE ensemble method, we leverage section ids and note category ids if len(section_ids.nonzero()) > 0: sf_mu_sec, sf_sigma_sec = self.encode_context( section_ids, context_ids, num_contexts) combined_mu.append(sf_mu_sec) combined_sigma.append(sf_sigma_sec) if len(category_ids.nonzero()) > 0: sf_mu_cat, sf_sigma_cat = self.encode_context( category_ids, context_ids, num_contexts) combined_mu.append(sf_mu_cat) combined_sigma.append(sf_sigma_cat) combined_mu = torch.cat(list(map(lambda x: x.unsqueeze(1), combined_mu)), axis=1) combined_sigma = torch.cat(list( map(lambda x: x.unsqueeze(1), combined_sigma)), axis=1) sf_mu = combined_mu.mean(1) sf_sigma = combined_sigma.mean(1) # Tile SFs across each LF and flatten both SFs and LFs sf_mu_flat = sf_mu.unsqueeze(1).repeat(1, max_output_size, 1).view( batch_size * max_output_size, -1) sf_sigma_flat = sf_sigma.unsqueeze(1).repeat( 1, max_output_size, 1).view(batch_size * max_output_size, -1) lf_mu_flat = lf_mu_sum.view(batch_size * max_output_size, -1) lf_sigma_flat = lf_sigma_sum.view(batch_size * max_output_size, -1) output_dim = torch.Size([batch_size, max_output_size]) output_mask = mask_2D(output_dim, num_outputs).to(self.device) kl = compute_kl(sf_mu_flat, sf_sigma_flat, lf_mu_flat, lf_sigma_flat).view(batch_size, max_output_size) score = -kl score.masked_fill_(output_mask, float('-inf')) return score, target_lf_ids, None
print(headers[i]) for p in np.arange(0, 1.25, 0.25): rw = [p, 1.0 - p] rel_weights = torch.FloatTensor([rw]).to(device_str) with torch.no_grad(): mu_q, sigma_q, weights = model.encoder(center_word_tens, header_tens, context_tens, mask, center_mask_p=None, context_mask_p=None, metadata_mask_p=None, rel_weights=rel_weights) scores = tensor_to_np( nn.Softmax(-1)(-compute_kl(mu_q, sigma_q, mu_compare, sigma_compare).squeeze(1))) order = np.argsort(-scores) weight_str = 'Relative Weights --> Word={}. Section={}'.format( rw[1], rw[0]) print('\t{}'.format(weight_str)) for i in order[:min(10, len(order))]: print('\t\t{} --> {}'.format(compare_words[i], scores[i])) section_df = pd.read_csv( os.path.join(home_dir, '/preprocess/data/mimic/section_freq.csv')).dropna() section_names = list( sorted( set( list( section_df.nlargest(
pos_center_tens = torch.LongTensor(pos_center_id).clamp_min(0) neg_center_tens = torch.LongTensor(neg_center_id).clamp_min(0) pos_mu, pos_sigma = model._compute_priors(pos_center_tens) neg_mu, neg_sigma = model._compute_priors(neg_center_tens) window_tens = torch.LongTensor([vocab.get_id(t) for t in window_toks ]).clamp_min(0).unsqueeze(0) window_mask = torch.BoolTensor(torch.Size([1, len(window_toks)])) window_mask.fill_(0) with torch.no_grad(): z_orig_mu, z_orig_sigma = model.encoder(ab_tens, window_tens, window_mask) orig_kl_pos = compute_kl(z_orig_mu, z_orig_sigma, pos_mu, pos_sigma).item() orig_kl_neg = compute_kl(z_orig_mu, z_orig_sigma, neg_mu, neg_sigma).item() extra_context_tokens = [] extra_kls = [] for idx, (mu, sigma) in enumerate([(pos_mu, pos_sigma), (neg_mu, neg_sigma)]): lf = pos_center[0] if idx == 0 else neg_center[0] posterior_kld_full_tokens = compute_kl(mu, sigma, full_priors['mu'], full_priors['sigma']).squeeze(1) closest_token_idxs = posterior_kld_full_tokens.squeeze().argsort( )[:EXTRA_CONTEXTS].numpy() closest_tokens = [] for idx in closest_token_idxs: closest_tokens.append(full_tokens[idx])
def forward(self, context_ids, center_ids, pos_ids, neg_ids, context_token_type_ids, num_contexts, context_mask, center_mask, pos_mask, neg_mask, num_metadata_samples=None): """ :param context_ids: batch_size x context_len (wp ids for BERT-style context word sequence with center word and metadata special token) :param center_ids: batch_size x num_context_ids (wp ids for center words and its metadata) :param pos_ids: batch_size x window_size * 2 x num_context_ids (wp ids for context words) :param neg_ids: batch_size x window_size * 2 x num_context_ids (wp ids for negatively sampled words and MC metadata sample) :param context_token_type_ids: batch_size x context_len (demarcating tokens from metadata in context_ids and MC metadata sample) :param num_contexts: batch_size (number of non-padded pos_ids / neg_ids in each row in batch) :param context_mask: batch_size x context_len (BERT attention mask for context_ids) :param center_mask: batch_size x num_context_ids :param pos_mask: batch_size x window_size * 2 x num_context_ids :param neg_mask: batch_size x window_size * 2 x num_context_ids :param num_metadata_samples: Number of MC samples approximating marginal distribution over metadata Affects where to how to segment BERT-style sequence into distinct token_type_ids segments :return: a tuple of 1-size tensors representing the hinge likelihood loss and the reconstruction kl-loss """ batch_size, num_context_ids, max_decoder_len = pos_ids.size() device = 'cuda' if torch.cuda.is_available() else 'cpu' mask_size = torch.Size([batch_size, num_context_ids]) mask = mask_2D(mask_size, num_contexts).to(device) # Compute center words mu_center_q, sigma_center_q, _ = self.encoder( input_ids=context_ids, attention_mask=context_mask, token_type_ids=context_token_type_ids) mu_center_tiled_q = mu_center_q.unsqueeze(1).repeat(1, num_context_ids, 1) sigma_center_tiled_q = sigma_center_q.unsqueeze(1).repeat(1, num_context_ids, 1) mu_center_flat_q = mu_center_tiled_q.view(batch_size * num_context_ids, -1) sigma_center_flat_q = sigma_center_tiled_q.view(batch_size * num_context_ids, -1) # Compute decoded representations of (w, d), E(c), E(n) n = batch_size * num_context_ids pos_ids_flat, pos_mask_flat = pos_ids.view(n, -1), pos_mask.view(n, -1) neg_ids_flat, neg_mask_flat = neg_ids.view(n, -1), neg_mask.view(n, -1) joint_ids = torch.cat([center_ids, pos_ids_flat, neg_ids_flat], axis=0) joint_mask = torch.cat([center_mask, pos_mask_flat, neg_mask_flat], axis=0) center_sep_idx = 1 + 2 other_sep_idx = num_metadata_samples + 2 decoder_type_ids = torch.zeros([joint_ids.size()[0], max_decoder_len]).long().to(device) decoder_type_ids[:batch_size, -center_sep_idx:] = 1 decoder_type_ids[batch_size:, -other_sep_idx:] = 1 mu_joint, sigma_joint = self.decoder( input_ids=joint_ids, attention_mask=joint_mask, token_type_ids=decoder_type_ids) mu_center, sigma_center = mu_joint[:batch_size], sigma_joint[:batch_size] s = batch_size * (num_context_ids + 1) mu_pos_flat, sigma_pos_flat = mu_joint[batch_size:s], sigma_joint[batch_size:s] mu_neg_flat, sigma_neg_flat = mu_joint[s:], sigma_joint[s:] # Compute KL-divergence between center words and negative and reshape kl_pos_flat = compute_kl(mu_center_flat_q, sigma_center_flat_q, mu_pos_flat, sigma_pos_flat) kl_neg_flat = compute_kl(mu_center_flat_q, sigma_center_flat_q, mu_neg_flat, sigma_neg_flat) kl_pos = kl_pos_flat.view(batch_size, num_context_ids) kl_neg = kl_neg_flat.view(batch_size, num_context_ids) hinge_loss = (kl_pos - kl_neg + 1.0).clamp_min_(0) hinge_loss.masked_fill_(mask, 0) hinge_loss = hinge_loss.sum(1) recon_loss = compute_kl(mu_center_q, sigma_center_q, mu_center, sigma_center).squeeze(-1) return hinge_loss.mean(), recon_loss.mean()
def forward(self, sf_ids, section_ids, category_ids, context_ids, lf_ids, target_lf_ids, lf_token_ct, lf_metadata_ids, lf_metadata_p, num_outputs, num_contexts): """ :param sf_ids: LongTensor of batch_size :param section_ids: LongTensor of batch_size :param category_ids: LongTensor of batch_size :param context_ids: LongTensor of batch_size x 2 * context_window :param lf_ids: LongTensor of batch_size x max_output_size x max_lf_len :param target_lf_ids: LongTensor of batch_size representing which index in lf_ids lies the target LF :param lf_token_ct: LongTensor of batch_size x max_output_size. N-gram count for each LF (used for masking) :param lf_metadata_ids: batch_size x max_output_size x max_num_metadata. Ids for every metadata LF appears in :param lf_metadata_p: batch_size x max_output_size x max_num_metadata Empirical probability for lf_metadata_ids ~ p(metadata|LF) :param num_outputs: list representing the number of target LFs for each row in batch. Used for masking to avoid returning invalid predictions. :param num_contexts: LongTensor of batch_size. The actual window size of the SF context. Many are shorter than target of 2 * context_window. :return: scores for each candidate LF, target_lf_ids, rel_weights (output of encoder gating function) """ batch_size, max_output_size, max_lf_ngram = lf_ids.size() _, num_context_ids = context_ids.size() # Compute SF contexts # Mask padded context ids mask_size = torch.Size([batch_size, num_context_ids]) mask = mask_2D(mask_size, num_contexts).to(self.device) sf_mu, sf_sigma, rel_weights = self.encoder(sf_ids, section_ids, context_ids, mask, center_mask_p=None, context_mask_p=None) num_metadata = lf_metadata_ids.size()[-1] # Tile SFs across each LF and flatten both SFs and LFs sf_mu_flat = sf_mu.unsqueeze(1).repeat( 1, max_output_size * num_metadata, 1).view(batch_size * max_output_size * num_metadata, -1) sf_sigma_flat = sf_sigma.unsqueeze(1).repeat( 1, max_output_size * num_metadata, 1).view(batch_size * max_output_size * num_metadata, -1) # Compute E[LF] normalizer = lf_token_ct.unsqueeze(-1).clamp_min(1.0) lf_mu, lf_sigma = self._compute_marginal(lf_ids, lf_metadata_ids, normalizer=normalizer) lf_mu_flat = lf_mu.view(batch_size * max_output_size * num_metadata, -1) lf_sigma_flat = lf_sigma.view( batch_size * max_output_size * num_metadata, 1) output_dim = torch.Size([batch_size, max_output_size]) output_mask = mask_2D(output_dim, num_outputs).to(self.device) kl_marginal = compute_kl(sf_mu_flat, sf_sigma_flat, lf_mu_flat, lf_sigma_flat).view(batch_size, max_output_size, num_metadata) kl = (kl_marginal * lf_metadata_p).sum(2) score = -kl score.masked_fill_(output_mask, float('-inf')) return score, target_lf_ids, rel_weights
def evaluate_mimic_acronyms(prev_args, model, vocab): print('Evaluating context-dependent representations via acronym disambiguation...') label_df = pd.read_csv('eval_data/mimic_acronym_expansion_labels.csv') expansion_df = pd.read_csv('eval_data/acronym_expansions.csv') expansion_df = expansion_df[expansion_df['lf_count'] > 0] target_window_size = prev_args.window sf_lf_map = defaultdict(set) for row_idx, row in expansion_df.iterrows(): row = row.to_dict() sf_lf_map[row['sf']].add(row['lf']) contexts = label_df['context'].tolist() context_ids = [] for row_idx, context in enumerate(contexts): sc = context.split() for cidx in range(len(sc)): if sc[cidx] == 'TARGETWORD': break left_idx = max(0, cidx - target_window_size) right_idx = min(cidx + target_window_size + 1, len(sc)) context_tokens = sc[left_idx:cidx] + sc[cidx + 1: right_idx] context_ids.append([ vocab.get_id(token.lower()) for token in context_tokens ]) prior_kls = [] posterior_kls = [] for row_idx, row in label_df.iterrows(): row = row.to_dict() sf, target_lfs = row['sf'], row['lf'] context_id_seq = context_ids[row_idx] center_id = vocab.get_id(sf.lower()) center_id_tens = torch.LongTensor([center_id]).clamp_min_(0).to(prev_args.device) context_id_tens = torch.LongTensor(context_id_seq).unsqueeze(0).clamp_min_(0).to(prev_args.device) mask = torch.BoolTensor(torch.Size([1, len(context_id_seq)])).to(prev_args.device) mask.fill_(0) p_mu, p_sigma = model._compute_priors(center_id_tens) min_prior_kl, min_posterior_kl = float('inf'), float('inf') with torch.no_grad(): z_mu, z_sigma = model.encoder(center_id_tens, context_id_tens, mask) for target_expansion in target_lfs.split('|'): lf_ids = [vocab.get_id(token.lower()) for token in target_expansion.split()] lf_tens = torch.LongTensor(lf_ids).clamp_min_(0).to(prev_args.device) lf_mu, lf_sigma = model._compute_priors(lf_tens) avg_lf_mu = lf_mu.mean(axis=0) avg_lf_sigma = lf_sigma.mean(axis=0) prior_kl = compute_kl(p_mu, p_sigma, avg_lf_mu, avg_lf_sigma).item() posterior_kl = compute_kl(z_mu, z_sigma, avg_lf_mu, avg_lf_sigma).item() min_prior_kl = min(prior_kl, min_prior_kl) min_posterior_kl = min(posterior_kl, min_posterior_kl) prior_kls.append(min_prior_kl) posterior_kls.append(min_posterior_kl) avg_prior_kl = sum(prior_kls) / float(len(prior_kls)) avg_posterior_distances = sum(posterior_kls) / float(len(posterior_kls)) print('Avg Prior KLD={}. Avg Posterior KLD={}'.format(avg_prior_kl, avg_posterior_distances))