Beispiel #1
0
def main(args):
    quick_init_out = initialization.quick_init(args=args, verbose=True)
    with quick_init_out.log_writer.log_context():
        task = tasks.create_task_from_config_path(
            config_path=args.task_config_path,
            verbose=True,
        )

        with distributed.only_first_process(local_rank=args.local_rank):
            # load the model
            model_class_spec = model_resolution.resolve_model_setup_classes(
                model_type=args.model_type,
                task_type=task.TASK_TYPE,
            )
            model_wrapper = model_setup.simple_model_setup(
                model_type=args.model_type,
                model_class_spec=model_class_spec,
                config_path=args.model_config_path,
                tokenizer_path=args.model_tokenizer_path,
                task=task,
            )
            model_setup.simple_load_model_path(
                model=model_wrapper.model,
                model_load_mode=args.model_load_mode,
                model_path=args.model_path,
                verbose=True,
            )
            model_wrapper.model.to(quick_init_out.device)

        train_examples = task.get_train_examples()
        train_examples, _ = train_setup.maybe_subsample_train(
            train_examples=train_examples,
            train_examples_number=args.train_examples_number,
            train_examples_fraction=args.train_examples_fraction,
        )
        num_train_examples = len(train_examples)

        train_schedule = train_setup.get_train_schedule(
            num_train_examples=num_train_examples,
            max_steps=args.max_steps,
            num_train_epochs=args.num_train_epochs,
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            per_gpu_train_batch_size=args.train_batch_size,
            n_gpu=quick_init_out.n_gpu,
        )
        quick_init_out.log_writer.write_entry(
            "text", f"t_total: {train_schedule.t_total}", do_print=True)
        loss_criterion = train_setup.resolve_loss_function(
            task_type=task.TASK_TYPE)
        optimizer_scheduler = model_setup.create_optimizer(
            model=model_wrapper.model,
            learning_rate=args.learning_rate,
            t_total=train_schedule.t_total,
            warmup_steps=args.warmup_steps,
            warmup_proportion=args.warmup_proportion,
            optimizer_type=args.optimizer_type,
            verbose=True,
        )
        model_setup.special_model_setup(
            model_wrapper=model_wrapper,
            optimizer_scheduler=optimizer_scheduler,
            fp16=args.fp16,
            fp16_opt_level=args.fp16_opt_level,
            n_gpu=quick_init_out.n_gpu,
            local_rank=args.local_rank,
        )
        rparams = simple_runner.RunnerParameters(
            feat_spec=model_resolution.build_featurization_spec(
                model_type=args.model_type,
                max_seq_length=args.max_seq_length,
            ),
            local_rank=args.local_rank,
            n_gpu=quick_init_out.n_gpu,
            fp16=args.fp16,
            learning_rate=args.learning_rate,
            eval_batch_size=args.eval_batch_size,
            max_grad_norm=args.max_grad_norm,
        )
        runner = simple_runner.SimpleTaskRunner(
            task=task,
            model_wrapper=model_wrapper,
            optimizer_scheduler=optimizer_scheduler,
            loss_criterion=loss_criterion,
            device=quick_init_out.device,
            rparams=rparams,
            train_schedule=train_schedule,
            log_writer=quick_init_out.log_writer,
        )

        if args.do_train:
            val_examples = task.get_val_examples()
            metarunner.MetaRunner(
                runner=runner,
                train_examples=train_examples,
                val_examples=
                val_examples[:args.partial_eval_number],  # quick and dirty
                should_save_func=metarunner.get_should_save_func(
                    args.save_every_steps),
                should_eval_func=metarunner.get_should_eval_func(
                    args.eval_every_steps),
                output_dir=args.output_dir,
                verbose=True,
                save_best_model=args.do_save,
                load_best_model=True,
                log_writer=quick_init_out.log_writer,
            ).train_val_save_every()

        if args.do_save:
            torch.save(model_wrapper.model.state_dict(),
                       os.path.join(args.output_dir, "model.p"))

        if args.do_val:
            val_examples = task.get_val_examples()
            results = runner.run_val(val_examples)
            evaluate.write_val_results(
                results=results,
                output_dir=args.output_dir,
                verbose=True,
            )

        if args.do_test:
            test_examples = task.get_test_examples()
            logits = runner.run_test(test_examples)
            evaluate.write_preds(
                logits=logits,
                output_path=os.path.join(args.output_dir, "test_preds.csv"),
            )
Beispiel #2
0
def main(args):
    quick_init_out = initialization.quick_init(args=args, verbose=True)

    with distributed.only_first_process(local_rank=args.local_rank):
        task, task_data = unsup_load_data.load_sup_and_unsup_data(
            task_config_path=args.task_config_path,
            unsup_task_config_path=args.unsup_task_config_path,
        )

        # load the model
        model_class_spec = model_resolution.resolve_model_setup_classes(
            model_type=args.model_type,
            task_type=task.TASK_TYPE,
        )
        model_wrapper = model_setup.simple_model_setup(
            model_type=args.model_type,
            model_class_spec=model_class_spec,
            config_path=args.model_config_path,
            tokenizer_path=args.model_tokenizer_path,
            task=task,
        )
        model_setup.simple_load_model_path(
            model=model_wrapper.model,
            model_load_mode=args.model_load_mode,
            model_path=args.model_path,
        )
        model_wrapper.model.to(quick_init_out.device)
        teacher_model_wrapper = mean_teacher_runner.create_teacher(
            model_wrapper)
        # Teacher not set up for special setup (e.g. fp16) #todo

    num_train_examples = len(task_data["sup"]["train"])
    train_schedule = train_setup.get_train_schedule(
        num_train_examples=num_train_examples,
        max_steps=args.max_steps,
        num_train_epochs=args.num_train_epochs,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        per_gpu_train_batch_size=args.train_batch_size,
        n_gpu=quick_init_out.n_gpu,
    )
    print("t_total", train_schedule.t_total)
    loss_criterion = train_setup.resolve_loss_function(
        task_type=task.TASK_TYPE)
    optimizer_scheduler = model_setup.create_optimizer(
        model=model_wrapper.model,
        learning_rate=args.learning_rate,
        t_total=train_schedule.t_total,
        warmup_steps=args.warmup_steps,
        warmup_proportion=args.warmup_proportion,
        optimizer_type=args.optimizer_type,
        verbose=True,
    )
    model_setup.special_model_setup(
        model_wrapper=model_wrapper,
        optimizer_scheduler=optimizer_scheduler,
        fp16=args.fp16,
        fp16_opt_level=args.fp16_opt_level,
        n_gpu=quick_init_out.n_gpu,
        local_rank=args.local_rank,
    )
    rparams = simple_runner.RunnerParameters(
        feat_spec=model_resolution.build_featurization_spec(
            model_type=args.model_type,
            max_seq_length=args.max_seq_length,
        ),
        local_rank=args.local_rank,
        n_gpu=quick_init_out.n_gpu,
        fp16=args.fp16,
        learning_rate=args.learning_rate,
        eval_batch_size=args.eval_batch_size,
        max_grad_norm=args.max_grad_norm,
    )
    mt_params = mean_teacher_runner.MeanTeacherParameters(
        alpha=args.mt_alpha,
        consistency_type=args.consistency_type,
        consistency_weight=args.consistency_weight,
        consistency_ramp_up_steps=int(args.consistency_ramp_up_fraction *
                                      train_schedule.t_total),
        use_unsup=args.unsup_ratio != 0,
        unsup_ratio=args.unsup_ratio,
    )
    runner = mean_teacher_runner.MeanTeacherRunner(
        task=task,
        model_wrapper=model_wrapper,
        teacher_model_wrapper=teacher_model_wrapper,
        optimizer_scheduler=optimizer_scheduler,
        loss_criterion=loss_criterion,
        device=quick_init_out.device,
        rparams=rparams,
        mt_params=mt_params,
        train_schedule=train_schedule,
        log_writer=quick_init_out.log_writer,
    )

    with quick_init_out.log_writer.log_context():
        if args.do_train:
            val_examples = task.get_val_examples()
            mean_teacher_runner.train_val_save_every(
                runner=runner,
                task_data=task_data,
                val_examples=
                val_examples[:args.partial_eval_number],  # quick and dirty
                should_save_func=metarunner.get_should_save_func(
                    args.save_every_steps),
                should_eval_func=metarunner.get_should_eval_func(
                    args.eval_every_steps),
                output_dir=args.output_dir,
                verbose=True,
                save_best_model=args.do_save,
                load_best_model=True,
                log_writer=quick_init_out.log_writer,
            )

        if args.do_save:
            torch.save(model_wrapper.model.state_dict(),
                       os.path.join(args.output_dir, "model.p"))

        if args.do_val:
            val_examples = task.get_val_examples()
            results = runner.run_val(val_examples)
            evaluate.write_val_results(
                results=results,
                output_dir=args.output_dir,
                verbose=True,
            )

        if args.do_test:
            test_examples = task.get_test_examples()
            logits = runner.run_test(test_examples)
            evaluate.write_preds(
                logits=logits,
                output_path=os.path.join(args.output_dir, "test_preds.csv"),
            )
Beispiel #3
0
 def create(self, num_train_examples, log_writer=zlogv1.PRINT_LOGGER):
     train_schedule = train_setup.get_train_schedule(
         num_train_examples=num_train_examples,
         max_steps=self.max_steps,
         num_train_epochs=self.num_train_epochs,
         gradient_accumulation_steps=self.gradient_accumulation_steps,
         per_gpu_train_batch_size=self.train_batch_size,
         n_gpu=self.n_gpu,
     )
     with distributed.only_first_process(local_rank=self.local_rank):
         # load the model
         model_class_spec = model_resolution.resolve_model_setup_classes(
             model_type=self.model_type,
             task_type=self.task.TASK_TYPE,
         )
         model_wrapper = model_setup.simple_model_setup(
             model_type=self.model_type,
             model_class_spec=model_class_spec,
             config_path=self.model_config_path,
             tokenizer_path=self.model_tokenizer_path,
             task=self.task,
         )
         model_setup.simple_load_model_path(
             model=model_wrapper.model,
             model_load_mode=self.model_load_mode,
             model_path=self.model_path,
         )
         model_wrapper.model.to(self.device)
     optimizer_scheduler = model_setup.create_optimizer(
         model=model_wrapper.model,
         learning_rate=self.learning_rate,
         t_total=train_schedule.t_total,
         warmup_steps=self.warmup_steps,
         warmup_proportion=self.warmup_proportion,
         optimizer_type=self.optimizer_type,
         verbose=self.verbose,
     )
     model_setup.special_model_setup(
         model_wrapper=model_wrapper,
         optimizer_scheduler=optimizer_scheduler,
         fp16=self.fp16,
         fp16_opt_level=self.fp16_opt_level,
         n_gpu=self.n_gpu,
         local_rank=self.local_rank,
     )
     loss_criterion = train_setup.resolve_loss_function(
         task_type=self.task.TASK_TYPE)
     rparams = simple_runner.RunnerParameters(
         feat_spec=model_resolution.build_featurization_spec(
             model_type=self.model_type,
             max_seq_length=self.max_seq_length,
         ),
         local_rank=self.local_rank,
         n_gpu=self.n_gpu,
         fp16=self.fp16,
         learning_rate=self.learning_rate,
         eval_batch_size=self.eval_batch_size,
         max_grad_norm=self.max_grad_norm,
     )
     runner = simple_runner.SimpleTaskRunner(
         task=self.task,
         model_wrapper=model_wrapper,
         optimizer_scheduler=optimizer_scheduler,
         loss_criterion=loss_criterion,
         device=self.device,
         rparams=rparams,
         train_schedule=train_schedule,
         log_writer=log_writer,
     )
     return runner
def main(args):
    quick_init_out = initialization.quick_init(args=args, verbose=False)
    task = tasks.create_task_from_config_path(
        config_path=args.task_config_path,
        verbose=True,
    )
    with distributed.only_first_process(local_rank=args.local_rank):
        model_class_spec = model_resolution.resolve_model_setup_classes(
            model_type=args.model_type,
            task_type=task.TASK_TYPE,
        )
        model_wrapper = model_setup.simple_model_setup(
            model_type=args.model_type,
            model_class_spec=model_class_spec,
            config_path=args.model_config_path,
            tokenizer_path=args.model_tokenizer_path,
            task=task,
        )
    model_setup.simple_load_model_path(
        model=model_wrapper.model,
        model_load_mode=args.model_load_mode,
        model_path=args.model_path,
    )
    adapter_weights_dict = multi_adapters.load_adapter_weights_dict_path(
        args.adapter_weights_path)
    multi_adapters.exclude_adapters(
        adapter_weights_dict, exclude_list=args.adapter_exclude.split(","))
    adapter_weights_dict = multi_adapters.isolate_adapter_weights_dict(
        adapter_weights_dict=adapter_weights_dict,
        model_type=args.model_type,
    )
    sub_module_name_list = list(adapter_weights_dict.keys())
    modified_layers = multi_adapters.add_multi_adapters(
        model=model_wrapper.model,
        sub_module_name_list=sub_module_name_list,
        adapter_config=adapters.AdapterConfig(),
        include_base=args.adapter_include_base,
        include_flex=args.adapter_include_flex,
        num_weight_sets=args.adapter_num_weight_sets,
        use_optimized=args.adapter_use_optimized,
    )
    multi_adapters.load_multi_adapter_weights(
        model=model_wrapper.model,
        modified_layers=modified_layers,
        adapter_weights_dict=adapter_weights_dict,
    )
    model_wrapper.model.to(quick_init_out.device)
    tunable_parameters = multi_adapters.get_tunable_parameters(
        model=model_wrapper.model,
        modified_layers=modified_layers,
        ft_mode=args.adapter_ft_mode,
    )

    train_examples = task.get_train_examples()
    train_examples, _ = train_setup.maybe_subsample_train(
        train_examples=train_examples,
        train_examples_number=args.train_examples_number,
        train_examples_fraction=args.train_examples_fraction,
    )
    num_train_examples = len(train_examples)

    train_schedule = train_setup.get_train_schedule(
        num_train_examples=num_train_examples,
        max_steps=args.max_steps,
        num_train_epochs=args.num_train_epochs,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        per_gpu_train_batch_size=args.train_batch_size,
        n_gpu=quick_init_out.n_gpu,
    )
    loss_criterion = train_setup.resolve_loss_function(
        task_type=task.TASK_TYPE)
    optimizer_scheduler = model_setup.create_optimizer_from_params(
        named_parameters=tunable_parameters,
        learning_rate=args.learning_rate,
        t_total=train_schedule.t_total,
        warmup_steps=args.warmup_steps,
        warmup_proportion=args.warmup_proportion,
        optimizer_type=args.optimizer_type,
        verbose=True,
    )
    model_setup.special_model_setup(
        model_wrapper=model_wrapper,
        optimizer_scheduler=optimizer_scheduler,
        fp16=args.fp16,
        fp16_opt_level=args.fp16_opt_level,
        n_gpu=quick_init_out.n_gpu,
        local_rank=args.local_rank,
    )
    rparams = simple_runner.RunnerParameters(
        feat_spec=model_resolution.build_featurization_spec(
            model_type=args.model_type,
            max_seq_length=args.max_seq_length,
        ),
        local_rank=args.local_rank,
        n_gpu=quick_init_out.n_gpu,
        fp16=args.fp16,
        learning_rate=args.learning_rate,
        eval_batch_size=args.eval_batch_size,
        max_grad_norm=args.max_grad_norm,
    )
    runner = simple_runner.SimpleTaskRunner(
        task=task,
        model_wrapper=model_wrapper,
        optimizer_scheduler=optimizer_scheduler,
        loss_criterion=loss_criterion,
        device=quick_init_out.device,
        rparams=rparams,
        train_schedule=train_schedule,
        log_writer=quick_init_out.log_writer,
    )
    if args.do_train:
        val_examples = task.get_val_examples()
        adapters_runner.AdapterMetaRunner(
            runner=runner,
            train_examples=train_examples,
            val_examples=val_examples[:args.
                                      partial_eval_number],  # quick and dirty
            should_save_func=metarunner.get_should_save_func(
                args.save_every_steps),
            should_eval_func=metarunner.get_should_eval_func(
                args.eval_every_steps),
            output_dir=args.output_dir,
            verbose=True,
            save_best_model=args.do_save,
            load_best_model=True,
            log_writer=quick_init_out.log_writer,
            modified_layers=modified_layers,
        ).train_val_save_every()

        if args.do_save:
            torch.save(model_wrapper.model.state_dict(),
                       os.path.join(args.output_dir, "model.p"))

        if args.do_val:
            val_examples = task.get_val_examples()
            results = runner.run_val(val_examples)
            evaluate.write_val_results(
                results=results,
                output_dir=args.output_dir,
                verbose=True,
            )

        if args.do_test:
            test_examples = task.get_test_examples()
            logits = runner.run_test(test_examples)
            evaluate.write_preds(
                logits=logits,
                output_path=os.path.join(args.output_dir, "test_preds.csv"),
            )
Beispiel #5
0
def main(args):
    quick_init_out = initialization.quick_init(args=args, verbose=True)
    with quick_init_out.log_writer.log_context():
        task_dict = create_task_dict(
            multitask_config_path=args.multitask_config_path,
            task_name_ls=args.task_name_ls,
        )
        with distributed.only_first_process(local_rank=args.local_rank):
            # load the model
            model_wrapper = multitask_model_setup.setup_multitask_ptt_model(
                model_type=args.model_type,
                config_path=args.model_config_path,
                tokenizer_path=args.model_tokenizer_path,
                task_dict=task_dict,
            )
            model_setup.simple_load_model_path(
                model=model_wrapper.model.model_dict[list(
                    task_dict.keys())[0]],
                model_load_mode=args.model_load_mode,
                model_path=args.model_path,
                verbose=True,
            )
            model_wrapper.model.to(quick_init_out.device)

        train_examples_dict = {}
        for task_name, task in task_dict.items():
            train_examples = task.get_train_examples()
            train_examples, _ = train_setup.maybe_subsample_train(
                train_examples=train_examples,
                train_examples_number=args.train_examples_number,
                train_examples_fraction=args.train_examples_fraction,
            )
            train_examples_dict[task_name] = train_examples

        # TODO: Tweak the schedule
        total_num_train_examples = sum(
            len(train_examples)
            for train_examples in train_examples_dict.values())

        train_schedule = train_setup.get_train_schedule(
            num_train_examples=total_num_train_examples,
            max_steps=args.max_steps,
            num_train_epochs=args.num_train_epochs,
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            per_gpu_train_batch_size=args.train_batch_size,
            n_gpu=quick_init_out.n_gpu,
        )
        quick_init_out.log_writer.write_entry(
            "text", f"t_total: {train_schedule.t_total}", do_print=True)
        loss_criterion_dict = {
            task_name:
            train_setup.resolve_loss_function(task_type=task.TASK_TYPE)
            for task_name, task in task_dict.items()
        }
        optimizer_scheduler = model_setup.create_optimizer(
            model=model_wrapper.model,
            learning_rate=args.learning_rate,
            t_total=train_schedule.t_total,
            warmup_steps=args.warmup_steps,
            warmup_proportion=args.warmup_proportion,
            optimizer_type=args.optimizer_type,
            verbose=True,
        )
        model_setup.special_model_setup(
            model_wrapper=model_wrapper,
            optimizer_scheduler=optimizer_scheduler,
            fp16=args.fp16,
            fp16_opt_level=args.fp16_opt_level,
            n_gpu=quick_init_out.n_gpu,
            local_rank=args.local_rank,
        )
        rparams = simple_runner.RunnerParameters(
            feat_spec=model_resolution.build_featurization_spec(
                model_type=args.model_type,
                max_seq_length=args.max_seq_length,
            ),
            local_rank=args.local_rank,
            n_gpu=quick_init_out.n_gpu,
            fp16=args.fp16,
            learning_rate=args.learning_rate,
            eval_batch_size=args.eval_batch_size,
            max_grad_norm=args.max_grad_norm,
        )
        runner = multitask_runner.MultiTaskRunner(
            task_dict=task_dict,
            model_wrapper=model_wrapper,
            optimizer_scheduler=optimizer_scheduler,
            loss_criterion_dict=loss_criterion_dict,
            device=quick_init_out.device,
            rparams=rparams,
            train_schedule=train_schedule,
            log_writer=quick_init_out.log_writer,
        )

        if args.do_train:
            val_examples_dict = {
                task_name: task.get_val_examples()[:args.partial_eval_number]
                for task_name, task in task_dict.items()
            }
            metarunner.MetaRunner(
                runner=runner,
                train_examples=train_examples_dict,
                val_examples=val_examples_dict,  # quick and dirty
                should_save_func=metarunner.get_should_save_func(
                    args.save_every_steps),
                should_eval_func=metarunner.get_should_eval_func(
                    args.eval_every_steps),
                output_dir=args.output_dir,
                verbose=True,
                save_best_model=args.do_save,
                load_best_model=True,
                log_writer=quick_init_out.log_writer,
            ).train_val_save_every()

        if args.do_save:
            torch.save(model_wrapper.model.state_dict(),
                       os.path.join(args.output_dir, "model.p"))

        if args.do_val:
            val_examples_dict = {
                task_name: task.get_val_examples()[:args.partial_eval_number]
                for task_name, task in task_dict.items()
            }
            results = runner.run_val(val_examples_dict)
            evaluate.write_metrics(
                results=results,
                output_path=os.path.join(args.output_dir, "val_metrics.json"),
                verbose=True,
            )

        if args.do_test:
            raise NotImplementedError()
def main(args):
    quick_init_out = initialization.quick_init(args=args, verbose=True)
    task = tasks.create_task_from_config_path(
        config_path=args.task_config_path)
    for phase in ["train", "val", "test"]:
        if phase in task.path_dict:
            print(task.path_dict[phase])

    with distributed.only_first_process(local_rank=args.local_rank):
        # load the model
        model_class_spec = model_resolution.resolve_model_setup_classes(
            model_type=args.model_type,
            task_type=task.TASK_TYPE,
        )
        model_wrapper = model_setup.simple_model_setup(
            model_type=args.model_type,
            model_class_spec=model_class_spec,
            config_path=args.model_config_path,
            tokenizer_path=args.model_tokenizer_path,
            task=task,
        )
        model_setup.safe_load_model(model=model_wrapper.model,
                                    state_dict=torch.load(args.model_path))
        model_wrapper.model.to(quick_init_out.device)

    train_examples = task.get_train_examples()
    for train_path in args.extra_train_paths:
        print(f"Extra: {train_path}")
        extra_task = task.__class__(task.name, {"train": train_path})
        train_examples += extra_task.get_train_examples()
    num_train_examples = len(train_examples)

    train_schedule = train_setup.get_train_schedule(
        num_train_examples=num_train_examples,
        max_steps=args.max_steps,
        num_train_epochs=args.num_train_epochs,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        per_gpu_train_batch_size=args.train_batch_size,
        n_gpu=quick_init_out.n_gpu,
    )
    print("t_total", train_schedule.t_total)
    loss_criterion = train_setup.resolve_loss_function(
        task_type=task.TASK_TYPE)
    optimizer_scheduler = model_setup.create_optimizer(
        model=model_wrapper.model,
        learning_rate=args.learning_rate,
        t_total=train_schedule.t_total,
        warmup_steps=args.warmup_steps,
        warmup_proportion=args.warmup_proportion,
        optimizer_type=args.optimizer_type,
        verbose=True,
    )
    model_setup.special_model_setup(
        model_wrapper=model_wrapper,
        optimizer_scheduler=optimizer_scheduler,
        fp16=args.fp16,
        fp16_opt_level=args.fp16_opt_level,
        n_gpu=quick_init_out.n_gpu,
        local_rank=args.local_rank,
    )
    rparams = simple_runner.RunnerParameters(
        feat_spec=model_resolution.build_featurization_spec(
            model_type=args.model_type,
            max_seq_length=args.max_seq_length,
        ),
        local_rank=args.local_rank,
        n_gpu=quick_init_out.n_gpu,
        fp16=args.fp16,
        learning_rate=args.learning_rate,
        eval_batch_size=args.eval_batch_size,
        max_grad_norm=args.max_grad_norm,
    )
    runner = simple_runner.SimpleTaskRunner(
        task=task,
        model_wrapper=model_wrapper,
        optimizer_scheduler=optimizer_scheduler,
        loss_criterion=loss_criterion,
        device=quick_init_out.device,
        rparams=rparams,
        train_schedule=train_schedule,
        log_writer=quick_init_out.log_writer,
    )

    with quick_init_out.log_writer.log_context():
        if args.do_train:
            runner.run_train(train_examples)

        if args.do_save:
            torch.save(model_wrapper.model.state_dict(),
                       os.path.join(args.output_dir, "model.p"))

        if args.do_val:
            val_examples = task.get_val_examples()
            results = runner.run_val(val_examples)
            evaluate.write_val_results(
                results=results,
                output_dir=args.output_dir,
                verbose=True,
            )

        if args.do_test:
            test_examples = task.get_test_examples()
            logits = runner.run_test(test_examples)
            evaluate.write_preds(
                logits=logits,
                output_path=os.path.join(args.output_dir, "test_preds.csv"),
            )