def run_loop(args: RunConfiguration, checkpoint=None, tasks=None): is_resumed = checkpoint is not None quick_init_out = initialization.quick_init(args=args, verbose=True) print(quick_init_out.n_gpu) with quick_init_out.log_writer.log_context(): jiant_task_container = container_setup.create_jiant_task_container_from_json( jiant_task_container_config_path=args. jiant_task_container_config_path, verbose=True, ) runner = setup_runner( args=args, jiant_task_container=jiant_task_container, quick_init_out=quick_init_out, verbose=True, ) if is_resumed: runner.load_state(checkpoint["runner_state"]) del checkpoint["runner_state"] checkpoint_saver = jiant_runner.CheckpointSaver( metadata={"args": args.to_dict()}, save_path=os.path.join(args.output_dir, "checkpoint.p"), ) if args.do_val: print("EVAL_BEFORE________________________________") val_results_dict = runner.run_val( task_name_list=runner.jiant_task_container.task_run_config. val_task_list, return_preds=args.write_val_preds, ) jiant_evaluate.write_val_results( val_results_dict=val_results_dict, metrics_aggregator=runner.jiant_task_container. metrics_aggregator, output_dir=args.output_dir, verbose=True, ) if args.write_val_preds: jiant_evaluate.write_preds( eval_results_dict=val_results_dict, path=os.path.join(args.output_dir, "val_preds.p"), ) if args.do_train: metarunner = jiant_metarunner.JiantMetarunner( runner=runner, save_every_steps=args.save_every_steps, eval_every_steps=args.eval_every_steps, save_checkpoint_every_steps=args.save_checkpoint_every_steps, no_improvements_for_n_evals=args.no_improvements_for_n_evals, checkpoint_saver=checkpoint_saver, output_dir=args.output_dir, verbose=True, save_best_model=True, save_last_model=args.do_save or args.do_save_last, load_best_model=True, log_writer=quick_init_out.log_writer, ) if is_resumed: metarunner.load_state(checkpoint["metarunner_state"]) del checkpoint["metarunner_state"] metarunner.run_train_loop() runner.run_perturb(tasks, metarunner) if args.do_val: val_results_dict = runner.run_val( task_name_list=runner.jiant_task_container.task_run_config. val_task_list, return_preds=args.write_val_preds, ) jiant_evaluate.write_val_results( val_results_dict=val_results_dict, metrics_aggregator=runner.jiant_task_container. metrics_aggregator, output_dir=args.output_dir, verbose=True, ) if args.write_val_preds: jiant_evaluate.write_preds( eval_results_dict=val_results_dict, path=os.path.join(args.output_dir, "val_preds.p"), ) else: assert not args.write_val_preds if args.write_test_preds: test_results_dict = runner.run_test( task_name_list=runner.jiant_task_container.task_run_config. test_task_list, ) jiant_evaluate.write_preds( eval_results_dict=test_results_dict, path=os.path.join(args.output_dir, "test_preds.p"), ) print("EVAL DONE------------------------") if (not args.keep_checkpoint_when_done and args.save_checkpoint_every_steps and os.path.exists(os.path.join(args.output_dir, "checkpoint.p"))): os.remove(os.path.join(args.output_dir, "checkpoint.p")) py_io.write_file("DONE", os.path.join(args.output_dir, "done_file"))
def run_simple(args: RunConfiguration, with_continue: bool = False): hf_config = AutoConfig.from_pretrained(args.hf_pretrained_model_name_or_path) model_cache_path = replace_none( args.model_cache_path, default=os.path.join(args.exp_dir, "models") ) with distributed.only_first_process(local_rank=args.local_rank): # === Step 1: Write task configs based on templates === # full_task_name_list = sorted(list(set(args.train_tasks + args.val_tasks + args.test_tasks))) task_config_path_dict = {} if args.create_config: task_config_path_dict = create_and_write_task_configs( task_name_list=full_task_name_list, data_dir=args.data_dir, task_config_base_path=os.path.join(args.data_dir, "configs"), ) else: for task_name in full_task_name_list: task_config_path_dict[task_name] = os.path.join( args.data_dir, "configs", f"{task_name}_config.json" ) # === Step 2: Download models === # # if not os.path.exists(os.path.join(model_cache_path, hf_config.model_type)): # print("Downloading model") # export_model.export_model( # hf_pretrained_model_name_or_path=args.hf_pretrained_model_name_or_path, # output_base_path=os.path.join(model_cache_path, hf_config.model_type), # ) # === Step 3: Tokenize and cache === # phase_task_dict = { "train": args.train_tasks, "val": args.val_tasks, "test": args.test_tasks, } for task_name in full_task_name_list: phases_to_do = [] for phase, phase_task_list in phase_task_dict.items(): if task_name in phase_task_list and not os.path.exists( os.path.join(args.exp_dir, "cache", hf_config.model_type, task_name, phase) ): config = read_json(task_config_path_dict[task_name]) if phase in config["paths"]: phases_to_do.append(phase) else: phase_task_list.remove(task_name) if not phases_to_do: continue print(f"Tokenizing Task '{task_name}' for phases '{','.join(phases_to_do)}'") tokenize_and_cache.main( tokenize_and_cache.RunConfiguration( task_config_path=task_config_path_dict[task_name], hf_pretrained_model_name_or_path=args.hf_pretrained_model_name_or_path, output_dir=os.path.join(args.exp_dir, "cache", hf_config.model_type, task_name), phases=phases_to_do, # TODO: Need a strategy for task-specific max_seq_length issues (issue #1176) max_seq_length=args.max_seq_length, smart_truncate=True, do_iter=True, ) ) # === Step 4: Generate jiant_task_container_config === # # We'll do this with a configurator. Creating a jiant_task_config has a surprising # number of moving parts. jiant_task_container_config = configurator.SimpleAPIMultiTaskConfigurator( task_config_base_path=os.path.join(args.data_dir, "configs"), task_cache_base_path=os.path.join(args.exp_dir, "cache", hf_config.model_type), train_task_name_list=args.train_tasks, val_task_name_list=args.val_tasks, test_task_name_list=args.test_tasks, train_batch_size=args.train_batch_size, eval_batch_multiplier=2, epochs=args.num_train_epochs, num_gpus=torch.cuda.device_count(), train_examples_cap=args.train_examples_cap, ).create_config() os.makedirs(os.path.join(args.exp_dir, "run_configs"), exist_ok=True) jiant_task_container_config_path = os.path.join( args.exp_dir, "run_configs", f"{args.run_name}_config.json" ) py_io.write_json(jiant_task_container_config, path=jiant_task_container_config_path) # === Step 5: Train/Eval! === # if args.model_weights_path: model_load_mode = "partial" model_weights_path = args.model_weights_path else: # From Transformers if any(task_name.startswith("mlm_") for task_name in full_task_name_list): model_load_mode = "from_transformers_with_mlm" else: model_load_mode = "from_transformers" model_weights_path = os.path.join( model_cache_path, hf_config.model_type, "model", "model.p" ) run_output_dir = os.path.join(args.exp_dir, "runs", args.run_name) if ( args.save_checkpoint_every_steps and os.path.exists(os.path.join(run_output_dir, "checkpoint.p")) and with_continue ): print("Resuming") checkpoint = torch.load(os.path.join(run_output_dir, "checkpoint.p")) run_args = runscript.RunConfiguration.from_dict(checkpoint["metadata"]["args"]) else: print("Running from start") run_args = runscript.RunConfiguration( # === Required parameters === # jiant_task_container_config_path=jiant_task_container_config_path, output_dir=run_output_dir, # === Model parameters === # hf_pretrained_model_name_or_path=args.hf_pretrained_model_name_or_path, model_path=model_weights_path, model_config_path=os.path.join( model_cache_path, hf_config.model_type, "model", "config.json", ), model_load_mode=model_load_mode, # === Running Setup === # do_train=bool(args.train_tasks), do_val=bool(args.val_tasks), do_save=args.do_save, do_save_best=args.do_save_best, do_save_last=args.do_save_last, write_val_preds=args.write_val_preds, write_test_preds=args.write_test_preds, eval_every_steps=args.eval_every_steps, save_every_steps=args.save_every_steps, save_checkpoint_every_steps=args.save_checkpoint_every_steps, no_improvements_for_n_evals=args.no_improvements_for_n_evals, keep_checkpoint_when_done=args.keep_checkpoint_when_done, force_overwrite=args.force_overwrite, seed=args.seed, # === Training Learning Parameters === # learning_rate=args.learning_rate, adam_epsilon=args.adam_epsilon, max_grad_norm=args.max_grad_norm, optimizer_type=args.optimizer_type, # === Specialized config === # no_cuda=args.no_cuda, fp16=args.fp16, fp16_opt_level=args.fp16_opt_level, local_rank=args.local_rank, server_ip=args.server_ip, server_port=args.server_port, ) checkpoint = None runscript.run_loop(args=run_args, checkpoint=checkpoint) py_io.write_file(args.to_json(), os.path.join(run_output_dir, "simple_run_config.json"))
def write_done(output_dir): py_io.write_file("DONE", os.path.join(output_dir, "DONE"))