def _max_margin(self, mu_q, sigma_q, pos_mu_p, pos_sigma_p, neg_mu_p,
                    neg_sigma_p, mask):
        """
        Computes a sum over context words margin(hinge loss).
        :param pos_context_words:  a tensor with true context words ids [batch_size x window_size]
        :param neg_context_words: a tensor with negative context words ids [batch_size x window_size]
        :param num_contexts: batch_size (how many context words for each center id - necessary for masking padding)
        :return: tensor [batch_size x 1]
        """
        batch_size, num_context_ids, embed_dim = pos_mu_p.size()

        mu_q_tiled = mu_q.unsqueeze(1).repeat(1, num_context_ids, 1)
        sigma_q_tiled = sigma_q.unsqueeze(1).repeat(1, num_context_ids, 1)
        mu_q_flat = mu_q_tiled.view(batch_size * num_context_ids, -1)
        sigma_q_flat = sigma_q_tiled.view(batch_size * num_context_ids, -1)

        pos_mu_p_flat = pos_mu_p.view(batch_size * num_context_ids, -1)
        pos_sigma_p_flat = pos_sigma_p.view(batch_size * num_context_ids, -1)
        neg_mu_p_flat = neg_mu_p.view(batch_size * num_context_ids, -1)
        neg_sigma_p_flat = neg_sigma_p.view(batch_size * num_context_ids, -1)

        kl_pos = compute_kl(mu_q_flat,
                            sigma_q_flat,
                            pos_mu_p_flat,
                            pos_sigma_p_flat,
                            device=self.device).view(batch_size, -1)
        kl_neg = compute_kl(mu_q_flat,
                            sigma_q_flat,
                            neg_mu_p_flat,
                            neg_sigma_p_flat,
                            device=self.device).view(batch_size, -1)

        hinge_loss = (kl_pos - kl_neg + self.margin).clamp_min_(0)
        hinge_loss.masked_fill_(mask, 0)
        return hinge_loss.sum(1)
Exemple #2
0
    def forward(self, center_ids, center_metadata_ids, context_ids, context_metadata_ids, neg_ids, neg_metadata_ids,
                num_contexts, num_metadata_samples=None):
        """
        :param center_ids: batch_size
        :param center_metadata_ids: batch_size
        :param context_ids: batch_size, 2 * context_window
        :param context_metadata_ids: batch_size, 2 * context_window, metadata_samples
        :param neg_ids: batch_size, 2 * context_window
        :param neg_metadata_ids: batch_size, 2 * context_window, metadata_samples
        :param num_contexts: batch_size (how many context words for each center id - necessary for masking padding)
        :return: cost components: KL-Divergence (q(z|w,c) || p(z|w)) and max margin (reconstruction error)
        """
        # Mask padded context ids
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        batch_size, num_context_ids = context_ids.size()
        mask_size = torch.Size([batch_size, num_context_ids])
        mask = mask_2D(mask_size, num_contexts).to(device)

        m_samples = context_metadata_ids.size()[-1]
        assert m_samples == num_metadata_samples

        # Compute center words
        mu_center_q, sigma_center_q, _ = self.encoder(center_ids, center_metadata_ids, context_ids, mask)
        mu_center_tiled_q = mu_center_q.unsqueeze(1).repeat(1, num_context_ids * m_samples, 1)
        sigma_center_tiled_q = sigma_center_q.unsqueeze(1).repeat(1, num_context_ids * m_samples, 1)
        mu_center_flat_q = mu_center_tiled_q.view(batch_size * num_context_ids * m_samples, -1)
        sigma_center_flat_q = sigma_center_tiled_q.view(batch_size * num_context_ids * m_samples, -1)

        # Compute decoded representations of (w, d), E(c), E(n)
        mu_center, sigma_center = self.decoder(center_ids, center_metadata_ids)
        mu_pos, sigma_pos = self._compute_marginal(context_ids, context_metadata_ids)
        mu_neg, sigma_neg = self._compute_marginal(neg_ids, neg_metadata_ids)

        # Flatten positive context
        mu_pos_flat = mu_pos.view(batch_size * num_context_ids * m_samples, -1)
        sigma_pos_flat = sigma_pos.view(batch_size * num_context_ids * m_samples, -1)

        # Flatten negative context
        mu_neg_flat = mu_neg.view(batch_size * num_context_ids * m_samples, -1)
        sigma_neg_flat = sigma_neg.view(batch_size * num_context_ids * m_samples, -1)

        # Compute KL-divergence between center words and negative and reshape
        kl_pos_flat = compute_kl(mu_center_flat_q, sigma_center_flat_q, mu_pos_flat, sigma_pos_flat)
        kl_neg_flat = compute_kl(mu_center_flat_q, sigma_center_flat_q, mu_neg_flat, sigma_neg_flat)
        kl_pos = kl_pos_flat.view(batch_size, num_context_ids, m_samples).mean(-1)
        kl_neg = kl_neg_flat.view(batch_size, num_context_ids, m_samples).mean(-1)

        hinge_loss = (kl_pos - kl_neg + 1.0).clamp_min_(0)
        hinge_loss.masked_fill_(mask, 0)
        hinge_loss = hinge_loss.sum(1)

        recon_loss = compute_kl(mu_center_q, sigma_center_q, mu_center, sigma_center).squeeze(-1)
        return hinge_loss.mean(), recon_loss.mean()
    def forward(self, sf_ids, context_ids, lf_ids, target_lf_ids, lf_token_ct,
                num_outputs):
        """
        :param sf_ids: batch_size
        :param context_ids: batch_size, num_context_ids
        :param lf_ids: batch_size, max_output_size, max_lf_len
        :param lf_token_ct: batch_size, max_output_size - normalizer for lf_ids
        :param target_lf_ids: batch_size
        :param num_outputs: batch_size
        :return:
        """
        batch_size, num_context_ids = context_ids.size()
        max_output_size = lf_ids.size()[1]

        output_dim = torch.Size([batch_size, max_output_size])
        output_mask = mask_2D(output_dim, num_outputs)

        # First thing is to pass the SF with the context to the encoder
        mask = torch.BoolTensor(torch.Size([batch_size, num_context_ids]))
        mask.fill_(0)
        sf_mu, sf_sigma = self.encoder(sf_ids, context_ids, mask)

        # Next is to get prior representations for each LF in lf_ids
        lf_mu, lf_sigma = self._compute_priors(lf_ids)

        # Summarize LFs
        normalizer = lf_token_ct.unsqueeze(-1).clamp_min(1.0)
        lf_mu_sum, lf_sigma_sum = lf_mu.sum(-2) / normalizer, lf_sigma.sum(
            -2) / normalizer

        # Tile SFs across each LF and flatten both SFs and LFs
        sf_mu_flat = sf_mu.unsqueeze(1).repeat(1, max_output_size, 1).view(
            batch_size * max_output_size, -1)
        sf_sigma_flat = sf_sigma.unsqueeze(1).repeat(
            1, max_output_size, 1).view(batch_size * max_output_size, -1)
        lf_mu_flat = lf_mu_sum.view(batch_size * max_output_size, -1)
        lf_sigma_flat = lf_sigma_sum.view(batch_size * max_output_size, -1)

        kl = compute_kl(sf_mu_flat, sf_sigma_flat, lf_mu_flat,
                        lf_sigma_flat).view(batch_size, max_output_size)
        # min_kl, max_kl = kl.min(-1)[0], kl.max(-1)[0]
        # normalizer = (max_kl - min_kl).unsqueeze(-1)
        # numerator = max_kl.unsqueeze(-1).repeat(1, kl.size()[-1]) - kl
        # score = numerator  # / normalizer
        score = -kl
        score.masked_fill_(output_mask, float('-inf'))
        return score, target_lf_ids
Exemple #4
0
    def forward(self, token_ids, sec_ids, cat_ids, context_ids,
                neg_context_ids, num_contexts):
        """
        :param center_ids: batch_size
        :param context_ids: batch_size, 2 * context_window
        :param neg_context_ids: batch_size, 2 * context_window
        :param num_contexts: batch_size (how many context words for each center id - necessary for masking padding)
        :return: cost components: KL-Divergence (q(z|w,c) || p(z|w)) and max margin (reconstruction error)
        """
        # Mask padded context ids
        batch_size, num_context_ids = context_ids.size()
        mask_size = torch.Size([batch_size, num_context_ids])
        mask = mask_2D(mask_size, num_contexts).to(self.device)

        center_ids = token_ids
        if self.multi_bsg:
            center_id_candidates = torch.cat([
                token_ids.unsqueeze(0),
                sec_ids.unsqueeze(0),
                cat_ids.unsqueeze(0)
            ])
            input_sample = torch.multinomial(self.input_weights,
                                             batch_size,
                                             replacement=True).to(self.device)
            center_ids = center_id_candidates.gather(
                0, input_sample.unsqueeze(0)).squeeze(0)

        mu_q, sigma_q = self.encoder(center_ids,
                                     context_ids,
                                     mask,
                                     token_mask_p=self.mask_p)
        mu_p, sigma_p = self._compute_priors(token_ids)

        pos_mu_p, pos_sigma_p = self._compute_priors(context_ids)
        neg_mu_p, neg_sigma_p = self._compute_priors(neg_context_ids)

        kl = compute_kl(mu_q, sigma_q, mu_p, sigma_p).mean()
        max_margin = self._max_margin(mu_q, sigma_q, pos_mu_p, pos_sigma_p,
                                      neg_mu_p, neg_sigma_p, mask).mean()
        return kl, max_margin
    def forward(self, center_ids, context_ids, neg_context_ids, num_contexts):
        """
        :param center_ids: batch_size
        :param context_ids: batch_size, 2 * context_window
        :param neg_context_ids: batch_size, 2 * context_window
        :param num_contexts: batch_size (how many context words for each center id - necessary for masking padding)
        :return: cost components: KL-Divergence (q(z|w,c) || p(z|w)) and max margin (reconstruction error)
        """
        # Mask padded context ids
        batch_size, num_context_ids = context_ids.size()
        mask_size = torch.Size([batch_size, num_context_ids])
        mask = mask_2D(mask_size, num_contexts).to(self.device)

        mu_q, sigma_q = self.encoder(center_ids, context_ids, mask)

        mu_p, sigma_p = self._compute_priors(center_ids)

        pos_mu_p, pos_sigma_p = self._compute_priors(context_ids)
        neg_mu_p, neg_sigma_p = self._compute_priors(neg_context_ids)

        kl = compute_kl(mu_q, sigma_q, mu_p, sigma_p).mean()
        max_margin = self._max_margin(mu_q, sigma_q, pos_mu_p, pos_sigma_p,
                                      neg_mu_p, neg_sigma_p, mask).mean()
        return kl, max_margin
    def forward(self, sf_ids, section_ids, category_ids, context_ids, lf_ids,
                target_lf_ids, lf_token_ct, lf_metadata_ids, num_outputs,
                num_contexts):
        """
        :param sf_ids: LongTensor of batch_size
        :param context_ids: LongTensor of batch_size x 2 * context_window
        :param lf_ids: LongTensor of batch_size x max_output_size x max_lf_len
        :param lf_token_ct: batch_size, max_output_size - normalizer for lf_ids
        :param target_lf_ids: LongTensor of batch_size representing which index in lf_ids lies the target LF
        :param num_outputs: list representing the number of target LFs for each row in batch.
        :return:
        """
        batch_size, max_output_size, _ = lf_ids.size()

        # Next is to get prior representations for each LF in lf_ids
        lf_mu, lf_sigma = self._compute_priors(lf_ids)
        # Summarize LFs
        normalizer = lf_token_ct.unsqueeze(-1).clamp_min(1.0)
        lf_mu_sum, lf_sigma_sum = lf_mu.sum(-2) / normalizer, lf_sigma.sum(
            -2) / normalizer

        combined_mu = []
        combined_sigma = []

        # Encode SFs in context
        sf_mu_tokens, sf_sigma_tokens = self.encode_context(
            sf_ids, context_ids, num_contexts)
        combined_mu.append(sf_mu_tokens)
        combined_sigma.append(sf_sigma_tokens)

        # For MBSGE ensemble method, we leverage section ids and note category ids
        if len(section_ids.nonzero()) > 0:
            sf_mu_sec, sf_sigma_sec = self.encode_context(
                section_ids, context_ids, num_contexts)
            combined_mu.append(sf_mu_sec)
            combined_sigma.append(sf_sigma_sec)

        if len(category_ids.nonzero()) > 0:
            sf_mu_cat, sf_sigma_cat = self.encode_context(
                category_ids, context_ids, num_contexts)
            combined_mu.append(sf_mu_cat)
            combined_sigma.append(sf_sigma_cat)

        combined_mu = torch.cat(list(map(lambda x: x.unsqueeze(1),
                                         combined_mu)),
                                axis=1)
        combined_sigma = torch.cat(list(
            map(lambda x: x.unsqueeze(1), combined_sigma)),
                                   axis=1)

        sf_mu = combined_mu.mean(1)
        sf_sigma = combined_sigma.mean(1)

        # Tile SFs across each LF and flatten both SFs and LFs
        sf_mu_flat = sf_mu.unsqueeze(1).repeat(1, max_output_size, 1).view(
            batch_size * max_output_size, -1)
        sf_sigma_flat = sf_sigma.unsqueeze(1).repeat(
            1, max_output_size, 1).view(batch_size * max_output_size, -1)
        lf_mu_flat = lf_mu_sum.view(batch_size * max_output_size, -1)
        lf_sigma_flat = lf_sigma_sum.view(batch_size * max_output_size, -1)

        output_dim = torch.Size([batch_size, max_output_size])
        output_mask = mask_2D(output_dim, num_outputs).to(self.device)

        kl = compute_kl(sf_mu_flat, sf_sigma_flat, lf_mu_flat,
                        lf_sigma_flat).view(batch_size, max_output_size)
        score = -kl
        score.masked_fill_(output_mask, float('-inf'))
        return score, target_lf_ids, None
Exemple #7
0
        print(headers[i])
        for p in np.arange(0, 1.25, 0.25):
            rw = [p, 1.0 - p]
            rel_weights = torch.FloatTensor([rw]).to(device_str)
            with torch.no_grad():
                mu_q, sigma_q, weights = model.encoder(center_word_tens,
                                                       header_tens,
                                                       context_tens,
                                                       mask,
                                                       center_mask_p=None,
                                                       context_mask_p=None,
                                                       metadata_mask_p=None,
                                                       rel_weights=rel_weights)

            scores = tensor_to_np(
                nn.Softmax(-1)(-compute_kl(mu_q, sigma_q, mu_compare,
                                           sigma_compare).squeeze(1)))
            order = np.argsort(-scores)
            weight_str = 'Relative Weights --> Word={}.  Section={}'.format(
                rw[1], rw[0])
            print('\t{}'.format(weight_str))
            for i in order[:min(10, len(order))]:
                print('\t\t{} --> {}'.format(compare_words[i], scores[i]))

    section_df = pd.read_csv(
        os.path.join(home_dir,
                     '/preprocess/data/mimic/section_freq.csv')).dropna()
    section_names = list(
        sorted(
            set(
                list(
                    section_df.nlargest(
Exemple #8
0
    pos_center_tens = torch.LongTensor(pos_center_id).clamp_min(0)
    neg_center_tens = torch.LongTensor(neg_center_id).clamp_min(0)
    pos_mu, pos_sigma = model._compute_priors(pos_center_tens)
    neg_mu, neg_sigma = model._compute_priors(neg_center_tens)

    window_tens = torch.LongTensor([vocab.get_id(t) for t in window_toks
                                    ]).clamp_min(0).unsqueeze(0)
    window_mask = torch.BoolTensor(torch.Size([1, len(window_toks)]))
    window_mask.fill_(0)

    with torch.no_grad():
        z_orig_mu, z_orig_sigma = model.encoder(ab_tens, window_tens,
                                                window_mask)

    orig_kl_pos = compute_kl(z_orig_mu, z_orig_sigma, pos_mu, pos_sigma).item()
    orig_kl_neg = compute_kl(z_orig_mu, z_orig_sigma, neg_mu, neg_sigma).item()

    extra_context_tokens = []
    extra_kls = []
    for idx, (mu, sigma) in enumerate([(pos_mu, pos_sigma),
                                       (neg_mu, neg_sigma)]):
        lf = pos_center[0] if idx == 0 else neg_center[0]

        posterior_kld_full_tokens = compute_kl(mu, sigma, full_priors['mu'],
                                               full_priors['sigma']).squeeze(1)
        closest_token_idxs = posterior_kld_full_tokens.squeeze().argsort(
        )[:EXTRA_CONTEXTS].numpy()
        closest_tokens = []
        for idx in closest_token_idxs:
            closest_tokens.append(full_tokens[idx])
Exemple #9
0
    def forward(self, context_ids, center_ids, pos_ids, neg_ids, context_token_type_ids,
                num_contexts, context_mask, center_mask, pos_mask, neg_mask, num_metadata_samples=None):
        """
        :param context_ids: batch_size x context_len (wp ids for BERT-style context word sequence with center word
        and metadata special token)
        :param center_ids: batch_size x num_context_ids (wp ids for center words and its metadata)
        :param pos_ids: batch_size x window_size * 2 x num_context_ids (wp ids for context words)
        :param neg_ids: batch_size x window_size * 2 x num_context_ids (wp ids for negatively sampled words
        and MC metadata sample)
        :param context_token_type_ids: batch_size x context_len (demarcating tokens from metadata in context_ids
        and MC metadata sample)
        :param num_contexts: batch_size (number of non-padded pos_ids / neg_ids in each row in batch)
        :param context_mask: batch_size x context_len (BERT attention mask for context_ids)
        :param center_mask: batch_size x num_context_ids
        :param pos_mask: batch_size x window_size * 2 x num_context_ids
        :param neg_mask: batch_size x window_size * 2 x num_context_ids
        :param num_metadata_samples:  Number of MC samples approximating marginal distribution over metadata
        Affects where to how to segment BERT-style sequence into distinct token_type_ids segments
        :return: a tuple of 1-size tensors representing the hinge likelihood loss and the reconstruction kl-loss
        """
        batch_size, num_context_ids, max_decoder_len = pos_ids.size()
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        mask_size = torch.Size([batch_size, num_context_ids])
        mask = mask_2D(mask_size, num_contexts).to(device)

        # Compute center words
        mu_center_q, sigma_center_q, _ = self.encoder(
            input_ids=context_ids, attention_mask=context_mask, token_type_ids=context_token_type_ids)
        mu_center_tiled_q = mu_center_q.unsqueeze(1).repeat(1, num_context_ids, 1)
        sigma_center_tiled_q = sigma_center_q.unsqueeze(1).repeat(1, num_context_ids, 1)
        mu_center_flat_q = mu_center_tiled_q.view(batch_size * num_context_ids, -1)
        sigma_center_flat_q = sigma_center_tiled_q.view(batch_size * num_context_ids, -1)

        # Compute decoded representations of (w, d), E(c), E(n)
        n = batch_size * num_context_ids
        pos_ids_flat, pos_mask_flat = pos_ids.view(n, -1), pos_mask.view(n, -1)
        neg_ids_flat, neg_mask_flat = neg_ids.view(n, -1), neg_mask.view(n, -1)

        joint_ids = torch.cat([center_ids, pos_ids_flat, neg_ids_flat], axis=0)
        joint_mask = torch.cat([center_mask, pos_mask_flat, neg_mask_flat], axis=0)

        center_sep_idx = 1 + 2
        other_sep_idx = num_metadata_samples + 2
        decoder_type_ids = torch.zeros([joint_ids.size()[0], max_decoder_len]).long().to(device)
        decoder_type_ids[:batch_size, -center_sep_idx:] = 1
        decoder_type_ids[batch_size:, -other_sep_idx:] = 1

        mu_joint, sigma_joint = self.decoder(
            input_ids=joint_ids, attention_mask=joint_mask, token_type_ids=decoder_type_ids)

        mu_center, sigma_center = mu_joint[:batch_size], sigma_joint[:batch_size]
        s = batch_size * (num_context_ids + 1)
        mu_pos_flat, sigma_pos_flat = mu_joint[batch_size:s], sigma_joint[batch_size:s]
        mu_neg_flat, sigma_neg_flat = mu_joint[s:], sigma_joint[s:]

        # Compute KL-divergence between center words and negative and reshape
        kl_pos_flat = compute_kl(mu_center_flat_q, sigma_center_flat_q, mu_pos_flat, sigma_pos_flat)
        kl_neg_flat = compute_kl(mu_center_flat_q, sigma_center_flat_q, mu_neg_flat, sigma_neg_flat)
        kl_pos = kl_pos_flat.view(batch_size, num_context_ids)
        kl_neg = kl_neg_flat.view(batch_size, num_context_ids)

        hinge_loss = (kl_pos - kl_neg + 1.0).clamp_min_(0)
        hinge_loss.masked_fill_(mask, 0)
        hinge_loss = hinge_loss.sum(1)

        recon_loss = compute_kl(mu_center_q, sigma_center_q, mu_center, sigma_center).squeeze(-1)
        return hinge_loss.mean(), recon_loss.mean()
Exemple #10
0
    def forward(self, sf_ids, section_ids, category_ids, context_ids, lf_ids,
                target_lf_ids, lf_token_ct, lf_metadata_ids, lf_metadata_p,
                num_outputs, num_contexts):
        """
        :param sf_ids: LongTensor of batch_size
        :param section_ids: LongTensor of batch_size
        :param category_ids: LongTensor of batch_size
        :param context_ids: LongTensor of batch_size x 2 * context_window
        :param lf_ids: LongTensor of batch_size x max_output_size x max_lf_len
        :param target_lf_ids: LongTensor of batch_size representing which index in lf_ids lies the target LF
        :param lf_token_ct: LongTensor of batch_size x max_output_size.  N-gram count for each LF (used for masking)
        :param lf_metadata_ids: batch_size x max_output_size x max_num_metadata. Ids for every metadata LF appears in
        :param lf_metadata_p: batch_size x max_output_size x max_num_metadata
        Empirical probability for lf_metadata_ids ~ p(metadata|LF)
        :param num_outputs: list representing the number of target LFs for each row in batch.
        Used for masking to avoid returning invalid predictions.
        :param num_contexts: LongTensor of batch_size.  The actual window size of the SF context.
        Many are shorter than target of 2 * context_window.
        :return: scores for each candidate LF, target_lf_ids, rel_weights (output of encoder gating function)
        """
        batch_size, max_output_size, max_lf_ngram = lf_ids.size()
        _, num_context_ids = context_ids.size()

        # Compute SF contexts
        # Mask padded context ids
        mask_size = torch.Size([batch_size, num_context_ids])
        mask = mask_2D(mask_size, num_contexts).to(self.device)
        sf_mu, sf_sigma, rel_weights = self.encoder(sf_ids,
                                                    section_ids,
                                                    context_ids,
                                                    mask,
                                                    center_mask_p=None,
                                                    context_mask_p=None)

        num_metadata = lf_metadata_ids.size()[-1]

        # Tile SFs across each LF and flatten both SFs and LFs
        sf_mu_flat = sf_mu.unsqueeze(1).repeat(
            1, max_output_size * num_metadata,
            1).view(batch_size * max_output_size * num_metadata, -1)
        sf_sigma_flat = sf_sigma.unsqueeze(1).repeat(
            1, max_output_size * num_metadata,
            1).view(batch_size * max_output_size * num_metadata, -1)

        # Compute E[LF]
        normalizer = lf_token_ct.unsqueeze(-1).clamp_min(1.0)
        lf_mu, lf_sigma = self._compute_marginal(lf_ids,
                                                 lf_metadata_ids,
                                                 normalizer=normalizer)
        lf_mu_flat = lf_mu.view(batch_size * max_output_size * num_metadata,
                                -1)
        lf_sigma_flat = lf_sigma.view(
            batch_size * max_output_size * num_metadata, 1)
        output_dim = torch.Size([batch_size, max_output_size])
        output_mask = mask_2D(output_dim, num_outputs).to(self.device)

        kl_marginal = compute_kl(sf_mu_flat, sf_sigma_flat, lf_mu_flat,
                                 lf_sigma_flat).view(batch_size,
                                                     max_output_size,
                                                     num_metadata)
        kl = (kl_marginal * lf_metadata_p).sum(2)
        score = -kl

        score.masked_fill_(output_mask, float('-inf'))
        return score, target_lf_ids, rel_weights
def evaluate_mimic_acronyms(prev_args, model, vocab):
    print('Evaluating context-dependent representations via acronym disambiguation...')
    label_df = pd.read_csv('eval_data/mimic_acronym_expansion_labels.csv')
    expansion_df = pd.read_csv('eval_data/acronym_expansions.csv')
    expansion_df = expansion_df[expansion_df['lf_count'] > 0]

    target_window_size = prev_args.window

    sf_lf_map = defaultdict(set)
    for row_idx, row in expansion_df.iterrows():
        row = row.to_dict()
        sf_lf_map[row['sf']].add(row['lf'])

    contexts = label_df['context'].tolist()
    context_ids = []
    for row_idx, context in enumerate(contexts):
        sc = context.split()
        for cidx in range(len(sc)):
            if sc[cidx] == 'TARGETWORD':
                break

        left_idx = max(0, cidx - target_window_size)
        right_idx = min(cidx + target_window_size + 1, len(sc))
        context_tokens = sc[left_idx:cidx] + sc[cidx + 1: right_idx]
        context_ids.append([
            vocab.get_id(token.lower()) for token in context_tokens
        ])

    prior_kls = []
    posterior_kls = []

    for row_idx, row in label_df.iterrows():
        row = row.to_dict()
        sf, target_lfs = row['sf'], row['lf']
        context_id_seq = context_ids[row_idx]
        center_id = vocab.get_id(sf.lower())

        center_id_tens = torch.LongTensor([center_id]).clamp_min_(0).to(prev_args.device)
        context_id_tens = torch.LongTensor(context_id_seq).unsqueeze(0).clamp_min_(0).to(prev_args.device)

        mask = torch.BoolTensor(torch.Size([1, len(context_id_seq)])).to(prev_args.device)
        mask.fill_(0)

        p_mu, p_sigma = model._compute_priors(center_id_tens)
        min_prior_kl, min_posterior_kl = float('inf'), float('inf')

        with torch.no_grad():
            z_mu, z_sigma = model.encoder(center_id_tens, context_id_tens, mask)

        for target_expansion in target_lfs.split('|'):
            lf_ids = [vocab.get_id(token.lower()) for token in target_expansion.split()]
            lf_tens = torch.LongTensor(lf_ids).clamp_min_(0).to(prev_args.device)

            lf_mu, lf_sigma = model._compute_priors(lf_tens)

            avg_lf_mu = lf_mu.mean(axis=0)
            avg_lf_sigma = lf_sigma.mean(axis=0)

            prior_kl = compute_kl(p_mu, p_sigma, avg_lf_mu, avg_lf_sigma).item()
            posterior_kl = compute_kl(z_mu, z_sigma, avg_lf_mu, avg_lf_sigma).item()

            min_prior_kl = min(prior_kl, min_prior_kl)
            min_posterior_kl = min(posterior_kl, min_posterior_kl)
        prior_kls.append(min_prior_kl)
        posterior_kls.append(min_posterior_kl)

    avg_prior_kl = sum(prior_kls) / float(len(prior_kls))
    avg_posterior_distances = sum(posterior_kls) / float(len(posterior_kls))
    print('Avg Prior KLD={}. Avg Posterior KLD={}'.format(avg_prior_kl, avg_posterior_distances))