Exemplo n.º 1
0
    def validate_fully_cached_task(self, name, sequence_length,
                                   expected_dataset):
        new_task = TaskRegistry.get(name)
        self.assertLen(new_task.preprocessors, 6)
        self.assertEqual(new_task.metric_fns, self.metrics_fns)
        self.assertIsInstance(new_task.preprocessors[-2],
                              CacheDatasetPlaceholder)
        self.assertTrue(new_task.preprocessors[-2].required)

        with self.assertRaisesWithLiteralMatch(
                ValueError,
                f"Task '{name}' requires caching, but was called with "
                "`use_cached=False`."):
            new_task.get_dataset(None)

        # Disable caching restriction to verify dataset is correct.
        new_task.preprocessors[-2]._required = False

        with self.assertRaisesWithLiteralMatch(
                ValueError,
                f"Fully-cached task '{name}' can only be loaded with "
                f'`sequence_length={sequence_length}` or `None`.'):
            new_task.get_dataset(
                {k: v + 1
                 for k, v in sequence_length.items()},
                use_cached=False)

        test_utils.assert_dataset(new_task.get_dataset(None, shuffle=False),
                                  expected_dataset)
        test_utils.assert_dataset(
            new_task.get_dataset(sequence_length, shuffle=False),
            expected_dataset)
Exemplo n.º 2
0
  def test_get_targets_and_examples(self):
    def _task_from_tensor_slices(name, tensor_slices, label_classes):
      return dataset_providers.Task(
          name,
          dataset_providers.FunctionDataSource(
              lambda split, shuffle_files:
              tf.data.Dataset.from_tensor_slices(tensor_slices),
              splits=("validation")),
          preprocessors=[utils.map_over_dataset(lambda ex: {
              "inputs": tf.range(ex["inputs_lengths"]),
              "targets": tf.range(ex["targets_lengths"]),
              "targets_pretokenized": ex["targets_pretokenized"],
          })],
          postprocess_fn=functools.partial(
              _string_label_to_class_id_postprocessor,
              label_classes=label_classes),
          output_features={"inputs": dataset_providers.Feature(mock.Mock()),
                           "targets": dataset_providers.Feature(mock.Mock())}
      )
    task1 = _task_from_tensor_slices(
        "task1",
        {
            "inputs_lengths": [3, 2],
            "targets_lengths": [2, 3],
            "targets_pretokenized": ["e6", "e5"],
        },
        ("e4", "e5", "e6"))
    task2 = _task_from_tensor_slices(
        "task2",
        {
            "inputs_lengths": [1],
            "targets_lengths": [4],
            "targets_pretokenized": ["e4"],
        },
        ("e2", "e3", "e4"))
    cached_targets, cached_task_datasets, max_sequence_length = (
        evaluation.get_targets_and_examples(
            [task1, task2],
            lambda t: t.get_dataset(
                split="validation", sequence_length=None, shuffle=False))
        )

    self.assertDictEqual({"task1": [2, 1], "task2": [2]}, cached_targets)
    self.assertDictEqual({"inputs": 3, "targets": 4}, max_sequence_length)
    self.assertCountEqual(["task1", "task2"], cached_task_datasets.keys())
    self.assertLen(cached_task_datasets["task1"], 2)
    self.assertLen(cached_task_datasets["task2"], 1)
    expected_task1_examples = [
        {"inputs": [0, 1, 2], "targets": [0, 1], "targets_pretokenized": "e6"},
        {"inputs": [0, 1], "targets": [0, 1, 2], "targets_pretokenized": "e5"}
    ]
    expected_task2_examples = [
        {"inputs": [0], "targets": [0, 1, 2, 3], "targets_pretokenized": "e4"},
    ]
    test_utils.assert_dataset(cached_task_datasets["task1"],
                              expected_task1_examples)
    test_utils.assert_dataset(cached_task_datasets["task2"],
                              expected_task2_examples)
    def test_assert_dataset(self):
        first_dataset = tf.data.Dataset.from_tensor_slices({
            'key1': ['val1'],
            'key2': ['val2']
        })

        # Equal
        assert_dataset(first_dataset, {'key1': [b'val1'], 'key2': [b'val2']})
        assert_dataset(first_dataset, {
            'key1': [b'val1'],
            'key2': [b'val2']
        },
                       expected_dtypes={'key1': tf.string})

        # Unequal value
        with self.assertRaises(AssertionError):
            assert_dataset(first_dataset, {
                'key1': [b'val1'],
                'key2': [b'val2x']
            })

        # Wrong dtype
        with self.assertRaises(AssertionError):
            assert_dataset(first_dataset, {
                'key1': [b'val1'],
                'key2': [b'val2']
            },
                           expected_dtypes={'key1': tf.int32})

        # Additional key, value
        with self.assertRaises(AssertionError):
            assert_dataset(first_dataset, {
                'key1': [b'val1'],
                'key2': [b'val2'],
                'key3': [b'val3']
            })

        # Additional key, value
        with self.assertRaises(AssertionError):
            assert_dataset(first_dataset, {
                'key1': [b'val1'],
                'key2': [b'val2'],
                'key3': [b'val3']
            })
    def test_caching(self):
        task_name = "caching"
        x = [{
            "inputs": [7, 8],
            "targets": [3, 9],
            "targets_pretokenized": "ex 1"
        }, {
            "inputs": [8, 4],
            "targets": [4],
            "targets_pretokenized": "ex 2"
        }]
        dtypes = {
            "inputs": tf.int32,
            "targets": tf.int32,
            "targets_pretokenized": tf.string
        }
        shapes = {
            "inputs": [None],
            "targets": [None],
            "targets_pretokenized": []
        }
        ds = tf.data.Dataset.from_generator(lambda: x,
                                            output_types=dtypes,
                                            output_shapes=shapes)
        dataset_fn = lambda split, shuffle_files: ds
        register_dummy_task(task_name,
                            dataset_fn=dataset_fn,
                            metrics_fn=[_sequence_accuracy_metric])

        # Feature converter that just pads "inputs" and "targets".
        feature_converter = mock.Mock(get_model_feature_lengths=lambda x: {
            "inputs": 4,
            "targets": 4
        })
        feature_converter.side_effect = (
            lambda ds, length: utils.trim_and_pad_dataset(
                ds, {
                    "inputs": 4,
                    "targets": 4
                }))
        evaluator = Evaluator(mixture_or_task_name=task_name,
                              feature_converter=feature_converter,
                              eval_split="validation")
        expected_task_examples = [{
            "inputs": [7, 8, 1],
            "targets": [3, 9, 1],
            "targets_pretokenized": b"ex 1"
        }, {
            "inputs": [8, 4, 1],
            "targets": [4, 1],
            "targets_pretokenized": b"ex 2"
        }]
        expected_examples = [{
            "inputs": [7, 8, 1, 0],
            "targets": [3, 9, 1, 0],
            "targets_pretokenized": b"ex 1"
        }, {
            "inputs": [8, 4, 1, 0],
            "targets": [4, 1, 0, 0],
            "targets_pretokenized": b"ex 2"
        }]

        test_utils.assert_dataset(evaluator._cached_task_datasets[task_name],
                                  expected_task_examples)

        # _cached_model_datasets are enumerated. Remove the index for assertion.
        eval_ds = evaluator._cached_model_datasets[task_name].map(
            lambda i, ds: ds)
        test_utils.assert_dataset(eval_ds, expected_examples)
        self.assertEqual(evaluator.cached_targets[task_name], ["ex 1", "ex 2"])
        self.assertDictEqual(evaluator.model_feature_lengths, {
            "inputs": 4,
            "targets": 4
        })
Exemplo n.º 5
0
    def test_add_fully_cached_mixture(self):
        TaskRegistry.add('task1',
                         source=self.fake_source,
                         preprocessors=self.preprocessors,
                         output_features={
                             'targets': Feature(self.vocabulary, add_eos=False)
                         },
                         metric_fns=self.metrics_fns)

        TaskRegistry.add('task2',
                         source=self.fake_source,
                         preprocessors=self.preprocessors,
                         output_features={
                             'targets': Feature(self.vocabulary, add_eos=True)
                         },
                         metric_fns=self.metrics_fns)

        MixtureRegistry.add('mix', [('task1', 2), ('task2', lambda x: 1)])

        experimental.add_fully_cached_mixture('mix',
                                              sequence_length={'targets': 6})

        new_mix = MixtureRegistry.get('mix_6')
        new_task_names = ('task1_6', 'task2_6')
        self.assertContainsSubset(new_task_names, TaskRegistry.names())

        new_tasks = [TaskRegistry.get(n) for n in new_task_names]

        self.assertCountEqual(new_tasks, new_mix.tasks)
        self.assertEqual(new_mix.get_rate(new_tasks[0]), 2)
        self.assertEqual(new_mix.get_rate(new_tasks[1]), 1)

        with self.assertRaisesWithLiteralMatch(
                ValueError,
                "Task 'task1_6' requires caching, but was called with "
                "`use_cached=False`."):
            new_mix.get_dataset(None)

        # Disable caching restriction to get past cache check.
        for t in new_tasks:
            t.preprocessors[-2]._required = False

        with self.assertRaisesWithLiteralMatch(
                ValueError,
                "Fully-cached task 'task1_6' can only be loaded with "
                "`sequence_length={'targets': 6}` or `None`."):
            new_mix.get_dataset({'targets: 7'}, use_cached=False)

        expected_dataset = [
            {
                'targets': [1, 6, 6]
            },
            {
                'targets': [2, 6, 6]
            },
            {
                'targets': [1, 6]
            },
            {
                'targets': [1, 6, 6]
            },
            {
                'targets': [2, 6]
            },
            {
                'targets': [2, 6, 6]
            },
        ]

        test_utils.assert_dataset(
            new_mix.get_dataset(None, shuffle=False).take(6), expected_dataset)
        test_utils.assert_dataset(
            new_mix.get_dataset({
                'targets': 6
            }, shuffle=False).take(6), expected_dataset)