def test_recv_message_zmqlet(mocker):
    zmqlet1 = Zmqlet(args1, default_logger)
    zmqlet2 = Zmqlet(args2, default_logger)
    req = Request()
    req.request_id = random_identity()
    doc = req.data.docs.add()
    doc.tags['id'] = 2
    msg = Message(None, req, 'tmp', '')

    def callback(msg_):
        assert msg_.request.docs[0].tags['id'] == msg.request.data.docs[
            0].tags['id']

    mock = mocker.Mock()
    zmqlet1.send_message(msg)
    time.sleep(1)
    zmqlet2.recv_message(mock)
    validate_callback(mock, callback)
Exemple #2
0
def test_data_request_handler_new_docs(logger):
    args = set_pea_parser().parse_args(['--uses', 'NewDocsExecutor'])
    handler = DataRequestHandler(args, logger)
    req = list(
        request_generator(
            '/',
            DocumentArray([Document(text='input document')
                           for _ in range(10)])))[0]
    msg = Message(None, req, 'test', '123')
    assert len(msg.request.docs) == 10
    handler.handle(
        msg=msg,
        partial_requests=None,
        peapod_name='name',
    )

    assert len(msg.request.docs) == 1
    assert msg.request.docs[0].text == 'new document'
def test_not_decompressed_zmqlet(mocker):
    with MockPea(args2) as pea, Zmqlet(args1, default_logger) as z:
        req = Request()
        req.request_id = random_identity()
        d = req.data.docs.add()
        d.tags['id'] = 2
        msg = Message(None, req, 'tmp', '')
        mock = mocker.Mock()
        z.send_message(msg)
        time.sleep(1)
        z.recv_message(mock)

    def callback(msg_):
        pass

    validate_callback(mock, callback)
    print(f' joining pea')
    pea.join()
    print(f' joined pea')
def test_simple_zmqlet():
    args = set_pea_parser().parse_args(
        [
            '--host-in',
            '0.0.0.0',
            '--host-out',
            '0.0.0.0',
            '--socket-in',
            'PULL_CONNECT',
            '--socket-out',
            'PUSH_CONNECT',
            '--timeout-ctrl',
            '-1',
        ]
    )

    args2 = set_pea_parser().parse_args(
        [
            '--host-in',
            '0.0.0.0',
            '--host-out',
            '0.0.0.0',
            '--port-in',
            str(args.port_out),
            '--port-out',
            str(args.port_in),
            '--socket-in',
            'PULL_BIND',
            '--socket-out',
            'PUSH_BIND',
            '--timeout-ctrl',
            '-1',
        ]
    )

    logger = JinaLogger('zmq-test')
    with BasePea(args2), Zmqlet(args, logger) as z:
        req = jina_pb2.RequestProto()
        req.request_id = random_identity()
        d = req.data.docs.add()
        d.tags['id'] = 2
        msg = Message(None, req, 'tmp', '')
        z.send_message(msg)
Exemple #5
0
def test_lazy_msg_access():
    reqs = [Message(None, r.SerializeToString(), 'test', '123',
                    request_id='123', request_type='IndexRequest') for r in request_generator(random_docs(10))]
    for r in reqs:
        assert not r.request.is_used
        assert r.envelope
        assert len(r.dump()) == 3
        assert not r.request.is_used

    for r in reqs:
        assert not r.request.is_used
        assert r.request
        assert len(r.dump()) == 3
        assert not r.request.is_used

    for r in reqs:
        assert not r.request.is_used
        assert r.request.index.docs
        assert len(r.dump()) == 3
        assert r.request.is_used
Exemple #6
0
def test_simple_zmqlet():
    args = set_pea_parser().parse_args([
        '--host-in', '0.0.0.0', '--host-out', '0.0.0.0', '--port-in', '12346',
        '--port-out', '12347', '--socket-in', 'PULL_CONNECT', '--socket-out',
        'PUSH_CONNECT', '--timeout-ctrl', '-1'
    ])

    args2 = set_pea_parser().parse_args([
        '--host-in', '0.0.0.0', '--host-out', '0.0.0.0', '--port-in', '12347',
        '--port-out', '12346', '--socket-in', 'PULL_BIND', '--socket-out',
        'PUSH_BIND', '--uses', '_logforward', '--timeout-ctrl', '-1'
    ])

    logger = logging.getLogger('zmq-test')
    with BasePea(args2) as z1, Zmqlet(args, logger) as z:
        req = jina_pb2.RequestProto()
        req.request_id = get_random_identity()
        d = req.index.docs.add()
        d.tags['id'] = 2
        msg = Message(None, req, 'tmp', '')
        z.send_message(msg)
Exemple #7
0
def test_lazy_msg_access_with_envelope():
    envelope_proto = jina_pb2.EnvelopeProto()
    envelope_proto.compression.algorithm = 'NONE'
    envelope_proto.request_type = 'DataRequest'
    messages = [
        Message(
            envelope_proto,
            r.SerializeToString(),
        ) for r in request_generator('/', random_docs(10))
    ]
    for m in messages:
        assert not m.request.is_decompressed
        assert m.envelope
        assert len(m.dump()) == 3
        assert not m.request.is_decompressed
        assert m.request._pb_body is None
        assert m.request._buffer is not None
        assert m.proto
        assert m.request.is_decompressed
        assert m.request._pb_body is not None
        assert m.request._buffer is None
Exemple #8
0
def _parse_from_frames(sock_type, frames: List[bytes]) -> 'Message':
    """
    Build :class:`Message` from a list of frames.

    The list of frames (has length >=3) has the following structure:

        - offset 0: the client id, can be empty
        - offset 1: is the offset 2 frame compressed
        - offset 2: the body of the serialized protobuf message

    :param sock_type: the recv socket type
    :param frames: list of bytes to parse from
    :return: a :class:`Message` object
    """
    if sock_type == zmq.DEALER:
        # dealer consumes the first part of the message as id, we need to prepend it back
        frames = [b' '] + frames
    elif sock_type == zmq.ROUTER:
        # the router appends dealer id when receive it, we need to remove it
        frames.pop(0)

    return Message(frames[1], frames[2])
def test_compression(compress_algo, low_bytes, high_ratio):
    no_comp_sizes = []
    sizes = []
    docs = list(random_docs(100, embed_dim=100))
    kwargs = dict(
        identity='gateway',
        pod_name='123',
        compress_min_bytes=2 * sum(no_comp_sizes) if low_bytes else 0,
        compress_min_ratio=10 if high_ratio else 1,
    )

    with TimeContext(f'no compress'):
        for r in request_generator('/', docs):
            m = Message(None, r, compress=CompressAlgo.NONE, **kwargs)
            m.dump()
            no_comp_sizes.append(m.size)

    kwargs = dict(
        identity='gateway',
        pod_name='123',
        compress_min_bytes=2 * sum(no_comp_sizes) if low_bytes else 0,
        compress_min_ratio=10 if high_ratio else 1,
    )
    with TimeContext(f'compressing with {str(compress_algo)}') as tc:
        for r in request_generator('/', docs):
            m = Message(None, r, compress=compress_algo, **kwargs)
            m.dump()
            sizes.append(m.size)

    if compress_algo == CompressAlgo.NONE or low_bytes or high_ratio:
        assert sum(sizes) >= sum(no_comp_sizes)
    else:
        assert sum(sizes) < sum(no_comp_sizes)
    print(
        f'{str(compress_algo)}: size {sum(sizes) / len(sizes)} (ratio: {sum(no_comp_sizes) / sum(sizes):.2f}) with {tc.duration:.2f}s'
    )
Exemple #10
0
def test_data_request_handler_change_docs_from_partial_requests(logger):
    NUM_PARTIAL_REQUESTS = 5
    args = set_pea_parser().parse_args(['--uses', 'MergeChangeDocsExecutor'])
    handler = DataRequestHandler(args, logger)

    partial_reqs = [
        list(
            request_generator(
                '/',
                DocumentArray(
                    [Document(text='input document') for _ in range(10)])))[0]
    ] * NUM_PARTIAL_REQUESTS
    msg = Message(None, partial_reqs[-1], 'test', '123')
    assert len(msg.request.docs) == 10
    handler.handle(
        msg=msg,
        partial_requests=partial_reqs,
        peapod_name='name',
    )

    assert len(msg.request.docs) == 10 * NUM_PARTIAL_REQUESTS
    for doc in msg.request.docs:
        assert doc.text == 'changed document'
Exemple #11
0
def ctrl_messages():
    return [
        Message(None, r, 'test', '123')
        for r in request_generator('/', random_docs(10))
    ]
Exemple #12
0
 async def CallUnary(self, request, context):
     with AsyncZmqlet(self.args, logger=self.logger) as zmqlet:
         await zmqlet.send_message(
             Message(None, request, 'gateway', **vars(self.args)))
         return await zmqlet.recv_message(callback=self.handle)
Exemple #13
0
def test_double_dynamic_routing_zmqlet():
    args1 = get_args()
    args2 = get_args()
    args3 = get_args()

    logger = JinaLogger('zmq-test')
    with Zmqlet(args1,
                logger) as z1, Zmqlet(args2,
                                      logger) as z2, Zmqlet(args3,
                                                            logger) as z3:
        assert z1.msg_sent == 0
        assert z2.msg_sent == 0
        assert z3.msg_sent == 0
        req = jina_pb2.RequestProto()
        req.request_id = random_identity()
        d = req.data.docs.add()
        d.tags['id'] = 2
        msg = Message(None, req, 'tmp', '')
        routing_table = {
            'active_pod': 'pod1',
            'pods': {
                'pod1': {
                    'host': '0.0.0.0',
                    'port': args1.port_in,
                    'expected_parts': 0,
                    'out_edges': [{
                        'pod': 'pod2'
                    }, {
                        'pod': 'pod3'
                    }],
                },
                'pod2': {
                    'host': '0.0.0.0',
                    'port': args2.port_in,
                    'expected_parts': 1,
                    'out_edges': [],
                },
                'pod3': {
                    'host': '0.0.0.0',
                    'port': args3.port_in,
                    'expected_parts': 1,
                    'out_edges': [],
                },
            },
        }
        msg.envelope.routing_table.CopyFrom(RoutingTable(routing_table).proto)

        number_messages = 100
        trips = 10
        for i in range(trips):
            for j in range(number_messages):
                z1.send_message(msg)
            time.sleep(1)
            for i in range(number_messages):
                z2.recv_message(callback)
                z3.recv_message(callback)

        total_number_messages = number_messages * trips

        assert z1.msg_sent == 2 * total_number_messages
        assert z2.msg_sent == 0
        assert z2.msg_recv == total_number_messages
        assert z3.msg_sent == 0
        assert z3.msg_recv == total_number_messages
Exemple #14
0
def test_double_dynamic_routing_zmqstreamlet():
    args1 = get_args()
    args2 = get_args()
    args3 = get_args()

    logger = JinaLogger('zmq-test')
    with ZmqStreamlet(args=args1, logger=logger) as z1, ZmqStreamlet(
            args=args2,
            logger=logger) as z2, ZmqStreamlet(args=args3,
                                               logger=logger) as z3:
        assert z1.msg_sent == 0
        assert z2.msg_sent == 0
        assert z3.msg_sent == 0
        req = jina_pb2.RequestProto()
        req.request_id = random_identity()
        d = req.data.docs.add()
        d.tags['id'] = 2
        msg = Message(None, req, 'tmp', '')
        routing_pb = jina_pb2.RoutingTableProto()
        routing_table = {
            'active_pod': 'pod1',
            'pods': {
                'pod1': {
                    'host': '0.0.0.0',
                    'port': args1.port_in,
                    'expected_parts': 0,
                    'out_edges': [{
                        'pod': 'pod2'
                    }, {
                        'pod': 'pod3'
                    }],
                },
                'pod2': {
                    'host': '0.0.0.0',
                    'port': args2.port_in,
                    'expected_parts': 1,
                    'out_edges': [],
                },
                'pod3': {
                    'host': '0.0.0.0',
                    'port': args3.port_in,
                    'expected_parts': 1,
                    'out_edges': [],
                },
            },
        }
        json_format.ParseDict(routing_table, routing_pb)
        msg.envelope.routing_table.CopyFrom(routing_pb)
        for pea in [z1, z2, z3]:
            thread = threading.Thread(target=pea.start, args=(callback, ))
            thread.daemon = True
            thread.start()

        number_messages = 1000
        for i in range(number_messages):
            z1.send_message(msg)

        time.sleep(5)

        assert z1.msg_sent == 2 * number_messages
        assert z1.msg_recv == 0
        assert z2.msg_sent == 0
        assert z2.msg_recv == number_messages
        assert z3.msg_sent == 0
        assert z3.msg_recv == number_messages
def _create_test_data_message(counter=0):
    req = list(request_generator('/', DocumentArray([Document(text=str(counter))])))[0]
    msg = Message(None, req, 'test', '123')
    return msg