def test_call_missing_input_lengths(self):
        x = [{"inputs": [9, 4, 3, 8, 6], "targets": [3, 9, 4, 5]}]
        ds = tf.data.Dataset.from_generator(lambda: x,
                                            output_types={
                                                "inputs": tf.int64,
                                                "targets": tf.int64
                                            },
                                            output_shapes={
                                                "inputs": [5],
                                                "targets": [5]
                                            })
        task_feature_lengths = {"inputs": 5}

        with mock.patch.object(feature_converters.FeatureConverter,
                               "__abstractmethods__", set()):
            converter = feature_converters.FeatureConverter()
            feature_converters.FeatureConverter.TASK_FEATURE_DTYPES = {
                "inputs": tf.int64,
                "targets": tf.int64
            }
            expected_msg = (
                "The task_feature_lengths is missing features specified "
                "in the TASK_FEATURE_DTYPES: {'targets'}")
            with self.assertRaisesRegex(ValueError, expected_msg):
                converter(ds, task_feature_lengths)
    def test_validate_dataset_incorrect_rank(self):
        x = [{"inputs": [[9, 4, 3, 8, 6]], "targets": [3, 9, 4, 5]}]
        ds = tf.data.Dataset.from_generator(lambda: x,
                                            output_types={
                                                "inputs": tf.int64,
                                                "targets": tf.int64
                                            },
                                            output_shapes={
                                                "inputs": [None, 1],
                                                "targets": [None]
                                            })
        task_feature_lengths = {"inputs": 5, "targets": 4}

        with mock.patch.object(feature_converters.FeatureConverter,
                               "__abstractmethods__", set()):
            converter = feature_converters.FeatureConverter()
            expected_msg = (
                "Dataset has incorrect rank for feature 'inputs' during "
                "initial validation: Got 2, expected 1")
            with self.assertRaisesRegex(ValueError, expected_msg):
                converter._validate_dataset(
                    ds,
                    expected_features=task_feature_lengths.keys(),
                    expected_dtypes={
                        "inputs": tf.int64,
                        "targets": tf.int64
                    },
                    expected_lengths=task_feature_lengths,
                    strict=False,
                    expected_rank=1,
                    error_label="initial")
    def test_validate_dataset_missing_length(self):
        x = [{"inputs": [9, 4, 3, 8, 6], "targets": [3, 9, 4, 5]}]
        ds = tf.data.Dataset.from_generator(lambda: x,
                                            output_types={
                                                "inputs": tf.int64,
                                                "targets": tf.int64
                                            },
                                            output_shapes={
                                                "inputs": [5],
                                                "targets": [5]
                                            })
        input_lengths = {"inputs": 5}

        with mock.patch.object(feature_converters.FeatureConverter,
                               "__abstractmethods__", set()):
            converter = feature_converters.FeatureConverter()
            expected_msg = ("Sequence length for feature 'targets' is missing "
                            "during final validation")
            with self.assertRaisesRegex(ValueError, expected_msg):
                converter._validate_dataset(
                    ds,
                    expected_features=["inputs", "targets"],
                    expected_dtypes={
                        "inputs": tf.int64,
                        "targets": tf.int64
                    },
                    expected_lengths=input_lengths,
                    strict=True,
                    error_label="final")
Example #4
0
  def test_validate_dataset_missing_feature(self):
    x = [{"targets": [3, 9, 4, 5]}]
    ds = create_default_dataset(x, feature_names=["targets"])
    task_feature_lengths = {"inputs": 4, "targets": 4}

    with mock.patch.object(feature_converters.FeatureConverter,
                           "__abstractmethods__", set()):
      converter = feature_converters.FeatureConverter()
      expected_msg = ("Dataset is missing an expected feature during "
                      "initial validation: 'inputs'")
      with self.assertRaisesRegex(ValueError, expected_msg):
        converter._validate_dataset(
            ds,
            expected_features=task_feature_lengths.keys(),
            expected_dtypes={"inputs": tf.int32, "targets": tf.int32},
            expected_lengths=task_feature_lengths,
            strict=False,
            error_label="initial")
Example #5
0
  def test_validate_dataset_plaintext_field(self):
    x = [{"targets": [3, 9, 4, 5], "targets_plaintext": "some text"}]
    output_types = {"targets": tf.int64, "targets_plaintext": tf.string}
    output_shapes = {"targets": [4], "targets_plaintext": []}
    ds = tf.data.Dataset.from_generator(
        lambda: x, output_types=output_types, output_shapes=output_shapes)

    task_feature_lengths = {"targets": 4}
    with mock.patch.object(feature_converters.FeatureConverter,
                           "__abstractmethods__", set()):
      converter = feature_converters.FeatureConverter()
      # _validate_dataset works even if ds has targets and targets_plaintext
      ds = converter._validate_dataset(
          ds,
          expected_features=task_feature_lengths.keys(),
          expected_dtypes={"targets": tf.int64},
          expected_lengths=task_feature_lengths,
          strict=True,
          error_label="initial")