def _build_network_on_replica(model, inputs=None, targets=None, mode=None): """Build an updated model on replicas. We create a new Keras model while sharing the variables from the old graph. Building a new sub-graph is required since the original keras model creates placeholders for the input and the output that are not accessible till we call iterator.get_next() inside the step_fn for `fit`/`evaluate`/`predict`. The sharing of weights and layers between the old and the new model gaurantee that we're using Strategy variables and any updates on either model are reflected correctly in callbacks and loop iterations. We need to make sure we share the optimizers between the old and the new model as well so that optimizer state is not lost if the user is running fit multiple times. Args: model: Model to be replicated across Replicas inputs: Input variables to be passed to the model targets: Target tensor to be passed to model.compile mode: Which of fit/eval/predict is building the distributed network Returns: A new model with shared layers with the old model. """ # Need to do imports here since we run into a circular dependency error. from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top from tensorflow.python.keras.engine import sequential # pylint: disable=g-import-not-at-top # We rely on the internal methods to avoid having share_weights weights in the # public API. if isinstance(model, sequential.Sequential): updated_model = models._clone_sequential_model(model, input_tensors=inputs, share_weights=True) else: updated_model = models._clone_functional_model(model, input_tensors=inputs, share_weights=True) # Recast all low precision outputs back to float32 since we only casted # the inputs to bfloat16 and not targets. This is done so that we can preserve # precision when calculating the loss value. def _upcast_low_precision_outputs(output): if output.dtype == dtypes.bfloat16: return math_ops.cast(output, dtypes.float32) else: return output updated_model.outputs = [ _upcast_low_precision_outputs(o) for o in updated_model.outputs ] if isinstance(targets, tuple): targets = nest.flatten(targets) if mode == ModeKeys.PREDICT: _custom_compile_for_predict(updated_model) else: updated_model.compile(model.optimizer, model.loss, metrics=metrics_module.clone_metrics( model._compile_metrics), loss_weights=model.loss_weights, sample_weight_mode=model.sample_weight_mode, weighted_metrics=metrics_module.clone_metrics( model._compile_weighted_metrics), target_tensors=targets) return updated_model
def _build_network_on_replica(model, mode, inputs=None, targets=None): """Build an updated model on replicas. We create a new Keras model while sharing the variables from the old graph. Building a new sub-graph is required since the original keras model creates placeholders for the input and the output that are not accessible till we call iterator.get_next() inside the step_fn for `fit`/`evaluate`/`predict`. The sharing of weights and layers between the old and the new model gaurantee that we're using Strategy variables and any updates on either model are reflected correctly in callbacks and loop iterations. We need to make sure we share the optimizers between the old and the new model as well so that optimizer state is not lost if the user is running fit multiple times. Args: model: Model to be replicated across Replicas mode: Which of fit/eval/predict is building the distributed network inputs: Input variables to be passed to the model targets: Target tensor to be passed to model.compile Returns: A new model with shared layers with the old model. """ # Need to do imports here since we run into a circular dependency error. from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top from tensorflow.python.keras.engine import sequential # pylint: disable=g-import-not-at-top # We rely on the internal methods to avoid having share_weights weights in the # public API. if isinstance(model, sequential.Sequential): updated_model = models._clone_sequential_model(model, input_tensors=inputs, share_weights=True) else: updated_model = models._clone_functional_model(model, input_tensors=inputs, share_weights=True) # Recast all low precision outputs back to float32 since we only casted # the inputs to bfloat16 and not targets. This is done so that we can preserve # precision when calculating the loss value. def _upcast_low_precision_outputs(output): if output.dtype == dtypes.bfloat16: return math_ops.cast(output, dtypes.float32) else: return output updated_model.outputs = [_upcast_low_precision_outputs(o) for o in updated_model.outputs] if isinstance(targets, tuple): targets = nest.flatten(targets) if mode == ModeKeys.PREDICT and inputs is not None: # TPU predict case _custom_compile_for_predict(updated_model) else: updated_model.compile( model.optimizer, model.loss, metrics=metrics_module.clone_metrics(model._compile_metrics), loss_weights=model.loss_weights, sample_weight_mode=model.sample_weight_mode, weighted_metrics=metrics_module.clone_metrics( model._compile_weighted_metrics), target_tensors=targets) return updated_model