def test_replace_none(): assert py_logic.replace_none(1, default=2) == 1 assert py_logic.replace_none(None, default=2) == 2
def run_simple(args: RunConfiguration, with_continue: bool = False): hf_config = AutoConfig.from_pretrained(args.hf_pretrained_model_name_or_path) model_cache_path = replace_none( args.model_cache_path, default=os.path.join(args.exp_dir, "models") ) with distributed.only_first_process(local_rank=args.local_rank): # === Step 1: Write task configs based on templates === # full_task_name_list = sorted(list(set(args.train_tasks + args.val_tasks + args.test_tasks))) task_config_path_dict = {} if args.create_config: task_config_path_dict = create_and_write_task_configs( task_name_list=full_task_name_list, data_dir=args.data_dir, task_config_base_path=os.path.join(args.data_dir, "configs"), ) else: for task_name in full_task_name_list: task_config_path_dict[task_name] = os.path.join( args.data_dir, "configs", f"{task_name}_config.json" ) # === Step 2: Download models === # # if not os.path.exists(os.path.join(model_cache_path, hf_config.model_type)): # print("Downloading model") # export_model.export_model( # hf_pretrained_model_name_or_path=args.hf_pretrained_model_name_or_path, # output_base_path=os.path.join(model_cache_path, hf_config.model_type), # ) # === Step 3: Tokenize and cache === # phase_task_dict = { "train": args.train_tasks, "val": args.val_tasks, "test": args.test_tasks, } for task_name in full_task_name_list: phases_to_do = [] for phase, phase_task_list in phase_task_dict.items(): if task_name in phase_task_list and not os.path.exists( os.path.join(args.exp_dir, "cache", hf_config.model_type, task_name, phase) ): config = read_json(task_config_path_dict[task_name]) if phase in config["paths"]: phases_to_do.append(phase) else: phase_task_list.remove(task_name) if not phases_to_do: continue print(f"Tokenizing Task '{task_name}' for phases '{','.join(phases_to_do)}'") tokenize_and_cache.main( tokenize_and_cache.RunConfiguration( task_config_path=task_config_path_dict[task_name], hf_pretrained_model_name_or_path=args.hf_pretrained_model_name_or_path, output_dir=os.path.join(args.exp_dir, "cache", hf_config.model_type, task_name), phases=phases_to_do, # TODO: Need a strategy for task-specific max_seq_length issues (issue #1176) max_seq_length=args.max_seq_length, smart_truncate=True, do_iter=True, ) ) # === Step 4: Generate jiant_task_container_config === # # We'll do this with a configurator. Creating a jiant_task_config has a surprising # number of moving parts. jiant_task_container_config = configurator.SimpleAPIMultiTaskConfigurator( task_config_base_path=os.path.join(args.data_dir, "configs"), task_cache_base_path=os.path.join(args.exp_dir, "cache", hf_config.model_type), train_task_name_list=args.train_tasks, val_task_name_list=args.val_tasks, test_task_name_list=args.test_tasks, train_batch_size=args.train_batch_size, eval_batch_multiplier=2, epochs=args.num_train_epochs, num_gpus=torch.cuda.device_count(), train_examples_cap=args.train_examples_cap, ).create_config() os.makedirs(os.path.join(args.exp_dir, "run_configs"), exist_ok=True) jiant_task_container_config_path = os.path.join( args.exp_dir, "run_configs", f"{args.run_name}_config.json" ) py_io.write_json(jiant_task_container_config, path=jiant_task_container_config_path) # === Step 5: Train/Eval! === # if args.model_weights_path: model_load_mode = "partial" model_weights_path = args.model_weights_path else: # From Transformers if any(task_name.startswith("mlm_") for task_name in full_task_name_list): model_load_mode = "from_transformers_with_mlm" else: model_load_mode = "from_transformers" model_weights_path = os.path.join( model_cache_path, hf_config.model_type, "model", "model.p" ) run_output_dir = os.path.join(args.exp_dir, "runs", args.run_name) if ( args.save_checkpoint_every_steps and os.path.exists(os.path.join(run_output_dir, "checkpoint.p")) and with_continue ): print("Resuming") checkpoint = torch.load(os.path.join(run_output_dir, "checkpoint.p")) run_args = runscript.RunConfiguration.from_dict(checkpoint["metadata"]["args"]) else: print("Running from start") run_args = runscript.RunConfiguration( # === Required parameters === # jiant_task_container_config_path=jiant_task_container_config_path, output_dir=run_output_dir, # === Model parameters === # hf_pretrained_model_name_or_path=args.hf_pretrained_model_name_or_path, model_path=model_weights_path, model_config_path=os.path.join( model_cache_path, hf_config.model_type, "model", "config.json", ), model_load_mode=model_load_mode, # === Running Setup === # do_train=bool(args.train_tasks), do_val=bool(args.val_tasks), do_save=args.do_save, do_save_best=args.do_save_best, do_save_last=args.do_save_last, write_val_preds=args.write_val_preds, write_test_preds=args.write_test_preds, eval_every_steps=args.eval_every_steps, save_every_steps=args.save_every_steps, save_checkpoint_every_steps=args.save_checkpoint_every_steps, no_improvements_for_n_evals=args.no_improvements_for_n_evals, keep_checkpoint_when_done=args.keep_checkpoint_when_done, force_overwrite=args.force_overwrite, seed=args.seed, # === Training Learning Parameters === # learning_rate=args.learning_rate, adam_epsilon=args.adam_epsilon, max_grad_norm=args.max_grad_norm, optimizer_type=args.optimizer_type, # === Specialized config === # no_cuda=args.no_cuda, fp16=args.fp16, fp16_opt_level=args.fp16_opt_level, local_rank=args.local_rank, server_ip=args.server_ip, server_port=args.server_port, ) checkpoint = None runscript.run_loop(args=run_args, checkpoint=checkpoint) py_io.write_file(args.to_json(), os.path.join(run_output_dir, "simple_run_config.json"))
def dry_run(args: RunConfiguration): model_cache_path = replace_none(args.model_cache_path, default=os.path.join( args.exp_dir, "models")) print("\n# === Step 1: Write task configs based on templates === #") full_task_name_list = sorted( list(set(args.train_tasks + args.val_tasks + args.test_tasks))) for task_name in full_task_name_list: print(f""" python jiant/proj/main/write_task_configs.py \\ --task_name {task_name} \\ --task_data_dir {os.path.join(args.data_dir, task_name)} \\ --task_config_path {os.path.join(args.exp_dir, "task_configs", f"{task_name}_config.json")} """.strip()) print("\n# === Step 2: Download models === #") print(f""" python jiant/proj/main/export_model.py \\ --model_type {args.model_type} \\ --output_base_path {os.path.join(model_cache_path, args.model_type)} """.strip()) print("\n# === Step 3: Tokenize and cache === #") phase_task_dict = { "train": args.train_tasks, "val": args.val_tasks, "test": args.test_tasks, } for task_name in full_task_name_list: phases_to_do = [] for phase, phase_task_list in phase_task_dict.items(): if task_name in phase_task_list: phases_to_do.append(phase) print(f""" python jiant/proj/main/tokenize_and_cache.py \\ --task_config_path {os.path.join(args.exp_dir, "task_configs", f"{task_name}_config.json")} \\ --model_type {args.model_type} \\ --model_tokenizer_path {os.path.join(model_cache_path, args.model_type, "tokenizer")} \\ --output_dir {os.path.join(args.exp_dir, "cache", task_name)} \\ --phases {",".join(phases_to_do)} \\ --max_seq_length {args.max_seq_length} \\ --smart_truncate \\ --do_iter """.strip()) print("\n# === Step 4: Generate jiant_task_container_config === #") s = f""" python jiant/proj/main/scripts/configurator.py \\ SimpleAPIMultiTaskConfigurator \\ {os.path.join(args.exp_dir, "run_configs", f"{args.run_name}_config.json")} \\ --task_config_base_path {os.path.join(args.exp_dir, "task_configs")} \\ --task_cache_base_path {os.path.join(args.exp_dir, "cache")} \\ --train_task_name_list {",".join(args.train_tasks)} \\ --val_task_name_list {",".join(args.val_tasks)} \\ --test_task_name_list {",".join(args.test_tasks)} \\ --train_batch_size {args.train_batch_size} \\ --eval_batch_multiplier 2 \\ --epochs {args.num_train_epochs} \\ --num_gpus {torch.cuda.device_count()} """.strip() if args.train_examples_cap: s += f" \\\n --train_examples_cap {args.train_examples_cap}" print(s.strip()) print("\n# === Step 5: Train/Eval! === #") if args.model_weights_path: model_load_mode = "partial" model_weights_path = args.model_weights_path else: # From Transformers if any( task_name.startswith("mlm_") for task_name in full_task_name_list): model_load_mode = "from_transformers_with_mlm" else: model_load_mode = "from_transformers" model_weights_path = os.path.join(model_cache_path, args.model_type, "model", f"{args.model_type}.p") s = f""" python jiant/proj/main/runscript.py \\ run \\ --jiant_task_container_config_path \ {os.path.join(args.exp_dir, "run_configs", f"{args.run_name}_config.json")} \\ --output_dir {os.path.join(args.exp_dir, "runs", args.run_name)} \\ --model_type {args.model_type} \\ --model_path {model_weights_path} \\ --model_config_path \ {os.path.join(model_cache_path, args.model_type, "model", f"{args.model_type}.json")} \\ --model_tokenizer_path {os.path.join(model_cache_path, args.model_type, "tokenizer")} \\ --model_load_mode {model_load_mode} """.strip() if args.train_tasks: s += " \\\n --do_train" if args.val_tasks: s += " \\\n --do_val" covered_attrs = [ "jiant_task_container_config_path", "output_dir", "model_type", "model_path", "model_config_path", "model_tokenizer_path", "model_load_mode", ] for attr in runscript.RunConfiguration.__attrs_attrs__: if attr.name in covered_attrs: continue if not hasattr(args, attr.name): continue args_attr = getattr(args, attr.name) if attr.default == args_attr: continue if attr.default is None and args_attr is None: continue if ("argparse_kwargs" in attr.metadata and "action" in attr.metadata["argparse_kwargs"] and attr.metadata["argparse_kwargs"]["action"] == "store_true"): s += f" \\\n --{attr.name}" else: s += f" \\\n --{attr.name} {args_attr}" print(s.strip())