示例#1
0
def test_lazy_msg_access():
    # this test does not make much sense, when `message` is instantiated without `envelope`, the `request` header is accessed and therefore decompressed
    messages = [
        Message(
            None,
            r.SerializeToString(),
            'test',
            '123',
            request_id='123',
            request_type='DataRequest',
        ) for r in request_generator('/', random_docs(10))
    ]
    for m in messages:
        assert m.request.is_decompressed
        assert m.envelope
        assert len(m.dump()) == 3
        assert m.request.is_decompressed

    for m in messages:
        assert m.request.is_decompressed
        assert m.request
        assert len(m.dump()) == 3
        assert m.request.is_decompressed

    for m in messages:
        assert m.request.is_decompressed
        assert m.request.data.docs
        assert len(m.dump()) == 3
        assert m.request.is_decompressed
示例#2
0
def test_simple_zmqlet():
    args = set_pea_parser().parse_args([
        '--host-in', '0.0.0.0',
        '--host-out', '0.0.0.0',
        '--port-in', '12346',
        '--port-out', '12347',
        '--socket-in', 'PULL_CONNECT',
        '--socket-out', 'PUSH_CONNECT',
        '--timeout-ctrl', '-1'])

    args2 = set_pea_parser().parse_args([
        '--host-in', '0.0.0.0',
        '--host-out', '0.0.0.0',
        '--port-in', '12347',
        '--port-out', '12346',
        '--socket-in', 'PULL_BIND',
        '--socket-out', 'PUSH_BIND',
        '--uses', '_logforward',
        '--timeout-ctrl', '-1'
    ])

    logger = logging.getLogger('zmq-test')
    with BasePea(args2) as z1, Zmqlet(args, logger) as z:
        req = jina_pb2.RequestProto()
        req.request_id = get_random_identity()
        d = req.index.docs.add()
        d.tags['id'] = 2
        msg = Message(None, req, 'tmp', '')
        z.send_message(msg)
示例#3
0
def test_data_request_handler_change_docs_dam(logger, tmpdir):
    class MemmapExecutor(Executor):
        @requests
        def foo(self, docs, **kwargs):
            dam = DocumentArrayMemmap(tmpdir + '/dam')
            dam.extend(docs)
            return dam

    args = set_pea_parser().parse_args(['--uses', 'MemmapExecutor'])
    handler = DataRequestHandler(args, logger)

    req = list(
        request_generator(
            '/',
            DocumentArray([Document(text='input document')
                           for _ in range(10)])))[0]
    msg = Message(None, req, 'test', '123')
    assert len(msg.request.docs) == 10
    handler.handle(
        msg=msg,
        partial_requests=None,
        peapod_name='name',
    )

    assert len(msg.request.docs) == 10
    for doc in msg.request.docs:
        assert doc.text == 'input document'
示例#4
0
def test_double_dynamic_routing_zmqstreamlet():
    args1 = get_args()
    args2 = get_args()
    args3 = get_args()

    logger = JinaLogger('zmq-test')
    with ZmqStreamlet(args=args1, logger=logger) as z1, ZmqStreamlet(
        args=args2, logger=logger
    ) as z2, ZmqStreamlet(args=args3, logger=logger) as z3:
        assert z1.msg_sent == 0
        assert z2.msg_sent == 0
        assert z3.msg_sent == 0
        req = jina_pb2.RequestProto()
        req.request_id = random_identity()
        d = req.data.docs.add()
        d.tags['id'] = 2
        msg = Message(None, req, 'tmp', '')
        routing_pb = jina_pb2.RoutingTableProto()
        routing_table = {
            'active_pod': 'executor1',
            'pods': {
                'executor1': {
                    'host': __default_host__,
                    'port': args1.port_in,
                    'expected_parts': 0,
                    'out_edges': [{'pod': 'executor2'}, {'pod': 'executor3'}],
                },
                'executor2': {
                    'host': __default_host__,
                    'port': args2.port_in,
                    'expected_parts': 1,
                    'out_edges': [],
                },
                'executor3': {
                    'host': __default_host__,
                    'port': args3.port_in,
                    'expected_parts': 1,
                    'out_edges': [],
                },
            },
        }
        json_format.ParseDict(routing_table, routing_pb)
        msg.envelope.routing_table.CopyFrom(routing_pb)
        for pea in [z1, z2, z3]:
            thread = threading.Thread(target=pea.start, args=(callback,))
            thread.daemon = True
            thread.start()

        number_messages = 1000
        for i in range(number_messages):
            z1.send_message(msg)

        time.sleep(5)

        assert z1.msg_sent == 2 * number_messages
        assert z1.msg_recv == 0
        assert z2.msg_sent == 0
        assert z2.msg_recv == number_messages
        assert z3.msg_sent == 0
        assert z3.msg_recv == number_messages
示例#5
0
def test_lazy_msg_access():
    reqs = [
        Message(
            None,
            r.SerializeToString(),
            'test',
            '123',
            request_id='123',
            request_type='IndexRequest',
        )
        for r in request_generator(random_docs(10))
    ]
    for r in reqs:
        assert not r.request.is_used
        assert r.envelope
        assert len(r.dump()) == 3
        assert not r.request.is_used

    for r in reqs:
        assert not r.request.is_used
        assert r.request
        assert len(r.dump()) == 3
        assert not r.request.is_used

    for r in reqs:
        assert not r.request.is_used
        assert r.request.index.docs
        assert len(r.dump()) == 3
        assert r.request.is_used
示例#6
0
def test_double_dynamic_routing_zmqlet():
    args1 = get_args()
    args2 = get_args()
    args3 = get_args()

    logger = JinaLogger('zmq-test')
    with Zmqlet(args1, logger) as z1, Zmqlet(args2, logger) as z2, Zmqlet(
        args3, logger
    ) as z3:
        assert z1.msg_sent == 0
        assert z2.msg_sent == 0
        assert z3.msg_sent == 0
        req = jina_pb2.RequestProto()
        req.request_id = random_identity()
        d = req.data.docs.add()
        d.tags['id'] = 2
        msg = Message(None, req, 'tmp', '')
        routing_table = {
            'active_pod': 'executor1',
            'pods': {
                'executor1': {
                    'host': __default_host__,
                    'port': args1.port_in,
                    'expected_parts': 0,
                    'out_edges': [{'pod': 'executor2'}, {'pod': 'executor3'}],
                },
                'executor2': {
                    'host': __default_host__,
                    'port': args2.port_in,
                    'expected_parts': 1,
                    'out_edges': [],
                },
                'executor3': {
                    'host': __default_host__,
                    'port': args3.port_in,
                    'expected_parts': 1,
                    'out_edges': [],
                },
            },
        }
        msg.envelope.routing_table.CopyFrom(RoutingTable(routing_table).proto)

        number_messages = 100
        trips = 10
        for i in range(trips):
            for j in range(number_messages):
                z1.send_message(msg)
            time.sleep(1)
            for i in range(number_messages):
                z2.recv_message(callback)
                z3.recv_message(callback)

        total_number_messages = number_messages * trips

        assert z1.msg_sent == 2 * total_number_messages
        assert z2.msg_sent == 0
        assert z2.msg_recv == total_number_messages
        assert z3.msg_sent == 0
        assert z3.msg_recv == total_number_messages
示例#7
0
def _create_test_data_message():
    req = list(
        request_generator(
            '/', DocumentArray([Document(text='input document') for _ in range(10)])
        )
    )[0]
    msg = Message(None, req, 'test', '123')
    return msg
示例#8
0
def test_not_read_zmqlet():
    with MockBasePeaNotRead(args3), Zmqlet(args1, default_logger) as z:
        req = jina_pb2.RequestProto()
        req.request_id = get_random_identity()
        d = req.index.docs.add()
        d.tags['id'] = 2
        msg = Message(None, req, 'tmp', '')
        z.send_message(msg)
示例#9
0
def test_add_route():
    r = jina_pb2.RequestProto()
    r.control.command = jina_pb2.RequestProto.ControlRequestProto.IDLE
    msg = Message(None, r, pod_name='test1', identity='sda')
    msg.add_route('name', 'identity')
    assert len(msg.envelope.routes) == 2
    assert msg.envelope.routes[1].pod == 'name'
    assert msg.envelope.routes[1].pod_id == 'identity'
示例#10
0
def test_read_zmqlet():
    with MockBasePeaRead(args2), Zmqlet(args1, default_logger) as z:
        req = Request()
        req.request_id = random_identity()
        d = req.data.docs.add()
        d.tags['id'] = 2
        msg = Message(None, req, 'tmp', '')
        z.send_message(msg)
示例#11
0
def test_read_zmqlet():
    with MockBasePeaRead(args2), Zmqlet(args1, default_logger) as z:
        req = jina_pb2.RequestProto()
        req.request_id = uuid.uuid1().hex
        d = req.index.docs.add()
        d.tags['id'] = 2
        msg = Message(None, req, 'tmp', '')
        z.send_message(msg)
示例#12
0
async def test_double_dynamic_routing_async_zmqlet():
    args1 = get_args()
    args2 = get_args()
    args3 = get_args()

    logger = JinaLogger('zmq-test')
    with AsyncZmqlet(args1, logger) as z1, AsyncZmqlet(
            args2, logger) as z2, AsyncZmqlet(args3, logger) as z3:
        assert z1.msg_sent == 0
        assert z2.msg_sent == 0
        assert z3.msg_sent == 0
        req = jina_pb2.RequestProto()
        req.request_id = random_identity()
        d = req.data.docs.add()
        d.tags['id'] = 2
        msg = Message(None, req, 'tmp', '')
        routing_pb = jina_pb2.RoutingTableProto()
        routing_table = {
            'active_pod': 'pod1',
            'pods': {
                'pod1': {
                    'host': '0.0.0.0',
                    'port': args1.port_in,
                    'expected_parts': 0,
                    'out_edges': [{
                        'pod': 'pod2'
                    }, {
                        'pod': 'pod3'
                    }],
                },
                'pod2': {
                    'host': '0.0.0.0',
                    'port': args2.port_in,
                    'expected_parts': 1,
                    'out_edges': [],
                },
                'pod3': {
                    'host': '0.0.0.0',
                    'port': args3.port_in,
                    'expected_parts': 1,
                    'out_edges': [],
                },
            },
        }
        json_format.ParseDict(routing_table, routing_pb)
        msg.envelope.routing_table.CopyFrom(routing_pb)

        await send_msg(z1, msg)

        await z2.recv_message(callback)
        await z3.recv_message(callback)

        assert z1.msg_sent == 2
        assert z1.msg_recv == 0
        assert z2.msg_sent == 0
        assert z2.msg_recv == 1
        assert z3.msg_sent == 0
        assert z3.msg_recv == 1
示例#13
0
def test_message_size():
    reqs = [Message(None, r, 'test', '123') for r in request_generator(random_docs(10))]
    for r in reqs:
        assert r.size == 0
        assert sys.getsizeof(r.envelope.SerializeToString())
        assert sys.getsizeof(r.request.SerializeToString())
        assert len(r.dump()) == 3
        assert r.size > sys.getsizeof(r.envelope.SerializeToString()) \
               + sys.getsizeof(r.request.SerializeToString())
示例#14
0
文件: __init__.py 项目: JoanFM/jina
    def _add_envelope(self, msg, routing_table):
        if not self._static_routing_table:
            new_envelope = jina_pb2.EnvelopeProto()
            new_envelope.CopyFrom(msg.envelope)
            new_envelope.routing_table.CopyFrom(routing_table.proto)
            new_message = Message(request=msg.request, envelope=new_envelope)

            return new_message
        else:
            return msg
示例#15
0
def test_remote_local_dynamic_routing_zmqlet():
    args1 = get_args()

    args2 = get_args()
    args2.zmq_identity = 'test-identity'
    args2.hosts_in_connect = [f'{args1.host}:{args1.port_out}']

    logger = JinaLogger('zmq-test')
    with Zmqlet(args1, logger) as z1, Zmqlet(args2, logger) as z2:
        assert z1.msg_sent == 0
        assert z1.msg_recv == 0
        assert z2.msg_sent == 0
        assert z2.msg_recv == 0
        req = jina_pb2.RequestProto()
        req.request_id = random_identity()
        d = req.data.docs.add()
        d.tags['id'] = 2
        msg = Message(None, req, 'tmp', '')
        routing_pb = jina_pb2.RoutingTableProto()
        routing_table = {
            'active_pod': 'pod1',
            'pods': {
                'pod1': {
                    'host': '0.0.0.0',
                    'port': args1.port_in,
                    'expected_parts': 0,
                    'out_edges': [{
                        'pod': 'pod2',
                        'send_as_bind': True
                    }],
                },
                'pod2': {
                    'host': '0.0.0.0',
                    'port': args2.port_in,
                    'expected_parts': 1,
                    'out_edges': [],
                    'target_identity': args2.zmq_identity,
                },
            },
        }
        json_format.ParseDict(routing_table, routing_pb)
        msg.envelope.routing_table.CopyFrom(routing_pb)
        z2.recv_message(callback)

        assert z2.msg_sent == 0
        assert z2.msg_recv == 0

        z1.send_message(msg)
        z2.recv_message(callback)

        assert z1.msg_sent == 1
        assert z1.msg_recv == 0
        assert z2.msg_sent == 0
        assert z2.msg_recv == 1
示例#16
0
文件: gateway.py 项目: serge-m/jina
 def prefetch_req(num_req, fetch_to):
     for _ in range(num_req):
         try:
             asyncio.create_task(
                 zmqlet.send_message(
                     Message(None, next(request_iterator), 'gateway',
                             **vars(self.args))))
             fetch_to.append(asyncio.create_task(zmqlet.recv_message(callback=self.handle)))
         except StopIteration:
             return True
     return False
示例#17
0
def test_compression(compress_algo, low_bytes, high_ratio):
    no_comp_sizes = []
    sizes = []
    docs = list(random_docs(100, embed_dim=100))
    kwargs = dict(
        identity='gateway',
        pod_name='123',
        compress_min_bytes=2 * sum(no_comp_sizes) if low_bytes else 0,
        compress_min_ratio=10 if high_ratio else 1,
    )

    with TimeContext(f'no compress'):
        for r in request_generator('/', docs):
            m = Message(None, r, compress=CompressAlgo.NONE, **kwargs)
            m.dump()
            no_comp_sizes.append(m.size)

    kwargs = dict(
        identity='gateway',
        pod_name='123',
        compress_min_bytes=2 * sum(no_comp_sizes) if low_bytes else 0,
        compress_min_ratio=10 if high_ratio else 1,
    )
    with TimeContext(f'compressing with {str(compress_algo)}') as tc:
        for r in request_generator('/', docs):
            m = Message(None, r, compress=compress_algo, **kwargs)
            m.dump()
            sizes.append(m.size)

    if compress_algo == CompressAlgo.NONE or low_bytes or high_ratio:
        assert sum(sizes) >= sum(no_comp_sizes)
    else:
        assert sum(sizes) < sum(no_comp_sizes)
    print(
        f'{str(compress_algo)}: size {sum(sizes) / len(sizes)} (ratio: {sum(no_comp_sizes) / sum(sizes):.2f}) with {tc.duration:.2f}s'
    )
def test_recv_message_zmqlet(mocker):
    zmqlet1 = Zmqlet(args1, default_logger)
    zmqlet2 = Zmqlet(args2, default_logger)
    req = Request()
    req.request_id = random_identity()
    doc = req.data.docs.add()
    doc.tags['id'] = 2
    msg = Message(None, req, 'tmp', '')

    def callback(msg_):
        assert msg_.request.docs[0].tags['id'] == msg.request.data.docs[0].tags['id']

    mock = mocker.Mock()
    zmqlet1.send_message(msg)
    time.sleep(1)
    zmqlet2.recv_message(mock)
    validate_callback(mock, callback)
示例#19
0
def test_simple_dynamic_routing_zmqlet():
    args1 = get_args()
    args2 = get_args()

    logger = JinaLogger('zmq-test')
    with Zmqlet(args1, logger) as z1, Zmqlet(args2, logger) as z2:
        assert z1.msg_sent == 0
        assert z1.msg_recv == 0
        assert z2.msg_sent == 0
        assert z2.msg_recv == 0
        req = jina_pb2.RequestProto()
        req.request_id = random_identity()
        d = req.data.docs.add()
        d.tags['id'] = 2
        msg = Message(None, req, 'tmp', '')
        routing_pb = jina_pb2.RoutingTableProto()
        routing_table = {
            'active_pod': 'executor1',
            'pods': {
                'executor1': {
                    'host': __default_host__,
                    'port': args1.port_in,
                    'expected_parts': 0,
                    'out_edges': [{'pod': 'executor2'}],
                },
                'executor2': {
                    'host': __default_host__,
                    'port': args2.port_in,
                    'expected_parts': 1,
                    'out_edges': [],
                },
            },
        }
        json_format.ParseDict(routing_table, routing_pb)
        msg.envelope.routing_table.CopyFrom(routing_pb)
        z2.recv_message(callback)

        assert z2.msg_sent == 0
        assert z2.msg_recv == 0

        z1.send_message(msg)
        z2.recv_message(callback)
        assert z1.msg_sent == 1
        assert z1.msg_recv == 0
        assert z2.msg_sent == 0
        assert z2.msg_recv == 1
示例#20
0
def test_data_request_handler_new_docs(logger):
    args = set_pea_parser().parse_args(['--uses', 'NewDocsExecutor'])
    handler = DataRequestHandler(args, logger)
    req = list(
        request_generator(
            '/',
            DocumentArray([Document(text='input document')
                           for _ in range(10)])))[0]
    msg = Message(None, req, 'test', '123')
    assert len(msg.request.docs) == 10
    handler.handle(
        msg=msg,
        partial_requests=None,
        peapod_name='name',
    )

    assert len(msg.request.docs) == 1
    assert msg.request.docs[0].text == 'new document'
def test_not_decompressed_zmqlet(mocker):
    with MockPea(args2) as pea, Zmqlet(args1, default_logger) as z:
        req = Request()
        req.request_id = random_identity()
        d = req.data.docs.add()
        d.tags['id'] = 2
        msg = Message(None, req, 'tmp', '')
        mock = mocker.Mock()
        z.send_message(msg)
        time.sleep(1)
        z.recv_message(mock)

    def callback(msg_):
        pass

    validate_callback(mock, callback)
    print(f' joining pea')
    pea.join()
    print(f' joined pea')
示例#22
0
def test_simple_zmqlet():
    args = set_pea_parser().parse_args(
        [
            '--host-in',
            '0.0.0.0',
            '--host-out',
            '0.0.0.0',
            '--socket-in',
            'PULL_CONNECT',
            '--socket-out',
            'PUSH_CONNECT',
            '--timeout-ctrl',
            '-1',
        ]
    )

    args2 = set_pea_parser().parse_args(
        [
            '--host-in',
            '0.0.0.0',
            '--host-out',
            '0.0.0.0',
            '--port-in',
            str(args.port_out),
            '--port-out',
            str(args.port_in),
            '--socket-in',
            'PULL_BIND',
            '--socket-out',
            'PUSH_BIND',
            '--timeout-ctrl',
            '-1',
        ]
    )

    logger = JinaLogger('zmq-test')
    with BasePea(args2), Zmqlet(args, logger) as z:
        req = jina_pb2.RequestProto()
        req.request_id = random_identity()
        d = req.data.docs.add()
        d.tags['id'] = 2
        msg = Message(None, req, 'tmp', '')
        z.send_message(msg)
示例#23
0
def test_lazy_msg_access_with_envelope():
    envelope_proto = jina_pb2.EnvelopeProto()
    envelope_proto.compression.algorithm = 'NONE'
    envelope_proto.request_type = 'DataRequest'
    messages = [
        Message(
            envelope_proto,
            r.SerializeToString(),
        ) for r in request_generator('/', random_docs(10))
    ]
    for m in messages:
        assert not m.request.is_decompressed
        assert m.envelope
        assert len(m.dump()) == 3
        assert not m.request.is_decompressed
        assert m.request._pb_body is None
        assert m.request._buffer is not None
        assert m.proto
        assert m.request.is_decompressed
        assert m.request._pb_body is not None
        assert m.request._buffer is None
示例#24
0
文件: zmq.py 项目: phuocddat/jina
def _parse_from_frames(sock_type, frames: List[bytes]) -> 'Message':
    """
    Build :class:`Message` from a list of frames.

    The list of frames (has length >=3) has the following structure:

        - offset 0: the client id, can be empty
        - offset 1: is the offset 2 frame compressed
        - offset 2: the body of the serialized protobuf message

    :param sock_type: the recv socket type
    :param frames: list of bytes to parse from
    :return: a :class:`Message` object
    """
    if sock_type == zmq.DEALER:
        # dealer consumes the first part of the message as id, we need to prepend it back
        frames = [b' '] + frames
    elif sock_type == zmq.ROUTER:
        # the router appends dealer id when receive it, we need to remove it
        frames.pop(0)

    return Message(frames[1], frames[2])
示例#25
0
def test_data_request_handler_change_docs_from_partial_requests(logger):
    NUM_PARTIAL_REQUESTS = 5
    args = set_pea_parser().parse_args(['--uses', 'MergeChangeDocsExecutor'])
    handler = DataRequestHandler(args, logger)

    partial_reqs = [
        list(
            request_generator(
                '/',
                DocumentArray(
                    [Document(text='input document') for _ in range(10)])))[0]
    ] * NUM_PARTIAL_REQUESTS
    msg = Message(None, partial_reqs[-1], 'test', '123')
    assert len(msg.request.docs) == 10
    handler.handle(
        msg=msg,
        partial_requests=partial_reqs,
        peapod_name='name',
    )

    assert len(msg.request.docs) == 10 * NUM_PARTIAL_REQUESTS
    for doc in msg.request.docs:
        assert doc.text == 'changed document'
示例#26
0
 async def CallUnary(self, request, context):
     with AsyncZmqlet(self.args, logger=self.logger) as zmqlet:
         await zmqlet.send_message(
             Message(None, request, 'gateway', **vars(self.args)))
         return await zmqlet.recv_message(callback=self.handle)
示例#27
0
def _create_test_data_message(counter=0):
    req = list(request_generator('/', DocumentArray([Document(text=str(counter))])))[0]
    msg = Message(None, req, 'test', '123')
    return msg
示例#28
0
def ctrl_messages():
    return [
        Message(None, r, 'test', '123')
        for r in request_generator('/', random_docs(10))
    ]