def main(): parser = HfArgumentParser((EvalArguments, )) args = parser.parse_args_into_dataclasses()[0] tokenizer = AutoTokenizer.from_pretrained( args.tokenizer_name_or_path if args.tokenizer_name_or_path else args.model_name_or_path, ) model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path) valid_dataset = torch.load(args.valid_file_path) collator = T2TDataCollator(tokenizer=tokenizer, model_type=args.model_type, mode="inference") loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, collate_fn=collator) predictions = get_predictions(model=model, tokenizer=tokenizer, data_loader=loader, num_beams=args.num_beams, max_length=args.max_decoding_length) with open(args.output_path, 'w') as f: f.write("\n".join(predictions)) logging.info(f"Output saved at {args.output_path}")
def main(): args_json = BaseArguments() args_json = args_json.parse_args() parser = HfArgumentParser((EvalArguments, TrainingArguments)) args, training_args = parser.parse_json_file( json_file=args_json.config_file) if args.model_type == "bart": prefix = "outputs/bart-large/checkpoint-" times = 10 per_count = 1000 elif args.model_type == "t5": prefix = "outputs/t5-large/checkpoint-" times = 5 per_count = 500 else: raise ValueError() for i in range(times): count = (i + 1) * per_count args.model_name_or_path = prefix + str(count) # breakpoint() tokenizer_cls = MODEL_TYPE_TO_TOKENIZER[args.model_type] tokenizer = tokenizer_cls.from_pretrained( args.tokenizer_name_or_path if args.tokenizer_name_or_path else args.model_batchname_or_path, ) # breakpoint() model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path) device = training_args.device model.to(device) eval_dataset = torch.load("data/" + args.eval_file_path) collator = T2TDataCollator(tokenizer=tokenizer, model_type=args.model_type, mode="inference") loader = torch.utils.data.DataLoader( eval_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=collator) predictions = get_predictions(model=model, tokenizer=tokenizer, data_loader=loader, device=device, num_beams=args.num_beams, max_length=args.max_decoding_length) #breakpoint() args.output_path = os.path.join(args.model_name_or_path, "generated_text") with open(args.output_path, 'wb') as f: pickle.dump(predictions, f) logging.info(f"Output saved at {args.output_path}")
def main(args_file=None): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) if (len(sys.argv) == 2 and sys.argv[1].endswith(".json")) or args_file is not None: # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. args_file_path = os.path.abspath(sys.argv[1]) if args_file is None else args_file model_args, data_args, training_args = parser.parse_json_file(json_file=args_file_path) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() assert model_args.model_type in list(MODEL_TYPE_TO_TOKENIZER.keys()), "model type should be 't5' or 'bart'" if ( os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir ): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." ) # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, ) logger.warning( "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", training_args.local_rank, training_args.device, training_args.n_gpu, bool(training_args.local_rank != -1), training_args.fp16, ) logger.info("Training/evaluation parameters %s", training_args) # Set seed set_seed(training_args.seed) # Set project name os.environ["WANDB_PROJECT"] = "question_generation_french" os.environ["WANDB_WATCH"] = "all" # Load pretrained model and tokenizer # # Distributed training: # The .from_pretrained methods guarantee that only one local process can concurrently # download model & vocab. tokenizer_cls = MODEL_TYPE_TO_TOKENIZER[model_args.model_type] tokenizer = tokenizer_cls.from_pretrained( model_args.tokenizer_name_or_path if model_args.tokenizer_name_or_path else model_args.model_name_or_path, cache_dir=model_args.cache_dir, ) model = AutoModelForSeq2SeqLM.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, ) model.resize_token_embeddings(len(tokenizer)) # Get datasets logger.info('loading dataset') train_dataset = torch.load(data_args.train_file_path) if training_args.do_train else None valid_dataset = torch.load(data_args.valid_file_path) if training_args.do_eval else None logger.info('finished loading dataset') # Initialize data_collator data_collator = T2TDataCollator( tokenizer=tokenizer, model_type=model_args.model_type, mode="training", using_tpu=training_args.tpu_num_cores is not None ) # Initialize our Trainer trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=valid_dataset, data_collator=data_collator, prediction_loss_only=True, ) # disable wandb console logs logging.getLogger('wandb.run_manager').setLevel(logging.WARNING) # Training if training_args.do_train: trainer.train( model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None ) trainer.save_model() if not os.path.exists(training_args.output_dir): os.makedirs(training_args.output_dir) save_model_dir = os.path.join(training_args.output_dir, 'checkpoint-last') os.makedirs(save_model_dir) model.save_pretrained(save_model_dir) model.save_pretrained(training_args.output_dir) # For convenience, we also re-save the tokenizer to the same directory, # so that you can share your model easily on huggingface.co/models =) if trainer.is_world_master(): tokenizer.save_pretrained(training_args.output_dir) # Evaluation results = {} if training_args.do_eval and training_args.local_rank in [-1, 0]: logger.info("*** Evaluate ***") eval_output = trainer.evaluate() output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt") with open(output_eval_file, "w") as writer: logger.info("***** Eval results *****") for key in sorted(eval_output.keys()): logger.info(" %s = %s", key, str(eval_output[key])) writer.write("%s = %s\n" % (key, str(eval_output[key]))) results.update(eval_output) return results
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. args = BaseArguments() args = args.parse_args() parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_json_file(json_file=args.config_file) assert model_args.model_type in list(MODEL_TYPE_TO_TOKENIZER.keys()), "model type should be 't5' or 'bart' or gpt2" if ( os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir ): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." ) # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, ) logger.warning( "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", training_args.local_rank, training_args.device, training_args.n_gpu, bool(training_args.local_rank != -1), training_args.fp16, ) logger.info("Training/evaluation parameters %s", training_args) # Set seed set_seed(training_args.seed) # Load pretrained model and tokenizer tokenizer_cls = MODEL_TYPE_TO_TOKENIZER[model_args.model_type] tokenizer = tokenizer_cls.from_pretrained( model_args.tokenizer_name_or_path if model_args.tokenizer_name_or_path else model_args.model_name_or_path, ) model = AutoModelForSeq2SeqLM.from_pretrained( model_args.model_name_or_path, ) # Extend the emb dim for special toks model.resize_token_embeddings(len(tokenizer)) if model_args.freeze_embeds: logger.info("freezing embeddings of the model") freeze_embeds(model) assert_not_all_frozen(model) # Get datasets logger.info('loading dataset') train_dataset = torch.load('data/' + data_args.train_file_path) if training_args.do_train else None eval_dataset = torch.load('data/' + data_args.eval_file_path) if training_args.do_eval else None # breakpoint() logger.info('finished loading dataset') # Initialize data_collator data_collator = T2TDataCollator( tokenizer=tokenizer, model_type=model_args.model_type, mode="training", using_tpu=training_args.tpu_num_cores is not None ) # Initialize our Trainer trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator, prediction_loss_only=True, label_smoothing=model_args.label_smoothing ) # breakpoint() # disable wandb console logs logging.getLogger('wandb.run_manager').setLevel(logging.WARNING) # Training if training_args.do_train: trainer.train( model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None ) trainer.save_model() # For convenience, we also re-save the tokenizer to the same directory, # so that you can share your model easily on huggingface.co/models =) if trainer.is_world_master(): tokenizer.save_pretrained(training_args.output_dir) # Evaluation results = {} if training_args.do_eval and training_args.local_rank in [-1, 0]: logger.info("*** Evaluate ***") eval_output = trainer.evaluate() output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt") with open(output_eval_file, "w") as writer: logger.info("***** Eval results *****") for key in sorted(eval_output.keys()): logger.info(" %s = %s", key, str(eval_output[key])) writer.write("%s = %s\n" % (key, str(eval_output[key]))) results.update(eval_output) logger.info('Results {}'.format(results)) return results