Ejemplo n.º 1
0
    def forward(self, token_ids: torch.Tensor, entity_mask: torch.IntTensor,
                verb_mask: torch.IntTensor, loc_mask: torch.IntTensor,
                gold_loc_seq: torch.IntTensor, gold_state_seq: torch.IntTensor,
                num_cands: torch.IntTensor, sentence_mask: torch.IntTensor,
                cpnet_triples: List, state_rel_labels: torch.IntTensor,
                loc_rel_labels: torch.IntTensor):
        """
        Args:
            token_ids: size (batch * max_wiki, max_ctx_tokens)
            *_mask: size (batch, max_sents, max_tokens)
            loc_mask: size (batch, max_cands, max_sents + 1, max_tokens), +1 for location 0
            gold_loc_seq: size (batch, max_sents)
            gold_state_seq: size (batch, max_sents)
            state_rel_labels: size (batch, max_sents, max_cpnet)
            loc_rel_labels: size (batch, max_sents, max_cpnet)
            num_cands: size (batch,)
        """
        assert entity_mask.size(-2) == verb_mask.size(-2) == loc_mask.size(-2) - 1\
               == gold_state_seq.size(-1) == gold_loc_seq.size(-1) - 1
        assert entity_mask.size(-1) == verb_mask.size(-1) == loc_mask.size(-1)
        batch_size = entity_mask.size(0)
        max_tokens = entity_mask.size(-1)
        max_sents = gold_state_seq.size(-1)
        max_cands = loc_mask.size(-3)

        attention_mask = (token_ids != self.plm_tokenizer.pad_token_id).to(
            torch.int)
        plm_outputs = self.embed_encoder(token_ids,
                                         attention_mask=attention_mask)
        embeddings = plm_outputs[
            0]  # hidden states at the last layer, (batch, max_tokens, plm_hidden_size)

        token_rep, _ = self.TokenEncoder(
            embeddings)  # (batch, max_tokens, 2*hidden_size)
        token_rep = self.Dropout(token_rep)
        assert token_rep.size() == (batch_size, max_tokens,
                                    2 * self.hidden_size)

        cpnet_rep = self.CpnetEncoder(cpnet_triples,
                                      tokenizer=self.plm_tokenizer,
                                      encoder=self.cpnet_encoder)

        # state change prediction
        # size (batch, max_sents, NUM_STATES)
        tag_logits, state_attn_probs = self.StateTracker(
            encoder_out=token_rep,
            entity_mask=entity_mask,
            verb_mask=verb_mask,
            sentence_mask=sentence_mask,
            cpnet_triples=cpnet_triples,
            cpnet_rep=cpnet_rep)
        tag_mask = (gold_state_seq != PAD_STATE
                    )  # mask the padded part so they won't count in loss
        log_likelihood = self.CRFLayer(emissions=tag_logits,
                                       tags=gold_state_seq.long(),
                                       mask=tag_mask,
                                       reduction='token_mean')

        state_loss = -log_likelihood  # State classification loss is negative log likelihood
        pred_state_seq = self.CRFLayer.decode(emissions=tag_logits,
                                              mask=tag_mask)
        assert len(pred_state_seq) == batch_size
        correct_state_pred, total_state_pred = compute_state_accuracy(
            pred=pred_state_seq,
            gold=gold_state_seq.tolist(),
            pad_value=PAD_STATE)

        # location prediction
        # size (batch, max_cands, max_sents + 1)
        empty_mask = torch.zeros((batch_size, 1, max_tokens), dtype=torch.int)
        if self.use_cuda:
            empty_mask = empty_mask.cuda()
        entity_mask = torch.cat([empty_mask, entity_mask], dim=1)
        loc_logits, loc_attn_probs = self.LocationPredictor(
            encoder_out=token_rep,
            entity_mask=entity_mask,
            loc_mask=loc_mask,
            sentence_mask=sentence_mask,
            cpnet_triples=cpnet_triples,
            cpnet_rep=cpnet_rep)
        loc_logits = loc_logits.transpose(
            -1, -2)  # size (batch, max_sents + 1, max_cands)
        masked_loc_logits = self.mask_loc_logits(
            loc_logits=loc_logits,
            num_cands=num_cands)  # (batch, max_sents + 1, max_cands)
        masked_gold_loc_seq = self.mask_undefined_loc(
            gold_loc_seq=gold_loc_seq,
            mask_value=PAD_LOC)  # (batch, max_sents + 1)
        loc_loss = self.CrossEntropy(input=masked_loc_logits.view(
            batch_size * (max_sents + 1), max_cands + 1),
                                     target=masked_gold_loc_seq.view(
                                         batch_size * (max_sents + 1)).long())
        correct_loc_pred, total_loc_pred = compute_loc_accuracy(
            logits=masked_loc_logits,
            gold=masked_gold_loc_seq,
            pad_value=PAD_LOC)

        if loc_attn_probs is not None:
            loc_attn_probs = self.get_gold_attn_probs(loc_attn_probs,
                                                      gold_loc_seq)
        attn_loss, total_attn_pred = self.get_attn_loss(
            state_attn_probs, loc_attn_probs, state_rel_labels, loc_rel_labels)

        if self.is_test:  # inference
            pred_loc_seq = get_pred_loc(loc_logits=masked_loc_logits,
                                        gold_loc_seq=gold_loc_seq)
            return pred_state_seq, pred_loc_seq, correct_state_pred, total_state_pred, correct_loc_pred, total_loc_pred

        return state_loss, loc_loss, attn_loss, correct_state_pred, total_state_pred, \
               correct_loc_pred, total_loc_pred, total_attn_pred
Ejemplo n.º 2
0
    def forward(self, char_paragraph: torch.Tensor,
                entity_mask: torch.IntTensor, verb_mask: torch.IntTensor,
                loc_mask: torch.IntTensor, gold_loc_seq: torch.IntTensor,
                gold_state_seq: torch.IntTensor, num_cands: torch.IntTensor):
        """
        Args:
            gold_loc_seq: size (batch, max_sents)
            gold_state_seq: size (batch, max_sents)
            num_cands: size(batch,)
        """
        assert entity_mask.size(-2) == verb_mask.size(-2) == loc_mask.size(
            -2) == gold_state_seq.size(-1) == gold_loc_seq.size(-1)
        assert entity_mask.size(-1) == verb_mask.size(-1) == loc_mask.size(
            -1) == char_paragraph.size(-2)
        batch_size = char_paragraph.size(0)
        max_tokens = char_paragraph.size(1)
        max_sents = gold_state_seq.size(-1)
        max_cands = loc_mask.size(-3)

        embeddings = self.EmbeddingLayer(
            char_paragraph, verb_mask)  # (batch, max_tokens, embed_size)
        token_rep, _ = self.TokenEncoder(
            embeddings)  # (batch, max_tokens, 2*hidden_size)
        token_rep = self.Dropout(token_rep)
        assert token_rep.size() == (batch_size, max_tokens,
                                    2 * self.hidden_size)

        # state cheng prediction
        # size (batch, max_sents, NUM_STATES)
        tag_logits = self.StateTracker(encoder_out=token_rep,
                                       entity_mask=entity_mask,
                                       verb_mask=verb_mask)
        tag_mask = (gold_state_seq != PAD_STATE
                    )  # mask the padded part so they won't count in loss
        log_likelihood = self.CRFLayer(emissions=tag_logits,
                                       tags=gold_state_seq.long(),
                                       mask=tag_mask,
                                       reduction='token_mean')

        state_loss = -log_likelihood  # State classification loss is negative log likelihood
        pred_state_seq = self.CRFLayer.decode(emissions=tag_logits,
                                              mask=tag_mask)
        assert len(pred_state_seq) == batch_size
        correct_state_pred, total_state_pred = compute_state_accuracy(
            pred=pred_state_seq,
            gold=gold_state_seq.tolist(),
            pad_value=PAD_STATE)

        # location prediction
        # size (batch, max_cands, max_sents)
        loc_logits = self.LocationPredictor(encoder_out=token_rep,
                                            entity_mask=entity_mask,
                                            loc_mask=loc_mask)
        loc_logits = loc_logits.transpose(
            -1, -2)  # size (batch, max_sents, max_cands)
        masked_loc_logits = self.mask_loc_logits(
            loc_logits=loc_logits,
            num_cands=num_cands)  # (batch, max_sents, max_cands)
        masked_gold_loc_seq = self.mask_undefined_loc(
            gold_loc_seq=gold_loc_seq,
            mask_value=PAD_LOC)  # (batch, max_sents)
        loc_loss = self.CrossEntropy(
            input=masked_loc_logits.view(batch_size * max_sents, max_cands),
            target=masked_gold_loc_seq.view(batch_size * max_sents).long())
        correct_loc_pred, total_loc_pred = compute_loc_accuracy(
            logits=masked_loc_logits,
            gold=masked_gold_loc_seq,
            pad_value=PAD_LOC)
        # assert total_loc_pred > 0

        if self.is_test:  # inference
            pred_loc_seq = get_pred_loc(loc_logits=masked_loc_logits,
                                        gold_loc_seq=gold_loc_seq)
            return pred_state_seq, pred_loc_seq, correct_state_pred, total_state_pred, correct_loc_pred, total_loc_pred

        return state_loss, loc_loss, correct_state_pred, total_state_pred, correct_loc_pred, total_loc_pred