예제 #1
0
  def get_dataset(self, batch_size, epochs_between_evals):
    """Construct the dataset to be used for training and eval.

    For local training, data is provided through Dataset.from_generator. For
    remote training (TPUs) the data is first serialized to files and then sent
    to the TPU through a StreamingFilesDataset.

    Args:
      batch_size: The per-replica batch size of the dataset.
      epochs_between_evals: How many epochs worth of data to yield.
        (Generator mode only.)
    """
    self.increment_request_epoch()
    if self._stream_files:
      if epochs_between_evals > 1:
        raise ValueError("epochs_between_evals > 1 not supported for file "
                         "based dataset.")
      epoch_data_dir = self._result_queue.get(timeout=300)
      if not self._is_training:
        self._result_queue.put(epoch_data_dir)  # Eval data is reused.

      file_pattern = os.path.join(
          epoch_data_dir, rconst.SHARD_TEMPLATE.format("*"))
      dataset = StreamingFilesDataset(
          files=file_pattern, worker_job=popen_helper.worker_job(),
          num_parallel_reads=rconst.NUM_FILE_SHARDS, num_epochs=1,
          sloppy=not self._deterministic)
      map_fn = functools.partial(
          self.deserialize,
          batch_size=batch_size,
          is_training=self._is_training)
      dataset = dataset.map(map_fn, num_parallel_calls=16)

    else:
      types = {movielens.USER_COLUMN: rconst.USER_DTYPE,
               movielens.ITEM_COLUMN: rconst.ITEM_DTYPE}
      shapes = {
          movielens.USER_COLUMN: tf.TensorShape([batch_size, 1]),
          movielens.ITEM_COLUMN: tf.TensorShape([batch_size, 1])
      }

      if self._is_training:
        types[rconst.VALID_POINT_MASK] = np.bool
        shapes[rconst.VALID_POINT_MASK] = tf.TensorShape([batch_size, 1])

        types = (types, np.bool)
        shapes = (shapes, tf.TensorShape([batch_size, 1]))

      else:
        types[rconst.DUPLICATE_MASK] = np.bool
        shapes[rconst.DUPLICATE_MASK] = tf.TensorShape([batch_size, 1])

      data_generator = functools.partial(
          self.data_generator, epochs_between_evals=epochs_between_evals)
      dataset = tf.data.Dataset.from_generator(
          generator=data_generator, output_types=types,
          output_shapes=shapes)

    return dataset.prefetch(16)
예제 #2
0
  def get_dataset(self, batch_size, epochs_between_evals):
    """Construct the dataset to be used for training and eval.

    For local training, data is provided through Dataset.from_generator. For
    remote training (TPUs) the data is first serialized to files and then sent
    to the TPU through a StreamingFilesDataset.

    Args:
      batch_size: The per-replica batch size of the dataset.
      epochs_between_evals: How many epochs worth of data to yield.
        (Generator mode only.)
    """
    self.increment_request_epoch()
    if self._stream_files:
      if epochs_between_evals > 1:
        raise ValueError("epochs_between_evals > 1 not supported for file "
                         "based dataset.")
      epoch_data_dir = self._result_queue.get(timeout=300)
      if not self._is_training:
        self._result_queue.put(epoch_data_dir)  # Eval data is reused.

      file_pattern = os.path.join(
          epoch_data_dir, rconst.SHARD_TEMPLATE.format("*"))
      # TODO(seemuch): remove this contrib import
      # pylint: disable=line-too-long
      from tensorflow.contrib.tpu.python.tpu.datasets import StreamingFilesDataset
      # pylint: enable=line-too-long
      dataset = StreamingFilesDataset(
          files=file_pattern, worker_job=popen_helper.worker_job(),
          num_parallel_reads=rconst.NUM_FILE_SHARDS, num_epochs=1,
          sloppy=not self._deterministic)
      map_fn = functools.partial(self._deserialize, batch_size=batch_size)
      dataset = dataset.map(map_fn, num_parallel_calls=16)

    else:
      types = {movielens.USER_COLUMN: rconst.USER_DTYPE,
               movielens.ITEM_COLUMN: rconst.ITEM_DTYPE}
      shapes = {movielens.USER_COLUMN: tf.TensorShape([batch_size]),
                movielens.ITEM_COLUMN: tf.TensorShape([batch_size])}

      if self._is_training:
        types[rconst.VALID_POINT_MASK] = np.bool
        shapes[rconst.VALID_POINT_MASK] = tf.TensorShape([batch_size])

        types = (types, np.bool)
        shapes = (shapes, tf.TensorShape([batch_size]))

      else:
        types[rconst.DUPLICATE_MASK] = np.bool
        shapes[rconst.DUPLICATE_MASK] = tf.TensorShape([batch_size])

      data_generator = functools.partial(
          self.data_generator, epochs_between_evals=epochs_between_evals)
      dataset = tf.data.Dataset.from_generator(
          generator=data_generator, output_types=types,
          output_shapes=shapes)

    return dataset.prefetch(16)