예제 #1
0
    def save(
        self,
        config: PyTextConfig,
        model: Model,
        meta: Optional[CommonMetadata],
        tensorizers: Dict[str, Tensorizer],
        training_state: Optional[TrainingState] = None,
        identifier: str = None,
    ) -> str:
        """
        save a checkpoint to given path, config, model and training_state
        together represent the checkpoint. When identifier is None, this
        function is used to save post-training snapshot
        """
        saved_path = ""
        if identifier:
            # saving during-training checkpoints
            saved_path = self.generate_checkpoint_path(config, identifier)
            print("Saving checkpoint to ", saved_path)
        else:
            # saving post-training snapshot if no identifer given
            saved_path = config.save_snapshot_path
            print(f"Saving pytorch model to: {saved_path}")

        saved_folder = os.path.dirname(saved_path)
        if not PathManager.exists(saved_folder):
            PathManager.mkdirs(saved_folder)
            print(f"created {saved_folder}")

        with PathManager.open(saved_path, "wb") as checkpoint_f:
            save_checkpoint(checkpoint_f, config, model, meta, tensorizers,
                            training_state)
            if identifier:
                self._saved_paths.append(saved_path)
            else:
                self._post_training_snapshot_path = saved_path
        return saved_path
예제 #2
0
def save(
    config: PyTextConfig,
    model: Model,
    meta: Optional[CommonMetadata],
    tensorizers: Dict[str, Tensorizer],
    training_state: Optional[TrainingState] = None,
    identifier: Optional[str] = None,
) -> str:
    """
    Save all stateful information of a training task to a specified file-like
    object, will save the original config, model state, metadata,
    training state if training is not completed
    Args:
    identifier (str): used to identify a checkpoint within a training job,
    used as a suffix for save path
    config (PytextConfig): contains all raw parameter/hyper-parameters
    for training task
    model (Model): actual model in training
    training_state (TrainingState): stateful infomation during training
    Returns:
    identifier (str): if identifier is not specified, will save to
    config.save_snapshot_path to be consistent to post-training snapshot;
    if specified, will be used to save checkpoint during training,
    identifier is used to identify checkpoints in the same training
    """
    saved_path = ""
    if identifier:
        # saving during-training checkpoints
        saved_path = generate_checkpoint_path(config, identifier)
    else:
        # saving post-training snapshot if no identifer given
        saved_path = config.save_snapshot_path
        print(f"Saving pytorch model to: {saved_path}")

    saved_folder = os.path.dirname(saved_path)
    if not PathManager.exists(saved_folder):
        PathManager.mkdirs(saved_folder)
        print(f"created {saved_folder}")

    # Currently torch.save() has error pickling certain models when not saving
    # by model.state_dict(), thus currently overriding the model in
    # training_state with None, and put back saving
    # https://github.com/pytorch/pytorch/issues/15116
    model_in_training_state = None
    if training_state:
        model_in_training_state, training_state.model = training_state.model, None
    try:
        state = {
            DATA_STATE: meta,
            CONFIG_JSON: config_to_json(PyTextConfig, config),
            MODEL_STATE: model.state_dict(),
            SERIALIZE_VERSION_KEY: LATEST_SERIALIZE_VERSION,
            TENSORIZERS: tensorizers,
            TRAINING_STATE: training_state,
        }
        if identifier is not None:
            _CHECKPOINT_MANAGER.save_checkpoint(state, saved_path)
        else:
            _CHECKPOINT_MANAGER.save_snapshot(state, saved_path)

    finally:
        if training_state:
            training_state.model = model_in_training_state
    return saved_path