def test_add_fully_cached_task(self):
        preprocessors = list(self.preprocessors)
        preprocessors.insert(2, CacheDatasetPlaceholder())

        TaskRegistry.add('encoder_decoder_task',
                         source=self.fake_source,
                         preprocessors=preprocessors,
                         output_features={
                             'inputs': Feature(self.vocabulary, add_eos=True),
                             'targets': Feature(self.vocabulary, add_eos=False)
                         },
                         metric_fns=self.metrics_fns)

        sequence_length = {'inputs': 5, 'targets': 6}
        actual_sequence_length = {'inputs': 6, 'targets': 7}
        experimental.add_fully_cached_task('encoder_decoder_task',
                                           sequence_length)
        self.validate_fully_cached_task('encoder_decoder_task_i5_t6',
                                        sequence_length,
                                        actual_sequence_length, [
                                            {
                                                'inputs': [1, 5, 5],
                                                'targets': [1, 6]
                                            },
                                            {
                                                'inputs': [2, 5, 5],
                                                'targets': [2, 6]
                                            },
                                        ])
    def test_add_fully_cached_task_unique_prefix(self):
        TaskRegistry.add('feature_prefix_task',
                         source=self.fake_source,
                         preprocessors=self.preprocessors,
                         output_features={
                             'tar': Feature(self.vocabulary, add_eos=True),
                             'targets': Feature(self.vocabulary, add_eos=False)
                         },
                         metric_fns=self.metrics_fns)

        sequence_length = {'tar': 5, 'targets': 6}
        actual_sequence_length = {'tar': 6, 'targets': 7}
        experimental.add_fully_cached_task('feature_prefix_task',
                                           sequence_length)
        self.validate_fully_cached_task('feature_prefix_task_tar5_targ6',
                                        sequence_length,
                                        actual_sequence_length, [
                                            {
                                                'tar': [1, 5, 5],
                                                'targets': [1, 6]
                                            },
                                            {
                                                'tar': [2, 5, 5],
                                                'targets': [2, 6]
                                            },
                                        ])
Exemplo n.º 3
0
    def test_add_fully_cached_task_single_feature(self):
        TaskRegistry.add('decoder_task',
                         source=self.fake_source,
                         preprocessors=self.preprocessors,
                         output_features={
                             'targets': Feature(self.vocabulary, add_eos=True)
                         },
                         metric_fns=self.metrics_fns)

        sequence_length = {'targets': 6}
        experimental.add_fully_cached_task('decoder_task', sequence_length)
        self.validate_fully_cached_task('decoder_task_6', sequence_length, [
            {
                'targets': [1, 6, 6]
            },
            {
                'targets': [2, 6, 6]
            },
        ])
    def test_add_fully_cached_task_disallow_shuffling(self):
        TaskRegistry.add('decoder_task',
                         source=self.fake_source,
                         preprocessors=self.preprocessors,
                         output_features={
                             'targets': Feature(self.vocabulary, add_eos=True)
                         },
                         metric_fns=self.metrics_fns)

        sequence_length = {'targets': 6}
        new_task = experimental.add_fully_cached_task('decoder_task',
                                                      sequence_length,
                                                      disallow_shuffling=True)

        # Disable caching restriction to get past cache check.
        new_task.preprocessors[-2]._required = False

        with self.assertRaisesWithLiteralMatch(
                ValueError,
                "Shuffling is disallowed for Task 'decoder_task_6' since its "
                '`shuffle_buffer_size` was set to `None` on construction.'):
            new_task.get_dataset(None, shuffle=True, use_cached=False)

        new_task.get_dataset(None, shuffle=False, use_cached=False)