torch.cuda.set_device(7) # 查看现在使用的设备 print('current device:', torch.cuda.current_device()) # 预测验证集还是测试集 params = utils.Params() # Set the random seed for reproducible experiments random.seed(args.seed) torch.manual_seed(args.seed) params.seed = args.seed # Set the logger utils.set_logger() # Create the input data pipeline logging.info("Loading the dataset...") dataloader = NERDataLoader(params) test_loader = dataloader.get_dataloader(data_sign='test') logging.info("- done.") # Define the model logging.info('Loading the model...') model = ZenForSequenceClassification.from_pretrained(params.pretrain_model_dir, num_labels=len(params.tags)) model.to(params.device) # Reload weights from the saved file utils.load_checkpoint(os.path.join(params.model_dir, args.restore_file + '.pth.tar'), model) logging.info('- done.') logging.info("Starting prediction...") predict(model, test_loader, params) logging.info('- done.')
# 查看现在使用的设备 print('current device:', torch.cuda.current_device()) # 预测验证集还是测试集 mode = args.mode params = utils.Params() # Set the random seed for reproducible experiments random.seed(args.seed) torch.manual_seed(args.seed) params.seed = args.seed # Set the logger utils.set_logger() # Create the input data pipeline logging.info("Loading the dataset...") dataloader = NERDataLoader(params) val_loader, test_loader = dataloader.load_data(mode='test') logging.info("- done.") # Define the model logging.info('Loading the model...') config_path = os.path.join(params.params_path, 'bert_config.json') config = RobertaConfig.from_json_file(config_path) model = ElectraForTokenClassification(config, params=params) model.to(params.device) # Reload weights from the saved file utils.load_checkpoint( os.path.join(params.model_dir, args.restore_file + '.pth.tar'), model) logging.info('- done.')
mode = args.mode params = utils.Params() # Set the random seed for reproducible experiments random.seed(args.seed) torch.manual_seed(args.seed) params.seed = args.seed # Set the logger utils.set_logger() # Define the model logging.info('Loading the model...') model = BertForTokenClassification.from_pretrained(params.bert_model_dir, params=params) model.to(params.device) # Reload weights from the saved file utils.load_checkpoint( os.path.join(params.model_dir, args.restore_file + '.pth.tar'), model) logging.info('- done.') logging.info("Loading the dataset...") # get dataloader dataloader = NERDataLoader(params) loader = dataloader.get_dataloader(data_sign=mode) logging.info('-done') logging.info("Starting prediction...") # Create the input data pipeline predict(model, loader, params, mode) logging.info('-done')
def train_and_evaluate(model, params, restore_file=None): """Train the model and evaluate every epoch.""" # load args args = parser.parse_args() # Load training data and val data dataloader = NERDataLoader(params) train_loader = dataloader.get_dataloader(data_sign='train') val_loader = dataloader.get_dataloader(data_sign='val') # 一个epoch的步数 params.train_steps = len(train_loader) # Prepare optimizer # fine-tuning # 取模型权重 param_optimizer = list(model.named_parameters()) # pretrain model param param_pre = [(n, p) for n, p in param_optimizer if 'bert' in n] # middle model param param_middle = [(n, p) for n, p in param_optimizer if 'bilstm' in n or 'dym_weight' in n] # crf param param_crf = [p for n, p in param_optimizer if 'crf' in n] # 不进行衰减的权重 no_decay = ['bias', 'LayerNorm', 'dym_weight', 'layer_norm'] # 将权重分组 optimizer_grouped_parameters = [ # pretrain model param # 衰减 {'params': [p for n, p in param_pre if not any(nd in n for nd in no_decay)], 'weight_decay': params.weight_decay_rate, 'lr': params.fin_tuning_lr }, # 不衰减 {'params': [p for n, p in param_pre if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': params.fin_tuning_lr }, # middle model # 衰减 {'params': [p for n, p in param_middle if not any(nd in n for nd in no_decay)], 'weight_decay': params.weight_decay_rate, 'lr': params.middle_lr }, # 不衰减 {'params': [p for n, p in param_middle if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': params.middle_lr }, # crf,单独设置学习率 {'params': param_crf, 'weight_decay': 0.0, 'lr': params.crf_lr} ] num_train_optimization_steps = len(train_loader) // params.gradient_accumulation_steps * args.epoch_num optimizer = BertAdam(optimizer_grouped_parameters, warmup=params.warmup_prop, schedule="warmup_cosine", t_total=num_train_optimization_steps, max_grad_norm=params.clip_grad) # reload weights from restore_file if specified if restore_file is not None: restore_path = os.path.join(params.model_dir, args.restore_file + '.pth.tar') logging.info("Restoring parameters from {}".format(restore_path)) # 读取checkpoint utils.load_checkpoint(restore_path, model, optimizer) # patience stage best_val_f1 = 0.0 patience_counter = 0 for epoch in range(1, args.epoch_num + 1): # Run one epoch logging.info("Epoch {}/{}".format(epoch, args.epoch_num)) # Train for one epoch on training set train(model, train_loader, optimizer, params) # Evaluate for one epoch on training set and validation set # train_metrics = evaluate(model, train_loader, params, mark='Train', # verbose=True) # Dict['loss', 'f1'] val_metrics = evaluate(args, model, val_loader, params, mark='Val', verbose=True) # Dict['loss', 'f1'] # 验证集f1-score val_f1 = val_metrics['f1'] # 提升的f1-score improve_f1 = val_f1 - best_val_f1 # Save weights of the network model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self optimizer_to_save = optimizer utils.save_checkpoint({'epoch': epoch + 1, 'state_dict': model_to_save.state_dict(), 'optim_dict': optimizer_to_save.state_dict()}, is_best=improve_f1 > 0, checkpoint=params.model_dir) params.save(params.params_path / 'params.json') # stop training based params.patience if improve_f1 > 0: logging.info("- Found new best F1") best_val_f1 = val_f1 if improve_f1 < params.patience: patience_counter += 1 else: patience_counter = 0 else: patience_counter += 1 # Early stopping and logging best f1 if (patience_counter > params.patience_num and epoch > params.min_epoch_num) or epoch == args.epoch_num: logging.info("Best val f1: {:05.2f}".format(best_val_f1)) break
def train_and_evaluate(model, optimizer, scheduler, params, restore_file=None): """Train the model and evaluate every epoch.""" # load args args = parser.parse_args() # reload weights from restore_file if specified if restore_file is not None: restore_path = os.path.join(params.model_dir, args.restore_file + '.pth.tar') logging.info("Restoring parameters from {}".format(restore_path)) # 读取checkpoint utils.load_checkpoint(restore_path, model, optimizer) # Load training data and val data dataloader = NERDataLoader(params) train_loader = dataloader.get_dataloader(data_sign='train') val_loader = dataloader.get_dataloader(data_sign='val') # patience stage best_val_f1 = 0.0 patience_counter = 0 for epoch in range(1, args.epoch_num + 1): # Run one epoch logging.info("Epoch {}/{}".format(epoch, args.epoch_num)) # 一个epoch的步数 params.train_steps = len(train_loader) # Train for one epoch on training set train(model, train_loader, optimizer, params) # Evaluate for one epoch on training set and validation set val_metrics = evaluate(model, val_loader, params, args, mark='Val') # Dict['loss', 'f1'] # lr_scheduler学习率递减 step scheduler.step() # 验证集f1-score val_f1 = val_metrics['f1'] # 提升的f1-score improve_f1 = val_f1 - best_val_f1 # Save weights of the network model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self optimizer_to_save = optimizer utils.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model_to_save.state_dict(), 'optim_dict': optimizer_to_save.state_dict() }, is_best=improve_f1 > 0, checkpoint=params.model_dir) params.save(params.params_path / 'params.json') # stop training based params.patience if improve_f1 > 0: logging.info("- Found new best F1") best_val_f1 = val_f1 if improve_f1 < params.patience: patience_counter += 1 else: patience_counter = 0 else: patience_counter += 1 # Early stopping and logging best f1 if (patience_counter > params.patience_num and epoch > params.min_epoch_num) or epoch == args.epoch_num: logging.info("Best val f1: {:05.2f}".format(best_val_f1)) break