def get_config(self): return DebertaConfig( vocab_size=self.vocab_size, hidden_size=self.hidden_size, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads, intermediate_size=self.intermediate_size, hidden_act=self.hidden_act, hidden_dropout_prob=self.hidden_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob, max_position_embeddings=self.max_position_embeddings, type_vocab_size=self.type_vocab_size, initializer_range=self.initializer_range, relative_attention=self.relative_attention, position_biased_input=self.position_biased_input, pos_att_type=self.pos_att_type, )
def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_mask = None if self.use_input_mask: input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) token_type_ids = None if self.use_token_type_ids: token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) sequence_labels = None token_labels = None choice_labels = None if self.use_labels: sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) choice_labels = ids_tensor([self.batch_size], self.num_choices) config = DebertaConfig( vocab_size=self.vocab_size, hidden_size=self.hidden_size, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads, intermediate_size=self.intermediate_size, hidden_act=self.hidden_act, hidden_dropout_prob=self.hidden_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob, max_position_embeddings=self.max_position_embeddings, type_vocab_size=self.type_vocab_size, initializer_range=self.initializer_range, relative_attention=self.relative_attention, position_biased_input=self.position_biased_input, pos_att_type=self.pos_att_type, ) return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def main(): device_ids=[0] init_lr = 1e-5 max_epochs = 10 max_length = 512 batch_size = 1 gradient_accu = 32 // batch_size num_label = 2 train_mode = False prev_acc = 0. max_acc = 0. config = DebertaConfig.from_pretrained('microsoft/deberta-large') tknzr = DebertaTokenizer.from_pretrained('microsoft/deberta-large') DebertaConfig.num_labels = 2 train_data, test_data = loadData.load_data() train_data = train_data + loadData.load_data_aug() train_input_ids, train_mask_ids, train_segment_ids, train_label_ids = get_features(train_data, max_length, tknzr) test_input_ids, test_mask_ids, test_segment_ids, test_label_ids = get_features(test_data, max_length, tknzr) # print(all_input_ids.shape) all_input_ids = torch.cat(train_input_ids, dim=0).long() all_input_mask_ids = torch.cat(train_mask_ids, dim=0).long() all_segment_ids = torch.cat(train_segment_ids, dim=0).long() all_label_ids = torch.Tensor(train_label_ids).long() train_dataloader = create_dataloader(all_input_ids, all_input_mask_ids, all_segment_ids, all_label_ids, batch_size=batch_size, train=True) all_input_ids = torch.cat(test_input_ids, dim=0).long() all_input_mask_ids = torch.cat(test_mask_ids, dim=0).long() all_segment_ids = torch.cat(test_segment_ids, dim=0).long() all_label_ids = torch.Tensor(test_label_ids).long() test_dataloader = create_dataloader(all_input_ids, all_input_mask_ids, all_segment_ids, all_label_ids, batch_size=batch_size, train=False) model = DebertaForSequenceClassification.from_pretrained('microsoft/deberta-large').cuda(device_ids[0]) model = torch.nn.DataParallel(model, device_ids=device_ids) optimizer = transformers.AdamW(model.parameters(), lr=init_lr, eps=1e-8) optimizer.zero_grad() #scheduler = transformers.get_constant_schedule_with_warmup(optimizer, len(train_dataloader) // (batch_size * gradient_accu)) #scheduler = transformers.get_linear_schedule_with_warmup(optimizer, len(train_dataloader) // (batch_size * gradient_accu), (len(train_dataloader) * max_epochs * 2) // (batch_size * gradient_accu), last_epoch=-1) if not train_mode: max_epochs = 1 model.load_state_dict(torch.load("../model/model-deberta-1231.ckpt")) foutput = open("answer-deberta-large-test.txt", "w") global_step = 0 for epoch in range(max_epochs): model.train() if train_mode: loss_avg = 0. for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): global_step += 1 batch = [t.cuda() for t in batch] input_id, input_mask, segment_id, label_id = batch loss, _ = model(input_ids=input_id, token_type_ids=segment_id, attention_mask=input_mask, labels=label_id) loss = torch.sum(loss) loss_avg += loss.item() loss = loss / (batch_size * gradient_accu) loss.backward() if global_step % gradient_accu == 0: optimizer.step() optimizer.zero_grad() #if epoch == 0: #scheduler.step() print(loss_avg / len(train_dataloader)) model.eval() final_acc = 0. num_test_sample = 0 tot = [0, 0] correct = [0, 0] countloop = 0 for input_id, input_mask, segment_id, label_id in test_dataloader: countloop += 1 input_id = input_id.cuda() input_mask = input_mask.cuda() segment_id = segment_id.cuda() label_id = label_id.cuda() with torch.no_grad(): loss, logit = model(input_ids=input_id, token_type_ids=segment_id, attention_mask=input_mask, labels=label_id) logit = logit.detach().cpu().numpy() print(logit[0][0], logit[0][1], file = foutput) #print(logit) label_id = label_id.to('cpu').numpy() acc = np.sum(np.argmax(logit, axis=1) == label_id) pred = np.argmax(logit, axis=1) for i in range(label_id.shape[0]): tot[label_id[i]] += 1 if pred[i] == label_id[i]: correct[label_id[i]] += 1 final_acc += acc num_test_sample += input_id.size(0) print("epoch:", epoch) print("final acc:", final_acc / num_test_sample) if train_mode and final_acc / num_test_sample > max_acc: max_acc = final_acc / num_test_sample print("save...") torch.save(model.state_dict(), "../model/model-deberta-1231.ckpt") print("finish") print("Max acc:", max_acc) ''' if final_acc / num_test_sample <= prev_acc: for param_group in optimizer.param_groups: param_group['lr'] = param_group['lr'] * 0.8 ''' prev_acc = final_acc / num_test_sample tp = correct[1] tn = correct[0] fp = tot[1] - correct[1] fn = tot[0] - correct[0] rec = tp / (tp + fn + 1e-5) pre = tp / (tp + fp + 1e-5) print("recall:{0}, precision:{1}".format(rec, pre)) print("f:", 2 * pre * rec / (pre + rec)) print("acc:", (tp + tn) / (tp+tn+fp+fn))