def estimator_input_fn(params):
   """Eval input function for estimator."""
   del params
   # Concatenate all dataset inputs to only have to do one decode loop
   combined_ds = None
   for task in tasks:
     ds = t5.models.mesh_transformer.mesh_eval_dataset_fn(
         mixture_or_task_name=task.name,
         sequence_length=sequence_length,
         dataset_split=split)[0].dataset_fn()
     ds = ds.map(utils.filter_features)
     combined_ds = ds if not combined_ds else combined_ds.concatenate(ds)
   combined_ds = combined_ds.batch(self.batch_size, drop_remainder=False)
   # Pad the final batch.
   combined_ds = transformer_dataset.trim_and_pad_dataset(
       combined_ds, length=self.batch_size)
   combined_ds = combined_ds.prefetch(tf.data.experimental.AUTOTUNE)
   return combined_ds
Example #2
0
def pack_or_pad_ll(dataset,
                   length,
                   pack=True,
                   feature_keys=None,
                   ensure_eos=False,
                   shift_decoder_output=False,
                   target_prefix_attributes=None,
                   tokenizer=None):
    """Creates a 'packed' version of a dataset or pads examples with zeros.
  If pack=True, then multiple examples concatenated to form one combined
  example with the given length.
  If pack=False, then examples are padded with zeros to 'length'.
  Args:
    dataset: a tf.data.Dataset
    length: an integer or a dict from feature-key to integer
    pack: a boolean, whether to pack (True) or pad (False).
    feature_keys: (optional) list of strings, the feature names to limit
      packing or padding to. Packing will filter out other features whereas
      padding will pass them through unchanged. Defaults to all features.
    ensure_eos: a boolean, whether to replace the final token with EOS=1 if it
      is not PAD=0.
  Returns:
    a tf.data.Dataset where all features have fixed shape [length].
  """
    feature_keys = feature_keys or list(dataset.output_shapes.keys())
    if shift_decoder_output:
        left_pad_amts = [
            len(tokenizer.encode(target_prefix_attribute)) - 1
            for target_prefix_attribute in target_prefix_attributes
        ]
        dataset = shift_decoder_output_fn(dataset,
                                          left_pad_amts=left_pad_amts,
                                          feature_keys=feature_keys)
    if pack:
        dataset = pack_dataset(dataset, length=length, keys=feature_keys)
    # Pad/trim length of each example to length.
    dataset = trim_and_pad_dataset(dataset,
                                   length=length,
                                   feature_keys=feature_keys)
    if ensure_eos:
        dataset = ensure_dataset_eos_ll(dataset, feature_keys)
    return dataset
Example #3
0
 def input_fn(params):
     """Eval input function for estimator."""
     del params
     # Concatenate all dataset inputs to only have to do one decode loop
     combined_ds = None
     for eval_dataset in eval_datasets:
         # Only cache targets for those tasks with eval functions provides
         if eval_dataset.metric_fns:
             ds = eval_dataset.dataset_fn()
             # Only pass those variables which will be used for decoding
             ds = ds.map(
                 lambda x:
                 {k: v
                  for k, v in x.items() if k in _INPUT_FEATURES_ll})
             combined_ds = ds if not combined_ds else combined_ds.concatenate(
                 ds)
     combined_ds = combined_ds.batch(batch_size, drop_remainder=False)
     # Pad the final batch.
     combined_ds = transformer_dataset.trim_and_pad_dataset(
         combined_ds, length=batch_size)
     combined_ds = combined_ds.prefetch(tf.data.experimental.AUTOTUNE)
     return combined_ds