def compile_model( context: keras.TFKerasContext, compile_args: inspect.BoundArguments, env: det.EnvContext, hvd_config: horovod.HorovodContext, ) -> None: context.model = keras._get_multi_gpu_model_if_using_native_parallel( pre_compiled_model=context.model, env=env, hvd_config=hvd_config, ) if "optimizer" in compile_args.arguments: # For backwards compatibility we check if an optimizer is passed as part # of the compile call. If `wrap_optimizer()` is used, we will ignore this # this optimizer. compile_args.arguments[ "optimizer"] = context._process_optimizer_from_compile( compile_args.arguments["optimizer"]) if hvd_config.use and version.parse("2.0.0") <= version.parse( tf.__version__) < version.parse("2.2.0"): logging.info( "Calling `model.compile(...)` with `experimental_run_tf_function=False` to ensure " "TensorFlow calls `optimizer.get_gradients()` to compute gradients." ) context.model.compile(*compile_args.args, **compile_args.kwargs, experimental_run_tf_function=False) else: context.model.compile(*compile_args.args, **compile_args.kwargs)
def compile_model( context: keras.TFKerasContext, compile_args: inspect.BoundArguments, env: det.EnvContext, hvd_config: horovod.HorovodContext, ) -> None: if "optimizer" in compile_args.arguments: # For backwards compatibility we check if an optimizer is passed as part # of the compile call. If `wrap_optimizer()` is used, we will ignore this # this optimizer. compile_args.arguments[ "optimizer"] = context._process_optimizer_from_compile( compile_args.arguments["optimizer"]) # context.model is Optional[Model]. This assert signals to mypy it can't # be none because we check that in `from_trial`. assert context.model is not None if hvd_config.use and version.parse("2.0.0") <= version.parse( tf.__version__) < version.parse("2.2.0"): logging.info( "Calling `model.compile(...)` with `experimental_run_tf_function=False` to ensure " "TensorFlow calls `optimizer.get_gradients()` to compute gradients." ) context.model.compile(*compile_args.args, **compile_args.kwargs, experimental_run_tf_function=False) else: context.model.compile(*compile_args.args, **compile_args.kwargs)