def test_validate_dataset_incorrect_rank(self):
        x = [{"inputs": [[9, 4, 3, 8, 6]], "targets": [3, 9, 4, 5]}]
        ds = tf.data.Dataset.from_generator(lambda: x,
                                            output_types={
                                                "inputs": tf.int64,
                                                "targets": tf.int64
                                            },
                                            output_shapes={
                                                "inputs": [None, 1],
                                                "targets": [None]
                                            })
        task_feature_lengths = {"inputs": 5, "targets": 4}

        with mock.patch.object(feature_converters.FeatureConverter,
                               "__abstractmethods__", set()):
            converter = feature_converters.FeatureConverter()
            expected_msg = (
                "Dataset has incorrect rank for feature 'inputs' during "
                "initial validation: Got 2, expected 1")
            with self.assertRaisesRegex(ValueError, expected_msg):
                converter._validate_dataset(
                    ds,
                    expected_features=task_feature_lengths.keys(),
                    expected_dtypes={
                        "inputs": tf.int64,
                        "targets": tf.int64
                    },
                    expected_lengths=task_feature_lengths,
                    strict=False,
                    expected_rank=1,
                    error_label="initial")
    def test_call_missing_input_lengths(self):
        x = [{"inputs": [9, 4, 3, 8, 6], "targets": [3, 9, 4, 5]}]
        ds = tf.data.Dataset.from_generator(lambda: x,
                                            output_types={
                                                "inputs": tf.int64,
                                                "targets": tf.int64
                                            },
                                            output_shapes={
                                                "inputs": [5],
                                                "targets": [5]
                                            })
        task_feature_lengths = {"inputs": 5}

        with mock.patch.object(feature_converters.FeatureConverter,
                               "__abstractmethods__", set()):
            converter = feature_converters.FeatureConverter()
            feature_converters.FeatureConverter.TASK_FEATURE_DTYPES = {
                "inputs": tf.int64,
                "targets": tf.int64
            }
            expected_msg = (
                "The task_feature_lengths is missing features specified "
                "in the TASK_FEATURE_DTYPES: {'targets'}")
            with self.assertRaisesRegex(ValueError, expected_msg):
                converter(ds, task_feature_lengths)
    def test_validate_dataset_pretokenized_field(self):
        x = [{"targets": [3, 9, 4, 5], "targets_pretokenized": "some text"}]
        output_types = {"targets": tf.int64, "targets_pretokenized": tf.string}
        output_shapes = {"targets": [4], "targets_pretokenized": []}
        ds = tf.data.Dataset.from_generator(lambda: x,
                                            output_types=output_types,
                                            output_shapes=output_shapes)

        task_feature_lengths = {"targets": 4}
        with mock.patch.object(feature_converters.FeatureConverter,
                               "__abstractmethods__", set()):
            converter = feature_converters.FeatureConverter()
            # _validate_dataset works even if ds has targets and targets_pretokenized
            ds = converter._validate_dataset(
                ds,
                expected_features=task_feature_lengths.keys(),
                expected_dtypes={"targets": tf.int64},
                expected_lengths=task_feature_lengths,
                strict=True,
                error_label="initial")
    def test_validate_dataset_missing_feature(self):
        x = [{"targets": [3, 9, 4, 5]}]
        ds = create_default_dataset(x, feature_names=["targets"])
        task_feature_lengths = {"inputs": 4, "targets": 4}

        with mock.patch.object(feature_converters.FeatureConverter,
                               "__abstractmethods__", set()):
            converter = feature_converters.FeatureConverter()
            expected_msg = ("Dataset is missing an expected feature during "
                            "initial validation: 'inputs'")
            with self.assertRaisesRegex(ValueError, expected_msg):
                converter._validate_dataset(
                    ds,
                    expected_features=task_feature_lengths.keys(),
                    expected_dtypes={
                        "inputs": tf.int32,
                        "targets": tf.int32
                    },
                    expected_lengths=task_feature_lengths,
                    strict=False,
                    error_label="initial")