Пример #1
0
    def compile_model(
        context: keras.TFKerasContext,
        compile_args: inspect.BoundArguments,
        env: det.EnvContext,
        hvd_config: horovod.HorovodContext,
    ) -> None:
        context.model = keras._get_multi_gpu_model_if_using_native_parallel(
            pre_compiled_model=context.model,
            env=env,
            hvd_config=hvd_config,
        )

        if "optimizer" in compile_args.arguments:
            # For backwards compatibility we check if an optimizer is passed as part
            # of the compile call. If `wrap_optimizer()` is used, we will ignore this
            # this optimizer.
            compile_args.arguments[
                "optimizer"] = context._process_optimizer_from_compile(
                    compile_args.arguments["optimizer"])

        if hvd_config.use and version.parse("2.0.0") <= version.parse(
                tf.__version__) < version.parse("2.2.0"):
            logging.info(
                "Calling `model.compile(...)` with `experimental_run_tf_function=False` to ensure "
                "TensorFlow calls `optimizer.get_gradients()` to compute gradients."
            )

            context.model.compile(*compile_args.args,
                                  **compile_args.kwargs,
                                  experimental_run_tf_function=False)
        else:
            context.model.compile(*compile_args.args, **compile_args.kwargs)
Пример #2
0
    def compile_model(
        context: keras.TFKerasContext,
        compile_args: inspect.BoundArguments,
        env: det.EnvContext,
        hvd_config: horovod.HorovodContext,
    ) -> None:
        if "optimizer" in compile_args.arguments:
            # For backwards compatibility we check if an optimizer is passed as part
            # of the compile call. If `wrap_optimizer()` is used, we will ignore this
            # this optimizer.
            compile_args.arguments[
                "optimizer"] = context._process_optimizer_from_compile(
                    compile_args.arguments["optimizer"])

        # context.model is Optional[Model]. This assert signals to mypy it can't
        # be none because we check that in `from_trial`.
        assert context.model is not None

        if hvd_config.use and version.parse("2.0.0") <= version.parse(
                tf.__version__) < version.parse("2.2.0"):
            logging.info(
                "Calling `model.compile(...)` with `experimental_run_tf_function=False` to ensure "
                "TensorFlow calls `optimizer.get_gradients()` to compute gradients."
            )

            context.model.compile(*compile_args.args,
                                  **compile_args.kwargs,
                                  experimental_run_tf_function=False)
        else:
            context.model.compile(*compile_args.args, **compile_args.kwargs)