class GenerateLoad(beam.DoFn): state_spec = userstate.CombiningValueStateSpec( 'bundles_remaining', beam.coders.VarIntCoder(), sum) timer_spec = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK) def __init__(self, num_records_per_key, value_size, bundle_size=1000): self.num_records_per_key = num_records_per_key self.payload = os.urandom(value_size) self.bundle_size = bundle_size self.key = None def process( self, element, records_remaining=beam.DoFn.StateParam(state_spec), timer=beam.DoFn.TimerParam(timer_spec)): self.key, _ = element records_remaining.add(self.num_records_per_key) timer.set(0) @userstate.on_timer(timer_spec) def process_timer( self, records_remaining=beam.DoFn.StateParam(state_spec), timer=beam.DoFn.TimerParam(timer_spec)): cur_bundle_size = min(self.bundle_size, records_remaining.read()) for _ in range(cur_bundle_size): records_remaining.add(-1) yield self.key, self.payload if records_remaining.read() > 0: timer.set(0)
class IndexAssigningDoFn(beam.DoFn): state_param = beam.DoFn.StateParam( userstate.CombiningValueStateSpec('index', beam.coders.VarIntCoder(), CallSequenceEnforcingCombineFn())) def process(self, element, state=state_param): _, value = element current_index = state.read() yield current_index, value state.add(1)
class StatefulCounterOperation(BaseCounterOperation): state_param = beam.DoFn.StateParam( userstate.CombiningValueStateSpec( 'count', beam.coders.IterableCoder(beam.coders.VarIntCoder()), sum)) if self.stateful else None def process(self, element, state=state_param): for _ in range(self.number_of_operations): for counter in self.counters: counter.inc() if state: state.add(1) yield element
def test_pardo_state_only(self): index_state_spec = userstate.CombiningValueStateSpec( 'index', beam.coders.VarIntCoder(), sum) # TODO(ccy): State isn't detected with Map/FlatMap. class AddIndex(beam.DoFn): def process(self, kv, index=beam.DoFn.StateParam(index_state_spec)): k, v = kv index.add(1) yield k, v, index.read() inputs = [('A', 'a')] * 2 + [('B', 'b')] * 3 expected = [('A', 'a', 1), ('A', 'a', 2), ('B', 'b', 1), ('B', 'b', 2), ('B', 'b', 3)] with self.create_pipeline() as p: assert_that(p | beam.Create(inputs) | beam.ParDo(AddIndex()), equal_to(expected))
class CounterOperation(beam.DoFn): def __init__(self, number_of_counters, number_of_operations): self.number_of_operations = number_of_operations self.counters = [] for i in range(number_of_counters): self.counters.append( Metrics.counter('do-not-publish', 'name-{}'.format(i))) state_param = beam.DoFn.StateParam( userstate.CombiningValueStateSpec( 'count', beam.coders.IterableCoder(beam.coders.VarIntCoder()), sum)) if self.stateful else None def process(self, element, state=state_param): for _ in range(self.number_of_operations): for counter in self.counters: counter.inc() if state: state.add(1) yield element
def test_pardo_state_with_custom_key_coder(self): """Tests that state requests work correctly when the key coder is an SDK-specific coder, i.e. non standard coder. This is additionally enforced by Java's ProcessBundleDescriptorsTest and by Flink's ExecutableStageDoFnOperator which detects invalid encoding by checking for the correct key group of the encoded key.""" index_state_spec = userstate.CombiningValueStateSpec('index', sum) # Test params # Ensure decent amount of elements to serve all partitions n = 200 duplicates = 1 split = n // (duplicates + 1) inputs = [(i % split, str(i % split)) for i in range(0, n)] # Use a DoFn which has to use FastPrimitivesCoder because the type cannot # be inferred class Input(beam.DoFn): def process(self, impulse): for i in inputs: yield i class AddIndex(beam.DoFn): def process(self, kv, index=beam.DoFn.StateParam(index_state_spec)): k, v = kv index.add(1) yield k, v, index.read() expected = [(i % split, str(i % split), i // split + 1) for i in range(0, n)] with self.create_pipeline() as p: assert_that( p | beam.Impulse() | beam.ParDo(Input()) | beam.ParDo(AddIndex()), equal_to(expected))
class StatefulBufferingFn(DoFn): BUFFER_STATE = BagStateSpec('buffer', StrUtf8Coder()) COUNT_STATE = userstate.CombiningValueStateSpec( 'count', VarIntCoder(), CountCombineFn()) def process(self, element, buffer_state=beam.DoFn.StateParam(BUFFER_STATE), count_state=beam.DoFn.StateParam(COUNT_STATE)): key, value = element try: index_value = list(buffer_state.read()).index(value) except: index_value = -1 if index_value < 0: buffer_state.add(value) index_value = count_state.read() count_state.add(1) # print(value, list(buffer_state.read()).index(value), list(buffer_state.read())) yield ('{}_{}'.format(value, index_value), 1)
def test_pardo_state_only(self): p = self.create_pipeline() if not isinstance(p.runner, fn_api_runner.FnApiRunner): # test is inherited by Flink PVR, which does not support the feature yet self.skipTest('User state not supported.') index_state_spec = userstate.CombiningValueStateSpec( 'index', beam.coders.VarIntCoder(), sum) # TODO(ccy): State isn't detected with Map/FlatMap. class AddIndex(beam.DoFn): def process(self, kv, index=beam.DoFn.StateParam(index_state_spec)): k, v = kv index.add(1) yield k, v, index.read() inputs = [('A', 'a')] * 2 + [('B', 'b')] * 3 expected = [('A', 'a', 1), ('A', 'a', 2), ('B', 'b', 1), ('B', 'b', 2), ('B', 'b', 3)] with p: assert_that(p | beam.Create(inputs) | beam.ParDo(AddIndex()), equal_to(expected))