Exemple #1
0
    def testCheckpointReuseObjectWithoutTraining(self):
        """Test that repeated save/restore never reuses same checkpoint dir."""
        def train(config, checkpoint_dir=None):
            if checkpoint_dir:
                count = sum("checkpoint-" in path
                            for path in os.listdir(checkpoint_dir))
                assert count == 1, os.listdir(checkpoint_dir)

            for step in range(20):
                with tune.checkpoint_dir(step=step) as checkpoint_dir:
                    path = os.path.join(checkpoint_dir,
                                        "checkpoint-{}".format(step))
                    open(path, "a").close()
                tune.report(test=step)

        wrapped = wrap_function(train)
        new_trainable = wrapped(logger_creator=self.logger_creator)
        for i in range(2):
            result = new_trainable.train()
        checkpoint = new_trainable.save_to_object()
        new_trainable.stop()

        new_trainable2 = wrapped(logger_creator=self.logger_creator)
        new_trainable2.restore_from_object(checkpoint)
        new_trainable2.stop()

        new_trainable2 = wrapped(logger_creator=self.logger_creator)
        new_trainable2.restore_from_object(checkpoint)
        result = new_trainable2.train()
        new_trainable2.stop()
        self.assertTrue(result[TRAINING_ITERATION] == 3)
Exemple #2
0
    def testReuseNullCheckpoint(self):
        def train(config, checkpoint_dir=None):
            assert not checkpoint_dir
            for step in range(10):
                tune.report(test=step)

        # Create checkpoint
        wrapped = wrap_function(train)
        checkpoint = None
        new_trainable = wrapped(logger_creator=self.logger_creator)
        new_trainable.train()
        checkpoint = new_trainable.save()
        new_trainable.stop()

        # Use the checkpoint a couple of times
        for i in range(3):
            new_trainable = wrapped(logger_creator=self.logger_creator)
            new_trainable.restore(checkpoint)
            new_trainable.stop()

        # Make sure the result is still good
        new_trainable = wrapped(logger_creator=self.logger_creator)
        new_trainable.restore(checkpoint)
        result = new_trainable.train()
        checkpoint = new_trainable.save()
        new_trainable.stop()
        self.assertTrue(result[TRAINING_ITERATION] == 1)
Exemple #3
0
    def testFunctionRecurringSave(self):
        """This tests that save and restore are commutative."""
        def train(config, checkpoint_dir=None):
            if checkpoint_dir:
                assert os.path.exists(checkpoint_dir)
            for step in range(10):
                if step % 3 == 0:
                    with tune.checkpoint_dir(step=step) as checkpoint_dir:
                        path = os.path.join(checkpoint_dir, "checkpoint")
                        with open(path, "w") as f:
                            f.write(json.dumps({"step": step}))
                tune.report(test=step)

        wrapped = wrap_function(train)

        new_trainable = wrapped(logger_creator=self.logger_creator)
        new_trainable.train()
        checkpoint_obj = new_trainable.save_to_object()
        new_trainable.restore_from_object(checkpoint_obj)
        checkpoint = new_trainable.save()

        new_trainable.stop()

        new_trainable2 = wrapped(logger_creator=self.logger_creator)
        new_trainable2.restore(checkpoint)
        new_trainable2.train()
        new_trainable2.stop()
Exemple #4
0
    def testFunctionImmediateSave(self):
        """This tests that save and restore are commutative."""
        def train(config, checkpoint_dir=None):
            if checkpoint_dir:
                assert os.path.exists(checkpoint_dir)
            for step in range(10):
                with tune.checkpoint_dir(step=step) as checkpoint_dir:
                    print(checkpoint_dir)
                    path = os.path.join(checkpoint_dir,
                                        "checkpoint-{}".format(step))
                    open(path, "w").close()
                tune.report(test=step)

        wrapped = wrap_function(train)
        new_trainable = wrapped(logger_creator=self.logger_creator)
        new_trainable.train()
        new_trainable.train()
        checkpoint_obj = new_trainable.save_to_object()
        new_trainable.stop()

        new_trainable2 = wrapped(logger_creator=self.logger_creator)
        new_trainable2.restore_from_object(checkpoint_obj)
        checkpoint_obj = new_trainable2.save_to_object()
        new_trainable2.train()
        result = new_trainable2.train()
        assert sum("tmp" in path for path in os.listdir(self.logdir)) == 1
        new_trainable2.stop()
        assert sum("tmp" in path for path in os.listdir(self.logdir)) == 0
        assert result[TRAINING_ITERATION] == 4
Exemple #5
0
def create_resettable_function(num_resets: defaultdict):
    def trainable(config, checkpoint_dir=None):
        if checkpoint_dir:
            with open(os.path.join(checkpoint_dir, "chkpt"), "rb") as fp:
                step = pickle.load(fp)
        else:
            step = 0

        while step < 2:
            step += 1
            with tune.checkpoint_dir(step) as checkpoint_dir:
                with open(os.path.join(checkpoint_dir, "chkpt"), "wb") as fp:
                    pickle.dump(step, fp)
            tune.report(**{
                "done": step >= 2,
                "iter": step,
                "id": config["id"]
            })

    trainable = wrap_function(trainable)

    class ResetCountTrainable(trainable):
        def reset_config(self, new_config):
            num_resets[self.trial_id] += 1
            return super().reset_config(new_config)

    return ResetCountTrainable
Exemple #6
0
    def setup(self, config):
        self._finished = False
        num_workers = self._num_workers
        logdir = self.logdir
        assert self._function

        func_trainable = wrap_function(self.__class__._function)

        remote_trainable = ray.remote(func_trainable)
        remote_trainable = remote_trainable.options(
            **self.get_remote_worker_options())

        address = setup_address()
        self.workers = [
            remote_trainable.remote(
                config=config,
                logger_creator=lambda cfg: logger_creator(cfg, logdir, rank))
            for rank in range(num_workers)
        ]

        pgroup_params = self.default_process_group_parameters()
        from functools import partial
        setup_on_worker = partial(setup_process_group,
                                  url=address,
                                  world_size=num_workers,
                                  **pgroup_params)
        ray.get([
            w.execute.remote(lambda _: setup_on_worker(world_rank=rank))
            for rank, w in enumerate(self.workers)
        ])
Exemple #7
0
    def setup(self, config: Dict):
        self._finished = False
        num_workers = self._num_workers
        assert self._function

        func_trainable = wrap_function(self.__class__._function)
        remote_trainable = ray.remote(func_trainable)
        remote_option, self._placement_group =\
            PlacementGroupUtil.get_remote_worker_options(
                self._num_workers, self._num_cpus_per_worker,
                self._num_gpus_per_worker,
                self._num_workers_per_host, self._timeout_s)
        remote_trainable = \
            remote_trainable.options(**remote_option)
        self.workers = [
            remote_trainable.remote(config=config, )
            for _ in range(num_workers)
        ]

        addresses = [
            ray.get(worker.execute.remote(lambda _: setup_address()))
            for worker in self.workers
        ]

        from functools import partial
        setup_on_worker = partial(
            setup_process_group, worker_addresses=addresses)
        ray.get([
            w.execute.remote(lambda _: setup_on_worker(index=index))
            for index, w in enumerate(self.workers)
        ])
Exemple #8
0
    def setup(self, config: Dict):
        trainable = wrap_function(self.__class__._function)
        # We use a filelock here to ensure that the file-writing
        # process is safe across different trainables.
        if self._ssh_identity_file:
            with FileLock(self._ssh_identity_file + ".lock"):
                settings = RayExecutor.create_settings(self._timeout_s,
                                                       self._ssh_identity_file,
                                                       self._ssh_str)
        else:
            settings = RayExecutor.create_settings(self._timeout_s,
                                                   self._ssh_identity_file,
                                                   self._ssh_str)

        self.executor = RayExecutor(settings,
                                    cpus_per_slot=self._num_cpus_per_slot,
                                    use_gpu=self._use_gpu,
                                    num_hosts=self._num_hosts,
                                    num_slots=self._num_slots)

        new_config = DistributedTrainable.build_config(self, config)

        # We can't put `self` in the lambda closure, so we
        # resolve the variable ahead of time.
        logdir_ = str(self.logdir)

        # Starts the workers as specified by the resources above.
        self.executor.start(executable_cls=trainable,
                            executable_kwargs={
                                "config":
                                new_config,
                                "logger_creator":
                                lambda cfg: logger_creator(cfg, logdir_)
                            })
Exemple #9
0
def wrap_function_patched(function):
    """ Monkey-patch FunctionRunner remote trainable"""
    func_trainable = wrap_function(function)
    func_trainable.save_all_states = types.MethodType(save_all_states_remote,
                                                      func_trainable)
    func_trainable.get_sched_hints = types.MethodType(get_sched_hints_remote,
                                                      func_trainable)
    return func_trainable
Exemple #10
0
    def as_trainable(self) -> Type[Trainable]:
        """Convert self to a ``tune.Trainable`` class."""

        base_config = self._param_dict
        trainer_cls = self.__class__
        scaling_config = self.scaling_config

        def train_func(config, checkpoint_dir=None):
            # config already contains merged values.
            # Instantiate new Trainer in Trainable.
            trainer = trainer_cls(**config)

            if checkpoint_dir:
                trainer.resume_from_checkpoint = Checkpoint.from_directory(
                    checkpoint_dir
                )

            trainer.setup()
            trainer.preprocess_datasets()
            trainer.training_loop()

        # Change the name of the training function to match the name of the Trainer
        # class. This will mean the Tune trial name will match the name of Trainer on
        # stdout messages and the results directory.
        train_func.__name__ = trainer_cls.__name__

        trainable_cls = wrap_function(train_func)

        class TrainTrainable(trainable_cls):
            """Add default resources to the Trainable."""

            def __init__(self, *args, **kwargs):
                super().__init__(*args, **kwargs)

                # Create a new config by merging the dicts.
                # run_config is not a tunable hyperparameter so it does not need to be
                # merged.
                run_config = base_config.pop("run_config", None)
                self._merged_config = merge_dicts(base_config, self.config)
                self._merged_config["run_config"] = run_config

            def _trainable_func(self, config, reporter, checkpoint_dir):
                # We ignore the config passed by Tune and instead use the merged
                # config which includes the initial Trainer args.
                super()._trainable_func(self._merged_config, reporter, checkpoint_dir)

            @classmethod
            def default_resource_request(cls, config):
                updated_scaling_config = config.get("scaling_config", scaling_config)
                scaling_config_dataclass = (
                    trainer_cls._validate_and_get_scaling_config_data_class(
                        updated_scaling_config
                    )
                )
                return scaling_config_dataclass.as_placement_group_factory()

        return TrainTrainable
Exemple #11
0
def _create_tune_trainable(
    train_func, dataset, backend_config, num_workers, use_gpu, resources_per_worker
):
    """Creates a Tune Trainable class for Train training.

    This function populates class attributes and methods.
    """

    # TODO(matt): Move dataset to Ray object store, like tune.with_parameters.
    def tune_function(config, checkpoint_dir=None):
        trainer = Trainer(
            backend=backend_config,
            num_workers=num_workers,
            use_gpu=use_gpu,
            resources_per_worker=resources_per_worker,
        )

        trainer.start()

        if checkpoint_dir is not None:
            checkpoint_path = os.path.join(checkpoint_dir, TUNE_CHECKPOINT_FILE_NAME)
        else:
            checkpoint_path = None

        iterator = trainer.run_iterator(
            train_func, config, dataset=dataset, checkpoint=checkpoint_path
        )

        for results in iterator:
            first_worker_results = results[0]

            tune.report(**first_worker_results)

        trainer.shutdown()

    trainable_cls = wrap_function(tune_function)

    class TrainTrainable(trainable_cls):
        """Add default resources to the Trainable."""

        @classmethod
        def default_resource_request(cls, config: Dict) -> PlacementGroupFactory:
            trainer_bundle = [{"CPU": 1}]
            worker_resources = {"CPU": 1, "GPU": int(use_gpu)}
            worker_resources_extra = (
                {} if resources_per_worker is None else resources_per_worker
            )
            worker_bundles = [
                {**worker_resources, **worker_resources_extra}
                for _ in range(num_workers)
            ]
            bundles = trainer_bundle + worker_bundles
            return PlacementGroupFactory(bundles, strategy="PACK")

    return TrainTrainable
Exemple #12
0
    def setup(self, config: Dict):
        self._finished = False
        num_workers = self._num_workers
        logdir = self.logdir
        assert self._function

        func_trainable = wrap_function(self.__class__._function)

        remote_trainable = ray.remote(func_trainable)
        (
            remote_option,
            self._placement_group,
        ) = PlacementGroupUtil.get_remote_worker_options(
            self._num_workers,
            self._num_cpus_per_worker,
            self._num_gpus_per_worker,
            self._num_workers_per_host,
            self._timeout_s,
        )
        remote_trainable = remote_trainable.options(**remote_option)
        new_config = DistributedTrainable.build_config(self, config)

        self.workers = [
            remote_trainable.remote(
                config=new_config,
                logger_creator=lambda cfg: logger_creator(cfg, logdir, rank),
            )
            for rank in range(num_workers)
        ]

        # Address has to be IP of rank 0 worker's node.
        address = ray.get(self.workers[0].execute.remote(lambda _: setup_address()))

        pgroup_params = self.default_process_group_parameters()
        from functools import partial

        setup_on_worker = partial(
            setup_process_group, url=address, world_size=num_workers, **pgroup_params
        )
        ray.get(
            [
                w.execute.remote(lambda _: setup_on_worker(world_rank=rank))
                for rank, w in enumerate(self.workers)
            ]
        )

        ray.get(
            [
                w.execute.remote(lambda _: enable_distributed_trainable())
                for rank, w in enumerate(self.workers)
            ]
        )
    def testMultipleNullMemoryCheckpoints(self):
        def train(config, checkpoint_dir=None):
            assert not checkpoint_dir
            for step in range(10):
                tune.report(test=step)

        wrapped = wrap_function(train)
        checkpoint = None
        for i in range(5):
            new_trainable = wrapped(logger_creator=self.logger_creator)
            if checkpoint:
                new_trainable.restore_from_object(checkpoint)
            result = new_trainable.train()
            checkpoint = new_trainable.save_to_object()
            new_trainable.stop()
        assert result[TRAINING_ITERATION] == 1
Exemple #14
0
    def as_trainable(self) -> Type[Trainable]:
        """Convert self to a ``tune.Trainable`` class."""

        base_config = self._param_dict
        trainer_cls = self.__class__
        scaling_config = self.scaling_config

        def train_func(config, checkpoint_dir=None):
            # config already contains merged values.
            # Instantiate new Trainer in Trainable.
            trainer = trainer_cls(**config)

            if checkpoint_dir:
                trainer.resume_from_checkpoint = Checkpoint.from_directory(
                    checkpoint_dir)

            trainer.setup()
            trainer.preprocess_datasets()
            trainer.training_loop()

        trainable_cls = wrap_function(train_func)

        class TrainTrainable(trainable_cls):
            """Add default resources to the Trainable."""
            def __init__(self, *args, **kwargs):
                super().__init__(*args, **kwargs)

                # Create a new config by merging the dicts.
                self._merged_config = merge_dicts(base_config, self.config)

            def _trainable_func(self, config, reporter, checkpoint_dir):
                # We ignore the config passed by Tune and instead use the merged
                # config which includes the initial Trainer args.
                super()._trainable_func(self._merged_config, reporter,
                                        checkpoint_dir)

            @classmethod
            def default_resource_request(cls, config):
                updated_scaling_config = config.get("scaling_config",
                                                    scaling_config)
                scaling_config_dataclass = ScalingConfigDataClass(
                    **updated_scaling_config)
                return scaling_config_dataclass.as_placement_group_factory()

        return TrainTrainable
Exemple #15
0
    def testFunctionNoCheckpointing(self):
        def train(config, checkpoint_dir=None):
            for i in range(10):
                tune.report(test=i)

        wrapped = wrap_function(train)

        new_trainable = wrapped()
        result = new_trainable.train()
        checkpoint = new_trainable.save()
        new_trainable.stop()

        new_trainable2 = wrapped()
        new_trainable2.restore(checkpoint)
        result = new_trainable2.train()
        self.assertEquals(result[TRAINING_ITERATION], 1)
        checkpoint = new_trainable2.save()
        new_trainable2.stop()
    def testFunctionNoCheckpointing(self):
        def train(config, checkpoint_dir=None):
            if checkpoint_dir:
                assert os.path.exists(checkpoint_dir)
            for step in range(10):
                tune.report(test=step)

        wrapped = wrap_function(train)

        new_trainable = wrapped(logger_creator=self.logger_creator)
        result = new_trainable.train()
        checkpoint = new_trainable.save()
        new_trainable.stop()

        new_trainable2 = wrapped(logger_creator=self.logger_creator)
        new_trainable2.restore(checkpoint)
        result = new_trainable2.train()
        self.assertEqual(result[TRAINING_ITERATION], 1)
        checkpoint = new_trainable2.save()
        new_trainable2.stop()
Exemple #17
0
    def testMlFlowMixinConfig(self):
        clear_env_vars()
        trial_config = {"par1": 4, "par2": 9.}

        @mlflow_mixin
        def train_fn(config):
            return 1

        train_fn.__mixins__ = (MLFlowTrainableMixin, )

        # No MLFlow config passed in.
        with self.assertRaises(ValueError):
            wrapped = wrap_function(train_fn)(trial_config)

        trial_config.update({"mlflow": {}})
        # No tracking uri or experiment_id/name passed in.
        with self.assertRaises(ValueError):
            wrapped = wrap_function(train_fn)(trial_config)

        # Invalid experiment-id
        trial_config["mlflow"].update({"experiment_id": "500"})
        # No tracking uri or experiment_id/name passed in.
        with self.assertRaises(ValueError):
            wrapped = wrap_function(train_fn)(trial_config)

        trial_config["mlflow"].update({
            "tracking_uri": "test_tracking_uri",
            "experiment_name": "existing_experiment"
        })
        wrapped = wrap_function(train_fn)(trial_config)
        client = wrapped._mlflow
        self.assertEqual(client.tracking_uri, "test_tracking_uri")
        self.assertTupleEqual(client.active_run.run_id, (0, 0))

        with patch("ray.tune.integration.mlflow._import_mlflow",
                   lambda: client):
            train_fn.__mixins__ = (MLFlowTrainableMixin, )
            wrapped = wrap_function(train_fn)(trial_config)
            client = wrapped._mlflow
            self.assertTupleEqual(client.active_run.run_id, (0, 1))

            # Set to experiment that does not already exist.
            # New experiment should be created.
            trial_config["mlflow"]["experiment_name"] = "new_experiment"
            with self.assertRaises(ValueError):
                wrapped = wrap_function(train_fn)(trial_config)
Exemple #18
0
    def testWandbDecoratorConfig(self):
        config = {"par1": 4, "par2": 9.12345678}
        trial = Trial(config, 0, "trial_0", "trainable",
                      PlacementGroupFactory([{
                          "CPU": 1
                      }]))
        trial_info = TrialInfo(trial)

        @wandb_mixin
        def train_fn(config):
            return 1

        train_fn.__mixins__ = (_MockWandbTrainableMixin, )

        config[TRIAL_INFO] = trial_info

        if WANDB_ENV_VAR in os.environ:
            del os.environ[WANDB_ENV_VAR]

        # Needs at least a project
        with self.assertRaises(ValueError):
            wrapped = wrap_function(train_fn)(config)

        # No API key
        config["wandb"] = {"project": "test_project"}
        with self.assertRaises(ValueError):
            wrapped = wrap_function(train_fn)(config)

        # API Key in config
        config["wandb"] = {"project": "test_project", "api_key": "1234"}
        wrapped = wrap_function(train_fn)(config)
        self.assertEqual(os.environ[WANDB_ENV_VAR], "1234")

        del os.environ[WANDB_ENV_VAR]

        # API Key file
        with tempfile.NamedTemporaryFile("wt") as fp:
            fp.write("5678")
            fp.flush()

            config["wandb"] = {
                "project": "test_project",
                "api_key_file": fp.name
            }

            wrapped = wrap_function(train_fn)(config)
            self.assertEqual(os.environ[WANDB_ENV_VAR], "5678")

        del os.environ[WANDB_ENV_VAR]

        # API Key in env
        os.environ[WANDB_ENV_VAR] = "9012"
        config["wandb"] = {"project": "test_project"}
        wrapped = wrap_function(train_fn)(config)

        # From now on, the API key is in the env variable.

        # Default configuration
        config["wandb"] = {"project": "test_project"}
        config[TRIAL_INFO] = trial_info

        wrapped = wrap_function(train_fn)(config)
        self.assertEqual(wrapped.wandb.kwargs["project"], "test_project")
        self.assertEqual(wrapped.wandb.kwargs["id"], trial.trial_id)
        self.assertEqual(wrapped.wandb.kwargs["name"], trial.trial_name)
Exemple #19
0
def durable(trainable: Union[str, Type[Trainable], Callable]):
    """Convert trainable into a durable trainable.

    Durable trainables are used to upload trial results and checkpoints
    to cloud storage, like e.g. AWS S3.

    This function can be used to convert your trainable, i.e. your trainable
    classes, functions, or string identifiers, to a durable trainable.

    To make durable trainables work, you should pass a valid
    :class:`SyncConfig <ray.tune.SyncConfig>` object to `tune.run()`.

    Example:

    .. code-block:: python

        from ray import tune

        analysis = tune.run(
            tune.durable("PPO"),
            config={"env": "CartPole-v0"},
            checkpoint_freq=1,
            sync_config=tune.SyncConfig(
                sync_to_driver=False,
                upload_dir="s3://your-s3-bucket/durable-ppo/",
            ))

    You can also convert your trainable functions:

    .. code-block:: python

        tune.run(
            tune.durable(your_training_fn),
            # ...
        )

    And your class functions:

    .. code-block:: python

        tune.run(
            tune.durable(YourTrainableClass),
            # ...
        )


    Args:
        trainable (str|Type[Trainable]|Callable): Trainable. Can be a
            string identifier, a trainable class, or a trainable function.

    Returns:
        A durable trainable class wrapped around your trainable.

    """
    if isinstance(trainable, str):
        trainable_cls = get_trainable_cls(trainable)
    else:
        trainable_cls = trainable

    if not inspect.isclass(trainable_cls):
        # Function API
        return wrap_function(trainable_cls, durable=True)

    if not issubclass(trainable_cls, Trainable):
        raise ValueError(
            "You can only use `durable()` with valid trainables. The class "
            "you passed does not inherit from `Trainable`. Please make sure "
            f"it does. Got: {type(trainable_cls)}")

    # else: Class API
    class _WrappedDurableTrainable(DurableTrainable, trainable_cls):
        _name = trainable_cls.__name__ if hasattr(trainable_cls, "__name__") \
            else "durable_trainable"

    return _WrappedDurableTrainable