def __init__( self, restore_path: str = None, trainable: Optional[Union[str, Callable, Type[Trainable], BaseTrainer, ]] = None, param_space: Optional[Dict[str, Any]] = None, tune_config: Optional[TuneConfig] = None, run_config: Optional[RunConfig] = None, _tuner_kwargs: Optional[Dict] = None, ): # Restored from Tuner checkpoint. if restore_path: trainable_ckpt = os.path.join(restore_path, _TRAINABLE_PKL) with open(trainable_ckpt, "rb") as fp: trainable = pickle.load(fp) tuner_ckpt = os.path.join(restore_path, _TUNER_PKL) with open(tuner_ckpt, "rb") as fp: tuner = pickle.load(fp) self.__dict__.update(tuner.__dict__) self._is_restored = True self._trainable = trainable self._experiment_checkpoint_dir = restore_path return # Start from fresh if not trainable: raise TuneError("You need to provide a trainable to tune.") # If no run config was passed to Tuner directly, use the one from the Trainer, # if available if not run_config and isinstance(trainable, BaseTrainer): run_config = trainable.run_config self._is_restored = False self._trainable = trainable self._tune_config = tune_config or TuneConfig() self._run_config = run_config or RunConfig() self._tuner_kwargs = copy.deepcopy(_tuner_kwargs) or {} self._experiment_checkpoint_dir = self._setup_create_experiment_checkpoint_dir( self._run_config) # Not used for restored Tuner. self._param_space = param_space or {} # This needs to happen before `tune.run()` is kicked in. # This is because currently tune does not exit gracefully if # run in ray client mode - if crash happens, it just exits immediately # without allowing for checkpointing tuner and trainable. # Thus this has to happen before tune.run() so that we can have something # to restore from. tuner_ckpt = os.path.join(self._experiment_checkpoint_dir, _TUNER_PKL) with open(tuner_ckpt, "wb") as fp: pickle.dump(self, fp) trainable_ckpt = os.path.join(self._experiment_checkpoint_dir, _TRAINABLE_PKL) with open(trainable_ckpt, "wb") as fp: pickle.dump(self._trainable, fp)
def load_checkpoint( checkpoint: Checkpoint, env: Optional[EnvType] = None, ) -> Tuple[Policy, Optional[Preprocessor]]: """Load a Checkpoint from ``RLTrainer``. Args: checkpoint: The checkpoint to load the policy and preprocessor from. It is expected to be from the result of a ``RLTrainer`` run. env: Optional environment to instantiate the trainer with. If not given, it is parsed from the saved trainer configuration instead. Returns: The policy and AIR preprocessor contained within. """ with checkpoint.as_directory() as checkpoint_path: trainer_class_path = os.path.join(checkpoint_path, RL_TRAINER_CLASS_FILE) config_path = os.path.join(checkpoint_path, RL_CONFIG_FILE) if not os.path.exists(trainer_class_path): raise ValueError( f"RLPredictor only works with checkpoints created by " f"RLTrainer. The checkpoint you specified is missing the " f"`{RL_TRAINER_CLASS_FILE}` file." ) if not os.path.exists(config_path): raise ValueError( f"RLPredictor only works with checkpoints created by " f"RLTrainer. The checkpoint you specified is missing the " f"`{RL_CONFIG_FILE}` file." ) with open(trainer_class_path, "rb") as fp: trainer_cls = cpickle.load(fp) with open(config_path, "rb") as fp: config = cpickle.load(fp) checkpoint_data_path = None for file in os.listdir(checkpoint_path): if file.startswith("checkpoint") and not file.endswith(".tune_metadata"): checkpoint_data_path = os.path.join(checkpoint_path, file) if not checkpoint_data_path: raise ValueError( f"Could not find checkpoint data in RLlib checkpoint. " f"Found files: {list(os.listdir(checkpoint_path))}" ) preprocessor = load_preprocessor_from_dir(checkpoint_path) config.get("evaluation_config", {}).pop("in_evaluation", None) trainer = trainer_cls(config=config, env=env) trainer.restore(checkpoint_data_path) policy = trainer.get_policy() return policy, preprocessor
def get_policy(self, env: Optional[EnvType] = None) -> Policy: """Retrieve the policy stored in this checkpoint. Args: env: Optional environment to instantiate the trainer with. If not given, it is parsed from the saved trainer configuration. Returns: The policy stored in this checkpoint. """ with self.as_directory() as checkpoint_path: trainer_class_path = os.path.join(checkpoint_path, RL_TRAINER_CLASS_FILE) config_path = os.path.join(checkpoint_path, RL_CONFIG_FILE) if not os.path.exists(trainer_class_path): raise ValueError( f"RLPredictor only works with checkpoints created by " f"RLTrainer. The checkpoint you specified is missing the " f"`{RL_TRAINER_CLASS_FILE}` file.") if not os.path.exists(config_path): raise ValueError( f"RLPredictor only works with checkpoints created by " f"RLTrainer. The checkpoint you specified is missing the " f"`{RL_CONFIG_FILE}` file.") with open(trainer_class_path, "rb") as fp: trainer_cls = cpickle.load(fp) with open(config_path, "rb") as fp: config = cpickle.load(fp) checkpoint_data_path = None for file in os.listdir(checkpoint_path): if file.startswith( "checkpoint") and not file.endswith(".tune_metadata"): checkpoint_data_path = os.path.join(checkpoint_path, file) if not checkpoint_data_path: raise ValueError( f"Could not find checkpoint data in RLlib checkpoint. " f"Found files: {list(os.listdir(checkpoint_path))}") config.get("evaluation_config", {}).pop("in_evaluation", None) trainer = trainer_cls(config=config, env=env) trainer.restore(checkpoint_data_path) return trainer.get_policy()
def load_from_checkpoint( checkpoint: Checkpoint, ) -> Tuple[RandomForestClassifier, Optional[Preprocessor]]: path = checkpoint.to_directory() estimator_path = os.path.join(path, MODEL_KEY) with open(estimator_path, "rb") as f: estimator = cpickle.load(f) preprocessor_path = os.path.join(path, PREPROCESSOR_KEY) if os.path.exists(preprocessor_path): with open(preprocessor_path, "rb") as f: preprocessor = cpickle.load(f) else: preprocessor = None return estimator, preprocessor
def load_checkpoint_from_path(checkpoint_to_load: Union[str, Path]) -> Dict: """Utility function to load a checkpoint Dict from a path.""" checkpoint_path = Path(checkpoint_to_load).expanduser() if not checkpoint_path.exists(): raise ValueError(f"Checkpoint path {checkpoint_path} does not exist.") with checkpoint_path.open("rb") as f: return cloudpickle.load(f)
def test_retry(ray_start_2_cpus): def train_func(): ckpt = sgd.load_checkpoint() restored = bool(ckpt) # Does a previous checkpoint exist? itr = 0 if ckpt: itr = ckpt["iter"] + 1 for i in range(itr, 4): if i == 2 and not restored: raise Exception("try to fail me") sgd.save_checkpoint(iter=i) sgd.report(test=i, training_iteration=i) trainer = Trainer(TestConfig()) TestTrainable = trainer.to_tune_trainable(train_func) analysis = tune.run(TestTrainable, max_failures=3) last_ckpt = analysis.trials[0].checkpoint.value checkpoint_file = os.path.join(last_ckpt, TUNE_CHECKPOINT_FILE_NAME) assert os.path.exists(checkpoint_file) with open(checkpoint_file, "rb") as f: checkpoint = cloudpickle.load(f) assert checkpoint["iter"] == 3 trial_dfs = list(analysis.trial_dataframes.values()) assert len(trial_dfs[0]["training_iteration"]) == 4
def test_reuse_checkpoint(ray_start_2_cpus): def train_func(config): itr = 0 ckpt = sgd.load_checkpoint() if ckpt is not None: itr = ckpt["iter"] + 1 for i in range(itr, config["max_iter"]): sgd.save_checkpoint(iter=i) sgd.report(test=i, training_iteration=i) trainer = Trainer(TestConfig()) TestTrainable = trainer.to_tune_trainable(train_func) [trial] = tune.run(TestTrainable, config={"max_iter": 5}).trials last_ckpt = trial.checkpoint.value checkpoint_file = os.path.join(last_ckpt, TUNE_CHECKPOINT_FILE_NAME) assert os.path.exists(checkpoint_file) with open(checkpoint_file, "rb") as f: checkpoint = cloudpickle.load(f) assert checkpoint["iter"] == 4 analysis = tune.run(TestTrainable, config={"max_iter": 10}, restore=last_ckpt) trial_dfs = list(analysis.trial_dataframes.values()) assert len(trial_dfs[0]["training_iteration"]) == 5
def to_dict(self) -> dict: """Return checkpoint data as dictionary. Returns: dict: Dictionary containing checkpoint data. """ if self._data_dict: # If the checkpoint data is already a dict, return return self._data_dict elif self._obj_ref: # If the checkpoint data is an object reference, resolve return ray.get(self._obj_ref) elif self._local_path or self._uri: # Else, checkpoint is either on FS or external storage with self.as_directory() as local_path: checkpoint_data_path = os.path.join( local_path, _DICT_CHECKPOINT_FILE_NAME) if os.path.exists(checkpoint_data_path): # If we are restoring a dict checkpoint, load the dict # from the checkpoint file. with open(checkpoint_data_path, "rb") as f: checkpoint_data = pickle.load(f) else: data = _pack(local_path) checkpoint_data = { _FS_CHECKPOINT_KEY: data, } return checkpoint_data else: raise RuntimeError(f"Empty data for checkpoint {self}")
def restore_policy_from_checkpoint( policy_class: type, env_creator: Callable[[Dict[str, Any]], gym.Env], checkpoint_path: str, config: Dict[str, Any]) -> Policy: """ TODO: Write documentation """ # Load checkpoint policy state with open(checkpoint_path, "rb") as checkpoint_dump: checkpoint_state = pickle.load(checkpoint_dump) worker_dump = checkpoint_state['worker'] worker_state = pickle.loads(worker_dump) policy_state = worker_state['state']['default_policy'] # Initiate temporary environment to get observation and action spaces env = env_creator(config.get("env_config", {})) # Get preprocessed observation space preprocessor_class = get_preprocessor(env.observation_space) preprocessor = preprocessor_class(env.observation_space) observation_space = preprocessor.observation_space # Instantiate policy and load checkpoint state policy = policy_class(observation_space, env.action_space, config) policy.set_state(policy_state) return policy
def restore(self, checkpoint_path: str): with open(checkpoint_path, "rb") as inputFile: save_object = cloudpickle.load(inputFile) numpy_random_state = save_object.pop("_random_state_seed_to_set", None) self.__dict__.update(save_object) if numpy_random_state is not None: np.random.set_state(numpy_random_state)
def restore(self, checkpoint_path): """Restores training state from a given model checkpoint. These checkpoints are returned from calls to save(). Subclasses should override ``_restore()`` instead to restore state. This method restores additional metadata saved with the checkpoint. """ # Maybe sync from cloud if self.uses_cloud_checkpointing: self.storage_client.sync_down(self.remote_checkpoint_dir, self.logdir) self.storage_client.wait() # Ensure TrialCheckpoints are converted if isinstance(checkpoint_path, TrialCheckpoint): checkpoint_path = checkpoint_path.local_path with open(checkpoint_path + ".tune_metadata", "rb") as f: metadata = pickle.load(f) self._experiment_id = metadata["experiment_id"] self._iteration = metadata["iteration"] self._timesteps_total = metadata["timesteps_total"] self._time_total = metadata["time_total"] self._episodes_total = metadata["episodes_total"] saved_as_dict = metadata["saved_as_dict"] if saved_as_dict: with open(checkpoint_path, "rb") as loaded_state: checkpoint_dict = pickle.load(loaded_state) checkpoint_dict.update(tune_checkpoint_path=checkpoint_path) self.load_checkpoint(checkpoint_dict) else: self.load_checkpoint(checkpoint_path) self._time_since_restore = 0.0 self._timesteps_since_restore = 0 self._iterations_since_restore = 0 self._restored = True logger.info("Restored on %s from checkpoint: %s", self.get_current_ip(), checkpoint_path) state = { "_iteration": self._iteration, "_timesteps_total": self._timesteps_total, "_time_total": self._time_total, "_episodes_total": self._episodes_total, } logger.info("Current state after restoring: %s", state)
def restore(self, checkpoint_path: str): with open(checkpoint_path, "rb") as inputFile: save_object = pickle.load(inputFile) if not isinstance(save_object, dict): # backwards compatibility # Deprecate: 1.8 self.optimizer = save_object self.__dict__.update(save_object)
def _find_newest_ckpt(dirpath: str, pattern: str): """Returns path to most recently modified checkpoint.""" full_paths = glob.glob(os.path.join(dirpath, pattern)) if not full_paths: return most_recent_checkpoint = max(full_paths) with open(most_recent_checkpoint, "rb") as f: search_alg_state = cloudpickle.load(f) return search_alg_state
def to_dict(self) -> dict: """Return checkpoint data as dictionary. Returns: dict: Dictionary containing checkpoint data. """ if self._data_dict: # If the checkpoint data is already a dict, return return self._data_dict elif self._obj_ref: # If the checkpoint data is an object reference, resolve return ray.get(self._obj_ref) elif self._local_path or self._uri: # Else, checkpoint is either on FS or external storage with self.as_directory() as local_path: checkpoint_data_path = os.path.join( local_path, _DICT_CHECKPOINT_FILE_NAME) if os.path.exists(checkpoint_data_path): # If we are restoring a dict checkpoint, load the dict # from the checkpoint file. with open(checkpoint_data_path, "rb") as f: checkpoint_data = pickle.load(f) else: files = [ f for f in os.listdir(local_path) if os.path.isfile(os.path.join(local_path, f)) and f.endswith(_METADATA_CHECKPOINT_SUFFIX) ] metadata = {} for file in files: with open(os.path.join(local_path, file), "rb") as f: key = file[:-len(_METADATA_CHECKPOINT_SUFFIX)] value = pickle.load(f) metadata[key] = value data = _pack(local_path) checkpoint_data = { _FS_CHECKPOINT_KEY: data, } checkpoint_data.update(metadata) return checkpoint_data else: raise RuntimeError(f"Empty data for checkpoint {self}")
def load_preprocessor_from_dir( parent_dir: os.PathLike, ) -> Optional["Preprocessor"]: """Loads preprocessor from directory, if file exists.""" parent_dir = Path(parent_dir) preprocessor_path = parent_dir.joinpath(PREPROCESSOR_KEY) if preprocessor_path.exists(): with open(preprocessor_path, "rb") as f: preprocessor = cpickle.load(f) else: preprocessor = None return preprocessor
def load_checkpoint_metadata(checkpoint_path: str) -> Optional[Dict]: metadata_path = os.path.join(checkpoint_path, ".tune_metadata") if not os.path.exists(metadata_path): checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path) metadatas = glob.glob(f"{checkpoint_dir}/**/.tune_metadata", recursive=True) if not metadatas: return None metadata_path = metadatas[0] with open(metadata_path, "rb") as f: return pickle.load(f)
def _restore(self, checkpoint): """Loads a checkpoint created from `save`. Args: checkpoint (str): file path to pickled checkpoint file. """ if self.pickled: with open(checkpoint, "rb") as f: self.estimator = cpickle.load(f) else: warnings.warn("No estimator restored")
def _restore(self, checkpoint): """Loads a checkpoint created from `save`. Args: checkpoint (str): file path to pickled checkpoint file. """ try: with open(checkpoint, "rb") as f: self.estimator_list = cpickle.load(f) except Exception: warnings.warn("No estimator restored", category=RuntimeWarning)
def load_from_checkpoint( checkpoint: Checkpoint, ) -> Tuple[xgb.Booster, Optional[Preprocessor]]: checkpoint_path = checkpoint.to_directory() xgb_model = xgb.Booster() xgb_model.load_model(os.path.join(checkpoint_path, MODEL_KEY)) preprocessor_path = os.path.join(checkpoint_path, PREPROCESSOR_KEY) if os.path.exists(preprocessor_path): with open(preprocessor_path, "rb") as f: preprocessor = cpickle.load(f) else: preprocessor = None return xgb_model, preprocessor
def get_checkpoints_paths(logdir): """Finds the checkpoints within a specific folder. Returns a pandas DataFrame of training iterations and checkpoint paths within a specific folder. Raises: FileNotFoundError if the directory is not found. """ marker_paths = glob.glob( os.path.join(glob.escape(logdir), "checkpoint_*/.is_checkpoint") ) iter_chkpt_pairs = [] for marker_path in marker_paths: chkpt_dir = os.path.dirname(marker_path) # Skip temporary checkpoints if os.path.basename(chkpt_dir).startswith("checkpoint_tmp"): continue metadata_file = glob.glob( os.path.join(glob.escape(chkpt_dir), "*.tune_metadata") ) # glob.glob: filenames starting with a dot are special cases # that are not matched by '*' and '?' patterns. metadata_file += glob.glob( os.path.join(glob.escape(chkpt_dir), ".tune_metadata") ) metadata_file = list(set(metadata_file)) # avoid duplication if len(metadata_file) != 1: raise ValueError( "{} has zero or more than one tune_metadata.".format(chkpt_dir) ) metadata_file = metadata_file[0] try: with open(metadata_file, "rb") as f: metadata = pickle.load(f) except Exception as e: logger.warning(f"Could not read metadata from checkpoint: {e}") metadata = {} chkpt_path = metadata_file[: -len(".tune_metadata")] chkpt_iter = metadata.get("iteration", -1) iter_chkpt_pairs.append([chkpt_iter, chkpt_path]) chkpt_df = pd.DataFrame( iter_chkpt_pairs, columns=["training_iteration", "chkpt_path"] ) return chkpt_df
def test_tune_checkpoint(ray_start_2_cpus): def train_func(): for i in range(10): train.report(test=i) train.save_checkpoint(hello="world") trainer = Trainer(TestConfig(), num_workers=1) TestTrainable = trainer.to_tune_trainable(train_func) [trial] = tune.run(TestTrainable).trials checkpoint_file = os.path.join(trial.checkpoint.value, TUNE_CHECKPOINT_FILE_NAME) assert os.path.exists(checkpoint_file) with open(checkpoint_file, "rb") as f: checkpoint = cloudpickle.load(f) assert checkpoint["hello"] == "world"
def _load_checkpoint( self, checkpoint_to_load: Optional[Union[Dict, str, Path]]) -> Optional[Dict]: """Load the checkpoint dictionary from the input dict or path.""" if checkpoint_to_load is None: return None if isinstance(checkpoint_to_load, Dict): return checkpoint_to_load else: # Load checkpoint from path. checkpoint_path = Path(checkpoint_to_load).expanduser() if not checkpoint_path.exists(): raise ValueError(f"Checkpoint path {checkpoint_path} " f"does not exist.") with checkpoint_path.open("rb") as f: return cloudpickle.load(f)
def to_dict(self) -> dict: """Return checkpoint data as dictionary. Returns: dict: Dictionary containing checkpoint data. """ if self._data_dict: # If the checkpoint data is already a dict, return return self._data_dict elif self._obj_ref: # If the checkpoint data is an object reference, resolve return ray.get(self._obj_ref) elif self._local_path or self._uri: # Else, checkpoint is either on FS or external storage cleanup = False local_path = self._local_path if not local_path: # Checkpoint does not exist on local path. Save # in temporary directory, but clean up later local_path = self.to_directory() cleanup = True checkpoint_data_path = os.path.join(local_path, _DICT_CHECKPOINT_FILE_NAME) if os.path.exists(checkpoint_data_path): # If we are restoring a dict checkpoint, load the dict # from the checkpoint file. with open(checkpoint_data_path, "rb") as f: checkpoint_data = pickle.load(f) else: data = _pack(local_path) checkpoint_data = { _FS_CHECKPOINT_KEY: data, } if cleanup: shutil.rmtree(local_path) return checkpoint_data else: raise RuntimeError(f"Empty data for checkpoint {self}")
def load_checkpoint( checkpoint: Checkpoint, ) -> Tuple[BaseEstimator, Optional[Preprocessor]]: """Load a Checkpoint from ``SklearnTrainer``. Args: checkpoint: The checkpoint to load the estimator and preprocessor from. It is expected to be from the result of a ``SklearnTrainer`` run. Returns: The estimator and AIR preprocessor contained within. """ with checkpoint.as_directory() as checkpoint_path: estimator_path = os.path.join(checkpoint_path, MODEL_KEY) with open(estimator_path, "rb") as f: estimator_path = cpickle.load(f) preprocessor = load_preprocessor_from_dir(checkpoint_path) return estimator_path, preprocessor
def from_checkpoint(cls, checkpoint: Checkpoint) -> "XGBoostPredictor": """Instantiate the predictor from a Checkpoint. The checkpoint is expected to be a result of ``XGBoostTrainer``. Args: checkpoint (Checkpoint): The checkpoint to load the model and preprocessor from. It is expected to be from the result of a ``XGBoostTrainer`` run. """ with checkpoint.as_directory() as path: bst = xgboost.Booster() bst.load_model(os.path.join(path, MODEL_KEY)) preprocessor_path = os.path.join(path, PREPROCESSOR_KEY) if os.path.exists(preprocessor_path): with open(preprocessor_path, "rb") as f: preprocessor = cpickle.load(f) else: preprocessor = None return XGBoostPredictor(model=bst, preprocessor=preprocessor)
def from_checkpoint(cls, checkpoint: Checkpoint) -> "LightGBMPredictor": """Instantiate the predictor from a Checkpoint. The checkpoint is expected to be a result of ``LightGBMTrainer``. Args: checkpoint (Checkpoint): The checkpoint to load the model and preprocessor from. It is expected to be from the result of a ``LightGBMTrainer`` run. """ path = checkpoint.to_directory() bst = lightgbm.Booster(model_file=os.path.join(path, MODEL_KEY)) preprocessor_path = os.path.join(path, PREPROCESSOR_KEY) if os.path.exists(preprocessor_path): with open(preprocessor_path, "rb") as f: preprocessor = cpickle.load(f) else: preprocessor = None shutil.rmtree(path) return LightGBMPredictor(model=bst, preprocessor=preprocessor)
def load_newest_checkpoint(dirpath: str, ckpt_pattern: str) -> dict: """Returns the most recently modified checkpoint. Assumes files are saved with an ordered name, most likely by :obj:atomic_save. Args: dirpath (str): Directory in which to look for the checkpoint file. ckpt_pattern (str): File name pattern to match to find checkpoint files. Returns: (dict) Deserialized state dict. """ import ray.cloudpickle as cloudpickle full_paths = glob.glob(os.path.join(dirpath, ckpt_pattern)) if not full_paths: return most_recent_checkpoint = max(full_paths) with open(most_recent_checkpoint, "rb") as f: checkpoint_state = cloudpickle.load(f) return checkpoint_state
def restore(self, checkpoint_path: str): with open(checkpoint_path, "rb") as input: trials_object = pickle.load(input) self.optimizer = trials_object
}, preprocessor=preprocessor, ) result = trainer.fit() # __trainer_end__ # __checkpoint_start__ import os import ray.cloudpickle as cpickle from ray.air.constants import PREPROCESSOR_KEY checkpoint = result.checkpoint with checkpoint.as_directory() as checkpoint_path: path = os.path.join(checkpoint_path, PREPROCESSOR_KEY) with open(path, "rb") as f: preprocessor = cpickle.load(f) print(preprocessor) # MixMaxScaler(columns=['x'], stats={'min(x)': 0, 'max(x)': 30}) # __checkpoint_end__ # __predictor_start__ from ray.train.batch_predictor import BatchPredictor from ray.train.xgboost import XGBoostPredictor test_dataset = ray.data.from_items([{"x": x} for x in range(2, 32, 3)]) batch_predictor = BatchPredictor.from_checkpoint(checkpoint, XGBoostPredictor) predicted_probabilities = batch_predictor.predict(test_dataset) predicted_probabilities.show() # {'predictions': 0.09843720495700836} # {'predictions': 5.604666709899902}
def restore(self, checkpoint_path): """Restores training state from a given model checkpoint. These checkpoints are returned from calls to save(). Subclasses should override ``load_checkpoint()`` instead to restore state. This method restores additional metadata saved with the checkpoint. `checkpoint_path` should match with the return from ``save()``. `checkpoint_path` can be `~/ray_results/exp/MyTrainable_abc/ checkpoint_00000/checkpoint`. Or, `~/ray_results/exp/MyTrainable_abc/checkpoint_00000`. `self.logdir` should generally be corresponding to `checkpoint_path`, for example, `~/ray_results/exp/MyTrainable_abc`. `self.remote_checkpoint_dir` in this case, is something like, `REMOTE_CHECKPOINT_BUCKET/exp/MyTrainable_abc` """ if self.uses_cloud_checkpointing: rel_checkpoint_dir = TrainableUtil.find_rel_checkpoint_dir( self.logdir, checkpoint_path) self.storage_client.sync_down( os.path.join(self.remote_checkpoint_dir, rel_checkpoint_dir), os.path.join(self.logdir, rel_checkpoint_dir), ) self.storage_client.wait_or_retry() # Ensure TrialCheckpoints are converted if isinstance(checkpoint_path, TrialCheckpoint): checkpoint_path = checkpoint_path.local_path with open(checkpoint_path + ".tune_metadata", "rb") as f: metadata = pickle.load(f) self._experiment_id = metadata["experiment_id"] self._iteration = metadata["iteration"] self._timesteps_total = metadata["timesteps_total"] self._time_total = metadata["time_total"] self._episodes_total = metadata["episodes_total"] saved_as_dict = metadata["saved_as_dict"] if saved_as_dict: with open(checkpoint_path, "rb") as loaded_state: checkpoint_dict = pickle.load(loaded_state) checkpoint_dict.update(tune_checkpoint_path=checkpoint_path) self.load_checkpoint(checkpoint_dict) else: self.load_checkpoint(checkpoint_path) self._time_since_restore = 0.0 self._timesteps_since_restore = 0 self._iterations_since_restore = 0 self._restored = True logger.info("Restored on %s from checkpoint: %s", self.get_current_ip(), checkpoint_path) state = { "_iteration": self._iteration, "_timesteps_total": self._timesteps_total, "_time_total": self._time_total, "_episodes_total": self._episodes_total, } logger.info("Current state after restoring: %s", state)