def testUriCheckpointSerde(self): # URI checkpoints keep the same internal representation, pointing to # a remote location checkpoint = Checkpoint.from_uri("s3://some/bucket") self._testCheckpointSerde(checkpoint, *checkpoint.get_internal_representation())
def get_best_checkpoint( self, trial: Trial, metric: Optional[str] = None, mode: Optional[str] = None) -> Optional[Checkpoint]: """Gets best persistent checkpoint path of provided trial. Any checkpoints with an associated metric value of ``nan`` will be filtered out. Args: trial: The log directory of a trial, or a trial instance. metric: key of trial info to return, e.g. "mean_accuracy". "training_iteration" is used by default if no value was passed to ``self.default_metric``. mode: One of [min, max]. Defaults to ``self.default_mode``. Returns: :class:`Checkpoint <ray.ml.Checkpoint>` object. """ metric = metric or self.default_metric or TRAINING_ITERATION mode = self._validate_mode(mode) checkpoint_paths = self.get_trial_checkpoints_paths(trial, metric) # Filter out nan. Sorting nan values leads to undefined behavior. checkpoint_paths = [(path, metric) for path, metric in checkpoint_paths if not is_nan(metric)] if not checkpoint_paths: logger.error(f"No checkpoints have been found for trial {trial}.") return None a = -1 if mode == "max" else 1 best_path_metrics = sorted(checkpoint_paths, key=lambda x: a * x[1]) best_path, best_metric = best_path_metrics[0] cloud_path = self._parse_cloud_path(best_path) if self._legacy_checkpoint: return TrialCheckpoint(local_path=best_path, cloud_path=cloud_path) if cloud_path: # Prefer cloud path over local path for downsteam processing return Checkpoint.from_uri(cloud_path) elif os.path.exists(best_path): return Checkpoint.from_directory(best_path) else: logger.error( f"No checkpoint locations for {trial} available on " f"this node. To avoid this, you " f"should enable checkpoint synchronization with the" f"`sync_config` argument in Ray Tune. " f"The checkpoint may be available on a different node - " f"please check this location on worker nodes: {best_path}") return None
def test_fs_delete_at_uri(self): """Test that clear bucket utility works""" checkpoint = self._prepare_fs_checkpoint() # Convert into dict checkpoint location = checkpoint.to_uri(self.cloud_uri) delete_at_uri(location) checkpoint = Checkpoint.from_uri(location) with self.assertRaises(FileNotFoundError): checkpoint.to_directory()
def test_fs_checkpoint_uri(self): """Test conversion from fs to cloud checkpoint and back.""" checkpoint = self._prepare_fs_checkpoint() # Convert into dict checkpoint location = checkpoint.to_uri(self.cloud_uri) self.assertIsInstance(location, str) self.assertIn("memory://", location) # Create from dict checkpoint = Checkpoint.from_uri(location) self.assertTrue(checkpoint._uri) self._assert_fs_checkpoint(checkpoint)
def test_fs_checkpoint_uri(self): """Test conversion from fs to cloud checkpoint and back.""" checkpoint = self._prepare_fs_checkpoint() with patch("subprocess.check_call", self.mock_s3): # Convert into dict checkpoint location = checkpoint.to_uri(self.cloud_uri) self.assertIsInstance(location, str) self.assertIn("s3://", location) # Create from dict checkpoint = Checkpoint.from_uri(location) self.assertTrue(checkpoint._uri) self._assert_fs_checkpoint(checkpoint)
def test_fs_checkpoint_uri_pa(self): """Test conversion from fs to cloud checkpoint and back.""" checkpoint = self._prepare_fs_checkpoint() # Clean up mock bucket delete_at_uri(self.cloud_uri_pa) _ensure_directory(self.cloud_uri_pa) # Convert into dict checkpoint location = checkpoint.to_uri(self.cloud_uri_pa) self.assertIsInstance(location, str) self.assertIn("mock://", location) # Create from dict checkpoint = Checkpoint.from_uri(location) self.assertTrue(checkpoint._uri) self._assert_fs_checkpoint(checkpoint)
def restore(self, checkpoint_path: str, checkpoint_node_ip: Optional[str] = None): """Restores training state from a given model checkpoint. These checkpoints are returned from calls to save(). Subclasses should override ``load_checkpoint()`` instead to restore state. This method restores additional metadata saved with the checkpoint. `checkpoint_path` should match with the return from ``save()``. `checkpoint_path` can be `~/ray_results/exp/MyTrainable_abc/ checkpoint_00000/checkpoint`. Or, `~/ray_results/exp/MyTrainable_abc/checkpoint_00000`. `self.logdir` should generally be corresponding to `checkpoint_path`, for example, `~/ray_results/exp/MyTrainable_abc`. `self.remote_checkpoint_dir` in this case, is something like, `REMOTE_CHECKPOINT_BUCKET/exp/MyTrainable_abc` Args: checkpoint_path: Path to restore checkpoint from. If this path does not exist on the local node, it will be fetched from external (cloud) storage if available, or restored from a remote node. checkpoint_node_ip: If given, try to restore checkpoint from this node if it doesn't exist locally or on cloud storage. """ # Ensure TrialCheckpoints are converted if isinstance(checkpoint_path, TrialCheckpoint): checkpoint_path = checkpoint_path.local_path if self.uses_cloud_checkpointing: rel_checkpoint_dir = TrainableUtil.find_rel_checkpoint_dir( self.logdir, checkpoint_path ) external_uri = os.path.join(self.remote_checkpoint_dir, rel_checkpoint_dir) local_dir = os.path.join(self.logdir, rel_checkpoint_dir) if self.storage_client: # Only keep for backwards compatibility self.storage_client.sync_down(external_uri, local_dir) self.storage_client.wait_or_retry() else: checkpoint = Checkpoint.from_uri(external_uri) retry_fn( lambda: checkpoint.to_directory(local_dir), subprocess.CalledProcessError, num_retries=3, sleep_time=1, ) elif ( # If a checkpoint source IP is given checkpoint_node_ip # And the checkpoint does not currently exist on the local node and not os.path.exists(checkpoint_node_ip) # And the source IP is different to the current IP and checkpoint_node_ip != ray.util.get_node_ip_address() ): checkpoint = get_checkpoint_from_remote_node( checkpoint_path, checkpoint_node_ip ) if checkpoint: checkpoint.to_directory(checkpoint_path) with open(checkpoint_path + ".tune_metadata", "rb") as f: metadata = pickle.load(f) self._experiment_id = metadata["experiment_id"] self._iteration = metadata["iteration"] self._timesteps_total = metadata["timesteps_total"] self._time_total = metadata["time_total"] self._episodes_total = metadata["episodes_total"] saved_as_dict = metadata["saved_as_dict"] if saved_as_dict: with open(checkpoint_path, "rb") as loaded_state: checkpoint_dict = pickle.load(loaded_state) checkpoint_dict.update(tune_checkpoint_path=checkpoint_path) self.load_checkpoint(checkpoint_dict) else: self.load_checkpoint(checkpoint_path) self._time_since_restore = 0.0 self._timesteps_since_restore = 0 self._iterations_since_restore = 0 self._restored = True logger.info( "Restored on %s from checkpoint: %s", self.get_current_ip(), checkpoint_path ) state = { "_iteration": self._iteration, "_timesteps_total": self._timesteps_total, "_time_total": self._time_total, "_episodes_total": self._episodes_total, } logger.info("Current state after restoring: %s", state)