コード例 #1
0
    def forward(self, context_ids, context_token_ct, lf_ids, lf_token_ct,
                target_lf_ids, num_outputs):
        device_str = 'cuda' if torch.cuda.is_available() else 'cpu'
        batch_size, num_context_ids = context_ids.size()
        _, max_output_size, max_lf_len = lf_ids.size()

        context_mask = np.ones([batch_size, num_context_ids], dtype=int)
        lf_mask = np.ones([batch_size, max_output_size, max_lf_len], dtype=int)
        for i in range(batch_size):
            context_mask[i, context_token_ct[i]:] = 0
            for lf_idx in range(max_output_size):
                lf_mask[i, lf_idx, lf_token_ct[i, lf_idx]:] = 0
        context_mask = torch.FloatTensor(context_mask).to(device_str)
        lf_mask = torch.FloatTensor(lf_mask).to(device_str)

        output_dim = torch.Size([batch_size, max_output_size])
        output_mask = mask_2D(output_dim, num_outputs).to(device_str)
        # context_position_ids = torch.arange(0, num_context_ids).unsqueeze(0).repeat(batch_size, 1).to(device_str).long()
        encoded_context = self.encode(context_ids, context_mask,
                                      context_token_ct)

        lf_ids_flat = lf_ids.view(batch_size * max_output_size, max_lf_len)
        lf_mask_flat = lf_mask.view(batch_size * max_output_size, max_lf_len)
        lf_token_ct_flat = lf_token_ct.view(batch_size * max_output_size)
        # lf_position_ids_flat = torch.arange(0, max_lf_len).unsqueeze(0).repeat(
        #     batch_size * max_output_size, 1).to(device_str).long()
        encoded_lfs_flat = self.encode(lf_ids_flat, lf_mask_flat,
                                       lf_token_ct_flat)
        encoded_lfs = encoded_lfs_flat.view(batch_size, max_output_size, -1)

        encoded_context_tiled = encoded_context.unsqueeze(1).repeat(
            1, max_output_size, 1)
        score = nn.CosineSimilarity(dim=-1)(encoded_context_tiled, encoded_lfs)
        score.masked_fill_(output_mask, float('-inf'))
        return score, target_lf_ids
コード例 #2
0
ファイル: bsg_acronym_expander.py プロジェクト: griff4692/LMC
    def encode_context(self, sf_ids, context_ids, num_contexts):
        batch_size, num_context_ids = context_ids.size()

        # Mask padded context ids
        mask_size = torch.Size([batch_size, num_context_ids])
        mask = mask_2D(mask_size, num_contexts).to(self.device)

        sf_mu, sf_sigma = self.encoder(sf_ids,
                                       context_ids,
                                       mask,
                                       token_mask_p=None)
        return sf_mu, sf_sigma
コード例 #3
0
ファイル: lmc_model.py プロジェクト: griff4692/LMC
    def forward(self, center_ids, center_metadata_ids, context_ids, context_metadata_ids, neg_ids, neg_metadata_ids,
                num_contexts, num_metadata_samples=None):
        """
        :param center_ids: batch_size
        :param center_metadata_ids: batch_size
        :param context_ids: batch_size, 2 * context_window
        :param context_metadata_ids: batch_size, 2 * context_window, metadata_samples
        :param neg_ids: batch_size, 2 * context_window
        :param neg_metadata_ids: batch_size, 2 * context_window, metadata_samples
        :param num_contexts: batch_size (how many context words for each center id - necessary for masking padding)
        :return: cost components: KL-Divergence (q(z|w,c) || p(z|w)) and max margin (reconstruction error)
        """
        # Mask padded context ids
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        batch_size, num_context_ids = context_ids.size()
        mask_size = torch.Size([batch_size, num_context_ids])
        mask = mask_2D(mask_size, num_contexts).to(device)

        m_samples = context_metadata_ids.size()[-1]
        assert m_samples == num_metadata_samples

        # Compute center words
        mu_center_q, sigma_center_q, _ = self.encoder(center_ids, center_metadata_ids, context_ids, mask)
        mu_center_tiled_q = mu_center_q.unsqueeze(1).repeat(1, num_context_ids * m_samples, 1)
        sigma_center_tiled_q = sigma_center_q.unsqueeze(1).repeat(1, num_context_ids * m_samples, 1)
        mu_center_flat_q = mu_center_tiled_q.view(batch_size * num_context_ids * m_samples, -1)
        sigma_center_flat_q = sigma_center_tiled_q.view(batch_size * num_context_ids * m_samples, -1)

        # Compute decoded representations of (w, d), E(c), E(n)
        mu_center, sigma_center = self.decoder(center_ids, center_metadata_ids)
        mu_pos, sigma_pos = self._compute_marginal(context_ids, context_metadata_ids)
        mu_neg, sigma_neg = self._compute_marginal(neg_ids, neg_metadata_ids)

        # Flatten positive context
        mu_pos_flat = mu_pos.view(batch_size * num_context_ids * m_samples, -1)
        sigma_pos_flat = sigma_pos.view(batch_size * num_context_ids * m_samples, -1)

        # Flatten negative context
        mu_neg_flat = mu_neg.view(batch_size * num_context_ids * m_samples, -1)
        sigma_neg_flat = sigma_neg.view(batch_size * num_context_ids * m_samples, -1)

        # Compute KL-divergence between center words and negative and reshape
        kl_pos_flat = compute_kl(mu_center_flat_q, sigma_center_flat_q, mu_pos_flat, sigma_pos_flat)
        kl_neg_flat = compute_kl(mu_center_flat_q, sigma_center_flat_q, mu_neg_flat, sigma_neg_flat)
        kl_pos = kl_pos_flat.view(batch_size, num_context_ids, m_samples).mean(-1)
        kl_neg = kl_neg_flat.view(batch_size, num_context_ids, m_samples).mean(-1)

        hinge_loss = (kl_pos - kl_neg + 1.0).clamp_min_(0)
        hinge_loss.masked_fill_(mask, 0)
        hinge_loss = hinge_loss.sum(1)

        recon_loss = compute_kl(mu_center_q, sigma_center_q, mu_center, sigma_center).squeeze(-1)
        return hinge_loss.mean(), recon_loss.mean()
コード例 #4
0
    def forward(self, sf_ids, context_ids, lf_ids, target_lf_ids, lf_token_ct,
                num_outputs):
        """
        :param sf_ids: batch_size
        :param context_ids: batch_size, num_context_ids
        :param lf_ids: batch_size, max_output_size, max_lf_len
        :param lf_token_ct: batch_size, max_output_size - normalizer for lf_ids
        :param target_lf_ids: batch_size
        :param num_outputs: batch_size
        :return:
        """
        batch_size, num_context_ids = context_ids.size()
        max_output_size = lf_ids.size()[1]

        output_dim = torch.Size([batch_size, max_output_size])
        output_mask = mask_2D(output_dim, num_outputs)

        # First thing is to pass the SF with the context to the encoder
        mask = torch.BoolTensor(torch.Size([batch_size, num_context_ids]))
        mask.fill_(0)
        sf_mu, sf_sigma = self.encoder(sf_ids, context_ids, mask)

        # Next is to get prior representations for each LF in lf_ids
        lf_mu, lf_sigma = self._compute_priors(lf_ids)

        # Summarize LFs
        normalizer = lf_token_ct.unsqueeze(-1).clamp_min(1.0)
        lf_mu_sum, lf_sigma_sum = lf_mu.sum(-2) / normalizer, lf_sigma.sum(
            -2) / normalizer

        # Tile SFs across each LF and flatten both SFs and LFs
        sf_mu_flat = sf_mu.unsqueeze(1).repeat(1, max_output_size, 1).view(
            batch_size * max_output_size, -1)
        sf_sigma_flat = sf_sigma.unsqueeze(1).repeat(
            1, max_output_size, 1).view(batch_size * max_output_size, -1)
        lf_mu_flat = lf_mu_sum.view(batch_size * max_output_size, -1)
        lf_sigma_flat = lf_sigma_sum.view(batch_size * max_output_size, -1)

        kl = compute_kl(sf_mu_flat, sf_sigma_flat, lf_mu_flat,
                        lf_sigma_flat).view(batch_size, max_output_size)
        # min_kl, max_kl = kl.min(-1)[0], kl.max(-1)[0]
        # normalizer = (max_kl - min_kl).unsqueeze(-1)
        # numerator = max_kl.unsqueeze(-1).repeat(1, kl.size()[-1]) - kl
        # score = numerator  # / normalizer
        score = -kl
        score.masked_fill_(output_mask, float('-inf'))
        return score, target_lf_ids
コード例 #5
0
    def forward(self, context_ids, context_token_ct, lf_ids, lf_token_ct,
                target_lf_ids, num_outputs):
        """
        :param sf_ids: batch_size
        :param context_ids: batch_size, num_context_ids, 50
        :param lf_ids: batch_size, max_output_size, max_lf_len, 50
        :param lf_token_ct: batch_size, max_output_size - normalizer for lf_ids
        :param target_lf_ids: batch_size
        :param num_outputs: batch_size
        :return: LF predictions, LF ground truths
        """
        device_str = 'cuda' if torch.cuda.is_available() else 'cpu'
        batch_size, num_context_ids, _ = context_ids.size()
        _, max_output_size, max_lf_len, _ = lf_ids.size()

        context_mask = np.ones([batch_size, num_context_ids], dtype=int)
        lf_mask = np.ones([batch_size, max_output_size, max_lf_len], dtype=int)
        for i in range(batch_size):
            context_mask[i, context_token_ct[i]:] = 0
            for lf_idx in range(max_output_size):
                lf_mask[i, lf_idx, lf_token_ct[i, lf_idx]:] = 0
        context_mask = torch.FloatTensor(context_mask).to(device_str)
        lf_mask = torch.FloatTensor(lf_mask).to(device_str)

        output_dim = torch.Size([batch_size, max_output_size])
        output_mask = mask_2D(output_dim, num_outputs).to(device_str)
        encoded_context = self.elmo(context_ids)
        encoded_context = (encoded_context * context_mask.unsqueeze(-1)).sum(
            axis=1) / context_token_ct.unsqueeze(-1)
        encoded_lfs = self.elmo(
            lf_ids.view(batch_size * max_output_size, max_lf_len,
                        50)).view(batch_size, max_output_size, max_lf_len, -1)

        encoded_lfs = (encoded_lfs * lf_mask.unsqueeze(-1)
                       ).sum(2) / lf_token_ct.unsqueeze(-1)
        encoded_context_tiled = encoded_context.unsqueeze(1).repeat(
            1, max_output_size, 1)
        score = nn.CosineSimilarity(dim=-1)(encoded_context_tiled, encoded_lfs)
        score.masked_fill_(output_mask, float('-inf'))
        return score, target_lf_ids
コード例 #6
0
ファイル: bsg_model.py プロジェクト: griff4692/LMC
    def forward(self, token_ids, sec_ids, cat_ids, context_ids,
                neg_context_ids, num_contexts):
        """
        :param center_ids: batch_size
        :param context_ids: batch_size, 2 * context_window
        :param neg_context_ids: batch_size, 2 * context_window
        :param num_contexts: batch_size (how many context words for each center id - necessary for masking padding)
        :return: cost components: KL-Divergence (q(z|w,c) || p(z|w)) and max margin (reconstruction error)
        """
        # Mask padded context ids
        batch_size, num_context_ids = context_ids.size()
        mask_size = torch.Size([batch_size, num_context_ids])
        mask = mask_2D(mask_size, num_contexts).to(self.device)

        center_ids = token_ids
        if self.multi_bsg:
            center_id_candidates = torch.cat([
                token_ids.unsqueeze(0),
                sec_ids.unsqueeze(0),
                cat_ids.unsqueeze(0)
            ])
            input_sample = torch.multinomial(self.input_weights,
                                             batch_size,
                                             replacement=True).to(self.device)
            center_ids = center_id_candidates.gather(
                0, input_sample.unsqueeze(0)).squeeze(0)

        mu_q, sigma_q = self.encoder(center_ids,
                                     context_ids,
                                     mask,
                                     token_mask_p=self.mask_p)
        mu_p, sigma_p = self._compute_priors(token_ids)

        pos_mu_p, pos_sigma_p = self._compute_priors(context_ids)
        neg_mu_p, neg_sigma_p = self._compute_priors(neg_context_ids)

        kl = compute_kl(mu_q, sigma_q, mu_p, sigma_p).mean()
        max_margin = self._max_margin(mu_q, sigma_q, pos_mu_p, pos_sigma_p,
                                      neg_mu_p, neg_sigma_p, mask).mean()
        return kl, max_margin
コード例 #7
0
    def forward(self, center_ids, context_ids, neg_context_ids, num_contexts):
        """
        :param center_ids: batch_size
        :param context_ids: batch_size, 2 * context_window
        :param neg_context_ids: batch_size, 2 * context_window
        :param num_contexts: batch_size (how many context words for each center id - necessary for masking padding)
        :return: cost components: KL-Divergence (q(z|w,c) || p(z|w)) and max margin (reconstruction error)
        """
        # Mask padded context ids
        batch_size, num_context_ids = context_ids.size()
        mask_size = torch.Size([batch_size, num_context_ids])
        mask = mask_2D(mask_size, num_contexts).to(self.device)

        mu_q, sigma_q = self.encoder(center_ids, context_ids, mask)

        mu_p, sigma_p = self._compute_priors(center_ids)

        pos_mu_p, pos_sigma_p = self._compute_priors(context_ids)
        neg_mu_p, neg_sigma_p = self._compute_priors(neg_context_ids)

        kl = compute_kl(mu_q, sigma_q, mu_p, sigma_p).mean()
        max_margin = self._max_margin(mu_q, sigma_q, pos_mu_p, pos_sigma_p,
                                      neg_mu_p, neg_sigma_p, mask).mean()
        return kl, max_margin
コード例 #8
0
ファイル: bsg_acronym_expander.py プロジェクト: griff4692/LMC
    def forward(self, sf_ids, section_ids, category_ids, context_ids, lf_ids,
                target_lf_ids, lf_token_ct, lf_metadata_ids, num_outputs,
                num_contexts):
        """
        :param sf_ids: LongTensor of batch_size
        :param context_ids: LongTensor of batch_size x 2 * context_window
        :param lf_ids: LongTensor of batch_size x max_output_size x max_lf_len
        :param lf_token_ct: batch_size, max_output_size - normalizer for lf_ids
        :param target_lf_ids: LongTensor of batch_size representing which index in lf_ids lies the target LF
        :param num_outputs: list representing the number of target LFs for each row in batch.
        :return:
        """
        batch_size, max_output_size, _ = lf_ids.size()

        # Next is to get prior representations for each LF in lf_ids
        lf_mu, lf_sigma = self._compute_priors(lf_ids)
        # Summarize LFs
        normalizer = lf_token_ct.unsqueeze(-1).clamp_min(1.0)
        lf_mu_sum, lf_sigma_sum = lf_mu.sum(-2) / normalizer, lf_sigma.sum(
            -2) / normalizer

        combined_mu = []
        combined_sigma = []

        # Encode SFs in context
        sf_mu_tokens, sf_sigma_tokens = self.encode_context(
            sf_ids, context_ids, num_contexts)
        combined_mu.append(sf_mu_tokens)
        combined_sigma.append(sf_sigma_tokens)

        # For MBSGE ensemble method, we leverage section ids and note category ids
        if len(section_ids.nonzero()) > 0:
            sf_mu_sec, sf_sigma_sec = self.encode_context(
                section_ids, context_ids, num_contexts)
            combined_mu.append(sf_mu_sec)
            combined_sigma.append(sf_sigma_sec)

        if len(category_ids.nonzero()) > 0:
            sf_mu_cat, sf_sigma_cat = self.encode_context(
                category_ids, context_ids, num_contexts)
            combined_mu.append(sf_mu_cat)
            combined_sigma.append(sf_sigma_cat)

        combined_mu = torch.cat(list(map(lambda x: x.unsqueeze(1),
                                         combined_mu)),
                                axis=1)
        combined_sigma = torch.cat(list(
            map(lambda x: x.unsqueeze(1), combined_sigma)),
                                   axis=1)

        sf_mu = combined_mu.mean(1)
        sf_sigma = combined_sigma.mean(1)

        # Tile SFs across each LF and flatten both SFs and LFs
        sf_mu_flat = sf_mu.unsqueeze(1).repeat(1, max_output_size, 1).view(
            batch_size * max_output_size, -1)
        sf_sigma_flat = sf_sigma.unsqueeze(1).repeat(
            1, max_output_size, 1).view(batch_size * max_output_size, -1)
        lf_mu_flat = lf_mu_sum.view(batch_size * max_output_size, -1)
        lf_sigma_flat = lf_sigma_sum.view(batch_size * max_output_size, -1)

        output_dim = torch.Size([batch_size, max_output_size])
        output_mask = mask_2D(output_dim, num_outputs).to(self.device)

        kl = compute_kl(sf_mu_flat, sf_sigma_flat, lf_mu_flat,
                        lf_sigma_flat).view(batch_size, max_output_size)
        score = -kl
        score.masked_fill_(output_mask, float('-inf'))
        return score, target_lf_ids, None
コード例 #9
0
    center_word_tens = torch.LongTensor([token_vocab.get_id(center_word[0])
                                         ]).to(device_str)
    header_ids = list(map(lambda x: metadata_vocab.get_id(x), headers))
    compare_ids = list(map(lambda x: token_vocab.get_id(x), compare_words))
    pad_context = torch.zeros([
        1,
    ]).long().to(device_str)
    context_ids = list(map(lambda x: token_vocab.get_id(x), context_words))

    compare_tens = torch.LongTensor(compare_ids).to(device_str)
    context_tens = torch.LongTensor(context_ids).unsqueeze(0).to(device_str)
    mu_compare, sigma_compare = model.decoder(
        compare_tens, pad_context.repeat(len(compare_words)))

    mask = mask_2D(torch.Size([1, len(context_ids)]),
                   [len(context_ids)]).to(device_str)
    print('Interpolation between word={} and headers={}'.format(
        center_word, ', '.join(headers)))
    for i, header_id in enumerate(header_ids):
        header_tens = torch.LongTensor([header_id]).to(device_str)
        print(headers[i])
        for p in np.arange(0, 1.25, 0.25):
            rw = [p, 1.0 - p]
            rel_weights = torch.FloatTensor([rw]).to(device_str)
            with torch.no_grad():
                mu_q, sigma_q, weights = model.encoder(center_word_tens,
                                                       header_tens,
                                                       context_tens,
                                                       mask,
                                                       center_mask_p=None,
                                                       context_mask_p=None,
コード例 #10
0
ファイル: lmc_model.py プロジェクト: griff4692/LMC
    def forward(self, context_ids, center_ids, pos_ids, neg_ids, context_token_type_ids,
                num_contexts, context_mask, center_mask, pos_mask, neg_mask, num_metadata_samples=None):
        """
        :param context_ids: batch_size x context_len (wp ids for BERT-style context word sequence with center word
        and metadata special token)
        :param center_ids: batch_size x num_context_ids (wp ids for center words and its metadata)
        :param pos_ids: batch_size x window_size * 2 x num_context_ids (wp ids for context words)
        :param neg_ids: batch_size x window_size * 2 x num_context_ids (wp ids for negatively sampled words
        and MC metadata sample)
        :param context_token_type_ids: batch_size x context_len (demarcating tokens from metadata in context_ids
        and MC metadata sample)
        :param num_contexts: batch_size (number of non-padded pos_ids / neg_ids in each row in batch)
        :param context_mask: batch_size x context_len (BERT attention mask for context_ids)
        :param center_mask: batch_size x num_context_ids
        :param pos_mask: batch_size x window_size * 2 x num_context_ids
        :param neg_mask: batch_size x window_size * 2 x num_context_ids
        :param num_metadata_samples:  Number of MC samples approximating marginal distribution over metadata
        Affects where to how to segment BERT-style sequence into distinct token_type_ids segments
        :return: a tuple of 1-size tensors representing the hinge likelihood loss and the reconstruction kl-loss
        """
        batch_size, num_context_ids, max_decoder_len = pos_ids.size()
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        mask_size = torch.Size([batch_size, num_context_ids])
        mask = mask_2D(mask_size, num_contexts).to(device)

        # Compute center words
        mu_center_q, sigma_center_q, _ = self.encoder(
            input_ids=context_ids, attention_mask=context_mask, token_type_ids=context_token_type_ids)
        mu_center_tiled_q = mu_center_q.unsqueeze(1).repeat(1, num_context_ids, 1)
        sigma_center_tiled_q = sigma_center_q.unsqueeze(1).repeat(1, num_context_ids, 1)
        mu_center_flat_q = mu_center_tiled_q.view(batch_size * num_context_ids, -1)
        sigma_center_flat_q = sigma_center_tiled_q.view(batch_size * num_context_ids, -1)

        # Compute decoded representations of (w, d), E(c), E(n)
        n = batch_size * num_context_ids
        pos_ids_flat, pos_mask_flat = pos_ids.view(n, -1), pos_mask.view(n, -1)
        neg_ids_flat, neg_mask_flat = neg_ids.view(n, -1), neg_mask.view(n, -1)

        joint_ids = torch.cat([center_ids, pos_ids_flat, neg_ids_flat], axis=0)
        joint_mask = torch.cat([center_mask, pos_mask_flat, neg_mask_flat], axis=0)

        center_sep_idx = 1 + 2
        other_sep_idx = num_metadata_samples + 2
        decoder_type_ids = torch.zeros([joint_ids.size()[0], max_decoder_len]).long().to(device)
        decoder_type_ids[:batch_size, -center_sep_idx:] = 1
        decoder_type_ids[batch_size:, -other_sep_idx:] = 1

        mu_joint, sigma_joint = self.decoder(
            input_ids=joint_ids, attention_mask=joint_mask, token_type_ids=decoder_type_ids)

        mu_center, sigma_center = mu_joint[:batch_size], sigma_joint[:batch_size]
        s = batch_size * (num_context_ids + 1)
        mu_pos_flat, sigma_pos_flat = mu_joint[batch_size:s], sigma_joint[batch_size:s]
        mu_neg_flat, sigma_neg_flat = mu_joint[s:], sigma_joint[s:]

        # Compute KL-divergence between center words and negative and reshape
        kl_pos_flat = compute_kl(mu_center_flat_q, sigma_center_flat_q, mu_pos_flat, sigma_pos_flat)
        kl_neg_flat = compute_kl(mu_center_flat_q, sigma_center_flat_q, mu_neg_flat, sigma_neg_flat)
        kl_pos = kl_pos_flat.view(batch_size, num_context_ids)
        kl_neg = kl_neg_flat.view(batch_size, num_context_ids)

        hinge_loss = (kl_pos - kl_neg + 1.0).clamp_min_(0)
        hinge_loss.masked_fill_(mask, 0)
        hinge_loss = hinge_loss.sum(1)

        recon_loss = compute_kl(mu_center_q, sigma_center_q, mu_center, sigma_center).squeeze(-1)
        return hinge_loss.mean(), recon_loss.mean()
コード例 #11
0
    def forward(self, sf_ids, section_ids, category_ids, context_ids, lf_ids,
                target_lf_ids, lf_token_ct, lf_metadata_ids, lf_metadata_p,
                num_outputs, num_contexts):
        """
        :param sf_ids: LongTensor of batch_size
        :param section_ids: LongTensor of batch_size
        :param category_ids: LongTensor of batch_size
        :param context_ids: LongTensor of batch_size x 2 * context_window
        :param lf_ids: LongTensor of batch_size x max_output_size x max_lf_len
        :param target_lf_ids: LongTensor of batch_size representing which index in lf_ids lies the target LF
        :param lf_token_ct: LongTensor of batch_size x max_output_size.  N-gram count for each LF (used for masking)
        :param lf_metadata_ids: batch_size x max_output_size x max_num_metadata. Ids for every metadata LF appears in
        :param lf_metadata_p: batch_size x max_output_size x max_num_metadata
        Empirical probability for lf_metadata_ids ~ p(metadata|LF)
        :param num_outputs: list representing the number of target LFs for each row in batch.
        Used for masking to avoid returning invalid predictions.
        :param num_contexts: LongTensor of batch_size.  The actual window size of the SF context.
        Many are shorter than target of 2 * context_window.
        :return: scores for each candidate LF, target_lf_ids, rel_weights (output of encoder gating function)
        """
        batch_size, max_output_size, max_lf_ngram = lf_ids.size()
        _, num_context_ids = context_ids.size()

        # Compute SF contexts
        # Mask padded context ids
        mask_size = torch.Size([batch_size, num_context_ids])
        mask = mask_2D(mask_size, num_contexts).to(self.device)
        sf_mu, sf_sigma, rel_weights = self.encoder(sf_ids,
                                                    section_ids,
                                                    context_ids,
                                                    mask,
                                                    center_mask_p=None,
                                                    context_mask_p=None)

        num_metadata = lf_metadata_ids.size()[-1]

        # Tile SFs across each LF and flatten both SFs and LFs
        sf_mu_flat = sf_mu.unsqueeze(1).repeat(
            1, max_output_size * num_metadata,
            1).view(batch_size * max_output_size * num_metadata, -1)
        sf_sigma_flat = sf_sigma.unsqueeze(1).repeat(
            1, max_output_size * num_metadata,
            1).view(batch_size * max_output_size * num_metadata, -1)

        # Compute E[LF]
        normalizer = lf_token_ct.unsqueeze(-1).clamp_min(1.0)
        lf_mu, lf_sigma = self._compute_marginal(lf_ids,
                                                 lf_metadata_ids,
                                                 normalizer=normalizer)
        lf_mu_flat = lf_mu.view(batch_size * max_output_size * num_metadata,
                                -1)
        lf_sigma_flat = lf_sigma.view(
            batch_size * max_output_size * num_metadata, 1)
        output_dim = torch.Size([batch_size, max_output_size])
        output_mask = mask_2D(output_dim, num_outputs).to(self.device)

        kl_marginal = compute_kl(sf_mu_flat, sf_sigma_flat, lf_mu_flat,
                                 lf_sigma_flat).view(batch_size,
                                                     max_output_size,
                                                     num_metadata)
        kl = (kl_marginal * lf_metadata_p).sum(2)
        score = -kl

        score.masked_fill_(output_mask, float('-inf'))
        return score, target_lf_ids, rel_weights