def test_wrapped_successful():
    wrapped = add_id_to_cps_subscribe_transformer(
        Partition(1), MessageTransformer.of_callable(to_cps_subscribe_message)
    )
    expected = PubsubMessage(
        data=b"xyz",
        ordering_key="def",
        attributes={
            "x": "abc",
            "y": "abc",
            PUBSUB_LITE_EVENT_TIME: encode_attribute_event_time(
                Timestamp(seconds=55).ToDatetime()
            ),
        },
        message_id=MessageMetadata(Partition(1), Cursor(offset=10)).encode(),
        publish_time=Timestamp(seconds=10),
    )
    result = wrapped.transform(
        SequencedMessage(
            message=PubSubMessage(
                data=b"xyz",
                key=b"def",
                event_time=Timestamp(seconds=55),
                attributes={
                    "x": AttributeValues(values=[b"abc"]),
                    "y": AttributeValues(values=[b"abc"]),
                },
            ),
            publish_time=Timestamp(seconds=10),
            cursor=Cursor(offset=10),
            size_bytes=10,
        )
    )
    assert result == expected
async def test_simple_publish(mock_publishers, mock_policies, mock_watcher, publisher):
    mock_watcher.get_partition_count.return_value = 2
    async with publisher:
        mock_policies[2].route.return_value = Partition(1)
        mock_publishers[Partition(1)].publish.return_value = "a"
        await publisher.publish(PubSubMessage())
        mock_policies[2].route.assert_called_with(PubSubMessage())
        mock_publishers[Partition(1)].publish.assert_called()
 def route(self, message: PubSubMessage) -> Partition:
     """Route the message using the key if set or round robin if unset."""
     if not message.key:
         result = Partition(self._current_round_robin.value)
         self._current_round_robin = Partition(
             (self._current_round_robin.value + 1) % self._num_partitions)
         return result
     sha = hashlib.sha256()
     sha.update(message.key)
     as_int = int.from_bytes(sha.digest(), byteorder="big")
     return Partition(as_int % self._num_partitions)
    async def _handle_partition_count_update(self, partition_count: int):
        current_count = len(self._publishers)
        if current_count == partition_count:
            return
        if current_count > partition_count:
            return

        new_publishers = {
            Partition(index): self._publisher_factory(Partition(index))
            for index in range(current_count, partition_count)
        }
        await asyncio.gather(
            *[p.__aenter__() for p in new_publishers.values()])
        routing_policy = self._policy_factory(partition_count)

        self._publishers.update(new_publishers)
        self._routing_policy = routing_policy
Esempio n. 5
0
async def test_initial_assignment(subscriber, assigner, subscriber_factory):
    assign_queues = wire_queues(assigner.get_assignment)
    async with subscriber:
        await assign_queues.called.get()
        sub1 = mock_async_context_manager(
            MagicMock(spec=AsyncSingleSubscriber))
        sub2 = mock_async_context_manager(
            MagicMock(spec=AsyncSingleSubscriber))
        subscriber_factory.side_effect = (
            lambda partition: sub1 if partition == Partition(1) else sub2)
        await assign_queues.results.put({Partition(1), Partition(2)})
        await assign_queues.called.get()
        subscriber_factory.assert_has_calls(
            [call(Partition(1)), call(Partition(2))], any_order=True)
        sub1.__aenter__.assert_called_once()
        sub2.__aenter__.assert_called_once()
    sub1.__aexit__.assert_called_once()
    sub2.__aexit__.assert_called_once()
Esempio n. 6
0
async def test_basic_assign(assigner: Assigner, default_connection,
                            initial_request):
    write_called_queue = asyncio.Queue()
    write_result_queue = asyncio.Queue()
    default_connection.write.side_effect = make_queue_waiter(
        write_called_queue, write_result_queue)
    read_called_queue = asyncio.Queue()
    read_result_queue = asyncio.Queue()
    default_connection.read.side_effect = make_queue_waiter(
        read_called_queue, read_result_queue)
    write_result_queue.put_nowait(None)
    async with assigner:
        # Set up connection
        await write_called_queue.get()
        await read_called_queue.get()
        default_connection.write.assert_has_calls([call(initial_request)])

        # Wait for the first assignment
        assign_fut1 = asyncio.ensure_future(assigner.get_assignment())
        assert not assign_fut1.done()

        partitions = {Partition(2), Partition(7)}

        # Send the first assignment.
        await read_result_queue.put(as_response(partitions=partitions))
        assert (await assign_fut1) == partitions

        # Get the next assignment: should send an ack on the stream
        assign_fut2 = asyncio.ensure_future(assigner.get_assignment())
        await write_called_queue.get()
        await write_result_queue.put(None)
        default_connection.write.assert_has_calls(
            [call(initial_request), call(ack_request())])

        partitions = {Partition(5)}

        # Send the second assignment.
        await read_called_queue.get()
        await read_result_queue.put(as_response(partitions=partitions))
        assert (await assign_fut2) == partitions
Esempio n. 7
0
async def test_delivery_from_multiple(subscriber, assigner,
                                      subscriber_factory):
    assign_queues = wire_queues(assigner.get_assignment)
    async with subscriber:
        await assign_queues.called.get()
        sub1 = mock_async_context_manager(
            MagicMock(spec=AsyncSingleSubscriber))
        sub2 = mock_async_context_manager(
            MagicMock(spec=AsyncSingleSubscriber))
        sub1_queues = wire_queues(sub1.read)
        sub2_queues = wire_queues(sub2.read)
        subscriber_factory.side_effect = (
            lambda partition: sub1 if partition == Partition(1) else sub2)
        await assign_queues.results.put({Partition(1), Partition(2)})
        await sub1_queues.results.put(
            Message(PubsubMessage(message_id="1")._pb, "", 0, None))
        await sub2_queues.results.put(
            Message(PubsubMessage(message_id="2")._pb, "", 0, None))
        message_ids: Set[str] = set()
        message_ids.add((await subscriber.read()).message_id)
        message_ids.add((await subscriber.read()).message_id)
        assert message_ids == {"1", "2"}
Esempio n. 8
0
async def test_subscriber_failure(subscriber, assigner, subscriber_factory):
    assign_queues = wire_queues(assigner.get_assignment)
    async with subscriber:
        await assign_queues.called.get()
        sub1 = mock_async_context_manager(
            MagicMock(spec=AsyncSingleSubscriber))
        sub1_queues = wire_queues(sub1.read)
        subscriber_factory.return_value = sub1
        await assign_queues.results.put({Partition(1)})
        await sub1_queues.called.get()
        await sub1_queues.results.put(FailedPrecondition("sub failed"))
        with pytest.raises(FailedPrecondition):
            await subscriber.read()
Esempio n. 9
0
def test_routing_cases():
    policy = DefaultRoutingPolicy(num_partitions=29)
    json_list = []
    with open(os.path.join(os.path.dirname(__file__),
                           "routing_tests.json")) as f:
        for line in f:
            if not line.startswith("//"):
                json_list.append(line)

    loaded = json.loads("\n".join(json_list))
    target = {bytes(k, "utf-8"): Partition(v) for k, v in loaded.items()}
    result = {}
    for key in target:
        result[key] = policy.route(PubSubMessage(key=key))
    assert result == target
 async def _receive_loop(self):
     while True:
         response = await self._connection.read()
         if self._outstanding_assignment or not self._new_assignment.empty(
         ):
             self._connection.fail(
                 FailedPrecondition(
                     "Received a duplicate assignment on the stream while one was outstanding."
                 ))
             return
         self._outstanding_assignment = True
         partitions = set()
         for partition in response.partitions:
             partitions.add(Partition(partition))
         self._new_assignment.put_nowait(partitions)
def test_wrapped_sets_id_error():
    wrapped = add_id_to_cps_subscribe_transformer(
        Partition(1),
        MessageTransformer.of_callable(lambda x: PubsubMessage(message_id="a")),
    )
    with pytest.raises(InvalidArgument):
        wrapped.transform(
            SequencedMessage(
                message=PubSubMessage(
                    data=b"xyz",
                    key=b"def",
                    event_time=Timestamp(seconds=55),
                    attributes={
                        "x": AttributeValues(values=[b"abc"]),
                        "y": AttributeValues(values=[b"abc"]),
                    },
                ),
                publish_time=Timestamp(seconds=10),
                cursor=Cursor(offset=10),
                size_bytes=10,
            )
        )
async def test_publish_after_increase(
    mock_publishers, mock_policies, mock_watcher, publisher
):
    get_queues = wire_queues(mock_watcher.get_partition_count)
    await get_queues.results.put(2)
    async with publisher:
        get_queues.called.get_nowait()

        mock_policies[2].route.return_value = Partition(1)
        mock_publishers[Partition(1)].publish.return_value = "a"
        await publisher.publish(PubSubMessage())
        mock_policies[2].route.assert_called_with(PubSubMessage())
        mock_publishers[Partition(1)].publish.assert_called()

        await get_queues.called.get()
        await get_queues.results.put(3)
        await get_queues.called.get()

        mock_policies[3].route.return_value = Partition(2)
        mock_publishers[Partition(2)].publish.return_value = "a"
        await publisher.publish(PubSubMessage())
        mock_policies[3].route.assert_called_with(PubSubMessage())
        mock_publishers[Partition(2)].publish.assert_called()
Esempio n. 13
0
 def _partition(self) -> Partition:
     return Partition(self._initial.partition)
 def __init__(self, num_partitions: int):
     self._num_partitions = num_partitions
     self._current_round_robin = Partition(
         random.randint(0, num_partitions - 1))
def mock_publishers():
    return {Partition(i): MagicMock(spec=Publisher) for i in range(10)}
Esempio n. 16
0
async def test_restart(
    assigner: Assigner,
    default_connection,
    connection_factory,
    initial_request,
    asyncio_sleep,
    sleep_queues,
):
    write_called_queue = asyncio.Queue()
    write_result_queue = asyncio.Queue()
    default_connection.write.side_effect = make_queue_waiter(
        write_called_queue, write_result_queue)
    read_called_queue = asyncio.Queue()
    read_result_queue = asyncio.Queue()
    default_connection.read.side_effect = make_queue_waiter(
        read_called_queue, read_result_queue)
    write_result_queue.put_nowait(None)
    async with assigner:
        # Set up connection
        await write_called_queue.get()
        await read_called_queue.get()
        default_connection.write.assert_has_calls([call(initial_request)])

        # Wait for the first assignment
        assign_fut1 = asyncio.ensure_future(assigner.get_assignment())
        assert not assign_fut1.done()

        partitions = {Partition(2), Partition(7)}

        # Send the first assignment.
        await read_result_queue.put(as_response(partitions=partitions))
        await read_called_queue.get()
        assert (await assign_fut1) == partitions

        # Get the next assignment: should attempt to send an ack on the stream
        assign_fut2 = asyncio.ensure_future(assigner.get_assignment())
        await write_called_queue.get()
        default_connection.write.assert_has_calls(
            [call(initial_request), call(ack_request())])

        # Set up the next connection
        conn2 = MagicMock(spec=Connection)
        conn2.__aenter__.return_value = conn2
        connection_factory.new.return_value = conn2
        write_called_queue_2 = asyncio.Queue()
        write_result_queue_2 = asyncio.Queue()
        conn2.write.side_effect = make_queue_waiter(write_called_queue_2,
                                                    write_result_queue_2)
        read_called_queue_2 = asyncio.Queue()
        read_result_queue_2 = asyncio.Queue()
        conn2.read.side_effect = make_queue_waiter(read_called_queue_2,
                                                   read_result_queue_2)

        # Fail the connection by failing the write call.
        await write_result_queue.put(InternalServerError("failed"))
        await sleep_queues[_MIN_BACKOFF_SECS].called.get()
        await sleep_queues[_MIN_BACKOFF_SECS].results.put(None)

        # Reinitialize
        await write_called_queue_2.get()
        write_result_queue_2.put_nowait(None)
        conn2.write.assert_has_calls([call(initial_request)])

        partitions = {Partition(5)}

        # Send the second assignment on the new connection.
        await read_called_queue_2.get()
        await read_result_queue_2.put(as_response(partitions=partitions))
        assert (await assign_fut2) == partitions
        # No ack call ever made.
        conn2.write.assert_has_calls([call(initial_request)])