def forward(self, out_dict, context_matrix_dict, final_score=None): """Model forward. Must pass the self.training bool forward to loss function. Args: out_dict: Dict of intermediate scores (B x M x K) context_matrix_dict: Dict of output embedding matrices, e.g., from KG modules (B x M x K x H) final_score: Final output scores (B x M x K) (default None) Returns: Dict of Dict with final scores added (B x M x K), final output embedding (B x M x K x H), tensor of is_training Bool """ score = model_utils.max_score_context_matrix( context_matrix_dict, self.prediction_head ) out_dict[DISAMBIG][FINAL_LOSS] = score if "context_matrix_main" not in context_matrix_dict: context_matrix_dict[ "context_matrix_main" ] = model_utils.generate_final_context_matrix( context_matrix_dict, ending_key_to_exclude="_nokg" ) final_entity_embs = context_matrix_dict["context_matrix_main"] # Must make the self.training bool a tensor that is required for the loss to be gatherable for DP return { "final_scores": out_dict, "ent_embs": final_entity_embs, "training": ( torch.tensor([1], device=final_entity_embs.device) * self.training ).bool(), }
def forward(self, context_matrix_dict, alias_idx_pair_sent, entity_pack, sent_emb): out = {DISAMBIG: {}} score = model_utils.max_score_context_matrix(context_matrix_dict, self.prediction_head) out[DISAMBIG][FINAL_LOSS] = score if "context_matrix_main" not in context_matrix_dict: context_matrix_dict[ "context_matrix_main"] = model_utils.generate_final_context_matrix( context_matrix_dict, ending_key_to_exclude="_nokg") return out, context_matrix_dict["context_matrix_main"]
def forward( self, sent_embedding, sent_embedding_mask, entity_embedding, entity_embedding_mask, start_span_idx, end_span_idx, batch_on_the_fly_data, ): """Model forward. Args: sent_embedding: sentence embedding (B x N x L) sent_embedding_mask: sentence embedding mask (B x N) entity_embedding: entity embedding (B x M x K x H) entity_embedding_mask: entity embedding mask (B x M x K) start_span_idx: start mention index into sentence (B x M) end_span_idx: end mention index into sentence (B x M) batch_on_the_fly_data: batch on the fly dictionary with values (B x (M*K) x (M*K)) of KG adjacency matrices Returns: Dict of Dict of intermediate layer candidate scores (B x M x K), Dict of all output entity embeddings from each KG matrix (B x M x K x H) """ batch_size = sent_embedding.shape[0] out = {DISAMBIG: {}} # Create KG bias matrices for each kg bias key kg_bias_norms = {} for key in self.kg_bias_keys: kg_bias_norms[key] = (batch_on_the_fly_data[key].float().reshape( batch_size, self.M * self.K, self.M * self.K)) # get mention embedding # average words in mention; batch x M x dim mention_tensor_start = model_utils.select_alias_word_sent( start_span_idx, sent_embedding) mention_tensor_end = model_utils.select_alias_word_sent( end_span_idx, sent_embedding) mention_tensor = (mention_tensor_start + mention_tensor_end) / 2 # reshape for alias attention where each mention attends to its K candidates # query = batch*M x 1 x dim, key = value = batch*M x K x dim # softmax(QK^T) -> batch*M x 1 x K # softmax(QK^T)V -> batch*M x 1 x dim mention_tensor = mention_tensor.reshape(batch_size * self.M, 1, self.hidden_size).transpose( 0, 1) # get sentence embedding; move batch to middle sent_tensor = sent_embedding.transpose(0, 1) # Resize the alias embeddings and the entity mask from B x M x K x D -> B x (M*K) x D entity_embedding = entity_embedding.contiguous().view( batch_size, self.M * self.K, self.hidden_size) entity_embedding = entity_embedding.transpose( 0, 1) # reshape for attention key_padding_mask_entities = entity_embedding_mask.contiguous().view( batch_size, self.M * self.K) key_padding_mask_entities_mention = entity_embedding_mask.contiguous( ).view(batch_size * self.M, self.K) # Mask of aliases; key_padding_mask_entities_mention of True means mask. # We want to find aliases with all masked entities key_padding_mask_mentions = (torch.sum( ~key_padding_mask_entities_mention, dim=-1) == 0) # Unmask these aliases to avoid nan in attention key_padding_mask_entities_mention[key_padding_mask_mentions] = False # Iterate through stages query_tensor = entity_embedding for stage_index in range(self.num_model_stages): # As we are adding a residual in the attention modules, we can make embs empty embs = [] context_mat_dict = {} key_tensor_mention = (query_tensor.transpose( 0, 1).contiguous().reshape(batch_size, self.M, self.K, self.hidden_size).reshape( batch_size * self.M, self.K, self.hidden_size).transpose(0, 1)) # ============================================================================ # Phrase module: compute attention between entities and words # ============================================================================ word_entity_attn_context, word_entity_attn_weights = self.attention_modules[ f"stage_{stage_index}_entity_word"]( q=query_tensor, x=sent_tensor, key_mask=sent_embedding_mask, attn_mask=None, ) # Add embeddings to be merged in the output embs.append(word_entity_attn_context) # Save the attention weights self.attention_weights[ f"stage_{stage_index}_entity_word"] = word_entity_attn_weights # ============================================================================ # Co-occurrence module: compute self attention over entities # ============================================================================ # Move entity mask to device # TODO: move to device in init? self.e2e_entity_mask = self.e2e_entity_mask.to( key_padding_mask_entities.device) entity_attn_context, entity_attn_weights = self.attention_modules[ f"stage_{stage_index}_self_entity"]( x=query_tensor, key_mask=key_padding_mask_entities, attn_mask=self.e2e_entity_mask, ) # Mask out MxK of single aliases, alias_indices is batch x M, mask is true when single alias non_null_aliases = (self.K - key_padding_mask_entities.reshape( batch_size, self.M, self.K).sum(-1)) != 0 entity_attn_post_mask = (( non_null_aliases.sum(1) == 1).unsqueeze(1).expand( batch_size, self.K * self.M).transpose(0, 1)) entity_attn_post_mask = entity_attn_post_mask.unsqueeze( -1).expand_as(entity_attn_context) entity_attn_context = torch.where( entity_attn_post_mask, torch.zeros_like(entity_attn_context), entity_attn_context, ) # Add embeddings to be merged in the output embs.append(entity_attn_context) # Save the attention weights self.attention_weights[ f"stage_{stage_index}_self_entity"] = entity_attn_weights # ============================================================================ # Mention module: compute attention between entities and mentions # ============================================================================ # output is 1 x batch*M x dim ( mention_entity_attn_context, mention_entity_attn_weights, ) = self.attention_modules[f"stage_{stage_index}_mention_entity"]( q=mention_tensor, x=key_tensor_mention, key_mask=key_padding_mask_entities_mention, attn_mask=None, ) # key_padding_mask_mentions mentions have all padded candidates, # meaning their row in the context matrix are all nan mention_entity_attn_context[key_padding_mask_mentions.unsqueeze( 0)] = 0 mention_entity_attn_context = (mention_entity_attn_context.expand( self.K, batch_size * self.M, self.hidden_size).transpose(0, 1).reshape( batch_size, self.M * self.K, self.hidden_size).transpose(0, 1)) # Add embeddings to be merged in the output embs.append(mention_entity_attn_context) # Save the attention weights self.attention_weights[ f"stage_{stage_index}_mention_entity"] = mention_entity_attn_weights # Combine module output context_matrix_nokg = self.combine_modules[ f"stage_{stage_index}_combine"](embs) context_mat_dict[self.no_kg_key] = context_matrix_nokg.transpose( 0, 1).reshape(batch_size, self.M, self.K, self.hidden_size) # ============================================================================ # KG module: add in KG connectivity bias # ============================================================================ for key in self.kg_bias_keys: context_matrix_kg = torch.bmm( kg_bias_norms[key], context_matrix_nokg.transpose(0, 1)).transpose(0, 1) context_matrix_kg = (context_matrix_nokg + context_matrix_kg) / 2 context_mat_dict[ f"context_matrix_{key}"] = context_matrix_kg.transpose( 0, 1).reshape(batch_size, self.M, self.K, self.hidden_size) if stage_index < self.num_model_stages - 1: score = model_utils.max_score_context_matrix( context_mat_dict, self.predict_layers[DISAMBIG][ bootleg.utils.model_utils.get_stage_head_name( stage_index)], ) out[DISAMBIG][ f"{bootleg.utils.model_utils.get_stage_head_name(stage_index)}"] = score # This will take the average of the context matrices that do not end in the key "_nokg"; # if there are not kg bias terms, it will select the context_matrix_nokg # (as it's key, in this setting, will not end in _nokg) query_tensor = (model_utils.generate_final_context_matrix( context_mat_dict, ending_key_to_exclude="_nokg").reshape( batch_size, self.M * self.K, self.hidden_size).transpose(0, 1)) return { "intermed_scores": out, "ent_embs": context_mat_dict, "final_scores": None, }
def forward( self, sent_embedding, sent_embedding_mask, entity_embedding, entity_embedding_mask, start_span_idx, end_span_idx, batch_on_the_fly_data, ): """Model forward. Args: sent_embedding: sentence embedding (B x N x L) sent_embedding_mask: sentence embedding mask (B x N) entity_embedding: entity embedding (B x M x K x H) entity_embedding_mask: entity embedding mask (B x M x K) start_span_idx: start mention index into sentence (B x M) end_span_idx: end mention index into sentence (B x M) batch_on_the_fly_data: batch on the fly dictionary with values (B x (M*K) x (M*K)) of KG adjacency matrices Returns: Dict of Dict of intermediate layer candidate scores (B x M x K), Dict of all output entity embeddings from each KG matrix (B x M x K x H) """ batch_size = sent_embedding.shape[0] out = {DISAMBIG: {}} # Create KG bias matrices for each kg bias key kg_bias_norms = {} for key in self.kg_bias_keys: bias_weight = getattr(self, key) # self.kg_bias_weights[key] kg_bias = (batch_on_the_fly_data[key].float().to( sent_embedding.device).reshape(batch_size, self.M * self.K, self.M * self.K)) kg_bias_diag = kg_bias + bias_weight * torch.eye( self.M * self.K).repeat(batch_size, 1, 1).view( batch_size, self.M * self.K, self.M * self.K).to( kg_bias.device) kg_bias_norm = self.kg_softmax( kg_bias_diag.masked_fill((kg_bias_diag == 0), float(-1e9))) kg_bias_norms[key] = kg_bias_norm sent_tensor = sent_embedding.transpose(0, 1) # Resize the alias embeddings and the entity mask from B x M x K x D -> B x (M*K) x D entity_embedding = entity_embedding.contiguous().view( batch_size, self.M * self.K, self.hidden_size) entity_embedding = entity_embedding.transpose( 0, 1) # reshape for attention key_padding_mask_entities = entity_embedding_mask.contiguous().view( batch_size, self.M * self.K) # Iterate through stages query_tensor = entity_embedding for stage_index in range(self.num_model_stages): # As we are adding a residual in the attention modules, we can make embs empty embs = [] context_mat_dict = {} # ============================================================================ # Phrase module: compute attention between entities and words # ============================================================================ word_entity_attn_context, word_entity_attn_weights = self.attention_modules[ f"stage_{stage_index}_entity_word"]( q=query_tensor, x=sent_tensor, key_mask=sent_embedding_mask, attn_mask=None, ) # Add embeddings to be merged in the output embs.append(word_entity_attn_context) # Save the attention weights self.attention_weights[ f"stage_{stage_index}_entity_word"] = word_entity_attn_weights # ============================================================================ # Co-occurrence module: compute self attention over entities # ============================================================================ # Move entity mask to device # TODO: move to device in init? self.e2e_entity_mask = self.e2e_entity_mask.to( key_padding_mask_entities.device) entity_attn_context, entity_attn_weights = self.attention_modules[ f"stage_{stage_index}_self_entity"]( x=query_tensor, key_mask=key_padding_mask_entities, attn_mask=self.e2e_entity_mask, ) # Mask out MxK of single aliases, alias_indices is batch x M, mask is true when single alias non_null_aliases = (self.K - key_padding_mask_entities.reshape( batch_size, self.M, self.K).sum(-1)) != 0 entity_attn_post_mask = (( non_null_aliases.sum(1) == 1).unsqueeze(1).expand( batch_size, self.K * self.M).transpose(0, 1)) entity_attn_post_mask = entity_attn_post_mask.unsqueeze( -1).expand_as(entity_attn_context) entity_attn_context = torch.where( entity_attn_post_mask, torch.zeros_like(entity_attn_context), entity_attn_context, ) # Add embeddings to be merged in the output embs.append(entity_attn_context) # Save the attention weights self.attention_weights[ f"stage_{stage_index}_self_entity"] = entity_attn_weights # Combine module output context_matrix_nokg = self.combine_modules[ f"stage_{stage_index}_combine"](embs) context_mat_dict[self.no_kg_key] = context_matrix_nokg.transpose( 0, 1).reshape(batch_size, self.M, self.K, self.hidden_size) # ============================================================================ # KG module: add in KG connectivity bias # ============================================================================ for key in self.kg_bias_keys: context_matrix_kg = torch.bmm( kg_bias_norms[key], context_matrix_nokg.transpose(0, 1)).transpose(0, 1) context_matrix_kg = (context_matrix_nokg + context_matrix_kg) / 2 context_mat_dict[ f"context_matrix_{key}"] = context_matrix_kg.transpose( 0, 1).reshape(batch_size, self.M, self.K, self.hidden_size) if stage_index < self.num_model_stages - 1: score = model_utils.max_score_context_matrix( context_mat_dict, self.predict_layers[DISAMBIG][ bootleg.utils.model_utils.get_stage_head_name( stage_index)], ) out[DISAMBIG][ f"{bootleg.utils.model_utils.get_stage_head_name(stage_index)}"] = score # This will take the average of the context matrices that do not end in the key "_nokg"; # if there are not kg bias terms, it will select the context_matrix_nokg # (as it's key, in this setting, will not end in _nokg) query_tensor = (model_utils.generate_final_context_matrix( context_mat_dict, ending_key_to_exclude="_nokg").reshape( batch_size, self.M * self.K, self.hidden_size).transpose(0, 1)) return { "intermed_scores": out, "ent_embs": context_mat_dict, "final_scores": None, }
def forward(self, context_matrix_dict, alias_idx_pair_sent, entity_pack, sent_emb): out = {DISAMBIG: {}, INDICATOR: {}} indicator_outputs = {} expert_slice_repr = {} predictor_outputs = {} if "context_matrix_main" not in context_matrix_dict: context_matrix_dict[ "context_matrix_main"] = model_utils.generate_final_context_matrix( context_matrix_dict) context_matrix = context_matrix_dict["context_matrix_main"] batch_size, M, K, H = context_matrix.shape assert M == self.M assert K == self.K assert H == self.hidden_size for i, slice_head in enumerate(self.train_heads): # Generate slice expert representation per head # context_matrix is batch x M x K x H expert_slice_repr[slice_head] = self.transform_modules[slice_head]( context_matrix) # Pass the expert slice representation through the shared prediction layer # Predictor_outputs is batch x M x K predictor_outputs[slice_head] = self.shared_slice_pred_head( expert_slice_repr[slice_head]).squeeze(-1) # Get an alias_matrix output (batch x M x H) # TODO: remove extra inputs alias_matrix, alias_word_weights = self.ind_alias_mha[slice_head]( sent_embedding=sent_emb, entity_embedding=context_matrix, entity_mask=entity_pack.mask, alias_idx_pair_sent=alias_idx_pair_sent, slice_emb_alias=self.slice_emb_ind_alias( torch.tensor(i, device=context_matrix.device)), slice_emb_ent=self.slice_emb_ind_ent( torch.tensor(i, device=context_matrix.device))) # Get indicator head outputs; size batch x M x 2 per head indicator_outputs[slice_head] = self.indicator_heads[slice_head]( alias_matrix) # Generate predictions via softmax + taking the "positive" class label # Output size is batch x M x num_slices indicator_preds = torch.cat( [ F.log_softmax(indicator_outputs[slice_head], dim=-1)[:, :, 1].unsqueeze(-1) for slice_head in self.train_heads ], dim=-1, ) assert not torch.isnan(indicator_preds).any() assert not torch.isinf(indicator_preds).any() # Compute the "confidence" # Output size should be batch x M x K x num_slices predictor_confidences = torch.cat( [ eval_utils.masked_class_logsoftmax( pred=predictor_outputs[slice_head], mask=~entity_pack.mask).unsqueeze(-1) for slice_head in self.train_heads ], dim=-1, ) assert not torch.isnan(predictor_confidences).any() assert not torch.isinf(predictor_confidences).any() if self.use_ind_attn: prod = indicator_preds # * margin_confidence prod[prod < 0.1] = -100000.0 * self.temperature attention_weights = F.softmax(prod / self.temperature, dim=-1) else: # Take margin confidence over K values to generate confidences of batch x M x num_slices vals, indices = torch.topk(predictor_confidences, k=2, dim=2) margin_confidence = (vals[:, :, 0, :] - vals[:, :, 1, :]) / vals.sum(2) assert list(margin_confidence.shape) == [ batch_size, self.M, len(self.train_heads) ] attention_weights = F.softmax( (indicator_preds + margin_confidence) / self.temperature, dim=-1) assert not torch.isnan(attention_weights).any() assert not torch.isinf(attention_weights).any() # attention_weights is batch x M x num_slices # slice_representations is batch_size x M x K x num_slices x feat_dim slice_representations = torch.stack( [expert_slice_repr[slice_head] for slice_head in self.train_heads], dim=3) # attention_weights becomes batch_size x M x K x num_slices x H of slice_representations attention_weights = attention_weights.unsqueeze(2).unsqueeze( -1).expand_as(slice_representations) # Reweight representations with weighted sum across slices reweighted_rep = torch.sum(attention_weights * slice_representations, dim=3) assert reweighted_rep.shape == context_matrix.shape # Pass through the final prediction layer for slice_head in self.train_heads: out[DISAMBIG][train_utils.get_slice_head_pred_name( slice_head)] = predictor_outputs[slice_head] for slice_head in self.train_heads: out[INDICATOR][train_utils.get_slice_head_ind_name( slice_head)] = indicator_outputs[slice_head] # Used for debugging if self.remove_final_loss: out[DISAMBIG][FINAL_LOSS] = out[DISAMBIG][ train_utils.get_slice_head_pred_name(BASE_SLICE)] else: out[DISAMBIG][FINAL_LOSS] = self.final_pred_head( reweighted_rep).squeeze(-1) return out, reweighted_rep
def forward(self, alias_idx_pair_sent, sent_embedding, entity_embedding, batch_prepped, batch_on_the_fly_data): batch_size = sent_embedding.tensor.shape[0] out = {DISAMBIG: {}} # Create KG bias matrices for each kg bias key kg_bias_norms = {} for key in self.kg_bias_keys: kg_bias = batch_on_the_fly_data[key].float().to( sent_embedding.tensor.device).reshape(batch_size, self.M * self.K, self.M * self.K) kg_bias_diag = kg_bias + self.kg_bias_weights[key] * torch.eye( self.M * self.K).repeat(batch_size, 1, 1).view( batch_size, self.M * self.K, self.M * self.K).to( kg_bias.device) kg_bias_norm = self.kg_softmax( kg_bias_diag.masked_fill((kg_bias_diag == 0), float(-1e9))) kg_bias_norms[key] = kg_bias_norm # get mention embedding # average words in mention; batch x M x dim mention_tensor_start = model_utils.select_alias_word_sent( alias_idx_pair_sent, sent_embedding, index=0) mention_tensor_end = model_utils.select_alias_word_sent( alias_idx_pair_sent, sent_embedding, index=1) mention_tensor = (mention_tensor_start + mention_tensor_end) / 2 # reshape for alias attention where each mention attends to its K candidates # query = batch*M x 1 x dim, key = value = batch*M x K x dim # softmax(QK^T) -> batch*M x 1 x K # softmax(QK^T)V -> batch*M x 1 x dim mention_tensor = mention_tensor.reshape(batch_size * self.M, 1, self.hidden_size).transpose( 0, 1) # get sentence embedding; move batch to middle sent_tensor = sent_embedding.tensor.transpose(0, 1) # Resize the alias embeddings and the entity mask from B x M x K x D -> B x (M*K) x D entity_mask = entity_embedding.mask entity_embedding = entity_embedding.tensor.contiguous().view( batch_size, self.M * self.K, self.hidden_size) entity_embedding = entity_embedding.transpose( 0, 1) # reshape for attention key_padding_mask_entities = entity_mask.contiguous().view( batch_size, self.M * self.K) key_padding_mask_entities_mention = entity_mask.contiguous().view( batch_size * self.M, self.K) # Mask of aliases; key_padding_mask_entities_mention of True means mask. We want to find aliases with all masked entities key_padding_mask_mentions = torch.sum( ~key_padding_mask_entities_mention, dim=-1) == 0 # Unmask these aliases to avoid nan in attention key_padding_mask_entities_mention[key_padding_mask_mentions] = False # Iterate through stages query_tensor = entity_embedding for stage_index in range(self.num_model_stages): # As we are adding a residual in the attention modules, we can make embs empty embs = [] context_mat_dict = {} key_tensor_mention = query_tensor.transpose(0,1).contiguous().reshape(batch_size, self.M, self.K, self.hidden_size)\ .reshape(batch_size*self.M, self.K, self.hidden_size).transpose(0,1) #============================================================================ # Phrase module: compute attention between entities and words #============================================================================ word_entity_attn_context, word_entity_attn_weights = self.attention_modules[ f"stage_{stage_index}_entity_word"]( q=query_tensor, x=sent_tensor, key_mask=sent_embedding.mask, attn_mask=None) # Add embeddings to be merged in the output embs.append(word_entity_attn_context) # Save the attention weights self.attention_weights[ f"stage_{stage_index}_entity_word"] = word_entity_attn_weights #============================================================================ # Co-occurrence module: compute self attention over entities #============================================================================ # Move entity mask to device # TODO: move to device in init? self.e2e_entity_mask = self.e2e_entity_mask.to( key_padding_mask_entities.device) entity_attn_context, entity_attn_weights = self.attention_modules[ f"stage_{stage_index}_self_entity"]( x=query_tensor, key_mask=key_padding_mask_entities, attn_mask=self.e2e_entity_mask) # Mask out MxK of single aliases, alias_indices is batch x M, mask is true when single alias non_null_aliases = (self.K - key_padding_mask_entities.reshape( batch_size, self.M, self.K).sum(-1)) != 0 entity_attn_post_mask = ( non_null_aliases.sum(1) == 1).unsqueeze(1).expand( batch_size, self.K * self.M).transpose(0, 1) entity_attn_post_mask = entity_attn_post_mask.unsqueeze( -1).expand_as(entity_attn_context) entity_attn_context = torch.where( entity_attn_post_mask, torch.zeros_like(entity_attn_context), entity_attn_context) # Add embeddings to be merged in the output embs.append(entity_attn_context) # Save the attention weights self.attention_weights[ f"stage_{stage_index}_self_entity"] = entity_attn_weights #============================================================================ # Mention module: compute attention between entities and mentions #============================================================================ # output is 1 x batch*M x dim mention_entity_attn_context, mention_entity_attn_weights = self.attention_modules[ f"stage_{stage_index}_mention_entity"]( q=mention_tensor, x=key_tensor_mention, key_mask=key_padding_mask_entities_mention, attn_mask=None) # key_padding_mask_mentions mentions have all padded candidates, meaning their row in the context matrix are all nan mention_entity_attn_context[key_padding_mask_mentions.unsqueeze( 0)] = 0 mention_entity_attn_context = mention_entity_attn_context.expand(self.K, batch_size*self.M, self.hidden_size).transpose(0,1).\ reshape(batch_size, self.M*self.K, self.hidden_size).transpose(0,1) # Add embeddings to be merged in the output embs.append(mention_entity_attn_context) # Save the attention weights self.attention_weights[ f"stage_{stage_index}_mention_entity"] = mention_entity_attn_weights # Combine module output context_matrix_nokg = self.combine_modules[ f"stage_{stage_index}_combine"](embs) context_mat_dict[self.no_kg_key] = context_matrix_nokg.transpose( 0, 1).reshape(batch_size, self.M, self.K, self.hidden_size) #============================================================================ # KG module: add in KG connectivity bias #============================================================================ for key in self.kg_bias_keys: context_matrix_kg = torch.bmm( kg_bias_norms[key], context_matrix_nokg.transpose(0, 1)).transpose(0, 1) context_matrix_kg = (context_matrix_nokg + context_matrix_kg) / 2 context_mat_dict[ f"context_matrix_{key}"] = context_matrix_kg.transpose( 0, 1).reshape(batch_size, self.M, self.K, self.hidden_size) if stage_index < self.num_model_stages - 1: score = model_utils.max_score_context_matrix( context_mat_dict, self.predict_layers[DISAMBIG][ train_utils.get_stage_head_name(stage_index)]) out[DISAMBIG][ f"{train_utils.get_stage_head_name(stage_index)}"] = score # This will take the average of the context matrices that do not end in the key "_nokg"; if there are not kg bias terms, it will # select the context_matrix_nokg (as it's key, in this setting, will not end in _nokg) query_tensor = model_utils.generate_final_context_matrix(context_mat_dict, ending_key_to_exclude="_nokg")\ .reshape(batch_size, self.M*self.K, self.hidden_size).transpose(0,1) return context_mat_dict, out