Exemplo n.º 1
0
 def test_check_lengths_strict_no_exception(self):
   x = [{"inputs": [9, 4, 3, 8, 1], "targets": [3, 9, 4, 5]}]
   ds = create_default_dataset(x)
   task_feature_lengths = {"inputs": 5, "targets": 4}
   ds = feature_converters._check_lengths(
       ds, task_feature_lengths, strict=True, error_label="initial")
   list(ds.as_numpy_iterator())
Exemplo n.º 2
0
 def test_check_lengths_extra_features(self):
   x = [{"targets": [3, 9, 4, 5], "targets_plaintext": "some text"}]
   output_types = {"targets": tf.int64, "targets_plaintext": tf.string}
   output_shapes = {"targets": [4], "targets_plaintext": []}
   ds = tf.data.Dataset.from_generator(
       lambda: x, output_types=output_types, output_shapes=output_shapes)
   task_feature_lengths = {"targets": 4}
   ds = feature_converters._check_lengths(
       ds, task_feature_lengths, strict=True, error_label="initial")
   list(ds.as_numpy_iterator())
Exemplo n.º 3
0
 def test_check_lengths_strict_exception(self):
   x = [{"inputs": [9, 4, 3, 8, 1], "targets": [3, 9, 4, 5]}]
   ds = create_default_dataset(x)
   task_feature_lengths = {"inputs": 7, "targets": 4}
   expected_msg = (
       r".*Feature \\'inputs\\' has length not equal to the expected length of"
       r" 7 during initial validation.*")
   with self.assertRaisesRegex(tf.errors.InvalidArgumentError, expected_msg):
     ds = feature_converters._check_lengths(
         ds, task_feature_lengths, strict=True, error_label="initial")
     list(ds.as_numpy_iterator())