예제 #1
0
 def soft_attn_merge(self, entity_package, batch_type_ids, batch_type_emb,
                     sent_emb):
     """For each candidate, use a weighted average of the type embs based on type emb similarity to the contextualized mention embedding."""
     batch, M, K, num_types, dim = batch_type_emb.shape
     # we don't want to compute probabilities over padded types
     # when there are no types -- we'll just get an average over unk types (i.e. the unk type)
     mask = (batch_type_ids <
             (self.num_types_with_pad_and_unk - 1)).reshape(
                 batch * M * K, num_types)
     # Get alias tensor and expand to be for each candidate for soft attn
     alias_word_tensor = model_utils.select_alias_word_sent(
         entity_package.pos_in_sent, sent_emb, index=0)
     _, _, sent_dim = alias_word_tensor.shape
     alias_word_tensor = alias_word_tensor.unsqueeze(2).expand(
         batch, M, K, sent_dim)
     # Reshape for soft attn
     batch_type_emb = batch_type_emb.contiguous().reshape(
         batch * M * K, num_types, dim)
     alias_word_tensor = alias_word_tensor.contiguous().reshape(
         batch * M * K, sent_dim)
     # Get soft attn
     batch_type_emb = self.soft_attn(batch_type_emb,
                                     alias_word_tensor,
                                     mask=mask)
     # Convert batch back to original shape
     batch_type_emb = batch_type_emb.reshape(batch, M, K, dim)
     return batch_type_emb
예제 #2
0
파일: layers.py 프로젝트: syyunn/bootleg
    def forward(self, sent_emb, entity_package, entity_embs):
        batch, M, K = entity_package.tensor.shape
        # Get alias tensor and expand to be for each candidate for soft attn
        alias_word_tensor = model_utils.select_alias_word_sent(
            entity_package.pos_in_sent, sent_emb, index=0)
        alias_mask = entity_package.alias_indices == -1

        # batch x M x num_types
        batch_type_pred = self.prediction(alias_word_tensor)
        batch_type_weights = self.type_softmax(batch_type_pred)
        # batch x M x emb_size
        batch_type_embs = torch.matmul(batch_type_weights,
                                       self.type_embedding.unsqueeze(0))
        # mask out unk alias embeddings
        batch_type_embs[alias_mask] = 0
        batch_type_embs = batch_type_embs.unsqueeze(2).expand(
            batch, M, K, self.emb_size)

        res = DottedDict(tensor=batch_type_embs,
                         pos_in_sent=entity_package.pos_in_sent,
                         alias_indices=entity_package.alias_indices,
                         mask=entity_package.mask,
                         normalize=True)
        entity_embs.append(res)
        return entity_embs, batch_type_pred
    def forward(self, sent_emb, start_span_idx):
        """Model forward.

        Args:
            sent_emb: sentence embedding (B x N x L)
            start_span_idx: span index into sentence embedding (B x M)

        Returns: type embeding tensor (B x M x K x dim), type weight prediction (B x M x num_types)
        """
        batch, M = start_span_idx.shape
        alias_mask = start_span_idx == -1
        # Get alias tensor and expand to be for each candidate for soft attn
        alias_word_tensor = model_utils.select_alias_word_sent(start_span_idx, sent_emb)

        # batch x M x num_types
        batch_type_pred = self.prediction(alias_word_tensor)
        batch_type_weights = self.type_softmax(batch_type_pred)
        # batch x M x emb_size
        batch_type_embs = torch.matmul(
            batch_type_weights, self.type_embedding.unsqueeze(0)
        )
        # mask out unk alias embeddings
        batch_type_embs[alias_mask] = 0
        batch_type_embs = batch_type_embs.unsqueeze(2).expand(
            batch, M, self.K, self.emb_size
        )
        # normalize the output before being concatenated
        batch_type_embs = model_utils.normalize_matrix(batch_type_embs, dim=3)
        return batch_type_embs, batch_type_pred
예제 #4
0
    def test_select_alias_word(self):
        # batch = 3, M = 2, N = 5, H = 4
        sent_embedding = torch.randn(3, 5, 4)
        alias_pos_in_sent = torch.tensor([[2, 1], [1, -1], [0, -1]])
        res = model_utils.select_alias_word_sent(alias_pos_in_sent,
                                                 sent_embedding)
        res_gold = torch.zeros(3, 2, 4)
        res_gold[0][0] = sent_embedding[0][2]
        res_gold[0][1] = sent_embedding[0][1]
        res_gold[1][0] = sent_embedding[1][1]
        res_gold[2][0] = sent_embedding[2][0]

        assert torch.equal(res_gold, res)
예제 #5
0
    def forward(self, sent_embedding, entity_embedding, entity_mask, alias_idx_pair_sent, slice_emb_alias, slice_emb_ent):
        batch_size, M, K, _ = entity_embedding.shape
        # Index is which word to select
        alias_word_tensor = model_utils.select_alias_word_sent(alias_idx_pair_sent, sent_embedding, index=0)

        # Slice emb is hidden_size x 1 -> batch x M x hidden_size
        # Add in slice alias embedding
        slice_emb_alias = slice_emb_alias.unsqueeze(0).unsqueeze(1).expand(batch_size, M, self.hidden_size)
        alias_word_tensor = alias_word_tensor + slice_emb_alias
        alias_word_tensor = alias_word_tensor.transpose(0,1)
        assert alias_word_tensor.shape[1] == batch_size

        # Sentence attn between alias vector and full sentence
        alias_word_tensor, alias_word_weights = self.sent_alias_attn(q=alias_word_tensor, x=sent_embedding.tensor.transpose(0,1),
            key_mask=sent_embedding.mask, attn_mask=None)

        # Add in slice entity embedding
        slice_emb_ent = slice_emb_ent.unsqueeze(0).unsqueeze(1).expand(batch_size, M, self.hidden_size)
        alias_word_tensor = alias_word_tensor.transpose(0,1) + slice_emb_ent
        alias_word_tensor = alias_word_tensor.transpose(0,1)
        entity_attn_tensor = entity_embedding.contiguous().view(batch_size, M*K, self.hidden_size).transpose(0,1)
        key_padding_mask_entities = entity_mask.contiguous().view(batch_size, M*K)
        # Each M alias should ONLY pay attention to it's OWN candidates
        entity_mask = torch.ones((M, K*M)).to(key_padding_mask_entities.device)
        # M x (M*K)
        # TODO: move this to init
        for i in range(M):
            entity_mask[i, i*K:(i+1)*K] = 0.0
            # Must manually move this to the device as it's not part of a module
            entity_mask = entity_mask.masked_fill((entity_mask == 1), float(-1e9))
        alias_entity_attn_context, alias_entity_attn_weights = self.attention_module(
                q=alias_word_tensor,
                x=entity_attn_tensor,
                key_mask=key_padding_mask_entities,
                attn_mask=entity_mask
        )
        # Returns batch x M x hidden
        return alias_entity_attn_context.transpose(0,1), alias_word_weights
예제 #6
0
    def forward(
        self,
        sent_embedding,
        sent_embedding_mask,
        entity_embedding,
        entity_embedding_mask,
        start_span_idx,
        end_span_idx,
        batch_on_the_fly_data,
    ):
        """Model forward.

        Args:
            sent_embedding: sentence embedding (B x N x L)
            sent_embedding_mask: sentence embedding mask (B x N)
            entity_embedding: entity embedding (B x M x K x H)
            entity_embedding_mask: entity embedding mask (B x M x K)
            start_span_idx: start mention index into sentence (B x M)
            end_span_idx: end mention index into sentence (B x M)
            batch_on_the_fly_data: batch on the fly dictionary with values (B x (M*K) x (M*K)) of KG adjacency matrices

        Returns: Dict of Dict of intermediate output layer scores (will be empty for this model),
                 Output entity embeddings (B x M x K x H),
                 Candidate scores (B x M x K)
        """
        out = {DISAMBIG: {}}
        context_mat_dict = {}

        batch_size, M, K, emb_dim = entity_embedding.shape
        alias_start_idx_sent = start_span_idx
        alias_end_idx_sent = end_span_idx
        assert (
            emb_dim == self.hidden_size
        ), f"BERT NED requires the learned entity embedding dim be the same as the hidden size"
        assert alias_start_idx_sent.shape == alias_end_idx_sent.shape

        # Get alias words from sent embedding then cat and proj
        alias_start_word_tensor = model_utils.select_alias_word_sent(
            alias_start_idx_sent, sent_embedding)
        alias_end_word_tensor = model_utils.select_alias_word_sent(
            alias_end_idx_sent, sent_embedding)
        alias_pair_word_tensor = torch.cat(
            [alias_start_word_tensor, alias_end_word_tensor], dim=-1)
        alias_emb = (
            self.span_proj(alias_pair_word_tensor).unsqueeze(2).expand(
                batch_size, M, self.K, self.hidden_size))
        alias_emb = (alias_emb.contiguous().reshape(
            (batch_size * M * self.K), self.hidden_size).unsqueeze(1))

        # entity_embedding_mask: if I don't have 30 candidates, use a mask to fill the rest of the
        # matrix for empty candidates
        entity_embedding_zeroed = torch.where(
            entity_embedding_mask.unsqueeze(-1),
            torch.zeros_like(entity_embedding),
            entity_embedding,
        )
        entity_embedding_tensor = (
            entity_embedding_zeroed.contiguous().reshape(
                (batch_size * M * self.K), self.hidden_size).unsqueeze(-1))

        # Performs batch wise dot produce across each dim=0 dimension
        score = (torch.bmm(alias_emb,
                           entity_embedding_tensor).unsqueeze(-1).reshape(
                               batch_size, M, self.K))
        context_mat_dict[DISAMBIG] = entity_embedding_tensor.reshape(
            batch_size, M, self.K, self.hidden_size)
        return {
            "intermed_scores": out,
            "ent_embs": context_mat_dict,
            "final_scores": score,
        }
예제 #7
0
    def forward(
        self,
        sent_embedding,
        sent_embedding_mask,
        entity_embedding,
        entity_embedding_mask,
        start_span_idx,
        end_span_idx,
        batch_on_the_fly_data,
    ):
        """Model forward.

        Args:
            sent_embedding: sentence embedding (B x N x L)
            sent_embedding_mask: sentence embedding mask (B x N)
            entity_embedding: entity embedding (B x M x K x H)
            entity_embedding_mask: entity embedding mask (B x M x K)
            start_span_idx: start mention index into sentence (B x M)
            end_span_idx: end mention index into sentence (B x M)
            batch_on_the_fly_data: batch on the fly dictionary with values (B x (M*K) x (M*K)) of KG adjacency matrices

        Returns: Dict of Dict of intermediate layer candidate scores (B x M x K),
                         Dict of all output entity embeddings from each KG matrix (B x M x K x H)
        """
        batch_size = sent_embedding.shape[0]
        out = {DISAMBIG: {}}

        # Create KG bias matrices for each kg bias key
        kg_bias_norms = {}
        for key in self.kg_bias_keys:
            kg_bias_norms[key] = (batch_on_the_fly_data[key].float().reshape(
                batch_size, self.M * self.K, self.M * self.K))

        # get mention embedding
        # average words in mention; batch x M x dim
        mention_tensor_start = model_utils.select_alias_word_sent(
            start_span_idx, sent_embedding)
        mention_tensor_end = model_utils.select_alias_word_sent(
            end_span_idx, sent_embedding)
        mention_tensor = (mention_tensor_start + mention_tensor_end) / 2

        # reshape for alias attention where each mention attends to its K candidates
        # query = batch*M x 1 x dim, key = value = batch*M x K x dim
        # softmax(QK^T) -> batch*M x 1 x K
        # softmax(QK^T)V -> batch*M x 1 x dim
        mention_tensor = mention_tensor.reshape(batch_size * self.M, 1,
                                                self.hidden_size).transpose(
                                                    0, 1)

        # get sentence embedding; move batch to middle
        sent_tensor = sent_embedding.transpose(0, 1)

        # Resize the alias embeddings and the entity mask from B x M x K x D -> B x (M*K) x D
        entity_embedding = entity_embedding.contiguous().view(
            batch_size, self.M * self.K, self.hidden_size)
        entity_embedding = entity_embedding.transpose(
            0, 1)  # reshape for attention
        key_padding_mask_entities = entity_embedding_mask.contiguous().view(
            batch_size, self.M * self.K)
        key_padding_mask_entities_mention = entity_embedding_mask.contiguous(
        ).view(batch_size * self.M, self.K)
        # Mask of aliases; key_padding_mask_entities_mention of True means mask.
        # We want to find aliases with all masked entities
        key_padding_mask_mentions = (torch.sum(
            ~key_padding_mask_entities_mention, dim=-1) == 0)
        # Unmask these aliases to avoid nan in attention
        key_padding_mask_entities_mention[key_padding_mask_mentions] = False
        # Iterate through stages
        query_tensor = entity_embedding
        for stage_index in range(self.num_model_stages):
            # As we are adding a residual in the attention modules, we can make embs empty
            embs = []
            context_mat_dict = {}
            key_tensor_mention = (query_tensor.transpose(
                0,
                1).contiguous().reshape(batch_size, self.M, self.K,
                                        self.hidden_size).reshape(
                                            batch_size * self.M, self.K,
                                            self.hidden_size).transpose(0, 1))
            # ============================================================================
            # Phrase module: compute attention between entities and words
            # ============================================================================
            word_entity_attn_context, word_entity_attn_weights = self.attention_modules[
                f"stage_{stage_index}_entity_word"](
                    q=query_tensor,
                    x=sent_tensor,
                    key_mask=sent_embedding_mask,
                    attn_mask=None,
                )
            # Add embeddings to be merged in the output
            embs.append(word_entity_attn_context)
            # Save the attention weights
            self.attention_weights[
                f"stage_{stage_index}_entity_word"] = word_entity_attn_weights

            # ============================================================================
            # Co-occurrence module: compute self attention over entities
            # ============================================================================
            # Move entity mask to device
            # TODO: move to device in init?
            self.e2e_entity_mask = self.e2e_entity_mask.to(
                key_padding_mask_entities.device)

            entity_attn_context, entity_attn_weights = self.attention_modules[
                f"stage_{stage_index}_self_entity"](
                    x=query_tensor,
                    key_mask=key_padding_mask_entities,
                    attn_mask=self.e2e_entity_mask,
                )
            # Mask out MxK of single aliases, alias_indices is batch x M, mask is true when single alias
            non_null_aliases = (self.K - key_padding_mask_entities.reshape(
                batch_size, self.M, self.K).sum(-1)) != 0
            entity_attn_post_mask = ((
                non_null_aliases.sum(1) == 1).unsqueeze(1).expand(
                    batch_size, self.K * self.M).transpose(0, 1))
            entity_attn_post_mask = entity_attn_post_mask.unsqueeze(
                -1).expand_as(entity_attn_context)
            entity_attn_context = torch.where(
                entity_attn_post_mask,
                torch.zeros_like(entity_attn_context),
                entity_attn_context,
            )

            # Add embeddings to be merged in the output
            embs.append(entity_attn_context)
            # Save the attention weights
            self.attention_weights[
                f"stage_{stage_index}_self_entity"] = entity_attn_weights

            # ============================================================================
            # Mention module: compute attention between entities and mentions
            # ============================================================================
            # output is 1 x batch*M x dim
            (
                mention_entity_attn_context,
                mention_entity_attn_weights,
            ) = self.attention_modules[f"stage_{stage_index}_mention_entity"](
                q=mention_tensor,
                x=key_tensor_mention,
                key_mask=key_padding_mask_entities_mention,
                attn_mask=None,
            )
            # key_padding_mask_mentions mentions have all padded candidates,
            # meaning their row in the context matrix are all nan
            mention_entity_attn_context[key_padding_mask_mentions.unsqueeze(
                0)] = 0
            mention_entity_attn_context = (mention_entity_attn_context.expand(
                self.K, batch_size * self.M,
                self.hidden_size).transpose(0, 1).reshape(
                    batch_size, self.M * self.K,
                    self.hidden_size).transpose(0, 1))
            # Add embeddings to be merged in the output
            embs.append(mention_entity_attn_context)
            # Save the attention weights
            self.attention_weights[
                f"stage_{stage_index}_mention_entity"] = mention_entity_attn_weights

            # Combine module output
            context_matrix_nokg = self.combine_modules[
                f"stage_{stage_index}_combine"](embs)
            context_mat_dict[self.no_kg_key] = context_matrix_nokg.transpose(
                0, 1).reshape(batch_size, self.M, self.K, self.hidden_size)
            # ============================================================================
            # KG module: add in KG connectivity bias
            # ============================================================================
            for key in self.kg_bias_keys:
                context_matrix_kg = torch.bmm(
                    kg_bias_norms[key],
                    context_matrix_nokg.transpose(0, 1)).transpose(0, 1)
                context_matrix_kg = (context_matrix_nokg +
                                     context_matrix_kg) / 2
                context_mat_dict[
                    f"context_matrix_{key}"] = context_matrix_kg.transpose(
                        0, 1).reshape(batch_size, self.M, self.K,
                                      self.hidden_size)

            if stage_index < self.num_model_stages - 1:
                score = model_utils.max_score_context_matrix(
                    context_mat_dict,
                    self.predict_layers[DISAMBIG][
                        bootleg.utils.model_utils.get_stage_head_name(
                            stage_index)],
                )
                out[DISAMBIG][
                    f"{bootleg.utils.model_utils.get_stage_head_name(stage_index)}"] = score

            # This will take the average of the context matrices that do not end in the key "_nokg";
            # if there are not kg bias terms, it will select the context_matrix_nokg
            # (as it's key, in this setting, will not end in _nokg)
            query_tensor = (model_utils.generate_final_context_matrix(
                context_mat_dict, ending_key_to_exclude="_nokg").reshape(
                    batch_size, self.M * self.K,
                    self.hidden_size).transpose(0, 1))
        return {
            "intermed_scores": out,
            "ent_embs": context_mat_dict,
            "final_scores": None,
        }
예제 #8
0
    def forward(self, alias_idx_pair_sent, sent_embedding, entity_embedding,
                batch_prepped, batch_on_the_fly_data):
        batch_size = sent_embedding.tensor.shape[0]
        out = {DISAMBIG: {}}

        # Create KG bias matrices for each kg bias key
        kg_bias_norms = {}
        for key in self.kg_bias_keys:
            kg_bias = batch_on_the_fly_data[key].float().to(
                sent_embedding.tensor.device).reshape(batch_size,
                                                      self.M * self.K,
                                                      self.M * self.K)
            kg_bias_diag = kg_bias + self.kg_bias_weights[key] * torch.eye(
                self.M * self.K).repeat(batch_size, 1, 1).view(
                    batch_size, self.M * self.K, self.M * self.K).to(
                        kg_bias.device)
            kg_bias_norm = self.kg_softmax(
                kg_bias_diag.masked_fill((kg_bias_diag == 0), float(-1e9)))
            kg_bias_norms[key] = kg_bias_norm

        # get mention embedding
        # average words in mention; batch x M x dim
        mention_tensor_start = model_utils.select_alias_word_sent(
            alias_idx_pair_sent, sent_embedding, index=0)
        mention_tensor_end = model_utils.select_alias_word_sent(
            alias_idx_pair_sent, sent_embedding, index=1)
        mention_tensor = (mention_tensor_start + mention_tensor_end) / 2

        # reshape for alias attention where each mention attends to its K candidates
        # query = batch*M x 1 x dim, key = value = batch*M x K x dim
        # softmax(QK^T) -> batch*M x 1 x K
        # softmax(QK^T)V -> batch*M x 1 x dim
        mention_tensor = mention_tensor.reshape(batch_size * self.M, 1,
                                                self.hidden_size).transpose(
                                                    0, 1)

        # get sentence embedding; move batch to middle
        sent_tensor = sent_embedding.tensor.transpose(0, 1)

        # Resize the alias embeddings and the entity mask from B x M x K x D -> B x (M*K) x D
        entity_mask = entity_embedding.mask
        entity_embedding = entity_embedding.tensor.contiguous().view(
            batch_size, self.M * self.K, self.hidden_size)
        entity_embedding = entity_embedding.transpose(
            0, 1)  # reshape for attention
        key_padding_mask_entities = entity_mask.contiguous().view(
            batch_size, self.M * self.K)
        key_padding_mask_entities_mention = entity_mask.contiguous().view(
            batch_size * self.M, self.K)
        # Mask of aliases; key_padding_mask_entities_mention of True means mask. We want to find aliases with all masked entities
        key_padding_mask_mentions = torch.sum(
            ~key_padding_mask_entities_mention, dim=-1) == 0
        # Unmask these aliases to avoid nan in attention
        key_padding_mask_entities_mention[key_padding_mask_mentions] = False
        # Iterate through stages
        query_tensor = entity_embedding
        for stage_index in range(self.num_model_stages):
            # As we are adding a residual in the attention modules, we can make embs empty
            embs = []
            context_mat_dict = {}
            key_tensor_mention = query_tensor.transpose(0,1).contiguous().reshape(batch_size, self.M, self.K, self.hidden_size)\
                .reshape(batch_size*self.M, self.K, self.hidden_size).transpose(0,1)
            #============================================================================
            # Phrase module: compute attention between entities and words
            #============================================================================
            word_entity_attn_context, word_entity_attn_weights = self.attention_modules[
                f"stage_{stage_index}_entity_word"](
                    q=query_tensor,
                    x=sent_tensor,
                    key_mask=sent_embedding.mask,
                    attn_mask=None)
            # Add embeddings to be merged in the output
            embs.append(word_entity_attn_context)
            # Save the attention weights
            self.attention_weights[
                f"stage_{stage_index}_entity_word"] = word_entity_attn_weights

            #============================================================================
            # Co-occurrence module: compute self attention over entities
            #============================================================================
            # Move entity mask to device
            # TODO: move to device in init?
            self.e2e_entity_mask = self.e2e_entity_mask.to(
                key_padding_mask_entities.device)

            entity_attn_context, entity_attn_weights = self.attention_modules[
                f"stage_{stage_index}_self_entity"](
                    x=query_tensor,
                    key_mask=key_padding_mask_entities,
                    attn_mask=self.e2e_entity_mask)
            # Mask out MxK of single aliases, alias_indices is batch x M, mask is true when single alias
            non_null_aliases = (self.K - key_padding_mask_entities.reshape(
                batch_size, self.M, self.K).sum(-1)) != 0
            entity_attn_post_mask = (
                non_null_aliases.sum(1) == 1).unsqueeze(1).expand(
                    batch_size, self.K * self.M).transpose(0, 1)
            entity_attn_post_mask = entity_attn_post_mask.unsqueeze(
                -1).expand_as(entity_attn_context)
            entity_attn_context = torch.where(
                entity_attn_post_mask, torch.zeros_like(entity_attn_context),
                entity_attn_context)

            # Add embeddings to be merged in the output
            embs.append(entity_attn_context)
            # Save the attention weights
            self.attention_weights[
                f"stage_{stage_index}_self_entity"] = entity_attn_weights

            #============================================================================
            # Mention module: compute attention between entities and mentions
            #============================================================================
            # output is 1 x batch*M x dim
            mention_entity_attn_context, mention_entity_attn_weights = self.attention_modules[
                f"stage_{stage_index}_mention_entity"](
                    q=mention_tensor,
                    x=key_tensor_mention,
                    key_mask=key_padding_mask_entities_mention,
                    attn_mask=None)
            # key_padding_mask_mentions mentions have all padded candidates, meaning their row in the context matrix are all nan
            mention_entity_attn_context[key_padding_mask_mentions.unsqueeze(
                0)] = 0
            mention_entity_attn_context = mention_entity_attn_context.expand(self.K, batch_size*self.M, self.hidden_size).transpose(0,1).\
                reshape(batch_size, self.M*self.K, self.hidden_size).transpose(0,1)
            # Add embeddings to be merged in the output
            embs.append(mention_entity_attn_context)
            # Save the attention weights
            self.attention_weights[
                f"stage_{stage_index}_mention_entity"] = mention_entity_attn_weights

            # Combine module output
            context_matrix_nokg = self.combine_modules[
                f"stage_{stage_index}_combine"](embs)
            context_mat_dict[self.no_kg_key] = context_matrix_nokg.transpose(
                0, 1).reshape(batch_size, self.M, self.K, self.hidden_size)
            #============================================================================
            # KG module: add in KG connectivity bias
            #============================================================================
            for key in self.kg_bias_keys:
                context_matrix_kg = torch.bmm(
                    kg_bias_norms[key],
                    context_matrix_nokg.transpose(0, 1)).transpose(0, 1)
                context_matrix_kg = (context_matrix_nokg +
                                     context_matrix_kg) / 2
                context_mat_dict[
                    f"context_matrix_{key}"] = context_matrix_kg.transpose(
                        0, 1).reshape(batch_size, self.M, self.K,
                                      self.hidden_size)

            if stage_index < self.num_model_stages - 1:
                score = model_utils.max_score_context_matrix(
                    context_mat_dict, self.predict_layers[DISAMBIG][
                        train_utils.get_stage_head_name(stage_index)])
                out[DISAMBIG][
                    f"{train_utils.get_stage_head_name(stage_index)}"] = score

            # This will take the average of the context matrices that do not end in the key "_nokg"; if there are not kg bias terms, it will
            # select the context_matrix_nokg (as it's key, in this setting, will not end in _nokg)
            query_tensor = model_utils.generate_final_context_matrix(context_mat_dict, ending_key_to_exclude="_nokg")\
                .reshape(batch_size, self.M*self.K, self.hidden_size).transpose(0,1)
        return context_mat_dict, out