コード例 #1
0
 def _task_from_tensor_slices(name, tensor_slices, label_classes):
   return dataset_providers.Task(
       name,
       dataset_providers.FunctionDataSource(
           lambda split, shuffle_files:
           tf.data.Dataset.from_tensor_slices(tensor_slices),
           splits=("validation")),
       preprocessors=[utils.map_over_dataset(lambda ex: {
           "inputs": tf.range(ex["inputs_lengths"]),
           "targets": tf.range(ex["targets_lengths"]),
           "targets_pretokenized": ex["targets_pretokenized"],
       })],
       postprocess_fn=functools.partial(
           _string_label_to_class_id_postprocessor,
           label_classes=label_classes),
       output_features={"inputs": dataset_providers.Feature(mock.Mock()),
                        "targets": dataset_providers.Feature(mock.Mock())}
   )
コード例 #2
0
    def test_dtype(self):
        default_vocab = test_utils.sentencepiece_vocab()
        features = {
            "inputs":
            # defaults to int32
            dataset_providers.Feature(vocabulary=default_vocab),
            "targets":
            dataset_providers.Feature(dtype=tf.int64,
                                      vocabulary=default_vocab),
        }

        self.add_task(
            "task_dtypes",
            self.function_source,
            preprocessors=self.DEFAULT_PREPROCESSORS + (
                utils.map_over_dataset(
                    lambda x: {
                        k: tf.cast(v, tf.int64) if k == "targets" else v  # pylint:disable=g-long-lambda
                        for k, v in x.items()
                    }), ),
            output_features=features)
        self.verify_task_matches_fake_datasets("task_dtypes", use_cached=False)
    def test_fewshot_data_source(self):
        def fake_dataset_fn(split, shuffle_files):
            del shuffle_files
            return tf.data.Dataset.range(
                *((0, 2) if split == 'validation' else (3, 5)))

        # 0 shot
        src = experimental.FewshotDataSource(
            dataset_providers.FunctionDataSource(
                dataset_fn=fake_dataset_fn, splits=['train', 'validation']),
            num_shots=0)
        dataset = src.get_dataset('validation')
        assert_dataset(dataset, [{
            'eval': 0,
        }, {
            'eval': 1
        }])

        # 3 shot
        src = experimental.FewshotDataSource(
            dataset_providers.FunctionDataSource(
                dataset_fn=fake_dataset_fn, splits=['train', 'validation']),
            train_preprocessors=[
                utils.map_over_dataset(lambda x: {
                    'inputs': 0,
                    'targets': x
                })
            ],
            num_shots=3)
        dataset = src.get_dataset('validation')
        assert_dataset(dataset, [
            {
                'eval': 0,
                'train': {
                    'inputs': [0, 0, 0],
                    'targets': [3, 4, 3]
                }
            },
            {
                'eval': 1,
                'train': {
                    'inputs': [0, 0, 0],
                    'targets': [4, 3, 4]
                }
            },
        ])

        # 3-shot, sharded.
        assert_dataset(
            src.get_dataset('validation', shard_info=ShardInfo(0, 2)), [
                {
                    'eval': 0,
                    'train': {
                        'inputs': [0, 0, 0],
                        'targets': [3, 3, 3]
                    }
                },
            ])
        assert_dataset(
            src.get_dataset('validation', shard_info=ShardInfo(1, 2)), [
                {
                    'eval': 1,
                    'train': {
                        'inputs': [0, 0, 0],
                        'targets': [4, 4, 4]
                    }
                },
            ])

        # Missing train
        src = experimental.FewshotDataSource(
            dataset_providers.FunctionDataSource(dataset_fn=fake_dataset_fn,
                                                 splits=['validation']),
            num_shots=3)
        with self.assertRaisesRegex(
                ValueError,
                'Train split \'train\' is not one of the original source splits: '
                r'\(\'validation\',\)'):
            dataset = src.get_dataset('validation')