Beispiel #1
0
    def forward(self, out_dict, context_matrix_dict, final_score=None):
        """Model forward. Must pass the self.training bool forward to loss
        function.

        Args:
            out_dict: Dict of intermediate scores (B x M x K)
            context_matrix_dict: Dict of output embedding matrices, e.g., from KG modules (B x M x K x H)
            final_score: Final output scores (B x M x K) (default None)

        Returns: Dict of Dict with final scores added (B x M x K), final output embedding (B x M x K x H),
                tensor of is_training Bool
        """
        score = model_utils.max_score_context_matrix(
            context_matrix_dict, self.prediction_head
        )
        out_dict[DISAMBIG][FINAL_LOSS] = score
        if "context_matrix_main" not in context_matrix_dict:
            context_matrix_dict[
                "context_matrix_main"
            ] = model_utils.generate_final_context_matrix(
                context_matrix_dict, ending_key_to_exclude="_nokg"
            )
        final_entity_embs = context_matrix_dict["context_matrix_main"]
        # Must make the self.training bool a tensor that is required for the loss to be gatherable for DP
        return {
            "final_scores": out_dict,
            "ent_embs": final_entity_embs,
            "training": (
                torch.tensor([1], device=final_entity_embs.device) * self.training
            ).bool(),
        }
Beispiel #2
0
 def forward(self, context_matrix_dict, alias_idx_pair_sent, entity_pack,
             sent_emb):
     out = {DISAMBIG: {}}
     score = model_utils.max_score_context_matrix(context_matrix_dict,
                                                  self.prediction_head)
     out[DISAMBIG][FINAL_LOSS] = score
     if "context_matrix_main" not in context_matrix_dict:
         context_matrix_dict[
             "context_matrix_main"] = model_utils.generate_final_context_matrix(
                 context_matrix_dict, ending_key_to_exclude="_nokg")
     return out, context_matrix_dict["context_matrix_main"]
Beispiel #3
0
    def forward(
        self,
        sent_embedding,
        sent_embedding_mask,
        entity_embedding,
        entity_embedding_mask,
        start_span_idx,
        end_span_idx,
        batch_on_the_fly_data,
    ):
        """Model forward.

        Args:
            sent_embedding: sentence embedding (B x N x L)
            sent_embedding_mask: sentence embedding mask (B x N)
            entity_embedding: entity embedding (B x M x K x H)
            entity_embedding_mask: entity embedding mask (B x M x K)
            start_span_idx: start mention index into sentence (B x M)
            end_span_idx: end mention index into sentence (B x M)
            batch_on_the_fly_data: batch on the fly dictionary with values (B x (M*K) x (M*K)) of KG adjacency matrices

        Returns: Dict of Dict of intermediate layer candidate scores (B x M x K),
                         Dict of all output entity embeddings from each KG matrix (B x M x K x H)
        """
        batch_size = sent_embedding.shape[0]
        out = {DISAMBIG: {}}

        # Create KG bias matrices for each kg bias key
        kg_bias_norms = {}
        for key in self.kg_bias_keys:
            kg_bias_norms[key] = (batch_on_the_fly_data[key].float().reshape(
                batch_size, self.M * self.K, self.M * self.K))

        # get mention embedding
        # average words in mention; batch x M x dim
        mention_tensor_start = model_utils.select_alias_word_sent(
            start_span_idx, sent_embedding)
        mention_tensor_end = model_utils.select_alias_word_sent(
            end_span_idx, sent_embedding)
        mention_tensor = (mention_tensor_start + mention_tensor_end) / 2

        # reshape for alias attention where each mention attends to its K candidates
        # query = batch*M x 1 x dim, key = value = batch*M x K x dim
        # softmax(QK^T) -> batch*M x 1 x K
        # softmax(QK^T)V -> batch*M x 1 x dim
        mention_tensor = mention_tensor.reshape(batch_size * self.M, 1,
                                                self.hidden_size).transpose(
                                                    0, 1)

        # get sentence embedding; move batch to middle
        sent_tensor = sent_embedding.transpose(0, 1)

        # Resize the alias embeddings and the entity mask from B x M x K x D -> B x (M*K) x D
        entity_embedding = entity_embedding.contiguous().view(
            batch_size, self.M * self.K, self.hidden_size)
        entity_embedding = entity_embedding.transpose(
            0, 1)  # reshape for attention
        key_padding_mask_entities = entity_embedding_mask.contiguous().view(
            batch_size, self.M * self.K)
        key_padding_mask_entities_mention = entity_embedding_mask.contiguous(
        ).view(batch_size * self.M, self.K)
        # Mask of aliases; key_padding_mask_entities_mention of True means mask.
        # We want to find aliases with all masked entities
        key_padding_mask_mentions = (torch.sum(
            ~key_padding_mask_entities_mention, dim=-1) == 0)
        # Unmask these aliases to avoid nan in attention
        key_padding_mask_entities_mention[key_padding_mask_mentions] = False
        # Iterate through stages
        query_tensor = entity_embedding
        for stage_index in range(self.num_model_stages):
            # As we are adding a residual in the attention modules, we can make embs empty
            embs = []
            context_mat_dict = {}
            key_tensor_mention = (query_tensor.transpose(
                0,
                1).contiguous().reshape(batch_size, self.M, self.K,
                                        self.hidden_size).reshape(
                                            batch_size * self.M, self.K,
                                            self.hidden_size).transpose(0, 1))
            # ============================================================================
            # Phrase module: compute attention between entities and words
            # ============================================================================
            word_entity_attn_context, word_entity_attn_weights = self.attention_modules[
                f"stage_{stage_index}_entity_word"](
                    q=query_tensor,
                    x=sent_tensor,
                    key_mask=sent_embedding_mask,
                    attn_mask=None,
                )
            # Add embeddings to be merged in the output
            embs.append(word_entity_attn_context)
            # Save the attention weights
            self.attention_weights[
                f"stage_{stage_index}_entity_word"] = word_entity_attn_weights

            # ============================================================================
            # Co-occurrence module: compute self attention over entities
            # ============================================================================
            # Move entity mask to device
            # TODO: move to device in init?
            self.e2e_entity_mask = self.e2e_entity_mask.to(
                key_padding_mask_entities.device)

            entity_attn_context, entity_attn_weights = self.attention_modules[
                f"stage_{stage_index}_self_entity"](
                    x=query_tensor,
                    key_mask=key_padding_mask_entities,
                    attn_mask=self.e2e_entity_mask,
                )
            # Mask out MxK of single aliases, alias_indices is batch x M, mask is true when single alias
            non_null_aliases = (self.K - key_padding_mask_entities.reshape(
                batch_size, self.M, self.K).sum(-1)) != 0
            entity_attn_post_mask = ((
                non_null_aliases.sum(1) == 1).unsqueeze(1).expand(
                    batch_size, self.K * self.M).transpose(0, 1))
            entity_attn_post_mask = entity_attn_post_mask.unsqueeze(
                -1).expand_as(entity_attn_context)
            entity_attn_context = torch.where(
                entity_attn_post_mask,
                torch.zeros_like(entity_attn_context),
                entity_attn_context,
            )

            # Add embeddings to be merged in the output
            embs.append(entity_attn_context)
            # Save the attention weights
            self.attention_weights[
                f"stage_{stage_index}_self_entity"] = entity_attn_weights

            # ============================================================================
            # Mention module: compute attention between entities and mentions
            # ============================================================================
            # output is 1 x batch*M x dim
            (
                mention_entity_attn_context,
                mention_entity_attn_weights,
            ) = self.attention_modules[f"stage_{stage_index}_mention_entity"](
                q=mention_tensor,
                x=key_tensor_mention,
                key_mask=key_padding_mask_entities_mention,
                attn_mask=None,
            )
            # key_padding_mask_mentions mentions have all padded candidates,
            # meaning their row in the context matrix are all nan
            mention_entity_attn_context[key_padding_mask_mentions.unsqueeze(
                0)] = 0
            mention_entity_attn_context = (mention_entity_attn_context.expand(
                self.K, batch_size * self.M,
                self.hidden_size).transpose(0, 1).reshape(
                    batch_size, self.M * self.K,
                    self.hidden_size).transpose(0, 1))
            # Add embeddings to be merged in the output
            embs.append(mention_entity_attn_context)
            # Save the attention weights
            self.attention_weights[
                f"stage_{stage_index}_mention_entity"] = mention_entity_attn_weights

            # Combine module output
            context_matrix_nokg = self.combine_modules[
                f"stage_{stage_index}_combine"](embs)
            context_mat_dict[self.no_kg_key] = context_matrix_nokg.transpose(
                0, 1).reshape(batch_size, self.M, self.K, self.hidden_size)
            # ============================================================================
            # KG module: add in KG connectivity bias
            # ============================================================================
            for key in self.kg_bias_keys:
                context_matrix_kg = torch.bmm(
                    kg_bias_norms[key],
                    context_matrix_nokg.transpose(0, 1)).transpose(0, 1)
                context_matrix_kg = (context_matrix_nokg +
                                     context_matrix_kg) / 2
                context_mat_dict[
                    f"context_matrix_{key}"] = context_matrix_kg.transpose(
                        0, 1).reshape(batch_size, self.M, self.K,
                                      self.hidden_size)

            if stage_index < self.num_model_stages - 1:
                score = model_utils.max_score_context_matrix(
                    context_mat_dict,
                    self.predict_layers[DISAMBIG][
                        bootleg.utils.model_utils.get_stage_head_name(
                            stage_index)],
                )
                out[DISAMBIG][
                    f"{bootleg.utils.model_utils.get_stage_head_name(stage_index)}"] = score

            # This will take the average of the context matrices that do not end in the key "_nokg";
            # if there are not kg bias terms, it will select the context_matrix_nokg
            # (as it's key, in this setting, will not end in _nokg)
            query_tensor = (model_utils.generate_final_context_matrix(
                context_mat_dict, ending_key_to_exclude="_nokg").reshape(
                    batch_size, self.M * self.K,
                    self.hidden_size).transpose(0, 1))
        return {
            "intermed_scores": out,
            "ent_embs": context_mat_dict,
            "final_scores": None,
        }
Beispiel #4
0
    def forward(
        self,
        sent_embedding,
        sent_embedding_mask,
        entity_embedding,
        entity_embedding_mask,
        start_span_idx,
        end_span_idx,
        batch_on_the_fly_data,
    ):
        """Model forward.

        Args:
            sent_embedding: sentence embedding (B x N x L)
            sent_embedding_mask: sentence embedding mask (B x N)
            entity_embedding: entity embedding (B x M x K x H)
            entity_embedding_mask: entity embedding mask (B x M x K)
            start_span_idx: start mention index into sentence (B x M)
            end_span_idx: end mention index into sentence (B x M)
            batch_on_the_fly_data: batch on the fly dictionary with values (B x (M*K) x (M*K)) of KG adjacency matrices

        Returns: Dict of Dict of intermediate layer candidate scores (B x M x K),
                         Dict of all output entity embeddings from each KG matrix (B x M x K x H)
        """
        batch_size = sent_embedding.shape[0]
        out = {DISAMBIG: {}}

        # Create KG bias matrices for each kg bias key
        kg_bias_norms = {}
        for key in self.kg_bias_keys:
            bias_weight = getattr(self, key)  # self.kg_bias_weights[key]
            kg_bias = (batch_on_the_fly_data[key].float().to(
                sent_embedding.device).reshape(batch_size, self.M * self.K,
                                               self.M * self.K))
            kg_bias_diag = kg_bias + bias_weight * torch.eye(
                self.M * self.K).repeat(batch_size, 1, 1).view(
                    batch_size, self.M * self.K, self.M * self.K).to(
                        kg_bias.device)
            kg_bias_norm = self.kg_softmax(
                kg_bias_diag.masked_fill((kg_bias_diag == 0), float(-1e9)))
            kg_bias_norms[key] = kg_bias_norm
        sent_tensor = sent_embedding.transpose(0, 1)

        # Resize the alias embeddings and the entity mask from B x M x K x D -> B x (M*K) x D
        entity_embedding = entity_embedding.contiguous().view(
            batch_size, self.M * self.K, self.hidden_size)
        entity_embedding = entity_embedding.transpose(
            0, 1)  # reshape for attention
        key_padding_mask_entities = entity_embedding_mask.contiguous().view(
            batch_size, self.M * self.K)

        # Iterate through stages
        query_tensor = entity_embedding
        for stage_index in range(self.num_model_stages):
            # As we are adding a residual in the attention modules, we can make embs empty
            embs = []
            context_mat_dict = {}
            # ============================================================================
            # Phrase module: compute attention between entities and words
            # ============================================================================
            word_entity_attn_context, word_entity_attn_weights = self.attention_modules[
                f"stage_{stage_index}_entity_word"](
                    q=query_tensor,
                    x=sent_tensor,
                    key_mask=sent_embedding_mask,
                    attn_mask=None,
                )
            # Add embeddings to be merged in the output
            embs.append(word_entity_attn_context)
            # Save the attention weights
            self.attention_weights[
                f"stage_{stage_index}_entity_word"] = word_entity_attn_weights

            # ============================================================================
            # Co-occurrence module: compute self attention over entities
            # ============================================================================
            # Move entity mask to device
            # TODO: move to device in init?
            self.e2e_entity_mask = self.e2e_entity_mask.to(
                key_padding_mask_entities.device)

            entity_attn_context, entity_attn_weights = self.attention_modules[
                f"stage_{stage_index}_self_entity"](
                    x=query_tensor,
                    key_mask=key_padding_mask_entities,
                    attn_mask=self.e2e_entity_mask,
                )
            # Mask out MxK of single aliases, alias_indices is batch x M, mask is true when single alias
            non_null_aliases = (self.K - key_padding_mask_entities.reshape(
                batch_size, self.M, self.K).sum(-1)) != 0
            entity_attn_post_mask = ((
                non_null_aliases.sum(1) == 1).unsqueeze(1).expand(
                    batch_size, self.K * self.M).transpose(0, 1))
            entity_attn_post_mask = entity_attn_post_mask.unsqueeze(
                -1).expand_as(entity_attn_context)
            entity_attn_context = torch.where(
                entity_attn_post_mask,
                torch.zeros_like(entity_attn_context),
                entity_attn_context,
            )

            # Add embeddings to be merged in the output
            embs.append(entity_attn_context)
            # Save the attention weights
            self.attention_weights[
                f"stage_{stage_index}_self_entity"] = entity_attn_weights

            # Combine module output
            context_matrix_nokg = self.combine_modules[
                f"stage_{stage_index}_combine"](embs)
            context_mat_dict[self.no_kg_key] = context_matrix_nokg.transpose(
                0, 1).reshape(batch_size, self.M, self.K, self.hidden_size)
            # ============================================================================
            # KG module: add in KG connectivity bias
            # ============================================================================
            for key in self.kg_bias_keys:
                context_matrix_kg = torch.bmm(
                    kg_bias_norms[key],
                    context_matrix_nokg.transpose(0, 1)).transpose(0, 1)
                context_matrix_kg = (context_matrix_nokg +
                                     context_matrix_kg) / 2
                context_mat_dict[
                    f"context_matrix_{key}"] = context_matrix_kg.transpose(
                        0, 1).reshape(batch_size, self.M, self.K,
                                      self.hidden_size)

            if stage_index < self.num_model_stages - 1:
                score = model_utils.max_score_context_matrix(
                    context_mat_dict,
                    self.predict_layers[DISAMBIG][
                        bootleg.utils.model_utils.get_stage_head_name(
                            stage_index)],
                )
                out[DISAMBIG][
                    f"{bootleg.utils.model_utils.get_stage_head_name(stage_index)}"] = score

            # This will take the average of the context matrices that do not end in the key "_nokg";
            # if there are not kg bias terms, it will select the context_matrix_nokg
            # (as it's key, in this setting, will not end in _nokg)
            query_tensor = (model_utils.generate_final_context_matrix(
                context_mat_dict, ending_key_to_exclude="_nokg").reshape(
                    batch_size, self.M * self.K,
                    self.hidden_size).transpose(0, 1))
        return {
            "intermed_scores": out,
            "ent_embs": context_mat_dict,
            "final_scores": None,
        }
Beispiel #5
0
    def forward(self, context_matrix_dict, alias_idx_pair_sent, entity_pack,
                sent_emb):
        out = {DISAMBIG: {}, INDICATOR: {}}
        indicator_outputs = {}
        expert_slice_repr = {}
        predictor_outputs = {}

        if "context_matrix_main" not in context_matrix_dict:
            context_matrix_dict[
                "context_matrix_main"] = model_utils.generate_final_context_matrix(
                    context_matrix_dict)

        context_matrix = context_matrix_dict["context_matrix_main"]

        batch_size, M, K, H = context_matrix.shape
        assert M == self.M
        assert K == self.K
        assert H == self.hidden_size
        for i, slice_head in enumerate(self.train_heads):
            # Generate slice expert representation per head
            # context_matrix is batch x M x K x H
            expert_slice_repr[slice_head] = self.transform_modules[slice_head](
                context_matrix)
            # Pass the expert slice representation through the shared prediction layer
            # Predictor_outputs is batch x M x K
            predictor_outputs[slice_head] = self.shared_slice_pred_head(
                expert_slice_repr[slice_head]).squeeze(-1)
            # Get an alias_matrix output (batch x M x H)
            # TODO: remove extra inputs
            alias_matrix, alias_word_weights = self.ind_alias_mha[slice_head](
                sent_embedding=sent_emb,
                entity_embedding=context_matrix,
                entity_mask=entity_pack.mask,
                alias_idx_pair_sent=alias_idx_pair_sent,
                slice_emb_alias=self.slice_emb_ind_alias(
                    torch.tensor(i, device=context_matrix.device)),
                slice_emb_ent=self.slice_emb_ind_ent(
                    torch.tensor(i, device=context_matrix.device)))
            # Get indicator head outputs; size batch x M x 2 per head
            indicator_outputs[slice_head] = self.indicator_heads[slice_head](
                alias_matrix)

        # Generate predictions via softmax + taking the "positive" class label
        # Output size is batch x M x num_slices
        indicator_preds = torch.cat(
            [
                F.log_softmax(indicator_outputs[slice_head],
                              dim=-1)[:, :, 1].unsqueeze(-1)
                for slice_head in self.train_heads
            ],
            dim=-1,
        )
        assert not torch.isnan(indicator_preds).any()
        assert not torch.isinf(indicator_preds).any()
        # Compute the "confidence"
        # Output size should be batch x M x K x num_slices
        predictor_confidences = torch.cat(
            [
                eval_utils.masked_class_logsoftmax(
                    pred=predictor_outputs[slice_head],
                    mask=~entity_pack.mask).unsqueeze(-1)
                for slice_head in self.train_heads
            ],
            dim=-1,
        )
        assert not torch.isnan(predictor_confidences).any()
        assert not torch.isinf(predictor_confidences).any()
        if self.use_ind_attn:
            prod = indicator_preds  # * margin_confidence
            prod[prod < 0.1] = -100000.0 * self.temperature
            attention_weights = F.softmax(prod / self.temperature, dim=-1)
        else:
            # Take margin confidence over K values to generate confidences of batch x M x num_slices
            vals, indices = torch.topk(predictor_confidences, k=2, dim=2)
            margin_confidence = (vals[:, :, 0, :] -
                                 vals[:, :, 1, :]) / vals.sum(2)
            assert list(margin_confidence.shape) == [
                batch_size, self.M, len(self.train_heads)
            ]
            attention_weights = F.softmax(
                (indicator_preds + margin_confidence) / self.temperature,
                dim=-1)

        assert not torch.isnan(attention_weights).any()
        assert not torch.isinf(attention_weights).any()

        # attention_weights is batch x M x num_slices
        # slice_representations is batch_size x M x K x num_slices x feat_dim
        slice_representations = torch.stack(
            [expert_slice_repr[slice_head] for slice_head in self.train_heads],
            dim=3)

        # attention_weights becomes batch_size x M x K x num_slices x H of slice_representations
        attention_weights = attention_weights.unsqueeze(2).unsqueeze(
            -1).expand_as(slice_representations)
        # Reweight representations with weighted sum across slices
        reweighted_rep = torch.sum(attention_weights * slice_representations,
                                   dim=3)
        assert reweighted_rep.shape == context_matrix.shape
        # Pass through the final prediction layer
        for slice_head in self.train_heads:
            out[DISAMBIG][train_utils.get_slice_head_pred_name(
                slice_head)] = predictor_outputs[slice_head]
        for slice_head in self.train_heads:
            out[INDICATOR][train_utils.get_slice_head_ind_name(
                slice_head)] = indicator_outputs[slice_head]
        # Used for debugging
        if self.remove_final_loss:
            out[DISAMBIG][FINAL_LOSS] = out[DISAMBIG][
                train_utils.get_slice_head_pred_name(BASE_SLICE)]
        else:
            out[DISAMBIG][FINAL_LOSS] = self.final_pred_head(
                reweighted_rep).squeeze(-1)
        return out, reweighted_rep
Beispiel #6
0
    def forward(self, alias_idx_pair_sent, sent_embedding, entity_embedding,
                batch_prepped, batch_on_the_fly_data):
        batch_size = sent_embedding.tensor.shape[0]
        out = {DISAMBIG: {}}

        # Create KG bias matrices for each kg bias key
        kg_bias_norms = {}
        for key in self.kg_bias_keys:
            kg_bias = batch_on_the_fly_data[key].float().to(
                sent_embedding.tensor.device).reshape(batch_size,
                                                      self.M * self.K,
                                                      self.M * self.K)
            kg_bias_diag = kg_bias + self.kg_bias_weights[key] * torch.eye(
                self.M * self.K).repeat(batch_size, 1, 1).view(
                    batch_size, self.M * self.K, self.M * self.K).to(
                        kg_bias.device)
            kg_bias_norm = self.kg_softmax(
                kg_bias_diag.masked_fill((kg_bias_diag == 0), float(-1e9)))
            kg_bias_norms[key] = kg_bias_norm

        # get mention embedding
        # average words in mention; batch x M x dim
        mention_tensor_start = model_utils.select_alias_word_sent(
            alias_idx_pair_sent, sent_embedding, index=0)
        mention_tensor_end = model_utils.select_alias_word_sent(
            alias_idx_pair_sent, sent_embedding, index=1)
        mention_tensor = (mention_tensor_start + mention_tensor_end) / 2

        # reshape for alias attention where each mention attends to its K candidates
        # query = batch*M x 1 x dim, key = value = batch*M x K x dim
        # softmax(QK^T) -> batch*M x 1 x K
        # softmax(QK^T)V -> batch*M x 1 x dim
        mention_tensor = mention_tensor.reshape(batch_size * self.M, 1,
                                                self.hidden_size).transpose(
                                                    0, 1)

        # get sentence embedding; move batch to middle
        sent_tensor = sent_embedding.tensor.transpose(0, 1)

        # Resize the alias embeddings and the entity mask from B x M x K x D -> B x (M*K) x D
        entity_mask = entity_embedding.mask
        entity_embedding = entity_embedding.tensor.contiguous().view(
            batch_size, self.M * self.K, self.hidden_size)
        entity_embedding = entity_embedding.transpose(
            0, 1)  # reshape for attention
        key_padding_mask_entities = entity_mask.contiguous().view(
            batch_size, self.M * self.K)
        key_padding_mask_entities_mention = entity_mask.contiguous().view(
            batch_size * self.M, self.K)
        # Mask of aliases; key_padding_mask_entities_mention of True means mask. We want to find aliases with all masked entities
        key_padding_mask_mentions = torch.sum(
            ~key_padding_mask_entities_mention, dim=-1) == 0
        # Unmask these aliases to avoid nan in attention
        key_padding_mask_entities_mention[key_padding_mask_mentions] = False
        # Iterate through stages
        query_tensor = entity_embedding
        for stage_index in range(self.num_model_stages):
            # As we are adding a residual in the attention modules, we can make embs empty
            embs = []
            context_mat_dict = {}
            key_tensor_mention = query_tensor.transpose(0,1).contiguous().reshape(batch_size, self.M, self.K, self.hidden_size)\
                .reshape(batch_size*self.M, self.K, self.hidden_size).transpose(0,1)
            #============================================================================
            # Phrase module: compute attention between entities and words
            #============================================================================
            word_entity_attn_context, word_entity_attn_weights = self.attention_modules[
                f"stage_{stage_index}_entity_word"](
                    q=query_tensor,
                    x=sent_tensor,
                    key_mask=sent_embedding.mask,
                    attn_mask=None)
            # Add embeddings to be merged in the output
            embs.append(word_entity_attn_context)
            # Save the attention weights
            self.attention_weights[
                f"stage_{stage_index}_entity_word"] = word_entity_attn_weights

            #============================================================================
            # Co-occurrence module: compute self attention over entities
            #============================================================================
            # Move entity mask to device
            # TODO: move to device in init?
            self.e2e_entity_mask = self.e2e_entity_mask.to(
                key_padding_mask_entities.device)

            entity_attn_context, entity_attn_weights = self.attention_modules[
                f"stage_{stage_index}_self_entity"](
                    x=query_tensor,
                    key_mask=key_padding_mask_entities,
                    attn_mask=self.e2e_entity_mask)
            # Mask out MxK of single aliases, alias_indices is batch x M, mask is true when single alias
            non_null_aliases = (self.K - key_padding_mask_entities.reshape(
                batch_size, self.M, self.K).sum(-1)) != 0
            entity_attn_post_mask = (
                non_null_aliases.sum(1) == 1).unsqueeze(1).expand(
                    batch_size, self.K * self.M).transpose(0, 1)
            entity_attn_post_mask = entity_attn_post_mask.unsqueeze(
                -1).expand_as(entity_attn_context)
            entity_attn_context = torch.where(
                entity_attn_post_mask, torch.zeros_like(entity_attn_context),
                entity_attn_context)

            # Add embeddings to be merged in the output
            embs.append(entity_attn_context)
            # Save the attention weights
            self.attention_weights[
                f"stage_{stage_index}_self_entity"] = entity_attn_weights

            #============================================================================
            # Mention module: compute attention between entities and mentions
            #============================================================================
            # output is 1 x batch*M x dim
            mention_entity_attn_context, mention_entity_attn_weights = self.attention_modules[
                f"stage_{stage_index}_mention_entity"](
                    q=mention_tensor,
                    x=key_tensor_mention,
                    key_mask=key_padding_mask_entities_mention,
                    attn_mask=None)
            # key_padding_mask_mentions mentions have all padded candidates, meaning their row in the context matrix are all nan
            mention_entity_attn_context[key_padding_mask_mentions.unsqueeze(
                0)] = 0
            mention_entity_attn_context = mention_entity_attn_context.expand(self.K, batch_size*self.M, self.hidden_size).transpose(0,1).\
                reshape(batch_size, self.M*self.K, self.hidden_size).transpose(0,1)
            # Add embeddings to be merged in the output
            embs.append(mention_entity_attn_context)
            # Save the attention weights
            self.attention_weights[
                f"stage_{stage_index}_mention_entity"] = mention_entity_attn_weights

            # Combine module output
            context_matrix_nokg = self.combine_modules[
                f"stage_{stage_index}_combine"](embs)
            context_mat_dict[self.no_kg_key] = context_matrix_nokg.transpose(
                0, 1).reshape(batch_size, self.M, self.K, self.hidden_size)
            #============================================================================
            # KG module: add in KG connectivity bias
            #============================================================================
            for key in self.kg_bias_keys:
                context_matrix_kg = torch.bmm(
                    kg_bias_norms[key],
                    context_matrix_nokg.transpose(0, 1)).transpose(0, 1)
                context_matrix_kg = (context_matrix_nokg +
                                     context_matrix_kg) / 2
                context_mat_dict[
                    f"context_matrix_{key}"] = context_matrix_kg.transpose(
                        0, 1).reshape(batch_size, self.M, self.K,
                                      self.hidden_size)

            if stage_index < self.num_model_stages - 1:
                score = model_utils.max_score_context_matrix(
                    context_mat_dict, self.predict_layers[DISAMBIG][
                        train_utils.get_stage_head_name(stage_index)])
                out[DISAMBIG][
                    f"{train_utils.get_stage_head_name(stage_index)}"] = score

            # This will take the average of the context matrices that do not end in the key "_nokg"; if there are not kg bias terms, it will
            # select the context_matrix_nokg (as it's key, in this setting, will not end in _nokg)
            query_tensor = model_utils.generate_final_context_matrix(context_mat_dict, ending_key_to_exclude="_nokg")\
                .reshape(batch_size, self.M*self.K, self.hidden_size).transpose(0,1)
        return context_mat_dict, out