예제 #1
0
  def _maybe_do_batch(
      self,
      accumulator: _CombinerStatsGeneratorsCombineFnAcc,
      force: bool = False) -> None:
    """Maybe updates accumulator in place.

    Checks if accumulator has enough examples for a batch, and if so, does the
    stats computation for the batch and updates accumulator in place.

    Args:
      accumulator: Accumulator. Will be updated in place.
      force: Force computation of stats even if accumulator has less examples
        than the batch size.
    """
    if self._should_do_batch(accumulator, force):
      self._combine_batch_size.update(accumulator.curr_batch_size)
      self._combine_byte_size.update(accumulator.curr_byte_size)
      if len(accumulator.input_record_batches) == 1:
        record_batch = accumulator.input_record_batches[0]
      else:
        record_batch = table_util.MergeRecordBatches(
            accumulator.input_record_batches)
      accumulator.partial_accumulators = self._for_each_generator(
          lambda gen, gen_acc: gen.add_input(gen_acc, record_batch),
          accumulator.partial_accumulators)
      del accumulator.input_record_batches[:]
      accumulator.curr_batch_size = 0
      accumulator.curr_byte_size = 0
def _process_partition(
    partition: Tuple[Tuple[types.SliceKey, int],
                     List[pa.RecordBatch]], stats_fn: PartitionedStatsFn
) -> Tuple[types.SliceKey, statistics_pb2.DatasetFeatureStatistics]:
    """Process batches in a single partition."""
    (slice_key, _), record_batches = partition
    return slice_key, stats_fn.compute(
        table_util.MergeRecordBatches(record_batches))
예제 #3
0
 def test_merge_0_column_record_batches(self):
   record_batches = ([
       pa.table([pa.array([1, 2, 3])],
                ["ignore"]).remove_column(0).to_batches(max_chunksize=None)[0]
   ] * 3)
   merged = table_util.MergeRecordBatches(record_batches)
   self.assertEqual(merged.num_rows, 9)
   self.assertEqual(merged.num_columns, 0)
예제 #4
0
  def test_merge_record_batches(self, inputs, expected_output):
    input_record_batches = [
        pa.RecordBatch.from_arrays(list(in_dict.values()), list(in_dict.keys()))
        for in_dict in inputs
    ]
    merged = table_util.MergeRecordBatches(input_record_batches)

    self.assertLen(expected_output, merged.num_columns)
    for column, column_name in zip(merged.columns, merged.schema.names):
      self.assertTrue(
          expected_output[column_name].equals(column),
          "Column {}:\nexpected:{}\ngot: {}".format(
              column_name, expected_output[column_name], column))
예제 #5
0
    def _compact_impl(
        self, accumulator: _SampleRecordBatchRowsAccumulator
    ) -> _SampleRecordBatchRowsAccumulator:
        """Compacts the accumulator.

    This compact selects samples rows from the record batch, and merges them
    into one record batch. We can then clear the cache of all record batches
    seen so far. If the accumulator holds too few record batches, then nothing
    will be compacted.

    The sampling is done by assigning each row in the record batch a random
    number. Then we choose the top-k of the random numbers to get a sample of
    size k.

    Args:
      accumulator: The _SampleRecordBatchRowsAccumulator to compact.

    Returns:
      A _SampleRecordBatchRowsAccumulator that contains one or a list of record
      batch.
    """
        self._combine_num_record_batches.update(len(
            accumulator.record_batches))

        # There is nothing to compact.
        if accumulator.curr_num_rows <= 1:
            return accumulator

        # There is no need to compact yet.
        if (len(accumulator.record_batches) <= 1
                and accumulator.curr_num_rows <= self._sample_size):
            return accumulator
        self._num_compacts.inc(1)
        k = min(self._sample_size, accumulator.curr_num_rows)

        rand_ints = np.concatenate(accumulator.random_ints)

        # Find the value that is the breakpoint for the top-k.
        kth_value = np.partition(rand_ints, k - 1)[k - 1]

        # This mask will always have >= 1 Trues.
        equals_to_kth = (rand_ints == kth_value)

        # This mask will always have < k Trues.
        less_than_kth = rand_ints < kth_value

        # Since there may be duplicate values, `equals_to_kth + less_than_kth` might
        # be greater than `k`. We need to keep track of how many to add, without
        # surpassing `k`.
        kth_to_add = k - np.sum(less_than_kth)

        # Preserve the random integers that we had assigned to each row.
        sample_random_ints = rand_ints[rand_ints <= kth_value][:k]

        beg = 0
        sample_indices = []
        for rb in accumulator.record_batches:
            size = rb.num_rows
            end = beg + size
            less_than_kth_indices = np.nonzero(less_than_kth[beg:end])[0]
            indices = less_than_kth_indices

            # Add indices of any duplicate values that are equal to `k`.
            if kth_to_add > 0:
                equals_to_kth_indices = np.nonzero(equals_to_kth[beg:end])[0]
                if equals_to_kth_indices.size > 0:
                    if equals_to_kth_indices.size >= kth_to_add:
                        indices = np.concatenate([
                            less_than_kth_indices,
                            equals_to_kth_indices[:kth_to_add]
                        ])
                        kth_to_add = 0
                    else:
                        indices = np.concatenate(
                            [less_than_kth_indices, equals_to_kth_indices])
                        kth_to_add -= equals_to_kth_indices.size

            sample_indices.append(indices)
            beg += size

        result = _SampleRecordBatchRowsAccumulator()

        # Take and merge the record batches, based on the sampled indices.
        rbs = []
        for rb, indices in zip(accumulator.record_batches, sample_indices):
            rbs.append(table_util.RecordBatchTake(rb, pa.array(indices)))
        compressed_rb = table_util.MergeRecordBatches(rbs)
        result.record_batches = [compressed_rb]
        result.curr_num_rows = compressed_rb.num_rows
        result.curr_byte_size = compressed_rb.nbytes
        result.random_ints = [sample_random_ints]

        self._combine_byte_size.update(result.curr_byte_size)

        return result
예제 #6
0
 def test_invalid_inputs(self, inputs, expected_error_regexp):
   with self.assertRaisesRegex(Exception, expected_error_regexp):
     _ = table_util.MergeRecordBatches(inputs)