def _build_network_on_replica(model, inputs=None, targets=None, mode=None):
    """Build an updated model on replicas.

  We create a new Keras model while sharing the variables from the old graph.
  Building a new sub-graph is required since the original keras model creates
  placeholders for the input and the output that are not accessible till we
  call iterator.get_next() inside the step_fn for `fit`/`evaluate`/`predict`.

  The sharing of weights and layers between the old and the new model gaurantee
  that we're using Strategy variables and any updates on either model are
  reflected correctly in callbacks and loop iterations.

  We need to make sure we share the optimizers between the old and the new model
  as well so that optimizer state is not lost if the user is running fit
  multiple times.

  Args:
    model: Model to be replicated across Replicas
    inputs: Input variables to be passed to the model
    targets: Target tensor to be passed to model.compile
    mode: Which of fit/eval/predict is building the distributed network

  Returns:
    A new model with shared layers with the old model.
  """
    # Need to do imports here since we run into a circular dependency error.
    from tensorflow.python.keras import models  # pylint: disable=g-import-not-at-top
    from tensorflow.python.keras.engine import sequential  # pylint: disable=g-import-not-at-top

    # We rely on the internal methods to avoid having share_weights weights in the
    # public API.
    if isinstance(model, sequential.Sequential):
        updated_model = models._clone_sequential_model(model,
                                                       input_tensors=inputs,
                                                       share_weights=True)
    else:
        updated_model = models._clone_functional_model(model,
                                                       input_tensors=inputs,
                                                       share_weights=True)

    # Recast all low precision outputs back to float32 since we only casted
    # the inputs to bfloat16 and not targets. This is done so that we can preserve
    # precision when calculating the loss value.
    def _upcast_low_precision_outputs(output):
        if output.dtype == dtypes.bfloat16:
            return math_ops.cast(output, dtypes.float32)
        else:
            return output

    updated_model.outputs = [
        _upcast_low_precision_outputs(o) for o in updated_model.outputs
    ]

    if isinstance(targets, tuple):
        targets = nest.flatten(targets)

    if mode == ModeKeys.PREDICT:
        _custom_compile_for_predict(updated_model)
    else:
        updated_model.compile(model.optimizer,
                              model.loss,
                              metrics=metrics_module.clone_metrics(
                                  model._compile_metrics),
                              loss_weights=model.loss_weights,
                              sample_weight_mode=model.sample_weight_mode,
                              weighted_metrics=metrics_module.clone_metrics(
                                  model._compile_weighted_metrics),
                              target_tensors=targets)
    return updated_model
def _build_network_on_replica(model, mode, inputs=None, targets=None):
  """Build an updated model on replicas.

  We create a new Keras model while sharing the variables from the old graph.
  Building a new sub-graph is required since the original keras model creates
  placeholders for the input and the output that are not accessible till we
  call iterator.get_next() inside the step_fn for `fit`/`evaluate`/`predict`.

  The sharing of weights and layers between the old and the new model gaurantee
  that we're using Strategy variables and any updates on either model are
  reflected correctly in callbacks and loop iterations.

  We need to make sure we share the optimizers between the old and the new model
  as well so that optimizer state is not lost if the user is running fit
  multiple times.

  Args:
    model: Model to be replicated across Replicas
    mode: Which of fit/eval/predict is building the distributed network
    inputs: Input variables to be passed to the model
    targets: Target tensor to be passed to model.compile

  Returns:
    A new model with shared layers with the old model.
  """
  # Need to do imports here since we run into a circular dependency error.
  from tensorflow.python.keras import models  # pylint: disable=g-import-not-at-top
  from tensorflow.python.keras.engine import sequential  # pylint: disable=g-import-not-at-top

  # We rely on the internal methods to avoid having share_weights weights in the
  # public API.
  if isinstance(model, sequential.Sequential):
    updated_model = models._clone_sequential_model(model, input_tensors=inputs,
                                                   share_weights=True)
  else:
    updated_model = models._clone_functional_model(model, input_tensors=inputs,
                                                   share_weights=True)

  # Recast all low precision outputs back to float32 since we only casted
  # the inputs to bfloat16 and not targets. This is done so that we can preserve
  # precision when calculating the loss value.
  def _upcast_low_precision_outputs(output):
    if output.dtype == dtypes.bfloat16:
      return math_ops.cast(output, dtypes.float32)
    else:
      return output
  updated_model.outputs = [_upcast_low_precision_outputs(o)
                           for o in updated_model.outputs]

  if isinstance(targets, tuple):
    targets = nest.flatten(targets)

  if mode == ModeKeys.PREDICT and inputs is not None:  # TPU predict case
    _custom_compile_for_predict(updated_model)
  else:
    updated_model.compile(
        model.optimizer,
        model.loss,
        metrics=metrics_module.clone_metrics(model._compile_metrics),
        loss_weights=model.loss_weights,
        sample_weight_mode=model.sample_weight_mode,
        weighted_metrics=metrics_module.clone_metrics(
            model._compile_weighted_metrics),
        target_tensors=targets)
  return updated_model