Ejemplo n.º 1
0
    def assertFeatureTest(self, fdict, test, feature, shape, dtype):
        """Test that encode=>decoding of a value works correctly."""

        # self._process_subtest_exp(e)
        input_value = {"inner": test.value}

        if test.raise_cls is not None:
            with self._subTest("raise"):
                if not test.raise_msg:
                    raise ValueError(
                        "test.raise_msg should be set with {} for test {}".
                        format(test.raise_cls, type(feature)))
                with self.assertRaisesWithPredicateMatch(
                        test.raise_cls, test.raise_msg):
                    features_encode_decode(fdict, input_value)
        else:
            # Test the serialization only
            if test.expected_serialized is not None:
                with self._subTest("out_serialize"):
                    self.assertEqual(
                        test.expected_serialized,
                        feature.encode_example(test.value),
                    )

            # Assert the returned type match the expected one
            with self._subTest("out"):
                out = features_encode_decode(fdict,
                                             input_value,
                                             as_tensor=True)
                out = out["inner"]
                with self._subTest("dtype"):
                    out_dtypes = utils.map_nested(lambda s: s.dtype, out)
                    self.assertEqual(out_dtypes, feature.dtype)
                with self._subTest("shape"):
                    # For shape, because (None, 3) match with (5, 3), we use
                    # tf.TensorShape.assert_is_compatible_with on each of the elements
                    out_shapes = utils.zip_nested(out, feature.shape)
                    utils.map_nested(
                        lambda x: x[0].shape.assert_is_compatible_with(x[1]),
                        out_shapes)

            # Test serialization + decoding from disk
            with self._subTest("out_value"):
                decoded_examples = features_encode_decode(fdict, input_value)
                decoded_examples = decoded_examples["inner"]
                if isinstance(decoded_examples, dict):
                    # assertAllEqual do not works well with dictionaries so assert
                    # on each individual elements instead
                    zipped_examples = utils.zip_nested(
                        test.expected,
                        decoded_examples,
                        dict_only=True,
                    )
                    utils.map_nested(
                        lambda x: self.assertAllEqual(x[0], x[1]),
                        zipped_examples,
                        dict_only=True,
                    )
                else:
                    self.assertAllEqual(test.expected, decoded_examples)
Ejemplo n.º 2
0
    def assertAllEqualNested(self, d1, d2, *, atol: Optional[float] = None):
        """Same as assertAllEqual but compatible with nested dict.

    Args:
      d1: First element to compare
      d2: Second element to compare
      atol: If given, perform a close float comparison. Otherwise, perform an
        exact comparison
    """
        if isinstance(d1, dict):
            # assertAllEqual do not works well with dictionaries so assert
            # on each individual elements instead
            zipped_examples = utils.zip_nested(d1, d2, dict_only=True)
            utils.map_nested(
                # recursively call assertAllEqualNested in case there is a dataset.
                lambda x: self.assertAllEqualNested(x[0], x[1], atol=atol),
                zipped_examples,
                dict_only=True,
            )
        elif isinstance(d1, (tf.data.Dataset, dataset_utils._IterableDataset)):  # pylint: disable=protected-access
            # Checks length and elements of the dataset. At the moment, more than one
            # level of nested datasets is not supported.
            self.assertEqual(len(d1), len(d2))
            for ex1, ex2 in zip(d1, d2):
                self.assertAllEqualNested(ex1, ex2, atol=atol)
        elif atol:
            self.assertAllClose(d1, d2, atol=atol)
        else:
            self.assertAllEqual(d1, d2)
Ejemplo n.º 3
0
    def assertFeatureTest(self, fdict, test, feature, shape, dtype):
        """Test that encode=>decoding of a value works correctly."""
        # test feature.encode_example can be pickled and unpickled for beam.
        dill.loads(dill.dumps(feature.encode_example))

        input_value = {'inner': test.value}

        if test.raise_cls is not None:
            with self._subTest('raise'):
                if not test.raise_msg:
                    raise ValueError(
                        'test.raise_msg should be set with {} for test {}'.
                        format(test.raise_cls, type(feature)))
                with self.assertRaisesWithPredicateMatch(
                        test.raise_cls, test.raise_msg):
                    features_encode_decode(fdict,
                                           input_value,
                                           decoders=test.decoders)
        else:
            # Test the serialization only
            if test.expected_serialized is not None:
                with self._subTest('out_serialize'):
                    self.assertEqual(
                        test.expected_serialized,
                        feature.encode_example(test.value),
                    )

            # Test serialization + decoding from disk
            with self._subTest('out'):
                out_tensor, out_numpy = features_encode_decode(
                    fdict,
                    input_value,
                    decoders={'inner': test.decoders},
                )
                out_tensor = out_tensor['inner']
                out_numpy = out_numpy['inner']

                # Assert the returned type match the expected one
                with self._subTest('dtype'):
                    out_dtypes = tf.nest.map_structure(lambda s: s.dtype,
                                                       out_tensor)
                    self.assertEqual(out_dtypes, test.dtype or feature.dtype)
                with self._subTest('shape'):
                    # For shape, because (None, 3) match with (5, 3), we use
                    # tf.TensorShape.assert_is_compatible_with on each of the elements
                    expected_shape = feature.shape if test.shape is None else test.shape
                    out_shapes = utils.zip_nested(out_tensor, expected_shape)
                    utils.map_nested(
                        lambda x: x[0].shape.assert_is_compatible_with(x[1]),
                        out_shapes)

                # Assert value
                with self._subTest('out_value'):
                    # Eventually construct the tf.RaggedTensor
                    expected = tf.nest.map_structure(
                        lambda t: t.build()
                        if isinstance(t, RaggedConstant) else t, test.expected)
                    self.assertAllEqualNested(out_numpy, expected)
Ejemplo n.º 4
0
 def assertAllEqualNested(self, d1, d2):
     """Same as assertAllEqual but compatible with nested dict."""
     if isinstance(d1, dict):
         # assertAllEqual do not works well with dictionaries so assert
         # on each individual elements instead
         zipped_examples = utils.zip_nested(d1, d2, dict_only=True)
         utils.map_nested(
             lambda x: self.assertAllEqual(x[0], x[1]),
             zipped_examples,
             dict_only=True,
         )
     else:
         self.assertAllEqual(d1, d2)
Ejemplo n.º 5
0
    def assertFeature(self, specs, serialized_info, tests):
        """Test the TFRecordExampleAdapter encoding."""

        adapter = file_format_adapter.TFRecordExampleAdapter(specs)

        with self._subTest("serialized_info"):
            self.assertEqual(serialized_info,
                             adapter._parser._build_feature_specs())

        for i, test in enumerate(tests):
            with self._subTest(str(i)):

                if test.raise_cls is not None:
                    with self.assertRaisesWithPredicateMatch(
                            test.raise_cls, test.raise_msg):
                        adapter._serializer.serialize_example(test.value)
                    continue
                serialized = adapter._serializer.serialize_example(test.value)

                if test.expected_serialized is not None:
                    example_proto = tf.train.Example()
                    example_proto.ParseFromString(serialized)
                    expected_proto = tf.train.Example(
                        features=tf.train.Features(
                            feature=test.expected_serialized))
                    self.assertEqual(expected_proto, example_proto)

                example = _parse_example(serialized,
                                         adapter._parser.parse_example)

                with self._subTest("dtype"):
                    out_dtypes = utils.map_nested(lambda s: s.dtype, example)
                    expected_dtypes = utils.map_nested(lambda s: s.dtype,
                                                       specs)
                    self.assertEqual(out_dtypes, expected_dtypes)
                with self._subTest("shape"):
                    # For shape, because (None, 3) match with (5, 3), we use
                    # tf.TensorShape.assert_is_compatible_with on each of the elements
                    utils.map_nested(
                        lambda x: x[0].shape.assert_is_compatible_with(x[1].
                                                                       shape),
                        utils.zip_nested(example, specs))
                np_example = dataset_utils.as_numpy(example)
                self.assertAllEqualNested(np_example, test.expected)
Ejemplo n.º 6
0
 def assertAllEqualNested(self, d1, d2):
     """Same as assertAllEqual but compatible with nested dict."""
     if isinstance(d1, dict):
         # assertAllEqual do not works well with dictionaries so assert
         # on each individual elements instead
         zipped_examples = utils.zip_nested(d1, d2, dict_only=True)
         utils.map_nested(
             # recursively call assertAllEqualNested in case there is a dataset.
             lambda x: self.assertAllEqualNested(x[0], x[1]),
             zipped_examples,
             dict_only=True,
         )
     elif isinstance(d1, (tf.data.Dataset, dataset_utils._IterableDataset)):  # pylint: disable=protected-access
         # Checks length and elements of the dataset. At the moment, more than one
         # level of nested datasets is not supported.
         self.assertEqual(len(d1), len(d2))
         for ex1, ex2 in zip(d1, d2):
             self.assertAllEqualNested(ex1, ex2)
     else:
         self.assertAllEqual(d1, d2)
Ejemplo n.º 7
0
    def _process_exp(self, exp):

        # Check the shape/dtype
        with self._subTest("shape"):
            self.assertEqual(exp.feature.shape, exp.shape)
        with self._subTest("dtype"):
            self.assertEqual(exp.feature.dtype, exp.dtype)

        # Check the serialized features
        if exp.serialized_features is not None:
            with self._subTest("serialized_features"):
                self.assertEqual(
                    exp.serialized_features,
                    exp.feature.get_serialized_features(),
                )

        # Create the feature dict
        fdict = features.FeaturesDict({exp.name: exp.feature})
        for i, test in enumerate(exp.tests):
            with self._subTest(str(i)):
                # self._process_subtest_exp(e)
                input_value = {exp.name: test.value}

                if test.raise_cls is not None:
                    with self._subTest("raise"):
                        if not test.raise_msg:
                            raise ValueError(
                                "test.raise_msg should be set with {}for test {}"
                                .format(test.raise_cls, exp.name))
                        with self.assertRaisesWithPredicateMatch(
                                test.raise_cls, test.raise_msg):
                            features_encode_decode(fdict, input_value)
                else:
                    # Test the serialization only
                    if test.expected_serialized is not None:
                        with self._subTest("out_serialize"):
                            self.assertEqual(
                                test.expected_serialized,
                                exp.feature.encode_sample(test.value),
                            )

                    # Assert the returned type match the expected one
                    with self._subTest("out_extract"):
                        out = features_encode_decode(fdict,
                                                     input_value,
                                                     as_tensor=True)
                        out = out[exp.name]
                    with self._subTest("out_dtype"):
                        out_dtypes = utils.map_nested(lambda s: s.dtype, out)
                        self.assertEqual(out_dtypes, exp.feature.dtype)
                    with self._subTest("out_shape"):
                        # For shape, because (None, 3) match with (5, 3), we use
                        # tf.TensorShape.assert_is_compatible_with on each of the elements
                        out_shapes = utils.zip_nested(out, exp.feature.shape)
                        utils.map_nested(
                            lambda x: x[0].shape.assert_is_compatible_with(x[
                                1]), out_shapes)

                    # Test serialization + decoding from disk
                    with self._subTest("out_value"):
                        decoded_samples = features_encode_decode(
                            fdict, input_value)
                        self.assertAllEqual(test.expected,
                                            decoded_samples[exp.name])