def main(): args = SimpleNamespace( debug=False, data_dir='data', train_file=None, directory='question_generator_gpt2-medium', model_name='gpt2-medium', do_lower_case=True, max_seq_length=448, overwrite_cache=False, cache_dir=None, dynamic_batching=True, dev_batch_size=4, fp16=True, fp16_opt_level='O2', no_cuda=False, seed=42, ) # Set device args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.DEBUG if args.debug else logging.INFO ) tb_writer = SummaryWriter(os.path.join(args.directory, 'TB_writer_q-loss_evaluation')) # Set seed set_seed(args) # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set. # Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will # remove the need for this code, but it is still valid. if args.fp16: try: import apex from apex import amp apex.amp.register_half_function(torch, "einsum") except ImportError: raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") # Evaluate checkpoints def sort_fn(x): return int(x.split('-')[-1]) checkpoints = [x for x in os.listdir(args.directory) if 'checkpoint' in x] checkpoints = sorted(checkpoints, key=sort_fn) for checkpoint in checkpoints: logging.info(f'Evaluating checkpint : {checkpoint}') global_step = checkpoint.split('-')[-1] # Load pretrained model and tokenizer config = GPT2Config.from_pretrained( os.path.join(args.directory, checkpoint), cache_dir=args.cache_dir if args.cache_dir else None, ) tokenizer = GPT2Tokenizer.from_pretrained( os.path.join(args.directory, checkpoint), do_lower_case=args.do_lower_case, cache_dir=args.cache_dir if args.cache_dir else None, ) tokenizer.encode = partial(tokenizer.encode, is_pretokenized=True, truncation=True) tokenizer.encode_plus = partial(tokenizer.encode_plus, is_pretokenized=True, truncation=True) model = GPT2LMHeadModel.from_pretrained( os.path.join(args.directory, checkpoint), config=config, cache_dir=args.cache_dir if args.cache_dir else None, ) model.to(args.device) if args.fp16: model = amp.initialize(model, opt_level=args.fp16_opt_level) # Load data dev_dataset = load_and_cache_examples(args, tokenizer, 'quest_gen', evaluate=True, gpt=True) dev_dataset = preprocess_dataset(dev_dataset, tokenizer) # Evaluation loss = evaluate(args, dev_dataset, model) tb_writer.add_scalar('dev_loss', loss, global_step)
def main(): parser = get_parser() args = parser.parse_args() if not args.model_name: args.model_name = args.model_path if args.doc_stride >= args.max_seq_length - args.max_query_length: logger.warning( "WARNING - You've set a doc stride which may be superior to the document length in some " "examples. This could result in errors when building features from the examples. Please reduce the doc " "stride or increase the maximum length to ensure the features are correctly built." ) if ( os.path.exists(args.output_dir) and os.listdir(args.output_dir) and not args.overwrite_output_dir ): raise ValueError( "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format( args.output_dir ) ) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) # Set device args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.DEBUG if args.debug else logging.INFO ) # Set seed set_seed(args) # Load pretrained model and tokenizer config = GPT2Config.from_pretrained( args.config_name if args.config_name else args.model_path, cache_dir=args.cache_dir if args.cache_dir else None, ) tokenizer = GPT2Tokenizer.from_pretrained( args.tokenizer_name if args.tokenizer_name else args.model_path, do_lower_case=args.do_lower_case, cache_dir=args.cache_dir if args.cache_dir else None, ) tokenizer.add_tokens(['question:', ':question']) tokenizer.pad_token = tokenizer.eos_token tokenizer.sep_token = tokenizer.eos_token tokenizer.encode = partial(tokenizer.encode, is_pretokenized=True, truncation=True) tokenizer.encode_plus = partial(tokenizer.encode_plus, is_pretokenized=True, truncation=True) model = GPT2LMHeadModel.from_pretrained( args.model_path, config=config, cache_dir=args.cache_dir if args.cache_dir else None, ) model.resize_token_embeddings(len(tokenizer)) model.to(args.device) logger.info("Training/evaluation parameters %s", args) # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set. # Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will # remove the need for this code, but it is still valid. if args.fp16: try: import apex apex.amp.register_half_function(torch, "einsum") except ImportError: raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") # Training train_dataset = load_and_cache_examples(args, tokenizer, 'quest_gen', evaluate=False, gpt=True) train_dataset = preprocess_dataset(train_dataset, tokenizer) dev_dataset = load_and_cache_examples(args, tokenizer, 'quest_gen', evaluate=True, gpt=True) dev_dataset = preprocess_dataset(dev_dataset, tokenizer) train(args, train_dataset, dev_dataset, model, tokenizer) logging.info('Finished training !') # Save a trained model, configuration and tokenizer using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` # Good practice: save your training arguments together with the trained model logger.info("Saving final model checkpoint to %s", args.output_dir) model.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir) torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
def evaluate(args, model, tokenizer, prefix=""): # Loop to handle MNLI double evaluation (matched, mis-matched) eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else ( args.task_name, ) eval_outputs_dirs = (args.output_dir, args.output_dir + '-MM') if args.task_name == "mnli" else ( args.output_dir, ) results = {} for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs): eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True) if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]: os.makedirs(eval_output_dir) args.eval_batch_size = args.per_gpu_eval_batch_size * max( 1, args.n_gpu) # Note that DistributedSampler samples randomly eval_sampler = SequentialSampler( eval_dataset) if args.local_rank == -1 else DistributedSampler( eval_dataset) eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) # multi-gpu eval if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Eval! logger.info("***** Running evaluation {} *****".format(prefix)) logger.info(" Num examples = %d", len(eval_dataset)) logger.info(" Batch size = %d", args.eval_batch_size) eval_loss = 0.0 nb_eval_steps = 0 preds = None out_label_ids = None for batch in tqdm(eval_dataloader, desc="Evaluating"): model.eval() batch = tuple(t.to(args.device) for t in batch) with torch.no_grad(): inputs = { 'input_ids': batch[0], 'attention_mask': batch[1], 'labels': batch[3] } if args.model_type != 'distilbert': inputs['token_type_ids'] = batch[2] if args.model_type in [ 'bert', 'xlnet' ] else None # XLM, DistilBERT and RoBERTa don't use segment_ids outputs = model(**inputs) tmp_eval_loss, logits = outputs[:2] eval_loss += tmp_eval_loss.mean().item() nb_eval_steps += 1 if preds is None: preds = logits.detach().cpu().numpy() out_label_ids = inputs['labels'].detach().cpu().numpy() else: preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) out_label_ids = np.append( out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0) eval_loss = eval_loss / nb_eval_steps if args.output_mode == "classification": preds = np.argmax(preds, axis=1) elif args.output_mode == "regression": preds = np.squeeze(preds) result = {"acc:": 0.0, "f1:": "0.0"} with open('./smp2020ewect-bert-wwm-baseline' + str(prefix) + ".txt", 'w', encoding='utf-8') as f: _preds = preds.tolist() for p in _preds: f.write(str(p) + '\n') #result = compute_metrics(eval_task, preds, out_label_ids) results.update(result) output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt") with open(output_eval_file, "w") as writer: logger.info("***** Eval results {} *****".format(prefix)) for key in sorted(result.keys()): logger.info(" %s = %s", key, str(result[key])) writer.write("%s = %s\n" % (key, str(result[key]))) return results
def main(): args = get_parser() if os.path.exists(args.output_dir) and os.listdir( args.output_dir ) and args.do_train and not args.overwrite_output_dir: raise ValueError( "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome." .format(args.output_dir)) # Setup distant debugging if needed if args.server_ip and args.server_port: # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script import ptvsd print("Waiting for debugger attach") ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.wait_for_attach() # Setup CUDA, GPU & distributed training if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") args.n_gpu = torch.cuda.device_count() else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) torch.distributed.init_process_group(backend='nccl') args.n_gpu = 1 args.device = device # Setup logging logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logger.warning( "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) # Set seed set_seed(args) # Prepare GLUE task # args.task_name = args.task_name.lower() # if args.task_name not in processors: # raise ValueError("Task not found: %s" % (args.task_name)) processor = SMP2020_ewect() # processors[args.task_name]() args.output_mode = "classification" # output_modes[args.task_name] label_list = processor.get_labels() num_labels = len(label_list) # Load pretrained model and tokenizer if args.local_rank not in [-1, 0]: torch.distributed.barrier( ) # Make sure only the first process in distributed training will download model & vocab args.model_type = args.model_type.lower() config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] config = config_class.from_pretrained( args.config_name if args.config_name else args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name, cache_dir=args.cache_dir if args.cache_dir else None) tokenizer = tokenizer_class.from_pretrained( args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case, cache_dir=args.cache_dir if args.cache_dir else None) model = model_class.from_pretrained( args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config, cache_dir=args.cache_dir if args.cache_dir else None) if args.local_rank == 0: torch.distributed.barrier( ) # Make sure only the first process in distributed training will download model & vocab model.to(args.device) logger.info("Training/evaluation parameters %s", args) # Training if args.do_train: train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) global_step, tr_loss = train(args, train_dataset, model, tokenizer) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): # Create output directory if needed if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: os.makedirs(args.output_dir) logger.info("Saving model checkpoint to %s", args.output_dir) # Save a trained model, configuration and tokenizer using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` model_to_save = model.module if hasattr( model, 'module') else model # Take care of distributed/parallel training model_to_save.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir) # Good practice: save your training arguments together with the trained model torch.save(args, os.path.join(args.output_dir, 'training_args.bin')) # Load a trained model and vocabulary that you have fine-tuned model = model_class.from_pretrained(args.output_dir) tokenizer = tokenizer_class.from_pretrained(args.output_dir) model.to(args.device) # Evaluation results = {} if args.do_eval and args.local_rank in [-1, 0]: tokenizer = tokenizer_class.from_pretrained( args.output_dir, do_lower_case=args.do_lower_case) checkpoints = [args.output_dir] if args.eval_all_checkpoints: checkpoints = list( os.path.dirname(c) for c in sorted( glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) logging.getLogger("transformers.modeling_utils").setLevel( logging.WARN) # Reduce logging logger.info("Evaluate the following checkpoints: %s", checkpoints) for checkpoint in checkpoints: global_step = checkpoint.split( '-')[-1] if len(checkpoints) > 1 else "" prefix = checkpoint.split( '/')[-1] if checkpoint.find('checkpoint') != -1 else "" model = model_class.from_pretrained(checkpoint) model.to(args.device) result = evaluate(args, model, tokenizer, prefix=prefix) result = dict( (k + '_{}'.format(global_step), v) for k, v in result.items()) results.update(result) return results
def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument( "--model_type", default=None, type=str, required=True, help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), ) parser.add_argument( "--model_name_or_path", default=None, type=str, required=True, help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS), ) parser.add_argument( "--output_dir", default=None, type=str, required=True, help="The output directory where the model checkpoints and predictions will be written.", ) # Other parameters parser.add_argument( "--data_dir", default=None, type=str, help="The input data dir. Should contain the .json files for the task." + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", ) parser.add_argument( "--train_file", default=None, type=str, help="The input training file. If a data dir is specified, will look for the file there" + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", ) parser.add_argument( "--predict_file", default=None, type=str, help="The input evaluation file. If a data dir is specified, will look for the file there" + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", ) parser.add_argument( "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" ) parser.add_argument( "--tokenizer_name", default="", type=str, help="Pretrained tokenizer name or path if not the same as model_name", ) parser.add_argument( "--cache_dir", default="", type=str, help="Where do you want to store the pre-trained models downloaded from s3", ) parser.add_argument('--threshold', type=float, default=None, help='Optionally provide a threshold to use for deciding whether to predict actual span' ' or leave question unanswered. If one is not provided, the "best" one is chosen' ' based on the ground truth in the query file. ') parser.add_argument( "--max_seq_length", default=512, type=int, help="The maximum total input sequence length after WordPiece tokenization. Sequences " "longer than this will be truncated, and sequences shorter than this will be padded.", ) parser.add_argument( "--doc_stride", default=192, type=int, help="When splitting up a long document into chunks, how much stride to take between chunks.", ) parser.add_argument( "--max_query_length", default=200, type=int, help="The maximum number of tokens for the question. Questions longer than this will " "be truncated to this length.", ) parser.add_argument("--do_train", action="store_true", help="Whether to run training.") parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.") parser.add_argument( "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step." ) parser.add_argument( "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model." ) parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.") parser.add_argument("--predict_batch_size", default=50, type=int, help="Total batch size for predictions.") parser.add_argument("--learning_rate", default=4e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument( "--gradient_accumulation_steps", type=int, default=16, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") parser.add_argument("--adam_beta1", default=0.9, type=float, help="Beta_1 for Adam optimizer.") parser.add_argument("--adam_beta2", default=0.999, type=float, help='Beta_2 for Adam optimizer') parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument( "--num_train_epochs", default=20.0, type=float, help="Total number of training epochs to perform." ) parser.add_argument( "--max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform. Override num_train_epochs.", ) parser.add_argument("--warmup_proportion", default=0.1, type=float, help="Proportion of training to perform linear learning rate warmup for" " Bert Adam. E.g., 0.1=10%% of training.") parser.add_argument("--n_best_size", default=20, type=int, help="The number (each) of start/end logits to track per feature when scoring.") parser.add_argument("--n_to_predict", default=20, type=int, help="The number of predictions to save for each " "example for analysis / ensembling (aka the `k` for the top k " "predictions setting). NOTE: these additional predictions are not" " used for evaluation...only the top answer is used for evaluation") parser.add_argument( "--max_answer_length", default=250, type=int, help="The maximum length of an answer that can be generated. This is needed because the start " "and end predictions are not conditioned on one another.", ) parser.add_argument( "--verbose_logging", action="store_true", help="If true, all of the warnings related to data processing will be printed. " "A number of warnings are expected for a normal SQuAD evaluation.", ) parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.") parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.") parser.add_argument( "--eval_all_checkpoints", action="store_true", help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number", ) parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available") parser.add_argument( "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory" ) parser.add_argument( "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets" ) parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") parser.add_argument('--negative_sampling_prob_when_has_answer', type=float, default=0.1, help="Only used when preparing training features, not for decoding. " " This ratio will be used when the example has a short answer, but" " the span does not. Specifically we will keep NO_ANSWER spans" " with probability ${negative_sampling_prob_when_has_answer}") parser.add_argument('--negative_sampling_prob_when_no_answer', type=float, default=0.15, help="Only used when preparing training features, not for decoding. " "This ratio will be used when the example has NO short answer. " "Specifically we will keep spans from this example " "with probability ${negative_sampling_prob_when_has_answer}") parser.add_argument('--add_doc_title_to_passage', action='store_true', help="Whether to add document title to each passage in feature generation") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument( "--fp16", action="store_true", help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", ) parser.add_argument( "--fp16_opt_level", type=str, default="O1", help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." "See details at https://nvidia.github.io/apex/amp.html", ) parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.") parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.") parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features") parser.add_argument("--dataset", type=str, help="dataset: LenovoQA, TechQA, UnixQA") parser.add_argument("--use_cached_file", type=bool, default=False, help="you may change stride length so default as False") args = parser.parse_args() if args.fp16: try: from apex import amp # HACK: dirty trick to import amp once so that it's accessible outside this method globals()['amp'] = amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") if ( os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir ): raise ValueError( "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format( args.output_dir ) ) # Setup CUDA, GPU & distributed training if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count() else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) torch.distributed.init_process_group(backend="nccl") args.n_gpu = 1 args.device = device # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN, ) logger.warning( "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16, ) # Set seed set_seed(args) # Load pretrained model and tokenizer if args.local_rank not in [-1, 0]: # Make sure only the first process in distributed training will download model & vocab torch.distributed.barrier() args.model_type = args.model_type.lower() config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] config = config_class.from_pretrained( args.config_name if args.config_name else args.model_name_or_path, cache_dir=args.cache_dir if args.cache_dir else None, ) tokenizer = tokenizer_class.from_pretrained( args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case, cache_dir=args.cache_dir if args.cache_dir else None, ) logger.info("Training/evaluation parameters %s", args) # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set. # Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will # remove the need for this code, but it is still valid. if args.fp16: try: import apex apex.amp.register_half_function(torch, "einsum") except ImportError: raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") model, optimizer = load_model(args, model_class, config) if args.do_train: train_dataset = load_and_cache_examples(args, dataset=args.dataset, tokenizer=tokenizer, evaluate=False, output_examples=False) if args.do_eval: test_dataset, test_features, gold_dict = load_and_cache_examples(args, dataset=args.dataset, tokenizer=tokenizer, evaluate=True, output_examples=True) model_evaluator = partial( predict_output, device=args.device, dataset=args.dataset, model_type=args.model_type, eval_features=test_features, eval_dataset=test_dataset, predict_batch_size=args.predict_batch_size, output_dir=args.output_dir) else: model_evaluator = None # Create output directory if needed if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: os.makedirs(args.output_dir) # Good practice: save your training arguments together with the trained model torch.save(args, os.path.join(args.output_dir, "training_args.bin")) # Training if args.do_train: model = train(args, train_dataset, model, optimizer, tokenizer, model_evaluator) if model_evaluator is not None: model_evaluator(model=model) return args