def soft_attn_merge(self, entity_package, batch_type_ids, batch_type_emb, sent_emb): """For each candidate, use a weighted average of the type embs based on type emb similarity to the contextualized mention embedding.""" batch, M, K, num_types, dim = batch_type_emb.shape # we don't want to compute probabilities over padded types # when there are no types -- we'll just get an average over unk types (i.e. the unk type) mask = (batch_type_ids < (self.num_types_with_pad_and_unk - 1)).reshape( batch * M * K, num_types) # Get alias tensor and expand to be for each candidate for soft attn alias_word_tensor = model_utils.select_alias_word_sent( entity_package.pos_in_sent, sent_emb, index=0) _, _, sent_dim = alias_word_tensor.shape alias_word_tensor = alias_word_tensor.unsqueeze(2).expand( batch, M, K, sent_dim) # Reshape for soft attn batch_type_emb = batch_type_emb.contiguous().reshape( batch * M * K, num_types, dim) alias_word_tensor = alias_word_tensor.contiguous().reshape( batch * M * K, sent_dim) # Get soft attn batch_type_emb = self.soft_attn(batch_type_emb, alias_word_tensor, mask=mask) # Convert batch back to original shape batch_type_emb = batch_type_emb.reshape(batch, M, K, dim) return batch_type_emb
def forward(self, sent_emb, entity_package, entity_embs): batch, M, K = entity_package.tensor.shape # Get alias tensor and expand to be for each candidate for soft attn alias_word_tensor = model_utils.select_alias_word_sent( entity_package.pos_in_sent, sent_emb, index=0) alias_mask = entity_package.alias_indices == -1 # batch x M x num_types batch_type_pred = self.prediction(alias_word_tensor) batch_type_weights = self.type_softmax(batch_type_pred) # batch x M x emb_size batch_type_embs = torch.matmul(batch_type_weights, self.type_embedding.unsqueeze(0)) # mask out unk alias embeddings batch_type_embs[alias_mask] = 0 batch_type_embs = batch_type_embs.unsqueeze(2).expand( batch, M, K, self.emb_size) res = DottedDict(tensor=batch_type_embs, pos_in_sent=entity_package.pos_in_sent, alias_indices=entity_package.alias_indices, mask=entity_package.mask, normalize=True) entity_embs.append(res) return entity_embs, batch_type_pred
def forward(self, sent_emb, start_span_idx): """Model forward. Args: sent_emb: sentence embedding (B x N x L) start_span_idx: span index into sentence embedding (B x M) Returns: type embeding tensor (B x M x K x dim), type weight prediction (B x M x num_types) """ batch, M = start_span_idx.shape alias_mask = start_span_idx == -1 # Get alias tensor and expand to be for each candidate for soft attn alias_word_tensor = model_utils.select_alias_word_sent(start_span_idx, sent_emb) # batch x M x num_types batch_type_pred = self.prediction(alias_word_tensor) batch_type_weights = self.type_softmax(batch_type_pred) # batch x M x emb_size batch_type_embs = torch.matmul( batch_type_weights, self.type_embedding.unsqueeze(0) ) # mask out unk alias embeddings batch_type_embs[alias_mask] = 0 batch_type_embs = batch_type_embs.unsqueeze(2).expand( batch, M, self.K, self.emb_size ) # normalize the output before being concatenated batch_type_embs = model_utils.normalize_matrix(batch_type_embs, dim=3) return batch_type_embs, batch_type_pred
def test_select_alias_word(self): # batch = 3, M = 2, N = 5, H = 4 sent_embedding = torch.randn(3, 5, 4) alias_pos_in_sent = torch.tensor([[2, 1], [1, -1], [0, -1]]) res = model_utils.select_alias_word_sent(alias_pos_in_sent, sent_embedding) res_gold = torch.zeros(3, 2, 4) res_gold[0][0] = sent_embedding[0][2] res_gold[0][1] = sent_embedding[0][1] res_gold[1][0] = sent_embedding[1][1] res_gold[2][0] = sent_embedding[2][0] assert torch.equal(res_gold, res)
def forward(self, sent_embedding, entity_embedding, entity_mask, alias_idx_pair_sent, slice_emb_alias, slice_emb_ent): batch_size, M, K, _ = entity_embedding.shape # Index is which word to select alias_word_tensor = model_utils.select_alias_word_sent(alias_idx_pair_sent, sent_embedding, index=0) # Slice emb is hidden_size x 1 -> batch x M x hidden_size # Add in slice alias embedding slice_emb_alias = slice_emb_alias.unsqueeze(0).unsqueeze(1).expand(batch_size, M, self.hidden_size) alias_word_tensor = alias_word_tensor + slice_emb_alias alias_word_tensor = alias_word_tensor.transpose(0,1) assert alias_word_tensor.shape[1] == batch_size # Sentence attn between alias vector and full sentence alias_word_tensor, alias_word_weights = self.sent_alias_attn(q=alias_word_tensor, x=sent_embedding.tensor.transpose(0,1), key_mask=sent_embedding.mask, attn_mask=None) # Add in slice entity embedding slice_emb_ent = slice_emb_ent.unsqueeze(0).unsqueeze(1).expand(batch_size, M, self.hidden_size) alias_word_tensor = alias_word_tensor.transpose(0,1) + slice_emb_ent alias_word_tensor = alias_word_tensor.transpose(0,1) entity_attn_tensor = entity_embedding.contiguous().view(batch_size, M*K, self.hidden_size).transpose(0,1) key_padding_mask_entities = entity_mask.contiguous().view(batch_size, M*K) # Each M alias should ONLY pay attention to it's OWN candidates entity_mask = torch.ones((M, K*M)).to(key_padding_mask_entities.device) # M x (M*K) # TODO: move this to init for i in range(M): entity_mask[i, i*K:(i+1)*K] = 0.0 # Must manually move this to the device as it's not part of a module entity_mask = entity_mask.masked_fill((entity_mask == 1), float(-1e9)) alias_entity_attn_context, alias_entity_attn_weights = self.attention_module( q=alias_word_tensor, x=entity_attn_tensor, key_mask=key_padding_mask_entities, attn_mask=entity_mask ) # Returns batch x M x hidden return alias_entity_attn_context.transpose(0,1), alias_word_weights
def forward( self, sent_embedding, sent_embedding_mask, entity_embedding, entity_embedding_mask, start_span_idx, end_span_idx, batch_on_the_fly_data, ): """Model forward. Args: sent_embedding: sentence embedding (B x N x L) sent_embedding_mask: sentence embedding mask (B x N) entity_embedding: entity embedding (B x M x K x H) entity_embedding_mask: entity embedding mask (B x M x K) start_span_idx: start mention index into sentence (B x M) end_span_idx: end mention index into sentence (B x M) batch_on_the_fly_data: batch on the fly dictionary with values (B x (M*K) x (M*K)) of KG adjacency matrices Returns: Dict of Dict of intermediate output layer scores (will be empty for this model), Output entity embeddings (B x M x K x H), Candidate scores (B x M x K) """ out = {DISAMBIG: {}} context_mat_dict = {} batch_size, M, K, emb_dim = entity_embedding.shape alias_start_idx_sent = start_span_idx alias_end_idx_sent = end_span_idx assert ( emb_dim == self.hidden_size ), f"BERT NED requires the learned entity embedding dim be the same as the hidden size" assert alias_start_idx_sent.shape == alias_end_idx_sent.shape # Get alias words from sent embedding then cat and proj alias_start_word_tensor = model_utils.select_alias_word_sent( alias_start_idx_sent, sent_embedding) alias_end_word_tensor = model_utils.select_alias_word_sent( alias_end_idx_sent, sent_embedding) alias_pair_word_tensor = torch.cat( [alias_start_word_tensor, alias_end_word_tensor], dim=-1) alias_emb = ( self.span_proj(alias_pair_word_tensor).unsqueeze(2).expand( batch_size, M, self.K, self.hidden_size)) alias_emb = (alias_emb.contiguous().reshape( (batch_size * M * self.K), self.hidden_size).unsqueeze(1)) # entity_embedding_mask: if I don't have 30 candidates, use a mask to fill the rest of the # matrix for empty candidates entity_embedding_zeroed = torch.where( entity_embedding_mask.unsqueeze(-1), torch.zeros_like(entity_embedding), entity_embedding, ) entity_embedding_tensor = ( entity_embedding_zeroed.contiguous().reshape( (batch_size * M * self.K), self.hidden_size).unsqueeze(-1)) # Performs batch wise dot produce across each dim=0 dimension score = (torch.bmm(alias_emb, entity_embedding_tensor).unsqueeze(-1).reshape( batch_size, M, self.K)) context_mat_dict[DISAMBIG] = entity_embedding_tensor.reshape( batch_size, M, self.K, self.hidden_size) return { "intermed_scores": out, "ent_embs": context_mat_dict, "final_scores": score, }
def forward( self, sent_embedding, sent_embedding_mask, entity_embedding, entity_embedding_mask, start_span_idx, end_span_idx, batch_on_the_fly_data, ): """Model forward. Args: sent_embedding: sentence embedding (B x N x L) sent_embedding_mask: sentence embedding mask (B x N) entity_embedding: entity embedding (B x M x K x H) entity_embedding_mask: entity embedding mask (B x M x K) start_span_idx: start mention index into sentence (B x M) end_span_idx: end mention index into sentence (B x M) batch_on_the_fly_data: batch on the fly dictionary with values (B x (M*K) x (M*K)) of KG adjacency matrices Returns: Dict of Dict of intermediate layer candidate scores (B x M x K), Dict of all output entity embeddings from each KG matrix (B x M x K x H) """ batch_size = sent_embedding.shape[0] out = {DISAMBIG: {}} # Create KG bias matrices for each kg bias key kg_bias_norms = {} for key in self.kg_bias_keys: kg_bias_norms[key] = (batch_on_the_fly_data[key].float().reshape( batch_size, self.M * self.K, self.M * self.K)) # get mention embedding # average words in mention; batch x M x dim mention_tensor_start = model_utils.select_alias_word_sent( start_span_idx, sent_embedding) mention_tensor_end = model_utils.select_alias_word_sent( end_span_idx, sent_embedding) mention_tensor = (mention_tensor_start + mention_tensor_end) / 2 # reshape for alias attention where each mention attends to its K candidates # query = batch*M x 1 x dim, key = value = batch*M x K x dim # softmax(QK^T) -> batch*M x 1 x K # softmax(QK^T)V -> batch*M x 1 x dim mention_tensor = mention_tensor.reshape(batch_size * self.M, 1, self.hidden_size).transpose( 0, 1) # get sentence embedding; move batch to middle sent_tensor = sent_embedding.transpose(0, 1) # Resize the alias embeddings and the entity mask from B x M x K x D -> B x (M*K) x D entity_embedding = entity_embedding.contiguous().view( batch_size, self.M * self.K, self.hidden_size) entity_embedding = entity_embedding.transpose( 0, 1) # reshape for attention key_padding_mask_entities = entity_embedding_mask.contiguous().view( batch_size, self.M * self.K) key_padding_mask_entities_mention = entity_embedding_mask.contiguous( ).view(batch_size * self.M, self.K) # Mask of aliases; key_padding_mask_entities_mention of True means mask. # We want to find aliases with all masked entities key_padding_mask_mentions = (torch.sum( ~key_padding_mask_entities_mention, dim=-1) == 0) # Unmask these aliases to avoid nan in attention key_padding_mask_entities_mention[key_padding_mask_mentions] = False # Iterate through stages query_tensor = entity_embedding for stage_index in range(self.num_model_stages): # As we are adding a residual in the attention modules, we can make embs empty embs = [] context_mat_dict = {} key_tensor_mention = (query_tensor.transpose( 0, 1).contiguous().reshape(batch_size, self.M, self.K, self.hidden_size).reshape( batch_size * self.M, self.K, self.hidden_size).transpose(0, 1)) # ============================================================================ # Phrase module: compute attention between entities and words # ============================================================================ word_entity_attn_context, word_entity_attn_weights = self.attention_modules[ f"stage_{stage_index}_entity_word"]( q=query_tensor, x=sent_tensor, key_mask=sent_embedding_mask, attn_mask=None, ) # Add embeddings to be merged in the output embs.append(word_entity_attn_context) # Save the attention weights self.attention_weights[ f"stage_{stage_index}_entity_word"] = word_entity_attn_weights # ============================================================================ # Co-occurrence module: compute self attention over entities # ============================================================================ # Move entity mask to device # TODO: move to device in init? self.e2e_entity_mask = self.e2e_entity_mask.to( key_padding_mask_entities.device) entity_attn_context, entity_attn_weights = self.attention_modules[ f"stage_{stage_index}_self_entity"]( x=query_tensor, key_mask=key_padding_mask_entities, attn_mask=self.e2e_entity_mask, ) # Mask out MxK of single aliases, alias_indices is batch x M, mask is true when single alias non_null_aliases = (self.K - key_padding_mask_entities.reshape( batch_size, self.M, self.K).sum(-1)) != 0 entity_attn_post_mask = (( non_null_aliases.sum(1) == 1).unsqueeze(1).expand( batch_size, self.K * self.M).transpose(0, 1)) entity_attn_post_mask = entity_attn_post_mask.unsqueeze( -1).expand_as(entity_attn_context) entity_attn_context = torch.where( entity_attn_post_mask, torch.zeros_like(entity_attn_context), entity_attn_context, ) # Add embeddings to be merged in the output embs.append(entity_attn_context) # Save the attention weights self.attention_weights[ f"stage_{stage_index}_self_entity"] = entity_attn_weights # ============================================================================ # Mention module: compute attention between entities and mentions # ============================================================================ # output is 1 x batch*M x dim ( mention_entity_attn_context, mention_entity_attn_weights, ) = self.attention_modules[f"stage_{stage_index}_mention_entity"]( q=mention_tensor, x=key_tensor_mention, key_mask=key_padding_mask_entities_mention, attn_mask=None, ) # key_padding_mask_mentions mentions have all padded candidates, # meaning their row in the context matrix are all nan mention_entity_attn_context[key_padding_mask_mentions.unsqueeze( 0)] = 0 mention_entity_attn_context = (mention_entity_attn_context.expand( self.K, batch_size * self.M, self.hidden_size).transpose(0, 1).reshape( batch_size, self.M * self.K, self.hidden_size).transpose(0, 1)) # Add embeddings to be merged in the output embs.append(mention_entity_attn_context) # Save the attention weights self.attention_weights[ f"stage_{stage_index}_mention_entity"] = mention_entity_attn_weights # Combine module output context_matrix_nokg = self.combine_modules[ f"stage_{stage_index}_combine"](embs) context_mat_dict[self.no_kg_key] = context_matrix_nokg.transpose( 0, 1).reshape(batch_size, self.M, self.K, self.hidden_size) # ============================================================================ # KG module: add in KG connectivity bias # ============================================================================ for key in self.kg_bias_keys: context_matrix_kg = torch.bmm( kg_bias_norms[key], context_matrix_nokg.transpose(0, 1)).transpose(0, 1) context_matrix_kg = (context_matrix_nokg + context_matrix_kg) / 2 context_mat_dict[ f"context_matrix_{key}"] = context_matrix_kg.transpose( 0, 1).reshape(batch_size, self.M, self.K, self.hidden_size) if stage_index < self.num_model_stages - 1: score = model_utils.max_score_context_matrix( context_mat_dict, self.predict_layers[DISAMBIG][ bootleg.utils.model_utils.get_stage_head_name( stage_index)], ) out[DISAMBIG][ f"{bootleg.utils.model_utils.get_stage_head_name(stage_index)}"] = score # This will take the average of the context matrices that do not end in the key "_nokg"; # if there are not kg bias terms, it will select the context_matrix_nokg # (as it's key, in this setting, will not end in _nokg) query_tensor = (model_utils.generate_final_context_matrix( context_mat_dict, ending_key_to_exclude="_nokg").reshape( batch_size, self.M * self.K, self.hidden_size).transpose(0, 1)) return { "intermed_scores": out, "ent_embs": context_mat_dict, "final_scores": None, }
def forward(self, alias_idx_pair_sent, sent_embedding, entity_embedding, batch_prepped, batch_on_the_fly_data): batch_size = sent_embedding.tensor.shape[0] out = {DISAMBIG: {}} # Create KG bias matrices for each kg bias key kg_bias_norms = {} for key in self.kg_bias_keys: kg_bias = batch_on_the_fly_data[key].float().to( sent_embedding.tensor.device).reshape(batch_size, self.M * self.K, self.M * self.K) kg_bias_diag = kg_bias + self.kg_bias_weights[key] * torch.eye( self.M * self.K).repeat(batch_size, 1, 1).view( batch_size, self.M * self.K, self.M * self.K).to( kg_bias.device) kg_bias_norm = self.kg_softmax( kg_bias_diag.masked_fill((kg_bias_diag == 0), float(-1e9))) kg_bias_norms[key] = kg_bias_norm # get mention embedding # average words in mention; batch x M x dim mention_tensor_start = model_utils.select_alias_word_sent( alias_idx_pair_sent, sent_embedding, index=0) mention_tensor_end = model_utils.select_alias_word_sent( alias_idx_pair_sent, sent_embedding, index=1) mention_tensor = (mention_tensor_start + mention_tensor_end) / 2 # reshape for alias attention where each mention attends to its K candidates # query = batch*M x 1 x dim, key = value = batch*M x K x dim # softmax(QK^T) -> batch*M x 1 x K # softmax(QK^T)V -> batch*M x 1 x dim mention_tensor = mention_tensor.reshape(batch_size * self.M, 1, self.hidden_size).transpose( 0, 1) # get sentence embedding; move batch to middle sent_tensor = sent_embedding.tensor.transpose(0, 1) # Resize the alias embeddings and the entity mask from B x M x K x D -> B x (M*K) x D entity_mask = entity_embedding.mask entity_embedding = entity_embedding.tensor.contiguous().view( batch_size, self.M * self.K, self.hidden_size) entity_embedding = entity_embedding.transpose( 0, 1) # reshape for attention key_padding_mask_entities = entity_mask.contiguous().view( batch_size, self.M * self.K) key_padding_mask_entities_mention = entity_mask.contiguous().view( batch_size * self.M, self.K) # Mask of aliases; key_padding_mask_entities_mention of True means mask. We want to find aliases with all masked entities key_padding_mask_mentions = torch.sum( ~key_padding_mask_entities_mention, dim=-1) == 0 # Unmask these aliases to avoid nan in attention key_padding_mask_entities_mention[key_padding_mask_mentions] = False # Iterate through stages query_tensor = entity_embedding for stage_index in range(self.num_model_stages): # As we are adding a residual in the attention modules, we can make embs empty embs = [] context_mat_dict = {} key_tensor_mention = query_tensor.transpose(0,1).contiguous().reshape(batch_size, self.M, self.K, self.hidden_size)\ .reshape(batch_size*self.M, self.K, self.hidden_size).transpose(0,1) #============================================================================ # Phrase module: compute attention between entities and words #============================================================================ word_entity_attn_context, word_entity_attn_weights = self.attention_modules[ f"stage_{stage_index}_entity_word"]( q=query_tensor, x=sent_tensor, key_mask=sent_embedding.mask, attn_mask=None) # Add embeddings to be merged in the output embs.append(word_entity_attn_context) # Save the attention weights self.attention_weights[ f"stage_{stage_index}_entity_word"] = word_entity_attn_weights #============================================================================ # Co-occurrence module: compute self attention over entities #============================================================================ # Move entity mask to device # TODO: move to device in init? self.e2e_entity_mask = self.e2e_entity_mask.to( key_padding_mask_entities.device) entity_attn_context, entity_attn_weights = self.attention_modules[ f"stage_{stage_index}_self_entity"]( x=query_tensor, key_mask=key_padding_mask_entities, attn_mask=self.e2e_entity_mask) # Mask out MxK of single aliases, alias_indices is batch x M, mask is true when single alias non_null_aliases = (self.K - key_padding_mask_entities.reshape( batch_size, self.M, self.K).sum(-1)) != 0 entity_attn_post_mask = ( non_null_aliases.sum(1) == 1).unsqueeze(1).expand( batch_size, self.K * self.M).transpose(0, 1) entity_attn_post_mask = entity_attn_post_mask.unsqueeze( -1).expand_as(entity_attn_context) entity_attn_context = torch.where( entity_attn_post_mask, torch.zeros_like(entity_attn_context), entity_attn_context) # Add embeddings to be merged in the output embs.append(entity_attn_context) # Save the attention weights self.attention_weights[ f"stage_{stage_index}_self_entity"] = entity_attn_weights #============================================================================ # Mention module: compute attention between entities and mentions #============================================================================ # output is 1 x batch*M x dim mention_entity_attn_context, mention_entity_attn_weights = self.attention_modules[ f"stage_{stage_index}_mention_entity"]( q=mention_tensor, x=key_tensor_mention, key_mask=key_padding_mask_entities_mention, attn_mask=None) # key_padding_mask_mentions mentions have all padded candidates, meaning their row in the context matrix are all nan mention_entity_attn_context[key_padding_mask_mentions.unsqueeze( 0)] = 0 mention_entity_attn_context = mention_entity_attn_context.expand(self.K, batch_size*self.M, self.hidden_size).transpose(0,1).\ reshape(batch_size, self.M*self.K, self.hidden_size).transpose(0,1) # Add embeddings to be merged in the output embs.append(mention_entity_attn_context) # Save the attention weights self.attention_weights[ f"stage_{stage_index}_mention_entity"] = mention_entity_attn_weights # Combine module output context_matrix_nokg = self.combine_modules[ f"stage_{stage_index}_combine"](embs) context_mat_dict[self.no_kg_key] = context_matrix_nokg.transpose( 0, 1).reshape(batch_size, self.M, self.K, self.hidden_size) #============================================================================ # KG module: add in KG connectivity bias #============================================================================ for key in self.kg_bias_keys: context_matrix_kg = torch.bmm( kg_bias_norms[key], context_matrix_nokg.transpose(0, 1)).transpose(0, 1) context_matrix_kg = (context_matrix_nokg + context_matrix_kg) / 2 context_mat_dict[ f"context_matrix_{key}"] = context_matrix_kg.transpose( 0, 1).reshape(batch_size, self.M, self.K, self.hidden_size) if stage_index < self.num_model_stages - 1: score = model_utils.max_score_context_matrix( context_mat_dict, self.predict_layers[DISAMBIG][ train_utils.get_stage_head_name(stage_index)]) out[DISAMBIG][ f"{train_utils.get_stage_head_name(stage_index)}"] = score # This will take the average of the context matrices that do not end in the key "_nokg"; if there are not kg bias terms, it will # select the context_matrix_nokg (as it's key, in this setting, will not end in _nokg) query_tensor = model_utils.generate_final_context_matrix(context_mat_dict, ending_key_to_exclude="_nokg")\ .reshape(batch_size, self.M*self.K, self.hidden_size).transpose(0,1) return context_mat_dict, out