def testProjection(self): schema = schema_pb2.Schema() schema.CopyFrom(_SCHEMA) tensor_representations = { "dense_string": text_format.Parse( """dense_tensor { column_name: "string_feature" shape { dim { size: 2 } } default_value { bytes_value: "zzz" } }""", schema_pb2.TensorRepresentation()), "varlen_string": text_format.Parse( """varlen_sparse_tensor { column_name: "string_feature" }""", schema_pb2.TensorRepresentation()), "varlen_float": text_format.Parse( """varlen_sparse_tensor { column_name: "float_feature" }""", schema_pb2.TensorRepresentation()), } schema.tensor_representation_group[""].CopyFrom( schema_pb2.TensorRepresentationGroup( tensor_representation=tensor_representations)) tfxio = self._MakeTFXIO(schema) self.assertEqual(tensor_representations, tfxio.TensorRepresentations()) projected_tfxio = tfxio.Project( ["dense_string", "varlen_string", "varlen_float"]) self.assertEqual(tensor_representations, projected_tfxio.TensorRepresentations()) self.assertTrue(projected_tfxio.ArrowSchema().equals( pa.schema([ pa.field("float_feature", pa.list_(pa.float32())), pa.field("string_feature", pa.list_(pa.binary())), ]))) def _AssertFn(record_batch_list): self.assertLen(record_batch_list, 1) record_batch = record_batch_list[0] self.ValidateRecordBatch(record_batch) expected_schema = projected_tfxio.ArrowSchema() self.assertTrue( record_batch.schema.equals(expected_schema), "actual: {}; expected: {}".format(record_batch.schema, expected_schema)) tensor_adapter = projected_tfxio.TensorAdapter() dict_of_tensors = tensor_adapter.ToBatchTensors(record_batch) self.assertLen(dict_of_tensors, 3) self.assertIn("dense_string", dict_of_tensors) self.assertIn("varlen_string", dict_of_tensors) self.assertIn("varlen_float", dict_of_tensors) with beam.Pipeline() as p: # Setting the betch_size to make sure only one batch is generated. record_batch_pcoll = p | projected_tfxio.BeamSource( batch_size=len(_EXAMPLES)) beam_testing_util.assert_that(record_batch_pcoll, _AssertFn)
def testExplicitTensorRepresentations(self): schema = schema_pb2.Schema() schema.CopyFrom(_SCHEMA) tensor_representations = { "my_feature": text_format.Parse( """ dense_tensor { column_name: "string_feature" shape { dim { size: 2 } } default_value { bytes_value: "zzz" } }""", schema_pb2.TensorRepresentation()) } schema.tensor_representation_group[""].CopyFrom( schema_pb2.TensorRepresentationGroup( tensor_representation=tensor_representations)) tfxio = self._MakeTFXIO(schema) self.assertEqual(tensor_representations, tfxio.TensorRepresentations())
def testExplicitTensorRepresentations(self): """Tests when the tensor representation is explicitely provided in the schema.""" schema = schema_pb2.Schema() schema.CopyFrom(_SCHEMA) tensor_representations = { "my_feature": text_format.Parse( """ dense_tensor { column_name: "string_feature" shape { dim { size: 1 } } default_value { bytes_value: "abc" } }""", schema_pb2.TensorRepresentation()) } schema.tensor_representation_group[""].CopyFrom( schema_pb2.TensorRepresentationGroup( tensor_representation=tensor_representations)) tfxio = self._MakeTFXIO(_COLUMN_NAMES, schema=schema) self.assertEqual(tensor_representations, tfxio.TensorRepresentations())
def testExplicitTensorRepresentations(self): """Tests when the tensor representation is explicitely provided in the schema.""" schema = schema_pb2.Schema() schema.CopyFrom(_SCHEMA) tensor_representations = { "my_feature": text_format.Parse( """ dense_tensor { column_name: "string_feature" shape { dim { size: 1 } } default_value { bytes_value: "abc" } }""", schema_pb2.TensorRepresentation()) } schema.tensor_representation_group[""].CopyFrom( schema_pb2.TensorRepresentationGroup( tensor_representation=tensor_representations)) tfxio = ParquetTFXIO(file_pattern=self._example_file, column_names=_COLUMN_NAMES, schema=schema, telemetry_descriptors=_TELEMETRY_DESCRIPTORS) self.assertEqual(tensor_representations, tfxio.TensorRepresentations())