def train(train_dataset, dev_dataset): train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size // args.gradient_accumulation_steps, shuffle=True, num_workers=2) nonlocal global_step n_sample = len(train_dataloader) early_stopping = EarlyStopping(args.patience, logger=logger) # Loss function classified_loss = torch.nn.CrossEntropyLoss().to(device) adversarial_loss = torch.nn.BCELoss().to(device) # Optimizers optimizer = AdamW(model.parameters(), args.lr) train_loss = [] if dev_dataset: valid_loss = [] valid_ind_class_acc = [] iteration = 0 for i in range(args.n_epoch): model.train() total_loss = 0 for sample in tqdm.tqdm(train_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, y = sample batch = len(token) f_vector, discriminator_output, classification_output = model( token, mask, type_ids, return_feature=True) discriminator_output = discriminator_output.squeeze() if args.BCE: loss = adversarial_loss(discriminator_output, (y != 0.0).float()) else: loss = classified_loss(discriminator_output, y.long()) total_loss += loss.item() loss = loss / args.gradient_accumulation_steps loss.backward() # bp and update parameters if (global_step + 1) % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() global_step += 1 logger.info('[Epoch {}] Train: train_loss: {}'.format( i, total_loss / n_sample)) logger.info('-' * 30) train_loss.append(total_loss / n_sample) iteration += 1 if dev_dataset: logger.info( '#################### eval result at step {} ####################' .format(global_step)) eval_result = eval(dev_dataset) valid_loss.append(eval_result['loss']) valid_ind_class_acc.append(eval_result['ind_class_acc']) # 1 表示要保存模型 # 0 表示不需要保存模型 # -1 表示不需要模型,且超过了patience,需要early stop signal = early_stopping(-eval_result['eer']) if signal == -1: break elif signal == 0: pass elif signal == 1: save_model(model, path=config['model_save_path'], model_name='bert') logger.info(eval_result) logger.info('valid_eer: {}'.format(eval_result['eer'])) logger.info('valid_oos_ind_precision: {}'.format( eval_result['oos_ind_precision'])) logger.info('valid_oos_ind_recall: {}'.format( eval_result['oos_ind_recall'])) logger.info('valid_oos_ind_f_score: {}'.format( eval_result['oos_ind_f_score'])) logger.info('valid_auc: {}'.format(eval_result['auc'])) logger.info('valid_fpr95: {}'.format( ErrorRateAt95Recall(eval_result['all_binary_y'], eval_result['y_score']))) from utils.visualization import draw_curve draw_curve(train_loss, iteration, 'train_loss', args.output_dir) if dev_dataset: draw_curve(valid_loss, iteration, 'valid_loss', args.output_dir) draw_curve(valid_ind_class_acc, iteration, 'valid_ind_class_accuracy', args.output_dir) if args.patience >= args.n_epoch: save_model(model, path=config['model_save_path'], model_name='bert') freeze_data['train_loss'] = train_loss freeze_data['valid_loss'] = valid_loss
def train(train_dataset, dev_dataset): train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size // args.gradient_accumulation_steps, shuffle=True, num_workers=2) nonlocal global_step n_sample = len(train_dataloader) early_stopping = EarlyStopping(args.patience, logger=logger) # Loss function classified_loss = torch.nn.CrossEntropyLoss().to(device) # Optimizers optimizer = AdamW(model.parameters(), args.lr) train_loss = [] if dev_dataset: valid_loss = [] valid_ind_class_acc = [] iteration = 0 for i in range(args.n_epoch): model.train() total_loss = 0 for sample in tqdm.tqdm(train_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, y = sample batch = len(token) logits = model(token, mask, type_ids) loss = classified_loss(logits, y.long()) total_loss += loss.item() loss = loss / args.gradient_accumulation_steps loss.backward() # bp and update parameters if (global_step + 1) % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() global_step += 1 logger.info('[Epoch {}] Train: train_loss: {}'.format( i, total_loss / n_sample)) logger.info('-' * 30) train_loss.append(total_loss / n_sample) iteration += 1 if dev_dataset: logger.info( '#################### eval result at step {} ####################' .format(global_step)) eval_result = eval(dev_dataset) valid_loss.append(eval_result['loss']) valid_ind_class_acc.append(eval_result['ind_class_acc']) # 1 表示要保存模型 # 0 表示不需要保存模型 # -1 表示不需要模型,且超过了patience,需要early stop signal = early_stopping(eval_result['accuracy']) if signal == -1: break elif signal == 0: pass elif signal == 1: save_model(model, path=config['model_save_path'], model_name='bert') # logger.info(eval_result) from utils.visualization import draw_curve draw_curve(train_loss, iteration, 'train_loss', args.output_dir) if dev_dataset: draw_curve(valid_loss, iteration, 'valid_loss', args.output_dir) draw_curve(valid_ind_class_acc, iteration, 'valid_ind_class_accuracy', args.output_dir) if args.patience >= args.n_epoch: save_model(model, path=config['model_save_path'], model_name='bert') freeze_data['train_loss'] = train_loss freeze_data['valid_loss'] = valid_loss
def train(train_dataset, dev_dataset): train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=2) global best_dev nonlocal global_step n_sample = len(train_dataloader) early_stopping = EarlyStopping(args.patience, logger=logger) # Loss function adversarial_loss = torch.nn.BCELoss().to(device) # Optimizers optimizer_pos = torch.optim.Adam(pos.parameters(), lr=args.pos_lr) optimizer_E = AdamW(E.parameters(), args.bert_lr) valid_detection_loss = [] valid_oos_ind_precision = [] valid_oos_ind_recall = [] valid_oos_ind_f_score = [] train_loss = [] iteration = 0 for i in range(args.n_epoch): logger.info('***********************************') logger.info('epoch: {}'.format(i)) # Initialize model state pos.train() E.train() total_loss = 0 for sample in tqdm(train_dataloader): sample = (i.to(device) for i in sample) token, mask, type_ids, pos1, pos2, pos_mask, y = sample batch = len(token) optimizer_E.zero_grad() optimizer_pos.zero_grad() sequence_output, pooled_output = E(token, mask, type_ids) real_feature = pooled_output out = pos(pos1, pos2, real_feature) loss = adversarial_loss(out, y.float()) loss.backward() total_loss += loss.detach() if args.fine_tune: optimizer_E.step() optimizer_pos.step() logger.info('[Epoch {}] Train: loss: {}'.format( i, total_loss / n_sample)) logger.info( '---------------------------------------------------------------------------' ) train_loss.append(total_loss / n_sample) iteration += 1 if dev_dataset: logger.info( '#################### eval result at step {} ####################' .format(global_step)) eval_result = eval(dev_dataset) valid_detection_loss.append(eval_result['detection_loss']) valid_oos_ind_precision.append( eval_result['oos_ind_precision']) valid_oos_ind_recall.append(eval_result['oos_ind_recall']) valid_oos_ind_f_score.append(eval_result['oos_ind_f_score']) # 1 表示要保存模型 # 0 表示不需要保存模型 # -1 表示不需要模型,且超过了patience,需要early stop signal = early_stopping(-eval_result['eer']) if signal == -1: break # elif signal == 0: # pass # elif signal == 1: # save_gan_model(D, G, config['gan_save_path']) # if args.fine_tune: # save_model(E, path=config['bert_save_path'], model_name='bert') logger.info(eval_result) logger.info('valid_eer: {}'.format(eval_result['eer'])) logger.info('valid_oos_ind_precision: {}'.format( eval_result['oos_ind_precision'])) logger.info('valid_oos_ind_recall: {}'.format( eval_result['oos_ind_recall'])) logger.info('valid_oos_ind_f_score: {}'.format( eval_result['oos_ind_f_score'])) logger.info('valid_auc: {}'.format(eval_result['auc'])) logger.info('valid_fpr95: {}'.format( ErrorRateAt95Recall(eval_result['all_binary_y'], eval_result['y_score']))) best_dev = -early_stopping.best_score # 绘制训练损失曲线 from utils.visualization import draw_curve draw_curve(train_loss, iteration, 'train_loss', args.output_dir)