def to_air_checkpoint(self) -> Optional[Checkpoint]: from ray.tune.trainable.util import TrainableUtil checkpoint_data = self.dir_or_data if not checkpoint_data: return None if isinstance(checkpoint_data, ray.ObjectRef): checkpoint_data = ray.get(checkpoint_data) if isinstance(checkpoint_data, str): checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_data) checkpoint = Checkpoint.from_directory(checkpoint_dir) elif isinstance(checkpoint_data, bytes): with tempfile.TemporaryDirectory() as tmpdir: TrainableUtil.create_from_pickle(checkpoint_data, tmpdir) # Double wrap in checkpoint so we hold the data in memory and # can remove the temp directory checkpoint = Checkpoint.from_dict( Checkpoint.from_directory(tmpdir).to_dict()) elif isinstance(checkpoint_data, dict): checkpoint = Checkpoint.from_dict(checkpoint_data) else: raise RuntimeError( f"Unknown checkpoint data type: {type(checkpoint_data)}") return checkpoint
def restore_from_object(self, obj): """Restores training state from a checkpoint object. These checkpoints are returned from calls to save_to_object(). """ tmpdir = tempfile.mkdtemp("restore_from_object", dir=self.logdir) checkpoint_path = TrainableUtil.create_from_pickle(obj, tmpdir) self.restore(checkpoint_path) shutil.rmtree(tmpdir)