def _process_partition( partition: Tuple[Tuple[types.SliceKey, int], List[pa.Table]], stats_fn: PartitionedStatsFn ) -> Tuple[types.SliceKey, statistics_pb2.DatasetFeatureStatistics]: """Process batches in a single partition.""" (slice_key, _), tables = partition return slice_key, stats_fn.compute(table_util.MergeTables(tables))
def _maybe_do_batch(self, accumulator: _CombinerStatsGeneratorsCombineFnAcc, force: bool = False) -> None: """Maybe updates accumulator in place. Checks if accumulator has enough examples for a batch, and if so, does the stats computation for the batch and updates accumulator in place. Args: accumulator: Accumulator. Will be updated in place. force: Force computation of stats even if accumulator has less examples than the batch size. """ batch_size = accumulator.curr_batch_size if (force and batch_size > 0) or batch_size >= self._desired_batch_size: self._combine_batch_size.update(batch_size) if len(accumulator.input_tables) == 1: arrow_table = accumulator.input_tables[0] else: arrow_table = table_util.MergeTables(accumulator.input_tables) accumulator.partial_accumulators = self._for_each_generator( lambda gen, gen_acc: gen.add_input(gen_acc, arrow_table), accumulator.partial_accumulators) del accumulator.input_tables[:] accumulator.curr_batch_size = 0
def test_merge_tables(self, inputs, expected_output): input_tables = [ pa.Table.from_arrays(list(in_dict.values()), list(in_dict.keys())) for in_dict in inputs ] merged = table_util.MergeTables(input_tables) self.assertLen(expected_output, merged.num_columns) for column_name in merged.schema.names: column = merged.column(column_name) self.assertEqual(column.num_chunks, 1) try: self.assertTrue(expected_output[column_name].equals( column.chunk(0))) except AssertionError: self.fail(msg="Column {}:\nexpected:{}\ngot: {}".format( column_name, expected_output[column_name], column))
def test_invalid_inputs(self, inputs, expected_error_regexp): with self.assertRaisesRegexp(Exception, expected_error_regexp): _ = table_util.MergeTables(inputs)