Example #1
0
 def forward(self, elems, context, mask):
     batch, _, emb_dim = elems.shape
     elems_norm = model_utils.normalize_matrix(self.emb_linear(elems),
                                               dim=-1)
     context_norm = model_utils.normalize_matrix(
         self.context_linear(context), dim=-1)
     scores = torch.bmm(elems_norm, context_norm.unsqueeze(-1)).squeeze(-1)
     # mask is true where we have valid elems, false for invalid (i.e. padded) elems
     probs = eval_utils.masked_softmax(pred=scores / self.tau,
                                       mask=mask,
                                       dim=-1).unsqueeze(-1)
     # probs = eval_utils.masked_gumbel(pred=scores, mask=mask, tau=self.tau, dim=-1).unsqueeze(-1)
     out = (elems * probs).sum(1)
     assert list(out.shape) == [batch, emb_dim]
     return out
    def forward(self, sent_emb, start_span_idx):
        """Model forward.

        Args:
            sent_emb: sentence embedding (B x N x L)
            start_span_idx: span index into sentence embedding (B x M)

        Returns: type embeding tensor (B x M x K x dim), type weight prediction (B x M x num_types)
        """
        batch, M = start_span_idx.shape
        alias_mask = start_span_idx == -1
        # Get alias tensor and expand to be for each candidate for soft attn
        alias_word_tensor = model_utils.select_alias_word_sent(start_span_idx, sent_emb)

        # batch x M x num_types
        batch_type_pred = self.prediction(alias_word_tensor)
        batch_type_weights = self.type_softmax(batch_type_pred)
        # batch x M x emb_size
        batch_type_embs = torch.matmul(
            batch_type_weights, self.type_embedding.unsqueeze(0)
        )
        # mask out unk alias embeddings
        batch_type_embs[alias_mask] = 0
        batch_type_embs = batch_type_embs.unsqueeze(2).expand(
            batch, M, self.K, self.emb_size
        )
        # normalize the output before being concatenated
        batch_type_embs = model_utils.normalize_matrix(batch_type_embs, dim=3)
        return batch_type_embs, batch_type_pred
Example #3
0
    def forward(self, sent_embedding, alias_idx_pair_sent, entity_embedding,
                entity_mask):
        batch_size = sent_embedding.tensor.shape[0]
        # Create list of all entity tensors
        alias_list = []
        alias_indices = None
        for embedding in entity_embedding:
            # Entity shape: batch_size x M x K x embedding_dim
            assert (embedding.tensor.shape[0] == batch_size)
            assert (embedding.tensor.shape[1] == self.M)
            assert (embedding.tensor.shape[2] == self.K)
            emb = embedding.tensor
            if alias_indices is not None:
                assert torch.equal(
                    alias_indices, embedding.alias_indices
                ), "Alias indices should not be different between embeddings in embCombiner"
            alias_indices = embedding.alias_indices
            # Normalize input embeddings
            if embedding.normalize:
                emb = model_utils.normalize_matrix(emb, dim=3)
                assert not torch.isnan(emb).any()
                assert not torch.isinf(emb).any()
            alias_list.append(emb)
        alias_tensor = self.linear_layers['project_embedding'](alias_list)
        alias_tensor_first = self.position_enc['alias'](
            alias_tensor,
            alias_idx_pair_sent[0].transpose(0,
                                             1).repeat(self.K, 1,
                                                       1).transpose(0,
                                                                    2).long())
        alias_tensor_last = self.position_enc['alias'](
            alias_tensor,
            alias_idx_pair_sent[1].transpose(0,
                                             1).repeat(self.K, 1,
                                                       1).transpose(0,
                                                                    2).long())
        alias_tensor = self.position_enc['alias_position_cat'](
            [alias_tensor_first, alias_tensor_last])

        proj_ent_embedding = DottedDict(
            tensor=alias_tensor,
            # Position of entities in sentence
            pos_in_sent=alias_idx_pair_sent,
            # Indexes of aliases
            alias_indices=alias_indices,
            # All entity embeddings have the same mask currently
            mask=embedding.mask,
            # Do not normalize this embedding if normalized is called
            normalize=False,
            dim=alias_tensor.shape[-1])
        return sent_embedding, proj_ent_embedding
Example #4
0
    def normalize_and_dropout_emb(self,
                                  embedding: torch.Tensor) -> torch.Tensor:
        """Whether to normalize and dropout embedding.

        Args:
            embedding: embedding

        Returns: adjusted embedding
        """
        if self.dropout1d_perc > 0:
            embedding = model_utils.emb_1d_dropout(self.training,
                                                   self.dropout1d_perc,
                                                   embedding)
        elif self.dropout2d_perc > 0:
            embedding = model_utils.emb_2d_dropout(self.training,
                                                   self.dropout2d_perc,
                                                   embedding)
        # We enforce that self.normalize is instantiated inside each subclass
        if self.normalize is True:
            embedding = model_utils.normalize_matrix(embedding, dim=-1)
        return embedding
Example #5
0
    def forward(self, alias_idx_pair_sent, sent_embedding, entity_embedding,
                batch_prepped, batch_on_the_fly_data):
        batch_size = sent_embedding.tensor.shape[0]
        out = {DISAMBIG: {}}

        # Prepare inputs for attention modules

        # Get the KG metadata for the KG module - this is 0/1 if pair is connected
        if REL_INDICES_KEY in batch_on_the_fly_data:
            kg_bias = batch_on_the_fly_data[REL_INDICES_KEY].float().to(
                sent_embedding.tensor.device).reshape(batch_size,
                                                      self.M * self.K,
                                                      self.M * self.K)
        else:
            kg_bias = torch.zeros(batch_size, self.M * self.K, self.M *
                                  self.K).to(sent_embedding.tensor.device)
        kg_bias_diag = kg_bias + self.kg_weight * torch.eye(
            self.M * self.K).repeat(batch_size, 1, 1).view(
                batch_size, self.M * self.K, self.M * self.K).to(
                    kg_bias.device)
        kg_bias_norm = self.softmax(
            kg_bias_diag.masked_fill((kg_bias_diag == 0), float(-1e9)))

        sent_tensor = sent_embedding.tensor.transpose(0, 1)

        # Resize the alias embeddings and the entity mask from B x M x K x D to B x (M*K) x D
        entity_mask = entity_embedding.mask
        entity_embedding = entity_embedding.tensor.contiguous().view(
            batch_size, self.M * self.K, self.hidden_size)
        entity_embedding = entity_embedding.transpose(
            0, 1)  # reshape for attention
        key_padding_mask_entities = entity_mask.contiguous().view(
            batch_size, self.M * self.K)

        # Iterate through stages
        query_tensor = entity_embedding
        for stage_index in range(self.num_model_stages):
            # As we are adding a residual in the attention modules, we can make embs empty
            embs = []
            #============================================================================
            # Phrase module: compute attention between entities and words
            #============================================================================
            word_entity_attn_context, word_entity_attn_weights = self.attention_modules[
                f"stage_{stage_index}_entity_word"](
                    q=query_tensor,
                    x=sent_tensor,
                    key_mask=sent_embedding.mask,
                    attn_mask=None)
            # Add embeddings to be merged in the output
            embs.append(word_entity_attn_context)
            # Save the attention weights
            self.attention_weights[
                f"stage_{stage_index}_entity_word"] = word_entity_attn_weights

            #============================================================================
            # Co-occurrence module: compute self attention over entities
            #============================================================================
            # Move entity mask over
            self.e2e_entity_mask = self.e2e_entity_mask.to(
                key_padding_mask_entities.device)

            entity_attn_context, entity_attn_weights = self.attention_modules[
                f"stage_{stage_index}_self_entity"](
                    x=query_tensor,
                    key_mask=key_padding_mask_entities,
                    attn_mask=self.e2e_entity_mask)

            # Mask out MxK of single aliases, alias_indices is batch x M, mask is true when single alias
            non_null_aliases = (self.K - key_padding_mask_entities.reshape(
                batch_size, self.M, self.K).sum(-1)) != 0
            entity_attn_post_mask = (
                non_null_aliases.sum(1) == 1).unsqueeze(1).expand(
                    batch_size, self.K * self.M).transpose(0, 1)
            entity_attn_post_mask = entity_attn_post_mask.unsqueeze(
                -1).expand_as(entity_attn_context)
            entity_attn_context = torch.where(
                entity_attn_post_mask, torch.zeros_like(entity_attn_context),
                entity_attn_context)

            # Add embeddings to be merged in the output
            embs.append(entity_attn_context)
            # Save the attention weights
            self.attention_weights[
                f"stage_{stage_index}_self_entity"] = entity_attn_weights

            context_matrix = self.combine_modules[
                f"stage_{stage_index}_combine"](embs)

            #============================================================================
            # KG module: add in KG connectivity bias
            #============================================================================
            context_matrix_kg = torch.bmm(kg_bias_norm,
                                          context_matrix.transpose(
                                              0, 1)).transpose(0, 1)

            if stage_index < self.num_model_stages - 1:
                pred = self.predict_layers[DISAMBIG][
                    train_utils.get_stage_head_name(stage_index)](
                        context_matrix)
                pred = pred.transpose(0, 1).squeeze(2).reshape(
                    batch_size, self.M, self.K)
                pred_kg = self.predict_layers[DISAMBIG][
                    train_utils.get_stage_head_name(stage_index)](
                        (context_matrix + context_matrix_kg) / 2)
                pred_kg = pred_kg.transpose(0, 1).squeeze(2).reshape(
                    batch_size, self.M, self.K)
                out[DISAMBIG][
                    f"{train_utils.get_stage_head_name(stage_index)}"] = torch.max(
                        torch.cat([pred.unsqueeze(3),
                                   pred_kg.unsqueeze(3)],
                                  dim=-1),
                        dim=-1)[0]
                pred_logit = eval_utils.masked_class_logsoftmax(
                    pred=pred, mask=~entity_mask)
                context_norm = model_utils.normalize_matrix(
                    (context_matrix + context_matrix_kg) / 2, dim=2)
                assert not torch.isnan(context_norm).any()
                assert not torch.isinf(context_norm).any()
                # Add predictions so model can learn from previous predictions
                context_matrix = context_norm + pred_logit.contiguous().view(
                    batch_size, self.M * self.K, 1).transpose(0, 1)
            else:
                context_matrix_nokg = context_matrix
                context_matrix = (context_matrix + context_matrix_kg) / 2
                context_matrix_main = context_matrix

            query_tensor = context_matrix

        context_matrix_nokg = context_matrix_nokg.transpose(0, 1).reshape(
            batch_size, self.M, self.K, self.hidden_size)
        context_matrix_main = context_matrix_main.transpose(0, 1).reshape(
            batch_size, self.M, self.K, self.hidden_size)

        # context_mat_dict is the contextualized entity embeddings, out is the predictions
        context_mat_dict = {
            "context_matrix_nokg": context_matrix_nokg,
            MAIN_CONTEXT_MATRIX: context_matrix_main
        }
        return context_mat_dict, out