Exemplo n.º 1
0
  def testConvertNamedSignatureToSignatureDef(self):
    signatures_proto = manifest_pb2.Signatures()
    generic_signature = manifest_pb2.GenericSignature()
    generic_signature.map["input_key"].CopyFrom(
        manifest_pb2.TensorBinding(tensor_name="input"))
    signatures_proto.named_signatures[
        signature_constants.PREDICT_INPUTS].generic_signature.CopyFrom(
            generic_signature)

    generic_signature = manifest_pb2.GenericSignature()
    generic_signature.map["output_key"].CopyFrom(
        manifest_pb2.TensorBinding(tensor_name="output"))
    signatures_proto.named_signatures[
        signature_constants.PREDICT_OUTPUTS].generic_signature.CopyFrom(
            generic_signature)
    signature_def = bundle_shim._convert_named_signatures_to_signature_def(
        signatures_proto)
    self.assertEqual(signature_def.method_name,
                      signature_constants.PREDICT_METHOD_NAME)
    self.assertEqual(len(signature_def.inputs), 1)
    self.assertEqual(len(signature_def.outputs), 1)
    self.assertProtoEquals(
        signature_def.inputs["input_key"],
        meta_graph_pb2.TensorInfo(name="input"))
    self.assertProtoEquals(
        signature_def.outputs["output_key"],
        meta_graph_pb2.TensorInfo(name="output"))
Exemplo n.º 2
0
    def testConvertNamedSignatureToSignatureDef(self):
        signatures_proto = manifest_pb2.Signatures()
        generic_signature = manifest_pb2.GenericSignature()
        generic_signature.map["input_key"].CopyFrom(
            manifest_pb2.TensorBinding(tensor_name="input"))
        signatures_proto.named_signatures[
            signature_constants.PREDICT_INPUTS].generic_signature.CopyFrom(
                generic_signature)

        generic_signature = manifest_pb2.GenericSignature()
        generic_signature.map["output_key"].CopyFrom(
            manifest_pb2.TensorBinding(tensor_name="output"))
        signatures_proto.named_signatures[
            signature_constants.PREDICT_OUTPUTS].generic_signature.CopyFrom(
                generic_signature)
        signature_def = bundle_shim._convert_named_signatures_to_signature_def(
            signatures_proto)
        self.assertEqual(signature_def.method_name,
                         signature_constants.PREDICT_METHOD_NAME)
        self.assertEqual(len(signature_def.inputs), 1)
        self.assertEqual(len(signature_def.outputs), 1)
        self.assertProtoEquals(signature_def.inputs["input_key"],
                               meta_graph_pb2.TensorInfo(name="input"))
        self.assertProtoEquals(signature_def.outputs["output_key"],
                               meta_graph_pb2.TensorInfo(name="output"))
Exemplo n.º 3
0
 def testConvertNamedSignatureNonGenericToSignatureDef(self):
   signatures_proto = manifest_pb2.Signatures()
   regression_signature = manifest_pb2.RegressionSignature()
   signatures_proto.named_signatures[
       signature_constants.PREDICT_INPUTS].regression_signature.CopyFrom(
           regression_signature)
   with self.assertRaises(RuntimeError) as cm:
     _ = bundle_shim._convert_named_signatures_to_signature_def(
         signatures_proto)
   signatures_proto = manifest_pb2.Signatures()
   classification_signature = manifest_pb2.ClassificationSignature()
   signatures_proto.named_signatures[
       signature_constants.PREDICT_INPUTS].classification_signature.CopyFrom(
           classification_signature)
   with self.assertRaises(RuntimeError) as cm:
     _ = bundle_shim._convert_named_signatures_to_signature_def(
         signatures_proto)
 def testConvertNamedSignatureNonGenericToSignatureDef(self):
   signatures_proto = manifest_pb2.Signatures()
   regression_signature = manifest_pb2.RegressionSignature()
   signatures_proto.named_signatures[
       signature_constants.PREDICT_INPUTS].regression_signature.CopyFrom(
           regression_signature)
   with self.assertRaises(RuntimeError):
     _ = bundle_shim._convert_named_signatures_to_signature_def(
         signatures_proto)
   signatures_proto = manifest_pb2.Signatures()
   classification_signature = manifest_pb2.ClassificationSignature()
   signatures_proto.named_signatures[
       signature_constants.PREDICT_INPUTS].classification_signature.CopyFrom(
           classification_signature)
   with self.assertRaises(RuntimeError):
     _ = bundle_shim._convert_named_signatures_to_signature_def(
         signatures_proto)