def test_wrapped_successful(): wrapped = add_id_to_cps_subscribe_transformer( Partition(1), MessageTransformer.of_callable(to_cps_subscribe_message) ) expected = PubsubMessage( data=b"xyz", ordering_key="def", attributes={ "x": "abc", "y": "abc", PUBSUB_LITE_EVENT_TIME: encode_attribute_event_time( Timestamp(seconds=55).ToDatetime() ), }, message_id=MessageMetadata(Partition(1), Cursor(offset=10)).encode(), publish_time=Timestamp(seconds=10), ) result = wrapped.transform( SequencedMessage( message=PubSubMessage( data=b"xyz", key=b"def", event_time=Timestamp(seconds=55), attributes={ "x": AttributeValues(values=[b"abc"]), "y": AttributeValues(values=[b"abc"]), }, ), publish_time=Timestamp(seconds=10), cursor=Cursor(offset=10), size_bytes=10, ) ) assert result == expected
async def test_ack(subscriber: AsyncSingleSubscriber, underlying, transformer, ack_set_tracker): ack_called_queue = asyncio.Queue() ack_result_queue = asyncio.Queue() ack_set_tracker.ack.side_effect = make_queue_waiter( ack_called_queue, ack_result_queue) async with subscriber: message_1 = SequencedMessage(cursor=Cursor(offset=1), size_bytes=5) message_2 = SequencedMessage(cursor=Cursor(offset=2), size_bytes=10) underlying.read.return_value = message_1 read_1: Message = await subscriber.read() ack_set_tracker.track.assert_has_calls([call(1)]) assert read_1.message_id == "1" underlying.read.return_value = message_2 read_2: Message = await subscriber.read() ack_set_tracker.track.assert_has_calls([call(1), call(2)]) assert read_2.message_id == "2" read_2.ack() await ack_called_queue.get() await ack_result_queue.put(None) ack_set_tracker.ack.assert_has_calls([call(2)]) read_1.ack() await ack_called_queue.get() await ack_result_queue.put(None) ack_set_tracker.ack.assert_has_calls([call(2), call(1)])
def test_subscribe_transform_correct(): expected = PubsubMessage( data=b"xyz", ordering_key="def", attributes={ "x": "abc", "y": "abc", PUBSUB_LITE_EVENT_TIME: encode_attribute_event_time( Timestamp(seconds=55).ToDatetime() ), }, publish_time=Timestamp(seconds=10), ) result = to_cps_subscribe_message( SequencedMessage( message=PubSubMessage( data=b"xyz", key=b"def", event_time=Timestamp(seconds=55), attributes={ "x": AttributeValues(values=[b"abc"]), "y": AttributeValues(values=[b"abc"]), }, ), publish_time=Timestamp(seconds=10), cursor=Cursor(offset=10), size_bytes=10, ) ) assert result == expected
async def test_ack_failure( subscriber: SinglePartitionSingleSubscriber, underlying, transformer, ack_set_tracker, ): ack_called_queue = asyncio.Queue() ack_result_queue = asyncio.Queue() ack_set_tracker.ack.side_effect = make_queue_waiter( ack_called_queue, ack_result_queue) async with subscriber: message = SequencedMessage(cursor=Cursor(offset=1), size_bytes=5) underlying.read.return_value = message read: Message = await subscriber.read() ack_set_tracker.track.assert_has_calls([call(1)]) read.ack() await ack_called_queue.get() ack_set_tracker.ack.assert_has_calls([call(1)]) await ack_result_queue.put(FailedPrecondition("Bad ack")) async def sleep_forever(): await asyncio.sleep(float("inf")) underlying.read.side_effect = sleep_forever with pytest.raises(FailedPrecondition): await subscriber.read()
async def test_nack_calls_ack( subscriber: SinglePartitionSingleSubscriber, underlying, transformer, ack_set_tracker, nack_handler, ): ack_called_queue = asyncio.Queue() ack_result_queue = asyncio.Queue() ack_set_tracker.ack.side_effect = make_queue_waiter( ack_called_queue, ack_result_queue) async with subscriber: message = SequencedMessage(cursor=Cursor(offset=1), size_bytes=5) underlying.read.return_value = message read: Message = await subscriber.read() ack_set_tracker.track.assert_has_calls([call(1)]) def on_nack(nacked: PubsubMessage, ack: Callable[[], None]): assert nacked.message_id == "1" ack() nack_handler.on_nack.side_effect = on_nack read.nack() await ack_called_queue.get() await ack_result_queue.put(None) ack_set_tracker.ack.assert_has_calls([call(1)])
async def reinitialize(self, connection: Connection[SubscribeRequest, SubscribeResponse]): self._reinitializing = True await self._stop_loopers() await connection.write(SubscribeRequest(initial=self._initial)) response = await connection.read() if "initial" not in response: self._connection.fail( FailedPrecondition( "Received an invalid initial response on the subscribe stream." )) return if self._last_received_offset is not None: # Perform a seek to get the next message after the one we received. await connection.write( SubscribeRequest(seek=SeekRequest(cursor=Cursor( offset=self._last_received_offset + 1)))) seek_response = await connection.read() if "seek" not in seek_response: self._connection.fail( FailedPrecondition( "Received an invalid seek response on the subscribe stream." )) return tokens = self._outstanding_flow_control.request_for_restart() if tokens is not None: await connection.write(SubscribeRequest(flow_control=tokens)) self._reinitializing = False self._start_loopers()
async def reinitialize( self, connection: Connection[SubscribeRequest, SubscribeResponse] ): initial = deepcopy(self._base_initial) if self._last_received_offset is not None: initial.initial_location = SeekRequest( cursor=Cursor(offset=self._last_received_offset + 1) ) else: initial.initial_location = SeekRequest( named_target=SeekRequest.NamedTarget.COMMITTED_CURSOR ) await connection.write(SubscribeRequest(initial=initial)) response = await connection.read() if "initial" not in response: self._connection.fail( FailedPrecondition( "Received an invalid initial response on the subscribe stream." ) ) return tokens = self._outstanding_flow_control.request_for_restart() if tokens is not None: await connection.write(SubscribeRequest(flow_control=tokens)) self._start_loopers()
def test_invalid_subscribe_transform_key(): with pytest.raises(InvalidArgument): to_cps_subscribe_message( SequencedMessage( message=PubSubMessage(key=NOT_UTF8), publish_time=Timestamp(), cursor=Cursor(offset=10), size_bytes=10, ) )
async def test_handle_reset( subscriber: SinglePartitionSingleSubscriber, underlying, transformer, ack_set_tracker, ): ack_called_queue = asyncio.Queue() ack_result_queue = asyncio.Queue() ack_set_tracker.ack.side_effect = make_queue_waiter( ack_called_queue, ack_result_queue ) async with subscriber: message_1 = SequencedMessage(cursor=Cursor(offset=1), size_bytes=5) underlying.read.return_value = message_1 read_1: Message = await subscriber.read() ack_set_tracker.track.assert_has_calls([call(1)]) assert read_1.message_id == "1" assert read_1.ack_id == ack_id(0, 1) await subscriber.handle_reset() ack_set_tracker.clear_and_commit.assert_called_once() # Message ACKed after reset. Its flow control tokens are refilled # but offset not committed (verified below after message 2). read_1.ack() message_2 = SequencedMessage(cursor=Cursor(offset=2), size_bytes=10) underlying.read.return_value = message_2 read_2: Message = await subscriber.read() ack_set_tracker.track.assert_has_calls([call(1), call(2)]) assert read_2.message_id == "2" assert read_2.ack_id == ack_id(1, 2) read_2.ack() await ack_called_queue.get() await ack_result_queue.put(None) underlying.allow_flow.assert_has_calls( [ call(FlowControlRequest(allowed_messages=1000, allowed_bytes=1000,)), call(FlowControlRequest(allowed_messages=1, allowed_bytes=5,)), call(FlowControlRequest(allowed_messages=1, allowed_bytes=10,)), ] ) ack_set_tracker.ack.assert_has_calls([call(2)])
def test_invalid_subscribe_contains_non_utf8_attributes(): with pytest.raises(InvalidArgument): to_cps_subscribe_message( SequencedMessage( message=PubSubMessage( key=b"def", attributes={"xyz": AttributeValues(values=[NOT_UTF8])} ), publish_time=Timestamp(seconds=10), cursor=Cursor(offset=10), size_bytes=10, ) )
async def test_track_failure( subscriber: SinglePartitionSingleSubscriber, underlying, transformer, ack_set_tracker, ): async with subscriber: ack_set_tracker.track.side_effect = FailedPrecondition("Bad track") message = SequencedMessage(cursor=Cursor(offset=1), size_bytes=5) underlying.read.return_value = message with pytest.raises(FailedPrecondition): await subscriber.read() ack_set_tracker.track.assert_has_calls([call(1)])
async def test_track_and_aggregate_acks(committer, tracker: AckSetTracker): async with tracker: committer.__aenter__.assert_called_once() tracker.track(offset=1) tracker.track(offset=3) tracker.track(offset=5) tracker.track(offset=7) committer.commit.assert_has_calls([]) await tracker.ack(offset=3) committer.commit.assert_has_calls([]) await tracker.ack(offset=5) committer.commit.assert_has_calls([]) await tracker.ack(offset=1) committer.commit.assert_has_calls([call(Cursor(offset=6))]) tracker.track(offset=8) await tracker.ack(offset=7) committer.commit.assert_has_calls( [call(Cursor(offset=6)), call(Cursor(offset=8))]) committer.__aexit__.assert_called_once()
def test_invalid_subscribe_contains_magic_attribute(): with pytest.raises(InvalidArgument): to_cps_subscribe_message( SequencedMessage( message=PubSubMessage( key=b"def", attributes={ PUBSUB_LITE_EVENT_TIME: AttributeValues(values=[b"abc"]) }, ), publish_time=Timestamp(seconds=10), cursor=Cursor(offset=10), size_bytes=10, ) )
async def ack(self, offset: int): # Note: put_nowait is used here and below to ensure that the below logic is executed without yielding # to another coroutine in the event loop. The queue is unbounded so it will never throw. self._acks.put_nowait(offset) prefix_acked_offset: Optional[int] = None while len(self._receipts) != 0 and not self._acks.empty(): receipt = self._receipts.popleft() ack = self._acks.get_nowait() if receipt == ack: prefix_acked_offset = receipt continue self._receipts.append(receipt) self._acks.put(ack) break if prefix_acked_offset is None: return # Convert from last acked to first unacked. await self._committer.commit(Cursor(offset=prefix_acked_offset + 1))
def ack(self, offset: int): self._acks.push(offset) prefix_acked_offset: Optional[int] = None while len(self._receipts) != 0 and not self._acks.empty(): receipt = self._receipts.popleft() ack = self._acks.peek() if receipt == ack: prefix_acked_offset = receipt self._acks.pop() continue self._receipts.appendleft(receipt) break if prefix_acked_offset is None: return # Convert from last acked to first unacked. cursor = Cursor() cursor._pb.offset = prefix_acked_offset + 1 self._committer.commit(cursor)
async def test_clear_and_commit(committer, tracker: AckSetTracker): async with tracker: committer.__aenter__.assert_called_once() tracker.track(offset=3) tracker.track(offset=5) with pytest.raises(FailedPrecondition): tracker.track(offset=1) await tracker.ack(offset=5) committer.commit.assert_has_calls([]) await tracker.clear_and_commit() committer.wait_until_empty.assert_called_once() # After clearing, it should be possible to track earlier offsets. tracker.track(offset=1) await tracker.ack(offset=1) committer.commit.assert_has_calls([call(Cursor(offset=2))]) committer.__aexit__.assert_called_once()
async def test_nack_failure( subscriber: SinglePartitionSingleSubscriber, underlying, transformer, ack_set_tracker, nack_handler, ): async with subscriber: message = SequencedMessage(cursor=Cursor(offset=1), size_bytes=5) underlying.read.return_value = message read: Message = await subscriber.read() ack_set_tracker.track.assert_has_calls([call(1)]) nack_handler.on_nack.side_effect = FailedPrecondition("Bad nack") read.nack() async def sleep_forever(): await asyncio.sleep(float("inf")) underlying.read.side_effect = sleep_forever with pytest.raises(FailedPrecondition): await subscriber.read()
def test_wrapped_sets_id_error(): wrapped = add_id_to_cps_subscribe_transformer( Partition(1), MessageTransformer.of_callable(lambda x: PubsubMessage(message_id="a")), ) with pytest.raises(InvalidArgument): wrapped.transform( SequencedMessage( message=PubSubMessage( data=b"xyz", key=b"def", event_time=Timestamp(seconds=55), attributes={ "x": AttributeValues(values=[b"abc"]), "y": AttributeValues(values=[b"abc"]), }, ), publish_time=Timestamp(seconds=10), cursor=Cursor(offset=10), size_bytes=10, ) )
async def reinitialize( self, connection: Connection[SubscribeRequest, SubscribeResponse], last_error: Optional[GoogleAPICallError], ): self._reinitializing = True await self._stop_loopers() if last_error and is_reset_signal(last_error): # Discard undelivered messages and refill flow control tokens. while not self._message_queue.empty(): msg = self._message_queue.get_nowait() self._outstanding_flow_control.add( FlowControlRequest( allowed_messages=1, allowed_bytes=msg.size_bytes, )) await self._reset_handler.handle_reset() self._last_received_offset = None initial = deepcopy(self._base_initial) if self._last_received_offset is not None: initial.initial_location = SeekRequest(cursor=Cursor( offset=self._last_received_offset + 1)) else: initial.initial_location = SeekRequest( named_target=SeekRequest.NamedTarget.COMMITTED_CURSOR) await connection.write(SubscribeRequest(initial=initial)) response = await connection.read() if "initial" not in response: self._connection.fail( FailedPrecondition( "Received an invalid initial response on the subscribe stream." )) return tokens = self._outstanding_flow_control.request_for_restart() if tokens is not None: await connection.write(SubscribeRequest(flow_control=tokens)) self._reinitializing = False self._start_loopers()