Пример #1
0
 def testBatch(self):
   initial_batch_size = 7
   final_batch_size = 13
   iterations = 50
   numpy_cols = in_memory_source.NumpySource(
       np.arange(1000, 2000), batch_size=initial_batch_size)()
   index_column = numpy_cols.index
   value_column = numpy_cols.value
   batcher = batch.Batch(
       batch_size=final_batch_size, output_names=["index", "value"])
   batched = batcher([index_column, value_column])
   cache = {}
   index_tensor = batched.index.build(cache)
   value_tensor = batched.value.build(cache)
   with self.test_session() as sess:
     coord = coordinator.Coordinator()
     threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
     for i in range(iterations):
       expected_index = range(i * final_batch_size, (i + 1) * final_batch_size)
       expected_value = range(1000 + i * final_batch_size,
                              1000 + (i + 1) * final_batch_size)
       actual_index, actual_value = sess.run([index_tensor, value_tensor])
       np.testing.assert_array_equal(expected_index, actual_index)
       np.testing.assert_array_equal(expected_value, actual_value)
     coord.request_stop()
     coord.join(threads)
Пример #2
0
  def batch(self,
            batch_size,
            shuffle=False,
            num_threads=1,
            queue_capacity=None,
            min_after_dequeue=None,
            seed=None):
    """Resize the batches in the `DataFrame` to the given `batch_size`.

    Args:
      batch_size: desired batch size.
      shuffle: whether records should be shuffled. Defaults to true.
      num_threads: the number of enqueueing threads.
      queue_capacity: capacity of the queue that will hold new batches.
      min_after_dequeue: minimum number of elements that can be left by a
        dequeue operation. Only used if `shuffle` is true.
      seed: passed to random shuffle operations. Only used if `shuffle` is true.

    Returns:
      A `DataFrame` with `batch_size` rows.
    """
    column_names = list(self._columns.keys())
    if shuffle:
      batcher = batch.ShuffleBatch(batch_size,
                                   output_names=column_names,
                                   num_threads=num_threads,
                                   queue_capacity=queue_capacity,
                                   min_after_dequeue=min_after_dequeue,
                                   seed=seed)
    else:
      batcher = batch.Batch(batch_size,
                            output_names=column_names,
                            num_threads=num_threads,
                            queue_capacity=queue_capacity)

    batched_series = batcher(list(self._columns.values()))
    dataframe = type(self)()
    dataframe.assign(**(dict(zip(column_names, batched_series))))
    return dataframe