def testCheckpointReuseObjectWithoutTraining(self): """Test that repeated save/restore never reuses same checkpoint dir.""" def train(config, checkpoint_dir=None): if checkpoint_dir: count = sum("checkpoint-" in path for path in os.listdir(checkpoint_dir)) assert count == 1, os.listdir(checkpoint_dir) for step in range(20): with tune.checkpoint_dir(step=step) as checkpoint_dir: path = os.path.join(checkpoint_dir, "checkpoint-{}".format(step)) open(path, "a").close() tune.report(test=step) wrapped = wrap_function(train) new_trainable = wrapped(logger_creator=self.logger_creator) for i in range(2): result = new_trainable.train() checkpoint = new_trainable.save_to_object() new_trainable.stop() new_trainable2 = wrapped(logger_creator=self.logger_creator) new_trainable2.restore_from_object(checkpoint) new_trainable2.stop() new_trainable2 = wrapped(logger_creator=self.logger_creator) new_trainable2.restore_from_object(checkpoint) result = new_trainable2.train() new_trainable2.stop() self.assertTrue(result[TRAINING_ITERATION] == 3)
def testReuseNullCheckpoint(self): def train(config, checkpoint_dir=None): assert not checkpoint_dir for step in range(10): tune.report(test=step) # Create checkpoint wrapped = wrap_function(train) checkpoint = None new_trainable = wrapped(logger_creator=self.logger_creator) new_trainable.train() checkpoint = new_trainable.save() new_trainable.stop() # Use the checkpoint a couple of times for i in range(3): new_trainable = wrapped(logger_creator=self.logger_creator) new_trainable.restore(checkpoint) new_trainable.stop() # Make sure the result is still good new_trainable = wrapped(logger_creator=self.logger_creator) new_trainable.restore(checkpoint) result = new_trainable.train() checkpoint = new_trainable.save() new_trainable.stop() self.assertTrue(result[TRAINING_ITERATION] == 1)
def testFunctionRecurringSave(self): """This tests that save and restore are commutative.""" def train(config, checkpoint_dir=None): if checkpoint_dir: assert os.path.exists(checkpoint_dir) for step in range(10): if step % 3 == 0: with tune.checkpoint_dir(step=step) as checkpoint_dir: path = os.path.join(checkpoint_dir, "checkpoint") with open(path, "w") as f: f.write(json.dumps({"step": step})) tune.report(test=step) wrapped = wrap_function(train) new_trainable = wrapped(logger_creator=self.logger_creator) new_trainable.train() checkpoint_obj = new_trainable.save_to_object() new_trainable.restore_from_object(checkpoint_obj) checkpoint = new_trainable.save() new_trainable.stop() new_trainable2 = wrapped(logger_creator=self.logger_creator) new_trainable2.restore(checkpoint) new_trainable2.train() new_trainable2.stop()
def testFunctionImmediateSave(self): """This tests that save and restore are commutative.""" def train(config, checkpoint_dir=None): if checkpoint_dir: assert os.path.exists(checkpoint_dir) for step in range(10): with tune.checkpoint_dir(step=step) as checkpoint_dir: print(checkpoint_dir) path = os.path.join(checkpoint_dir, "checkpoint-{}".format(step)) open(path, "w").close() tune.report(test=step) wrapped = wrap_function(train) new_trainable = wrapped(logger_creator=self.logger_creator) new_trainable.train() new_trainable.train() checkpoint_obj = new_trainable.save_to_object() new_trainable.stop() new_trainable2 = wrapped(logger_creator=self.logger_creator) new_trainable2.restore_from_object(checkpoint_obj) checkpoint_obj = new_trainable2.save_to_object() new_trainable2.train() result = new_trainable2.train() assert sum("tmp" in path for path in os.listdir(self.logdir)) == 1 new_trainable2.stop() assert sum("tmp" in path for path in os.listdir(self.logdir)) == 0 assert result[TRAINING_ITERATION] == 4
def create_resettable_function(num_resets: defaultdict): def trainable(config, checkpoint_dir=None): if checkpoint_dir: with open(os.path.join(checkpoint_dir, "chkpt"), "rb") as fp: step = pickle.load(fp) else: step = 0 while step < 2: step += 1 with tune.checkpoint_dir(step) as checkpoint_dir: with open(os.path.join(checkpoint_dir, "chkpt"), "wb") as fp: pickle.dump(step, fp) tune.report(**{ "done": step >= 2, "iter": step, "id": config["id"] }) trainable = wrap_function(trainable) class ResetCountTrainable(trainable): def reset_config(self, new_config): num_resets[self.trial_id] += 1 return super().reset_config(new_config) return ResetCountTrainable
def setup(self, config): self._finished = False num_workers = self._num_workers logdir = self.logdir assert self._function func_trainable = wrap_function(self.__class__._function) remote_trainable = ray.remote(func_trainable) remote_trainable = remote_trainable.options( **self.get_remote_worker_options()) address = setup_address() self.workers = [ remote_trainable.remote( config=config, logger_creator=lambda cfg: logger_creator(cfg, logdir, rank)) for rank in range(num_workers) ] pgroup_params = self.default_process_group_parameters() from functools import partial setup_on_worker = partial(setup_process_group, url=address, world_size=num_workers, **pgroup_params) ray.get([ w.execute.remote(lambda _: setup_on_worker(world_rank=rank)) for rank, w in enumerate(self.workers) ])
def setup(self, config: Dict): self._finished = False num_workers = self._num_workers assert self._function func_trainable = wrap_function(self.__class__._function) remote_trainable = ray.remote(func_trainable) remote_option, self._placement_group =\ PlacementGroupUtil.get_remote_worker_options( self._num_workers, self._num_cpus_per_worker, self._num_gpus_per_worker, self._num_workers_per_host, self._timeout_s) remote_trainable = \ remote_trainable.options(**remote_option) self.workers = [ remote_trainable.remote(config=config, ) for _ in range(num_workers) ] addresses = [ ray.get(worker.execute.remote(lambda _: setup_address())) for worker in self.workers ] from functools import partial setup_on_worker = partial( setup_process_group, worker_addresses=addresses) ray.get([ w.execute.remote(lambda _: setup_on_worker(index=index)) for index, w in enumerate(self.workers) ])
def setup(self, config: Dict): trainable = wrap_function(self.__class__._function) # We use a filelock here to ensure that the file-writing # process is safe across different trainables. if self._ssh_identity_file: with FileLock(self._ssh_identity_file + ".lock"): settings = RayExecutor.create_settings(self._timeout_s, self._ssh_identity_file, self._ssh_str) else: settings = RayExecutor.create_settings(self._timeout_s, self._ssh_identity_file, self._ssh_str) self.executor = RayExecutor(settings, cpus_per_slot=self._num_cpus_per_slot, use_gpu=self._use_gpu, num_hosts=self._num_hosts, num_slots=self._num_slots) new_config = DistributedTrainable.build_config(self, config) # We can't put `self` in the lambda closure, so we # resolve the variable ahead of time. logdir_ = str(self.logdir) # Starts the workers as specified by the resources above. self.executor.start(executable_cls=trainable, executable_kwargs={ "config": new_config, "logger_creator": lambda cfg: logger_creator(cfg, logdir_) })
def wrap_function_patched(function): """ Monkey-patch FunctionRunner remote trainable""" func_trainable = wrap_function(function) func_trainable.save_all_states = types.MethodType(save_all_states_remote, func_trainable) func_trainable.get_sched_hints = types.MethodType(get_sched_hints_remote, func_trainable) return func_trainable
def as_trainable(self) -> Type[Trainable]: """Convert self to a ``tune.Trainable`` class.""" base_config = self._param_dict trainer_cls = self.__class__ scaling_config = self.scaling_config def train_func(config, checkpoint_dir=None): # config already contains merged values. # Instantiate new Trainer in Trainable. trainer = trainer_cls(**config) if checkpoint_dir: trainer.resume_from_checkpoint = Checkpoint.from_directory( checkpoint_dir ) trainer.setup() trainer.preprocess_datasets() trainer.training_loop() # Change the name of the training function to match the name of the Trainer # class. This will mean the Tune trial name will match the name of Trainer on # stdout messages and the results directory. train_func.__name__ = trainer_cls.__name__ trainable_cls = wrap_function(train_func) class TrainTrainable(trainable_cls): """Add default resources to the Trainable.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Create a new config by merging the dicts. # run_config is not a tunable hyperparameter so it does not need to be # merged. run_config = base_config.pop("run_config", None) self._merged_config = merge_dicts(base_config, self.config) self._merged_config["run_config"] = run_config def _trainable_func(self, config, reporter, checkpoint_dir): # We ignore the config passed by Tune and instead use the merged # config which includes the initial Trainer args. super()._trainable_func(self._merged_config, reporter, checkpoint_dir) @classmethod def default_resource_request(cls, config): updated_scaling_config = config.get("scaling_config", scaling_config) scaling_config_dataclass = ( trainer_cls._validate_and_get_scaling_config_data_class( updated_scaling_config ) ) return scaling_config_dataclass.as_placement_group_factory() return TrainTrainable
def _create_tune_trainable( train_func, dataset, backend_config, num_workers, use_gpu, resources_per_worker ): """Creates a Tune Trainable class for Train training. This function populates class attributes and methods. """ # TODO(matt): Move dataset to Ray object store, like tune.with_parameters. def tune_function(config, checkpoint_dir=None): trainer = Trainer( backend=backend_config, num_workers=num_workers, use_gpu=use_gpu, resources_per_worker=resources_per_worker, ) trainer.start() if checkpoint_dir is not None: checkpoint_path = os.path.join(checkpoint_dir, TUNE_CHECKPOINT_FILE_NAME) else: checkpoint_path = None iterator = trainer.run_iterator( train_func, config, dataset=dataset, checkpoint=checkpoint_path ) for results in iterator: first_worker_results = results[0] tune.report(**first_worker_results) trainer.shutdown() trainable_cls = wrap_function(tune_function) class TrainTrainable(trainable_cls): """Add default resources to the Trainable.""" @classmethod def default_resource_request(cls, config: Dict) -> PlacementGroupFactory: trainer_bundle = [{"CPU": 1}] worker_resources = {"CPU": 1, "GPU": int(use_gpu)} worker_resources_extra = ( {} if resources_per_worker is None else resources_per_worker ) worker_bundles = [ {**worker_resources, **worker_resources_extra} for _ in range(num_workers) ] bundles = trainer_bundle + worker_bundles return PlacementGroupFactory(bundles, strategy="PACK") return TrainTrainable
def setup(self, config: Dict): self._finished = False num_workers = self._num_workers logdir = self.logdir assert self._function func_trainable = wrap_function(self.__class__._function) remote_trainable = ray.remote(func_trainable) ( remote_option, self._placement_group, ) = PlacementGroupUtil.get_remote_worker_options( self._num_workers, self._num_cpus_per_worker, self._num_gpus_per_worker, self._num_workers_per_host, self._timeout_s, ) remote_trainable = remote_trainable.options(**remote_option) new_config = DistributedTrainable.build_config(self, config) self.workers = [ remote_trainable.remote( config=new_config, logger_creator=lambda cfg: logger_creator(cfg, logdir, rank), ) for rank in range(num_workers) ] # Address has to be IP of rank 0 worker's node. address = ray.get(self.workers[0].execute.remote(lambda _: setup_address())) pgroup_params = self.default_process_group_parameters() from functools import partial setup_on_worker = partial( setup_process_group, url=address, world_size=num_workers, **pgroup_params ) ray.get( [ w.execute.remote(lambda _: setup_on_worker(world_rank=rank)) for rank, w in enumerate(self.workers) ] ) ray.get( [ w.execute.remote(lambda _: enable_distributed_trainable()) for rank, w in enumerate(self.workers) ] )
def testMultipleNullMemoryCheckpoints(self): def train(config, checkpoint_dir=None): assert not checkpoint_dir for step in range(10): tune.report(test=step) wrapped = wrap_function(train) checkpoint = None for i in range(5): new_trainable = wrapped(logger_creator=self.logger_creator) if checkpoint: new_trainable.restore_from_object(checkpoint) result = new_trainable.train() checkpoint = new_trainable.save_to_object() new_trainable.stop() assert result[TRAINING_ITERATION] == 1
def as_trainable(self) -> Type[Trainable]: """Convert self to a ``tune.Trainable`` class.""" base_config = self._param_dict trainer_cls = self.__class__ scaling_config = self.scaling_config def train_func(config, checkpoint_dir=None): # config already contains merged values. # Instantiate new Trainer in Trainable. trainer = trainer_cls(**config) if checkpoint_dir: trainer.resume_from_checkpoint = Checkpoint.from_directory( checkpoint_dir) trainer.setup() trainer.preprocess_datasets() trainer.training_loop() trainable_cls = wrap_function(train_func) class TrainTrainable(trainable_cls): """Add default resources to the Trainable.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Create a new config by merging the dicts. self._merged_config = merge_dicts(base_config, self.config) def _trainable_func(self, config, reporter, checkpoint_dir): # We ignore the config passed by Tune and instead use the merged # config which includes the initial Trainer args. super()._trainable_func(self._merged_config, reporter, checkpoint_dir) @classmethod def default_resource_request(cls, config): updated_scaling_config = config.get("scaling_config", scaling_config) scaling_config_dataclass = ScalingConfigDataClass( **updated_scaling_config) return scaling_config_dataclass.as_placement_group_factory() return TrainTrainable
def testFunctionNoCheckpointing(self): def train(config, checkpoint_dir=None): for i in range(10): tune.report(test=i) wrapped = wrap_function(train) new_trainable = wrapped() result = new_trainable.train() checkpoint = new_trainable.save() new_trainable.stop() new_trainable2 = wrapped() new_trainable2.restore(checkpoint) result = new_trainable2.train() self.assertEquals(result[TRAINING_ITERATION], 1) checkpoint = new_trainable2.save() new_trainable2.stop()
def testFunctionNoCheckpointing(self): def train(config, checkpoint_dir=None): if checkpoint_dir: assert os.path.exists(checkpoint_dir) for step in range(10): tune.report(test=step) wrapped = wrap_function(train) new_trainable = wrapped(logger_creator=self.logger_creator) result = new_trainable.train() checkpoint = new_trainable.save() new_trainable.stop() new_trainable2 = wrapped(logger_creator=self.logger_creator) new_trainable2.restore(checkpoint) result = new_trainable2.train() self.assertEqual(result[TRAINING_ITERATION], 1) checkpoint = new_trainable2.save() new_trainable2.stop()
def testMlFlowMixinConfig(self): clear_env_vars() trial_config = {"par1": 4, "par2": 9.} @mlflow_mixin def train_fn(config): return 1 train_fn.__mixins__ = (MLFlowTrainableMixin, ) # No MLFlow config passed in. with self.assertRaises(ValueError): wrapped = wrap_function(train_fn)(trial_config) trial_config.update({"mlflow": {}}) # No tracking uri or experiment_id/name passed in. with self.assertRaises(ValueError): wrapped = wrap_function(train_fn)(trial_config) # Invalid experiment-id trial_config["mlflow"].update({"experiment_id": "500"}) # No tracking uri or experiment_id/name passed in. with self.assertRaises(ValueError): wrapped = wrap_function(train_fn)(trial_config) trial_config["mlflow"].update({ "tracking_uri": "test_tracking_uri", "experiment_name": "existing_experiment" }) wrapped = wrap_function(train_fn)(trial_config) client = wrapped._mlflow self.assertEqual(client.tracking_uri, "test_tracking_uri") self.assertTupleEqual(client.active_run.run_id, (0, 0)) with patch("ray.tune.integration.mlflow._import_mlflow", lambda: client): train_fn.__mixins__ = (MLFlowTrainableMixin, ) wrapped = wrap_function(train_fn)(trial_config) client = wrapped._mlflow self.assertTupleEqual(client.active_run.run_id, (0, 1)) # Set to experiment that does not already exist. # New experiment should be created. trial_config["mlflow"]["experiment_name"] = "new_experiment" with self.assertRaises(ValueError): wrapped = wrap_function(train_fn)(trial_config)
def testWandbDecoratorConfig(self): config = {"par1": 4, "par2": 9.12345678} trial = Trial(config, 0, "trial_0", "trainable", PlacementGroupFactory([{ "CPU": 1 }])) trial_info = TrialInfo(trial) @wandb_mixin def train_fn(config): return 1 train_fn.__mixins__ = (_MockWandbTrainableMixin, ) config[TRIAL_INFO] = trial_info if WANDB_ENV_VAR in os.environ: del os.environ[WANDB_ENV_VAR] # Needs at least a project with self.assertRaises(ValueError): wrapped = wrap_function(train_fn)(config) # No API key config["wandb"] = {"project": "test_project"} with self.assertRaises(ValueError): wrapped = wrap_function(train_fn)(config) # API Key in config config["wandb"] = {"project": "test_project", "api_key": "1234"} wrapped = wrap_function(train_fn)(config) self.assertEqual(os.environ[WANDB_ENV_VAR], "1234") del os.environ[WANDB_ENV_VAR] # API Key file with tempfile.NamedTemporaryFile("wt") as fp: fp.write("5678") fp.flush() config["wandb"] = { "project": "test_project", "api_key_file": fp.name } wrapped = wrap_function(train_fn)(config) self.assertEqual(os.environ[WANDB_ENV_VAR], "5678") del os.environ[WANDB_ENV_VAR] # API Key in env os.environ[WANDB_ENV_VAR] = "9012" config["wandb"] = {"project": "test_project"} wrapped = wrap_function(train_fn)(config) # From now on, the API key is in the env variable. # Default configuration config["wandb"] = {"project": "test_project"} config[TRIAL_INFO] = trial_info wrapped = wrap_function(train_fn)(config) self.assertEqual(wrapped.wandb.kwargs["project"], "test_project") self.assertEqual(wrapped.wandb.kwargs["id"], trial.trial_id) self.assertEqual(wrapped.wandb.kwargs["name"], trial.trial_name)
def durable(trainable: Union[str, Type[Trainable], Callable]): """Convert trainable into a durable trainable. Durable trainables are used to upload trial results and checkpoints to cloud storage, like e.g. AWS S3. This function can be used to convert your trainable, i.e. your trainable classes, functions, or string identifiers, to a durable trainable. To make durable trainables work, you should pass a valid :class:`SyncConfig <ray.tune.SyncConfig>` object to `tune.run()`. Example: .. code-block:: python from ray import tune analysis = tune.run( tune.durable("PPO"), config={"env": "CartPole-v0"}, checkpoint_freq=1, sync_config=tune.SyncConfig( sync_to_driver=False, upload_dir="s3://your-s3-bucket/durable-ppo/", )) You can also convert your trainable functions: .. code-block:: python tune.run( tune.durable(your_training_fn), # ... ) And your class functions: .. code-block:: python tune.run( tune.durable(YourTrainableClass), # ... ) Args: trainable (str|Type[Trainable]|Callable): Trainable. Can be a string identifier, a trainable class, or a trainable function. Returns: A durable trainable class wrapped around your trainable. """ if isinstance(trainable, str): trainable_cls = get_trainable_cls(trainable) else: trainable_cls = trainable if not inspect.isclass(trainable_cls): # Function API return wrap_function(trainable_cls, durable=True) if not issubclass(trainable_cls, Trainable): raise ValueError( "You can only use `durable()` with valid trainables. The class " "you passed does not inherit from `Trainable`. Please make sure " f"it does. Got: {type(trainable_cls)}") # else: Class API class _WrappedDurableTrainable(DurableTrainable, trainable_cls): _name = trainable_cls.__name__ if hasattr(trainable_cls, "__name__") \ else "durable_trainable" return _WrappedDurableTrainable