예제 #1
0
 def TensorRepresentations(self) -> tensor_adapter.TensorRepresentations:
     result = tensor_rep_util.GetTensorRepresentationsFromSchema(
         self._schema)
     if result is None:
         result = (tensor_rep_util.InferTensorRepresentationsFromSchema(
             self._schema))
     return result
 def testInferTensorRepresentationsFromSchema(
     self,
     ascii_proto,
     expected,
     generate_legacy_feature_spec=False,
     schema_is_mixed=False):
   if not _IS_LEGACY_SCHEMA and generate_legacy_feature_spec:
     raise self.skipTest('This test exersizes legacy inference logic, but the '
                         'schema is not legacy schema.')
   schema = text_format.Parse(ascii_proto, schema_pb2.Schema())
   if _IS_LEGACY_SCHEMA:
     schema.generate_legacy_feature_spec = generate_legacy_feature_spec
   expected_protos = {
       k: text_format.Parse(pbtxt, schema_pb2.TensorRepresentation())
       for k, pbtxt in expected.items()
   }
   if not schema_is_mixed:
     self.assertEqual(
         expected_protos,
         tensor_representation_util.InferTensorRepresentationsFromSchema(
             schema))
   self.assertEqual(
       expected_protos,
       tensor_representation_util.InferTensorRepresentationsFromMixedSchema(
           schema))
예제 #3
0
 def testInferTensorRepresentationsFromSchemaInvalidSchema(
     self, ascii_proto, error_msg, generate_legacy_feature_spec=False):
   if not _IS_LEGACY_SCHEMA and generate_legacy_feature_spec:
     print('Skipping test case: ', self.id(), file=sys.stderr)
     return
   schema = text_format.Parse(ascii_proto, schema_pb2.Schema())
   if _IS_LEGACY_SCHEMA:
     schema.generate_legacy_feature_spec = generate_legacy_feature_spec
   with self.assertRaisesRegex(ValueError, error_msg):
     tensor_representation_util.InferTensorRepresentationsFromSchema(schema)
 def testInferTensorRepresentationsFromSchemaInvalidSchema(
     self, ascii_proto, error_msg, generate_legacy_feature_spec=False):
   if not _IS_LEGACY_SCHEMA and generate_legacy_feature_spec:
     raise self.skipTest('This test exersizes legacy inference logic, but the '
                         'schema is not legacy schema.')
   schema = text_format.Parse(ascii_proto, schema_pb2.Schema())
   if _IS_LEGACY_SCHEMA:
     schema.generate_legacy_feature_spec = generate_legacy_feature_spec
   with self.assertRaisesRegex(ValueError, error_msg):
     tensor_representation_util.InferTensorRepresentationsFromSchema(schema)
   with self.assertRaisesRegex(ValueError, error_msg):
     tensor_representation_util.InferTensorRepresentationsFromMixedSchema(
         schema)
 def testInferTensorRepresentationsFromSchema(
         self, ascii_proto, expected, generate_legacy_feature_spec=False):
     # Skip a test if it's testing legacy logic but the schema is not the
     # legacy schema.
     if not _IS_LEGACY_SCHEMA and generate_legacy_feature_spec:
         print('Skipping test case: ', self.id(), file=sys.stderr)
         return
     schema = text_format.Parse(ascii_proto, schema_pb2.Schema())
     if _IS_LEGACY_SCHEMA:
         schema.generate_legacy_feature_spec = generate_legacy_feature_spec
     expected_protos = {
         k: text_format.Parse(pbtxt, schema_pb2.TensorRepresentation())
         for k, pbtxt in expected.items()
     }
     self.assertEqual(
         expected_protos,
         tensor_representation_util.InferTensorRepresentationsFromSchema(
             schema))