def train(self, crf2train_dataloader, crf2dev_dataloader, dev_dataset_loader, epoch_list, args): start_time = time.time() for epoch_idx in epoch_list: args.start_epoch = epoch_idx curr_start_time = time.time() ########################### crf_no = random.randint(0, len(self.crf2corpus) - 1) ########################### cur_dataset = crf2train_dataloader[crf_no] epoch_loss = self.train_epoch(cur_dataset, crf_no, self.crit_ner, self.optimizer, args) # main evaluation on the combined dev in N21 or single dev in N2N corpus_name = [ args.dev_file[i].split("/")[-2] for i in self.crf2corpus[crf_no] ] print(args.dispatch, "Dev Corpus: ", corpus_name) dev_f1, dev_pre, dev_rec, dev_acc = self.eval_epoch( crf2dev_dataloader[crf_no], crf_no, args) if_add_patience = True if dev_f1 > self.best_f1[crf_no] and not args.combine: print("Prev Best F1: {:.4f} Curr Best F1: {:.4f}".format( self.best_f1[crf_no], dev_f1)) self.best_epoch_idx = epoch_idx self.patience_count = 0 self.best_f1[crf_no] = dev_f1 self.best_pre[crf_no] = dev_pre self.best_rec[crf_no] = dev_rec self.best_state_dict = deepcopy(self.ner_model.state_dict()) checkpoint_name = args.checkpoint + "/" checkpoint_name += args.dispatch + "_" if args.dispatch in ["N2K", "N2N"]: checkpoint_name += args.train_file[ self.crf2corpus[crf_no][0]].split("/")[-2] + "_" checkpoint_name += "{:.4f}_{:.4f}_{:.4f}_{:d}".format( dev_f1, dev_pre, dev_rec, epoch_idx) print("NOW SAVING, ", checkpoint_name) print() self.drop_check_point(checkpoint_name, args) self.best_checkpoint_name = checkpoint_name if_add_patience &= False else: if args.dispatch == "N2N" or not args.stop_on_single: self.patience_count += 1 self.track_list.append({ 'loss': epoch_loss, 'dev_f1': dev_f1, 'dev_acc': dev_acc }) if epoch_idx == args.epoch - 1: last_checkpoint_name = args.checkpoint + "/" last_checkpoint_name += args.dispatch + "_" if args.dispatch in ["N2K", "N2N"]: last_checkpoint_name += args.train_file[ self.crf2corpus[crf_no][0]].split("/")[-2] + "_" last_checkpoint_name += "LAST" + "_" last_checkpoint_name += "{:.4f}_{:.4f}_{:.4f}_{:d}".format( dev_f1, dev_pre, dev_rec, epoch_idx) print("NOW SAVING LAST, ", last_checkpoint_name) self.drop_check_point(last_checkpoint_name, args) print() if args.combine: self.best_state_dict = deepcopy( self.ner_model.state_dict()) # save check point for each corpus if args.dispatch in ["N21", "N2K"]: # print("Drop the best check point for single corpus") for cid in self.crf2corpus[crf_no]: print(args.dev_file[cid]) cid_f1, cid_pre, cid_rec, cid_acc = self.eval_epoch( dev_dataset_loader[cid], crf_no, args) # F1 if cid_f1 > self.corpus_best_vec[cid][ 0] and not args.combine: print( "Prev Best F1: {:.4f} Curr Best F1: {:.4f}".format( self.corpus_best_vec[cid][0], cid_f1)) self.corpus_best_vec[cid] = [cid_f1, cid_pre, cid_rec] if args.stop_on_single: self.patience_count = 0 checkpoint_name = args.checkpoint + "/" checkpoint_name += args.dispatch + "_" checkpoint_name += args.dev_file[cid].split( "/")[-2] + "_" checkpoint_name += "{:.4f}_{:.4f}_{:.4f}_{:d}".format( cid_f1, cid_pre, cid_rec, epoch_idx) print("NOW SAVING, ", checkpoint_name) self.drop_check_point(checkpoint_name, args) print() self.corpus_best_checkpoint_name[cid] = checkpoint_name if_add_patience &= False else: if_add_patience &= True if if_add_patience and args.stop_on_single: self.patience_count += 1 operating_time = time.time() - start_time h = operating_time // 3600 m = (operating_time - 3600 * h) // 60 s = operating_time - 3600 * h - 60 * m print( "Epoch: [{:d}/{:d}]\t Patient: {:d}\t Current: {:.2f}\t Total: {:2d}:{:2d}:{:.2f}\n" .format(args.start_epoch, args.epoch - 1, self.patience_count, time.time() - curr_start_time, int(h), int(m), s)) if self.patience_count >= args.patience and args.start_epoch >= args.least_iters: break # update lr if self.plateau: self.scheduler.step(dev_f1) else: utils.adjust_learning_rate( self.optimizer, args.lr / (1 + (args.start_epoch + 1) * args.lr_decay)) print("Sample Frequence") for crf, corpus_idx in self.crf2corpus.items(): corpus_name = [ args.train_file[i].split("/")[-2] for i in corpus_idx ] print(crf, corpus_name, self.sample_cnter[crf]) print()
f_f, f_p, b_f, b_p, w_f, tg_v, mask_v, SCRF_labels, mask_SCRF_labels, cnn_features = packer.repack(f_f, f_p, b_f, b_p, w_f, tg_v, mask_v, len_v, SCRF_labels, mask_SCRF_labels, cnn_features, test=False) optimizer.zero_grad() loss = model(f_f, f_p, b_f, b_p, w_f, cnn_features, tg_v, mask_v, mask_v.long().sum(0), SCRF_labels, mask_SCRF_labels, onlycrf=False) epoch_loss += utils.to_scalar(loss) loss.backward() nn.utils.clip_grad_norm(model.parameters(), args.clip_grad) optimizer.step() epoch_loss /= tot_length print('epoch_loss: ', epoch_loss) utils.adjust_learning_rate(optimizer, args.lr / (1 + (args.start_epoch + 1) * args.lr_decay)) dev_f1_crf, dev_pre_crf, dev_rec_crf, dev_acc_crf, dev_f1_scrf, dev_pre_scrf, dev_rec_scrf, dev_acc_scrf, dev_f1_jnt, dev_pre_jnt, dev_rec_jnt, dev_acc_jnt = \ evaluator.calc_score(model, dev_dataset_loader) if dev_f1_jnt > best_dev_f1_jnt: early_stop_epochs = 0 test_f1_crf, test_pre_crf, test_rec_crf, test_acc_crf, test_f1_scrf, test_pre_scrf, test_rec_scrf, test_acc_scrf, test_f1_jnt, test_pre_jnt, test_rec_jnt, test_acc_jnt = \ evaluator.calc_score(model, test_dataset_loader) best_test_f1_crf = test_f1_crf best_test_f1_scrf = test_f1_scrf best_dev_f1_jnt = dev_f1_jnt best_test_f1_jnt = test_f1_jnt
def train(self, data, *args, **kwargs): tot_length = sum(map(lambda t: len(t), self.dataset_loader)) loss_list = [] acc_list = [] best_f1 = [] for i in range(self.file_num): best_f1.append(float('-inf')) best_pre = [] for i in range(self.file_num): best_pre.append(float('-inf')) best_rec = [] for i in range(self.file_num): best_rec.append(float('-inf')) start_time = time.time() epoch_list = range(self.args.start_epoch, self.args.start_epoch + self.args.epoch) patience_count = 0 for epoch_idx, self.args.start_epoch in enumerate(epoch_list): sample_num = 1 epoch_loss = 0 self.ner_model.train() for sample_id in tqdm(range(sample_num), mininterval=2, desc=' - Tot it %d (epoch %d)' % (tot_length, self.args.start_epoch), leave=False, file=sys.stdout): self.file_no = random.randint(0, self.file_num - 1) cur_dataset = self.dataset_loader[self.file_no] for f_f, f_p, b_f, b_p, w_f, tg_v, mask_v, len_v in itertools.chain.from_iterable( cur_dataset): f_f, f_p, b_f, b_p, w_f, tg_v, mask_v = self.packer.repack_vb( f_f, f_p, b_f, b_p, w_f, tg_v, mask_v, len_v) self.ner_model.zero_grad() scores = self.ner_model(f_f, f_p, b_f, b_p, w_f, self.file_no) loss = self.crit_ner(scores, tg_v, mask_v) epoch_loss += utils.to_scalar(loss) if self.args.co_train: cf_p = f_p[0:-1, :].contiguous() cb_p = b_p[1:, :].contiguous() cf_y = w_f[1:, :].contiguous() cb_y = w_f[0:-1, :].contiguous() cfs, _ = self.ner_model.word_pre_train_forward( f_f, cf_p) loss = loss + self.args.lambda0 * self.crit_lm( cfs, cf_y.view(-1)) cbs, _ = self.ner_model.word_pre_train_backward( b_f, cb_p) loss = loss + self.args.lambda0 * self.crit_lm( cbs, cb_y.view(-1)) loss.backward() nn.utils.clip_grad_norm(self.ner_model.parameters(), self.args.clip_grad) self.optimizer.step() epoch_loss /= tot_length # update lr utils.adjust_learning_rate( self.optimizer, self.args.lr / (1 + (self.args.start_epoch + 1) * self.args.lr_decay)) # eval & save check_point if 'f' in self.args.eva_matrix: dev_f1, dev_pre, dev_rec, dev_acc = self.evaluate( None, None, self.dev_dataset_loader[self.file_no], self.file_no) loss_list.append(epoch_loss) acc_list.append(dev_acc) if dev_f1 > best_f1[self.file_no]: patience_count = 0 best_f1[self.file_no] = dev_f1 best_pre[self.file_no] = dev_pre best_rec[self.file_no] = dev_rec self.track_list.append({ 'loss': epoch_loss, 'dev_f1': dev_f1, 'dev_acc': dev_acc }) print( '(loss: %.4f, epoch: %d, dataset: %d, dev F1 = %.4f, dev pre = %.4f, dev rec = %.4f)' % (epoch_loss, self.args.start_epoch, self.file_no, dev_f1, dev_pre, dev_rec)) try: self.save_model(None) except Exception as inst: print(inst) else: patience_count += 1 print( '(loss: %.4f, epoch: %d, dataset: %d, dev F1 = %.4f, dev pre = %.4f, dev rec = %.4f)' % (epoch_loss, self.args.start_epoch, self.file_no, dev_f1, dev_pre, dev_rec)) self.track_list.append({ 'loss': epoch_loss, 'dev_f1': dev_f1, 'dev_acc': dev_acc }) print('epoch: ' + str(self.args.start_epoch) + '\t in ' + str(self.args.epoch) + ' take: ' + str(time.time() - start_time) + ' s') if patience_count >= self.args.patience and self.args.start_epoch >= self.args.least_iters: break return loss_list, acc_list