예제 #1
0
    def forward(self,
                input_ids=None,
                token_type_ids=None,
                attention_mask=None,
                document_mask=None,
                labels=None,
                input_embeddings=None):
        _, output, embeddings = self.bert(input_ids,
                                          token_type_ids,
                                          attention_mask,
                                          output_all_encoded_layers=False,
                                          input_embeddings=input_embeddings)
        output = self.dropout(output)

        # sentence level transform to document level
        length = document_mask.sum(dim=1).long()
        max_len = length.max()
        output = output.view(-1, max_len, self.config.hidden_size)

        # document level RNN processing
        if self.rnn is not None:
            output, hx, rev_order, mask = utils.prepare_rnn_seq(
                output, length, hx=None, masks=document_mask, batch_first=True)
            output, hn = self.rnn(output, hx=hx)
            output, hn = utils.recover_rnn_seq(output,
                                               rev_order,
                                               hx=hn,
                                               batch_first=True)

        # apply dropout for the output of rnn
        output = self.dropout_other(output)
        if self.dense is not None:
            # [batch, length, tag_space]
            output = self.dropout_other(F.elu(self.dense(output)))

        # final output layer
        if not self.use_crf:
            # not use crf
            output = self.dense_softmax(output)  # [batch, length, num_labels]
            if labels is None:
                _, preds = torch.max(output, dim=2)
                return preds, None, embeddings
            else:
                return (F.cross_entropy(output.view(-1, output.size(-1)),
                                        labels.view(-1),
                                        reduction='none') *
                        document_mask.view(-1)
                        ).sum() / document_mask.sum(), None, embeddings
        else:
            # CRF processing
            if labels is not None:
                loss, logits = self.crf.loss(output,
                                             labels,
                                             mask=document_mask)
                return loss.mean(), logits, embeddings
            else:
                seq_pred, logits = self.crf.decode(output,
                                                   mask=document_mask,
                                                   leading_symbolic=0)
                return seq_pred, logits, embeddings
예제 #2
0
    def _get_rnn_output(self,
                        input_word,
                        input_char,
                        main_task,
                        mask,
                        hx=None):
        length = mask.data.sum(dim=1).long()
        # [batch, length, word_dim]
        if self.use_elmo:
            input = self.elmo(input_word)
            input = input['elmo_representations'][1]
        else:
            # [batch, length, word_dim]
            # torch.Size([128, 20, 50])
            word = self.word_embedd(
                input_word)  # [bach size,sentence size,embedding size]
            #  [batch, length, char_length, char_dim]
            #  torch.Size([128, 20, 24, 300])
            char = self.char_embedd(input_char)
            char_size = char.size()
            # first transform to [batch *length, char_length, char_dim]
            # then transpose to [batch * length, char_dim, char_length]
            char = char.view(char_size[0] * char_size[1], char_size[2],
                             char_size[3]).transpose(1, 2)
            # put into cnn [batch*length, char_filters, char_length]
            # then put into maxpooling [batch * length, char_filters]
            char, _ = self.conv1d(char).max(dim=2)
            # reshape to [batch, length, char_filters]
            char = torch.tanh(char).view(char_size[0], char_size[1], -1)
            # apply dropout word on input
            word = self.dropout_in(word)
            char = self.dropout_in(char)
            # concatenate word and char [batch, length, word_dim+char_filter]
            input = torch.cat([word, char], dim=2)
        # apply dropout
        input = self.dropout_rnn_in(input)
        # prepare packed_sequence
        seq_input, hx, rev_order, mask, _ = utils.prepare_rnn_seq(
            input, length, hx=hx, masks=mask, batch_first=True)
        if main_task:
            seq_output, hn = self.rnn_2(seq_input, hx=hx)
        else:
            seq_output, hn = self.rnn_1(seq_input, hx=hx)
        output, hn = utils.recover_rnn_seq(seq_output,
                                           rev_order,
                                           hx=hn,
                                           batch_first=True)
        output = self.dropout_out(output)

        pass
        if self.use_lm:
            output_size = output.size()
            # print output_size
            lm = output.view(output_size[0], output_size[1], 2, -1)
            # print output_lm.size()
            lm_fw = lm[:, :, 0]
            lm_bw = lm[:, :, 1]
            return output, hn, mask, length, lm_fw, lm_bw
        else:
            return output, hn, mask, length