def body_wrapper(*inputs): """Wrapper around `body` that handles infeed queues and control deps.""" inputs = list(inputs) # Discards the dummy output added for arity-0 loops. if input_arity == 0: inputs = [] # Runs `body` with the dequeue_ops appended. if infeed_queue: number_of_shards = tpu_function.get_tpu_context().number_of_shards if number_of_shards is None: raise ValueError( "Can't build training loop with infeed when there is " "no tpu_shard_context. Are you building a loop or " "graph directly rather than from inside tpu.rewrite, " "tpu.batch_parallel, tpu.shard, or tpu.replicate?") infeed_queue.set_number_of_shards(number_of_shards) dequeue_ops = [d for d in infeed_queue.generate_dequeue_op()] else: dequeue_ops = [] outputs = body(*(inputs + dequeue_ops)) # If the computation only returned one value, make it a tuple. if not isinstance(outputs, (list, tuple)): outputs = (outputs, ) outputs = [ o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) for o in outputs ] # Separates the returned Operations and Tensors. output_operations = [ o for o in outputs if isinstance(o, ops.Operation) ] output_tensors = [ o for o in outputs if not isinstance(o, ops.Operation) ] if outputs != output_tensors + output_operations: raise ValueError( "TPU training loop body must return zero or more Tensor values " "followed by zero or more Operations.") output_types = [op.dtype for op in output_tensors] if input_types != output_types: raise TypeError( "Mismatch between input types and output types for training loop " "body: {} vs {}".format(input_types, output_types)) # Add the dequeue operations to output_operations to ensure they are run # by the loop, even if the programmer's loop body does not use them. output_operations += dequeue_ops # Add a dummy output, if needed. if not output_tensors: output_tensors = array_ops.constant(0) if output_operations: # TODO(phawkins): in principle this is too restrictive since it serializes # the training loop steps. In practice it does not matter since this loop # will be compiled by XLA. output_tensors = control_flow_ops.tuple( output_tensors, control_inputs=output_operations) if tensor_tracer.TensorTracer.is_enabled(): num_replicas = tpu_function.get_tpu_context().number_of_shards if num_replicas is None: num_replicas = 1 tt = tensor_tracer.TensorTracer() output_tensors = tt.trace_tpu(ops.get_default_graph(), output_tensors, None, num_replicas) return output_tensors
def __new__(cls, mode, predictions=None, loss=None, train_op=None, eval_metric_ops=None, export_outputs=None, training_chief_hooks=None, training_hooks=None, scaffold=None, evaluation_hooks=None, prediction_hooks=None): """Creates a validated `EstimatorSpec` instance. Depending on the value of `mode`, different arguments are required. Namely * For `mode == ModeKeys.TRAIN`: required fields are `loss` and `train_op`. * For `mode == ModeKeys.EVAL`: required field is `loss`. * For `mode == ModeKeys.PREDICT`: required fields are `predictions`. model_fn can populate all arguments independent of mode. In this case, some arguments will be ignored by an `Estimator`. E.g. `train_op` will be ignored in eval and infer modes. Example: ```python def my_model_fn(features, labels, mode): predictions = ... loss = ... train_op = ... return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, loss=loss, train_op=train_op) ``` Alternatively, model_fn can just populate the arguments appropriate to the given mode. Example: ```python def my_model_fn(features, labels, mode): if (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL): loss = ... else: loss = None if mode == tf.estimator.ModeKeys.TRAIN: train_op = ... else: train_op = None if mode == tf.estimator.ModeKeys.PREDICT: predictions = ... else: predictions = None return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, loss=loss, train_op=train_op) ``` Args: mode: A `ModeKeys`. Specifies if this is training, evaluation or prediction. predictions: Predictions `Tensor` or dict of `Tensor`. loss: Training loss `Tensor`. Must be either scalar, or with shape `[1]`. train_op: Op for the training step. eval_metric_ops: Dict of metric results keyed by name. The values of the dict can be one of the following: (1) instance of `Metric` class. (2) Results of calling a metric function, namely a `(metric_tensor, update_op)` tuple. `metric_tensor` should be evaluated without any impact on state (typically is a pure computation results based on variables.). For example, it should not trigger the `update_op` or requires any input fetching. export_outputs: Describes the output signatures to be exported to `SavedModel` and used during serving. A dict `{name: output}` where: * name: An arbitrary name for this output. * output: an `ExportOutput` object such as `ClassificationOutput`, `RegressionOutput`, or `PredictOutput`. Single-headed models only need to specify one entry in this dictionary. Multi-headed models should specify one entry for each head, one of which must be named using `tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY`. If no entry is provided, a default `PredictOutput` mapping to `predictions` will be created. training_chief_hooks: Iterable of `tf.train.SessionRunHook` objects to run on the chief worker during training. training_hooks: Iterable of `tf.train.SessionRunHook` objects to run on all workers during training. scaffold: A `tf.train.Scaffold` object that can be used to set initialization, saver, and more to be used in training. evaluation_hooks: Iterable of `tf.train.SessionRunHook` objects to run during evaluation. prediction_hooks: Iterable of `tf.train.SessionRunHook` objects to run during predictions. Returns: A validated `EstimatorSpec` object. Raises: ValueError: If validation fails. TypeError: If any of the arguments is not the expected type. """ train_op = _validate_estimator_spec_train_op(train_op, mode) loss = _validate_estimator_spec_loss(loss, mode) predictions = _validate_estimator_spec_predictions(predictions, mode) export_outputs = _validate_estimator_spec_export_outputs( export_outputs, predictions, mode) training_hooks = _validate_estimator_spec_hooks(training_hooks) evaluation_hooks = _validate_estimator_spec_hooks(evaluation_hooks) prediction_hooks = _validate_estimator_spec_hooks(prediction_hooks) training_chief_hooks = _validate_estimator_spec_hooks( training_chief_hooks) eval_metric_ops = _validate_eval_metric_ops(eval_metric_ops) scaffold = _validate_scaffold(scaffold) # By default, Tensor Tracer is not enabled and the block below is an no-op. if tensor_tracer.TensorTracer.is_enabled() and train_op is not None: # If Tensor Tracer is enabled via environment flags, loss and train_op # will be used to determine the execution path that will be traced. A # `tf.identity` of loss that enforces the execution of tracing ops will be # returned. tt = tensor_tracer.TensorTracer() loss = tt.trace_cpu(tf.compat.v1.get_default_graph(), loss, train_op) return super(EstimatorSpec, cls).__new__(cls, mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops, export_outputs=export_outputs, training_chief_hooks=training_chief_hooks, training_hooks=training_hooks, scaffold=scaffold, evaluation_hooks=evaluation_hooks, prediction_hooks=prediction_hooks)