def register_trainable(name, trainable): """Register a trainable function or class. Args: name (str): Name to register. trainable (obj): Function or tune.Trainable class. Functions must take (config, status_reporter) as arguments and will be automatically converted into a class during registration. """ from ray.tune.trainable import Trainable, wrap_function if isinstance(trainable, type): logger.debug("Detected class for trainable.") elif isinstance(trainable, FunctionType): logger.debug("Detected function for trainable.") trainable = wrap_function(trainable) elif callable(trainable): logger.warning( "Detected unknown callable for trainable. Converting to class.") trainable = wrap_function(trainable) if not issubclass(trainable, Trainable): raise TypeError("Second argument must be convertable to Trainable", trainable) _global_registry.register(TRAINABLE_CLASS, name, trainable)
def register_trainable(name: str, trainable: Union[Callable, Type], warn: bool = True): """Register a trainable function or class. This enables a class or function to be accessed on every Ray process in the cluster. Args: name: Name to register. trainable: Function or tune.Trainable class. Functions must take (config, status_reporter) as arguments and will be automatically converted into a class during registration. """ from ray.tune.trainable import wrap_function from ray.tune.trainable import Trainable if isinstance(trainable, type): logger.debug("Detected class for trainable.") elif isinstance(trainable, FunctionType) or isinstance(trainable, partial): logger.debug("Detected function for trainable.") trainable = wrap_function(trainable, warn=warn) elif callable(trainable): logger.info("Detected unknown callable for trainable. Converting to class.") trainable = wrap_function(trainable, warn=warn) if not issubclass(trainable, Trainable): raise TypeError("Second argument must be convertable to Trainable", trainable) _global_registry.register(TRAINABLE_CLASS, name, trainable)
def testMlFlowMixinConfig(self): clear_env_vars() trial_config = {"par1": 4, "par2": 9.0} @mlflow_mixin def train_fn(config): return 1 train_fn.__mixins__ = (MLflowTrainableMixin, ) # No MLflow config passed in. with self.assertRaises(ValueError): wrap_function(train_fn)(trial_config) trial_config.update({"mlflow": {}}) # No tracking uri or experiment_id/name passed in. with self.assertRaises(ValueError): wrap_function(train_fn)(trial_config) # Invalid experiment-id trial_config["mlflow"].update({"experiment_id": "500"}) # No tracking uri or experiment_id/name passed in. with self.assertRaises(ValueError): wrap_function(train_fn)(trial_config) # Set to experiment that does not already exist. # New experiment should be created. trial_config["mlflow"]["experiment_name"] = "new_experiment" with self.assertRaises(ValueError): wrap_function(train_fn)(trial_config)
def testCheckpointReuseObject(self): """Test that repeated save/restore never reuses same checkpoint dir.""" def train(config, checkpoint_dir=None): if checkpoint_dir: count = sum("checkpoint-" in path for path in os.listdir(checkpoint_dir)) assert count == 1, os.listdir(checkpoint_dir) for step in range(20): with tune.checkpoint_dir(step=step) as checkpoint_dir: path = os.path.join(checkpoint_dir, "checkpoint-{}".format(step)) open(path, "a").close() tune.report(test=step) wrapped = wrap_function(train) checkpoint = None for i in range(5): new_trainable = wrapped(logger_creator=self.logger_creator) if checkpoint: new_trainable.restore_from_object(checkpoint) for i in range(2): result = new_trainable.train() checkpoint = new_trainable.save_to_object() new_trainable.stop() self.assertTrue(result[TRAINING_ITERATION] == 10)
def testFunctionImmediateSave(self): """This tests that save and restore are commutative.""" def train(config, checkpoint_dir=None): if checkpoint_dir: assert os.path.exists(checkpoint_dir) for step in range(10): with tune.checkpoint_dir(step=step) as checkpoint_dir: print(checkpoint_dir) path = os.path.join(checkpoint_dir, "checkpoint-{}".format(step)) open(path, "w").close() tune.report(test=step) wrapped = wrap_function(train) new_trainable = wrapped(logger_creator=self.logger_creator) new_trainable.train() new_trainable.train() checkpoint_obj = new_trainable.save_to_object() new_trainable.stop() new_trainable2 = wrapped(logger_creator=self.logger_creator) new_trainable2.restore_from_object(checkpoint_obj) checkpoint_obj = new_trainable2.save_to_object() new_trainable2.train() result = new_trainable2.train() assert sum("tmp" in path for path in os.listdir(self.logdir)) == 1 new_trainable2.stop() assert sum("tmp" in path for path in os.listdir(self.logdir)) == 0 assert result[TRAINING_ITERATION] == 4
def testFunctionRecurringSave(self): """This tests that save and restore are commutative.""" def train(config, checkpoint_dir=None): if checkpoint_dir: assert os.path.exists(checkpoint_dir) for step in range(10): if step % 3 == 0: with tune.checkpoint_dir(step=step) as checkpoint_dir: path = os.path.join(checkpoint_dir, "checkpoint") with open(path, "w") as f: f.write(json.dumps({"step": step})) tune.report(test=step) wrapped = wrap_function(train) new_trainable = wrapped(logger_creator=self.logger_creator) new_trainable.train() checkpoint_obj = new_trainable.save_to_object() new_trainable.restore_from_object(checkpoint_obj) checkpoint = new_trainable.save() new_trainable.stop() new_trainable2 = wrapped(logger_creator=self.logger_creator) new_trainable2.restore(checkpoint) new_trainable2.train() new_trainable2.stop()
def testReuseNullCheckpoint(self): def train(config, checkpoint_dir=None): assert not checkpoint_dir for step in range(10): tune.report(test=step) # Create checkpoint wrapped = wrap_function(train) checkpoint = None new_trainable = wrapped(logger_creator=self.logger_creator) new_trainable.train() checkpoint = new_trainable.save() new_trainable.stop() # Use the checkpoint a couple of times for i in range(3): new_trainable = wrapped(logger_creator=self.logger_creator) new_trainable.restore(checkpoint) new_trainable.stop() # Make sure the result is still good new_trainable = wrapped(logger_creator=self.logger_creator) new_trainable.restore(checkpoint) result = new_trainable.train() checkpoint = new_trainable.save() new_trainable.stop() self.assertTrue(result[TRAINING_ITERATION] == 1)
def test_save_load_checkpoint_object_fn(ray_start_2_cpus, fn_trainable): trainable_cls = wrap_function(fn_trainable) trainable = ray.remote(trainable_cls).remote() ray.get(trainable.train.remote()) saving_future = trainable.save_to_object.remote() # Check for errors ray.get(saving_future) restoring_future = trainable.restore_from_object.remote(saving_future) ray.get(restoring_future)
def _create_tune_trainable( train_func, dataset, backend_config, num_workers, use_gpu, resources_per_worker ): """Creates a Tune Trainable class for Train training. This function populates class attributes and methods. """ # TODO(matt): Move dataset to Ray object store, like tune.with_parameters. def tune_function(config, checkpoint_dir=None): trainer = Trainer( backend=backend_config, num_workers=num_workers, use_gpu=use_gpu, resources_per_worker=resources_per_worker, ) trainer.start() iterator = trainer.run_iterator( train_func, config, dataset=dataset, checkpoint=checkpoint_dir ) for results in iterator: first_worker_results = results[0] tune.report(**first_worker_results) trainer.shutdown() trainable_cls = wrap_function(tune_function) class TrainTrainable(trainable_cls): """Add default resources to the Trainable.""" @classmethod def default_resource_request(cls, config: Dict) -> PlacementGroupFactory: trainer_bundle = [{"CPU": 1}] worker_resources = {"CPU": 1, "GPU": int(use_gpu)} worker_resources_extra = ( {} if resources_per_worker is None else resources_per_worker ) worker_bundles = [ {**worker_resources, **worker_resources_extra} for _ in range(num_workers) ] bundles = trainer_bundle + worker_bundles return PlacementGroupFactory(bundles, strategy="PACK") return TrainTrainable
def register_trainable(name, trainable): """Register a trainable function or class. Args: name (str): Name to register. trainable (obj): Function or tune.Trainable clsas. Functions must take (config, status_reporter) as arguments and will be automatically converted into a class during registration. """ if isinstance(trainable, FunctionType): trainable = wrap_function(trainable) if not issubclass(trainable, Trainable): raise TypeError( "Second argument must be convertable to Trainable", trainable) _default_registry.register(TRAINABLE_CLASS, name, trainable)
def register_trainable(name, trainable): """Register a trainable function or class. Args: name (str): Name to register. trainable (obj): Function or tune.Trainable clsas. Functions must take (config, status_reporter) as arguments and will be automatically converted into a class during registration. """ if isinstance(trainable, FunctionType): trainable = wrap_function(trainable) if not issubclass(trainable, Trainable): raise TypeError("Second argument must be convertable to Trainable", trainable) _default_registry.register(TRAINABLE_CLASS, name, trainable)
def testMultipleNullMemoryCheckpoints(self): def train(config, checkpoint_dir=None): assert not checkpoint_dir for step in range(10): tune.report(test=step) wrapped = wrap_function(train) checkpoint = None for i in range(5): new_trainable = wrapped(logger_creator=self.logger_creator) if checkpoint: new_trainable.restore_from_object(checkpoint) result = new_trainable.train() checkpoint = new_trainable.save_to_object() new_trainable.stop() assert result[TRAINING_ITERATION] == 1
def testFunctionNoCheckpointing(self): def train(config, checkpoint_dir=None): if checkpoint_dir: assert os.path.exists(checkpoint_dir) for step in range(10): tune.report(test=step) wrapped = wrap_function(train) new_trainable = wrapped(logger_creator=self.logger_creator) result = new_trainable.train() checkpoint = new_trainable.save() new_trainable.stop() new_trainable2 = wrapped(logger_creator=self.logger_creator) new_trainable2.restore(checkpoint) result = new_trainable2.train() self.assertEqual(result[TRAINING_ITERATION], 1) checkpoint = new_trainable2.save() new_trainable2.stop()
def testWandbDecoratorConfig(self): config = {"par1": 4, "par2": 9.12345678} trial = Trial( config, 0, "trial_0", "trainable", PlacementGroupFactory([{ "CPU": 1 }]), "/tmp", ) trial_info = _TrialInfo(trial) @wandb_mixin def train_fn(config): return 1 train_fn.__mixins__ = (_MockWandbTrainableMixin, ) config[TRIAL_INFO] = trial_info if WANDB_ENV_VAR in os.environ: del os.environ[WANDB_ENV_VAR] # Needs at least a project with self.assertRaises(ValueError): wrapped = wrap_function(train_fn)(config) # No API key config["wandb"] = {"project": "test_project"} with self.assertRaises(ValueError): wrapped = wrap_function(train_fn)(config) # API Key in config config["wandb"] = {"project": "test_project", "api_key": "1234"} wrapped = wrap_function(train_fn)(config) self.assertEqual(os.environ[WANDB_ENV_VAR], "1234") del os.environ[WANDB_ENV_VAR] # API Key file with tempfile.NamedTemporaryFile("wt") as fp: fp.write("5678") fp.flush() config["wandb"] = { "project": "test_project", "api_key_file": fp.name } wrapped = wrap_function(train_fn)(config) self.assertEqual(os.environ[WANDB_ENV_VAR], "5678") del os.environ[WANDB_ENV_VAR] # API Key in env os.environ[WANDB_ENV_VAR] = "9012" config["wandb"] = {"project": "test_project"} wrapped = wrap_function(train_fn)(config) # From now on, the API key is in the env variable. # Default configuration config["wandb"] = {"project": "test_project"} config[TRIAL_INFO] = trial_info wrapped = wrap_function(train_fn)(config) self.assertEqual(wrapped.wandb.kwargs["project"], "test_project") self.assertEqual(wrapped.wandb.kwargs["id"], trial.trial_id) self.assertEqual(wrapped.wandb.kwargs["name"], trial.trial_name)
def as_trainable(self) -> Type[Trainable]: """Convert self to a ``tune.Trainable`` class.""" base_config = self._param_dict trainer_cls = self.__class__ scaling_config = self.scaling_config def train_func(config, checkpoint_dir=None): # config already contains merged values. # Instantiate new Trainer in Trainable. trainer = trainer_cls(**config) if checkpoint_dir: trainer.resume_from_checkpoint = Checkpoint.from_directory( checkpoint_dir ) trainer.setup() trainer.preprocess_datasets() trainer.training_loop() # Change the name of the training function to match the name of the Trainer # class. This will mean the Tune trial name will match the name of Trainer on # stdout messages and the results directory. train_func.__name__ = trainer_cls.__name__ trainable_cls = wrap_function(train_func) class TrainTrainable(trainable_cls): """Add default resources to the Trainable.""" # Workaround for actor name not being logged correctly # if __repr__ is not directly defined in a class. def __repr__(self): return super().__repr__() def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Create a new config by merging the dicts. # run_config is not a tunable hyperparameter so it does not need to be # merged. run_config = base_config.pop("run_config", None) self._merged_config = merge_dicts(base_config, self.config) self._merged_config["run_config"] = run_config def _trainable_func(self, config, reporter, checkpoint_dir): # We ignore the config passed by Tune and instead use the merged # config which includes the initial Trainer args. super()._trainable_func(self._merged_config, reporter, checkpoint_dir) @classmethod def default_resource_request(cls, config): updated_scaling_config = config.get("scaling_config", scaling_config) scaling_config_dataclass = ( trainer_cls._validate_and_get_scaling_config_data_class( updated_scaling_config ) ) return scaling_config_dataclass.as_placement_group_factory() return TrainTrainable
def as_trainable(self) -> Type[Trainable]: """Convert self to a ``tune.Trainable`` class.""" base_config = self._param_dict trainer_cls = self.__class__ scaling_config = self.scaling_config def train_func(config, checkpoint_dir=None): # config already contains merged values. # Instantiate new Trainer in Trainable. trainer = trainer_cls(**config) if checkpoint_dir: trainer.resume_from_checkpoint = Checkpoint.from_directory( checkpoint_dir) trainer.setup() trainer.preprocess_datasets() trainer.training_loop() # Change the name of the training function to match the name of the Trainer # class. This will mean the Tune trial name will match the name of Trainer on # stdout messages and the results directory. train_func.__name__ = trainer_cls.__name__ trainable_cls = wrap_function(train_func, warn=False) class TrainTrainable(trainable_cls): """Add default resources to the Trainable.""" _handles_checkpoint_freq = trainer_cls._handles_checkpoint_freq _handles_checkpoint_at_end = trainer_cls._handles_checkpoint_at_end # Workaround for actor name not being logged correctly # if __repr__ is not directly defined in a class. def __repr__(self): return super().__repr__() def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Create a new config by merging the dicts. # run_config is not a tunable hyperparameter so it does not need to be # merged. run_config = base_config.pop("run_config", None) self._merged_config = merge_dicts(base_config, self.config) self._merged_config["run_config"] = run_config merged_scaling_config = self._merged_config.get( "scaling_config") if isinstance(merged_scaling_config, dict): merged_scaling_config = ScalingConfig( **merged_scaling_config) self._merged_config[ "scaling_config"] = self._reconcile_scaling_config_with_trial_resources( merged_scaling_config) def _reconcile_scaling_config_with_trial_resources( self, scaling_config: ScalingConfig) -> ScalingConfig: """ ResourceChangingScheduler workaround. Ensures that the scaling config matches trial resources. This should be replaced with RCS returning a ScalingConfig in the future. """ trial_resources = self.trial_resources # This will be false if the resources are default if not isinstance(trial_resources, PlacementGroupFactory): return scaling_config if scaling_config: scaling_config = trainer_cls._validate_scaling_config( scaling_config) scaling_config_from_trial_resources = ( ScalingConfig.from_placement_group_factory(trial_resources) ) # This check should always pass if ResourceChangingScheduler is not # used. if scaling_config_from_trial_resources != scaling_config: scaling_config = trainer_cls._validate_scaling_config( scaling_config_from_trial_resources) return scaling_config def _trainable_func(self, config, reporter, checkpoint_dir): # We ignore the config passed by Tune and instead use the merged # config which includes the initial Trainer args. super()._trainable_func(self._merged_config, reporter, checkpoint_dir) @classmethod def default_resource_request(cls, config): # `config["scaling_config"] is a dataclass when passed via the # `scaling_config` argument in `Trainer` and is a dict when passed # via the `scaling_config` key of `param_spec`. # Conversion logic must be duplicated in `TrainTrainable.__init__` # because this is a class method. updated_scaling_config = config.get("scaling_config", scaling_config) if isinstance(updated_scaling_config, dict): updated_scaling_config = ScalingConfig( **updated_scaling_config) validated_scaling_config = trainer_cls._validate_scaling_config( updated_scaling_config) return validated_scaling_config.as_placement_group_factory() return TrainTrainable