Ejemplo n.º 1
0
def wrap_function(train_func, warn=True):
    if hasattr(train_func, "__mixins__"):
        inherit_from = train_func.__mixins__ + (FunctionRunner, )
    else:
        inherit_from = (FunctionRunner, )

    func_args = inspect.getfullargspec(train_func).args
    use_checkpoint = detect_checkpoint_function(train_func)
    use_config_single = detect_config_single(train_func)
    use_reporter = detect_reporter(train_func)

    if not any([use_checkpoint, use_config_single, use_reporter]):
        # use_reporter is hidden
        raise ValueError(
            "Unknown argument found in the Trainable function. "
            "The function args must include a 'config' positional "
            "parameter. Any other args must be 'checkpoint_dir'. "
            "Found: {}".format(func_args))

    if use_config_single and not use_checkpoint:
        if log_once("tune_function_checkpoint") and warn:
            logger.warning(
                "Function checkpointing is disabled. This may result in "
                "unexpected behavior when using checkpointing features or "
                "certain schedulers. To enable, set the train function "
                "arguments to be `func(config, checkpoint_dir=None)`.")

    class ImplicitFunc(*inherit_from):
        _name = train_func.__name__ if hasattr(train_func, "__name__") \
            else "func"

        def _trainable_func(self, config, reporter, checkpoint_dir):
            if not use_checkpoint and not use_reporter:
                output = train_func(config)
            elif use_checkpoint:
                output = train_func(config, checkpoint_dir=checkpoint_dir)
            else:
                output = train_func(config, reporter)

            # If train_func returns, we need to notify the main event loop
            # of the last result while avoiding double logging. This is done
            # with the keyword RESULT_DUPLICATE -- see tune/trial_runner.py.
            reporter(**{RESULT_DUPLICATE: True})
            return output

    return ImplicitFunc
Ejemplo n.º 2
0
def wrap_function(
    train_func: Callable[[Any], Any], warn: bool = True, name: Optional[str] = None
):
    inherit_from = (FunctionTrainable,)

    if hasattr(train_func, "__mixins__"):
        inherit_from = train_func.__mixins__ + inherit_from

    func_args = inspect.getfullargspec(train_func).args
    use_checkpoint = detect_checkpoint_function(train_func)
    use_config_single = detect_config_single(train_func)
    use_reporter = detect_reporter(train_func)

    if not any([use_checkpoint, use_config_single, use_reporter]):
        # use_reporter is hidden
        raise ValueError(
            "Unknown argument found in the Trainable function. "
            "The function args must include a 'config' positional "
            "parameter. Any other args must be 'checkpoint_dir'. "
            "Found: {}".format(func_args)
        )

    if use_config_single and not use_checkpoint:
        if log_once("tune_function_checkpoint") and warn:
            logger.warning(
                "Function checkpointing is disabled. This may result in "
                "unexpected behavior when using checkpointing features or "
                "certain schedulers. To enable, set the train function "
                "arguments to be `func(config, checkpoint_dir=None)`."
            )

    class ImplicitFunc(*inherit_from):
        _name = name or (
            train_func.__name__ if hasattr(train_func, "__name__") else "func"
        )

        def __repr__(self):
            return self._name

        def _trainable_func(self, config, reporter, checkpoint_dir):
            if not use_checkpoint and not use_reporter:
                fn = partial(train_func, config)
            elif use_checkpoint:
                fn = partial(train_func, config, checkpoint_dir=checkpoint_dir)
            else:
                fn = partial(train_func, config, reporter)

            def handle_output(output):
                if not output:
                    return
                elif isinstance(output, dict):
                    reporter(**output)
                elif isinstance(output, Number):
                    reporter(_metric=output)
                else:
                    raise ValueError(
                        "Invalid return or yield value. Either return/yield "
                        "a single number or a dictionary object in your "
                        "trainable function."
                    )

            output = None
            if inspect.isgeneratorfunction(train_func):
                for output in fn():
                    handle_output(output)
            else:
                output = fn()
                handle_output(output)

            # If train_func returns, we need to notify the main event loop
            # of the last result while avoiding double logging. This is done
            # with the keyword RESULT_DUPLICATE -- see tune/trial_runner.py.
            reporter(**{RESULT_DUPLICATE: True})
            return output

    return ImplicitFunc
Ejemplo n.º 3
0
def wrap_function(train_func: Callable[[Any], Any],
                  warn: bool = True,
                  name: Optional[str] = None) -> Type["FunctionTrainable"]:
    inherit_from = (FunctionTrainable, )

    if hasattr(train_func, "__mixins__"):
        inherit_from = train_func.__mixins__ + inherit_from

    func_args = inspect.getfullargspec(train_func).args
    use_checkpoint = detect_checkpoint_function(train_func)
    use_config_single = detect_config_single(train_func)
    use_reporter = detect_reporter(train_func)

    if not any([use_checkpoint, use_config_single, use_reporter]):
        # use_reporter is hidden
        raise ValueError(
            "Unknown argument found in the Trainable function. "
            "The function args must include a 'config' positional "
            "parameter. Any other args must be 'checkpoint_dir'. "
            "Found: {}".format(func_args))

    if use_config_single and not use_checkpoint:
        if log_once("tune_function_checkpoint") and warn:
            logger.warning(
                "Function checkpointing is disabled. This may result in "
                "unexpected behavior when using checkpointing features or "
                "certain schedulers. To enable, set the train function "
                "arguments to be `func(config, checkpoint_dir=None)`.")

    if use_checkpoint:
        if log_once("tune_checkpoint_dir_deprecation") and warn:
            with warnings.catch_warnings():
                warnings.simplefilter("always")
                warning_msg = (
                    "`checkpoint_dir` in `func(config, checkpoint_dir)` is "
                    "being deprecated. "
                    "To save and load checkpoint in trainable functions, "
                    "please use the `ray.air.session` API:\n\n"
                    "from ray.air import session\n\n"
                    "def train(config):\n"
                    "    # ...\n"
                    '    session.report({"metric": metric}, checkpoint=checkpoint)\n\n'
                    "For more information please see "
                    "https://docs.ray.io/en/master/ray-air/key-concepts.html#session\n"
                )
                warnings.warn(
                    warning_msg,
                    DeprecationWarning,
                )

    class ImplicitFunc(*inherit_from):
        _name = name or (train_func.__name__ if hasattr(
            train_func, "__name__") else "func")

        def __repr__(self):
            return self._name

        def _trainable_func(self, config, reporter, checkpoint_dir):
            if not use_checkpoint and not use_reporter:
                fn = partial(train_func, config)
            elif use_checkpoint:
                fn = partial(train_func, config, checkpoint_dir=checkpoint_dir)
            else:
                fn = partial(train_func, config, reporter)

            def handle_output(output):
                if not output:
                    return
                elif isinstance(output, dict):
                    reporter(**output)
                elif isinstance(output, Number):
                    reporter(_metric=output)
                else:
                    raise ValueError(
                        "Invalid return or yield value. Either return/yield "
                        "a single number or a dictionary object in your "
                        "trainable function.")

            output = None
            if inspect.isgeneratorfunction(train_func):
                for output in fn():
                    handle_output(output)
            else:
                output = fn()
                handle_output(output)

            # If train_func returns, we need to notify the main event loop
            # of the last result while avoiding double logging. This is done
            # with the keyword RESULT_DUPLICATE -- see tune/trial_runner.py.
            reporter(**{RESULT_DUPLICATE: True})
            return output

    return ImplicitFunc