Ejemplo n.º 1
0
    def neg_log_likelihood_loss(self, gaz_list, reverse_gaz_list, word_inputs,
                                word_seq_lengths, batch_label, mask):
        """
        get the neg_log_likelihood_loss
        Args:
            gaz_list: the batch data's gaz, for every chinese char
            reverse_gaz_list: the reverse list
            word_inputs: word input ids, [batch_size, seq_len]
            word_seq_lengths: [batch_size]
            batch_label: [batch_size, seq_len]
            mask: [batch_size, seq_len]
        """
        batch_size = word_inputs.size(0)
        seq_len = word_inputs.size(1)
        lengths = list(map(int, word_seq_lengths))

        # print('one ', reverse_gaz_list[0][:10])

        ## get batch gaz ids
        batch_gaz_ids, batch_gaz_length, batch_gaz_mask = get_batch_gaz(
            reverse_gaz_list, batch_size, seq_len, self.gpu, type=self.type)

        # print('two ', batch_gaz_ids[0][:10])

        reverse_batch_gaz_ids, reverse_batch_gaz_length, reverse_batch_gaz_mask = get_batch_gaz(
            gaz_list, batch_size, seq_len, self.gpu, type=self.type)
        reverse_batch_gaz_ids = reverse_padded_sequence(
            reverse_batch_gaz_ids, lengths)
        reverse_batch_gaz_length = reverse_padded_sequence(
            reverse_batch_gaz_length, lengths)
        reverse_batch_gaz_mask = reverse_padded_sequence(
            reverse_batch_gaz_mask, lengths)

        ## word embedding
        word_embs = self.word_embedding(word_inputs)
        reverse_word_embs = reverse_padded_sequence(word_embs, lengths)

        ## gaz embedding
        gaz_embs = self.gaz_embed(
            (batch_gaz_ids, batch_gaz_length, batch_gaz_mask))
        reverse_gaz_embs = self.gaz_embed(
            (reverse_batch_gaz_ids, reverse_batch_gaz_length,
             reverse_batch_gaz_mask))

        ## lstm
        forward_inputs = torch.cat((word_embs, gaz_embs), dim=-1)
        backward_inputs = torch.cat((reverse_word_embs, reverse_gaz_embs),
                                    dim=-1)

        lstm_outs, _ = self.lstm((forward_inputs, backward_inputs),
                                 word_seq_lengths)

        ## hidden2tag
        outs = self.hidden2tag(lstm_outs)

        ## crf and loss
        loss = self.crf.neg_log_likelihood_loss(outs, mask, batch_label)
        _, tag_seq = self.crf._viterbi_decode(outs, mask)

        return loss, tag_seq
Ejemplo n.º 2
0
    def forward(self, gaz_list, reverse_gaz_list, word_inputs,
                word_seq_lengths, mask):
        """
        Args:
            gaz_list: the forward gaz_list
            reverse_gaz_list: the backward gaz list
            word_inputs: word ids
            word_seq_lengths: each sentence length
            mask: sentence mask
        """
        batch_size = word_inputs.size(0)
        seq_len = word_inputs.size(1)
        lengths = list(map(int, word_seq_lengths))

        ## get batch gaz ids
        batch_gaz_ids, batch_gaz_length, batch_gaz_mask = get_batch_gaz(
            reverse_gaz_list, batch_size, seq_len, self.gpu, type=self.type)
        reverse_batch_gaz_ids, reverse_batch_gaz_length, reverse_batch_gaz_mask = get_batch_gaz(
            gaz_list, batch_size, seq_len, self.gpu, type=self.type)
        reverse_batch_gaz_ids = reverse_padded_sequence(
            reverse_batch_gaz_ids, lengths)
        reverse_batch_gaz_length = reverse_padded_sequence(
            reverse_batch_gaz_length, lengths)
        reverse_batch_gaz_mask = reverse_padded_sequence(
            reverse_batch_gaz_mask, lengths)

        ## word embedding
        word_embs = self.word_embedding(word_inputs)
        reverse_word_embs = reverse_padded_sequence(word_embs, lengths)

        ## gaz embedding
        gaz_embs = self.gaz_embed(
            (batch_gaz_ids, batch_gaz_length, batch_gaz_mask))
        reverse_gaz_embs = self.gaz_embed(
            (reverse_batch_gaz_ids, reverse_batch_gaz_length,
             reverse_batch_gaz_mask))

        ## lstm
        forward_inputs = torch.cat((word_embs, gaz_embs), dim=-1)
        backward_inputs = torch.cat((reverse_word_embs, reverse_gaz_embs),
                                    dim=-1)

        lstm_outs, _ = self.lstm((forward_inputs, backward_inputs),
                                 word_seq_lengths)

        ## hidden2tag
        outs = self.hidden2tag(lstm_outs)

        ## crf and loss
        _, tag_seq = self.crf._viterbi_decode(outs, mask)

        return tag_seq