def validate_fully_cached_task(self, name, sequence_length, expected_dataset): new_task = TaskRegistry.get(name) self.assertLen(new_task.preprocessors, 6) self.assertEqual(new_task.metric_fns, self.metrics_fns) self.assertIsInstance(new_task.preprocessors[-2], CacheDatasetPlaceholder) self.assertTrue(new_task.preprocessors[-2].required) with self.assertRaisesWithLiteralMatch( ValueError, f"Task '{name}' requires caching, but was called with " "`use_cached=False`."): new_task.get_dataset(None) # Disable caching restriction to verify dataset is correct. new_task.preprocessors[-2]._required = False with self.assertRaisesWithLiteralMatch( ValueError, f"Fully-cached task '{name}' can only be loaded with " f'`sequence_length={sequence_length}` or `None`.'): new_task.get_dataset( {k: v + 1 for k, v in sequence_length.items()}, use_cached=False) test_utils.assert_dataset(new_task.get_dataset(None, shuffle=False), expected_dataset) test_utils.assert_dataset( new_task.get_dataset(sequence_length, shuffle=False), expected_dataset)
def test_get_targets_and_examples(self): def _task_from_tensor_slices(name, tensor_slices, label_classes): return dataset_providers.Task( name, dataset_providers.FunctionDataSource( lambda split, shuffle_files: tf.data.Dataset.from_tensor_slices(tensor_slices), splits=("validation")), preprocessors=[utils.map_over_dataset(lambda ex: { "inputs": tf.range(ex["inputs_lengths"]), "targets": tf.range(ex["targets_lengths"]), "targets_pretokenized": ex["targets_pretokenized"], })], postprocess_fn=functools.partial( _string_label_to_class_id_postprocessor, label_classes=label_classes), output_features={"inputs": dataset_providers.Feature(mock.Mock()), "targets": dataset_providers.Feature(mock.Mock())} ) task1 = _task_from_tensor_slices( "task1", { "inputs_lengths": [3, 2], "targets_lengths": [2, 3], "targets_pretokenized": ["e6", "e5"], }, ("e4", "e5", "e6")) task2 = _task_from_tensor_slices( "task2", { "inputs_lengths": [1], "targets_lengths": [4], "targets_pretokenized": ["e4"], }, ("e2", "e3", "e4")) cached_targets, cached_task_datasets, max_sequence_length = ( evaluation.get_targets_and_examples( [task1, task2], lambda t: t.get_dataset( split="validation", sequence_length=None, shuffle=False)) ) self.assertDictEqual({"task1": [2, 1], "task2": [2]}, cached_targets) self.assertDictEqual({"inputs": 3, "targets": 4}, max_sequence_length) self.assertCountEqual(["task1", "task2"], cached_task_datasets.keys()) self.assertLen(cached_task_datasets["task1"], 2) self.assertLen(cached_task_datasets["task2"], 1) expected_task1_examples = [ {"inputs": [0, 1, 2], "targets": [0, 1], "targets_pretokenized": "e6"}, {"inputs": [0, 1], "targets": [0, 1, 2], "targets_pretokenized": "e5"} ] expected_task2_examples = [ {"inputs": [0], "targets": [0, 1, 2, 3], "targets_pretokenized": "e4"}, ] test_utils.assert_dataset(cached_task_datasets["task1"], expected_task1_examples) test_utils.assert_dataset(cached_task_datasets["task2"], expected_task2_examples)
def test_assert_dataset(self): first_dataset = tf.data.Dataset.from_tensor_slices({ 'key1': ['val1'], 'key2': ['val2'] }) # Equal assert_dataset(first_dataset, {'key1': [b'val1'], 'key2': [b'val2']}) assert_dataset(first_dataset, { 'key1': [b'val1'], 'key2': [b'val2'] }, expected_dtypes={'key1': tf.string}) # Unequal value with self.assertRaises(AssertionError): assert_dataset(first_dataset, { 'key1': [b'val1'], 'key2': [b'val2x'] }) # Wrong dtype with self.assertRaises(AssertionError): assert_dataset(first_dataset, { 'key1': [b'val1'], 'key2': [b'val2'] }, expected_dtypes={'key1': tf.int32}) # Additional key, value with self.assertRaises(AssertionError): assert_dataset(first_dataset, { 'key1': [b'val1'], 'key2': [b'val2'], 'key3': [b'val3'] }) # Additional key, value with self.assertRaises(AssertionError): assert_dataset(first_dataset, { 'key1': [b'val1'], 'key2': [b'val2'], 'key3': [b'val3'] })
def test_caching(self): task_name = "caching" x = [{ "inputs": [7, 8], "targets": [3, 9], "targets_pretokenized": "ex 1" }, { "inputs": [8, 4], "targets": [4], "targets_pretokenized": "ex 2" }] dtypes = { "inputs": tf.int32, "targets": tf.int32, "targets_pretokenized": tf.string } shapes = { "inputs": [None], "targets": [None], "targets_pretokenized": [] } ds = tf.data.Dataset.from_generator(lambda: x, output_types=dtypes, output_shapes=shapes) dataset_fn = lambda split, shuffle_files: ds register_dummy_task(task_name, dataset_fn=dataset_fn, metrics_fn=[_sequence_accuracy_metric]) # Feature converter that just pads "inputs" and "targets". feature_converter = mock.Mock(get_model_feature_lengths=lambda x: { "inputs": 4, "targets": 4 }) feature_converter.side_effect = ( lambda ds, length: utils.trim_and_pad_dataset( ds, { "inputs": 4, "targets": 4 })) evaluator = Evaluator(mixture_or_task_name=task_name, feature_converter=feature_converter, eval_split="validation") expected_task_examples = [{ "inputs": [7, 8, 1], "targets": [3, 9, 1], "targets_pretokenized": b"ex 1" }, { "inputs": [8, 4, 1], "targets": [4, 1], "targets_pretokenized": b"ex 2" }] expected_examples = [{ "inputs": [7, 8, 1, 0], "targets": [3, 9, 1, 0], "targets_pretokenized": b"ex 1" }, { "inputs": [8, 4, 1, 0], "targets": [4, 1, 0, 0], "targets_pretokenized": b"ex 2" }] test_utils.assert_dataset(evaluator._cached_task_datasets[task_name], expected_task_examples) # _cached_model_datasets are enumerated. Remove the index for assertion. eval_ds = evaluator._cached_model_datasets[task_name].map( lambda i, ds: ds) test_utils.assert_dataset(eval_ds, expected_examples) self.assertEqual(evaluator.cached_targets[task_name], ["ex 1", "ex 2"]) self.assertDictEqual(evaluator.model_feature_lengths, { "inputs": 4, "targets": 4 })
def test_add_fully_cached_mixture(self): TaskRegistry.add('task1', source=self.fake_source, preprocessors=self.preprocessors, output_features={ 'targets': Feature(self.vocabulary, add_eos=False) }, metric_fns=self.metrics_fns) TaskRegistry.add('task2', source=self.fake_source, preprocessors=self.preprocessors, output_features={ 'targets': Feature(self.vocabulary, add_eos=True) }, metric_fns=self.metrics_fns) MixtureRegistry.add('mix', [('task1', 2), ('task2', lambda x: 1)]) experimental.add_fully_cached_mixture('mix', sequence_length={'targets': 6}) new_mix = MixtureRegistry.get('mix_6') new_task_names = ('task1_6', 'task2_6') self.assertContainsSubset(new_task_names, TaskRegistry.names()) new_tasks = [TaskRegistry.get(n) for n in new_task_names] self.assertCountEqual(new_tasks, new_mix.tasks) self.assertEqual(new_mix.get_rate(new_tasks[0]), 2) self.assertEqual(new_mix.get_rate(new_tasks[1]), 1) with self.assertRaisesWithLiteralMatch( ValueError, "Task 'task1_6' requires caching, but was called with " "`use_cached=False`."): new_mix.get_dataset(None) # Disable caching restriction to get past cache check. for t in new_tasks: t.preprocessors[-2]._required = False with self.assertRaisesWithLiteralMatch( ValueError, "Fully-cached task 'task1_6' can only be loaded with " "`sequence_length={'targets': 6}` or `None`."): new_mix.get_dataset({'targets: 7'}, use_cached=False) expected_dataset = [ { 'targets': [1, 6, 6] }, { 'targets': [2, 6, 6] }, { 'targets': [1, 6] }, { 'targets': [1, 6, 6] }, { 'targets': [2, 6] }, { 'targets': [2, 6, 6] }, ] test_utils.assert_dataset( new_mix.get_dataset(None, shuffle=False).take(6), expected_dataset) test_utils.assert_dataset( new_mix.get_dataset({ 'targets': 6 }, shuffle=False).take(6), expected_dataset)