def _step_fn(ctx, inputs):
        """A step fn that returns update ops."""
        if isinstance(inputs, (tuple, list)) and len(inputs) == 2:
            inputs, targets = inputs
        else:
            targets = None

        # When input feature is a dictionary of tensors, dictionary is flattended
        # to an array and passed as a model input. This results in input mismatch
        # when model input layer names are not sorted in alphabetical order as
        # `nest.flatten()`sorts dictionary elements by keys. As so, transform input
        # tensors into an array and order it along `model._feed_input_names`.
        if isinstance(inputs, dict):
            inputs = [
                inputs[input_name] for input_name in model._feed_input_names
            ]

        _build_model(strategy, model, mode, inputs, targets)

        (
            grouped_inputs,
            grouped_outputs,
            grouped_updates,
            grouped_session_args,
        ) = strategy.extended.call_for_each_replica(
            _per_replica_execution_function,
            args=(dist_utils.get_distributed_model(model, mode), mode),
        )
        (
            all_inputs,
            all_outputs,
            all_updates,
            all_session_args,
        ) = dist_utils.unwrap_values(
            strategy,
            grouped_inputs,
            grouped_outputs,
            grouped_updates,
            grouped_session_args,
        )
        combined_fn = backend.function(all_inputs,
                                       all_outputs,
                                       updates=all_updates,
                                       name="distributed_" + str(mode) +
                                       "_function",
                                       **all_session_args)

        for label, output in zip(output_labels, combined_fn.outputs):
            if label == "loss":
                reduce_op = tf.distribute.ReduceOp.SUM
            else:
                # We reduce all other metrics using mean for now. This is temporary
                # workaround until new metrics are in place.
                reduce_op = tf.distribute.ReduceOp.MEAN
            ctx.set_last_step_output(label, output, reduce_op)

        # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
        # feed_dict, session kwargs, run options, run_metadata for now. These should
        # be handled appropriately
        return combined_fn.updates_op
예제 #2
0
  def _predict_step_fn(inputs):
    """A fn that returns output of single prediction step."""

    (tf.distribute.get_replica_context().merge_call(
        _build_model, args=(model, mode, inputs)))

    (_, outputs, updates, _) = _per_replica_execution_function(
        dist_utils.get_distributed_model(model, mode), mode)

    with tf.control_dependencies([updates]):
      return [tf.identity(out) for out in outputs]
예제 #3
0
  def _test_step_fn(inputs):
    """A fn that returns output of single test step."""
    if isinstance(inputs, (tuple, list)) and len(inputs) == 2:
      inputs, targets = inputs
    else:
      targets = None

    (tf.distribute.get_replica_context().merge_call(
        _build_model, args=(model, mode, inputs, targets)))

    (_, outputs, updates, _) = _per_replica_execution_function(
        dist_utils.get_distributed_model(model, mode), mode)
    with tf.control_dependencies([updates]):
      return [tf.identity(out) for out in outputs]