예제 #1
0
def load_checkpoint(f: io.IOBase, overwrite_config=None):
    state = torch.load(f, map_location=lambda storage, loc: storage)
    distributed.force_print(f"Loaded checkpoint...", flush=True)
    if SERIALIZE_VERSION_KEY not in state:
        return load_v1(state)
    else:
        return LOADER_VERSION_MAP[state[SERIALIZE_VERSION_KEY]](
            state, overwrite_config)
예제 #2
0
def load_v3(state, overwrite_config=None):
    saved_config = pytext_config_from_json(state[CONFIG_JSON])
    if overwrite_config:
        config = overwrite_config
        distributed.force_print(f"Use config from current task", flush=True)
    else:
        config = saved_config
        distributed.force_print(f"Use config saved in snapshot", flush=True)

    model_state = state[MODEL_STATE]
    training_state = state[TRAINING_STATE]
    if training_state and training_state.tensorizers:
        tensorizers = training_state.tensorizers
    else:
        tensorizers = state[TENSORIZERS]
    # importing in file level generates circular import/dependency failures,
    # that need refator later to fix
    from .task import create_task

    task = create_task(
        config.task,
        metadata=state[DATA_STATE],
        model_state=model_state,
        tensorizers=tensorizers,
    )

    # TODO: T53664090 @stevenliu save & load state_dict() of optimizer and scheduler
    if training_state:
        if training_state.model is None and task.model:
            training_state.model = task.model
        if training_state.optimizer and task.trainer.optimizer:
            """
            https://pytorch.org/tutorials/beginner/saving_loading_models.html
            Unpickling optimizer object from checkpoint could result in a
            different parameter copy from model parameters. Especially in
            mixied precision training, which optimizer param_groups maintains
            master weights copy instead of the model parameters.

            The suggested loading mechanism is

            model = TheModelClass(*args, **kwargs)
            optimizer = TheOptimizerClass(model.parameters(), *args, **kwargs)

            checkpoint = torch.load(PATH)
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            """

            optimizer = task.trainer.optimizer
            optimizer.load_state_dict(training_state.optimizer.state_dict())
            training_state.optimizer = optimizer

    return task, config, training_state
예제 #3
0
 def load(self, load_path: str, overwrite_config=None):
     """
     Loads a checkpoint from disk.
     Args:
         load_path (str): the file path to load for checkpoint
     Returns: task (Task), config (PyTextConfig) and training_state (TrainingState)
     """
     if not (load_path and os.path.isfile(load_path)):
         raise ValueError(f"Invalid snapshot path{load_path}")
     distributed.force_print(f"Loading model from {load_path}...",
                             flush=True)
     with open(load_path, "rb") as checkpoint_f:
         return load_checkpoint(checkpoint_f, overwrite_config)
예제 #4
0
    def from_config(
        cls,
        config: Config,
        unused_metadata=None,
        model_state=None,
        tensorizers=None,
        rank=0,
        world_size=1,
    ):
        distributed.force_print(f"Creating task: {cls.__name__}...",
                                flush=True)
        tensorizers, data = cls._init_tensorizers(config, tensorizers, rank,
                                                  world_size)

        # Initialized tensorizers can be used to create the model
        model = cls._init_model(config.model, tensorizers, model_state)

        # This is the only place right now that the task actually cares about which
        # features and tensors are being used. This is a strong tie between
        # the implementation of the model and the metric reporter.
        metric_reporter = cls.create_metric_reporter(config, tensorizers)
        trainer = create_trainer(config.trainer, model)
        return cls(data, model, metric_reporter, trainer)
예제 #5
0
    def train_from_state(
        self,
        state: TrainingState,
        training_data: BatchIterator,
        eval_data: BatchIterator,
        metric_reporter: MetricReporter,
        train_config: PyTextConfig,
    ) -> Tuple[torch.nn.Module, Any]:
        """
        Train and eval a model from a given training state will be modified.
        This function iterates epochs specified in config, and for each epoch do:

            1. Train model using training data, aggregate and report training results
            2. Adjust learning rate if scheduler is specified
            3. Evaluate model using evaluation data
            4. Calculate metrics based on evaluation results and select best model

        Args:
            training_state (TrainingState): contrains stateful information to be
            able to restore a training job
            train_iter (BatchIterator): batch iterator of training data
            eval_iter (BatchIterator): batch iterator of evaluation data
            model (Model): model to be trained
            metric_reporter (MetricReporter): compute metric based on training
                output and report results to console, file.. etc
            train_config (PyTextConfig): training config

        Returns:
            model, best_metric: the trained model together with the best metric
        """
        training_data = self.set_up_training(state, training_data)
        model = state.model
        rank = state.rank
        trainable_params = sum(p.numel() for p in state.model.parameters()
                               if p.requires_grad)
        print(f"Num trainable parameters: {trainable_params}")

        while self.continue_training(state):
            state.epoch += 1
            state.epochs_since_last_improvement += 1
            lrs = learning_rates(state.optimizer)
            distributed.force_print(
                f"\nWorker {state.rank} starting epoch {state.epoch}",
                flush=True)
            print(f"Learning rate(s): {', '.join(map(str, lrs))}")

            with timing.time("train epoch"):
                state.stage = Stage.TRAIN
                state.model.train()
                distributed.force_print(f"start training epoch {state.epoch}",
                                        flush=True)
                epoch_data = training_data
                if self.config.num_batches_per_epoch:
                    # We want to limit the number of batches in the epoch;
                    # equivalent to epoch_data[:num_batches_per_epoch] for iterators.
                    # In this case we set the training data iterator to cycle earlier
                    # in the training process, so when it reaches the end it will
                    # loop back to the beginning.
                    epoch_data = itertools.islice(
                        epoch_data, self.config.num_batches_per_epoch)
                self.run_epoch(state, epoch_data, metric_reporter)

            if not self.config.do_eval:
                continue

            with timing.time("eval epoch"):
                state.stage = Stage.EVAL
                model.eval(Stage.EVAL)
                distributed.force_print(
                    f"start evaluating epoch {state.epoch}", flush=True)
                with torch.no_grad():
                    eval_metric = self.run_epoch(state, eval_data,
                                                 metric_reporter)

            # Step the learning rate scheduler(s)
            assert eval_metric is not None
            state.scheduler.step_epoch(
                metrics=metric_reporter.get_model_select_metric(eval_metric),
                epoch=state.epoch,
            )

            # Did we train a better model?
            better_model = metric_reporter.compare_metric(
                eval_metric, state.best_model_metric)
            if better_model:
                self.update_best_model(state, train_config, eval_metric)
            if better_model or train_config.save_all_checkpoints:
                self.save_checkpoint(state, train_config)

        if self.optimizer.finalize():
            state.stage = Stage.EVAL
            model.eval(Stage.EVAL)
            distributed.force_print(f"start evaluating finalized state",
                                    flush=True)
            with torch.no_grad():
                eval_metric = self.run_epoch(state, eval_data, metric_reporter)
            better_model = metric_reporter.compare_metric(
                eval_metric, state.best_model_metric)
            if better_model:
                self.update_best_model(state, train_config, eval_metric)
            if better_model or train_config.save_all_checkpoints:
                self.save_checkpoint(state, train_config)
        # Only bother loading the best model for master worker
        if rank == 0 and state.best_model_state is not None:
            self.load_best_model(state)

        return state.model, state.best_model_metric