예제 #1
0
def export_outputs_for_mode(mode,
                            serving_export_outputs=None,
                            predictions=None,
                            loss=None,
                            metrics=None):
    """Util function for constructing a `ExportOutput` dict given a mode.

  The returned dict can be directly passed to `build_all_signature_defs` helper
  function as the `export_outputs` argument, used for generating a SignatureDef
  map.

  Args:
    mode: A `ModeKeys` specifying the mode.
    serving_export_outputs: Describes the output signatures to be exported to
      `SavedModel` and used during serving. Should be a dict or None.
    predictions: A dict of Tensors or single Tensor representing model
        predictions. This argument is only used if serving_export_outputs is not
        set.
    loss: A dict of Tensors or single Tensor representing calculated loss.
    metrics: A dict of (metric_value, update_op) tuples, or a single tuple.
      metric_value must be a Tensor, and update_op must be a Tensor or Op

  Returns:
    Dictionary mapping the a key to an `tf.estimator.export.ExportOutput` object
    The key is the expected SignatureDef key for the mode.

  Raises:
    ValueError: if an appropriate ExportOutput cannot be found for the mode.
  """
    # TODO(b/113185250): move all model export helper functions into an util file.
    if mode == ModeKeys.PREDICT:
        return _get_export_outputs(serving_export_outputs, predictions)
    elif mode == ModeKeys.TRAIN:
        return {
            mode:
            export_output_lib.TrainOutput(loss=loss,
                                          predictions=predictions,
                                          metrics=metrics)
        }
    elif mode == ModeKeys.EVAL:
        return {
            mode:
            export_output_lib.EvalOutput(loss=loss,
                                         predictions=predictions,
                                         metrics=metrics)
        }
    else:
        raise ValueError(
            'Export output type not found for mode: {}'.format(mode))
예제 #2
0
    def test_eval_signature_def(self):
        loss = {"my_loss": constant_op.constant([0])}
        predictions = {u"output1": constant_op.constant(["foo"])}

        outputter = export_output_lib.EvalOutput(loss, predictions, None)

        receiver = {
            u"features": constant_op.constant(100, shape=(100, 2)),
            "labels": constant_op.constant(100, shape=(100, 1))
        }
        sig_def = outputter.as_signature_def(receiver)

        self.assertTrue("loss/my_loss" in sig_def.outputs)
        self.assertFalse("metrics/value" in sig_def.outputs)
        self.assertTrue("predictions/output1" in sig_def.outputs)
        self.assertTrue("features" in sig_def.inputs)