コード例 #1
0
ファイル: registry.py プロジェクト: anke522/ray-1
def register_trainable(name, trainable):
    """Register a trainable function or class.

    Args:
        name (str): Name to register.
        trainable (obj): Function or tune.Trainable class. Functions must
            take (config, status_reporter) as arguments and will be
            automatically converted into a class during registration.
    """

    from ray.tune.trainable import Trainable, wrap_function

    if isinstance(trainable, type):
        logger.debug("Detected class for trainable.")
    elif isinstance(trainable, FunctionType):
        logger.debug("Detected function for trainable.")
        trainable = wrap_function(trainable)
    elif callable(trainable):
        logger.warning(
            "Detected unknown callable for trainable. Converting to class.")
        trainable = wrap_function(trainable)

    if not issubclass(trainable, Trainable):
        raise TypeError("Second argument must be convertable to Trainable",
                        trainable)
    _global_registry.register(TRAINABLE_CLASS, name, trainable)
コード例 #2
0
ファイル: registry.py プロジェクト: ray-project/ray
def register_trainable(name: str, trainable: Union[Callable, Type], warn: bool = True):
    """Register a trainable function or class.

    This enables a class or function to be accessed on every Ray process
    in the cluster.

    Args:
        name: Name to register.
        trainable: Function or tune.Trainable class. Functions must
            take (config, status_reporter) as arguments and will be
            automatically converted into a class during registration.
    """

    from ray.tune.trainable import wrap_function
    from ray.tune.trainable import Trainable

    if isinstance(trainable, type):
        logger.debug("Detected class for trainable.")
    elif isinstance(trainable, FunctionType) or isinstance(trainable, partial):
        logger.debug("Detected function for trainable.")
        trainable = wrap_function(trainable, warn=warn)
    elif callable(trainable):
        logger.info("Detected unknown callable for trainable. Converting to class.")
        trainable = wrap_function(trainable, warn=warn)

    if not issubclass(trainable, Trainable):
        raise TypeError("Second argument must be convertable to Trainable", trainable)
    _global_registry.register(TRAINABLE_CLASS, name, trainable)
コード例 #3
0
ファイル: test_integration_mlflow.py プロジェクト: r4b3rt/ray
    def testMlFlowMixinConfig(self):
        clear_env_vars()
        trial_config = {"par1": 4, "par2": 9.0}

        @mlflow_mixin
        def train_fn(config):
            return 1

        train_fn.__mixins__ = (MLflowTrainableMixin, )

        # No MLflow config passed in.
        with self.assertRaises(ValueError):
            wrap_function(train_fn)(trial_config)

        trial_config.update({"mlflow": {}})
        # No tracking uri or experiment_id/name passed in.
        with self.assertRaises(ValueError):
            wrap_function(train_fn)(trial_config)

        # Invalid experiment-id
        trial_config["mlflow"].update({"experiment_id": "500"})
        # No tracking uri or experiment_id/name passed in.
        with self.assertRaises(ValueError):
            wrap_function(train_fn)(trial_config)

        # Set to experiment that does not already exist.
        # New experiment should be created.
        trial_config["mlflow"]["experiment_name"] = "new_experiment"
        with self.assertRaises(ValueError):
            wrap_function(train_fn)(trial_config)
コード例 #4
0
ファイル: test_function_api.py プロジェクト: ray-project/ray
    def testCheckpointReuseObject(self):
        """Test that repeated save/restore never reuses same checkpoint dir."""
        def train(config, checkpoint_dir=None):
            if checkpoint_dir:
                count = sum("checkpoint-" in path
                            for path in os.listdir(checkpoint_dir))
                assert count == 1, os.listdir(checkpoint_dir)

            for step in range(20):
                with tune.checkpoint_dir(step=step) as checkpoint_dir:
                    path = os.path.join(checkpoint_dir,
                                        "checkpoint-{}".format(step))
                    open(path, "a").close()
                tune.report(test=step)

        wrapped = wrap_function(train)
        checkpoint = None
        for i in range(5):
            new_trainable = wrapped(logger_creator=self.logger_creator)
            if checkpoint:
                new_trainable.restore_from_object(checkpoint)
            for i in range(2):
                result = new_trainable.train()
            checkpoint = new_trainable.save_to_object()
            new_trainable.stop()
        self.assertTrue(result[TRAINING_ITERATION] == 10)
コード例 #5
0
ファイル: test_function_api.py プロジェクト: ray-project/ray
    def testFunctionImmediateSave(self):
        """This tests that save and restore are commutative."""
        def train(config, checkpoint_dir=None):
            if checkpoint_dir:
                assert os.path.exists(checkpoint_dir)
            for step in range(10):
                with tune.checkpoint_dir(step=step) as checkpoint_dir:
                    print(checkpoint_dir)
                    path = os.path.join(checkpoint_dir,
                                        "checkpoint-{}".format(step))
                    open(path, "w").close()
                tune.report(test=step)

        wrapped = wrap_function(train)
        new_trainable = wrapped(logger_creator=self.logger_creator)
        new_trainable.train()
        new_trainable.train()
        checkpoint_obj = new_trainable.save_to_object()
        new_trainable.stop()

        new_trainable2 = wrapped(logger_creator=self.logger_creator)
        new_trainable2.restore_from_object(checkpoint_obj)
        checkpoint_obj = new_trainable2.save_to_object()
        new_trainable2.train()
        result = new_trainable2.train()
        assert sum("tmp" in path for path in os.listdir(self.logdir)) == 1
        new_trainable2.stop()
        assert sum("tmp" in path for path in os.listdir(self.logdir)) == 0
        assert result[TRAINING_ITERATION] == 4
コード例 #6
0
ファイル: test_function_api.py プロジェクト: ray-project/ray
    def testFunctionRecurringSave(self):
        """This tests that save and restore are commutative."""
        def train(config, checkpoint_dir=None):
            if checkpoint_dir:
                assert os.path.exists(checkpoint_dir)
            for step in range(10):
                if step % 3 == 0:
                    with tune.checkpoint_dir(step=step) as checkpoint_dir:
                        path = os.path.join(checkpoint_dir, "checkpoint")
                        with open(path, "w") as f:
                            f.write(json.dumps({"step": step}))
                tune.report(test=step)

        wrapped = wrap_function(train)

        new_trainable = wrapped(logger_creator=self.logger_creator)
        new_trainable.train()
        checkpoint_obj = new_trainable.save_to_object()
        new_trainable.restore_from_object(checkpoint_obj)
        checkpoint = new_trainable.save()

        new_trainable.stop()

        new_trainable2 = wrapped(logger_creator=self.logger_creator)
        new_trainable2.restore(checkpoint)
        new_trainable2.train()
        new_trainable2.stop()
コード例 #7
0
ファイル: test_function_api.py プロジェクト: ray-project/ray
    def testReuseNullCheckpoint(self):
        def train(config, checkpoint_dir=None):
            assert not checkpoint_dir
            for step in range(10):
                tune.report(test=step)

        # Create checkpoint
        wrapped = wrap_function(train)
        checkpoint = None
        new_trainable = wrapped(logger_creator=self.logger_creator)
        new_trainable.train()
        checkpoint = new_trainable.save()
        new_trainable.stop()

        # Use the checkpoint a couple of times
        for i in range(3):
            new_trainable = wrapped(logger_creator=self.logger_creator)
            new_trainable.restore(checkpoint)
            new_trainable.stop()

        # Make sure the result is still good
        new_trainable = wrapped(logger_creator=self.logger_creator)
        new_trainable.restore(checkpoint)
        result = new_trainable.train()
        checkpoint = new_trainable.save()
        new_trainable.stop()
        self.assertTrue(result[TRAINING_ITERATION] == 1)
コード例 #8
0
ファイル: test_trainable.py プロジェクト: parasj/ray
def test_save_load_checkpoint_object_fn(ray_start_2_cpus, fn_trainable):
    trainable_cls = wrap_function(fn_trainable)
    trainable = ray.remote(trainable_cls).remote()
    ray.get(trainable.train.remote())

    saving_future = trainable.save_to_object.remote()

    # Check for errors
    ray.get(saving_future)

    restoring_future = trainable.restore_from_object.remote(saving_future)

    ray.get(restoring_future)
コード例 #9
0
ファイル: trainer.py プロジェクト: ray-project/ray
def _create_tune_trainable(
    train_func, dataset, backend_config, num_workers, use_gpu, resources_per_worker
):
    """Creates a Tune Trainable class for Train training.

    This function populates class attributes and methods.
    """

    # TODO(matt): Move dataset to Ray object store, like tune.with_parameters.
    def tune_function(config, checkpoint_dir=None):
        trainer = Trainer(
            backend=backend_config,
            num_workers=num_workers,
            use_gpu=use_gpu,
            resources_per_worker=resources_per_worker,
        )

        trainer.start()

        iterator = trainer.run_iterator(
            train_func, config, dataset=dataset, checkpoint=checkpoint_dir
        )

        for results in iterator:
            first_worker_results = results[0]

            tune.report(**first_worker_results)

        trainer.shutdown()

    trainable_cls = wrap_function(tune_function)

    class TrainTrainable(trainable_cls):
        """Add default resources to the Trainable."""

        @classmethod
        def default_resource_request(cls, config: Dict) -> PlacementGroupFactory:
            trainer_bundle = [{"CPU": 1}]
            worker_resources = {"CPU": 1, "GPU": int(use_gpu)}
            worker_resources_extra = (
                {} if resources_per_worker is None else resources_per_worker
            )
            worker_bundles = [
                {**worker_resources, **worker_resources_extra}
                for _ in range(num_workers)
            ]
            bundles = trainer_bundle + worker_bundles
            return PlacementGroupFactory(bundles, strategy="PACK")

    return TrainTrainable
コード例 #10
0
ファイル: registry.py プロジェクト: adgirish/ray
def register_trainable(name, trainable):
    """Register a trainable function or class.

    Args:
        name (str): Name to register.
        trainable (obj): Function or tune.Trainable clsas. Functions must
            take (config, status_reporter) as arguments and will be
            automatically converted into a class during registration.
    """

    if isinstance(trainable, FunctionType):
        trainable = wrap_function(trainable)
    if not issubclass(trainable, Trainable):
        raise TypeError(
            "Second argument must be convertable to Trainable", trainable)
    _default_registry.register(TRAINABLE_CLASS, name, trainable)
コード例 #11
0
def register_trainable(name, trainable):
    """Register a trainable function or class.

    Args:
        name (str): Name to register.
        trainable (obj): Function or tune.Trainable clsas. Functions must
            take (config, status_reporter) as arguments and will be
            automatically converted into a class during registration.
    """

    if isinstance(trainable, FunctionType):
        trainable = wrap_function(trainable)
    if not issubclass(trainable, Trainable):
        raise TypeError("Second argument must be convertable to Trainable",
                        trainable)
    _default_registry.register(TRAINABLE_CLASS, name, trainable)
コード例 #12
0
ファイル: test_function_api.py プロジェクト: ray-project/ray
    def testMultipleNullMemoryCheckpoints(self):
        def train(config, checkpoint_dir=None):
            assert not checkpoint_dir
            for step in range(10):
                tune.report(test=step)

        wrapped = wrap_function(train)
        checkpoint = None
        for i in range(5):
            new_trainable = wrapped(logger_creator=self.logger_creator)
            if checkpoint:
                new_trainable.restore_from_object(checkpoint)
            result = new_trainable.train()
            checkpoint = new_trainable.save_to_object()
            new_trainable.stop()
        assert result[TRAINING_ITERATION] == 1
コード例 #13
0
ファイル: test_function_api.py プロジェクト: ray-project/ray
    def testFunctionNoCheckpointing(self):
        def train(config, checkpoint_dir=None):
            if checkpoint_dir:
                assert os.path.exists(checkpoint_dir)
            for step in range(10):
                tune.report(test=step)

        wrapped = wrap_function(train)

        new_trainable = wrapped(logger_creator=self.logger_creator)
        result = new_trainable.train()
        checkpoint = new_trainable.save()
        new_trainable.stop()

        new_trainable2 = wrapped(logger_creator=self.logger_creator)
        new_trainable2.restore(checkpoint)
        result = new_trainable2.train()
        self.assertEqual(result[TRAINING_ITERATION], 1)
        checkpoint = new_trainable2.save()
        new_trainable2.stop()
コード例 #14
0
ファイル: test_integration_wandb.py プロジェクト: parasj/ray
    def testWandbDecoratorConfig(self):
        config = {"par1": 4, "par2": 9.12345678}
        trial = Trial(
            config,
            0,
            "trial_0",
            "trainable",
            PlacementGroupFactory([{
                "CPU": 1
            }]),
            "/tmp",
        )
        trial_info = _TrialInfo(trial)

        @wandb_mixin
        def train_fn(config):
            return 1

        train_fn.__mixins__ = (_MockWandbTrainableMixin, )

        config[TRIAL_INFO] = trial_info

        if WANDB_ENV_VAR in os.environ:
            del os.environ[WANDB_ENV_VAR]

        # Needs at least a project
        with self.assertRaises(ValueError):
            wrapped = wrap_function(train_fn)(config)

        # No API key
        config["wandb"] = {"project": "test_project"}
        with self.assertRaises(ValueError):
            wrapped = wrap_function(train_fn)(config)

        # API Key in config
        config["wandb"] = {"project": "test_project", "api_key": "1234"}
        wrapped = wrap_function(train_fn)(config)
        self.assertEqual(os.environ[WANDB_ENV_VAR], "1234")

        del os.environ[WANDB_ENV_VAR]

        # API Key file
        with tempfile.NamedTemporaryFile("wt") as fp:
            fp.write("5678")
            fp.flush()

            config["wandb"] = {
                "project": "test_project",
                "api_key_file": fp.name
            }

            wrapped = wrap_function(train_fn)(config)
            self.assertEqual(os.environ[WANDB_ENV_VAR], "5678")

        del os.environ[WANDB_ENV_VAR]

        # API Key in env
        os.environ[WANDB_ENV_VAR] = "9012"
        config["wandb"] = {"project": "test_project"}
        wrapped = wrap_function(train_fn)(config)

        # From now on, the API key is in the env variable.

        # Default configuration
        config["wandb"] = {"project": "test_project"}
        config[TRIAL_INFO] = trial_info

        wrapped = wrap_function(train_fn)(config)
        self.assertEqual(wrapped.wandb.kwargs["project"], "test_project")
        self.assertEqual(wrapped.wandb.kwargs["id"], trial.trial_id)
        self.assertEqual(wrapped.wandb.kwargs["name"], trial.trial_name)
コード例 #15
0
ファイル: base_trainer.py プロジェクト: ray-project/ray
    def as_trainable(self) -> Type[Trainable]:
        """Convert self to a ``tune.Trainable`` class."""

        base_config = self._param_dict
        trainer_cls = self.__class__
        scaling_config = self.scaling_config

        def train_func(config, checkpoint_dir=None):
            # config already contains merged values.
            # Instantiate new Trainer in Trainable.
            trainer = trainer_cls(**config)

            if checkpoint_dir:
                trainer.resume_from_checkpoint = Checkpoint.from_directory(
                    checkpoint_dir
                )

            trainer.setup()
            trainer.preprocess_datasets()
            trainer.training_loop()

        # Change the name of the training function to match the name of the Trainer
        # class. This will mean the Tune trial name will match the name of Trainer on
        # stdout messages and the results directory.
        train_func.__name__ = trainer_cls.__name__

        trainable_cls = wrap_function(train_func)

        class TrainTrainable(trainable_cls):
            """Add default resources to the Trainable."""

            # Workaround for actor name not being logged correctly
            # if __repr__ is not directly defined in a class.
            def __repr__(self):
                return super().__repr__()

            def __init__(self, *args, **kwargs):
                super().__init__(*args, **kwargs)

                # Create a new config by merging the dicts.
                # run_config is not a tunable hyperparameter so it does not need to be
                # merged.
                run_config = base_config.pop("run_config", None)
                self._merged_config = merge_dicts(base_config, self.config)
                self._merged_config["run_config"] = run_config

            def _trainable_func(self, config, reporter, checkpoint_dir):
                # We ignore the config passed by Tune and instead use the merged
                # config which includes the initial Trainer args.
                super()._trainable_func(self._merged_config, reporter, checkpoint_dir)

            @classmethod
            def default_resource_request(cls, config):
                updated_scaling_config = config.get("scaling_config", scaling_config)
                scaling_config_dataclass = (
                    trainer_cls._validate_and_get_scaling_config_data_class(
                        updated_scaling_config
                    )
                )
                return scaling_config_dataclass.as_placement_group_factory()

        return TrainTrainable
コード例 #16
0
ファイル: base_trainer.py プロジェクト: parasj/ray
    def as_trainable(self) -> Type[Trainable]:
        """Convert self to a ``tune.Trainable`` class."""

        base_config = self._param_dict
        trainer_cls = self.__class__
        scaling_config = self.scaling_config

        def train_func(config, checkpoint_dir=None):
            # config already contains merged values.
            # Instantiate new Trainer in Trainable.
            trainer = trainer_cls(**config)

            if checkpoint_dir:
                trainer.resume_from_checkpoint = Checkpoint.from_directory(
                    checkpoint_dir)

            trainer.setup()
            trainer.preprocess_datasets()
            trainer.training_loop()

        # Change the name of the training function to match the name of the Trainer
        # class. This will mean the Tune trial name will match the name of Trainer on
        # stdout messages and the results directory.
        train_func.__name__ = trainer_cls.__name__

        trainable_cls = wrap_function(train_func, warn=False)

        class TrainTrainable(trainable_cls):
            """Add default resources to the Trainable."""

            _handles_checkpoint_freq = trainer_cls._handles_checkpoint_freq
            _handles_checkpoint_at_end = trainer_cls._handles_checkpoint_at_end

            # Workaround for actor name not being logged correctly
            # if __repr__ is not directly defined in a class.
            def __repr__(self):
                return super().__repr__()

            def __init__(self, *args, **kwargs):
                super().__init__(*args, **kwargs)

                # Create a new config by merging the dicts.
                # run_config is not a tunable hyperparameter so it does not need to be
                # merged.
                run_config = base_config.pop("run_config", None)
                self._merged_config = merge_dicts(base_config, self.config)
                self._merged_config["run_config"] = run_config
                merged_scaling_config = self._merged_config.get(
                    "scaling_config")
                if isinstance(merged_scaling_config, dict):
                    merged_scaling_config = ScalingConfig(
                        **merged_scaling_config)
                self._merged_config[
                    "scaling_config"] = self._reconcile_scaling_config_with_trial_resources(
                        merged_scaling_config)

            def _reconcile_scaling_config_with_trial_resources(
                    self, scaling_config: ScalingConfig) -> ScalingConfig:
                """
                ResourceChangingScheduler workaround.

                Ensures that the scaling config matches trial resources.

                This should be replaced with RCS returning a ScalingConfig
                in the future.
                """

                trial_resources = self.trial_resources
                # This will be false if the resources are default
                if not isinstance(trial_resources, PlacementGroupFactory):
                    return scaling_config

                if scaling_config:
                    scaling_config = trainer_cls._validate_scaling_config(
                        scaling_config)
                scaling_config_from_trial_resources = (
                    ScalingConfig.from_placement_group_factory(trial_resources)
                )

                # This check should always pass if ResourceChangingScheduler is not
                # used.
                if scaling_config_from_trial_resources != scaling_config:
                    scaling_config = trainer_cls._validate_scaling_config(
                        scaling_config_from_trial_resources)
                return scaling_config

            def _trainable_func(self, config, reporter, checkpoint_dir):
                # We ignore the config passed by Tune and instead use the merged
                # config which includes the initial Trainer args.
                super()._trainable_func(self._merged_config, reporter,
                                        checkpoint_dir)

            @classmethod
            def default_resource_request(cls, config):
                # `config["scaling_config"] is a dataclass when passed via the
                # `scaling_config` argument in `Trainer` and is a dict when passed
                # via the `scaling_config` key of `param_spec`.

                # Conversion logic must be duplicated in `TrainTrainable.__init__`
                # because this is a class method.
                updated_scaling_config = config.get("scaling_config",
                                                    scaling_config)
                if isinstance(updated_scaling_config, dict):
                    updated_scaling_config = ScalingConfig(
                        **updated_scaling_config)
                validated_scaling_config = trainer_cls._validate_scaling_config(
                    updated_scaling_config)
                return validated_scaling_config.as_placement_group_factory()

        return TrainTrainable