Ejemplo n.º 1
0
    def body_wrapper(*inputs):
        """Wrapper around `body` that handles infeed queues and control deps."""
        inputs = list(inputs)

        # Discards the dummy output added for arity-0 loops.
        if input_arity == 0:
            inputs = []

        # Runs `body` with the dequeue_ops appended.
        if infeed_queue:
            number_of_shards = tpu_function.get_tpu_context().number_of_shards
            if number_of_shards is None:
                raise ValueError(
                    "Can't build training loop with infeed when there is "
                    "no tpu_shard_context. Are you building a loop or "
                    "graph directly rather than from inside tpu.rewrite, "
                    "tpu.batch_parallel, tpu.shard, or tpu.replicate?")
            infeed_queue.set_number_of_shards(number_of_shards)
            dequeue_ops = [d for d in infeed_queue.generate_dequeue_op()]
        else:
            dequeue_ops = []
        outputs = body(*(inputs + dequeue_ops))

        # If the computation only returned one value, make it a tuple.
        if not isinstance(outputs, (list, tuple)):
            outputs = (outputs, )

        outputs = [
            o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
            for o in outputs
        ]

        # Separates the returned Operations and Tensors.
        output_operations = [
            o for o in outputs if isinstance(o, ops.Operation)
        ]
        output_tensors = [
            o for o in outputs if not isinstance(o, ops.Operation)
        ]

        if outputs != output_tensors + output_operations:
            raise ValueError(
                "TPU training loop body must return zero or more Tensor values "
                "followed by zero or more Operations.")

        output_types = [op.dtype for op in output_tensors]
        if input_types != output_types:
            raise TypeError(
                "Mismatch between input types and output types for training loop "
                "body: {} vs {}".format(input_types, output_types))

        # Add the dequeue operations to output_operations to ensure they are run
        # by the loop, even if the programmer's loop body does not use them.
        output_operations += dequeue_ops

        # Add a dummy output, if needed.
        if not output_tensors:
            output_tensors = array_ops.constant(0)

        if output_operations:
            # TODO(phawkins): in principle this is too restrictive since it serializes
            # the training loop steps. In practice it does not matter since this loop
            # will be compiled by XLA.
            output_tensors = control_flow_ops.tuple(
                output_tensors, control_inputs=output_operations)

        if tensor_tracer.TensorTracer.is_enabled():
            num_replicas = tpu_function.get_tpu_context().number_of_shards
            if num_replicas is None:
                num_replicas = 1
            tt = tensor_tracer.TensorTracer()
            output_tensors = tt.trace_tpu(ops.get_default_graph(),
                                          output_tensors, None, num_replicas)
        return output_tensors
Ejemplo n.º 2
0
    def __new__(cls,
                mode,
                predictions=None,
                loss=None,
                train_op=None,
                eval_metric_ops=None,
                export_outputs=None,
                training_chief_hooks=None,
                training_hooks=None,
                scaffold=None,
                evaluation_hooks=None,
                prediction_hooks=None):
        """Creates a validated `EstimatorSpec` instance.

    Depending on the value of `mode`, different arguments are required. Namely

    * For `mode == ModeKeys.TRAIN`: required fields are `loss` and `train_op`.
    * For `mode == ModeKeys.EVAL`: required field is `loss`.
    * For `mode == ModeKeys.PREDICT`: required fields are `predictions`.

    model_fn can populate all arguments independent of mode. In this case, some
    arguments will be ignored by an `Estimator`. E.g. `train_op` will be
    ignored in eval and infer modes. Example:

    ```python
    def my_model_fn(features, labels, mode):
      predictions = ...
      loss = ...
      train_op = ...
      return tf.estimator.EstimatorSpec(
          mode=mode,
          predictions=predictions,
          loss=loss,
          train_op=train_op)
    ```

    Alternatively, model_fn can just populate the arguments appropriate to the
    given mode. Example:

    ```python
    def my_model_fn(features, labels, mode):
      if (mode == tf.estimator.ModeKeys.TRAIN or
          mode == tf.estimator.ModeKeys.EVAL):
        loss = ...
      else:
        loss = None
      if mode == tf.estimator.ModeKeys.TRAIN:
        train_op = ...
      else:
        train_op = None
      if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = ...
      else:
        predictions = None

      return tf.estimator.EstimatorSpec(
          mode=mode,
          predictions=predictions,
          loss=loss,
          train_op=train_op)
    ```

    Args:
      mode: A `ModeKeys`. Specifies if this is training, evaluation or
        prediction.
      predictions: Predictions `Tensor` or dict of `Tensor`.
      loss: Training loss `Tensor`. Must be either scalar, or with shape `[1]`.
      train_op: Op for the training step.
      eval_metric_ops: Dict of metric results keyed by name.
        The values of the dict can be one of the following: (1) instance of
          `Metric` class. (2) Results of calling a metric function, namely a
          `(metric_tensor, update_op)` tuple. `metric_tensor` should be
          evaluated without any impact on state (typically is a pure computation
          results based on variables.). For example, it should not trigger the
          `update_op` or requires any input fetching.
      export_outputs: Describes the output signatures to be exported to
        `SavedModel` and used during serving.
        A dict `{name: output}` where:
        * name: An arbitrary name for this output.
        * output: an `ExportOutput` object such as `ClassificationOutput`,
          `RegressionOutput`, or `PredictOutput`. Single-headed models only need
          to specify one entry in this dictionary. Multi-headed models should
          specify one entry for each head, one of which must be named using
          `tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY`.
          If no entry is provided, a default `PredictOutput` mapping to
          `predictions` will be created.
      training_chief_hooks: Iterable of `tf.train.SessionRunHook` objects to run
        on the chief worker during training.
      training_hooks: Iterable of `tf.train.SessionRunHook` objects to run on
        all workers during training.
      scaffold: A `tf.train.Scaffold` object that can be used to set
        initialization, saver, and more to be used in training.
      evaluation_hooks: Iterable of `tf.train.SessionRunHook` objects to run
        during evaluation.
      prediction_hooks: Iterable of `tf.train.SessionRunHook` objects to run
        during predictions.

    Returns:
      A validated `EstimatorSpec` object.

    Raises:
      ValueError: If validation fails.
      TypeError: If any of the arguments is not the expected type.
    """
        train_op = _validate_estimator_spec_train_op(train_op, mode)
        loss = _validate_estimator_spec_loss(loss, mode)
        predictions = _validate_estimator_spec_predictions(predictions, mode)
        export_outputs = _validate_estimator_spec_export_outputs(
            export_outputs, predictions, mode)
        training_hooks = _validate_estimator_spec_hooks(training_hooks)
        evaluation_hooks = _validate_estimator_spec_hooks(evaluation_hooks)
        prediction_hooks = _validate_estimator_spec_hooks(prediction_hooks)
        training_chief_hooks = _validate_estimator_spec_hooks(
            training_chief_hooks)
        eval_metric_ops = _validate_eval_metric_ops(eval_metric_ops)
        scaffold = _validate_scaffold(scaffold)

        # By default, Tensor Tracer is not enabled and the block below is an no-op.
        if tensor_tracer.TensorTracer.is_enabled() and train_op is not None:
            # If Tensor Tracer is enabled via environment flags, loss and train_op
            # will be used to determine the execution path that will be traced. A
            # `tf.identity` of loss that enforces the execution of tracing ops will be
            # returned.
            tt = tensor_tracer.TensorTracer()
            loss = tt.trace_cpu(tf.compat.v1.get_default_graph(), loss,
                                train_op)

        return super(EstimatorSpec,
                     cls).__new__(cls,
                                  mode=mode,
                                  predictions=predictions,
                                  loss=loss,
                                  train_op=train_op,
                                  eval_metric_ops=eval_metric_ops,
                                  export_outputs=export_outputs,
                                  training_chief_hooks=training_chief_hooks,
                                  training_hooks=training_hooks,
                                  scaffold=scaffold,
                                  evaluation_hooks=evaluation_hooks,
                                  prediction_hooks=prediction_hooks)