def test_add_fully_cached_task(self): preprocessors = list(self.preprocessors) preprocessors.insert(2, CacheDatasetPlaceholder()) TaskRegistry.add('encoder_decoder_task', source=self.fake_source, preprocessors=preprocessors, output_features={ 'inputs': Feature(self.vocabulary, add_eos=True), 'targets': Feature(self.vocabulary, add_eos=False) }, metric_fns=self.metrics_fns) sequence_length = {'inputs': 5, 'targets': 6} actual_sequence_length = {'inputs': 6, 'targets': 7} experimental.add_fully_cached_task('encoder_decoder_task', sequence_length) self.validate_fully_cached_task('encoder_decoder_task_i5_t6', sequence_length, actual_sequence_length, [ { 'inputs': [1, 5, 5], 'targets': [1, 6] }, { 'inputs': [2, 5, 5], 'targets': [2, 6] }, ])
def test_add_fully_cached_task_unique_prefix(self): TaskRegistry.add('feature_prefix_task', source=self.fake_source, preprocessors=self.preprocessors, output_features={ 'tar': Feature(self.vocabulary, add_eos=True), 'targets': Feature(self.vocabulary, add_eos=False) }, metric_fns=self.metrics_fns) sequence_length = {'tar': 5, 'targets': 6} actual_sequence_length = {'tar': 6, 'targets': 7} experimental.add_fully_cached_task('feature_prefix_task', sequence_length) self.validate_fully_cached_task('feature_prefix_task_tar5_targ6', sequence_length, actual_sequence_length, [ { 'tar': [1, 5, 5], 'targets': [1, 6] }, { 'tar': [2, 5, 5], 'targets': [2, 6] }, ])
def test_add_fully_cached_task_single_feature(self): TaskRegistry.add('decoder_task', source=self.fake_source, preprocessors=self.preprocessors, output_features={ 'targets': Feature(self.vocabulary, add_eos=True) }, metric_fns=self.metrics_fns) sequence_length = {'targets': 6} experimental.add_fully_cached_task('decoder_task', sequence_length) self.validate_fully_cached_task('decoder_task_6', sequence_length, [ { 'targets': [1, 6, 6] }, { 'targets': [2, 6, 6] }, ])
def test_add_fully_cached_task_disallow_shuffling(self): TaskRegistry.add('decoder_task', source=self.fake_source, preprocessors=self.preprocessors, output_features={ 'targets': Feature(self.vocabulary, add_eos=True) }, metric_fns=self.metrics_fns) sequence_length = {'targets': 6} new_task = experimental.add_fully_cached_task('decoder_task', sequence_length, disallow_shuffling=True) # Disable caching restriction to get past cache check. new_task.preprocessors[-2]._required = False with self.assertRaisesWithLiteralMatch( ValueError, "Shuffling is disallowed for Task 'decoder_task_6' since its " '`shuffle_buffer_size` was set to `None` on construction.'): new_task.get_dataset(None, shuffle=True, use_cached=False) new_task.get_dataset(None, shuffle=False, use_cached=False)