def TensorRepresentations(self) -> tensor_adapter.TensorRepresentations: result = tensor_rep_util.GetTensorRepresentationsFromSchema( self._schema) if result is None: result = (tensor_rep_util.InferTensorRepresentationsFromSchema( self._schema)) return result
def testInferTensorRepresentationsFromSchema( self, ascii_proto, expected, generate_legacy_feature_spec=False, schema_is_mixed=False): if not _IS_LEGACY_SCHEMA and generate_legacy_feature_spec: raise self.skipTest('This test exersizes legacy inference logic, but the ' 'schema is not legacy schema.') schema = text_format.Parse(ascii_proto, schema_pb2.Schema()) if _IS_LEGACY_SCHEMA: schema.generate_legacy_feature_spec = generate_legacy_feature_spec expected_protos = { k: text_format.Parse(pbtxt, schema_pb2.TensorRepresentation()) for k, pbtxt in expected.items() } if not schema_is_mixed: self.assertEqual( expected_protos, tensor_representation_util.InferTensorRepresentationsFromSchema( schema)) self.assertEqual( expected_protos, tensor_representation_util.InferTensorRepresentationsFromMixedSchema( schema))
def testInferTensorRepresentationsFromSchemaInvalidSchema( self, ascii_proto, error_msg, generate_legacy_feature_spec=False): if not _IS_LEGACY_SCHEMA and generate_legacy_feature_spec: print('Skipping test case: ', self.id(), file=sys.stderr) return schema = text_format.Parse(ascii_proto, schema_pb2.Schema()) if _IS_LEGACY_SCHEMA: schema.generate_legacy_feature_spec = generate_legacy_feature_spec with self.assertRaisesRegex(ValueError, error_msg): tensor_representation_util.InferTensorRepresentationsFromSchema(schema)
def testInferTensorRepresentationsFromSchemaInvalidSchema( self, ascii_proto, error_msg, generate_legacy_feature_spec=False): if not _IS_LEGACY_SCHEMA and generate_legacy_feature_spec: raise self.skipTest('This test exersizes legacy inference logic, but the ' 'schema is not legacy schema.') schema = text_format.Parse(ascii_proto, schema_pb2.Schema()) if _IS_LEGACY_SCHEMA: schema.generate_legacy_feature_spec = generate_legacy_feature_spec with self.assertRaisesRegex(ValueError, error_msg): tensor_representation_util.InferTensorRepresentationsFromSchema(schema) with self.assertRaisesRegex(ValueError, error_msg): tensor_representation_util.InferTensorRepresentationsFromMixedSchema( schema)
def testInferTensorRepresentationsFromSchema( self, ascii_proto, expected, generate_legacy_feature_spec=False): # Skip a test if it's testing legacy logic but the schema is not the # legacy schema. if not _IS_LEGACY_SCHEMA and generate_legacy_feature_spec: print('Skipping test case: ', self.id(), file=sys.stderr) return schema = text_format.Parse(ascii_proto, schema_pb2.Schema()) if _IS_LEGACY_SCHEMA: schema.generate_legacy_feature_spec = generate_legacy_feature_spec expected_protos = { k: text_format.Parse(pbtxt, schema_pb2.TensorRepresentation()) for k, pbtxt in expected.items() } self.assertEqual( expected_protos, tensor_representation_util.InferTensorRepresentationsFromSchema( schema))