def main(args): init_logger() set_seed(args) tokenizer = load_tokenizer(args) train_dataset = load_and_cache_examples(args, args.train_data_file, tokenizer) test_dataset = load_and_cache_examples(args, args.eval_data_file, tokenizer) trainer = Trainer(args, tokenizer, train_dataset=train_dataset, test_dataset=test_dataset) if args.do_train: trainer.train() if args.do_eval: trainer.load_model() trainer.evaluate()
if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model_type", type=str, required=True, choices=["rbert", "bert_em_cls", "bert_em_es", "bert_em_all"], help="Model type") parser.add_argument("--model_dir", type=str, required=True, help="Path to model directory") parser.add_argument("--input_file", type=str, required=True, help="Path to input file") parser.add_argument("--output_file", type=str, required=True, help="Path to output file (to store predicted labels)") parser.add_argument("--eval_batch_size", type=int, default=32, help="Batch size for evaluation.") parser.add_argument("--no_cuda", action="store_true", help="Whether to use GPU for evaluation.") parser.add_argument("--overwrite_cache", action="store_true", help="Whether to overwrite cached feature file.") args = parser.parse_args() init_logger() logger.info("%s" % args) config = BertConfig.from_pretrained(args.model_dir) train_args = torch.load(os.path.join(args.model_dir, "training_args.bin")) logger.info("Training args: {}".format(train_args)) train_args.eval_batch_size = args.eval_batch_size train_args.overwrite_cache = args.overwrite_cache # For BERT-EM, we have to use GPU because we fix device="cuda" in the code args.device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" # Check whether model exists if not os.path.exists(args.model_dir): raise Exception("Model doesn't exists! Train first!")