Exemple #1
0
    def decode_example(self, serialized_example):
        """Return a dict of Tensors from a serialized tensorflow.Example."""
        data_fields, data_items_to_decoders = self.example_reading_spec()
        # Necessary to rejoin examples in the correct order with the Cloud ML Engine
        # batch prediction API.
        data_fields["batch_prediction_key"] = tf.FixedLenFeature([1], tf.int64,
                                                                 0)

        if getattr(self._hparams, "sampling_method",
                   "") == "random_per_example":
            data_fields["sampling_temp"] = tf.FixedLenFeature(
                [1], tf.float32, getattr(self._hparams, "sampling_temp", 1.0))
            data_fields["sampling_keep_top_k"] = tf.FixedLenFeature(
                [1], tf.int64, getattr(self._hparams, "sampling_keep_top_k",
                                       -1))

        if data_items_to_decoders is None:
            data_items_to_decoders = {}
            for field in data_fields:
                if data_fields[field].dtype is tf.string:
                    default_value = b""
                else:
                    default_value = 0
                data_items_to_decoders[field] = contrib.slim(
                ).tfexample_decoder.Tensor(field, default_value=default_value)

        decoder = contrib.slim().tfexample_decoder.TFExampleDecoder(
            data_fields, data_items_to_decoders)

        decode_items = list(sorted(data_items_to_decoders))
        decoded = decoder.decode(serialized_example, items=decode_items)
        return dict(zip(decode_items, decoded))
Exemple #2
0
    def example_reading_spec(self):
        data_fields = {"dist_targets": tf.VarLenFeature(tf.int64)}

        if self.has_inputs:
            data_fields["inputs"] = tf.VarLenFeature(tf.int64)

        # hack: ignoring true targets and putting dist_targets in targets
        data_items_to_decoders = {
            "inputs": contrib.slim().tfexample_decoder.Tensor("inputs"),
            "targets": contrib.slim().tfexample_decoder.Tensor("dist_targets"),
        }

        return (data_fields, data_items_to_decoders)
 def extra_reading_spec(self):
   """Additional data fields to store on disk and their decoders."""
   data_fields = {
       "frame_number": tf.FixedLenFeature([1], tf.int64),
       "action": tf.FixedLenFeature([4], tf.float32),
   }
   decoders = {
       "frame_number":
           contrib.slim().tfexample_decoder.Tensor(tensor_key="frame_number"),
       "action":
           contrib.slim().tfexample_decoder.Tensor(tensor_key="action"),
   }
   return data_fields, decoders
Exemple #4
0
  def example_reading_spec(self):
    slim = contrib.slim()
    data_fields, data_items_to_decoders = {}, {}
    data_fields["image/feature"] = tf.FixedLenSequenceFeature(
        (), tf.float32, allow_missing=True)
    data_fields["image/spatial_feature"] = tf.FixedLenSequenceFeature(
        (), tf.float32, allow_missing=True)
    data_fields["image/image_id"] = tf.FixedLenFeature((), tf.int64)
    data_fields["image/question_id"] = tf.FixedLenFeature((), tf.int64)
    data_fields["image/question"] = tf.FixedLenSequenceFeature(
        (), tf.int64, allow_missing=True)
    data_fields["image/answer"] = tf.FixedLenSequenceFeature(
        (), tf.int64, allow_missing=True)

    data_items_to_decoders["inputs"] = slim.tfexample_decoder.Tensor(
        "image/feature")
    data_items_to_decoders["question_id"] = slim.tfexample_decoder.Tensor(
        "image/question_id")
    data_items_to_decoders["image_id"] = slim.tfexample_decoder.Tensor(
        "image/image_id")

    data_items_to_decoders["spatial_feature"] = slim.tfexample_decoder.Tensor(
        "image/spatial_feature")
    data_items_to_decoders["question"] = slim.tfexample_decoder.Tensor(
        "image/question")
    data_items_to_decoders["targets"] = slim.tfexample_decoder.Tensor(
        "image/answer")

    return data_fields, data_items_to_decoders
    def example_reading_spec(self):
        """Data fields to store on disk and their decoders."""

        # Subclasses can override and/or extend.

        processed_reward_type = tf.float32
        if self.is_processed_rewards_discrete:
            processed_reward_type = tf.int64

        data_fields = {
            TIMESTEP_FIELD:
            tf.FixedLenFeature((1, ), tf.int64),
            RAW_REWARD_FIELD:
            tf.FixedLenFeature((1, ), tf.float32),
            PROCESSED_REWARD_FIELD:
            tf.FixedLenFeature((1, ), processed_reward_type),
            DONE_FIELD:
            tf.FixedLenFeature((1, ), tf.int64),  # we wrote this as int.

            # Special treatment because we need to determine type and shape, also
            # enables classes to override.
            OBSERVATION_FIELD:
            self.observation_spec,
            ACTION_FIELD:
            self.action_spec,
        }

        data_items_to_decoders = {
            field: contrib.slim().tfexample_decoder.Tensor(field)
            for field in data_fields
        }

        return data_fields, data_items_to_decoders
Exemple #6
0
 def example_reading_spec(self):
   label_key = "image/class/label"
   data_fields, data_items_to_decoders = (
       super(Video2ClassProblem, self).example_reading_spec())
   data_fields[label_key] = tf.FixedLenFeature((1,), tf.int64)
   data_items_to_decoders["targets"] = contrib.slim().tfexample_decoder.Tensor(
       label_key)
   return data_fields, data_items_to_decoders
Exemple #7
0
 def example_reading_spec(self):
     label_key = "image/unpadded_label"
     data_fields, data_items_to_decoders = (super(
         ImageFSNS, self).example_reading_spec())
     data_fields[label_key] = tf.VarLenFeature(tf.int64)
     data_items_to_decoders["targets"] = contrib.slim(
     ).tfexample_decoder.Tensor(label_key)
     return data_fields, data_items_to_decoders
  def decode_example(self, serialized_example):
    """Return a dict of Tensors from a serialized tensorflow.Example."""
    data_fields, data_items_to_decoders = self.example_reading_spec()
    # Necessary to rejoin examples in the correct order with the Cloud ML Engine
    # batch prediction API.
    data_fields["batch_prediction_key"] = tf.FixedLenFeature([1], tf.int64, 0)
    if data_items_to_decoders is None:
      data_items_to_decoders = {
          field: contrib.slim().tfexample_decoder.Tensor(field)
          for field in data_fields
      }

    decoder = contrib.slim().tfexample_decoder.TFExampleDecoder(
        data_fields, data_items_to_decoders)

    decode_items = list(sorted(data_items_to_decoders))
    decoded = decoder.decode(serialized_example, items=decode_items)
    return dict(zip(decode_items, decoded))
Exemple #9
0
 def extra_reading_spec(self):
     """Additional data fields to store on disk and their decoders."""
     field_names = ("frame_number", "action", "reward", "done")
     data_fields = {
         name: tf.FixedLenFeature([1], tf.int64)
         for name in field_names
     }
     decoders = {
         name: contrib.slim().tfexample_decoder.Tensor(tensor_key=name)
         for name in field_names
     }
     return (data_fields, decoders)
Exemple #10
0
def decode_example(serialized_example):
    """Return a dict of Tensors from a serialized tensorflow.Example."""

    data_fields = {'targets_rel': tf.FixedLenFeature([51*10], tf.float32),
                   'targets_rnd': tf.FixedLenFeature([64*64], tf.float32),
                   'targets_sln': tf.FixedLenFeature([1], tf.int64),
                   'targets_cls': tf.FixedLenFeature([1], tf.int64)}

    # Necessary to rejoin examples in the correct order with the Cloud ML Engine
    # batch prediction API.
    
    data_fields["batch_prediction_key"] = tf.FixedLenFeature([1], tf.int64, 0)

    data_items_to_decoders = {
        field: contrib.slim().tfexample_decoder.Tensor(field) for field in data_fields
    }

    decoder = contrib.slim().tfexample_decoder.TFExampleDecoder(data_fields, data_items_to_decoders)

    decode_items = list(sorted(data_items_to_decoders))
    
    decoded = decoder.decode(serialized_example, items=decode_items)
    return dict(zip(decode_items, decoded))
Exemple #11
0
    def example_reading_spec(self):
        data_fields = {
            "image/encoded": tf.FixedLenFeature((), tf.string),
            "image/format": tf.FixedLenFeature((), tf.string),
        }

        data_items_to_decoders = {
            "inputs":
            contrib.slim().tfexample_decoder.Image(image_key="image/encoded",
                                                   format_key="image/format",
                                                   channels=self.num_channels),
        }

        return data_fields, data_items_to_decoders
Exemple #12
0
  def example_reading_spec(self):
    data_fields, data_items_to_decoders = (
        super(ImageVqav2Tokens10kLabels3k, self).example_reading_spec())
    data_fields["image/image_id"] = tf.FixedLenFeature((), tf.int64)
    data_fields["image/question_id"] = tf.FixedLenFeature((), tf.int64)
    data_fields["image/question"] = tf.FixedLenSequenceFeature(
        (), tf.int64, allow_missing=True)
    data_fields["image/answer"] = tf.FixedLenSequenceFeature(
        (), tf.int64, allow_missing=True)

    slim = contrib.slim()
    data_items_to_decoders["question"] = slim.tfexample_decoder.Tensor(
        "image/question")
    data_items_to_decoders["targets"] = slim.tfexample_decoder.Tensor(
        "image/answer")
    return data_fields, data_items_to_decoders
Exemple #13
0
  def example_reading_spec(self):
    extra_data_fields, extra_data_items_to_decoders = self.extra_reading_spec

    data_fields = {
        "image/encoded": tf.FixedLenFeature((), tf.string),
        "image/format": tf.FixedLenFeature((), tf.string),
    }
    data_fields.update(extra_data_fields)

    data_items_to_decoders = {
        "frame":
            contrib.slim().tfexample_decoder.Image(
                image_key="image/encoded",
                format_key="image/format",
                shape=[self.frame_height, self.frame_width, self.num_channels],
                channels=self.num_channels),
    }
    data_items_to_decoders.update(extra_data_items_to_decoders)

    return data_fields, data_items_to_decoders
    def example_reading_spec(self):
        """Return a mix of env and video data fields and decoders."""
        slim = contrib.slim()
        video_fields, video_decoders = (
            video_utils.VideoProblem.example_reading_spec(self))
        env_fields, env_decoders = (
            gym_env_problem.GymEnvProblem.example_reading_spec(self))

        # Remove raw observations field since we want to capture them as videos.
        env_fields.pop(env_problem.OBSERVATION_FIELD)
        env_decoders.pop(env_problem.OBSERVATION_FIELD)

        # Add frame number spec and decoder.
        env_fields[_FRAME_NUMBER_FIELD] = tf.FixedLenFeature((1, ), tf.int64)
        env_decoders[_FRAME_NUMBER_FIELD] = slim.tfexample_decoder.Tensor(
            _FRAME_NUMBER_FIELD)

        # Add video fields and decoders
        env_fields.update(video_fields)
        env_decoders.update(video_decoders)
        return env_fields, env_decoders
Exemple #15
0
def input_fn(dataset,
             filepattern,
             skip_random_fraction_when_training,
             batch_size_means_tokens_param,
             batch_size_multiplier,
             max_length,
             mode,
             hparams,
             data_dir=None,
             params=None,
             config=None,
             force_repeat=False,
             prevent_repeat=False):
    """Builds input pipeline for problem.

  Args:
    dataset: the dataset to make input function from.
    filepattern: the pattern of files to read from.
    skip_random_fraction_when_training: whether to skip randomly when training.
    batch_size_means_tokens_param: whether batch size should mean tokens.
    batch_size_multiplier: how to multiply batch size when bucketing.
    max_length: maximum length,
    mode: tf.estimator.ModeKeys
    hparams: HParams, model hparams
    data_dir: str, data directory; if None, will use hparams.data_dir
    params: dict, may include "batch_size"
    config: RunConfig; should have the data_parallelism attribute if not using
      TPU
    force_repeat: bool, whether to repeat the data even if not training
    prevent_repeat: bool, whether to not repeat when in training mode.
      Overrides force_repeat.

  Returns:
    (features_dict<str name, Tensor feature>, Tensor targets)
  """
    is_training = mode == tf.estimator.ModeKeys.TRAIN
    if config and config.use_tpu:
        num_threads = 64
    else:
        num_threads = cpu_count() if is_training else 1

    if config and hasattr(config,
                          "data_parallelism") and config.data_parallelism:
        num_shards = config.data_parallelism.n
    else:
        num_shards = 1

    mlperf_log.transformer_print(key=mlperf_log.INPUT_MAX_LENGTH,
                                 value=max_length)

    def tpu_valid_size(example):
        return example_valid_size(example, hparams.min_length, max_length)

    def gpu_valid_size(example):
        drop_long_sequences = is_training or hparams.eval_drop_long_sequences
        max_validate_length = max_length if drop_long_sequences else 10**9
        return example_valid_size(example, hparams.min_length,
                                  max_validate_length)

    def define_shapes(example):
        batch_size = config and config.use_tpu and params["batch_size"]
        return standardize_shapes(example, batch_size=batch_size)

    # Read and preprocess
    data_dir = data_dir or (hasattr(hparams, "data_dir") and hparams.data_dir)

    if (force_repeat or is_training) and not prevent_repeat:
        # Repeat and skip a random number of records
        dataset = dataset.repeat()

    if is_training and skip_random_fraction_when_training:
        data_files = contrib.slim().parallel_reader.get_data_files(filepattern)
        #  In continuous_train_and_eval when switching between train and
        #  eval, this input_fn method gets called multiple times and it
        #  would give you the exact same samples from the last call
        #  (because the Graph seed is set). So this skip gives you some
        #  shuffling.
        dataset = skip_random_fraction(dataset, data_files[0])

    dataset = dataset.map(cast_ints_to_int32, num_parallel_calls=num_threads)

    if batch_size_means_tokens_param:
        batch_size_means_tokens = True
    else:
        if _are_shapes_fully_defined(dataset.output_shapes):
            batch_size_means_tokens = False
        else:
            tf.logging.warning(
                "Shapes are not fully defined. Assuming batch_size means tokens."
            )
            batch_size_means_tokens = True

    # Batching
    if not batch_size_means_tokens:
        # Batch size means examples per datashard.
        if config and config.use_tpu:
            # on TPU, we use params["batch_size"], which specifies the number of
            # examples across all datashards
            batch_size = params["batch_size"]
            dataset = dataset.batch(batch_size, drop_remainder=True)
        else:
            batch_size = hparams.batch_size * num_shards
            dataset = dataset.batch(batch_size)
    else:
        # batch_size means tokens per datashard
        if config and config.use_tpu:
            dataset = dataset.filter(tpu_valid_size)
            padded_shapes = pad_for_tpu(dataset.output_shapes, hparams,
                                        max_length)
            # on TPU, we use params["batch_size"], which specifies the number of
            # examples across all datashards
            batch_size = params["batch_size"]
            if hparams.pad_batch:
                tf.logging.warn(
                    "Padding the batch to ensure that remainder eval batches are "
                    "processed. This may lead to incorrect metrics for "
                    "non-zero-padded features, e.g. images. Use a smaller batch "
                    "size that has no remainder in that case.")
                dataset = dataset.padded_batch(batch_size,
                                               padded_shapes,
                                               drop_remainder=False)
                dataset = dataset.map(functools.partial(
                    pad_batch, batch_multiple=batch_size),
                                      num_parallel_calls=num_threads)
            else:
                dataset = dataset.padded_batch(batch_size,
                                               padded_shapes,
                                               drop_remainder=True)
        else:
            # On GPU, bucket by length
            dataset = dataset.filter(gpu_valid_size)
            cur_batching_scheme = hparams_to_batching_scheme(
                hparams,
                shard_multiplier=num_shards,
                length_multiplier=batch_size_multiplier)
            if hparams.use_fixed_batch_size:
                # Here  batch_size really means examples per datashard.
                cur_batching_scheme["batch_sizes"] = [hparams.batch_size]
                cur_batching_scheme["boundaries"] = []
            dataset = dataset.apply(
                tf.data.experimental.bucket_by_sequence_length(
                    example_length, cur_batching_scheme["boundaries"],
                    cur_batching_scheme["batch_sizes"]))

            if not is_training:
                batch_multiple = num_shards
                if hparams.use_fixed_batch_size:
                    # Make sure the last batch has the same fixed size as the rest.
                    batch_multiple *= hparams.batch_size
                if batch_multiple > 1:
                    tf.logging.warn(
                        "Padding the batch to ensure that remainder eval batches have "
                        "a batch size divisible by the number of data shards. This may "
                        "lead to incorrect metrics for non-zero-padded features, e.g. "
                        "images. Use a single datashard (i.e. 1 GPU) in that case."
                    )
                    dataset = dataset.map(functools.partial(
                        pad_batch, batch_multiple=batch_multiple),
                                          num_parallel_calls=num_threads)

    dataset = dataset.map(define_shapes, num_parallel_calls=num_threads)

    # Add shuffling for training batches. This is necessary along with record
    # level shuffling in the dataset generation. Record shuffling will shuffle
    # the examples. However, in some cases, it's possible that the shuffle
    # buffer size for record shuffling is smaller than the batch size. In such
    # cases, adding batch shuffling ensures that the data is in random order
    # during training
    if (is_training and hasattr(hparams, "batch_shuffle_size")
            and hparams.batch_shuffle_size):
        dataset = dataset.shuffle(hparams.batch_shuffle_size)

    # Split batches into chunks if targets are too long.
    # The new "chunk_number" feature is 0 for the first chunk and goes up then.
    # Chunks are reversed so the 0th chunk comes first, then the 1st and so on,
    # so models can attend to them in the order they arrive. The last chunk is
    # usually the one containing the end of the target sentence (EOS).
    chunk_length = hparams.get("split_targets_chunk_length", 0)
    max_chunks = hparams.get("split_targets_max_chunks", 100)
    if chunk_length > 0:

        def is_nonzero_chunk(example):
            """A chunk is zero if all targets are 0s."""
            return tf.less(0, tf.reduce_sum(tf.abs(example["targets"])))

        def split_on_length(example):
            """Split a batch of ditcs on length."""
            x = example["targets"]
            # TODO(kitaev): This code breaks if chunk_length * max_chunks < batch_size
            length_diff = chunk_length * max_chunks - tf.shape(x)[1]
            padded_x = tf.pad(x, [(0, 0), (0, length_diff), (0, 0), (0, 0)])
            chunks = [
                padded_x[:, i * chunk_length:(i + 1) * chunk_length, :, :]
                for i in range(max_chunks - 1)
            ]
            chunks.append(padded_x[:, (max_chunks - 1) * chunk_length:, :, :])
            new_example = {}
            # Setting chunk_number to be tf.range(max_chunks) is incompatible with TPU
            new_example["chunk_number"] = tf.concat([
                tf.expand_dims(tf.ones_like(c) * n, axis=0)
                for n, c in enumerate(chunks)
            ],
                                                    axis=0)
            new_example["targets"] = tf.concat(
                [tf.expand_dims(c, axis=0) for c in chunks], axis=0)
            for k in example:
                if k != "targets":
                    assert k != "chunk_number", (
                        "Chunking code expects the chunk_number feature name to be "
                        "available")
                    new_example[k] = tf.concat([
                        tf.expand_dims(example[k], axis=0)
                        for _ in range(max_chunks)
                    ],
                                               axis=0)
            return tf.data.Dataset.from_tensor_slices(new_example)

        dataset = dataset.flat_map(split_on_length)
        dataset = dataset.filter(is_nonzero_chunk)

        # The chunking data pipeline thus far creates batches of examples where all
        # of the examples have the same chunk number. This can lead to periodic
        # fluctuations in the loss; for example, when all examples in the batch have
        # chunk number 0 the loss may be higher than midway through a sequence.
        # Enabling split_targets_strided_training adjusts the data so that each
        # batch includes examples at various points within a sequence.
        if is_training and hparams.split_targets_strided_training:
            # TODO(kitaev): make sure that shape inference works on GPU, not just TPU.
            inferred_batch_size = dataset.output_shapes["targets"].as_list()[0]
            if inferred_batch_size is None:
                raise ValueError(
                    "Strided training is only implemented when the batch size can be "
                    "inferred statically, for example when training on TPU.")
            chunk_stride = inferred_batch_size * max(
                1, max_chunks // inferred_batch_size) + 1

            def collapse_nested_datasets(example):
                """Converts a dataset of datasets to a dataset of tensor features."""
                new_example = {}
                for k, v in example.items():
                    v = tf.data.experimental.get_single_element(
                        v.batch(inferred_batch_size, drop_remainder=True))
                    new_example[k] = v
                return tf.data.Dataset.from_tensor_slices(new_example)

            dataset = dataset.unbatch()
            dataset = dataset.window(inferred_batch_size, inferred_batch_size,
                                     chunk_stride)
            dataset = dataset.flat_map(collapse_nested_datasets)
            dataset = dataset.batch(inferred_batch_size, drop_remainder=True)

    def prepare_for_output(example):
        if not config or not config.use_tpu:
            _summarize_features(example, num_shards)
        if mode == tf.estimator.ModeKeys.PREDICT:
            example["infer_targets"] = example.pop("targets")
            return example
        else:
            return example, example[hparams.get(key="labels_feature_name",
                                                default="targets")]

    dataset = dataset.map(prepare_for_output, num_parallel_calls=num_threads)
    dataset = dataset.prefetch(2)

    if mode == tf.estimator.ModeKeys.PREDICT:
        # This is because of a bug in the Estimator that short-circuits prediction
        # if it doesn't see a QueueRunner. DummyQueueRunner implements the
        # minimal expected interface but does nothing.
        tf.add_to_collection(tf.GraphKeys.QUEUE_RUNNERS, DummyQueueRunner())

    return dataset
Exemple #16
0
    def dataset(self,
                mode,
                data_dir=None,
                num_threads=None,
                output_buffer_size=None,
                shuffle_files=None,
                hparams=None,
                preprocess=True,
                dataset_split=None,
                shard=None,
                partition_id=0,
                num_partitions=1,
                shuffle_buffer_size=1024,
                max_records=-1):
        """Build a Dataset for this problem.

    Args:
      mode: tf.estimator.ModeKeys; determines which files to read from.
      data_dir: directory that contains data files.
      num_threads: int, number of threads to use for decode and preprocess
        Dataset.map calls.
      output_buffer_size: int, how many elements to prefetch at end of pipeline.
      shuffle_files: whether to shuffle input files. Default behavior (i.e. when
        shuffle_files=None) is to shuffle if mode == TRAIN.
      hparams: HParams; hparams to be passed to
        Problem.preprocess_example and Problem.hparams. If None, will use a
        default set that is a no-op.
      preprocess: bool, whether to map the Dataset through
        Problem.preprocess_example.
      dataset_split: DatasetSplit, which split to read data
        from (TRAIN:"-train", EVAL:"-dev", "test":"-test"). Defaults to mode.
      shard: int, if provided, will only read data from the specified shard.
      partition_id: integer - which partition of the dataset to read from
      num_partitions: how many partitions in the dataset
      shuffle_buffer_size: if shuffle_files is True, this is the buffer size
        used to shuffle records.
      max_records: int, number of records to truncate to.

    Returns:
      Dataset containing dict<feature name, Tensor>.

    Raises:
      ValueError: if num_partitions is greater than the number of data files.
    """
        is_training = mode == tf.estimator.ModeKeys.TRAIN
        shuffle_files = shuffle_files or shuffle_files is None and is_training

        dataset_split = dataset_split or mode
        assert data_dir

        if hparams is None:
            hparams = default_model_hparams()

        if not hasattr(hparams, "data_dir"):
            hparams.add_hparam("data_dir", data_dir)
        if not hparams.data_dir:
            hparams.data_dir = data_dir
        # Construct the Problem's hparams so that items within it are accessible
        _ = self.get_hparams(hparams)

        data_filepattern = self.filepattern(data_dir,
                                            dataset_split,
                                            shard=shard)
        tf.logging.info("Reading data files from %s", data_filepattern)
        data_files = sorted(
            contrib.slim().parallel_reader.get_data_files(data_filepattern))

        # Functions used in dataset transforms below. `filenames` can be either a
        # `tf.string` tensor or `tf.data.Dataset` containing one or more filenames.
        def _load_records_and_preprocess(filenames):
            """Reads files from a string tensor or a dataset of filenames."""
            # Load records from file(s) with an 8MiB read buffer.
            dataset = tf.data.TFRecordDataset(filenames,
                                              buffer_size=8 * 1024 * 1024)
            # Decode.
            dataset = dataset.map(self.decode_example,
                                  num_parallel_calls=num_threads)
            # Preprocess if requested.
            # Note that preprocessing should happen per-file as order may matter.
            if preprocess:
                dataset = self.preprocess(dataset,
                                          mode,
                                          hparams,
                                          interleave=shuffle_files)
            return dataset

        if len(data_files) < num_partitions:
            raise ValueError(
                "number of data files (%d) must be at least the number of hosts (%d)"
                % (len(data_files), num_partitions))
        data_files = [
            f for (i, f) in enumerate(data_files)
            if i % num_partitions == partition_id
        ]
        tf.logging.info("partition: %d num_data_files: %d" %
                        (partition_id, len(data_files)))
        if shuffle_files:
            mlperf_log.transformer_print(key=mlperf_log.INPUT_ORDER)
            random.shuffle(data_files)

        dataset = tf.data.Dataset.from_tensor_slices(tf.constant(data_files))
        # Create data-set from files by parsing, pre-processing and interleaving.
        if shuffle_files:
            dataset = dataset.apply(
                tf.data.experimental.parallel_interleave(
                    _load_records_and_preprocess, sloppy=True, cycle_length=8))
        else:
            dataset = _load_records_and_preprocess(dataset)

        dataset = dataset.map(self.maybe_reverse_and_copy,
                              num_parallel_calls=num_threads)
        dataset = dataset.take(max_records)

        ## Shuffle records only for training examples.
        if shuffle_files and is_training:
            dataset = dataset.shuffle(shuffle_buffer_size)
        if hparams.get("pack_dataset", False):
            dataset = generator_utils.pack_dataset(dataset,
                                                   hparams.max_length,
                                                   keys=["inputs", "targets"],
                                                   use_custom_ops=hparams.get(
                                                       "use_custom_ops",
                                                       False))
        if output_buffer_size:
            dataset = dataset.prefetch(output_buffer_size)

        return dataset