예제 #1
0
  def __next__(self):
    """Get the next training batch or test example.

    Returns:
      batch: data_types.Batch.
    """
    x = self.queue.get()
    if self.split == "train":
      return data_utils.shard(x)
    else:
      return data_utils.to_device(x)
예제 #2
0
  def peek(self):
    """Peek at the next training batch or test example without dequeuing it.

    Returns:
      batch: data_types.Batch".
    """
    while self.queue.empty():
      x = None
    # Make a copy of the front of the queue.
    x = jax.tree_map(lambda x: x.copy(), self.queue.queue[0])
    if self.split == "train":
      return data_utils.shard(x)
    else:
      return data_utils.to_device(x)