Ejemplo n.º 1
0
    def save(
        self,
        trial: Trial,
        storage: str = _TuneCheckpoint.PERSISTENT,
        result: Optional[Dict] = None,
    ) -> _TuneCheckpoint:
        """Saves the trial's state to a checkpoint asynchronously.

        Args:
            trial: The trial to be saved.
            storage: Where to store the checkpoint. Defaults to
                PERSISTENT.
            result: The state of this trial as a dictionary to be saved.
                If result is None, the trial's last result will be used.

        Returns:
             Checkpoint object, or None if an Exception occurs.
        """
        logger.debug(f"saving trial {trial}")
        result = result or trial.last_result
        with self._change_working_directory(trial):
            if storage == _TuneCheckpoint.MEMORY:
                value = trial.runner.save_to_object.remote()
                checkpoint = _TuneCheckpoint(storage, value, result)
                trial.on_checkpoint(checkpoint)
            else:
                value = trial.runner.save.remote()
                checkpoint = _TuneCheckpoint(storage, value, result)
                trial.saving_to = checkpoint
                self._futures[value] = (ExecutorEventType.SAVING_RESULT, trial)
        return checkpoint
Ejemplo n.º 2
0
        def write_checkpoint(trial: Trial, index: int):
            checkpoint_dir = TrainableUtil.make_checkpoint_dir(trial.logdir,
                                                               index=index)
            result = {"training_iteration": index}
            with open(os.path.join(checkpoint_dir, "cp.json"), "w") as f:
                json.dump(result, f)

            tune_cp = _TuneCheckpoint(_TuneCheckpoint.PERSISTENT,
                                      checkpoint_dir, result)
            trial.saving_to = tune_cp
            trial.on_checkpoint(tune_cp)

            return checkpoint_dir