def __init__(self, args, emb_sizes, sent_emb_size, entity_symbols, word_symbols): super(Bootleg, self).__init__(args, emb_sizes, sent_emb_size, entity_symbols, word_symbols) self.dropout = args.train_config.dropout # For each stage, instantiate a transformer block for phrase (entity_word) and co-occurrence (self_entity) modules self.attention_modules = nn.ModuleDict() self.combine_modules = nn.ModuleDict() for i in range(self.num_model_stages): self.attention_modules[f"stage_{i}_entity_word"] = \ AttnBlock(size=self.hidden_size, ff_inner_size=args.model_config.ff_inner_size, dropout=self.dropout, num_heads=self.num_heads) self.attention_modules[f"stage_{i}_self_entity"] = \ SelfAttnBlock(size=self.hidden_size, ff_inner_size=args.model_config.ff_inner_size, dropout=self.dropout, num_heads=self.num_heads) self.combine_modules[f"stage_{i}_combine"] = NormAndSum( self.hidden_size) # For the KG bias module self.kg_bias_weights = nn.ParameterDict() for emb in args.data_config.ent_embeddings: if emb.load_class == KG_BIAS_LOAD_CLASS: self.kg_bias_weights[emb.key] = torch.nn.Parameter( torch.tensor(2.0)) self.kg_bias_keys = sorted(list(self.kg_bias_weights.keys())) # If we have kg bias terms, we want to take the average of those context matrices when generating the final context matrix to be returned. # The no_kg_key is used for the context matrix without kg_bias terms added. If we use the key ending in _nokg, it will not be averaged # in the final result. # If we do not have kg bias terms, we want the nokg context matrix to be the final matrix. MAIN_CONTEXT_MATRIX key allows for this. if len(self.kg_bias_keys) > 0: self.no_kg_key = "context_matrix_nokg" else: self.no_kg_key = MAIN_CONTEXT_MATRIX self.kg_softmax = nn.Softmax(dim=2) # Two things to note, the attn mask is a block diagonal matrix prevent an alias from paying attention to its own K candidates in the attention layer # This works because the original input is added to the output of this attention, meaning an alias becomes its # original embedding plus the contributions of the other aliases in the sentence. # Second, the attn mask is added to the attention before softmax (added to Q dot V^T) -- softmax makes e^(-1e9+old_value) become zero # When setting it to be -inf, you can get nans in the loss if all entities end up being masked out (eg only one alias in the sentence) self.e2e_entity_mask = torch.zeros((self.K * self.M, self.K * self.M)) for i in range(self.M): self.e2e_entity_mask[i * self.K:(i + 1) * self.K, i * self.K:(i + 1) * self.K] = 1.0 # Must manually move this to the device as it's not part of a module...we can probably fix this self.e2e_entity_mask = self.e2e_entity_mask.masked_fill( (self.e2e_entity_mask == 1), float(-1e9)) # Track attention weights self.attention_weights = {} # Prediction layers: each stage except the last gets a prediction layer # Last layer's prediction head is added in slice heads disambig_task = nn.ModuleDict() for i in range(self.num_model_stages - 1): disambig_task[train_utils.get_stage_head_name(i)] = MLP( self.hidden_size, self.hidden_size, 1, self.num_fc_layers, self.dropout) self.predict_layers = {DISAMBIG: disambig_task} self.predict_layers = nn.ModuleDict(self.predict_layers)
def __init__(self, args, emb_sizes, sent_emb_size, entity_symbols, word_symbols): super(BootlegV1, self).__init__(args, emb_sizes, sent_emb_size, entity_symbols, word_symbols) self.dropout = args.train_config.dropout # For each stage, instantiate a transformer block for phrase (entity_word) and co-occurrence (self_entity) modules self.attention_modules = nn.ModuleDict() self.combine_modules = nn.ModuleDict() for i in range(self.num_model_stages): self.attention_modules[f"stage_{i}_entity_word"] = \ AttnBlock(size=self.hidden_size, ff_inner_size=args.model_config.ff_inner_size, dropout=self.dropout, num_heads=self.num_heads) self.attention_modules[f"stage_{i}_self_entity"] = \ SelfAttnBlock(size=self.hidden_size, ff_inner_size=args.model_config.ff_inner_size, dropout=self.dropout, num_heads=self.num_heads) self.combine_modules[f"stage_{i}_combine"] = NormAndSum( self.hidden_size) # For the KG module self.kg_weight = torch.nn.Parameter(torch.tensor(2.0)) self.softmax = nn.Softmax(dim=2) # Two things to note, the attn mask is a block diagonal matrix prevent an alias from paying attention to its own K candidates in the attention layer # This works because the original input is added to the output of this attention, meaning an alias becomes its # original embedding plus the contributions of the other aliases in the sentence. # Second, the attn mask is added to the attention before softmax (added to Q dot V^T) -- softmax makes e^(-1e9+old_value) become zero # When setting it to be -inf, you can get nans in the loss if all entities end up being masked out (eg only one alias in the sentence) self.e2e_entity_mask = torch.zeros((self.K * self.M, self.K * self.M)) for i in range(self.M): self.e2e_entity_mask[i * self.K:(i + 1) * self.K, i * self.K:(i + 1) * self.K] = 1.0 # Must manually move this to the device as it's not part of a module...we can probably fix this self.e2e_entity_mask = self.e2e_entity_mask.masked_fill( (self.e2e_entity_mask == 1), float(-1e9)) # Track attention weights self.attention_weights = {} # Prediction layers: each stage except the last gets a prediction layer # Last layer's prediction head is added in slice heads disambig_task = nn.ModuleDict() for i in range(self.num_model_stages - 1): disambig_task[train_utils.get_stage_head_name(i)] = MLP( self.hidden_size, self.hidden_size, 1, self.num_fc_layers, self.dropout) self.predict_layers = {DISAMBIG: disambig_task} self.predict_layers = nn.ModuleDict(self.predict_layers)
def forward(self, alias_idx_pair_sent, sent_embedding, entity_embedding, batch_prepped, batch_on_the_fly_data): batch_size = sent_embedding.tensor.shape[0] out = {DISAMBIG: {}} # Prepare inputs for attention modules # Get the KG metadata for the KG module - this is 0/1 if pair is connected if REL_INDICES_KEY in batch_on_the_fly_data: kg_bias = batch_on_the_fly_data[REL_INDICES_KEY].float().to( sent_embedding.tensor.device).reshape(batch_size, self.M * self.K, self.M * self.K) else: kg_bias = torch.zeros(batch_size, self.M * self.K, self.M * self.K).to(sent_embedding.tensor.device) kg_bias_diag = kg_bias + self.kg_weight * torch.eye( self.M * self.K).repeat(batch_size, 1, 1).view( batch_size, self.M * self.K, self.M * self.K).to( kg_bias.device) kg_bias_norm = self.softmax( kg_bias_diag.masked_fill((kg_bias_diag == 0), float(-1e9))) sent_tensor = sent_embedding.tensor.transpose(0, 1) # Resize the alias embeddings and the entity mask from B x M x K x D to B x (M*K) x D entity_mask = entity_embedding.mask entity_embedding = entity_embedding.tensor.contiguous().view( batch_size, self.M * self.K, self.hidden_size) entity_embedding = entity_embedding.transpose( 0, 1) # reshape for attention key_padding_mask_entities = entity_mask.contiguous().view( batch_size, self.M * self.K) # Iterate through stages query_tensor = entity_embedding for stage_index in range(self.num_model_stages): # As we are adding a residual in the attention modules, we can make embs empty embs = [] #============================================================================ # Phrase module: compute attention between entities and words #============================================================================ word_entity_attn_context, word_entity_attn_weights = self.attention_modules[ f"stage_{stage_index}_entity_word"]( q=query_tensor, x=sent_tensor, key_mask=sent_embedding.mask, attn_mask=None) # Add embeddings to be merged in the output embs.append(word_entity_attn_context) # Save the attention weights self.attention_weights[ f"stage_{stage_index}_entity_word"] = word_entity_attn_weights #============================================================================ # Co-occurrence module: compute self attention over entities #============================================================================ # Move entity mask over self.e2e_entity_mask = self.e2e_entity_mask.to( key_padding_mask_entities.device) entity_attn_context, entity_attn_weights = self.attention_modules[ f"stage_{stage_index}_self_entity"]( x=query_tensor, key_mask=key_padding_mask_entities, attn_mask=self.e2e_entity_mask) # Mask out MxK of single aliases, alias_indices is batch x M, mask is true when single alias non_null_aliases = (self.K - key_padding_mask_entities.reshape( batch_size, self.M, self.K).sum(-1)) != 0 entity_attn_post_mask = ( non_null_aliases.sum(1) == 1).unsqueeze(1).expand( batch_size, self.K * self.M).transpose(0, 1) entity_attn_post_mask = entity_attn_post_mask.unsqueeze( -1).expand_as(entity_attn_context) entity_attn_context = torch.where( entity_attn_post_mask, torch.zeros_like(entity_attn_context), entity_attn_context) # Add embeddings to be merged in the output embs.append(entity_attn_context) # Save the attention weights self.attention_weights[ f"stage_{stage_index}_self_entity"] = entity_attn_weights context_matrix = self.combine_modules[ f"stage_{stage_index}_combine"](embs) #============================================================================ # KG module: add in KG connectivity bias #============================================================================ context_matrix_kg = torch.bmm(kg_bias_norm, context_matrix.transpose( 0, 1)).transpose(0, 1) if stage_index < self.num_model_stages - 1: pred = self.predict_layers[DISAMBIG][ train_utils.get_stage_head_name(stage_index)]( context_matrix) pred = pred.transpose(0, 1).squeeze(2).reshape( batch_size, self.M, self.K) pred_kg = self.predict_layers[DISAMBIG][ train_utils.get_stage_head_name(stage_index)]( (context_matrix + context_matrix_kg) / 2) pred_kg = pred_kg.transpose(0, 1).squeeze(2).reshape( batch_size, self.M, self.K) out[DISAMBIG][ f"{train_utils.get_stage_head_name(stage_index)}"] = torch.max( torch.cat([pred.unsqueeze(3), pred_kg.unsqueeze(3)], dim=-1), dim=-1)[0] pred_logit = eval_utils.masked_class_logsoftmax( pred=pred, mask=~entity_mask) context_norm = model_utils.normalize_matrix( (context_matrix + context_matrix_kg) / 2, dim=2) assert not torch.isnan(context_norm).any() assert not torch.isinf(context_norm).any() # Add predictions so model can learn from previous predictions context_matrix = context_norm + pred_logit.contiguous().view( batch_size, self.M * self.K, 1).transpose(0, 1) else: context_matrix_nokg = context_matrix context_matrix = (context_matrix + context_matrix_kg) / 2 context_matrix_main = context_matrix query_tensor = context_matrix context_matrix_nokg = context_matrix_nokg.transpose(0, 1).reshape( batch_size, self.M, self.K, self.hidden_size) context_matrix_main = context_matrix_main.transpose(0, 1).reshape( batch_size, self.M, self.K, self.hidden_size) # context_mat_dict is the contextualized entity embeddings, out is the predictions context_mat_dict = { "context_matrix_nokg": context_matrix_nokg, MAIN_CONTEXT_MATRIX: context_matrix_main } return context_mat_dict, out
def forward(self, alias_idx_pair_sent, sent_embedding, entity_embedding, batch_prepped, batch_on_the_fly_data): batch_size = sent_embedding.tensor.shape[0] out = {DISAMBIG: {}} # Create KG bias matrices for each kg bias key kg_bias_norms = {} for key in self.kg_bias_keys: kg_bias = batch_on_the_fly_data[key].float().to( sent_embedding.tensor.device).reshape(batch_size, self.M * self.K, self.M * self.K) kg_bias_diag = kg_bias + self.kg_bias_weights[key] * torch.eye( self.M * self.K).repeat(batch_size, 1, 1).view( batch_size, self.M * self.K, self.M * self.K).to( kg_bias.device) kg_bias_norm = self.kg_softmax( kg_bias_diag.masked_fill((kg_bias_diag == 0), float(-1e9))) kg_bias_norms[key] = kg_bias_norm # get mention embedding # average words in mention; batch x M x dim mention_tensor_start = model_utils.select_alias_word_sent( alias_idx_pair_sent, sent_embedding, index=0) mention_tensor_end = model_utils.select_alias_word_sent( alias_idx_pair_sent, sent_embedding, index=1) mention_tensor = (mention_tensor_start + mention_tensor_end) / 2 # reshape for alias attention where each mention attends to its K candidates # query = batch*M x 1 x dim, key = value = batch*M x K x dim # softmax(QK^T) -> batch*M x 1 x K # softmax(QK^T)V -> batch*M x 1 x dim mention_tensor = mention_tensor.reshape(batch_size * self.M, 1, self.hidden_size).transpose( 0, 1) # get sentence embedding; move batch to middle sent_tensor = sent_embedding.tensor.transpose(0, 1) # Resize the alias embeddings and the entity mask from B x M x K x D -> B x (M*K) x D entity_mask = entity_embedding.mask entity_embedding = entity_embedding.tensor.contiguous().view( batch_size, self.M * self.K, self.hidden_size) entity_embedding = entity_embedding.transpose( 0, 1) # reshape for attention key_padding_mask_entities = entity_mask.contiguous().view( batch_size, self.M * self.K) key_padding_mask_entities_mention = entity_mask.contiguous().view( batch_size * self.M, self.K) # Mask of aliases; key_padding_mask_entities_mention of True means mask. We want to find aliases with all masked entities key_padding_mask_mentions = torch.sum( ~key_padding_mask_entities_mention, dim=-1) == 0 # Unmask these aliases to avoid nan in attention key_padding_mask_entities_mention[key_padding_mask_mentions] = False # Iterate through stages query_tensor = entity_embedding for stage_index in range(self.num_model_stages): # As we are adding a residual in the attention modules, we can make embs empty embs = [] context_mat_dict = {} key_tensor_mention = query_tensor.transpose(0,1).contiguous().reshape(batch_size, self.M, self.K, self.hidden_size)\ .reshape(batch_size*self.M, self.K, self.hidden_size).transpose(0,1) #============================================================================ # Phrase module: compute attention between entities and words #============================================================================ word_entity_attn_context, word_entity_attn_weights = self.attention_modules[ f"stage_{stage_index}_entity_word"]( q=query_tensor, x=sent_tensor, key_mask=sent_embedding.mask, attn_mask=None) # Add embeddings to be merged in the output embs.append(word_entity_attn_context) # Save the attention weights self.attention_weights[ f"stage_{stage_index}_entity_word"] = word_entity_attn_weights #============================================================================ # Co-occurrence module: compute self attention over entities #============================================================================ # Move entity mask to device # TODO: move to device in init? self.e2e_entity_mask = self.e2e_entity_mask.to( key_padding_mask_entities.device) entity_attn_context, entity_attn_weights = self.attention_modules[ f"stage_{stage_index}_self_entity"]( x=query_tensor, key_mask=key_padding_mask_entities, attn_mask=self.e2e_entity_mask) # Mask out MxK of single aliases, alias_indices is batch x M, mask is true when single alias non_null_aliases = (self.K - key_padding_mask_entities.reshape( batch_size, self.M, self.K).sum(-1)) != 0 entity_attn_post_mask = ( non_null_aliases.sum(1) == 1).unsqueeze(1).expand( batch_size, self.K * self.M).transpose(0, 1) entity_attn_post_mask = entity_attn_post_mask.unsqueeze( -1).expand_as(entity_attn_context) entity_attn_context = torch.where( entity_attn_post_mask, torch.zeros_like(entity_attn_context), entity_attn_context) # Add embeddings to be merged in the output embs.append(entity_attn_context) # Save the attention weights self.attention_weights[ f"stage_{stage_index}_self_entity"] = entity_attn_weights #============================================================================ # Mention module: compute attention between entities and mentions #============================================================================ # output is 1 x batch*M x dim mention_entity_attn_context, mention_entity_attn_weights = self.attention_modules[ f"stage_{stage_index}_mention_entity"]( q=mention_tensor, x=key_tensor_mention, key_mask=key_padding_mask_entities_mention, attn_mask=None) # key_padding_mask_mentions mentions have all padded candidates, meaning their row in the context matrix are all nan mention_entity_attn_context[key_padding_mask_mentions.unsqueeze( 0)] = 0 mention_entity_attn_context = mention_entity_attn_context.expand(self.K, batch_size*self.M, self.hidden_size).transpose(0,1).\ reshape(batch_size, self.M*self.K, self.hidden_size).transpose(0,1) # Add embeddings to be merged in the output embs.append(mention_entity_attn_context) # Save the attention weights self.attention_weights[ f"stage_{stage_index}_mention_entity"] = mention_entity_attn_weights # Combine module output context_matrix_nokg = self.combine_modules[ f"stage_{stage_index}_combine"](embs) context_mat_dict[self.no_kg_key] = context_matrix_nokg.transpose( 0, 1).reshape(batch_size, self.M, self.K, self.hidden_size) #============================================================================ # KG module: add in KG connectivity bias #============================================================================ for key in self.kg_bias_keys: context_matrix_kg = torch.bmm( kg_bias_norms[key], context_matrix_nokg.transpose(0, 1)).transpose(0, 1) context_matrix_kg = (context_matrix_nokg + context_matrix_kg) / 2 context_mat_dict[ f"context_matrix_{key}"] = context_matrix_kg.transpose( 0, 1).reshape(batch_size, self.M, self.K, self.hidden_size) if stage_index < self.num_model_stages - 1: score = model_utils.max_score_context_matrix( context_mat_dict, self.predict_layers[DISAMBIG][ train_utils.get_stage_head_name(stage_index)]) out[DISAMBIG][ f"{train_utils.get_stage_head_name(stage_index)}"] = score # This will take the average of the context matrices that do not end in the key "_nokg"; if there are not kg bias terms, it will # select the context_matrix_nokg (as it's key, in this setting, will not end in _nokg) query_tensor = model_utils.generate_final_context_matrix(context_mat_dict, ending_key_to_exclude="_nokg")\ .reshape(batch_size, self.M*self.K, self.hidden_size).transpose(0,1) return context_mat_dict, out
def forward(self, slice_indices, true_label, entity_indices, model_outs): batch_size, M, K = entity_indices.shape true_entity_idx = true_label[DISAMBIG][FINAL_LOSS] # dummy disambig for when we are only training an end indicator model for weak supervision # and only have indicator predictions if DISAMBIG not in model_outs: model_outs[DISAMBIG] = {} for out_head_key in self.head_key_to_idx: model_outs[DISAMBIG][out_head_key] = torch.ones(batch_size, M, K).to(entity_indices.device) padded_entities = (true_entity_idx == -1) # we need to filter out cases where true_entity_idx is -1 bc alias occurs # in multiple subsentences to not double count it for a slice # it will only be predicted in one of these subsentences # where true_entity_idx is not -1 slice_indices[padded_entities] = 0 # compute count of head correct across eval slices head_values = true_entity_idx.new_full(true_entity_idx.size(), fill_value=int(not self.train_in_candidates)) head_correct = head_values == true_entity_idx head_correct[padded_entities] = 0 head_correct_slices = head_correct.unsqueeze(-1) * slice_indices total_head_correct_slices = head_correct_slices.sum((0,1)) self.alias_head_correct += total_head_correct_slices # total number of mentions per slice total_count_slices = slice_indices.sum((0,1)) self.alias_count += total_count_slices #=============================================== # Slicing model predictions (indicators) #=============================================== for i, slice_head in enumerate(self.train_heads): out_head_key = train_utils.get_slice_head_pred_name(slice_head) self._update_pred_counts(out_head_key, model_outs, slice_indices, entity_indices, padded_entities, true_entity_idx, true_label) # compute scores of train head conditions on predicting in slice or not buffer_idx = self.head_key_to_idx[out_head_key] # If INDICATOR is an output, take it. Otherwise take the ground truth as the model does not have that loss (e.g. HPS model). # outs_ind is BxMx2 if INDICATOR in model_outs: outs_ind = model_outs[INDICATOR][train_utils.get_slice_head_ind_name(slice_head)] outs_pred_in_slice = torch.argmax(outs_ind, -1).to(entity_indices.device) else: outs_pred_in_slice = torch.ones(entity_indices.shape[0],entity_indices.shape[1]).to(entity_indices.device).long() if out_head_key in model_outs[DISAMBIG]: # computes the train_head over each eval slice _, pred_correct = self._get_topk_correct(model_preds=model_outs[DISAMBIG][out_head_key], slice_indices=slice_indices, true_entity_idx=true_entity_idx, padded_entities=padded_entities, entity_indices=entity_indices, topk_val=1) # given that it predicted in slice # how many did I say were in the slice that were actually in the slice did I get correct pred_correct_pred_in_slice = pred_correct * outs_pred_in_slice.unsqueeze(-1) * slice_indices total_pred_correct_pred_in_slice = pred_correct_pred_in_slice.sum((0,1)) # how many did I say were in the slice that were actually in the slice total_pred_in_slice = outs_pred_in_slice.unsqueeze(-1) * slice_indices total_pred_in_slice = total_pred_in_slice.sum((0,1)) # print(total_pred_in_slice, total_pred_correct_pred_in_slice, slice_head) self.alias_pred_correct_pred_in_slice[buffer_idx] += total_pred_correct_pred_in_slice self.alias_pred_count_pred_in_slice[buffer_idx] += total_pred_in_slice #=============================================== # Final prediction head #=============================================== out_head_key = FINAL_LOSS self._update_pred_counts(out_head_key, model_outs, slice_indices, entity_indices, padded_entities, true_entity_idx, true_label) #=============================================== # Stage heads #=============================================== # compute the topk for each model stage loss head for stage_idx in range(self.num_model_stages-1): out_head_key = train_utils.get_stage_head_name(stage_idx) self._update_pred_counts(out_head_key, model_outs, slice_indices, entity_indices, padded_entities, true_entity_idx, true_label) return None