Example #1
0
def c4_bare_preprocess_fn(dataset,
                          training=True,
                          spm_path=None,
                          copy_plaintext=True,
                          sequence_length=None):
    """Returns a dataset that contains 'inputs' and 'targets' from C4."""
    # Set target key to be equal to the text content.
    dataset = t5_processors.rekey(dataset,
                                  key_map={
                                      'targets': 'text',
                                      'inputs': None
                                  })

    # Vocabulary for tokenization.
    vocab = t5_spc_vocab.SentencePieceVocabulary(
        sentencepiece_model_file=spm_path or t5_utils.DEFAULT_SPM_PATH)
    feature = t5_utils.Feature(vocab)
    output_features = {'targets': feature, 'inputs': feature}

    # Tokenize the targets.
    dataset = t5_utils.encode_string_features(dataset,
                                              output_features,
                                              keys=output_features,
                                              copy_plaintext=copy_plaintext)

    # Preprocess the tokens - the exact preprocessors are set via gin.
    dataset = t5_processors.unsupervised(dataset,
                                         sequence_length=sequence_length,
                                         output_features=output_features)

    # Add EOS.
    dataset = add_eos_to_output_features(dataset, training)

    return dataset
Example #2
0
def generic_text_dataset_preprocess_fn(dataset,
                                       text_preprocess_fn=None,
                                       spm_path=None,
                                       copy_plaintext=False):
    """Applies a text preprocess fn and tokenizes the dataset."""

    # The assumption is that `text_preprocess_fn` finally gives us a dataset
    # which has `inputs` and `targets`.
    if text_preprocess_fn is not None:
        dataset = text_preprocess_fn(dataset)

    # Vocabulary for tokenization.
    vocab = t5_spc_vocab.SentencePieceVocabulary(
        sentencepiece_model_file=spm_path or t5_utils.DEFAULT_SPM_PATH)
    feature = t5_utils.Feature(vocab)
    output_features = {'targets': feature, 'inputs': feature}

    # Tokenize the inputs and targets.
    dataset = t5_utils.encode_string_features(dataset,
                                              output_features,
                                              keys=output_features,
                                              copy_plaintext=copy_plaintext)

    return dataset
Example #3
0
def generic_text_dataset_preprocess_fn(dataset,
                                       training=True,
                                       text_preprocess_fns=None,
                                       token_preprocess_fns=None,
                                       spm_path=None,
                                       copy_plaintext=False,
                                       debug_print_examples=False,
                                       debug_print_examples_rate=0.01):
    """Pre-processes, tokenizes and post-processes a `tf.data.Dataset`.

  Args:
    dataset: `tf.data.Dataset` to process.
    training: boolean, set to True if training, False otherwise.
    text_preprocess_fns: None or list of callables: `tf.data.Dataset`, bool ->
      `tf.data.Dataset` this operates before tokenization. Typically used to
      select which fields we want to learn over or change something into
      "text to text" form.
    token_preprocess_fns: None or list of callables: `tf.data.Dataset`, bool ->
      `tf.data.Dataset`, this operates after tokenization. Since this can view
      the tokenized fields, this can be used to filter on length etc.
    spm_path: None or str, path to a sentencepiece model to use for tokenization
      by default uses the 32k vocabulary from T5.
    copy_plaintext: bool, if True retains the original fields after
      tokenization.
    debug_print_examples: bool, if True this prints examples to the logging
      stream for inspection, both before and after tokenization.
    debug_print_examples_rate: float, [0, 1.0], on average this fraction of
      dataset examples will be printed out in each phase i.e. pre and post
      tokenization.

  Returns:
    a `tf.data.Dataset` with all the preprocessing and tokenization performed.
  """

    # The assumption is that `text_preprocess_fns` finally gives us a dataset
    # which has `inputs` and `targets`.
    if text_preprocess_fns is not None:
        for text_preprocess_fn in text_preprocess_fns:
            dataset = text_preprocess_fn(dataset, training)

    # Print debugging examples if needed before tokenization.
    if debug_print_examples:

        def print_examples(x):
            if np.random.uniform() < debug_print_examples_rate:
                tf.print(x, output_stream=logging.info)
            return x

        dataset = dataset.map(print_examples)

    # Vocabulary for tokenization.
    vocab = t5_spc_vocab.SentencePieceVocabulary(
        sentencepiece_model_file=spm_path or t5_utils.DEFAULT_SPM_PATH)
    feature = t5_utils.Feature(vocab)
    output_features = {'targets': feature, 'inputs': feature}

    # Tokenize the inputs and targets.
    dataset = t5_utils.encode_string_features(dataset,
                                              output_features,
                                              keys=output_features,
                                              copy_plaintext=copy_plaintext)

    # Apply the token-preprocessors.
    if token_preprocess_fns is not None:
        for token_preprocess_fn in token_preprocess_fns:
            dataset = token_preprocess_fn(dataset, training)

    if debug_print_examples:

        def print_examples_and_shapes(x):
            if np.random.uniform() < debug_print_examples_rate:
                tf.print(
                    {
                        'inputs_shape': tf.size(x['inputs']),
                        'targets_shape': tf.size(x['targets']),
                        'inputs': x['inputs'],
                        'targets': x['targets'],
                    },
                    output_stream=logging.info)
            return x

        dataset = dataset.map(print_examples_and_shapes)

    return dataset
    def get_dataset(
        self,
        sequence_length,
        split=tfds.Split.TRAIN,
        use_cached=False,
        shuffle=True,
        shuffle_buffer_size=None,
        seed=None,
        copy_plaintext=True,
    ):
        """Returns a tf.data.Dataset from cache or generated on the fly.

    Args:
      sequence_length: dict mapping feature key to int length for that feature
      split: string, the split to return.
      use_cached: bool, whether to use the cached dataset instead of processing
        it on the fly. Defaults to False.
      shuffle: bool, whether to shuffle the dataset.  Only used when generating
        on the fly (use_cached=False).
      shuffle_buffer_size: an integer or None to use task-specific buffer size.
      seed: tf.int64 scalar tf.Tensor (or None) for shuffling tf.data.
      copy_plaintext: bool, whether to pass through copies of plaintext strings
        with a "_plaintext" suffix added to the key.
    Returns:
      A mixed tf.data.Dataset.
    """
        if seed is not None:
            logging.warning(
                ("Global random seed is now set to %d. All TF operations "
                 "are now deterministic with respect to that seed."), seed)
            tf.random.set_seed(seed)

        if use_cached and not self.supports_caching:
            logging.warning(
                "Task '%s' does not support caching. Switching to on-the-fly "
                "preprocessing.", self.name)
            use_cached = False
        if use_cached:
            ds = self._get_cached_dataset(split, shuffle)
        else:
            ds = self._dataset_fn(split=split, shuffle_files=shuffle)
            if seed is None:
                ds = self._dataset_fn(split=split, shuffle_files=shuffle)
            else:
                _validate_args(self._dataset_fn,
                               ["split", "shuffle_files", "seed"])
                ds = self._dataset_fn(split=split,
                                      shuffle_files=shuffle,
                                      seed=seed)
            ds = self.preprocess_text(ds)
            # Tokenize
            ds = utils.encode_string_features(ds,
                                              self.output_features,
                                              keys=self.output_features,
                                              copy_plaintext=copy_plaintext)

        if (not use_cached and self.num_input_examples(split) and
                self.num_input_examples(split) < _MAX_EXAMPLES_TO_MEM_CACHE):
            ds = ds.cache()

        # Post tokenization processing.
        ds = self.preprocess_tokens(ds, sequence_length)
        ds = maybe_print_dataset(ds)

        if shuffle:
            # Shuffle before mixing since preprocessor can output multiple
            # (correlated) examples per input.
            ds = ds.shuffle(shuffle_buffer_size or self._shuffle_buffer_size,
                            seed=seed)

        return ds
    def get_dataset(
        self,
        sequence_length,
        split=tfds.Split.TRAIN,
        use_cached=False,
        shuffle=True,
        shuffle_buffer_size=_SHUFFLE_BUFFER_SIZE,
        mode="train",
    ):
        """Returns a tf.data.Dataset from cache or generated on the fly.
        Args:
          sequence_length: dict mapping feature key to int length for that feature
          split: string, the split to return.
          use_cached: bool, whether to use the cached dataset instead of processing
            it on the fly. Defaults to True.
          shuffle: bool, whether to shuffle the dataset.  Only used when generating
            on the fly (use_cached=False).
          shuffle_buffer_size: an integer
          mode: string, "train" or "eval".
        Returns:
          A mixed tf.data.Dataset.
        """
        if use_cached:
            ds = self._get_cached_dataset(split, shuffle)
        else:
            ds = self._dataset_fn(split=split, shuffle_files=shuffle)
            if self.balance_attributes and mode == "train":
                ds = ds.filter(
                    functools.partial(balance_fn,
                                      balance_rate=self.balance_rate))
            ds = self.preprocess_text_ll(ds)
            # Tokenize
            ds = encode_string_features(ds,
                                        self.get_vocabulary(),
                                        keys=self.output_features,
                                        copy_plaintext=True)

        if (not use_cached and self.num_input_examples(split) and
                self.num_input_examples(split) < _MAX_EXAMPLES_TO_MEM_CACHE):
            ds = ds.cache()

        # Post tokenization processing.
        if (self.denoise and mode == "train") or (not self.denoise):
            ds = self.preprocess_tokens_ll(ds, sequence_length)

        if self.denoise and mode == "eval":

            def _trim_and_append_eos(feat, v):
                if feat == "attribute" or feat == "controlcode" or feat not in self.output_features:
                    return v
                return tf.concat([v[:sequence_length[feat] - 1], [1]], axis=0)

            return ds.map(
                lambda ex:
                {k: _trim_and_append_eos(k, v)
                 for k, v in ex.items()},
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

        if shuffle:
            # Shuffle before mixing since preprocessor can output multiple
            # (correlated) examples per input.
            ds = ds.shuffle(shuffle_buffer_size)

        return ds