def main(cl_arguments): """ Train a model for multitask-training.""" cl_args = handle_arguments(cl_arguments) args = config.params_from_file(cl_args.config_file, cl_args.overrides) train_type = args.get('train_type', "SamplingMultiTaskTrainer") if train_type != "SamplingMultiTaskTrainer": print("\n\n\n", train_type, "\n\n\n") # Check for deprecated arg names check_arg_name(args) args, seed = initial_setup(args, cl_args) # Load tasks log.info("Loading tasks...") start_time = time.time() pretrain_tasks, target_tasks, vocab, word_embs = build_tasks(args) tasks = sorted(set(pretrain_tasks + target_tasks), key=lambda x: x.name) log.info("\tFinished loading tasks in %.3fs", time.time() - start_time) log.info("\t Tasks: {}".format([task.name for task in tasks])) # Build model log.info("Building model...") start_time = time.time() model = build_model(args, vocab, word_embs, tasks) log.info("Finished building model in %.3fs", time.time() - start_time) # Start Tensorboard if requested if cl_args.tensorboard: tb_logdir = os.path.join(args.run_dir, "tensorboard") _run_background_tensorboard(tb_logdir, cl_args.tensorboard_port) check_configurations(args, pretrain_tasks, target_tasks) if args.do_pretrain: # Train on pretrain tasks log.info("Training...") stop_metric = pretrain_tasks[0].val_metric if len( pretrain_tasks) == 1 else "macro_avg" should_decrease = (pretrain_tasks[0].val_metric_decreases if len(pretrain_tasks) == 1 else False) trainer, _, opt_params, schd_params = build_trainer( args, [], model, args.run_dir, should_decrease, phase="pretrain", train_type=train_type) to_train = [(n, p) for n, p in model.named_parameters() if p.requires_grad] _ = trainer.train( pretrain_tasks, stop_metric, args.batch_size, args.weighting_method, args.scaling_method, to_train, opt_params, schd_params, args.load_model, phase="pretrain", ) # For checkpointing logic if not args.do_target_task_training: strict = True else: strict = False if args.do_target_task_training: # Train on target tasks pre_target_train_path = setup_target_task_training( args, target_tasks, model, strict) target_tasks_to_train = copy.deepcopy(target_tasks) # Check for previous target train checkpoints task_to_restore, _, _ = check_for_previous_checkpoints( args.run_dir, target_tasks_to_train, "target_train", args.load_model) if task_to_restore is not None: # If there is a task to restore from, target train only on target tasks # including and following that task. last_task_index = [task.name for task in target_tasks_to_train ].index(task_to_restore) target_tasks_to_train = target_tasks_to_train[last_task_index:] for task in target_tasks_to_train: # Skip tasks that should not be trained on. if task.eval_only_task: continue params_to_train = load_model_for_target_train_run( args, pre_target_train_path, model, strict, task) trainer, _, opt_params, schd_params = build_trainer( args, [task.name], model, args.run_dir, task.val_metric_decreases, phase="target_train", train_type=train_type) _ = trainer.train( tasks=[task], stop_metric=task.val_metric, batch_size=args.batch_size, weighting_method=args.weighting_method, scaling_method=args.scaling_method, train_params=params_to_train, optimizer_params=opt_params, scheduler_params=schd_params, load_model=(task.name == task_to_restore), phase="target_train", ) if args.do_full_eval: log.info("Evaluating...") splits_to_write = evaluate.parse_write_preds_arg(args.write_preds) # Evaluate on target_tasks. for task in target_tasks: # Find the task-specific best checkpoint to evaluate on. task_to_use = model._get_task_params(task.name).get( "use_classifier", task.name) ckpt_path = get_best_checkpoint_path(args, "eval", task_to_use) assert ckpt_path is not None load_model_state(model, ckpt_path, args.cuda, skip_task_models=[], strict=strict) evaluate_and_write(args, model, [task], splits_to_write) if args.delete_checkpoints_when_done and not args.keep_all_checkpoints: log.info("Deleting all checkpoints.") delete_all_checkpoints(args.run_dir) log.info("Done!")
def main(cl_arguments): """ Train a model for multitask-training.""" cl_args = handle_arguments(cl_arguments) args = config.params_from_file(cl_args.config_file, cl_args.overrides) # Check for deprecated arg names check_arg_name(args) args, seed = initial_setup(args, cl_args) #Store the run description, if any if FLAGS.description: with open(Path(args.run_dir, 'description.txt'), 'w') as f: f.write(FLAGS.description) # Load tasks log.info("Loading tasks...") start_time = time.time() # cuda_device = parse_cuda_list_arg(args.cuda) cuda_device = FLAGS.device_idxs pretrain_tasks, target_tasks, vocab, word_embs = build_tasks( args, cuda_device) tasks = sorted(set(pretrain_tasks + target_tasks), key=lambda x: x.name) log.info("\tFinished loading tasks in %.3fs", time.time() - start_time) log.info("\t Tasks: {}".format([task.name for task in tasks])) # Build model log.info("Building model...") start_time = time.time() model = build_model(args, vocab, word_embs, tasks, cuda_device) log.info("Finished building model in %.3fs", time.time() - start_time) # Start Tensorboard if requested if cl_args.tensorboard: tb_logdir = os.path.join(args.run_dir, "tensorboard") _run_background_tensorboard(tb_logdir, cl_args.tensorboard_port) check_configurations(args, pretrain_tasks, target_tasks) if args.do_pretrain: # Train on pretrain tasks log.info("Training...") stop_metric = pretrain_tasks[0].val_metric if len( pretrain_tasks) == 1 else "macro_avg" should_decrease = (pretrain_tasks[0].val_metric_decreases if len(pretrain_tasks) == 1 else False) trainer, _, opt_params, schd_params = build_trainer(args, cuda_device, [], model, args.run_dir, should_decrease, phase="pretrain") to_train = [(n, p) for n, p in model.named_parameters() if p.requires_grad] _ = trainer.train( pretrain_tasks, stop_metric, args.batch_size, args.weighting_method, args.scaling_method, to_train, opt_params, schd_params, args.load_model, phase="pretrain", ) # For checkpointing logic if not args.do_target_task_training: strict = True else: strict = False if args.do_target_task_training: # Train on target tasks pre_target_train_path = setup_target_task_training( args, target_tasks, model, strict) target_tasks_to_train = copy.deepcopy(target_tasks) # Check for previous target train checkpoints task_to_restore, _, _ = check_for_previous_checkpoints( args.run_dir, target_tasks_to_train, "target_train", args.load_model) if task_to_restore is not None: # If there is a task to restore from, target train only on target tasks # including and following that task. last_task_index = [task.name for task in target_tasks_to_train ].index(task_to_restore) target_tasks_to_train = target_tasks_to_train[last_task_index:] for task in target_tasks_to_train: # Skip tasks that should not be trained on. if task.eval_only_task: continue params_to_train = load_model_for_target_train_run( args, pre_target_train_path, model, strict, task, cuda_device) trainer, _, opt_params, schd_params = build_trainer( args, cuda_device, [task.name], model, args.run_dir, task.val_metric_decreases, phase="target_train", ) _ = trainer.train( tasks=[task], stop_metric=task.val_metric, batch_size=args.batch_size, weighting_method=args.weighting_method, scaling_method=args.scaling_method, train_params=params_to_train, optimizer_params=opt_params, scheduler_params=schd_params, load_model=(task.name == task_to_restore), phase="target_train", ) if args.do_full_eval: log.info("Evaluating...") splits_to_write = evaluate.parse_write_preds_arg(args.write_preds) results_dict = {'run_name': [args.run_name]} # Evaluate on target_tasks. for task in target_tasks: # Find the task-specific best checkpoint to evaluate on. task_params = get_model_attribute(model, "_get_task_params", cuda_device) task_to_use = task_params(task.name).get("use_classifier", task.name) ckpt_path = get_best_checkpoint_path(args, "eval", task_to_use) assert ckpt_path is not None load_model_state(model, ckpt_path, cuda_device, skip_task_models=[], strict=strict) current_tasks_val_results = evaluate_and_write( args, model, [task], splits_to_write, cuda_device) results_dict = {**results_dict, **current_tasks_val_results} tabular_results_csv = os.path.join(SMALL_SHARED_SERVER_DIR, "tabular_results.csv") existing_results_df = pd.read_csv(tabular_results_csv, index_col=False) new_results_df = pd.DataFrame.from_dict(results_dict) updated_results_df = new_results_df.append(existing_results_df, sort=False) with open(tabular_results_csv, 'w') as f: log.info(f"Prepending results to {tabular_results_csv}.") updated_results_df.to_csv(f, header=True, index=False) if args.delete_checkpoints_when_done and not args.keep_all_checkpoints: log.info("Deleting all checkpoints.") delete_all_checkpoints(args.run_dir) log.info("Done!")