def test_eval_signature_def(self): loss = {"my_loss": constant_op.constant([0])} predictions = {u"output1": constant_op.constant(["foo"])} outputter = export_output_lib.EvalOutput(loss, predictions, None) receiver = {u"features": constant_op.constant(100, shape=(100, 2)), "labels": constant_op.constant(100, shape=(100, 1))} sig_def = outputter.as_signature_def(receiver) self.assertTrue("loss/my_loss" in sig_def.outputs) self.assertFalse("metrics/value" in sig_def.outputs) self.assertTrue("predictions/output1" in sig_def.outputs) self.assertTrue("features" in sig_def.inputs)
def export_outputs_for_mode(mode, serving_export_outputs=None, predictions=None, loss=None, metrics=None): """Util function for constructing a `ExportOutput` dict given a mode. The returned dict can be directly passed to `build_all_signature_defs` helper function as the `export_outputs` argument, used for generating a SignatureDef map. Args: mode: A `ModeKeys` specifying the mode. serving_export_outputs: Describes the output signatures to be exported to `SavedModel` and used during serving. Should be a dict or None. predictions: A dict of Tensors or single Tensor representing model predictions. This argument is only used if serving_export_outputs is not set. loss: A dict of Tensors or single Tensor representing calculated loss. metrics: A dict of (metric_value, update_op) tuples, or a single tuple. metric_value must be a Tensor, and update_op must be a Tensor or Op Returns: Dictionary mapping the a key to an `tf.estimator.export.ExportOutput` object The key is the expected SignatureDef key for the mode. Raises: ValueError: if an appropriate ExportOutput cannot be found for the mode. """ # TODO(b/113185250): move all model export helper functions into an util file. if mode == ModeKeys.PREDICT: return _get_export_outputs(serving_export_outputs, predictions) elif mode == ModeKeys.TRAIN: return { mode: export_output_lib.TrainOutput(loss=loss, predictions=predictions, metrics=metrics) } elif mode == ModeKeys.EVAL: return { mode: export_output_lib.EvalOutput(loss=loss, predictions=predictions, metrics=metrics) } else: raise ValueError( 'Export output type not found for mode: {}'.format(mode))