def __getitem__(self, window): target_window = self._side_input_data.window_mapping_fn(window) if target_window not in self._cache: state_handler = self._state_handler access_pattern = self._side_input_data.access_pattern if access_pattern == common_urns.side_inputs.ITERABLE.urn: state_key = beam_fn_api_pb2.StateKey( iterable_side_input=beam_fn_api_pb2.StateKey.IterableSideInput( transform_id=self._transform_id, side_input_id=self._tag, window=self._target_window_coder.encode(target_window))) raw_view = _StateBackedIterable( state_handler, state_key, self._element_coder) elif access_pattern == common_urns.side_inputs.MULTIMAP.urn: state_key = beam_fn_api_pb2.StateKey( multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput( transform_id=self._transform_id, side_input_id=self._tag, window=self._target_window_coder.encode(target_window), key=b'')) cache = {} key_coder_impl = self._element_coder.key_coder().get_impl() value_coder = self._element_coder.value_coder() class MultiMap(object): def __getitem__(self, key): if key not in cache: keyed_state_key = beam_fn_api_pb2.StateKey() keyed_state_key.CopyFrom(state_key) keyed_state_key.multimap_side_input.key = ( key_coder_impl.encode_nested(key)) cache[key] = _StateBackedIterable( state_handler, keyed_state_key, value_coder) return cache[key] def __reduce__(self): # TODO(robertwb): Figure out how to support this. raise TypeError(common_urns.side_inputs.MULTIMAP.urn) raw_view = MultiMap() else: raise ValueError( "Unknown access pattern: '%s'" % access_pattern) self._cache[target_window] = self._side_input_data.view_fn(raw_view) return self._cache[target_window]
def _create_internal_map_state(self, name, encoded_namespace, map_key_coder, map_value_coder, ttl_config, cache_type): # Currently the `beam_fn_api.proto` does not support MapState, so we use the # the `MultimapSideInput` message to mark the state as a MapState for now. from pyflink.fn_execution.flink_fn_execution_pb2 import StateDescriptor state_proto = StateDescriptor() state_proto.state_name = name if ttl_config is not None: state_proto.state_ttl_config.CopyFrom(ttl_config._to_proto()) state_key = beam_fn_api_pb2.StateKey( multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput( transform_id="", window=encoded_namespace, side_input_id=base64.b64encode( state_proto.SerializeToString()), key=self._encoded_current_key)) if cache_type == SynchronousKvRuntimeState.CacheType.DISABLE_CACHE: write_cache_size = 0 else: write_cache_size = self._map_state_write_cache_size return InternalSynchronousMapRuntimeState(self._map_state_handler, state_key, map_key_coder, map_value_coder, write_cache_size)
def __getitem__(self, window): target_window = self._side_input_data.window_mapping_fn(window) if target_window not in self._cache: state_key = beam_fn_api_pb2.StateKey( multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput( ptransform_id=self._transform_id, side_input_id=self._tag, window=self._target_window_coder.encode(target_window), key='')) state_handler = self._state_handler access_pattern = self._side_input_data.access_pattern class AllElements(object): def __init__(self, state_key, coder): self._state_key = state_key self._coder_impl = coder.get_impl() def __iter__(self): # TODO(robertwb): Support pagination. input_stream = coder_impl.create_InputStream( state_handler.blocking_get(self._state_key, None)) while input_stream.size() > 0: yield self._coder_impl.decode_from_stream( input_stream, True) def __reduce__(self): return list, (list(self), ) if access_pattern == common_urns.ITERABLE_SIDE_INPUT: raw_view = AllElements(state_key, self._element_coder) elif access_pattern == common_urns.MULTIMAP_SIDE_INPUT: cache = {} key_coder_impl = self._element_coder.key_coder().get_impl() value_coder = self._element_coder.value_coder() class MultiMap(object): def __getitem__(self, key): if key not in cache: keyed_state_key = beam_fn_api_pb2.StateKey() keyed_state_key.CopyFrom(state_key) keyed_state_key.multimap_side_input.key = ( key_coder_impl.encode_nested(key)) cache[key] = AllElements(keyed_state_key, value_coder) return cache[key] def __reduce__(self): # TODO(robertwb): Figure out how to support this. raise TypeError(common_urns.MULTIMAP_SIDE_INPUT) raw_view = MultiMap() else: raise ValueError("Unknown access pattern: '%s'" % access_pattern) self._cache[target_window] = self._side_input_data.view_fn( raw_view) return self._cache[target_window]
def __getitem__(self, window): target_window = self._side_input_data.window_mapping_fn(window) if target_window not in self._cache: state_key = beam_fn_api_pb2.StateKey( multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput( ptransform_id=self._transform_id, side_input_id=self._tag, window=self._target_window_coder.encode(target_window))) element_coder_impl = self._element_coder.get_impl() state_handler = self._state_handler class AllElements(object): def __iter__(self): # TODO(robertwb): Support pagination. input_stream = coder_impl.create_InputStream( state_handler.blocking_get(state_key, None)) while input_stream.size() > 0: yield element_coder_impl.decode_from_stream( input_stream, True) def __reduce__(self): return list, (list(self), ) self._cache[target_window] = self._side_input_data.view_fn( AllElements()) return self._cache[target_window]
def test_extend_fetches_initial_state(self): coder = VarIntCoder() coder_impl = coder.get_impl() class UnderlyingStateHandler(object): """Simply returns an incremented counter as the state "value." """ def set_value(self, value): self._encoded_values = coder.encode(value) def get_raw(self, *args): return self._encoded_values, None def append_raw(self, _key, bytes): self._encoded_values += bytes def clear(self, *args): self._encoded_values = bytes() @contextlib.contextmanager def process_instruction_id(self, bundle_id): yield underlying_state_handler = UnderlyingStateHandler() state_cache = statecache.StateCache(100) handler = sdk_worker.CachingStateHandler(state_cache, underlying_state_handler) state = beam_fn_api_pb2.StateKey( bag_user_state=beam_fn_api_pb2.StateKey.BagUserState( user_state_id='state1')) cache_token = beam_fn_api_pb2.ProcessBundleRequest.CacheToken( token=b'state_token1', user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken. UserState()) def get(): return list(handler.blocking_get(state, coder_impl, True)) def append(value): handler.extend(state, coder_impl, [value], True) def clear(): handler.clear(state, True) # Initialize state underlying_state_handler.set_value(42) with handler.process_instruction_id('bundle', [cache_token]): # Append without reading beforehand append(43) self.assertEqual(get(), [42, 43]) clear() self.assertEqual(get(), []) append(44) self.assertEqual(get(), [44])
def iterable_state_write(values, element_coder_impl): token = unique_name(None, 'iter').encode('ascii') out = create_OutputStream() for element in values: element_coder_impl.encode_to_stream(element, out, True) controller.state.blocking_append( beam_fn_api_pb2.StateKey( runner=beam_fn_api_pb2.StateKey.Runner(key=token)), out.get()) return token
def _create_state(self, state_spec: userstate.StateSpec) -> userstate.AccumulatingRuntimeState: if isinstance(state_spec, userstate.BagStateSpec): bag_state = SynchronousBagRuntimeState( self._state_handler, state_key=beam_fn_api_pb2.StateKey( bag_user_state=beam_fn_api_pb2.StateKey.BagUserState( transform_id="", user_state_id=state_spec.name, key=self._encoded_current_key)), value_coder=state_spec.coder) return bag_state else: raise NotImplementedError(state_spec)
def _store_side_inputs_in_state(self, runner_execution_context, # type: execution.FnApiRunnerExecutionContext data_side_input, # type: DataSideInput ): # type: (...) -> None for (transform_id, tag), (buffer_id, si) in data_side_input.items(): _, pcoll_id = split_buffer_id(buffer_id) value_coder = runner_execution_context.pipeline_context.coders[ runner_execution_context.safe_coders[ runner_execution_context.data_channel_coders[pcoll_id]]] elements_by_window = WindowGroupingBuffer(si, value_coder) if buffer_id not in runner_execution_context.pcoll_buffers: runner_execution_context.pcoll_buffers[buffer_id] = ListBuffer( coder_impl=value_coder.get_impl()) for element_data in runner_execution_context.pcoll_buffers[buffer_id]: elements_by_window.append(element_data) if si.urn == common_urns.side_inputs.ITERABLE.urn: for _, window, elements_data in elements_by_window.encoded_items(): state_key = beam_fn_api_pb2.StateKey( iterable_side_input=beam_fn_api_pb2.StateKey.IterableSideInput( transform_id=transform_id, side_input_id=tag, window=window)) ( runner_execution_context.worker_handler_manager.state_servicer. append_raw(state_key, elements_data)) elif si.urn == common_urns.side_inputs.MULTIMAP.urn: for key, window, elements_data in elements_by_window.encoded_items(): state_key = beam_fn_api_pb2.StateKey( multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput( transform_id=transform_id, side_input_id=tag, window=window, key=key)) ( runner_execution_context.worker_handler_manager.state_servicer. append_raw(state_key, elements_data)) else: raise ValueError("Unknown access pattern: '%s'" % si.urn)
def test_caching(self): coder = VarIntCoder() coder_impl = coder.get_impl() class FakeUnderlyingState(object): """Simply returns an incremented counter as the state "value." """ def set_counter(self, n): self._counter = n def get_raw(self, *args): self._counter += 1 return coder.encode(self._counter), None @contextlib.contextmanager def process_instruction_id(self, bundle_id): yield underlying_state = FakeUnderlyingState() state_cache = statecache.StateCache(100) caching_state_hander = sdk_worker.CachingStateHandler( state_cache, underlying_state) state1 = beam_fn_api_pb2.StateKey( bag_user_state=beam_fn_api_pb2.StateKey.BagUserState( user_state_id='state1')) state2 = beam_fn_api_pb2.StateKey( bag_user_state=beam_fn_api_pb2.StateKey.BagUserState( user_state_id='state2')) side1 = beam_fn_api_pb2.StateKey( multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput( transform_id='transform', side_input_id='side1')) side2 = beam_fn_api_pb2.StateKey( iterable_side_input=beam_fn_api_pb2.StateKey.IterableSideInput( transform_id='transform', side_input_id='side2')) state_token1 = beam_fn_api_pb2.ProcessBundleRequest.CacheToken( token=b'state_token1', user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken. UserState()) state_token2 = beam_fn_api_pb2.ProcessBundleRequest.CacheToken( token=b'state_token2', user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken. UserState()) side1_token1 = beam_fn_api_pb2.ProcessBundleRequest.CacheToken( token=b'side1_token1', side_input=beam_fn_api_pb2.ProcessBundleRequest.CacheToken. SideInput(transform_id='transform', side_input_id='side1')) side1_token2 = beam_fn_api_pb2.ProcessBundleRequest.CacheToken( token=b'side1_token2', side_input=beam_fn_api_pb2.ProcessBundleRequest.CacheToken. SideInput(transform_id='transform', side_input_id='side1')) def get_as_list(key): return list(caching_state_hander.blocking_get(key, coder_impl)) underlying_state.set_counter(100) with caching_state_hander.process_instruction_id('bundle1', []): self.assertEqual(get_as_list(state1), [101]) # uncached self.assertEqual(get_as_list(state2), [102]) # uncached self.assertEqual(get_as_list(state1), [101]) # cached on bundle self.assertEqual(get_as_list(side1), [103]) # uncached self.assertEqual(get_as_list(side2), [104]) # uncached underlying_state.set_counter(200) with caching_state_hander.process_instruction_id( 'bundle2', [state_token1, side1_token1]): self.assertEqual(get_as_list(state1), [201]) # uncached self.assertEqual(get_as_list(state2), [202]) # uncached self.assertEqual(get_as_list(state1), [201]) # cached on state token1 self.assertEqual(get_as_list(side1), [203]) # uncached self.assertEqual(get_as_list(side1), [203]) # cached on side1_token1 self.assertEqual(get_as_list(side2), [204]) # uncached self.assertEqual(get_as_list(side2), [204]) # cached on bundle underlying_state.set_counter(300) with caching_state_hander.process_instruction_id( 'bundle3', [state_token1, side1_token1]): self.assertEqual(get_as_list(state1), [201]) # cached on state token1 self.assertEqual(get_as_list(state2), [202]) # cached on state token1 self.assertEqual(get_as_list(state1), [201]) # cached on state token1 self.assertEqual(get_as_list(side1), [203]) # cached on side1_token1 self.assertEqual(get_as_list(side1), [203]) # cached on side1_token1 self.assertEqual(get_as_list(side2), [301]) # uncached self.assertEqual(get_as_list(side2), [301]) # cached on bundle underlying_state.set_counter(400) with caching_state_hander.process_instruction_id( 'bundle4', [state_token2, side1_token1]): self.assertEqual(get_as_list(state1), [401]) # uncached self.assertEqual(get_as_list(state2), [402]) # uncached self.assertEqual(get_as_list(state1), [401]) # cached on state token2 self.assertEqual(get_as_list(side1), [203]) # cached on side1_token1 self.assertEqual(get_as_list(side1), [203]) # cached on side1_token1 self.assertEqual(get_as_list(side2), [403]) # uncached self.assertEqual(get_as_list(side2), [403]) # cached on bundle underlying_state.set_counter(500) with caching_state_hander.process_instruction_id( 'bundle5', [state_token2, side1_token2]): self.assertEqual(get_as_list(state1), [401]) # cached on state token2 self.assertEqual(get_as_list(state2), [402]) # cached on state token2 self.assertEqual(get_as_list(state1), [401]) # cached on state token2 self.assertEqual(get_as_list(side1), [501]) # uncached self.assertEqual(get_as_list(side1), [501]) # cached on side1_token2 self.assertEqual(get_as_list(side2), [502]) # uncached self.assertEqual(get_as_list(side2), [502]) # cached on bundle
def run_stage( self, controller, pipeline_components, stage, pcoll_buffers, safe_coders): context = pipeline_context.PipelineContext(pipeline_components) data_operation_spec = controller.data_operation_spec() def extract_endpoints(stage): # Returns maps of transform names to PCollection identifiers. # Also mutates IO stages to point to the data data_operation_spec. data_input = {} data_side_input = {} data_output = {} for transform in stage.transforms: if transform.spec.urn in (bundle_processor.DATA_INPUT_URN, bundle_processor.DATA_OUTPUT_URN): pcoll_id = transform.spec.payload if transform.spec.urn == bundle_processor.DATA_INPUT_URN: target = transform.unique_name, only_element(transform.outputs) data_input[target] = pcoll_buffers[pcoll_id] elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN: target = transform.unique_name, only_element(transform.inputs) data_output[target] = pcoll_id else: raise NotImplementedError if data_operation_spec: transform.spec.payload = data_operation_spec.SerializeToString() else: transform.spec.payload = "" elif transform.spec.urn == urns.PARDO_TRANSFORM: payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.ParDoPayload) for tag, si in payload.side_inputs.items(): data_side_input[transform.unique_name, tag] = ( 'materialize:' + transform.inputs[tag], beam.pvalue.SideInputData.from_runner_api(si, None)) return data_input, data_side_input, data_output logging.info('Running %s', stage.name) logging.debug(' %s', stage) data_input, data_side_input, data_output = extract_endpoints(stage) process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor( id=self._next_uid(), transforms={transform.unique_name: transform for transform in stage.transforms}, pcollections=dict(pipeline_components.pcollections.items()), coders=dict(pipeline_components.coders.items()), windowing_strategies=dict( pipeline_components.windowing_strategies.items()), environments=dict(pipeline_components.environments.items())) # Store the required side inputs into state. for (transform_id, tag), (pcoll_id, si) in data_side_input.items(): elements_by_window = _WindowGroupingBuffer(si) for element_data in pcoll_buffers[pcoll_id]: elements_by_window.append(element_data) for window, elements_data in elements_by_window.items(): state_key = beam_fn_api_pb2.StateKey( multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput( ptransform_id=transform_id, side_input_id=tag, window=window)) controller.state_handler.blocking_append(state_key, elements_data, None) def get_buffer(pcoll_id): if pcoll_id.startswith('materialize:'): if pcoll_id not in pcoll_buffers: # Just store the data chunks for replay. pcoll_buffers[pcoll_id] = list() elif pcoll_id.startswith('group:'): # This is a grouping write, create a grouping buffer if needed. if pcoll_id not in pcoll_buffers: original_gbk_transform = pcoll_id.split(':', 1)[1] transform_proto = pipeline_components.transforms[ original_gbk_transform] input_pcoll = only_element(transform_proto.inputs.values()) output_pcoll = only_element(transform_proto.outputs.values()) pre_gbk_coder = context.coders[safe_coders[ pipeline_components.pcollections[input_pcoll].coder_id]] post_gbk_coder = context.coders[safe_coders[ pipeline_components.pcollections[output_pcoll].coder_id]] windowing_strategy = context.windowing_strategies[ pipeline_components .pcollections[output_pcoll].windowing_strategy_id] pcoll_buffers[pcoll_id] = _GroupingBuffer( pre_gbk_coder, post_gbk_coder, windowing_strategy) else: # These should be the only two identifiers we produce for now, # but special side input writes may go here. raise NotImplementedError(pcoll_id) return pcoll_buffers[pcoll_id] return BundleManager( controller, get_buffer, process_bundle_descriptor, self._progress_frequency).process_bundle(data_input, data_output)
def run_stage( self, worker_handler_factory, pipeline_components, stage, pcoll_buffers, safe_coders): def iterable_state_write(values, element_coder_impl): token = unique_name(None, 'iter').encode('ascii') out = create_OutputStream() for element in values: element_coder_impl.encode_to_stream(element, out, True) controller.state.blocking_append( beam_fn_api_pb2.StateKey( runner=beam_fn_api_pb2.StateKey.Runner(key=token)), out.get()) return token controller = worker_handler_factory(stage.environment) context = pipeline_context.PipelineContext( pipeline_components, iterable_state_write=iterable_state_write) data_api_service_descriptor = controller.data_api_service_descriptor() def extract_endpoints(stage): # Returns maps of transform names to PCollection identifiers. # Also mutates IO stages to point to the data ApiServiceDescriptor. data_input = {} data_side_input = {} data_output = {} for transform in stage.transforms: if transform.spec.urn in (bundle_processor.DATA_INPUT_URN, bundle_processor.DATA_OUTPUT_URN): pcoll_id = transform.spec.payload if transform.spec.urn == bundle_processor.DATA_INPUT_URN: target = transform.unique_name, only_element(transform.outputs) if pcoll_id == fn_api_runner_transforms.IMPULSE_BUFFER: data_input[target] = [ENCODED_IMPULSE_VALUE] else: data_input[target] = pcoll_buffers[pcoll_id] coder_id = pipeline_components.pcollections[ only_element(transform.outputs.values())].coder_id elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN: target = transform.unique_name, only_element(transform.inputs) data_output[target] = pcoll_id coder_id = pipeline_components.pcollections[ only_element(transform.inputs.values())].coder_id else: raise NotImplementedError data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id) if data_api_service_descriptor: data_spec.api_service_descriptor.url = ( data_api_service_descriptor.url) transform.spec.payload = data_spec.SerializeToString() elif transform.spec.urn == common_urns.primitives.PAR_DO.urn: payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.ParDoPayload) for tag, si in payload.side_inputs.items(): data_side_input[transform.unique_name, tag] = ( create_buffer_id(transform.inputs[tag]), si.access_pattern) return data_input, data_side_input, data_output logging.info('Running %s', stage.name) logging.debug(' %s', stage) data_input, data_side_input, data_output = extract_endpoints(stage) process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor( id=self._next_uid(), transforms={transform.unique_name: transform for transform in stage.transforms}, pcollections=dict(pipeline_components.pcollections.items()), coders=dict(pipeline_components.coders.items()), windowing_strategies=dict( pipeline_components.windowing_strategies.items()), environments=dict(pipeline_components.environments.items())) if controller.state_api_service_descriptor(): process_bundle_descriptor.state_api_service_descriptor.url = ( controller.state_api_service_descriptor().url) # Store the required side inputs into state. for (transform_id, tag), (buffer_id, si) in data_side_input.items(): _, pcoll_id = split_buffer_id(buffer_id) value_coder = context.coders[safe_coders[ pipeline_components.pcollections[pcoll_id].coder_id]] elements_by_window = _WindowGroupingBuffer(si, value_coder) for element_data in pcoll_buffers[buffer_id]: elements_by_window.append(element_data) for key, window, elements_data in elements_by_window.encoded_items(): state_key = beam_fn_api_pb2.StateKey( multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput( ptransform_id=transform_id, side_input_id=tag, window=window, key=key)) controller.state.blocking_append(state_key, elements_data) def get_buffer(buffer_id): kind, name = split_buffer_id(buffer_id) if kind in ('materialize', 'timers'): if buffer_id not in pcoll_buffers: # Just store the data chunks for replay. pcoll_buffers[buffer_id] = list() elif kind == 'group': # This is a grouping write, create a grouping buffer if needed. if buffer_id not in pcoll_buffers: original_gbk_transform = name transform_proto = pipeline_components.transforms[ original_gbk_transform] input_pcoll = only_element(list(transform_proto.inputs.values())) output_pcoll = only_element(list(transform_proto.outputs.values())) pre_gbk_coder = context.coders[safe_coders[ pipeline_components.pcollections[input_pcoll].coder_id]] post_gbk_coder = context.coders[safe_coders[ pipeline_components.pcollections[output_pcoll].coder_id]] windowing_strategy = context.windowing_strategies[ pipeline_components .pcollections[output_pcoll].windowing_strategy_id] pcoll_buffers[buffer_id] = _GroupingBuffer( pre_gbk_coder, post_gbk_coder, windowing_strategy) else: # These should be the only two identifiers we produce for now, # but special side input writes may go here. raise NotImplementedError(buffer_id) return pcoll_buffers[buffer_id] for k in range(self._bundle_repeat): try: controller.state.checkpoint() BundleManager( controller, lambda pcoll_id: [], process_bundle_descriptor, self._progress_frequency, k).process_bundle(data_input, data_output) finally: controller.state.restore() result = BundleManager( controller, get_buffer, process_bundle_descriptor, self._progress_frequency).process_bundle(data_input, data_output) while True: timer_inputs = {} for transform_id, timer_writes in stage.timer_pcollections: windowed_timer_coder_impl = context.coders[ pipeline_components.pcollections[timer_writes].coder_id].get_impl() written_timers = get_buffer( create_buffer_id(timer_writes, kind='timers')) if written_timers: # Keep only the "last" timer set per key and window. timers_by_key_and_window = {} for elements_data in written_timers: input_stream = create_InputStream(elements_data) while input_stream.size() > 0: windowed_key_timer = windowed_timer_coder_impl.decode_from_stream( input_stream, True) key, _ = windowed_key_timer.value # TODO: Explode and merge windows. assert len(windowed_key_timer.windows) == 1 timers_by_key_and_window[ key, windowed_key_timer.windows[0]] = windowed_key_timer out = create_OutputStream() for windowed_key_timer in timers_by_key_and_window.values(): windowed_timer_coder_impl.encode_to_stream( windowed_key_timer, out, True) timer_inputs[transform_id, 'out'] = [out.get()] written_timers[:] = [] if timer_inputs: # The worker will be waiting on these inputs as well. for other_input in data_input: if other_input not in timer_inputs: timer_inputs[other_input] = [] # TODO(robertwb): merge results BundleManager( controller, get_buffer, process_bundle_descriptor, self._progress_frequency, True).process_bundle(timer_inputs, data_output) else: break return result