def test_check_lengths_strict_no_exception(self):
     x = [{"inputs": [9, 4, 3, 8, 1], "targets": [3, 9, 4, 5]}]
     ds = create_default_dataset(x)
     task_feature_lengths = {"inputs": 5, "targets": 4}
     ds = feature_converters._check_lengths(ds,
                                            task_feature_lengths,
                                            strict=True,
                                            error_label="initial")
     list(ds.as_numpy_iterator())
 def test_check_lengths_extra_features(self):
     x = [{"targets": [3, 9, 4, 5], "targets_pretokenized": "some text"}]
     output_types = {"targets": tf.int64, "targets_pretokenized": tf.string}
     output_shapes = {"targets": [4], "targets_pretokenized": []}
     ds = tf.data.Dataset.from_generator(lambda: x,
                                         output_types=output_types,
                                         output_shapes=output_shapes)
     task_feature_lengths = {"targets": 4}
     ds = feature_converters._check_lengths(ds,
                                            task_feature_lengths,
                                            strict=True,
                                            error_label="initial")
     list(ds.as_numpy_iterator())
 def test_check_lengths_strict_exception(self):
     x = [{"inputs": [9, 4, 3, 8, 1], "targets": [3, 9, 4, 5]}]
     ds = create_default_dataset(x)
     task_feature_lengths = {"inputs": 7, "targets": 4}
     expected_msg = (
         r".*Feature \\'inputs\\' has length not equal to the expected length of"
         r" 7 during initial validation.*")
     with self.assertRaisesRegex(tf.errors.InvalidArgumentError,
                                 expected_msg):
         ds = feature_converters._check_lengths(ds,
                                                task_feature_lengths,
                                                strict=True,
                                                error_label="initial")
         list(ds.as_numpy_iterator())