Esempio n. 1
def _make_eager_execution_function(model, mode):
    """Makes function to run one step of distributed model eager execution."""
    def _per_replica_function(model):
        f = model._make_execution_function(mode)
        return (f.inputs, f.outputs)

    # NOTE(priyag): Try creating a new FuncGraph within DS scope instead of using
    # the global one.
    strategy = model._distribution_strategy
    global_graph = backend.get_graph()

    with global_graph.as_default(), strategy.scope():
        # First we gather the relevant portions of the model across all replicas.
        # `backend._scratch_graph(global_graph)` signals to Keras that it should not
        # lift to a separate graph when creating the per-replica functions.
        with backend._scratch_graph(global_graph):
            # Create train ops on each of the devices when we call
            # `_per_replica_fit_function`.
            grouped = strategy.extended.call_for_each_replica(
                args=(get_distributed_model(model, mode), ))
            grouped_inputs, grouped_outputs = grouped

            # Unwrap all the per device values returned from `call_for_each_replica`.
            # Unwrapping per device values gives you a list of values that can be
            # used to construct a new train function that is composed of
            # inputs/outputs on all the devices over which the model is distributed.
            (all_inputs, all_outputs, _,
             _) = unwrap_values(strategy,
                                with_loss_tensor=(mode != ModeKeys.PREDICT))

        # Finally, a joint Keras function is created; this one will be created in
        # a separate FuncGraph.
        return backend.function(
def _make_eager_execution_function(model, mode):
  """Makes function to run one step of distributed model eager execution."""
  def _per_device_function(model):
    f = model._make_execution_function(mode)
    return (f.inputs, f.outputs)

  # NOTE(priyag): Try creating a new FuncGraph within DS scope instead of using
  # the global one.
  strategy = model._distribution_strategy
  global_graph = K.get_graph()

  with global_graph.as_default(), strategy.scope():
    # First we gather the relevant portions of the model across all replicas.
    # `K._scratch_graph(global_graph)` signals to Keras that it should not
    # lift to a separate graph when creating the per-replica functions.
    with K._scratch_graph(global_graph):
      # Create train ops on each of the devices when we call
      # `_per_device_fit_function`.
      grouped = strategy.extended.call_for_each_replica(
          _per_device_function, args=(get_distributed_model(model, mode),))
      grouped_inputs, grouped_outputs = grouped

      # Unwrap all the per device values returned from `call_for_each_replica`.
      # Unwrapping per device values gives you a list of values that can be
      # used to construct a new train function that is composed of
      # inputs/outputs on all the devices over which the model is distributed.
      (all_inputs, all_outputs, _, _) = unwrap_values(
          with_loss_tensor=(mode != ModeKeys.PREDICT))

    # Finally, a joint Keras function is created; this one will be created in
    # a separate FuncGraph.
    return K.function(