Пример #1
0
    def n_training_step(self, step, training_set: TrainingSet):
        logits_ls = []
        runner_ls = []
        for i in tqdm.trange(self.rparams.num_models):
            runner_train_examples = training_set.get_training_examples(i)
            sub_log_writer = self.get_sub_log_writer(f"runner{step}__{i}")
            with sub_log_writer.log_context():
                runner = self.runner_creator.create(
                    num_train_examples=len(runner_train_examples),
                    log_writer=sub_log_writer,
                )
                val_examples = self.task.get_val_examples()
                mrunner = metarunner.MetaRunner(
                    runner=runner,
                    train_examples=runner_train_examples,
                    # quick and dirty
                    val_examples=val_examples[:self.meta_runner_parameters.
                                              partial_eval_number],
                    should_save_func=metarunner.get_should_save_func(
                        self.meta_runner_parameters.save_every_steps),
                    should_eval_func=metarunner.get_should_eval_func(
                        self.meta_runner_parameters.eval_every_steps),
                    output_dir=self.meta_runner_parameters.output_dir,
                    verbose=True,
                    save_best_model=self.meta_runner_parameters.do_save,
                    load_best_model=True,
                    log_writer=self.log_writer,
                )
                mrunner.train_val_save_every()
                logits = runner.run_test(self.unlabeled_examples)
                runner_save_memory(runner)
            logits_ls.append(logits)
            runner_ls.append(runner)

        all_logits = np.stack(logits_ls, axis=1)
        self.log_writer.write_obj("sub_runner_logits", all_logits, {
            "step": step,
        })
        self.log_writer.flush()

        chosen_examples, chosen_preds = get_n_training_pseudolabels(
            all_logits=all_logits,
            with_disagreement=self.rparams.with_disagreement,
            confidence_threshold=self.rparams.confidence_threshold,
        )

        new_training_set = TrainingSet(
            task=training_set.task,
            labeled_examples=training_set.labeled_examples,
            unlabeled_examples=training_set.unlabeled_examples,
            labeled_indices=training_set.labeled_indices,
            chosen_examples=chosen_examples,
            chosen_preds=chosen_preds,
        )
        return new_training_set, runner_ls
Пример #2
0
def main(args):
    quick_init_out = initialization.quick_init(args=args, verbose=True)
    with quick_init_out.log_writer.log_context():
        task = tasks.create_task_from_config_path(
            config_path=args.task_config_path,
            verbose=True,
        )

        with distributed.only_first_process(local_rank=args.local_rank):
            # load the model
            model_class_spec = model_resolution.resolve_model_setup_classes(
                model_type=args.model_type,
                task_type=task.TASK_TYPE,
            )
            model_wrapper = model_setup.simple_model_setup(
                model_type=args.model_type,
                model_class_spec=model_class_spec,
                config_path=args.model_config_path,
                tokenizer_path=args.model_tokenizer_path,
                task=task,
            )
            model_setup.simple_load_model_path(
                model=model_wrapper.model,
                model_load_mode=args.model_load_mode,
                model_path=args.model_path,
                verbose=True,
            )
            model_wrapper.model.to(quick_init_out.device)

        train_examples = task.get_train_examples()
        train_examples, _ = train_setup.maybe_subsample_train(
            train_examples=train_examples,
            train_examples_number=args.train_examples_number,
            train_examples_fraction=args.train_examples_fraction,
        )
        num_train_examples = len(train_examples)

        train_schedule = train_setup.get_train_schedule(
            num_train_examples=num_train_examples,
            max_steps=args.max_steps,
            num_train_epochs=args.num_train_epochs,
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            per_gpu_train_batch_size=args.train_batch_size,
            n_gpu=quick_init_out.n_gpu,
        )
        quick_init_out.log_writer.write_entry(
            "text", f"t_total: {train_schedule.t_total}", do_print=True)
        loss_criterion = train_setup.resolve_loss_function(
            task_type=task.TASK_TYPE)
        optimizer_scheduler = model_setup.create_optimizer(
            model=model_wrapper.model,
            learning_rate=args.learning_rate,
            t_total=train_schedule.t_total,
            warmup_steps=args.warmup_steps,
            warmup_proportion=args.warmup_proportion,
            optimizer_type=args.optimizer_type,
            verbose=True,
        )
        model_setup.special_model_setup(
            model_wrapper=model_wrapper,
            optimizer_scheduler=optimizer_scheduler,
            fp16=args.fp16,
            fp16_opt_level=args.fp16_opt_level,
            n_gpu=quick_init_out.n_gpu,
            local_rank=args.local_rank,
        )
        rparams = simple_runner.RunnerParameters(
            feat_spec=model_resolution.build_featurization_spec(
                model_type=args.model_type,
                max_seq_length=args.max_seq_length,
            ),
            local_rank=args.local_rank,
            n_gpu=quick_init_out.n_gpu,
            fp16=args.fp16,
            learning_rate=args.learning_rate,
            eval_batch_size=args.eval_batch_size,
            max_grad_norm=args.max_grad_norm,
        )
        runner = simple_runner.SimpleTaskRunner(
            task=task,
            model_wrapper=model_wrapper,
            optimizer_scheduler=optimizer_scheduler,
            loss_criterion=loss_criterion,
            device=quick_init_out.device,
            rparams=rparams,
            train_schedule=train_schedule,
            log_writer=quick_init_out.log_writer,
        )

        if args.do_train:
            val_examples = task.get_val_examples()
            metarunner.MetaRunner(
                runner=runner,
                train_examples=train_examples,
                val_examples=
                val_examples[:args.partial_eval_number],  # quick and dirty
                should_save_func=metarunner.get_should_save_func(
                    args.save_every_steps),
                should_eval_func=metarunner.get_should_eval_func(
                    args.eval_every_steps),
                output_dir=args.output_dir,
                verbose=True,
                save_best_model=args.do_save,
                load_best_model=True,
                log_writer=quick_init_out.log_writer,
            ).train_val_save_every()

        if args.do_save:
            torch.save(model_wrapper.model.state_dict(),
                       os.path.join(args.output_dir, "model.p"))

        if args.do_val:
            val_examples = task.get_val_examples()
            results = runner.run_val(val_examples)
            evaluate.write_val_results(
                results=results,
                output_dir=args.output_dir,
                verbose=True,
            )

        if args.do_test:
            test_examples = task.get_test_examples()
            logits = runner.run_test(test_examples)
            evaluate.write_preds(
                logits=logits,
                output_path=os.path.join(args.output_dir, "test_preds.csv"),
            )
Пример #3
0
def main(args):
    quick_init_out = initialization.quick_init(args=args, verbose=True)
    with quick_init_out.log_writer.log_context():
        task_dict = create_task_dict(
            multitask_config_path=args.multitask_config_path,
            task_name_ls=args.task_name_ls,
        )
        with distributed.only_first_process(local_rank=args.local_rank):
            # load the model
            model_wrapper = multitask_model_setup.setup_multitask_ptt_model(
                model_type=args.model_type,
                config_path=args.model_config_path,
                tokenizer_path=args.model_tokenizer_path,
                task_dict=task_dict,
            )
            model_setup.simple_load_model_path(
                model=model_wrapper.model.model_dict[list(
                    task_dict.keys())[0]],
                model_load_mode=args.model_load_mode,
                model_path=args.model_path,
                verbose=True,
            )
            model_wrapper.model.to(quick_init_out.device)

        train_examples_dict = {}
        for task_name, task in task_dict.items():
            train_examples = task.get_train_examples()
            train_examples, _ = train_setup.maybe_subsample_train(
                train_examples=train_examples,
                train_examples_number=args.train_examples_number,
                train_examples_fraction=args.train_examples_fraction,
            )
            train_examples_dict[task_name] = train_examples

        # TODO: Tweak the schedule
        total_num_train_examples = sum(
            len(train_examples)
            for train_examples in train_examples_dict.values())

        train_schedule = train_setup.get_train_schedule(
            num_train_examples=total_num_train_examples,
            max_steps=args.max_steps,
            num_train_epochs=args.num_train_epochs,
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            per_gpu_train_batch_size=args.train_batch_size,
            n_gpu=quick_init_out.n_gpu,
        )
        quick_init_out.log_writer.write_entry(
            "text", f"t_total: {train_schedule.t_total}", do_print=True)
        loss_criterion_dict = {
            task_name:
            train_setup.resolve_loss_function(task_type=task.TASK_TYPE)
            for task_name, task in task_dict.items()
        }
        optimizer_scheduler = model_setup.create_optimizer(
            model=model_wrapper.model,
            learning_rate=args.learning_rate,
            t_total=train_schedule.t_total,
            warmup_steps=args.warmup_steps,
            warmup_proportion=args.warmup_proportion,
            optimizer_type=args.optimizer_type,
            verbose=True,
        )
        model_setup.special_model_setup(
            model_wrapper=model_wrapper,
            optimizer_scheduler=optimizer_scheduler,
            fp16=args.fp16,
            fp16_opt_level=args.fp16_opt_level,
            n_gpu=quick_init_out.n_gpu,
            local_rank=args.local_rank,
        )
        rparams = simple_runner.RunnerParameters(
            feat_spec=model_resolution.build_featurization_spec(
                model_type=args.model_type,
                max_seq_length=args.max_seq_length,
            ),
            local_rank=args.local_rank,
            n_gpu=quick_init_out.n_gpu,
            fp16=args.fp16,
            learning_rate=args.learning_rate,
            eval_batch_size=args.eval_batch_size,
            max_grad_norm=args.max_grad_norm,
        )
        runner = multitask_runner.MultiTaskRunner(
            task_dict=task_dict,
            model_wrapper=model_wrapper,
            optimizer_scheduler=optimizer_scheduler,
            loss_criterion_dict=loss_criterion_dict,
            device=quick_init_out.device,
            rparams=rparams,
            train_schedule=train_schedule,
            log_writer=quick_init_out.log_writer,
        )

        if args.do_train:
            val_examples_dict = {
                task_name: task.get_val_examples()[:args.partial_eval_number]
                for task_name, task in task_dict.items()
            }
            metarunner.MetaRunner(
                runner=runner,
                train_examples=train_examples_dict,
                val_examples=val_examples_dict,  # quick and dirty
                should_save_func=metarunner.get_should_save_func(
                    args.save_every_steps),
                should_eval_func=metarunner.get_should_eval_func(
                    args.eval_every_steps),
                output_dir=args.output_dir,
                verbose=True,
                save_best_model=args.do_save,
                load_best_model=True,
                log_writer=quick_init_out.log_writer,
            ).train_val_save_every()

        if args.do_save:
            torch.save(model_wrapper.model.state_dict(),
                       os.path.join(args.output_dir, "model.p"))

        if args.do_val:
            val_examples_dict = {
                task_name: task.get_val_examples()[:args.partial_eval_number]
                for task_name, task in task_dict.items()
            }
            results = runner.run_val(val_examples_dict)
            evaluate.write_metrics(
                results=results,
                output_path=os.path.join(args.output_dir, "val_metrics.json"),
                verbose=True,
            )

        if args.do_test:
            raise NotImplementedError()