def forward(self, input_ids=None, token_type_ids=None, attention_mask=None, document_mask=None, labels=None, input_embeddings=None): _, output, embeddings = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, input_embeddings=input_embeddings) output = self.dropout(output) # sentence level transform to document level length = document_mask.sum(dim=1).long() max_len = length.max() output = output.view(-1, max_len, self.config.hidden_size) # document level RNN processing if self.rnn is not None: output, hx, rev_order, mask = utils.prepare_rnn_seq( output, length, hx=None, masks=document_mask, batch_first=True) output, hn = self.rnn(output, hx=hx) output, hn = utils.recover_rnn_seq(output, rev_order, hx=hn, batch_first=True) # apply dropout for the output of rnn output = self.dropout_other(output) if self.dense is not None: # [batch, length, tag_space] output = self.dropout_other(F.elu(self.dense(output))) # final output layer if not self.use_crf: # not use crf output = self.dense_softmax(output) # [batch, length, num_labels] if labels is None: _, preds = torch.max(output, dim=2) return preds, None, embeddings else: return (F.cross_entropy(output.view(-1, output.size(-1)), labels.view(-1), reduction='none') * document_mask.view(-1) ).sum() / document_mask.sum(), None, embeddings else: # CRF processing if labels is not None: loss, logits = self.crf.loss(output, labels, mask=document_mask) return loss.mean(), logits, embeddings else: seq_pred, logits = self.crf.decode(output, mask=document_mask, leading_symbolic=0) return seq_pred, logits, embeddings
def _get_rnn_output(self, input_word, input_char, main_task, mask, hx=None): length = mask.data.sum(dim=1).long() # [batch, length, word_dim] if self.use_elmo: input = self.elmo(input_word) input = input['elmo_representations'][1] else: # [batch, length, word_dim] # torch.Size([128, 20, 50]) word = self.word_embedd( input_word) # [bach size,sentence size,embedding size] # [batch, length, char_length, char_dim] # torch.Size([128, 20, 24, 300]) char = self.char_embedd(input_char) char_size = char.size() # first transform to [batch *length, char_length, char_dim] # then transpose to [batch * length, char_dim, char_length] char = char.view(char_size[0] * char_size[1], char_size[2], char_size[3]).transpose(1, 2) # put into cnn [batch*length, char_filters, char_length] # then put into maxpooling [batch * length, char_filters] char, _ = self.conv1d(char).max(dim=2) # reshape to [batch, length, char_filters] char = torch.tanh(char).view(char_size[0], char_size[1], -1) # apply dropout word on input word = self.dropout_in(word) char = self.dropout_in(char) # concatenate word and char [batch, length, word_dim+char_filter] input = torch.cat([word, char], dim=2) # apply dropout input = self.dropout_rnn_in(input) # prepare packed_sequence seq_input, hx, rev_order, mask, _ = utils.prepare_rnn_seq( input, length, hx=hx, masks=mask, batch_first=True) if main_task: seq_output, hn = self.rnn_2(seq_input, hx=hx) else: seq_output, hn = self.rnn_1(seq_input, hx=hx) output, hn = utils.recover_rnn_seq(seq_output, rev_order, hx=hn, batch_first=True) output = self.dropout_out(output) pass if self.use_lm: output_size = output.size() # print output_size lm = output.view(output_size[0], output_size[1], 2, -1) # print output_lm.size() lm_fw = lm[:, :, 0] lm_bw = lm[:, :, 1] return output, hn, mask, length, lm_fw, lm_bw else: return output, hn, mask, length