def get_tune_kwargs(config): """ Build and return the kwargs needed to run `tune.run` for a given config. :param config: - ray_trainable: the ray.tune.Trainable; defaults to RemoteProcessTrainable or DistributedTrainable, depending on whether the experiment class is distributed - stop_condition: If the trainable has this attribute, it will be used to decide which config parameter dictates the stop training_iteration. - sigopt_config: (optional) used for running experiments with SigOpt and the SigOpt trainables - restore: whether to restore from the latest checkpoint; defaults to False - local_dir: needed with 'restore'; identifies the parent directory of experiment results. - name: needed with 'restore'; local_dir/name identifies the path to the experiment checkpoints. """ # Build kwargs for `tune.run` function using merged config and command line dict kwargs_names = tune.run.__code__.co_varnames[:tune.run.__code__. co_argcount] # Zip the kwargs along with the Ray trainable. distributed = issubclass(config.get("experiment_class"), interfaces.DistributedAggregation) if "sigopt_config" in config: default_trainable = (trainables.SigOptDistributedTrainable if distributed else trainables.SigOptRemoteProcessTrainable) else: default_trainable = (trainables.DistributedTrainable if distributed else trainables.RemoteProcessTrainable) ray_trainable = config.get("ray_trainable", default_trainable) assert issubclass(ray_trainable, Trainable) kwargs = dict(zip(kwargs_names, [ray_trainable, *tune.run.__defaults__])) # Check if restoring experiment from last known checkpoint if config.pop("restore", False): result_dir = os.path.join(config["local_dir"], config["name"]) config["restore_checkpoint_file"] = get_last_checkpoint(result_dir) # Update`tune.run` kwargs with config kwargs.update(config) kwargs["config"] = config # Make sure to only select `tune.run` function arguments kwargs = dict(filter(lambda x: x[0] in kwargs_names, kwargs.items())) # Update the stop condition. kwargs["stop"] = kwargs.get("stop", {}) or dict() return kwargs
def run(config): # Connect to ray address = os.environ.get("REDIS_ADDRESS", config.get("redis_address")) ray.init(address=address, local_mode=config.get("local_mode", False)) # Register serializer and deserializer - needed when logging arrays and tensors. register_torch_serializers() # Build kwargs for `tune.run` function using merged config and command line dict kwargs_names = tune.run.__code__.co_varnames[:tune.run.__code__. co_argcount] if "sigopt_config" in config: kwargs = dict( zip(kwargs_names, [SigOptImagenetTrainable, *tune.run.__defaults__])) else: imagenet_trainable = config.get("imagenet_trainable", ImagenetTrainable) assert issubclass(imagenet_trainable, ImagenetTrainable) kwargs = dict( zip(kwargs_names, [imagenet_trainable, *tune.run.__defaults__])) # Check if restoring experiment from last known checkpoint if config.pop("restore", False): result_dir = os.path.join(config["local_dir"], config["name"]) config["restore_checkpoint_file"] = get_last_checkpoint(result_dir) # Update`tune.run` kwargs with config kwargs.update(config) kwargs["config"] = config # Make sure to only select`tune.run` function arguments kwargs = dict(filter(lambda x: x[0] in kwargs_names, kwargs.items())) # Queue trials until the cluster scales up kwargs.update(queue_trials=True) pprint(kwargs) result = tune.run(**kwargs) ray.shutdown() return result