def State( self, request_stream, # type: Iterable[beam_fn_api_pb2.StateRequest] context=None # type: Any ): # type: (...) -> Iterator[beam_fn_api_pb2.StateResponse] # Note that this eagerly mutates state, assuming any failures are fatal. # Thus it is safe to ignore instruction_id. for request in request_stream: request_type = request.WhichOneof('request') if request_type == 'get': data, continuation_token = self._state.get_raw( request.state_key, request.get.continuation_token) yield beam_fn_api_pb2.StateResponse( id=request.id, get=beam_fn_api_pb2.StateGetResponse( data=data, continuation_token=continuation_token)) elif request_type == 'append': self._state.append_raw(request.state_key, request.append.data) yield beam_fn_api_pb2.StateResponse( id=request.id, append=beam_fn_api_pb2.StateAppendResponse()) elif request_type == 'clear': self._state.clear(request.state_key) yield beam_fn_api_pb2.StateResponse( id=request.id, clear=beam_fn_api_pb2.StateClearResponse()) else: raise NotImplementedError('Unknown state request: %s' % request_type)
def State(self, request_stream, context=None): # Note that this eagerly mutates state, assuming any failures are fatal. # Thus it is safe to ignore instruction_reference. for request in request_stream: if request.get: yield beam_fn_api_pb2.StateResponse( id=request.id, get=beam_fn_api_pb2.StateGetResponse( data=self.blocking_get(request.state_key))) elif request.append: self.blocking_append(request.state_key, request.append.data) yield beam_fn_api_pb2.StateResponse( id=request.id, append=beam_fn_api_pb2.AppendResponse()) elif request.clear: self.blocking_clear(request.state_key) yield beam_fn_api_pb2.StateResponse( id=request.id, clear=beam_fn_api_pb2.ClearResponse())
def extend( self, state_key, # type: beam_fn_api_pb2.StateKey coder, # type: coder_impl.CoderImpl elements, # type: Iterable[Any] ): # type: (...) -> _Future cache_token = self._get_cache_token(state_key) if cache_token: # Update the cache cache_key = self._convert_to_cache_key(state_key) cached_value = self._state_cache.get(cache_key, cache_token) # Keep in mind that the state for this key can be evicted # while executing this function. Either read or write to the cache # but never do both here! if cached_value is None: # We have never cached this key before, first retrieve state cached_value = self.blocking_get(state_key, coder) # Just extend the already cached value if isinstance(cached_value, list): # Materialize provided iterable to ensure reproducible iterations, # here and when writing to the state handler below. elements = list(elements) # The state is fully cached and can be extended cached_value.extend(elements) elif isinstance(cached_value, self.ContinuationIterable): # The state is too large to be fully cached (continuation token used), # only the first part is cached, the rest if enumerated via the runner. pass else: # When a corrupt value made it into the cache, we have to fail. raise Exception("Unexpected cached value: %s" % cached_value) # Write to state handler futures = [] out = coder_impl.create_OutputStream() for element in elements: coder.encode_to_stream(element, out, True) if out.size() > data_plane._DEFAULT_SIZE_FLUSH_THRESHOLD: futures.append(self._underlying.append_raw(state_key, out.get())) out = coder_impl.create_OutputStream() if out.size(): futures.append(self._underlying.append_raw(state_key, out.get())) return _DeferredCall( lambda *results: beam_fn_api_pb2.StateResponse( error='\n'.join( result.error for result in results if result and result.error), append=beam_fn_api_pb2.StateAppendResponse()), *futures)