Example #1
0
    def train_val_save_every(self):
        assert not self.single_use_check
        self.single_use_check = True

        for _ in maybe_trange(
                int(self.train_schedule.num_train_epochs), desc="Epoch", verbose=self.verbose):
            train_dataloader = self.runner.get_train_dataloader(self.train_examples)
            for _ in self.runner.run_train_epoch_context(
                    train_dataloader=train_dataloader,
                    train_global_state=self.train_global_state,
                    verbose=self.verbose):
                self.inject_at_step()

                if self.should_save_func(self.train_global_state):
                    save_model_with_metadata(
                        model=self.model,
                        metadata={},
                        output_dir=self.output_dir,
                        file_name=f"model__{self.train_global_state.global_step}.p",
                    )
                if self.should_eval_func(self.train_global_state):
                    self.eval_save()

                if self.train_schedule.max_steps is not None and \
                        self.train_schedule.max_steps != -1 and \
                        self.train_global_state.global_step >= self.train_schedule.max_steps:
                    self.full_break = True

                if compare_steps_max_steps(
                        step=self.train_global_state.global_step,
                        max_steps=self.train_schedule.max_steps):
                    self.full_break = True

                if self.full_break:
                    break

            if self.full_break:
                break

            self.inject_at_epoch()

        # End of training eval
        self.eval_save()

        if self.load_best_model and self.best_state_dict is not None:
            if self.verbose:
                print("Loading Best")
            self.model.load_state_dict(copy_state_dict(
                state_dict=self.best_state_dict,
                target_device=self.device,
            ))

        return {
            "best_val_state": self.best_val_state,
            "val_state_history": self.val_state_history,
        }
Example #2
0
 def run_train(self, train_examples, uda_task_data, verbose=True):
     train_dataset_with_metadata = self.convert_examples_to_dataset(
         examples=train_examples,
         verbose=verbose,
     )
     train_global_state = TrainGlobalState()
     for _ in maybe_trange(int(self.train_schedule.num_train_epochs), desc="Epoch", verbose=verbose):
         self.run_train_epoch(
             train_dataset_with_metadata=train_dataset_with_metadata,
             uda_task_data=uda_task_data,
             train_global_state=train_global_state,
             verbose=verbose,
         )
Example #3
0
 def run_train_val(self, train_examples, val_examples, verbose=True):
     epoch_result_dict = col.OrderedDict()
     train_global_state = TrainGlobalState()
     for epoch_i in maybe_trange(int(self.train_schedule.num_train_epochs),
                                 desc="Epoch",
                                 verbose=verbose):
         train_global_state.epoch = epoch_i
         train_dataloader = self.get_train_dataloader(train_examples)
         self.run_train_epoch(train_dataloader, train_global_state)
         epoch_result = self.run_val(val_examples)
         del epoch_result["logits"]
         epoch_result["metrics"] = epoch_result["metrics"].asdict()
         epoch_result_dict[epoch_i] = epoch_result
     return epoch_result_dict
Example #4
0
    def run_train(self, train_examples, verbose=True):
        train_dataloader = self.get_train_dataloader(train_examples)
        train_global_state = TrainGlobalState()

        for epoch_i in \
                maybe_trange(int(self.train_schedule.num_train_epochs), desc="Epoch", verbose=verbose):
            train_global_state.epoch = epoch_i
            self.run_train_epoch(train_dataloader, train_global_state)
            results = self.run_val(val_examples=self.task.get_val_examples())
            self.log_writer.write_entry(
                "val_metric", {
                    "epoch": train_global_state.epoch,
                    "metric": results["metrics"].asdict(),
                })
            self.log_writer.flush()
Example #5
0
 def run_train_val(self, train_examples, val_examples, verbose=True):
     train_dataset_with_metadata = self.convert_examples_to_dataset(
         examples=train_examples,
         verbose=verbose,
     )
     train_global_state = TrainGlobalState()
     epoch_result_dict = OrderedDict()
     for i in maybe_trange(int(self.train_schedule.num_train_epochs),
                           desc="Epoch",
                           verbose=verbose):
         self.run_train_epoch(train_dataset_with_metadata,
                              train_global_state,
                              verbose=verbose)
         epoch_result = self.run_val(val_examples)
         del epoch_result["logits"]
         epoch_result["metrics"] = epoch_result["metrics"].asdict()
         epoch_result_dict[i] = epoch_result
     return epoch_result_dict
Example #6
0
    def run_train(self, task_data, verbose=True):
        train_global_state = TrainGlobalState()
        sup_dataloader = self.get_sup_dataloader(
            task_data=task_data,
            verbose=verbose,
        )

        for epoch_i in maybe_trange(int(self.train_schedule.num_train_epochs),
                                    desc="Epoch",
                                    verbose=verbose):
            train_global_state.epoch = epoch_i
            unsup_dataloaders = self.get_unsup_dataloaders(
                sup_dataloader=sup_dataloader,
                task_data=task_data,
            )
            if self.uda_params.use_unsup:
                self.log_writer.write_entry(
                    "misc", {
                        "unsup_indices": [
                            int(x) for x in
                            unsup_dataloaders.metadata["unsup_indices"]
                        ],
                        "unsup_aug_set": [
                            int(x) for x in
                            unsup_dataloaders.metadata["unsup_aug_set"]
                        ],
                    })
                self.log_writer.flush()
            dataloader_triplet = self.form_dataloader_triplet(
                sup_dataloader=sup_dataloader,
                unsup_orig_loader=unsup_dataloaders.unsup_orig,
                unsup_aug_loader=unsup_dataloaders.unsup_aug,
            )
            self.run_train_epoch(dataloader_triplet,
                                 train_global_state,
                                 verbose=verbose)
            results = self.run_val(val_examples=self.task.get_val_examples())
            self.log_writer.write_entry(
                "val_metric", {
                    "epoch": train_global_state.epoch,
                    "metric": results["metrics"].asdict(),
                })
Example #7
0
def train_val_save_every(runner: LLPRunner,
                         train_examples: list,
                         val_examples: list,
                         should_save_func,
                         should_eval_func,
                         output_dir,
                         verbose: bool = True,
                         save_best_model: bool = True,
                         load_best_model: bool = True,
                         log_writer: BaseZLogger = PRINT_LOGGER):

    train_global_state = TrainGlobalState()
    best_val_state = None
    best_state_dict = None
    full_break = False
    val_state_history = []

    train_dataset_with_metadata = runner.convert_examples_to_dataset(
        examples=train_examples,
        verbose=verbose,
    )

    for _ in maybe_trange(int(runner.train_schedule.num_train_epochs),
                          desc="Epoch",
                          verbose=verbose):
        for _ in runner.run_train_epoch_context(
                train_dataset_with_metadata=train_dataset_with_metadata,
                train_global_state=train_global_state,
                verbose=verbose):
            if should_save_func(train_global_state):
                metarunner.save_model_with_metadata(
                    model=runner.model,
                    metadata={},
                    output_dir=output_dir,
                    file_name=f"model__{train_global_state.global_step}.p",
                )
            if should_eval_func(train_global_state):
                val_result = runner.run_val(val_examples)
                val_state = metarunner.ValState(
                    score=val_result["metrics"].major,
                    train_global_state=train_global_state.new(),
                )
                log_writer.write_entry("train_val", val_state.asdict())
                log_writer.flush()
                if best_val_state is None or val_state.score > best_val_state.score:
                    best_val_state = val_state.new()
                    log_writer.write_entry("train_val_best",
                                           best_val_state.asdict())
                    log_writer.flush()
                    if save_best_model:
                        metarunner.save_model_with_metadata(
                            model=runner.model,
                            metadata={
                                "val_state": best_val_state.asdict(),
                            },
                            output_dir=output_dir,
                            file_name="best_model.p",
                        )
                    best_state_dict = metarunner.copy_state_dict(
                        state_dict=runner.model.state_dict(),
                        target_device=CPU_DEVICE,
                    )
                val_state_history.append(val_state)
            if runner.train_schedule.max_steps != -1 and \
                    train_global_state.global_step >= runner.train_schedule.max_steps:
                full_break = True

            if metarunner.compare_steps_max_steps(
                    step=train_global_state.global_step,
                    max_steps=runner.train_schedule.max_steps):
                full_break = True

            if full_break:
                break

        if full_break:
            break

    if load_best_model and best_state_dict is not None:
        if verbose:
            print("Loading Best")
        runner.model.load_state_dict(
            metarunner.copy_state_dict(
                state_dict=best_state_dict,
                target_device=runner.device,
            ))

    return {
        "best_val_state": best_val_state,
        "val_state_history": val_state_history,
    }
Example #8
0
def train_val_save_every(runner: UDARunner,
                         task_data: dict,
                         val_examples: list,
                         should_save_func,
                         should_eval_func,
                         output_dir,
                         verbose: bool = True,
                         save_best_model: bool = True,
                         load_best_model: bool = True,
                         log_writer: BaseZLogger = PRINT_LOGGER):
    # HACK: from nlpr.shared.metarunner # todo: refactor

    train_global_state = TrainGlobalState()
    best_val_state = None
    best_state_dict = None
    full_break = False
    val_state_history = []

    sup_dataloader = runner.get_sup_dataloader(
        task_data=task_data,
        verbose=verbose,
    )

    for _ in maybe_trange(int(runner.train_schedule.num_train_epochs),
                          desc="Epoch",
                          verbose=verbose):
        unsup_dataloaders = runner.get_unsup_dataloaders(
            sup_dataloader=sup_dataloader,
            task_data=task_data,
        )
        if runner.uda_params.use_unsup:
            runner.log_writer.write_entry(
                "misc", {
                    "unsup_indices": [
                        int(x)
                        for x in unsup_dataloaders.metadata["unsup_indices"]
                    ],
                    "unsup_aug_set": [
                        int(x)
                        for x in unsup_dataloaders.metadata["unsup_aug_set"]
                    ],
                })
            runner.log_writer.flush()
        dataloader_triplet = runner.form_dataloader_triplet(
            sup_dataloader=sup_dataloader,
            unsup_orig_loader=unsup_dataloaders.unsup_orig,
            unsup_aug_loader=unsup_dataloaders.unsup_aug,
        )
        for _ in runner.run_train_epoch_context(
                dataloader_triplet=dataloader_triplet,
                train_global_state=train_global_state,
                verbose=verbose):
            if should_save_func(train_global_state):
                metarunner.save_model_with_metadata(
                    model=runner.model,
                    metadata={},
                    output_dir=output_dir,
                    file_name=f"model__{train_global_state.global_step}.p",
                )
            if should_eval_func(train_global_state):
                val_result = runner.run_val(val_examples)
                val_state = metarunner.ValState(
                    score=val_result["metrics"].major,
                    train_global_state=train_global_state.new(),
                )
                log_writer.write_entry("train_val", val_state.asdict())
                log_writer.flush()
                if best_val_state is None or val_state.score > best_val_state.score:
                    best_val_state = val_state.new()
                    log_writer.write_entry("train_val_best",
                                           best_val_state.asdict())
                    log_writer.flush()
                    if save_best_model:
                        metarunner.save_model_with_metadata(
                            model=runner.model,
                            metadata={
                                "val_state": best_val_state.as_dict(),
                            },
                            output_dir=output_dir,
                            file_name="best_model.p",
                        )
                    best_state_dict = copy_state_dict(
                        state_dict=runner.model.state_dict(),
                        target_device=CPU_DEVICE,
                    )
                val_state_history.append(val_state)
            if runner.train_schedule.max_steps != -1 and \
                    train_global_state.global_step >= runner.train_schedule.max_steps:
                full_break = True

            if metarunner.compare_steps_max_steps(
                    step=train_global_state.global_step,
                    max_steps=runner.train_schedule.max_steps):
                full_break = True

            if full_break:
                break

        if full_break:
            break

    if load_best_model and best_state_dict is not None:
        if verbose:
            print("Loading Best")
        runner.model.load_state_dict(
            copy_state_dict(
                state_dict=best_state_dict,
                target_device=runner.device,
            ))

    return {
        "best_val_state": best_val_state,
        "val_state_history": val_state_history,
    }