def test_wrapped_successful(): wrapped = add_id_to_cps_subscribe_transformer( Partition(1), MessageTransformer.of_callable(to_cps_subscribe_message) ) expected = PubsubMessage( data=b"xyz", ordering_key="def", attributes={ "x": "abc", "y": "abc", PUBSUB_LITE_EVENT_TIME: encode_attribute_event_time( Timestamp(seconds=55).ToDatetime() ), }, message_id=MessageMetadata(Partition(1), Cursor(offset=10)).encode(), publish_time=Timestamp(seconds=10), ) result = wrapped.transform( SequencedMessage( message=PubSubMessage( data=b"xyz", key=b"def", event_time=Timestamp(seconds=55), attributes={ "x": AttributeValues(values=[b"abc"]), "y": AttributeValues(values=[b"abc"]), }, ), publish_time=Timestamp(seconds=10), cursor=Cursor(offset=10), size_bytes=10, ) ) assert result == expected
async def test_simple_publish(mock_publishers, mock_policies, mock_watcher, publisher): mock_watcher.get_partition_count.return_value = 2 async with publisher: mock_policies[2].route.return_value = Partition(1) mock_publishers[Partition(1)].publish.return_value = "a" await publisher.publish(PubSubMessage()) mock_policies[2].route.assert_called_with(PubSubMessage()) mock_publishers[Partition(1)].publish.assert_called()
def route(self, message: PubSubMessage) -> Partition: """Route the message using the key if set or round robin if unset.""" if not message.key: result = Partition(self._current_round_robin.value) self._current_round_robin = Partition( (self._current_round_robin.value + 1) % self._num_partitions) return result sha = hashlib.sha256() sha.update(message.key) as_int = int.from_bytes(sha.digest(), byteorder="big") return Partition(as_int % self._num_partitions)
async def _handle_partition_count_update(self, partition_count: int): current_count = len(self._publishers) if current_count == partition_count: return if current_count > partition_count: return new_publishers = { Partition(index): self._publisher_factory(Partition(index)) for index in range(current_count, partition_count) } await asyncio.gather( *[p.__aenter__() for p in new_publishers.values()]) routing_policy = self._policy_factory(partition_count) self._publishers.update(new_publishers) self._routing_policy = routing_policy
async def test_initial_assignment(subscriber, assigner, subscriber_factory): assign_queues = wire_queues(assigner.get_assignment) async with subscriber: await assign_queues.called.get() sub1 = mock_async_context_manager( MagicMock(spec=AsyncSingleSubscriber)) sub2 = mock_async_context_manager( MagicMock(spec=AsyncSingleSubscriber)) subscriber_factory.side_effect = ( lambda partition: sub1 if partition == Partition(1) else sub2) await assign_queues.results.put({Partition(1), Partition(2)}) await assign_queues.called.get() subscriber_factory.assert_has_calls( [call(Partition(1)), call(Partition(2))], any_order=True) sub1.__aenter__.assert_called_once() sub2.__aenter__.assert_called_once() sub1.__aexit__.assert_called_once() sub2.__aexit__.assert_called_once()
async def test_basic_assign(assigner: Assigner, default_connection, initial_request): write_called_queue = asyncio.Queue() write_result_queue = asyncio.Queue() default_connection.write.side_effect = make_queue_waiter( write_called_queue, write_result_queue) read_called_queue = asyncio.Queue() read_result_queue = asyncio.Queue() default_connection.read.side_effect = make_queue_waiter( read_called_queue, read_result_queue) write_result_queue.put_nowait(None) async with assigner: # Set up connection await write_called_queue.get() await read_called_queue.get() default_connection.write.assert_has_calls([call(initial_request)]) # Wait for the first assignment assign_fut1 = asyncio.ensure_future(assigner.get_assignment()) assert not assign_fut1.done() partitions = {Partition(2), Partition(7)} # Send the first assignment. await read_result_queue.put(as_response(partitions=partitions)) assert (await assign_fut1) == partitions # Get the next assignment: should send an ack on the stream assign_fut2 = asyncio.ensure_future(assigner.get_assignment()) await write_called_queue.get() await write_result_queue.put(None) default_connection.write.assert_has_calls( [call(initial_request), call(ack_request())]) partitions = {Partition(5)} # Send the second assignment. await read_called_queue.get() await read_result_queue.put(as_response(partitions=partitions)) assert (await assign_fut2) == partitions
async def test_delivery_from_multiple(subscriber, assigner, subscriber_factory): assign_queues = wire_queues(assigner.get_assignment) async with subscriber: await assign_queues.called.get() sub1 = mock_async_context_manager( MagicMock(spec=AsyncSingleSubscriber)) sub2 = mock_async_context_manager( MagicMock(spec=AsyncSingleSubscriber)) sub1_queues = wire_queues(sub1.read) sub2_queues = wire_queues(sub2.read) subscriber_factory.side_effect = ( lambda partition: sub1 if partition == Partition(1) else sub2) await assign_queues.results.put({Partition(1), Partition(2)}) await sub1_queues.results.put( Message(PubsubMessage(message_id="1")._pb, "", 0, None)) await sub2_queues.results.put( Message(PubsubMessage(message_id="2")._pb, "", 0, None)) message_ids: Set[str] = set() message_ids.add((await subscriber.read()).message_id) message_ids.add((await subscriber.read()).message_id) assert message_ids == {"1", "2"}
async def test_subscriber_failure(subscriber, assigner, subscriber_factory): assign_queues = wire_queues(assigner.get_assignment) async with subscriber: await assign_queues.called.get() sub1 = mock_async_context_manager( MagicMock(spec=AsyncSingleSubscriber)) sub1_queues = wire_queues(sub1.read) subscriber_factory.return_value = sub1 await assign_queues.results.put({Partition(1)}) await sub1_queues.called.get() await sub1_queues.results.put(FailedPrecondition("sub failed")) with pytest.raises(FailedPrecondition): await subscriber.read()
def test_routing_cases(): policy = DefaultRoutingPolicy(num_partitions=29) json_list = [] with open(os.path.join(os.path.dirname(__file__), "routing_tests.json")) as f: for line in f: if not line.startswith("//"): json_list.append(line) loaded = json.loads("\n".join(json_list)) target = {bytes(k, "utf-8"): Partition(v) for k, v in loaded.items()} result = {} for key in target: result[key] = policy.route(PubSubMessage(key=key)) assert result == target
async def _receive_loop(self): while True: response = await self._connection.read() if self._outstanding_assignment or not self._new_assignment.empty( ): self._connection.fail( FailedPrecondition( "Received a duplicate assignment on the stream while one was outstanding." )) return self._outstanding_assignment = True partitions = set() for partition in response.partitions: partitions.add(Partition(partition)) self._new_assignment.put_nowait(partitions)
def test_wrapped_sets_id_error(): wrapped = add_id_to_cps_subscribe_transformer( Partition(1), MessageTransformer.of_callable(lambda x: PubsubMessage(message_id="a")), ) with pytest.raises(InvalidArgument): wrapped.transform( SequencedMessage( message=PubSubMessage( data=b"xyz", key=b"def", event_time=Timestamp(seconds=55), attributes={ "x": AttributeValues(values=[b"abc"]), "y": AttributeValues(values=[b"abc"]), }, ), publish_time=Timestamp(seconds=10), cursor=Cursor(offset=10), size_bytes=10, ) )
async def test_publish_after_increase( mock_publishers, mock_policies, mock_watcher, publisher ): get_queues = wire_queues(mock_watcher.get_partition_count) await get_queues.results.put(2) async with publisher: get_queues.called.get_nowait() mock_policies[2].route.return_value = Partition(1) mock_publishers[Partition(1)].publish.return_value = "a" await publisher.publish(PubSubMessage()) mock_policies[2].route.assert_called_with(PubSubMessage()) mock_publishers[Partition(1)].publish.assert_called() await get_queues.called.get() await get_queues.results.put(3) await get_queues.called.get() mock_policies[3].route.return_value = Partition(2) mock_publishers[Partition(2)].publish.return_value = "a" await publisher.publish(PubSubMessage()) mock_policies[3].route.assert_called_with(PubSubMessage()) mock_publishers[Partition(2)].publish.assert_called()
def _partition(self) -> Partition: return Partition(self._initial.partition)
def __init__(self, num_partitions: int): self._num_partitions = num_partitions self._current_round_robin = Partition( random.randint(0, num_partitions - 1))
def mock_publishers(): return {Partition(i): MagicMock(spec=Publisher) for i in range(10)}
async def test_restart( assigner: Assigner, default_connection, connection_factory, initial_request, asyncio_sleep, sleep_queues, ): write_called_queue = asyncio.Queue() write_result_queue = asyncio.Queue() default_connection.write.side_effect = make_queue_waiter( write_called_queue, write_result_queue) read_called_queue = asyncio.Queue() read_result_queue = asyncio.Queue() default_connection.read.side_effect = make_queue_waiter( read_called_queue, read_result_queue) write_result_queue.put_nowait(None) async with assigner: # Set up connection await write_called_queue.get() await read_called_queue.get() default_connection.write.assert_has_calls([call(initial_request)]) # Wait for the first assignment assign_fut1 = asyncio.ensure_future(assigner.get_assignment()) assert not assign_fut1.done() partitions = {Partition(2), Partition(7)} # Send the first assignment. await read_result_queue.put(as_response(partitions=partitions)) await read_called_queue.get() assert (await assign_fut1) == partitions # Get the next assignment: should attempt to send an ack on the stream assign_fut2 = asyncio.ensure_future(assigner.get_assignment()) await write_called_queue.get() default_connection.write.assert_has_calls( [call(initial_request), call(ack_request())]) # Set up the next connection conn2 = MagicMock(spec=Connection) conn2.__aenter__.return_value = conn2 connection_factory.new.return_value = conn2 write_called_queue_2 = asyncio.Queue() write_result_queue_2 = asyncio.Queue() conn2.write.side_effect = make_queue_waiter(write_called_queue_2, write_result_queue_2) read_called_queue_2 = asyncio.Queue() read_result_queue_2 = asyncio.Queue() conn2.read.side_effect = make_queue_waiter(read_called_queue_2, read_result_queue_2) # Fail the connection by failing the write call. await write_result_queue.put(InternalServerError("failed")) await sleep_queues[_MIN_BACKOFF_SECS].called.get() await sleep_queues[_MIN_BACKOFF_SECS].results.put(None) # Reinitialize await write_called_queue_2.get() write_result_queue_2.put_nowait(None) conn2.write.assert_has_calls([call(initial_request)]) partitions = {Partition(5)} # Send the second assignment on the new connection. await read_called_queue_2.get() await read_result_queue_2.put(as_response(partitions=partitions)) assert (await assign_fut2) == partitions # No ack call ever made. conn2.write.assert_has_calls([call(initial_request)])