class BertSimpleClassifier(nn.Module): def __init__(self, bert_pretrained_weights, num_class): super().__init__() self.bert = BertModel.from_pretrained(bert_pretrained_weights) self.positional_encoding = PositionalEncoding(input_dim=768) # self.linear_doc = nn.Linear(768, 768) # self.linear_prompt = nn.Linear(768, 768) self.linear_layer = nn.Linear(768 * 2, num_class) self.dropout_layer = nn.Dropout(0.5) self.criterion = nn.NLLLoss(reduction='sum') nn.init.uniform_(self.linear_layer.weight.data, -0.1, 0.1) nn.init.zeros_(self.linear_layer.bias.data) def forward(self, inputs, mask, sent_counts, sent_lens, prompt_inputs, prompt_mask, prompt_sent_counts, prompt_sent_lens, label=None): """ :param prompt_sent_lens: :param prompt_sent_counts: :param prompt_inputs: :param prompt_mask: :param inputs: [batch size, max sent count, max sent len] :param mask: [batch size, max sent count, max sent len] :param sent_counts: [batch size] :param sent_lens: [batch size, max sent count] :param label: [batch size] :return: """ batch_size = inputs.shape[0] max_sent_count = inputs.shape[1] max_sent_length = inputs.shape[2] inputs = inputs.view(-1, inputs.shape[-1]) mask = mask.view(-1, mask.shape[-1]) # [batch size * max sent len, hid size] last_hidden_states = self.bert(input_ids=inputs, attention_mask=mask)[0] last_hidden_states = last_hidden_states.view(batch_size, max_sent_count, max_sent_length, -1) last_hidden_states = self.dropout_layer(last_hidden_states) prompt_inputs = prompt_inputs.view(-1, prompt_inputs.shape[-1]) prompt_mask = prompt_mask.view(-1, prompt_mask.shape[-1]) prompt_hidden_states = self.bert(input_ids=prompt_inputs, attention_mask=prompt_mask)[0] prompt_hidden_states = self.dropout_layer(prompt_hidden_states) docs = [] lens = [] for i in range(0, batch_size): doc = [] sent_count = sent_counts[i] sent_len = sent_lens[i] for j in range(sent_count): length = sent_len[j] cur_sent = last_hidden_states[i, j, :length, :] # print('cur sent shape', cur_sent.shape) doc.append(cur_sent) # mean for a doc doc_vec = torch.cat(doc, dim=0).unsqueeze(0) doc_vec = self.positional_encoding.forward(doc_vec) doc_vec = torch.mean(doc_vec, dim=1) lens.append(doc_vec.shape[0]) # print(i, 'doc shape', doc_vec.shape) docs.append(doc_vec) # [batch size, bert embedding dim] docs = torch.cat(docs, 0) prompt = [] for j in range(prompt_sent_counts): length = prompt_sent_lens[0][j] sent = prompt_hidden_states[j, :length, :] prompt.append(sent) prompt_vec = torch.cat(prompt, dim=0).unsqueeze(0) prompt_vec = self.positional_encoding.forward(prompt_vec) # mean [1, bert embedding dim] prompt_vec = torch.mean(prompt_vec, dim=1) # prompt_vec = self.linear_prompt(prompt_vec) doc_feature = docs prompt_feature = prompt_vec.expand_as(doc_feature) feature = torch.cat([doc_feature, prompt_feature], dim=-1) log_probs = torch.log_softmax(torch.tanh(self.linear_layer(feature)), dim=-1) # log_probs = self.classifier(docs) if label is not None: loss = self.criterion(input=log_probs.contiguous().view( -1, log_probs.shape[-1]), target=label.contiguous().view(-1)) else: loss = None prediction = torch.max(log_probs, dim=1)[1] return {'loss': loss, 'prediction': prediction}
class MixBertRecurrentAttentionRegressor(nn.Module): def __init__(self, bert_pretrained_weights): super().__init__() self.bert = BertModel.from_pretrained(bert_pretrained_weights) self.positional_encoding = PositionalEncoding(input_dim=768) self.linear_layer = nn.Linear(768 + 5 + 300, 1) self.dropout_layer = nn.Dropout(0.6) self.criterion = nn.MSELoss(reduction='sum') self.manual_feature_layer = nn.Linear(27, 5) self.prompt_global_attention = GlobalAttention(hid_dim=768, key_size=768) self.prompt_doc_attention = BahdanauAttention(hid_dim=768, key_size=768, query_size=768) self.segment_encoder = RNNEncoder(embedding_dim=768, hid_dim=150, num_layers=1, dropout_rate=0.5) nn.init.uniform_(self.linear_layer.weight.data, -0.1, 0.1) nn.init.zeros_(self.linear_layer.bias.data) def forward(self, inputs, mask, sent_counts, sent_lens, prompt_inputs, prompt_mask, prompt_sent_counts, prompt_sent_lens, min_score, max_score, manual_feature, label=None): """ :param manual_feature: [batch size] :param max_score: [batch size] :param min_score: [batch size] :param prompt_sent_lens: [batch size, max sent count] :param prompt_sent_counts: [batch size] :param prompt_inputs: [batch size, max sent count, max sent len] :param prompt_mask: [batch size, max sent count, max sent len] :param inputs: [batch size, max sent count, max sent len] :param mask: [batch size, max sent count, max sent len] :param sent_counts: [batch size] :param sent_lens: [batch size, max sent count] :param label: [batch size] :return: """ batch_size = inputs.shape[0] max_sent_count = inputs.shape[1] max_sent_length = inputs.shape[2] max_prompt_sent_count = prompt_inputs.shape[1] max_prompt_sent_length = prompt_inputs.shape[2] inputs = inputs.view(-1, inputs.shape[-1]) mask = mask.view(-1, mask.shape[-1]) # [batch size * max sent len, hid size] last_hidden_states = self.bert(input_ids=inputs, attention_mask=mask)[0] last_hidden_states = last_hidden_states.view(batch_size, max_sent_count, max_sent_length, -1) last_hidden_states = self.dropout_layer(last_hidden_states) prompt_inputs = prompt_inputs.view(-1, prompt_inputs.shape[-1]) prompt_mask = prompt_mask.view(-1, prompt_mask.shape[-1]) prompt_hidden_states = self.bert(input_ids=prompt_inputs, attention_mask=prompt_mask)[0] prompt_hidden_states = prompt_hidden_states.view( batch_size, max_prompt_sent_count, max_prompt_sent_length, -1) prompt_hidden_states = self.dropout_layer(prompt_hidden_states) docs = [] lens = [] doc_segments = [] for i in range(0, batch_size): doc = [] doc_segment = [] sent_count = sent_counts[i] sent_len = sent_lens[i] for j in range(sent_count): length = sent_len[j] cur_sent = last_hidden_states[i, j, :length, :] mean_cur_sent = torch.mean(cur_sent, dim=0) # print('cur sent shape', cur_sent.shape) doc.append(cur_sent) doc_segment.append(mean_cur_sent.unsqueeze(0)) # [1, len, hid size] doc_vec = torch.cat(doc, dim=0).unsqueeze(0) doc_vec = self.positional_encoding.forward(doc_vec) lens.append(doc_vec.shape[1]) # print(i, 'doc shape', doc_vec.shape) docs.append(doc_vec) doc_segments.append(doc_segment) batch_max_len = max(lens) for i, doc in enumerate(docs): if doc.shape[1] < batch_max_len: pd = (0, 0, 0, batch_max_len - doc.shape[1]) m = nn.ConstantPad2d(pd, 0) doc = m(doc) docs[i] = doc # [batch size, bert embedding dim] docs = torch.cat(docs, 0) docs_mask = get_mask_from_sequence_lengths( torch.tensor(lens), max_length=batch_max_len).to(docs.device) # print('lens ', lens) # print('docs shape', docs.shape) prompt_docs = [] prompt_lens = [] for i in range(0, batch_size): prompt_doc = [] prompt_sent_count = prompt_sent_counts[i] prompt_sent_len = prompt_sent_lens[i] for j in range(prompt_sent_count): length = prompt_sent_len[j] cur_sent = prompt_hidden_states[i, j, :length, :] prompt_doc.append(cur_sent) prompt_doc_vec = torch.cat(prompt_doc, dim=0).unsqueeze(0) prompt_doc_vec = self.positional_encoding.forward(prompt_doc_vec) prompt_lens.append(prompt_doc_vec.shape[1]) prompt_docs.append(prompt_doc_vec) prompt_batch_max_len = max(prompt_lens) for i, doc in enumerate(prompt_docs): if doc.shape[1] < prompt_batch_max_len: pd = (0, 0, 0, prompt_batch_max_len - doc.shape[1]) m = nn.ConstantPad2d(pd, 0) doc = m(doc) prompt_docs[i] = doc prompt_docs = torch.cat(prompt_docs, 0) prompt_attention_mask = get_mask_from_sequence_lengths( torch.tensor(prompt_lens), max_length=prompt_batch_max_len).to(docs.device) # [batch size, max seq len] prompt_vec_weights = self.prompt_global_attention( prompt_docs, prompt_attention_mask) # [batch size, bert hidden size] prompt_vec = torch.bmm(prompt_vec_weights.unsqueeze(1), prompt_docs).squeeze(1) # print('prompt len', prompt_len) doc_weights = self.prompt_doc_attention(query=prompt_vec, key=docs, mask=docs_mask) doc_vec = torch.bmm(doc_weights.unsqueeze(1), docs).squeeze(1) doc_feature = self.dropout_layer(torch.tanh(doc_vec)) manual_feature = torch.tanh( self.manual_feature_layer(self.dropout_layer(manual_feature))) # rnn segments encoder sorted_index = sorted(range(len(sent_counts)), key=lambda i: sent_counts[i], reverse=True) max_count = max_sent_count for idx, doc in enumerate(doc_segments): for i in range(max_count - len(doc)): doc.append(torch.zeros_like(doc[0])) doc_segments[idx] = torch.cat(doc, dim=0).unsqueeze(0) doc_segments = torch.cat(doc_segments, dim=0) sorted_doc_segments = doc_segments[sorted_index] sorted_batch_counts = sent_counts[sorted_index] final_hidden_states = self.segment_encoder( sorted_doc_segments, sorted_batch_counts)['final_hidden_states'] final_hidden_states[sorted_index] = final_hidden_states final_hidden_states = torch.tanh(final_hidden_states) final_hidden_states = self.dropout_layer(final_hidden_states) # feature = self.dropout_layer(torch.tanh(doc_vec)) # prompt_feature = self.dropout_layer(torch.tanh(prompt_vec.expand_as(doc_feature))) feature = torch.cat([doc_feature, manual_feature, final_hidden_states], dim=-1) grade = self.linear_layer(feature) if label is not None: # print('label ', label) # print('min score ', min_score) # print('max score ', max_score) # grade = grade * (max_score - min_score) + min_score label = (label.type_as(grade) - min_score.type_as(grade)) / ( max_score.type_as(grade) - min_score.type_as(grade)) loss = self.criterion( input=grade.contiguous().view(-1), target=label.type_as(grade).contiguous().view(-1)) else: loss = None prediction = grade * (max_score.type_as(grade) - min_score.type_as( grade)) + min_score.type_as(grade) return {'loss': loss, 'prediction': prediction}
class BertGlobalAttentionClassifier(nn.Module): def __init__(self, bert_pretrained_weights, num_class): super().__init__() self.bert = BertModel.from_pretrained(bert_pretrained_weights) self.positional_encoding = PositionalEncoding(input_dim=768) self.linear_layer = nn.Linear(768 * 2 + 5, num_class) self.manual_feature_layer = nn.Linear(27, 5) self.dropout_layer = nn.Dropout(0.5) self.criterion = nn.NLLLoss(reduction='mean') self.prompt_global_attention = GlobalAttention(hid_dim=768, key_size=768) self.doc_global_attention = GlobalAttention(hid_dim=768, key_size=768) nn.init.uniform_(self.linear_layer.weight.data, -0.1, 0.1) nn.init.zeros_(self.linear_layer.bias.data) def forward(self, inputs, mask, sent_counts, sent_lens, prompt_inputs, prompt_mask, prompt_sent_counts, prompt_sent_lens, manual_feature, label=None): """ :param prompt_sent_lens: :param prompt_sent_counts: :param prompt_inputs: :param prompt_mask: :param inputs: [batch size, max sent count, max sent len] :param mask: [batch size, max sent count, max sent len] :param sent_counts: [batch size] :param sent_lens: [batch size, max sent count] :param label: [batch size] :return: """ batch_size = inputs.shape[0] max_sent_count = inputs.shape[1] max_sent_length = inputs.shape[2] inputs = inputs.view(-1, inputs.shape[-1]) mask = mask.view(-1, mask.shape[-1]) # [batch size * max sent len, hid size] last_hidden_states = self.bert(input_ids=inputs, attention_mask=mask)[0] last_hidden_states = last_hidden_states.view(batch_size, max_sent_count, max_sent_length, -1) prompt_inputs = prompt_inputs.view(-1, prompt_inputs.shape[-1]) prompt_mask = prompt_mask.view(-1, prompt_mask.shape[-1]) prompt_hidden_states = self.bert(input_ids=prompt_inputs, attention_mask=prompt_mask)[0] docs = [] lens = [] for i in range(0, batch_size): doc = [] sent_count = sent_counts[i] sent_len = sent_lens[i] for j in range(sent_count): length = sent_len[j] cur_sent = last_hidden_states[i, j, :length, :] # print('cur sent shape', cur_sent.shape) doc.append(cur_sent) # mean for a doc doc_vec = torch.cat(doc, dim=0).unsqueeze(0) doc_vec = self.positional_encoding.forward(doc_vec) lens.append(doc_vec.shape[1]) # print(i, 'doc shape', doc_vec.shape) docs.append(doc_vec) batch_max_len = max(lens) for i, doc in enumerate(docs): if doc.shape[1] < batch_max_len: pd = (0, 0, 0, batch_max_len - doc.shape[1]) m = nn.ConstantPad2d(pd, 0) doc = m(doc) docs[i] = doc # [batch size, bert embedding dim] docs = torch.cat(docs, 0) docs_mask = get_mask_from_sequence_lengths( torch.tensor(lens), max_length=batch_max_len).to(docs.device) prompt = [] for j in range(prompt_sent_counts): length = prompt_sent_lens[0][j] sent = prompt_hidden_states[j, :length, :] prompt.append(sent) prompt_vec = torch.cat(prompt, dim=0).unsqueeze(0) prompt_vec = self.positional_encoding.forward(prompt_vec) prompt_len = prompt_vec.shape[1] prompt_attention_mask = get_mask_from_sequence_lengths( torch.tensor([prompt_len]), max_length=prompt_len).to(prompt_vec.device) # [1, seq len] prompt_vec_weights = self.prompt_global_attention( prompt_vec, prompt_attention_mask) # [1, bert hidden size] prompt_vec = torch.bmm(prompt_vec_weights.unsqueeze(1), prompt_vec).squeeze(1) doc_weights = self.doc_global_attention(docs, docs_mask) doc_vec = torch.bmm(doc_weights.unsqueeze(1), docs).squeeze(1) doc_feature = self.dropout_layer(torch.tanh(doc_vec)) prompt_feature = self.dropout_layer( torch.tanh(prompt_vec.expand_as(doc_feature))) feature = torch.cat([doc_feature, prompt_feature], dim=-1) log_probs = torch.log_softmax(self.linear_layer(feature), dim=-1) # log_probs = self.classifier(docs) if label is not None: loss = self.criterion(input=log_probs.contiguous().view( -1, log_probs.shape[-1]), target=label.contiguous().view(-1)) else: loss = None prediction = torch.max(log_probs, dim=1)[1] return {'loss': loss, 'prediction': prediction}
class BertClassifier(nn.Module): def __init__(self, bert_pretrained_weights, num_class, kernel_size, kernel_nums): super().__init__() self.bert = BertModel.from_pretrained(bert_pretrained_weights) self.positional_encoding = PositionalEncoding(input_dim=768) # self.classifier = CNNClassifier(num_class=num_class, # input_dim=768, # kernel_nums=kernel_nums, # kernel_sizes=kernel_size, # max_kernel_size=kernel_size[-1]) # self.essay_feature_extracter = CNNFeatureExtrater( # input_dim=768, # output_dim=300, # kernel_nums=kernel_nums, # kernel_sizes=kernel_size, # max_kernel_size=kernel_size[-1] # ) # self.prompt_feature_extracter = CNNFeatureExtrater( # input_dim=768, # output_dim=300, # kernel_sizes=[2, 4, 8, 16, 32, 64, 128, 256], # kernel_nums=[64, 64, 64, 64, 64, 64, 64, 64], # max_kernel_size=kernel_size[-1] # ) self.linear_layer = nn.Linear(768 * 2, num_class) self.dropout_layer = nn.Dropout(0.5) self.criterion = nn.NLLLoss(reduction='mean') def forward(self, inputs, mask, sent_counts, sent_lens, prompt_inputs, prompt_mask, prompt_sent_counts, prompt_sent_lens, label=None): """ :param prompt_sent_lens: :param prompt_sent_counts: :param prompt_inputs: :param prompt_mask: :param inputs: [batch size, max sent count, max sent len] :param mask: [batch size, max sent count, max sent len] :param sent_counts: [batch size] :param sent_lens: [batch size, max sent count] :param label: [batch size] :return: """ batch_size = inputs.shape[0] max_sent_count = inputs.shape[1] max_sent_length = inputs.shape[2] inputs = inputs.view(-1, inputs.shape[-1]) mask = mask.view(-1, mask.shape[-1]) # [batch size * max sent len, hid size] last_hidden_states = self.bert(input_ids=inputs, attention_mask=mask)[0] last_hidden_states = last_hidden_states.view(batch_size, max_sent_count, max_sent_length, -1) prompt_inputs = prompt_inputs.view(-1, prompt_inputs.shape[-1]) prompt_mask = prompt_mask.view(-1, prompt_mask.shape[-1]) prompt_hidden_states = self.bert(input_ids=prompt_inputs, attention_mask=prompt_mask)[0] docs = [] lens = [] for i in range(0, batch_size): doc = [] sent_count = sent_counts[i] sent_len = sent_lens[i] for j in range(sent_count): length = sent_len[j] cur_sent = last_hidden_states[i, j, :length, :] # print('cur sent shape', cur_sent.shape) doc.append(cur_sent) doc_vec = torch.cat(doc, dim=0).unsqueeze(0) doc_vec = self.positional_encoding.forward(doc_vec) doc_vec = torch.mean(doc_vec, dim=1) lens.append(doc_vec.shape[0]) # print(i, 'doc shape', doc_vec.shape) docs.append(doc_vec) # batch_max_len = max(lens) # for i, doc in enumerate(docs): # if doc.shape[0] < batch_max_len: # pd = (0, 0, 0, batch_max_len - doc.shape[0]) # m = nn.ConstantPad2d(pd, 0) # doc = m(doc) # # docs[i] = doc.unsqueeze(0) docs = torch.cat(docs, 0) # print(docs.shape) # docs = self.positional_encoding.forward(docs) # [batch size, num_class] prompt = [] for j in range(prompt_sent_counts): length = prompt_sent_lens[0][j] sent = prompt_hidden_states[j, :length, :] prompt.append(sent) prompt_vec = torch.cat(prompt, dim=0).unsqueeze(0) prompt_vec = self.positional_encoding.forward(prompt_vec) prompt_vec = torch.mean(prompt_vec, dim=1) # [batch size, feature size] # doc_feature = self.essay_feature_extracter(docs) # prompt_feature = self.prompt_feature_extracter(prompt_vec) # prompt_feature = prompt_feature.expand_as(doc_feature) doc_feature = self.dropout_layer(torch.tanh(docs)) prompt_feature = self.dropout_layer( torch.tanh(prompt_vec.expand_as(doc_feature))) feature = torch.cat([doc_feature, prompt_feature], dim=-1) log_probs = torch.log_softmax(self.linear_layer(feature), dim=-1) # log_probs = self.classifier(docs) if label is not None: loss = self.criterion(input=log_probs.contiguous().view( -1, log_probs.shape[-1]), target=label.contiguous().view(-1)) else: loss = None prediction = torch.max(log_probs, dim=1)[1] return {'loss': loss, 'prediction': prediction}