def forward(self, elems, context, mask): batch, _, emb_dim = elems.shape elems_norm = model_utils.normalize_matrix(self.emb_linear(elems), dim=-1) context_norm = model_utils.normalize_matrix( self.context_linear(context), dim=-1) scores = torch.bmm(elems_norm, context_norm.unsqueeze(-1)).squeeze(-1) # mask is true where we have valid elems, false for invalid (i.e. padded) elems probs = eval_utils.masked_softmax(pred=scores / self.tau, mask=mask, dim=-1).unsqueeze(-1) # probs = eval_utils.masked_gumbel(pred=scores, mask=mask, tau=self.tau, dim=-1).unsqueeze(-1) out = (elems * probs).sum(1) assert list(out.shape) == [batch, emb_dim] return out
def forward(self, sent_emb, start_span_idx): """Model forward. Args: sent_emb: sentence embedding (B x N x L) start_span_idx: span index into sentence embedding (B x M) Returns: type embeding tensor (B x M x K x dim), type weight prediction (B x M x num_types) """ batch, M = start_span_idx.shape alias_mask = start_span_idx == -1 # Get alias tensor and expand to be for each candidate for soft attn alias_word_tensor = model_utils.select_alias_word_sent(start_span_idx, sent_emb) # batch x M x num_types batch_type_pred = self.prediction(alias_word_tensor) batch_type_weights = self.type_softmax(batch_type_pred) # batch x M x emb_size batch_type_embs = torch.matmul( batch_type_weights, self.type_embedding.unsqueeze(0) ) # mask out unk alias embeddings batch_type_embs[alias_mask] = 0 batch_type_embs = batch_type_embs.unsqueeze(2).expand( batch, M, self.K, self.emb_size ) # normalize the output before being concatenated batch_type_embs = model_utils.normalize_matrix(batch_type_embs, dim=3) return batch_type_embs, batch_type_pred
def forward(self, sent_embedding, alias_idx_pair_sent, entity_embedding, entity_mask): batch_size = sent_embedding.tensor.shape[0] # Create list of all entity tensors alias_list = [] alias_indices = None for embedding in entity_embedding: # Entity shape: batch_size x M x K x embedding_dim assert (embedding.tensor.shape[0] == batch_size) assert (embedding.tensor.shape[1] == self.M) assert (embedding.tensor.shape[2] == self.K) emb = embedding.tensor if alias_indices is not None: assert torch.equal( alias_indices, embedding.alias_indices ), "Alias indices should not be different between embeddings in embCombiner" alias_indices = embedding.alias_indices # Normalize input embeddings if embedding.normalize: emb = model_utils.normalize_matrix(emb, dim=3) assert not torch.isnan(emb).any() assert not torch.isinf(emb).any() alias_list.append(emb) alias_tensor = self.linear_layers['project_embedding'](alias_list) alias_tensor_first = self.position_enc['alias']( alias_tensor, alias_idx_pair_sent[0].transpose(0, 1).repeat(self.K, 1, 1).transpose(0, 2).long()) alias_tensor_last = self.position_enc['alias']( alias_tensor, alias_idx_pair_sent[1].transpose(0, 1).repeat(self.K, 1, 1).transpose(0, 2).long()) alias_tensor = self.position_enc['alias_position_cat']( [alias_tensor_first, alias_tensor_last]) proj_ent_embedding = DottedDict( tensor=alias_tensor, # Position of entities in sentence pos_in_sent=alias_idx_pair_sent, # Indexes of aliases alias_indices=alias_indices, # All entity embeddings have the same mask currently mask=embedding.mask, # Do not normalize this embedding if normalized is called normalize=False, dim=alias_tensor.shape[-1]) return sent_embedding, proj_ent_embedding
def normalize_and_dropout_emb(self, embedding: torch.Tensor) -> torch.Tensor: """Whether to normalize and dropout embedding. Args: embedding: embedding Returns: adjusted embedding """ if self.dropout1d_perc > 0: embedding = model_utils.emb_1d_dropout(self.training, self.dropout1d_perc, embedding) elif self.dropout2d_perc > 0: embedding = model_utils.emb_2d_dropout(self.training, self.dropout2d_perc, embedding) # We enforce that self.normalize is instantiated inside each subclass if self.normalize is True: embedding = model_utils.normalize_matrix(embedding, dim=-1) return embedding
def forward(self, alias_idx_pair_sent, sent_embedding, entity_embedding, batch_prepped, batch_on_the_fly_data): batch_size = sent_embedding.tensor.shape[0] out = {DISAMBIG: {}} # Prepare inputs for attention modules # Get the KG metadata for the KG module - this is 0/1 if pair is connected if REL_INDICES_KEY in batch_on_the_fly_data: kg_bias = batch_on_the_fly_data[REL_INDICES_KEY].float().to( sent_embedding.tensor.device).reshape(batch_size, self.M * self.K, self.M * self.K) else: kg_bias = torch.zeros(batch_size, self.M * self.K, self.M * self.K).to(sent_embedding.tensor.device) kg_bias_diag = kg_bias + self.kg_weight * torch.eye( self.M * self.K).repeat(batch_size, 1, 1).view( batch_size, self.M * self.K, self.M * self.K).to( kg_bias.device) kg_bias_norm = self.softmax( kg_bias_diag.masked_fill((kg_bias_diag == 0), float(-1e9))) sent_tensor = sent_embedding.tensor.transpose(0, 1) # Resize the alias embeddings and the entity mask from B x M x K x D to B x (M*K) x D entity_mask = entity_embedding.mask entity_embedding = entity_embedding.tensor.contiguous().view( batch_size, self.M * self.K, self.hidden_size) entity_embedding = entity_embedding.transpose( 0, 1) # reshape for attention key_padding_mask_entities = entity_mask.contiguous().view( batch_size, self.M * self.K) # Iterate through stages query_tensor = entity_embedding for stage_index in range(self.num_model_stages): # As we are adding a residual in the attention modules, we can make embs empty embs = [] #============================================================================ # Phrase module: compute attention between entities and words #============================================================================ word_entity_attn_context, word_entity_attn_weights = self.attention_modules[ f"stage_{stage_index}_entity_word"]( q=query_tensor, x=sent_tensor, key_mask=sent_embedding.mask, attn_mask=None) # Add embeddings to be merged in the output embs.append(word_entity_attn_context) # Save the attention weights self.attention_weights[ f"stage_{stage_index}_entity_word"] = word_entity_attn_weights #============================================================================ # Co-occurrence module: compute self attention over entities #============================================================================ # Move entity mask over self.e2e_entity_mask = self.e2e_entity_mask.to( key_padding_mask_entities.device) entity_attn_context, entity_attn_weights = self.attention_modules[ f"stage_{stage_index}_self_entity"]( x=query_tensor, key_mask=key_padding_mask_entities, attn_mask=self.e2e_entity_mask) # Mask out MxK of single aliases, alias_indices is batch x M, mask is true when single alias non_null_aliases = (self.K - key_padding_mask_entities.reshape( batch_size, self.M, self.K).sum(-1)) != 0 entity_attn_post_mask = ( non_null_aliases.sum(1) == 1).unsqueeze(1).expand( batch_size, self.K * self.M).transpose(0, 1) entity_attn_post_mask = entity_attn_post_mask.unsqueeze( -1).expand_as(entity_attn_context) entity_attn_context = torch.where( entity_attn_post_mask, torch.zeros_like(entity_attn_context), entity_attn_context) # Add embeddings to be merged in the output embs.append(entity_attn_context) # Save the attention weights self.attention_weights[ f"stage_{stage_index}_self_entity"] = entity_attn_weights context_matrix = self.combine_modules[ f"stage_{stage_index}_combine"](embs) #============================================================================ # KG module: add in KG connectivity bias #============================================================================ context_matrix_kg = torch.bmm(kg_bias_norm, context_matrix.transpose( 0, 1)).transpose(0, 1) if stage_index < self.num_model_stages - 1: pred = self.predict_layers[DISAMBIG][ train_utils.get_stage_head_name(stage_index)]( context_matrix) pred = pred.transpose(0, 1).squeeze(2).reshape( batch_size, self.M, self.K) pred_kg = self.predict_layers[DISAMBIG][ train_utils.get_stage_head_name(stage_index)]( (context_matrix + context_matrix_kg) / 2) pred_kg = pred_kg.transpose(0, 1).squeeze(2).reshape( batch_size, self.M, self.K) out[DISAMBIG][ f"{train_utils.get_stage_head_name(stage_index)}"] = torch.max( torch.cat([pred.unsqueeze(3), pred_kg.unsqueeze(3)], dim=-1), dim=-1)[0] pred_logit = eval_utils.masked_class_logsoftmax( pred=pred, mask=~entity_mask) context_norm = model_utils.normalize_matrix( (context_matrix + context_matrix_kg) / 2, dim=2) assert not torch.isnan(context_norm).any() assert not torch.isinf(context_norm).any() # Add predictions so model can learn from previous predictions context_matrix = context_norm + pred_logit.contiguous().view( batch_size, self.M * self.K, 1).transpose(0, 1) else: context_matrix_nokg = context_matrix context_matrix = (context_matrix + context_matrix_kg) / 2 context_matrix_main = context_matrix query_tensor = context_matrix context_matrix_nokg = context_matrix_nokg.transpose(0, 1).reshape( batch_size, self.M, self.K, self.hidden_size) context_matrix_main = context_matrix_main.transpose(0, 1).reshape( batch_size, self.M, self.K, self.hidden_size) # context_mat_dict is the contextualized entity embeddings, out is the predictions context_mat_dict = { "context_matrix_nokg": context_matrix_nokg, MAIN_CONTEXT_MATRIX: context_matrix_main } return context_mat_dict, out