Exemple #1
0
    def save(self, checkpoint_dir=None) -> str:
        """Saves the current model state to a checkpoint.

        Subclasses should override ``save_checkpoint()`` instead to save state.
        This method dumps additional metadata alongside the saved path.

        If a remote checkpoint dir is given, this will also sync up to remote
        storage.

        Args:
            checkpoint_dir (str): Optional dir to place the checkpoint.

        Returns:
            str: path that points to xxx.pkl file.

        Note the return path should match up with what is expected of
        `restore()`.
        """
        checkpoint_dir = TrainableUtil.make_checkpoint_dir(
            checkpoint_dir or self.logdir, index=self.iteration)
        checkpoint = self.save_checkpoint(checkpoint_dir)
        trainable_state = self.get_state()
        checkpoint_path = TrainableUtil.process_checkpoint(
            checkpoint,
            parent_dir=checkpoint_dir,
            trainable_state=trainable_state)

        self._postprocess_checkpoint(checkpoint_dir)

        # Maybe sync to cloud
        self._maybe_save_to_cloud(checkpoint_dir)

        return checkpoint_path
Exemple #2
0
    def save(self, checkpoint_dir=None):
        """Saves the current model state to a checkpoint.

        Subclasses should override ``save_checkpoint()`` instead to save state.
        This method dumps additional metadata alongside the saved path.

        If a remote checkpoint dir is given, this will also sync up to remote
        storage.

        Args:
            checkpoint_dir (str): Optional dir to place the checkpoint.

        Returns:
            str: Checkpoint path or prefix that may be passed to restore().
        """
        checkpoint_dir = TrainableUtil.make_checkpoint_dir(
            checkpoint_dir or self.logdir, index=self.iteration)
        checkpoint = self.save_checkpoint(checkpoint_dir)
        trainable_state = self.get_state()
        checkpoint_path = TrainableUtil.process_checkpoint(
            checkpoint,
            parent_dir=checkpoint_dir,
            trainable_state=trainable_state)

        # Maybe sync to cloud
        self._maybe_save_to_cloud()

        return checkpoint_path
Exemple #3
0
        def write_checkpoint(trial: Trial, index: int):
            checkpoint_dir = TrainableUtil.make_checkpoint_dir(trial.logdir,
                                                               index=index)
            result = {"training_iteration": index}
            with open(os.path.join(checkpoint_dir, "cp.json"), "w") as f:
                json.dump(result, f)

            tune_cp = _TuneCheckpoint(_TuneCheckpoint.PERSISTENT,
                                      checkpoint_dir, result)
            trial.saving_to = tune_cp
            trial.on_checkpoint(tune_cp)

            return checkpoint_dir
Exemple #4
0
        def write_checkpoint(trial: Trial, index: int):
            checkpoint_dir = TrainableUtil.make_checkpoint_dir(
                trial.logdir, index=index
            )
            result = {"training_iteration": index}
            with open(os.path.join(checkpoint_dir, "cp.json"), "w") as f:
                json.dump(result, f)

            tune_cp = _TrackedCheckpoint(
                dir_or_data=checkpoint_dir,
                storage_mode=CheckpointStorage.PERSISTENT,
                metrics=result,
            )
            trial.saving_to = tune_cp

            return checkpoint_dir
Exemple #5
0
 def setUp(self):
     self.checkpoint_dir = os.path.join(
         ray._private.utils.get_user_temp_dir(), "tune", "MyTrainable123")
     self.checkpoint_dir = TrainableUtil.make_checkpoint_dir(
         self.checkpoint_dir, "0")