コード例 #1
0
    def testProjection(self):
        schema = schema_pb2.Schema()
        schema.CopyFrom(_SCHEMA)
        tensor_representations = {
            "dense_string":
            text_format.Parse(
                """dense_tensor {
             column_name: "string_feature"
             shape { dim { size: 2 } }
             default_value { bytes_value: "zzz" }
           }""", schema_pb2.TensorRepresentation()),
            "varlen_string":
            text_format.Parse(
                """varlen_sparse_tensor {
             column_name: "string_feature"
           }""", schema_pb2.TensorRepresentation()),
            "varlen_float":
            text_format.Parse(
                """varlen_sparse_tensor {
             column_name: "float_feature"
           }""", schema_pb2.TensorRepresentation()),
        }
        schema.tensor_representation_group[""].CopyFrom(
            schema_pb2.TensorRepresentationGroup(
                tensor_representation=tensor_representations))

        tfxio = self._MakeTFXIO(schema)
        self.assertEqual(tensor_representations, tfxio.TensorRepresentations())

        projected_tfxio = tfxio.Project(
            ["dense_string", "varlen_string", "varlen_float"])
        self.assertEqual(tensor_representations,
                         projected_tfxio.TensorRepresentations())
        self.assertTrue(projected_tfxio.ArrowSchema().equals(
            pa.schema([
                pa.field("float_feature", pa.list_(pa.float32())),
                pa.field("string_feature", pa.list_(pa.binary())),
            ])))

        def _AssertFn(record_batch_list):
            self.assertLen(record_batch_list, 1)
            record_batch = record_batch_list[0]
            self.ValidateRecordBatch(record_batch)
            expected_schema = projected_tfxio.ArrowSchema()
            self.assertTrue(
                record_batch.schema.equals(expected_schema),
                "actual: {}; expected: {}".format(record_batch.schema,
                                                  expected_schema))
            tensor_adapter = projected_tfxio.TensorAdapter()
            dict_of_tensors = tensor_adapter.ToBatchTensors(record_batch)
            self.assertLen(dict_of_tensors, 3)
            self.assertIn("dense_string", dict_of_tensors)
            self.assertIn("varlen_string", dict_of_tensors)
            self.assertIn("varlen_float", dict_of_tensors)

        with beam.Pipeline() as p:
            # Setting the betch_size to make sure only one batch is generated.
            record_batch_pcoll = p | projected_tfxio.BeamSource(
                batch_size=len(_EXAMPLES))
            beam_testing_util.assert_that(record_batch_pcoll, _AssertFn)
コード例 #2
0
    def testExplicitTensorRepresentations(self):
        schema = schema_pb2.Schema()
        schema.CopyFrom(_SCHEMA)
        tensor_representations = {
            "my_feature":
            text_format.Parse(
                """
            dense_tensor {
             column_name: "string_feature"
             shape { dim { size: 2 } }
             default_value { bytes_value: "zzz" }
           }""", schema_pb2.TensorRepresentation())
        }
        schema.tensor_representation_group[""].CopyFrom(
            schema_pb2.TensorRepresentationGroup(
                tensor_representation=tensor_representations))

        tfxio = self._MakeTFXIO(schema)
        self.assertEqual(tensor_representations, tfxio.TensorRepresentations())
コード例 #3
0
    def testExplicitTensorRepresentations(self):
        """Tests when the tensor representation is explicitely provided in the schema."""
        schema = schema_pb2.Schema()
        schema.CopyFrom(_SCHEMA)
        tensor_representations = {
            "my_feature":
            text_format.Parse(
                """
            dense_tensor {
             column_name: "string_feature"
             shape { dim { size: 1 } }
             default_value { bytes_value: "abc" }
           }""", schema_pb2.TensorRepresentation())
        }
        schema.tensor_representation_group[""].CopyFrom(
            schema_pb2.TensorRepresentationGroup(
                tensor_representation=tensor_representations))

        tfxio = self._MakeTFXIO(_COLUMN_NAMES, schema=schema)
        self.assertEqual(tensor_representations, tfxio.TensorRepresentations())
コード例 #4
0
    def testExplicitTensorRepresentations(self):
        """Tests when the tensor representation is explicitely provided in the schema."""
        schema = schema_pb2.Schema()
        schema.CopyFrom(_SCHEMA)
        tensor_representations = {
            "my_feature":
            text_format.Parse(
                """
          dense_tensor {
           column_name: "string_feature"
           shape { dim { size: 1 } }
           default_value { bytes_value: "abc" }
         }""", schema_pb2.TensorRepresentation())
        }
        schema.tensor_representation_group[""].CopyFrom(
            schema_pb2.TensorRepresentationGroup(
                tensor_representation=tensor_representations))

        tfxio = ParquetTFXIO(file_pattern=self._example_file,
                             column_names=_COLUMN_NAMES,
                             schema=schema,
                             telemetry_descriptors=_TELEMETRY_DESCRIPTORS)
        self.assertEqual(tensor_representations, tfxio.TensorRepresentations())