def get_tune_kwargs(config):
    """
    Build and return the kwargs needed to run `tune.run` for a given config.

    :param config:
        - ray_trainable: the ray.tune.Trainable; defaults to
                         RemoteProcessTrainable or DistributedTrainable,
                         depending on whether the experiment class is distributed
            - stop_condition: If the trainable has this attribute, it will be used
                              to decide which config parameter dictates the stop
                              training_iteration.
        - sigopt_config: (optional) used for running experiments with SigOpt and
                         the SigOpt trainables
        - restore: whether to restore from the latest checkpoint; defaults to False
        - local_dir: needed with 'restore'; identifies the parent directory of
                     experiment results.
        - name: needed with 'restore'; local_dir/name identifies the path to
                the experiment checkpoints.
    """

    # Build kwargs for `tune.run` function using merged config and command line dict
    kwargs_names = tune.run.__code__.co_varnames[:tune.run.__code__.
                                                 co_argcount]

    # Zip the kwargs along with the Ray trainable.
    distributed = issubclass(config.get("experiment_class"),
                             interfaces.DistributedAggregation)
    if "sigopt_config" in config:
        default_trainable = (trainables.SigOptDistributedTrainable
                             if distributed else
                             trainables.SigOptRemoteProcessTrainable)
    else:
        default_trainable = (trainables.DistributedTrainable if distributed
                             else trainables.RemoteProcessTrainable)

    ray_trainable = config.get("ray_trainable", default_trainable)
    assert issubclass(ray_trainable, Trainable)
    kwargs = dict(zip(kwargs_names, [ray_trainable, *tune.run.__defaults__]))

    # Check if restoring experiment from last known checkpoint
    if config.pop("restore", False):
        result_dir = os.path.join(config["local_dir"], config["name"])
        config["restore_checkpoint_file"] = get_last_checkpoint(result_dir)

    # Update`tune.run` kwargs with config
    kwargs.update(config)
    kwargs["config"] = config

    # Make sure to only select `tune.run` function arguments
    kwargs = dict(filter(lambda x: x[0] in kwargs_names, kwargs.items()))

    # Update the stop condition.
    kwargs["stop"] = kwargs.get("stop", {}) or dict()

    return kwargs
Ejemplo n.º 2
0
def run(config):
    # Connect to ray
    address = os.environ.get("REDIS_ADDRESS", config.get("redis_address"))
    ray.init(address=address, local_mode=config.get("local_mode", False))

    # Register serializer and deserializer - needed when logging arrays and tensors.
    register_torch_serializers()

    # Build kwargs for `tune.run` function using merged config and command line dict
    kwargs_names = tune.run.__code__.co_varnames[:tune.run.__code__.
                                                 co_argcount]

    if "sigopt_config" in config:
        kwargs = dict(
            zip(kwargs_names,
                [SigOptImagenetTrainable, *tune.run.__defaults__]))
    else:
        imagenet_trainable = config.get("imagenet_trainable",
                                        ImagenetTrainable)
        assert issubclass(imagenet_trainable, ImagenetTrainable)
        kwargs = dict(
            zip(kwargs_names, [imagenet_trainable, *tune.run.__defaults__]))

    # Check if restoring experiment from last known checkpoint
    if config.pop("restore", False):
        result_dir = os.path.join(config["local_dir"], config["name"])
        config["restore_checkpoint_file"] = get_last_checkpoint(result_dir)

    # Update`tune.run` kwargs with config
    kwargs.update(config)
    kwargs["config"] = config

    # Make sure to only select`tune.run` function arguments
    kwargs = dict(filter(lambda x: x[0] in kwargs_names, kwargs.items()))

    # Queue trials until the cluster scales up
    kwargs.update(queue_trials=True)

    pprint(kwargs)
    result = tune.run(**kwargs)
    ray.shutdown()
    return result