def testConvertNamedSignatureToSignatureDef(self): signatures_proto = manifest_pb2.Signatures() generic_signature = manifest_pb2.GenericSignature() generic_signature.map["input_key"].CopyFrom( manifest_pb2.TensorBinding(tensor_name="input")) signatures_proto.named_signatures[ signature_constants.PREDICT_INPUTS].generic_signature.CopyFrom( generic_signature) generic_signature = manifest_pb2.GenericSignature() generic_signature.map["output_key"].CopyFrom( manifest_pb2.TensorBinding(tensor_name="output")) signatures_proto.named_signatures[ signature_constants.PREDICT_OUTPUTS].generic_signature.CopyFrom( generic_signature) signature_def = bundle_shim._convert_named_signatures_to_signature_def( signatures_proto) self.assertEqual(signature_def.method_name, signature_constants.PREDICT_METHOD_NAME) self.assertEqual(len(signature_def.inputs), 1) self.assertEqual(len(signature_def.outputs), 1) self.assertProtoEquals(signature_def.inputs["input_key"], meta_graph_pb2.TensorInfo(name="input")) self.assertProtoEquals(signature_def.outputs["output_key"], meta_graph_pb2.TensorInfo(name="output"))
def testConvertDefaultSignatureRegressionToSignatureDef(self): signatures_proto = manifest_pb2.Signatures() regression_signature = manifest_pb2.RegressionSignature() regression_signature.input.CopyFrom( manifest_pb2.TensorBinding( tensor_name=signature_constants.REGRESS_INPUTS)) regression_signature.output.CopyFrom( manifest_pb2.TensorBinding( tensor_name=signature_constants.REGRESS_OUTPUTS)) signatures_proto.default_signature.regression_signature.CopyFrom( regression_signature) signature_def = bundle_shim._convert_default_signature_to_signature_def( signatures_proto) # Validate regression signature correctly copied over. self.assertEqual(signature_def.method_name, signature_constants.REGRESS_METHOD_NAME) self.assertEqual(len(signature_def.inputs), 1) self.assertEqual(len(signature_def.outputs), 1) self.assertProtoEquals( signature_def.inputs[signature_constants.REGRESS_INPUTS], meta_graph_pb2.TensorInfo(name=signature_constants.REGRESS_INPUTS)) self.assertProtoEquals( signature_def.outputs[signature_constants.REGRESS_OUTPUTS], meta_graph_pb2.TensorInfo( name=signature_constants.REGRESS_OUTPUTS))
def testConvertDefaultSignatureClassificationToSignatureDef(self): signatures_proto = manifest_pb2.Signatures() classification_signature = manifest_pb2.ClassificationSignature() classification_signature.input.CopyFrom( manifest_pb2.TensorBinding( tensor_name=signature_constants.CLASSIFY_INPUTS)) classification_signature.classes.CopyFrom( manifest_pb2.TensorBinding( tensor_name=signature_constants.CLASSIFY_OUTPUT_CLASSES)) classification_signature.scores.CopyFrom( manifest_pb2.TensorBinding( tensor_name=signature_constants.CLASSIFY_OUTPUT_SCORES)) signatures_proto.default_signature.classification_signature.CopyFrom( classification_signature) signatures_proto.default_signature.classification_signature.CopyFrom( classification_signature) signature_def = bundle_shim._convert_default_signature_to_signature_def( signatures_proto) # Validate classification signature correctly copied over. self.assertEqual(signature_def.method_name, signature_constants.CLASSIFY_METHOD_NAME) self.assertEqual(len(signature_def.inputs), 1) self.assertEqual(len(signature_def.outputs), 2) self.assertProtoEquals( signature_def.inputs[signature_constants.CLASSIFY_INPUTS], meta_graph_pb2.TensorInfo( name=signature_constants.CLASSIFY_INPUTS)) self.assertProtoEquals( signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_SCORES], meta_graph_pb2.TensorInfo( name=signature_constants.CLASSIFY_OUTPUT_SCORES)) self.assertProtoEquals( signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES], meta_graph_pb2.TensorInfo( name=signature_constants.CLASSIFY_OUTPUT_CLASSES))