def from_saved_model(export_dir,
                     signature_def_key=None,
                     signature_def=None,
                     tags=None,
                     graph=None):
    """Constructs a `Predictor` from a `SavedModel` on disk.

  Args:
    export_dir: a path to a directory containing a `SavedModel`.
    signature_def_key: Optional string specifying the signature to use. If
      `None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used. Only one of
    `signature_def_key` and `signature_def`
    signature_def: A `SignatureDef` proto specifying the inputs and outputs
      for prediction. Only one of `signature_def_key` and `signature_def`
      should be specified.
    tags: Optional. Tags that will be used to retrieve the correct
      `SignatureDef`. Defaults to `DEFAULT_TAGS`.
    graph: Optional. The Tensorflow `graph` in which prediction should be
      done.

  Returns:
    An initialized `Predictor`.

  Raises:
    ValueError: More than one of `signature_def_key` and `signature_def` is
      specified.
  """
    return saved_model_predictor.SavedModelPredictor(
        export_dir,
        signature_def_key=signature_def_key,
        signature_def=signature_def,
        tags=tags,
        graph=graph)
    def testSpecifiedSignatureKey(self):
        """Test prediction with spedicified signature key."""
        np.random.seed(1234)
        for signature_def_key, op in KEYS_AND_OPS:
            x = np.random.rand()
            y = np.random.rand()
            expected_output = op(x, y)

            predictor = saved_model_predictor.SavedModelPredictor(
                export_dir=self._export_dir,
                signature_def_key=signature_def_key)

            output_tensor_name = predictor.fetch_tensors['outputs'].name
            self.assertRegexpMatches(output_tensor_name,
                                     signature_def_key,
                                     msg='Unexpected fetch tensor.')

            output = predictor({'x': x, 'y': y})['outputs']
            self.assertAlmostEqual(
                expected_output,
                output,
                places=3,
                msg='Failed for signature "{}." '
                'Got output {} for x = {} and y = {}'.format(
                    signature_def_key, output, x, y))
    def testSpecifiedTensors(self):
        """Test prediction with spedicified `Tensor`s."""
        np.random.seed(987)
        for key, op in KEYS_AND_OPS:
            x = np.random.rand()
            y = np.random.rand()
            expected_output = op(x, y)
            input_names = {'x': 'inputs/x:0', 'y': 'inputs/y:0'}
            output_names = {key: 'outputs/{}:0'.format(key)}
            predictor = saved_model_predictor.SavedModelPredictor(
                export_dir=self._export_dir,
                input_names=input_names,
                output_names=output_names)

            output_tensor_name = predictor.fetch_tensors[key].name
            self.assertRegexpMatches(output_tensor_name,
                                     key,
                                     msg='Unexpected fetch tensor.')

            output = predictor({'x': x, 'y': y})[key]
            self.assertAlmostEqual(
                expected_output,
                output,
                places=3,
                msg='Failed for signature "{}". '
                'Got output {} for x = {} and y = {}'.format(
                    key, output, x, y))
 def testDefault(self):
     """Test prediction with default signature."""
     np.random.seed(1111)
     x = np.random.rand()
     y = np.random.rand()
     predictor = saved_model_predictor.SavedModelPredictor(
         export_dir=self._export_dir)
     output = predictor({'x': x, 'y': y})['outputs']
     self.assertAlmostEqual(output, x + y, places=3)
    def testSpecifiedSignature(self):
        """Test prediction with spedicified signature definition."""
        np.random.seed(4444)
        for key, op in KEYS_AND_OPS:
            x = np.random.rand()
            y = np.random.rand()
            expected_output = op(x, y)

            inputs = {
                'x':
                meta_graph_pb2.TensorInfo(
                    name='inputs/x:0',
                    dtype=types_pb2.DT_FLOAT,
                    tensor_shape=tensor_shape_pb2.TensorShapeProto()),
                'y':
                meta_graph_pb2.TensorInfo(
                    name='inputs/y:0',
                    dtype=types_pb2.DT_FLOAT,
                    tensor_shape=tensor_shape_pb2.TensorShapeProto())
            }
            outputs = {
                key:
                meta_graph_pb2.TensorInfo(
                    name='outputs/{}:0'.format(key),
                    dtype=types_pb2.DT_FLOAT,
                    tensor_shape=tensor_shape_pb2.TensorShapeProto())
            }
            signature_def = signature_def_utils.build_signature_def(
                inputs=inputs,
                outputs=outputs,
                method_name='tensorflow/serving/regress')
            predictor = saved_model_predictor.SavedModelPredictor(
                export_dir=self._export_dir, signature_def=signature_def)

            output_tensor_name = predictor.fetch_tensors[key].name
            self.assertRegexpMatches(output_tensor_name,
                                     key,
                                     msg='Unexpected fetch tensor.')

            output = predictor({'x': x, 'y': y})[key]
            self.assertAlmostEqual(
                expected_output,
                output,
                places=3,
                msg='Failed for signature "{}". '
                'Got output {} for x = {} and y = {}'.format(
                    key, output, x, y))
Ejemplo n.º 6
0
def from_saved_model(export_dir,
                     signature_def_key=None,
                     signature_def=None,
                     input_names=None,
                     output_names=None,
                     tags=None,
                     graph=None,
                     config=None):
    """Constructs a `Predictor` from a `SavedModel` on disk.

  Args:
    export_dir: a path to a directory containing a `SavedModel`.
    signature_def_key: Optional string specifying the signature to use. If
      `None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used. Only one of
    `signature_def_key` and `signature_def`
    signature_def: A `SignatureDef` proto specifying the inputs and outputs
      for prediction. Only one of `signature_def_key` and `signature_def`
      should be specified.
      input_names: A dictionary mapping strings to `Tensor`s in the `SavedModel`
        that represent the input. The keys can be any string of the user's
        choosing.
      output_names: A dictionary mapping strings to `Tensor`s in the
        `SavedModel` that represent the output. The keys can be any string of
        the user's choosing.
    tags: Optional. Tags that will be used to retrieve the correct
      `SignatureDef`. Defaults to `DEFAULT_TAGS`.
    graph: Optional. The Tensorflow `graph` in which prediction should be
      done.
    config: `ConfigProto` proto used to configure the session.

  Returns:
    An initialized `Predictor`.

  Raises:
    ValueError: More than one of `signature_def_key` and `signature_def` is
      specified.
  """
    return saved_model_predictor.SavedModelPredictor(
        export_dir,
        signature_def_key=signature_def_key,
        signature_def=signature_def,
        input_names=input_names,
        output_names=output_names,
        tags=tags,
        graph=graph,
        config=config)
 def testSpecifiedGraph(self):
     """Test that the predictor remembers a specified `Graph`."""
     g = ops.Graph()
     predictor = saved_model_predictor.SavedModelPredictor(
         export_dir=self._export_dir, graph=g)
     self.assertEqual(predictor.graph, g)
 def testBadTagsFail(self):
     """Test that predictor construction fails for bad tags."""
     bad_tags_regex = ('.* could not be found in SavedModel')
     with self.assertRaisesRegexp(RuntimeError, bad_tags_regex):
         _ = saved_model_predictor.SavedModelPredictor(
             export_dir=self._export_dir, tags=('zomg, bad, tags'))