def test_check_lengths_strict_no_exception(self): x = [{"inputs": [9, 4, 3, 8, 1], "targets": [3, 9, 4, 5]}] ds = create_default_dataset(x) task_feature_lengths = {"inputs": 5, "targets": 4} ds = feature_converters._check_lengths(ds, task_feature_lengths, strict=True, error_label="initial") list(ds.as_numpy_iterator())
def test_check_lengths_extra_features(self): x = [{"targets": [3, 9, 4, 5], "targets_pretokenized": "some text"}] output_types = {"targets": tf.int64, "targets_pretokenized": tf.string} output_shapes = {"targets": [4], "targets_pretokenized": []} ds = tf.data.Dataset.from_generator(lambda: x, output_types=output_types, output_shapes=output_shapes) task_feature_lengths = {"targets": 4} ds = feature_converters._check_lengths(ds, task_feature_lengths, strict=True, error_label="initial") list(ds.as_numpy_iterator())
def test_check_lengths_strict_exception(self): x = [{"inputs": [9, 4, 3, 8, 1], "targets": [3, 9, 4, 5]}] ds = create_default_dataset(x) task_feature_lengths = {"inputs": 7, "targets": 4} expected_msg = ( r".*Feature \\'inputs\\' has length not equal to the expected length of" r" 7 during initial validation.*") with self.assertRaisesRegex(tf.errors.InvalidArgumentError, expected_msg): ds = feature_converters._check_lengths(ds, task_feature_lengths, strict=True, error_label="initial") list(ds.as_numpy_iterator())