def load_checkpoint(f: io.IOBase, overwrite_config=None): state = torch.load(f, map_location=lambda storage, loc: storage) distributed.force_print(f"Loaded checkpoint...", flush=True) if SERIALIZE_VERSION_KEY not in state: return load_v1(state) else: return LOADER_VERSION_MAP[state[SERIALIZE_VERSION_KEY]]( state, overwrite_config)
def load_v3(state, overwrite_config=None): saved_config = pytext_config_from_json(state[CONFIG_JSON]) if overwrite_config: config = overwrite_config distributed.force_print(f"Use config from current task", flush=True) else: config = saved_config distributed.force_print(f"Use config saved in snapshot", flush=True) model_state = state[MODEL_STATE] training_state = state[TRAINING_STATE] if training_state and training_state.tensorizers: tensorizers = training_state.tensorizers else: tensorizers = state[TENSORIZERS] # importing in file level generates circular import/dependency failures, # that need refator later to fix from .task import create_task task = create_task( config.task, metadata=state[DATA_STATE], model_state=model_state, tensorizers=tensorizers, ) # TODO: T53664090 @stevenliu save & load state_dict() of optimizer and scheduler if training_state: if training_state.model is None and task.model: training_state.model = task.model if training_state.optimizer and task.trainer.optimizer: """ https://pytorch.org/tutorials/beginner/saving_loading_models.html Unpickling optimizer object from checkpoint could result in a different parameter copy from model parameters. Especially in mixied precision training, which optimizer param_groups maintains master weights copy instead of the model parameters. The suggested loading mechanism is model = TheModelClass(*args, **kwargs) optimizer = TheOptimizerClass(model.parameters(), *args, **kwargs) checkpoint = torch.load(PATH) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) """ optimizer = task.trainer.optimizer optimizer.load_state_dict(training_state.optimizer.state_dict()) training_state.optimizer = optimizer return task, config, training_state
def load(self, load_path: str, overwrite_config=None): """ Loads a checkpoint from disk. Args: load_path (str): the file path to load for checkpoint Returns: task (Task), config (PyTextConfig) and training_state (TrainingState) """ if not (load_path and os.path.isfile(load_path)): raise ValueError(f"Invalid snapshot path{load_path}") distributed.force_print(f"Loading model from {load_path}...", flush=True) with open(load_path, "rb") as checkpoint_f: return load_checkpoint(checkpoint_f, overwrite_config)
def from_config( cls, config: Config, unused_metadata=None, model_state=None, tensorizers=None, rank=0, world_size=1, ): distributed.force_print(f"Creating task: {cls.__name__}...", flush=True) tensorizers, data = cls._init_tensorizers(config, tensorizers, rank, world_size) # Initialized tensorizers can be used to create the model model = cls._init_model(config.model, tensorizers, model_state) # This is the only place right now that the task actually cares about which # features and tensors are being used. This is a strong tie between # the implementation of the model and the metric reporter. metric_reporter = cls.create_metric_reporter(config, tensorizers) trainer = create_trainer(config.trainer, model) return cls(data, model, metric_reporter, trainer)
def train_from_state( self, state: TrainingState, training_data: BatchIterator, eval_data: BatchIterator, metric_reporter: MetricReporter, train_config: PyTextConfig, ) -> Tuple[torch.nn.Module, Any]: """ Train and eval a model from a given training state will be modified. This function iterates epochs specified in config, and for each epoch do: 1. Train model using training data, aggregate and report training results 2. Adjust learning rate if scheduler is specified 3. Evaluate model using evaluation data 4. Calculate metrics based on evaluation results and select best model Args: training_state (TrainingState): contrains stateful information to be able to restore a training job train_iter (BatchIterator): batch iterator of training data eval_iter (BatchIterator): batch iterator of evaluation data model (Model): model to be trained metric_reporter (MetricReporter): compute metric based on training output and report results to console, file.. etc train_config (PyTextConfig): training config Returns: model, best_metric: the trained model together with the best metric """ training_data = self.set_up_training(state, training_data) model = state.model rank = state.rank trainable_params = sum(p.numel() for p in state.model.parameters() if p.requires_grad) print(f"Num trainable parameters: {trainable_params}") while self.continue_training(state): state.epoch += 1 state.epochs_since_last_improvement += 1 lrs = learning_rates(state.optimizer) distributed.force_print( f"\nWorker {state.rank} starting epoch {state.epoch}", flush=True) print(f"Learning rate(s): {', '.join(map(str, lrs))}") with timing.time("train epoch"): state.stage = Stage.TRAIN state.model.train() distributed.force_print(f"start training epoch {state.epoch}", flush=True) epoch_data = training_data if self.config.num_batches_per_epoch: # We want to limit the number of batches in the epoch; # equivalent to epoch_data[:num_batches_per_epoch] for iterators. # In this case we set the training data iterator to cycle earlier # in the training process, so when it reaches the end it will # loop back to the beginning. epoch_data = itertools.islice( epoch_data, self.config.num_batches_per_epoch) self.run_epoch(state, epoch_data, metric_reporter) if not self.config.do_eval: continue with timing.time("eval epoch"): state.stage = Stage.EVAL model.eval(Stage.EVAL) distributed.force_print( f"start evaluating epoch {state.epoch}", flush=True) with torch.no_grad(): eval_metric = self.run_epoch(state, eval_data, metric_reporter) # Step the learning rate scheduler(s) assert eval_metric is not None state.scheduler.step_epoch( metrics=metric_reporter.get_model_select_metric(eval_metric), epoch=state.epoch, ) # Did we train a better model? better_model = metric_reporter.compare_metric( eval_metric, state.best_model_metric) if better_model: self.update_best_model(state, train_config, eval_metric) if better_model or train_config.save_all_checkpoints: self.save_checkpoint(state, train_config) if self.optimizer.finalize(): state.stage = Stage.EVAL model.eval(Stage.EVAL) distributed.force_print(f"start evaluating finalized state", flush=True) with torch.no_grad(): eval_metric = self.run_epoch(state, eval_data, metric_reporter) better_model = metric_reporter.compare_metric( eval_metric, state.best_model_metric) if better_model: self.update_best_model(state, train_config, eval_metric) if better_model or train_config.save_all_checkpoints: self.save_checkpoint(state, train_config) # Only bother loading the best model for master worker if rank == 0 and state.best_model_state is not None: self.load_best_model(state) return state.model, state.best_model_metric