def _clone_keras_tensor(kt): """Create an identical keras_tensor based on the input. We use keras_tensor_to_placeholder and keras_tensor_from_tensor to make sure inferred shape are not lost during the copy. Args: kt: the input KerasTensor. Returns: An identical copy of the input KerasTensor. """ # Create a scratch graph since we don't intend to use the placeholders. with backend._scratch_graph() as scratch_graph: with scratch_graph.as_default(): placeholder = keras_tensor.keras_tensor_to_placeholder(kt) return keras_tensor.keras_tensor_from_tensor(placeholder)
def _make_eager_execution_function(model, mode): """Makes function to run one step of distributed model eager execution.""" def _per_replica_function(model): f = model._make_execution_function(mode) return (f.inputs, f.outputs) # NOTE(priyag): Try creating a new FuncGraph within DS scope instead of # using the global one. strategy = model._distribution_strategy global_graph = backend.get_graph() with global_graph.as_default(), strategy.scope(): # First we gather the relevant portions of the model across all # replicas. `backend._scratch_graph(global_graph)` signals to Keras # that it should not lift to a separate graph when creating the # per-replica functions. with backend._scratch_graph(global_graph): # Create train ops on each of the devices when we call # `_per_replica_fit_function`. grouped = strategy.extended.call_for_each_replica( _per_replica_function, args=(get_distributed_model(model, mode), ), ) grouped_inputs, grouped_outputs = grouped # Unwrap all the per device values returned from # `call_for_each_replica`. Unwrapping per device values gives you a # list of values that can be used to construct a new train function # that is composed of inputs/outputs on all the devices over which # the model is distributed. (all_inputs, all_outputs, _, _) = unwrap_values( strategy, grouped_inputs, grouped_outputs, with_loss_tensor=(mode != ModeKeys.PREDICT), ) # Finally, a joint Keras function is created; this one will be created # in a separate FuncGraph. return backend.function( all_inputs, all_outputs, name="eager_distributed_{}_function".format(mode), )