Exemple #1
0
    def __init__(
        self,
        restore_path: str = None,
        trainable: Optional[Union[str, Callable, Type[Trainable],
                                  BaseTrainer, ]] = None,
        param_space: Optional[Dict[str, Any]] = None,
        tune_config: Optional[TuneConfig] = None,
        run_config: Optional[RunConfig] = None,
        _tuner_kwargs: Optional[Dict] = None,
    ):
        # Restored from Tuner checkpoint.
        if restore_path:
            trainable_ckpt = os.path.join(restore_path, _TRAINABLE_PKL)
            with open(trainable_ckpt, "rb") as fp:
                trainable = pickle.load(fp)

            tuner_ckpt = os.path.join(restore_path, _TUNER_PKL)
            with open(tuner_ckpt, "rb") as fp:
                tuner = pickle.load(fp)
                self.__dict__.update(tuner.__dict__)

            self._is_restored = True
            self._trainable = trainable
            self._experiment_checkpoint_dir = restore_path
            return

        # Start from fresh
        if not trainable:
            raise TuneError("You need to provide a trainable to tune.")

        # If no run config was passed to Tuner directly, use the one from the Trainer,
        # if available
        if not run_config and isinstance(trainable, BaseTrainer):
            run_config = trainable.run_config

        self._is_restored = False
        self._trainable = trainable
        self._tune_config = tune_config or TuneConfig()
        self._run_config = run_config or RunConfig()
        self._tuner_kwargs = copy.deepcopy(_tuner_kwargs) or {}
        self._experiment_checkpoint_dir = self._setup_create_experiment_checkpoint_dir(
            self._run_config)

        # Not used for restored Tuner.
        self._param_space = param_space or {}

        # This needs to happen before `tune.run()` is kicked in.
        # This is because currently tune does not exit gracefully if
        # run in ray client mode - if crash happens, it just exits immediately
        # without allowing for checkpointing tuner and trainable.
        # Thus this has to happen before tune.run() so that we can have something
        # to restore from.
        tuner_ckpt = os.path.join(self._experiment_checkpoint_dir, _TUNER_PKL)
        with open(tuner_ckpt, "wb") as fp:
            pickle.dump(self, fp)

        trainable_ckpt = os.path.join(self._experiment_checkpoint_dir,
                                      _TRAINABLE_PKL)
        with open(trainable_ckpt, "wb") as fp:
            pickle.dump(self._trainable, fp)
Exemple #2
0
def load_checkpoint(
    checkpoint: Checkpoint,
    env: Optional[EnvType] = None,
) -> Tuple[Policy, Optional[Preprocessor]]:
    """Load a Checkpoint from ``RLTrainer``.

    Args:
        checkpoint: The checkpoint to load the policy and
            preprocessor from. It is expected to be from the result of a
            ``RLTrainer`` run.
        env: Optional environment to instantiate the trainer with. If not given,
            it is parsed from the saved trainer configuration instead.

    Returns:
        The policy and AIR preprocessor contained within.
    """
    with checkpoint.as_directory() as checkpoint_path:
        trainer_class_path = os.path.join(checkpoint_path, RL_TRAINER_CLASS_FILE)
        config_path = os.path.join(checkpoint_path, RL_CONFIG_FILE)

        if not os.path.exists(trainer_class_path):
            raise ValueError(
                f"RLPredictor only works with checkpoints created by "
                f"RLTrainer. The checkpoint you specified is missing the "
                f"`{RL_TRAINER_CLASS_FILE}` file."
            )

        if not os.path.exists(config_path):
            raise ValueError(
                f"RLPredictor only works with checkpoints created by "
                f"RLTrainer. The checkpoint you specified is missing the "
                f"`{RL_CONFIG_FILE}` file."
            )

        with open(trainer_class_path, "rb") as fp:
            trainer_cls = cpickle.load(fp)

        with open(config_path, "rb") as fp:
            config = cpickle.load(fp)

        checkpoint_data_path = None
        for file in os.listdir(checkpoint_path):
            if file.startswith("checkpoint") and not file.endswith(".tune_metadata"):
                checkpoint_data_path = os.path.join(checkpoint_path, file)

        if not checkpoint_data_path:
            raise ValueError(
                f"Could not find checkpoint data in RLlib checkpoint. "
                f"Found files: {list(os.listdir(checkpoint_path))}"
            )

        preprocessor = load_preprocessor_from_dir(checkpoint_path)

        config.get("evaluation_config", {}).pop("in_evaluation", None)
        trainer = trainer_cls(config=config, env=env)
        trainer.restore(checkpoint_data_path)

        policy = trainer.get_policy()
        return policy, preprocessor
Exemple #3
0
    def get_policy(self, env: Optional[EnvType] = None) -> Policy:
        """Retrieve the policy stored in this checkpoint.

        Args:
            env: Optional environment to instantiate the trainer with. If not given,
                it is parsed from the saved trainer configuration.

        Returns:
            The policy stored in this checkpoint.
        """
        with self.as_directory() as checkpoint_path:
            trainer_class_path = os.path.join(checkpoint_path,
                                              RL_TRAINER_CLASS_FILE)
            config_path = os.path.join(checkpoint_path, RL_CONFIG_FILE)

            if not os.path.exists(trainer_class_path):
                raise ValueError(
                    f"RLPredictor only works with checkpoints created by "
                    f"RLTrainer. The checkpoint you specified is missing the "
                    f"`{RL_TRAINER_CLASS_FILE}` file.")

            if not os.path.exists(config_path):
                raise ValueError(
                    f"RLPredictor only works with checkpoints created by "
                    f"RLTrainer. The checkpoint you specified is missing the "
                    f"`{RL_CONFIG_FILE}` file.")

            with open(trainer_class_path, "rb") as fp:
                trainer_cls = cpickle.load(fp)

            with open(config_path, "rb") as fp:
                config = cpickle.load(fp)

            checkpoint_data_path = None
            for file in os.listdir(checkpoint_path):
                if file.startswith(
                        "checkpoint") and not file.endswith(".tune_metadata"):
                    checkpoint_data_path = os.path.join(checkpoint_path, file)

            if not checkpoint_data_path:
                raise ValueError(
                    f"Could not find checkpoint data in RLlib checkpoint. "
                    f"Found files: {list(os.listdir(checkpoint_path))}")

            config.get("evaluation_config", {}).pop("in_evaluation", None)
            trainer = trainer_cls(config=config, env=env)
            trainer.restore(checkpoint_data_path)

            return trainer.get_policy()
Exemple #4
0
def load_from_checkpoint(
    checkpoint: Checkpoint,
) -> Tuple[RandomForestClassifier, Optional[Preprocessor]]:
    path = checkpoint.to_directory()
    estimator_path = os.path.join(path, MODEL_KEY)
    with open(estimator_path, "rb") as f:
        estimator = cpickle.load(f)
    preprocessor_path = os.path.join(path, PREPROCESSOR_KEY)
    if os.path.exists(preprocessor_path):
        with open(preprocessor_path, "rb") as f:
            preprocessor = cpickle.load(f)
    else:
        preprocessor = None

    return estimator, preprocessor
Exemple #5
0
def load_checkpoint_from_path(checkpoint_to_load: Union[str, Path]) -> Dict:
    """Utility function to load a checkpoint Dict from a path."""
    checkpoint_path = Path(checkpoint_to_load).expanduser()
    if not checkpoint_path.exists():
        raise ValueError(f"Checkpoint path {checkpoint_path} does not exist.")
    with checkpoint_path.open("rb") as f:
        return cloudpickle.load(f)
Exemple #6
0
def test_retry(ray_start_2_cpus):
    def train_func():
        ckpt = sgd.load_checkpoint()
        restored = bool(ckpt)  # Does a previous checkpoint exist?
        itr = 0
        if ckpt:
            itr = ckpt["iter"] + 1

        for i in range(itr, 4):
            if i == 2 and not restored:
                raise Exception("try to fail me")
            sgd.save_checkpoint(iter=i)
            sgd.report(test=i, training_iteration=i)

    trainer = Trainer(TestConfig())
    TestTrainable = trainer.to_tune_trainable(train_func)

    analysis = tune.run(TestTrainable, max_failures=3)
    last_ckpt = analysis.trials[0].checkpoint.value
    checkpoint_file = os.path.join(last_ckpt, TUNE_CHECKPOINT_FILE_NAME)
    assert os.path.exists(checkpoint_file)
    with open(checkpoint_file, "rb") as f:
        checkpoint = cloudpickle.load(f)
        assert checkpoint["iter"] == 3
    trial_dfs = list(analysis.trial_dataframes.values())
    assert len(trial_dfs[0]["training_iteration"]) == 4
Exemple #7
0
def test_reuse_checkpoint(ray_start_2_cpus):
    def train_func(config):
        itr = 0
        ckpt = sgd.load_checkpoint()
        if ckpt is not None:
            itr = ckpt["iter"] + 1

        for i in range(itr, config["max_iter"]):
            sgd.save_checkpoint(iter=i)
            sgd.report(test=i, training_iteration=i)

    trainer = Trainer(TestConfig())
    TestTrainable = trainer.to_tune_trainable(train_func)

    [trial] = tune.run(TestTrainable, config={"max_iter": 5}).trials
    last_ckpt = trial.checkpoint.value
    checkpoint_file = os.path.join(last_ckpt, TUNE_CHECKPOINT_FILE_NAME)
    assert os.path.exists(checkpoint_file)
    with open(checkpoint_file, "rb") as f:
        checkpoint = cloudpickle.load(f)
        assert checkpoint["iter"] == 4
    analysis = tune.run(TestTrainable,
                        config={"max_iter": 10},
                        restore=last_ckpt)
    trial_dfs = list(analysis.trial_dataframes.values())
    assert len(trial_dfs[0]["training_iteration"]) == 5
Exemple #8
0
    def to_dict(self) -> dict:
        """Return checkpoint data as dictionary.

        Returns:
            dict: Dictionary containing checkpoint data.
        """
        if self._data_dict:
            # If the checkpoint data is already a dict, return
            return self._data_dict
        elif self._obj_ref:
            # If the checkpoint data is an object reference, resolve
            return ray.get(self._obj_ref)
        elif self._local_path or self._uri:
            # Else, checkpoint is either on FS or external storage
            with self.as_directory() as local_path:
                checkpoint_data_path = os.path.join(
                    local_path, _DICT_CHECKPOINT_FILE_NAME)
                if os.path.exists(checkpoint_data_path):
                    # If we are restoring a dict checkpoint, load the dict
                    # from the checkpoint file.
                    with open(checkpoint_data_path, "rb") as f:
                        checkpoint_data = pickle.load(f)
                else:
                    data = _pack(local_path)

                    checkpoint_data = {
                        _FS_CHECKPOINT_KEY: data,
                    }
                return checkpoint_data
        else:
            raise RuntimeError(f"Empty data for checkpoint {self}")
Exemple #9
0
def restore_policy_from_checkpoint(
        policy_class: type,
        env_creator: Callable[[Dict[str, Any]], gym.Env],
        checkpoint_path: str,
        config: Dict[str, Any]) -> Policy:
    """ TODO: Write documentation
    """
    # Load checkpoint policy state
    with open(checkpoint_path, "rb") as checkpoint_dump:
        checkpoint_state = pickle.load(checkpoint_dump)
        worker_dump = checkpoint_state['worker']
        worker_state = pickle.loads(worker_dump)
        policy_state = worker_state['state']['default_policy']

    # Initiate temporary environment to get observation and action spaces
    env = env_creator(config.get("env_config", {}))

    # Get preprocessed observation space
    preprocessor_class = get_preprocessor(env.observation_space)
    preprocessor = preprocessor_class(env.observation_space)
    observation_space = preprocessor.observation_space

    # Instantiate policy and load checkpoint state
    policy = policy_class(observation_space, env.action_space, config)
    policy.set_state(policy_state)

    return policy
Exemple #10
0
 def restore(self, checkpoint_path: str):
     with open(checkpoint_path, "rb") as inputFile:
         save_object = cloudpickle.load(inputFile)
     numpy_random_state = save_object.pop("_random_state_seed_to_set", None)
     self.__dict__.update(save_object)
     if numpy_random_state is not None:
         np.random.set_state(numpy_random_state)
Exemple #11
0
    def restore(self, checkpoint_path):
        """Restores training state from a given model checkpoint.

        These checkpoints are returned from calls to save().

        Subclasses should override ``_restore()`` instead to restore state.
        This method restores additional metadata saved with the checkpoint.
        """
        # Maybe sync from cloud
        if self.uses_cloud_checkpointing:
            self.storage_client.sync_down(self.remote_checkpoint_dir,
                                          self.logdir)
            self.storage_client.wait()

        # Ensure TrialCheckpoints are converted
        if isinstance(checkpoint_path, TrialCheckpoint):
            checkpoint_path = checkpoint_path.local_path

        with open(checkpoint_path + ".tune_metadata", "rb") as f:
            metadata = pickle.load(f)
        self._experiment_id = metadata["experiment_id"]
        self._iteration = metadata["iteration"]
        self._timesteps_total = metadata["timesteps_total"]
        self._time_total = metadata["time_total"]
        self._episodes_total = metadata["episodes_total"]
        saved_as_dict = metadata["saved_as_dict"]
        if saved_as_dict:
            with open(checkpoint_path, "rb") as loaded_state:
                checkpoint_dict = pickle.load(loaded_state)
            checkpoint_dict.update(tune_checkpoint_path=checkpoint_path)
            self.load_checkpoint(checkpoint_dict)
        else:
            self.load_checkpoint(checkpoint_path)
        self._time_since_restore = 0.0
        self._timesteps_since_restore = 0
        self._iterations_since_restore = 0
        self._restored = True
        logger.info("Restored on %s from checkpoint: %s",
                    self.get_current_ip(), checkpoint_path)
        state = {
            "_iteration": self._iteration,
            "_timesteps_total": self._timesteps_total,
            "_time_total": self._time_total,
            "_episodes_total": self._episodes_total,
        }
        logger.info("Current state after restoring: %s", state)
Exemple #12
0
 def restore(self, checkpoint_path: str):
     with open(checkpoint_path, "rb") as inputFile:
         save_object = pickle.load(inputFile)
     if not isinstance(save_object, dict):
         # backwards compatibility
         # Deprecate: 1.8
         self.optimizer = save_object
     self.__dict__.update(save_object)
Exemple #13
0
def _find_newest_ckpt(dirpath: str, pattern: str):
    """Returns path to most recently modified checkpoint."""
    full_paths = glob.glob(os.path.join(dirpath, pattern))
    if not full_paths:
        return
    most_recent_checkpoint = max(full_paths)
    with open(most_recent_checkpoint, "rb") as f:
        search_alg_state = cloudpickle.load(f)
    return search_alg_state
Exemple #14
0
    def to_dict(self) -> dict:
        """Return checkpoint data as dictionary.

        Returns:
            dict: Dictionary containing checkpoint data.
        """
        if self._data_dict:
            # If the checkpoint data is already a dict, return
            return self._data_dict
        elif self._obj_ref:
            # If the checkpoint data is an object reference, resolve
            return ray.get(self._obj_ref)
        elif self._local_path or self._uri:
            # Else, checkpoint is either on FS or external storage
            with self.as_directory() as local_path:
                checkpoint_data_path = os.path.join(
                    local_path, _DICT_CHECKPOINT_FILE_NAME)
                if os.path.exists(checkpoint_data_path):
                    # If we are restoring a dict checkpoint, load the dict
                    # from the checkpoint file.
                    with open(checkpoint_data_path, "rb") as f:
                        checkpoint_data = pickle.load(f)
                else:
                    files = [
                        f for f in os.listdir(local_path)
                        if os.path.isfile(os.path.join(local_path, f))
                        and f.endswith(_METADATA_CHECKPOINT_SUFFIX)
                    ]
                    metadata = {}
                    for file in files:
                        with open(os.path.join(local_path, file), "rb") as f:
                            key = file[:-len(_METADATA_CHECKPOINT_SUFFIX)]
                            value = pickle.load(f)
                            metadata[key] = value

                    data = _pack(local_path)

                    checkpoint_data = {
                        _FS_CHECKPOINT_KEY: data,
                    }
                    checkpoint_data.update(metadata)
                return checkpoint_data
        else:
            raise RuntimeError(f"Empty data for checkpoint {self}")
Exemple #15
0
def load_preprocessor_from_dir(
    parent_dir: os.PathLike, ) -> Optional["Preprocessor"]:
    """Loads preprocessor from directory, if file exists."""
    parent_dir = Path(parent_dir)
    preprocessor_path = parent_dir.joinpath(PREPROCESSOR_KEY)
    if preprocessor_path.exists():
        with open(preprocessor_path, "rb") as f:
            preprocessor = cpickle.load(f)
    else:
        preprocessor = None
    return preprocessor
Exemple #16
0
    def load_checkpoint_metadata(checkpoint_path: str) -> Optional[Dict]:
        metadata_path = os.path.join(checkpoint_path, ".tune_metadata")
        if not os.path.exists(metadata_path):
            checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)
            metadatas = glob.glob(f"{checkpoint_dir}/**/.tune_metadata", recursive=True)
            if not metadatas:
                return None
            metadata_path = metadatas[0]

        with open(metadata_path, "rb") as f:
            return pickle.load(f)
Exemple #17
0
    def _restore(self, checkpoint):
        """Loads a checkpoint created from `save`.

        Args:
            checkpoint (str): file path to pickled checkpoint file.

        """
        if self.pickled:
            with open(checkpoint, "rb") as f:
                self.estimator = cpickle.load(f)
        else:
            warnings.warn("No estimator restored")
    def _restore(self, checkpoint):
        """Loads a checkpoint created from `save`.

        Args:
            checkpoint (str): file path to pickled checkpoint file.

        """
        try:
            with open(checkpoint, "rb") as f:
                self.estimator_list = cpickle.load(f)
        except Exception:
            warnings.warn("No estimator restored", category=RuntimeWarning)
def load_from_checkpoint(
    checkpoint: Checkpoint, ) -> Tuple[xgb.Booster, Optional[Preprocessor]]:
    checkpoint_path = checkpoint.to_directory()
    xgb_model = xgb.Booster()
    xgb_model.load_model(os.path.join(checkpoint_path, MODEL_KEY))
    preprocessor_path = os.path.join(checkpoint_path, PREPROCESSOR_KEY)
    if os.path.exists(preprocessor_path):
        with open(preprocessor_path, "rb") as f:
            preprocessor = cpickle.load(f)
    else:
        preprocessor = None

    return xgb_model, preprocessor
Exemple #20
0
    def get_checkpoints_paths(logdir):
        """Finds the checkpoints within a specific folder.

        Returns a pandas DataFrame of training iterations and checkpoint
        paths within a specific folder.

        Raises:
            FileNotFoundError if the directory is not found.
        """
        marker_paths = glob.glob(
            os.path.join(glob.escape(logdir), "checkpoint_*/.is_checkpoint")
        )
        iter_chkpt_pairs = []
        for marker_path in marker_paths:
            chkpt_dir = os.path.dirname(marker_path)

            # Skip temporary checkpoints
            if os.path.basename(chkpt_dir).startswith("checkpoint_tmp"):
                continue

            metadata_file = glob.glob(
                os.path.join(glob.escape(chkpt_dir), "*.tune_metadata")
            )
            # glob.glob: filenames starting with a dot are special cases
            # that are not matched by '*' and '?' patterns.
            metadata_file += glob.glob(
                os.path.join(glob.escape(chkpt_dir), ".tune_metadata")
            )
            metadata_file = list(set(metadata_file))  # avoid duplication
            if len(metadata_file) != 1:
                raise ValueError(
                    "{} has zero or more than one tune_metadata.".format(chkpt_dir)
                )

            metadata_file = metadata_file[0]

            try:
                with open(metadata_file, "rb") as f:
                    metadata = pickle.load(f)
            except Exception as e:
                logger.warning(f"Could not read metadata from checkpoint: {e}")
                metadata = {}

            chkpt_path = metadata_file[: -len(".tune_metadata")]
            chkpt_iter = metadata.get("iteration", -1)
            iter_chkpt_pairs.append([chkpt_iter, chkpt_path])

        chkpt_df = pd.DataFrame(
            iter_chkpt_pairs, columns=["training_iteration", "chkpt_path"]
        )
        return chkpt_df
Exemple #21
0
def test_tune_checkpoint(ray_start_2_cpus):
    def train_func():
        for i in range(10):
            train.report(test=i)
        train.save_checkpoint(hello="world")

    trainer = Trainer(TestConfig(), num_workers=1)
    TestTrainable = trainer.to_tune_trainable(train_func)

    [trial] = tune.run(TestTrainable).trials
    checkpoint_file = os.path.join(trial.checkpoint.value, TUNE_CHECKPOINT_FILE_NAME)
    assert os.path.exists(checkpoint_file)
    with open(checkpoint_file, "rb") as f:
        checkpoint = cloudpickle.load(f)
        assert checkpoint["hello"] == "world"
Exemple #22
0
 def _load_checkpoint(
     self, checkpoint_to_load: Optional[Union[Dict, str,
                                              Path]]) -> Optional[Dict]:
     """Load the checkpoint dictionary from the input dict or path."""
     if checkpoint_to_load is None:
         return None
     if isinstance(checkpoint_to_load, Dict):
         return checkpoint_to_load
     else:
         # Load checkpoint from path.
         checkpoint_path = Path(checkpoint_to_load).expanduser()
         if not checkpoint_path.exists():
             raise ValueError(f"Checkpoint path {checkpoint_path} "
                              f"does not exist.")
         with checkpoint_path.open("rb") as f:
             return cloudpickle.load(f)
Exemple #23
0
    def to_dict(self) -> dict:
        """Return checkpoint data as dictionary.

        Returns:
            dict: Dictionary containing checkpoint data.
        """
        if self._data_dict:
            # If the checkpoint data is already a dict, return
            return self._data_dict
        elif self._obj_ref:
            # If the checkpoint data is an object reference, resolve
            return ray.get(self._obj_ref)
        elif self._local_path or self._uri:
            # Else, checkpoint is either on FS or external storage
            cleanup = False

            local_path = self._local_path
            if not local_path:
                # Checkpoint does not exist on local path. Save
                # in temporary directory, but clean up later
                local_path = self.to_directory()
                cleanup = True

            checkpoint_data_path = os.path.join(local_path,
                                                _DICT_CHECKPOINT_FILE_NAME)
            if os.path.exists(checkpoint_data_path):
                # If we are restoring a dict checkpoint, load the dict
                # from the checkpoint file.
                with open(checkpoint_data_path, "rb") as f:
                    checkpoint_data = pickle.load(f)
            else:
                data = _pack(local_path)

                checkpoint_data = {
                    _FS_CHECKPOINT_KEY: data,
                }

            if cleanup:
                shutil.rmtree(local_path)

            return checkpoint_data
        else:
            raise RuntimeError(f"Empty data for checkpoint {self}")
Exemple #24
0
def load_checkpoint(
    checkpoint: Checkpoint, ) -> Tuple[BaseEstimator, Optional[Preprocessor]]:
    """Load a Checkpoint from ``SklearnTrainer``.

    Args:
        checkpoint: The checkpoint to load the estimator and
            preprocessor from. It is expected to be from the result of a
            ``SklearnTrainer`` run.

    Returns:
        The estimator and AIR preprocessor contained within.
    """
    with checkpoint.as_directory() as checkpoint_path:
        estimator_path = os.path.join(checkpoint_path, MODEL_KEY)
        with open(estimator_path, "rb") as f:
            estimator_path = cpickle.load(f)
        preprocessor = load_preprocessor_from_dir(checkpoint_path)

    return estimator_path, preprocessor
Exemple #25
0
    def from_checkpoint(cls, checkpoint: Checkpoint) -> "XGBoostPredictor":
        """Instantiate the predictor from a Checkpoint.

        The checkpoint is expected to be a result of ``XGBoostTrainer``.

        Args:
            checkpoint (Checkpoint): The checkpoint to load the model and
                preprocessor from. It is expected to be from the result of a
                ``XGBoostTrainer`` run.

        """
        with checkpoint.as_directory() as path:
            bst = xgboost.Booster()
            bst.load_model(os.path.join(path, MODEL_KEY))
            preprocessor_path = os.path.join(path, PREPROCESSOR_KEY)
            if os.path.exists(preprocessor_path):
                with open(preprocessor_path, "rb") as f:
                    preprocessor = cpickle.load(f)
            else:
                preprocessor = None
        return XGBoostPredictor(model=bst, preprocessor=preprocessor)
    def from_checkpoint(cls, checkpoint: Checkpoint) -> "LightGBMPredictor":
        """Instantiate the predictor from a Checkpoint.

        The checkpoint is expected to be a result of ``LightGBMTrainer``.

        Args:
            checkpoint (Checkpoint): The checkpoint to load the model and
                preprocessor from. It is expected to be from the result of a
                ``LightGBMTrainer`` run.

        """
        path = checkpoint.to_directory()
        bst = lightgbm.Booster(model_file=os.path.join(path, MODEL_KEY))
        preprocessor_path = os.path.join(path, PREPROCESSOR_KEY)
        if os.path.exists(preprocessor_path):
            with open(preprocessor_path, "rb") as f:
                preprocessor = cpickle.load(f)
        else:
            preprocessor = None
        shutil.rmtree(path)
        return LightGBMPredictor(model=bst, preprocessor=preprocessor)
Exemple #27
0
def load_newest_checkpoint(dirpath: str, ckpt_pattern: str) -> dict:
    """Returns the most recently modified checkpoint.

    Assumes files are saved with an ordered name, most likely by
    :obj:atomic_save.

    Args:
        dirpath (str): Directory in which to look for the checkpoint file.
        ckpt_pattern (str): File name pattern to match to find checkpoint
            files.

    Returns:
        (dict) Deserialized state dict.
    """
    import ray.cloudpickle as cloudpickle
    full_paths = glob.glob(os.path.join(dirpath, ckpt_pattern))
    if not full_paths:
        return
    most_recent_checkpoint = max(full_paths)
    with open(most_recent_checkpoint, "rb") as f:
        checkpoint_state = cloudpickle.load(f)
    return checkpoint_state
Exemple #28
0
 def restore(self, checkpoint_path: str):
     with open(checkpoint_path, "rb") as input:
         trials_object = pickle.load(input)
     self.optimizer = trials_object
Exemple #29
0
    },
    preprocessor=preprocessor,
)
result = trainer.fit()
# __trainer_end__

# __checkpoint_start__
import os
import ray.cloudpickle as cpickle
from ray.air.constants import PREPROCESSOR_KEY

checkpoint = result.checkpoint
with checkpoint.as_directory() as checkpoint_path:
    path = os.path.join(checkpoint_path, PREPROCESSOR_KEY)
    with open(path, "rb") as f:
        preprocessor = cpickle.load(f)
    print(preprocessor)
# MixMaxScaler(columns=['x'], stats={'min(x)': 0, 'max(x)': 30})
# __checkpoint_end__

# __predictor_start__
from ray.train.batch_predictor import BatchPredictor
from ray.train.xgboost import XGBoostPredictor

test_dataset = ray.data.from_items([{"x": x} for x in range(2, 32, 3)])

batch_predictor = BatchPredictor.from_checkpoint(checkpoint, XGBoostPredictor)
predicted_probabilities = batch_predictor.predict(test_dataset)
predicted_probabilities.show()
# {'predictions': 0.09843720495700836}
# {'predictions': 5.604666709899902}
Exemple #30
0
    def restore(self, checkpoint_path):
        """Restores training state from a given model checkpoint.

        These checkpoints are returned from calls to save().

        Subclasses should override ``load_checkpoint()`` instead to
        restore state.
        This method restores additional metadata saved with the checkpoint.

        `checkpoint_path` should match with the return from ``save()``.

        `checkpoint_path` can be
        `~/ray_results/exp/MyTrainable_abc/
        checkpoint_00000/checkpoint`. Or,
        `~/ray_results/exp/MyTrainable_abc/checkpoint_00000`.

        `self.logdir` should generally be corresponding to `checkpoint_path`,
        for example, `~/ray_results/exp/MyTrainable_abc`.

        `self.remote_checkpoint_dir` in this case, is something like,
        `REMOTE_CHECKPOINT_BUCKET/exp/MyTrainable_abc`
        """
        if self.uses_cloud_checkpointing:
            rel_checkpoint_dir = TrainableUtil.find_rel_checkpoint_dir(
                self.logdir, checkpoint_path)
            self.storage_client.sync_down(
                os.path.join(self.remote_checkpoint_dir, rel_checkpoint_dir),
                os.path.join(self.logdir, rel_checkpoint_dir),
            )
            self.storage_client.wait_or_retry()

        # Ensure TrialCheckpoints are converted
        if isinstance(checkpoint_path, TrialCheckpoint):
            checkpoint_path = checkpoint_path.local_path

        with open(checkpoint_path + ".tune_metadata", "rb") as f:
            metadata = pickle.load(f)
        self._experiment_id = metadata["experiment_id"]
        self._iteration = metadata["iteration"]
        self._timesteps_total = metadata["timesteps_total"]
        self._time_total = metadata["time_total"]
        self._episodes_total = metadata["episodes_total"]
        saved_as_dict = metadata["saved_as_dict"]
        if saved_as_dict:
            with open(checkpoint_path, "rb") as loaded_state:
                checkpoint_dict = pickle.load(loaded_state)
            checkpoint_dict.update(tune_checkpoint_path=checkpoint_path)
            self.load_checkpoint(checkpoint_dict)
        else:
            self.load_checkpoint(checkpoint_path)
        self._time_since_restore = 0.0
        self._timesteps_since_restore = 0
        self._iterations_since_restore = 0
        self._restored = True
        logger.info("Restored on %s from checkpoint: %s",
                    self.get_current_ip(), checkpoint_path)
        state = {
            "_iteration": self._iteration,
            "_timesteps_total": self._timesteps_total,
            "_time_total": self._time_total,
            "_episodes_total": self._episodes_total,
        }
        logger.info("Current state after restoring: %s", state)