コード例 #1
0
    def test_tokenize(self):
        og_dataset = tf.data.Dataset.from_tensors({
            'prefix': 'This is',
            'suffix': 'a test.'
        })
        output_features = {
            'prefix': Feature(test_utils.MockVocabulary({'This is': [0, 1]})),
            'suffix': Feature(test_utils.MockVocabulary({'a test.': [2, 3]})),
        }

        assert_dataset(
            prep.tokenize(og_dataset, output_features=output_features), {
                'prefix': [0, 1],
                'prefix_plaintext': 'This is',
                'suffix': [2, 3],
                'suffix_plaintext': 'a test.'
            })
        assert_dataset(
            prep.tokenize(og_dataset,
                          output_features=output_features,
                          copy_plaintext=False), {
                              'prefix': [0, 1],
                              'suffix': [2, 3]
                          })
コード例 #2
0
 def test_denoise_nested_decorators(self):
     """Test whether gin and utils.map_over_dataset decorators are compatible."""
     bindings = """
   preprocessors.unsupervised.preprocessors = [@preprocessors.denoise]
   preprocessors.denoise.noise_density = 0.15
   preprocessors.denoise.noise_mask_fn = @preprocessors.iid_noise_mask
   preprocessors.denoise.inputs_fn = @noise_token_to_sentinel
 """
     gin.parse_config(bindings)
     og_dataset = tf.data.Dataset.from_tensor_slices({'targets': [1, 2, 3]})
     output_features = {
         'targets': Feature(test_utils.sentencepiece_vocab())
     }
     # Test denoise function when it is used as a gin-configurable of another
     # gin-configurable, prep.unsupervised.
     dataset = prep.unsupervised(og_dataset,
                                 output_features=output_features)
     self.assertIsInstance(dataset, tf.data.Dataset)
  def test_denoise(self):
    tf.random.set_seed(55)

    vocab = test_utils.sentencepiece_vocab()
    target_tokens = vocab.encode('The quick brown fox.')

    # This is what it encodes to.
    self.assertEqual(
        target_tokens,
        [3, 2, 20, 4, 3, 2, 8, 13, 2, 3, 2, 23, 7, 19, 22, 3, 2, 7, 2])

    og_dataset = tf.data.Dataset.from_tensor_slices({
        'targets': [target_tokens],
    })

    output_features = {
        'targets': Feature(vocab),
    }

    # These are the parameters of denoise in the operative config of 'base'.
    # Except noise_density, bumped up from 0.15 to 0.3 in order to demonstrate
    # multiple corrupted spans.
    denoised_dataset = prep.denoise(
        og_dataset,
        output_features,
        noise_density=0.3,
        noise_mask_fn=prep.random_spans_noise_mask,
        inputs_fn=prep.noise_span_to_unique_sentinel,
        targets_fn=prep.nonnoise_span_to_unique_sentinel)

    # Two spans corrupted, [2] and [22, 3, 2, 7, 2], replaced by unique
    # sentinels 25 and 24 respectively.
    assert_dataset(denoised_dataset, [
        {
            'inputs': [
                3, 25, 20, 4, 3, 2, 8, 13, 2, 3, 2, 23, 7, 19, 24
            ],
            'targets': [
                25, 2, 24, 22, 3, 2, 7, 2
            ],
        },
    ])
コード例 #4
0
 def test_prefix_lm(self):
     vocab = test_utils.sentencepiece_vocab()
     inp = list(range(1, 101))
     og_dataset = tf.data.Dataset.from_tensor_slices({'targets': [inp]})
     og_dataset = og_dataset.repeat(100)
     output_features = {'targets': Feature(vocab)}
     output_dataset = prep.prefix_lm(
         og_dataset,
         {
             'inputs': 100,
             'targets': 100
         },
         output_features,
     )
     input_lengths = set()
     for ex in output_dataset.as_numpy_iterator():
         self.assertListEqual(
             ex['inputs'].tolist() + ex['targets'].tolist(), inp)
         input_lengths.add(len(ex['inputs']))
     self.assertGreater(len(input_lengths), 1)
コード例 #5
0
from t5.data import preprocessors
from t5.data.dataset_providers import Feature
from t5.data.dataset_providers import TaskRegistry
from t5.data.dataset_providers import TfdsTask
from t5.data.glue_utils import get_glue_metric
from t5.data.glue_utils import get_glue_postprocess_fn
from t5.data.glue_utils import get_glue_text_preprocessor
from t5.data.glue_utils import get_super_glue_metric
from t5.data.utils import get_default_vocabulary
from t5.data.utils import set_global_cache_dirs
from t5.evaluation import metrics
import tensorflow_datasets as tfds

DEFAULT_OUTPUT_FEATURES = {
    "inputs":
    Feature(vocabulary=get_default_vocabulary(), add_eos=True, required=False),
    "targets":
    Feature(vocabulary=get_default_vocabulary(), add_eos=True)
}

# ==================================== C4 ======================================
# Final pretraining task used in Raffel et al., 2019.
TaskRegistry.add("c4_v220_span_corruption",
                 TfdsTask,
                 tfds_name="c4/en:2.2.0",
                 text_preprocessor=functools.partial(preprocessors.rekey,
                                                     key_map={
                                                         "inputs": None,
                                                         "targets": "text"
                                                     }),
                 token_preprocessor=preprocessors.span_corruption,
コード例 #6
0
from t5.data.dataset_providers import Feature
from t5.data.dataset_providers import TaskRegistry
from t5.data.dataset_providers import TfdsTask
from t5.data.glue_utils import get_glue_metric
from t5.data.glue_utils import get_glue_postprocess_fn
from t5.data.glue_utils import get_glue_text_preprocessor
from t5.data.glue_utils import get_super_glue_metric
from t5.data.utils import get_default_vocabulary
from t5.data.utils import set_global_cache_dirs
from t5.evaluation import metrics
import tensorflow_datasets as tfds



DEFAULT_OUTPUT_FEATURES = {
    "inputs": Feature(
        vocabulary=get_default_vocabulary, add_eos=True, required=False),
    "targets": Feature(vocabulary=get_default_vocabulary, add_eos=True)
}

# ==================================== C4 ======================================
# Final pretraining task used in Raffel et al., 2019.
TaskRegistry.add(
    "c4_v220_span_corruption",
    TfdsTask,
    tfds_name="c4/en:2.2.0",
    text_preprocessor=functools.partial(
        preprocessors.rekey, key_map={"inputs": None, "targets": "text"}),
    token_preprocessor=preprocessors.span_corruption,
    output_features=DEFAULT_OUTPUT_FEATURES,
    metric_fns=[])