Exemple #1
0
    def testConvertDefaultSignatureRegressionToSignatureDef(self):
        signatures_proto = manifest_pb2.Signatures()
        regression_signature = manifest_pb2.RegressionSignature()
        regression_signature.input.CopyFrom(
            manifest_pb2.TensorBinding(
                tensor_name=signature_constants.REGRESS_INPUTS))
        regression_signature.output.CopyFrom(
            manifest_pb2.TensorBinding(
                tensor_name=signature_constants.REGRESS_OUTPUTS))
        signatures_proto.default_signature.regression_signature.CopyFrom(
            regression_signature)
        signature_def = bundle_shim._convert_default_signature_to_signature_def(
            signatures_proto)

        # Validate regression signature correctly copied over.
        self.assertEqual(signature_def.method_name,
                         signature_constants.REGRESS_METHOD_NAME)
        self.assertEqual(len(signature_def.inputs), 1)
        self.assertEqual(len(signature_def.outputs), 1)
        self.assertProtoEquals(
            signature_def.inputs[signature_constants.REGRESS_INPUTS],
            meta_graph_pb2.TensorInfo(name=signature_constants.REGRESS_INPUTS))
        self.assertProtoEquals(
            signature_def.outputs[signature_constants.REGRESS_OUTPUTS],
            meta_graph_pb2.TensorInfo(
                name=signature_constants.REGRESS_OUTPUTS))
  def testConvertDefaultSignatureRegressionToSignatureDef(self):
    signatures_proto = manifest_pb2.Signatures()
    regression_signature = manifest_pb2.RegressionSignature()
    regression_signature.input.CopyFrom(
        manifest_pb2.TensorBinding(
            tensor_name=signature_constants.REGRESS_INPUTS))
    regression_signature.output.CopyFrom(
        manifest_pb2.TensorBinding(
            tensor_name=signature_constants.REGRESS_OUTPUTS))
    signatures_proto.default_signature.regression_signature.CopyFrom(
        regression_signature)
    signature_def = bundle_shim._convert_default_signature_to_signature_def(
        signatures_proto)

    # Validate regression signature correctly copied over.
    self.assertEqual(signature_def.method_name,
                     signature_constants.REGRESS_METHOD_NAME)
    self.assertEqual(len(signature_def.inputs), 1)
    self.assertEqual(len(signature_def.outputs), 1)
    self.assertProtoEquals(
        signature_def.inputs[signature_constants.REGRESS_INPUTS],
        meta_graph_pb2.TensorInfo(name=signature_constants.REGRESS_INPUTS))
    self.assertProtoEquals(
        signature_def.outputs[signature_constants.REGRESS_OUTPUTS],
        meta_graph_pb2.TensorInfo(name=signature_constants.REGRESS_OUTPUTS))
 def testConvertDefaultSignatureBadTypeToSignatureDef(self):
   signatures_proto = manifest_pb2.Signatures()
   generic_signature = manifest_pb2.GenericSignature()
   signatures_proto.default_signature.generic_signature.CopyFrom(
       generic_signature)
   with self.assertRaises(RuntimeError) as cm:
     _ = bundle_shim._convert_default_signature_to_signature_def(
         signatures_proto)
Exemple #4
0
 def testConvertDefaultSignatureBadTypeToSignatureDef(self):
     signatures_proto = manifest_pb2.Signatures()
     generic_signature = manifest_pb2.GenericSignature()
     signatures_proto.default_signature.generic_signature.CopyFrom(
         generic_signature)
     with self.assertRaises(RuntimeError) as cm:
         _ = bundle_shim._convert_default_signature_to_signature_def(
             signatures_proto)
 def testConvertDefaultSignatureGenericToSignatureDef(self):
     signatures_proto = manifest_pb2.Signatures()
     generic_signature = manifest_pb2.GenericSignature()
     signatures_proto.default_signature.generic_signature.CopyFrom(
         generic_signature)
     signature_def = bundle_shim._convert_default_signature_to_signature_def(
         signatures_proto)
     self.assertEquals(signature_def, None)
 def testConvertDefaultSignatureGenericToSignatureDef(self):
   signatures_proto = manifest_pb2.Signatures()
   generic_signature = manifest_pb2.GenericSignature()
   signatures_proto.default_signature.generic_signature.CopyFrom(
       generic_signature)
   signature_def = bundle_shim._convert_default_signature_to_signature_def(
       signatures_proto)
   self.assertEquals(signature_def, None)
Exemple #7
0
    def testConvertDefaultSignatureClassificationToSignatureDef(self):
        signatures_proto = manifest_pb2.Signatures()
        classification_signature = manifest_pb2.ClassificationSignature()
        classification_signature.input.CopyFrom(
            manifest_pb2.TensorBinding(
                tensor_name=signature_constants.CLASSIFY_INPUTS))
        classification_signature.classes.CopyFrom(
            manifest_pb2.TensorBinding(
                tensor_name=signature_constants.CLASSIFY_OUTPUT_CLASSES))
        classification_signature.scores.CopyFrom(
            manifest_pb2.TensorBinding(
                tensor_name=signature_constants.CLASSIFY_OUTPUT_SCORES))
        signatures_proto.default_signature.classification_signature.CopyFrom(
            classification_signature)

        signatures_proto.default_signature.classification_signature.CopyFrom(
            classification_signature)
        signature_def = bundle_shim._convert_default_signature_to_signature_def(
            signatures_proto)

        # Validate classification signature correctly copied over.
        self.assertEqual(signature_def.method_name,
                         signature_constants.CLASSIFY_METHOD_NAME)
        self.assertEqual(len(signature_def.inputs), 1)
        self.assertEqual(len(signature_def.outputs), 2)
        self.assertProtoEquals(
            signature_def.inputs[signature_constants.CLASSIFY_INPUTS],
            meta_graph_pb2.TensorInfo(
                name=signature_constants.CLASSIFY_INPUTS))
        self.assertProtoEquals(
            signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_SCORES],
            meta_graph_pb2.TensorInfo(
                name=signature_constants.CLASSIFY_OUTPUT_SCORES))
        self.assertProtoEquals(
            signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES],
            meta_graph_pb2.TensorInfo(
                name=signature_constants.CLASSIFY_OUTPUT_CLASSES))
  def testConvertDefaultSignatureClassificationToSignatureDef(self):
    signatures_proto = manifest_pb2.Signatures()
    classification_signature = manifest_pb2.ClassificationSignature()
    classification_signature.input.CopyFrom(
        manifest_pb2.TensorBinding(
            tensor_name=signature_constants.CLASSIFY_INPUTS))
    classification_signature.classes.CopyFrom(
        manifest_pb2.TensorBinding(
            tensor_name=signature_constants.CLASSIFY_OUTPUT_CLASSES))
    classification_signature.scores.CopyFrom(
        manifest_pb2.TensorBinding(
            tensor_name=signature_constants.CLASSIFY_OUTPUT_SCORES))
    signatures_proto.default_signature.classification_signature.CopyFrom(
        classification_signature)

    signatures_proto.default_signature.classification_signature.CopyFrom(
        classification_signature)
    signature_def = bundle_shim._convert_default_signature_to_signature_def(
        signatures_proto)

    # Validate classification signature correctly copied over.
    self.assertEqual(signature_def.method_name,
                     signature_constants.CLASSIFY_METHOD_NAME)
    self.assertEqual(len(signature_def.inputs), 1)
    self.assertEqual(len(signature_def.outputs), 2)
    self.assertProtoEquals(
        signature_def.inputs[signature_constants.CLASSIFY_INPUTS],
        meta_graph_pb2.TensorInfo(name=signature_constants.CLASSIFY_INPUTS))
    self.assertProtoEquals(
        signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_SCORES],
        meta_graph_pb2.TensorInfo(
            name=signature_constants.CLASSIFY_OUTPUT_SCORES))
    self.assertProtoEquals(
        signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES],
        meta_graph_pb2.TensorInfo(
            name=signature_constants.CLASSIFY_OUTPUT_CLASSES))