Example #1
0
    def to_air_checkpoint(self) -> Optional[Checkpoint]:
        checkpoint_data = self.dir_or_data

        if not checkpoint_data:
            return None

        if isinstance(checkpoint_data, ray.ObjectRef):
            checkpoint_data = ray.get(checkpoint_data)

        if isinstance(checkpoint_data, str):
            checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_data)
            checkpoint = Checkpoint.from_directory(checkpoint_dir)
        elif isinstance(checkpoint_data, bytes):
            with tempfile.TemporaryDirectory() as tmpdir:
                TrainableUtil.create_from_pickle(checkpoint_data, tmpdir)
                # Double wrap in checkpoint so we hold the data in memory and
                # can remove the temp directory
                checkpoint = Checkpoint.from_dict(
                    Checkpoint.from_directory(tmpdir).to_dict())
        elif isinstance(checkpoint_data, dict):
            checkpoint = Checkpoint.from_dict(checkpoint_data)
        else:
            raise RuntimeError(
                f"Unknown checkpoint data type: {type(checkpoint_data)}")

        return checkpoint
Example #2
0
    def restore_from_object(self, obj):
        """Restores training state from a checkpoint object.

        These checkpoints are returned from calls to save_to_object().
        """
        tmpdir = tempfile.mkdtemp("restore_from_object", dir=self.logdir)
        checkpoint_path = TrainableUtil.create_from_pickle(obj, tmpdir)
        self.restore(checkpoint_path)
        shutil.rmtree(tmpdir)
Example #3
0
 def save_checkpoint(self, checkpoint_dir: str) -> str:
     # TODO: optimize if colocated
     save_obj = ray.get(self.workers[0].save_to_object.remote())
     checkpoint_path = TrainableUtil.create_from_pickle(
         save_obj, checkpoint_dir)
     return checkpoint_path
Example #4
0
 def save_checkpoint(self, checkpoint_dir: str) -> str:
     # TODO: optimize if colocated
     save_obj = self.executor.execute_single(lambda w: w.save_to_object())
     checkpoint_path = TrainableUtil.create_from_pickle(
         save_obj, checkpoint_dir)
     return checkpoint_path