示例#1
0
    def test_rerank_train(self):
        with RouterService(self.args), ZmqClient(self.c_args) as c1:
            msg = gnes_pb2.Message()
            msg.response.search.ClearField('topk_results')
            msg.request.search.query.raw_text = 'This is a query'

            for i, line in enumerate(self.test_str[:5]):
                s = msg.response.search.topk_results.add()
                s.score.value = 0.1
                s.doc.doc_id = i
                s.doc.raw_text = line

            msg.envelope.num_part.extend([1])
            msg.response.search.top_k = 5
            c1.send_message(msg)

            r = c1.recv_message()
            print(r)

            msg = gnes_pb2.Message()

            for i, line in enumerate(self.test_str):
                doc = msg.request.train.docs.add()
                doc.doc_id = i
                doc.raw_bytes = json.dumps({
                    'Query': 'test query',
                    'Candidate': line,
                    'Label': 1.0 if i % 2 == 0 else 0.0
                }).encode('utf-8')

            msg.envelope.num_part.extend([1])
            c1.send_message(msg)
            r = c1.recv_message()
            print(r)
示例#2
0
    def test_preprocessor_service_realdata(self):
        args = set_preprocessor_service_parser().parse_args([])
        c_args = _set_client_parser().parse_args(
            ['--port_in',
             str(args.port_out), '--port_out',
             str(args.port_in)])
        with open(os.path.join(self.dirname, '26-doc-chinese.txt'),
                  'r',
                  encoding='utf8') as fp:
            msg = gnes_pb2.Message()
            all_text = ''
            for v in fp:
                if v.strip():
                    d = msg.request.train.docs.add()
                    d.raw_text = v
                    all_text += v
            with PreprocessorService(args), ZmqClient(c_args) as client:
                client.send_message(msg)
                r = client.recv_message()
                print(r)

                msg1 = gnes_pb2.Message()
                msg1.request.index.docs.extend(msg.request.train.docs)

                client.send_message(msg1)
                r = client.recv_message()
                print(r)

                msg2 = gnes_pb2.Message()
                msg2.request.search.query.raw_text = all_text

                client.send_message(msg2)
                r = client.recv_message()
                print(r)
示例#3
0
 def test_block_router(self):
     with RouterService(self.args), ZmqClient(self.c_args) as c1:
         msg = gnes_pb2.Message()
         msg.request.train.docs.add()
         c1.send_message(msg)
         msg = gnes_pb2.Message()
         msg.request.index.docs.add()
         c1.send_message(msg)
         r = c1.recv_message()
示例#4
0
    def test_keyword(self):
        args = set_router_parser().parse_args([
            '--yaml_path',
            self.yaml,
            '--socket_out',
            str(SocketType.PUB_BIND),
            '--py_path',
            self.python_code,
        ])
        args.as_response = True
        c_args = _set_client_parser().parse_args([
            '--port_in',
            str(args.port_out), '--port_out',
            str(args.port_in), '--socket_in',
            str(SocketType.SUB_CONNECT)
        ])
        with IndexerService(args), ZmqClient(c_args) as c1:
            msg = gnes_pb2.Message()
            for i, vec in enumerate(self.test_vec):
                doc = msg.request.index.docs.add()
                doc.doc_id = i
                doc.raw_text = self.test_str[i]
                c = doc.chunks.add()
                c.doc_id = i
                c.offset = 0
                c.embedding.data = vec.tobytes()
                for d in vec.shape:
                    c.embedding.shape.extend([d])
                c.embedding.dtype = str(vec.dtype)
                c.text = self.test_str[i]
            c1.send_message(msg)

            r = c1.recv_message()
            self.assert_(r.response.index)

            for i, vec in enumerate(self.test_vec):
                msg = gnes_pb2.Message()
                msg.request.search.query.doc_id = 1
                msg.request.search.top_k = 1
                c = msg.request.search.query.chunks.add()
                c.doc_id = 1
                c.embedding.data = vec.tobytes()
                for d in vec.shape:
                    c.embedding.shape.extend([d])
                c.embedding.dtype = str(vec.dtype)
                c.offset = 0
                c.weight = 1
                c.text = self.test_str[i]
                c1.send_message(msg)
                r = c1.recv_message()
                self.assert_(
                    r.response.search.topk_results[0].chunk.doc_id == i)
示例#5
0
    def test_fasterrcnn_preprocessor(self):
        args = set_preprocessor_parser().parse_args(
            ['--yaml_path', self.fasterrcnn_yaml])
        c_args = _set_client_parser().parse_args(
            ['--port_in',
             str(args.port_out), '--port_out',
             str(args.port_in)])
        all_zips = zipfile.ZipFile(self.data_path)
        all_bytes = [all_zips.open(v).read() for v in all_zips.namelist()]

        with ServiceManager(PreprocessorService,
                            args), ZmqClient(c_args) as client:
            for req in RequestGenerator.index(all_bytes):
                msg = gnes_pb2.Message()
                msg.request.index.CopyFrom(req.index)
                client.send_message(msg)
                r = client.recv_message()
                for d in r.request.index.docs:
                    self.assertGreater(len(d.chunks), 0)
                    for _ in range(len(d.chunks)):
                        self.assertEqual(
                            len(blob2array(d.chunks[_].blob).shape), 3)
                        self.assertEqual(
                            blob2array(d.chunks[_].blob).shape[-1], 3)
                        self.assertEqual(
                            blob2array(d.chunks[_].blob).shape[0], 224)
                        self.assertEqual(
                            blob2array(d.chunks[_].blob).shape[1], 224)
                        print(blob2array(d.chunks[0].blob).dtype)
示例#6
0
    def test_chunk_sum_reduce_router(self):
        args = set_router_parser().parse_args([
            '--yaml_path', self.chunk_sum_yaml,
            '--socket_out', str(SocketType.PUB_BIND)
        ])
        c_args = _set_client_parser().parse_args([
            '--port_in', str(args.port_out),
            '--port_out', str(args.port_in),
            '--socket_in', str(SocketType.SUB_CONNECT)
        ])
        with RouterService(args), ZmqClient(c_args) as c1:
            msg = gnes_pb2.Message()
            s = msg.response.search.topk_results.add()
            s.score.value = 0.6
            s.score.explained = json.dumps(['1-c1', '1-c3', '2-c1'])
            s.doc.doc_id = 1

            s = msg.response.search.topk_results.add()
            s.score.value = 0.4
            s.score.explained = json.dumps(['1-c2', '2-c2'])
            s.doc.doc_id = 2

            s = msg.response.search.topk_results.add()
            s.score.value = 0.3
            s.score.explained = json.dumps(['2-c3'])
            s.doc.doc_id = 3

            msg.envelope.num_part.extend([1, 2])
            c1.send_message(msg)

            msg.response.search.ClearField('topk_results')

            s = msg.response.search.topk_results.add()
            s.score.value = 0.5
            s.score.explained = json.dumps(['2-c1', '1-c2', '1-c1'])
            s.doc.doc_id = 2

            s = msg.response.search.topk_results.add()
            s.score.value = 0.3
            s.score.explained = json.dumps(['1-c3', '2-c2'])
            s.doc.doc_id = 3

            s = msg.response.search.topk_results.add()
            s.score.value = 0.1
            s.score.explained = json.dumps(['2-c3'])
            s.doc.doc_id = 1
            c1.send_message(msg)
            r = c1.recv_message()
            self.assertSequenceEqual(r.envelope.num_part, [1])
            self.assertEqual(len(r.response.search.topk_results), 3)
            self.assertGreaterEqual(r.response.search.topk_results[0].score.value,
                                    r.response.search.topk_results[-1].score.value)
            print(r.response.search.topk_results)
            self.assertEqual(r.response.search.topk_results[0].score.explained, '1-c2\n2-c2\n\n2-c1\n1-c2\n1-c1\n\n')
            self.assertEqual(r.response.search.topk_results[1].score.explained, '1-c1\n1-c3\n2-c1\n\n2-c3\n\n')
            self.assertEqual(r.response.search.topk_results[2].score.explained, '2-c3\n\n1-c3\n2-c2\n\n')

            self.assertAlmostEqual(r.response.search.topk_results[0].score.value, 0.9)
            self.assertAlmostEqual(r.response.search.topk_results[1].score.value, 0.7)
            self.assertAlmostEqual(r.response.search.topk_results[2].score.value, 0.6)
示例#7
0
    def test_concat_router(self):
        args = set_router_parser().parse_args([
            '--yaml_path', self.concat_router_yaml,
            '--socket_out', str(SocketType.PUSH_BIND)
        ])
        c_args = _set_client_parser().parse_args([
            '--port_in', str(args.port_out),
            '--port_out', str(args.port_in),
            '--socket_in', str(SocketType.PULL_CONNECT)
        ])
        with RouterService(args), ZmqClient(c_args) as c1:
            msg = gnes_pb2.Message()
            msg.request.search.query.chunk_embeddings.CopyFrom(array2blob(np.random.random([5, 2])))
            msg.envelope.num_part.extend([1, 3])
            c1.send_message(msg)
            c1.send_message(msg)
            c1.send_message(msg)
            r = c1.recv_message()
            self.assertSequenceEqual(r.envelope.num_part, [1])
            print(r.envelope.routes)
            self.assertEqual(r.request.search.query.chunk_embeddings.shape, [5, 6])

            for j in range(1, 4):
                d = msg.request.index.docs.add()
                d.chunk_embeddings.CopyFrom(array2blob(np.random.random([5, 2 * j])))

            c1.send_message(msg)
            c1.send_message(msg)
            c1.send_message(msg)
            r = c1.recv_message()
            self.assertSequenceEqual(r.envelope.num_part, [1])
            for j in range(1, 4):
                self.assertEqual(r.request.index.docs[j - 1].chunk_embeddings.shape, [5, 6 * j])
示例#8
0
    def test_singleton_preprocessor_service_realdata(self):
        args = set_preprocessor_service_parser().parse_args(
            ['--yaml_path', self.singleton_img_pre_yaml])
        c_args = _set_client_parser().parse_args(
            ['--port_in',
             str(args.port_out), '--port_out',
             str(args.port_in)])
        all_zips = zipfile.ZipFile(os.path.join(self.dirname, 'imgs/test.zip'))
        all_bytes = [all_zips.open(v).read() for v in all_zips.namelist()]

        with PreprocessorService(args), ZmqClient(c_args) as client:
            for req in RequestGenerator.index(all_bytes):
                msg = gnes_pb2.Message()
                msg.request.index.CopyFrom(req.index)
                client.send_message(msg)
                r = client.recv_message()
                self.assertEqual(
                    r.envelope.routes[0].service,
                    'PreprocessorService:BaseSingletonPreprocessor')
                for d in r.request.index.docs:
                    self.assertEqual(len(d.chunks), 1)
                    self.assertEqual(len(blob2array(d.chunks[0].blob).shape),
                                     3)
                    self.assertEqual(blob2array(d.chunks[0].blob).shape[-1], 3)
                    self.assertEqual(
                        blob2array(d.chunks[0].blob).shape[0], 224)
                    self.assertEqual(
                        blob2array(d.chunks[0].blob).shape[1], 224)
                    print(blob2array(d.chunks[0].blob).dtype)
示例#9
0
 def test_send_recv_response(self):
     with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2:
         msg = gnes_pb2.Message()
         msg.envelope.client_id = c1.args.identity
         msg.response.train.status = 2
         c1.send_message(msg, squeeze_pb=True)
         r_msg = c2.recv_message()
         self.assertEqual(msg.response.train.status, r_msg.response.train.status)
示例#10
0
    def test_doc_reduce_router(self):
        args = set_router_service_parser().parse_args([
            '--yaml_path', self.doc_router_yaml, '--socket_out',
            str(SocketType.PUB_BIND)
        ])
        c_args = _set_client_parser().parse_args([
            '--port_in',
            str(args.port_out), '--port_out',
            str(args.port_in), '--socket_in',
            str(SocketType.SUB_CONNECT)
        ])
        with RouterService(args), ZmqClient(c_args) as c1:
            msg = gnes_pb2.Message()

            # shard1 only has d1
            s = msg.response.search.topk_results.add()
            s.score = 0.1
            s.doc.doc_id = 1
            s.doc.raw_text = 'd1'

            s = msg.response.search.topk_results.add()
            s.score = 0.2
            s.doc.doc_id = 2

            s = msg.response.search.topk_results.add()
            s.score = 0.3
            s.chunk.doc_id = 3

            msg.envelope.num_part = 2
            c1.send_message(msg)

            msg.response.search.ClearField('topk_results')

            # shard2 has d2 and d3
            s = msg.response.search.topk_results.add()
            s.score = 0.1
            s.doc.doc_id = 1

            s = msg.response.search.topk_results.add()
            s.score = 0.2
            s.doc.doc_id = 2
            s.doc.raw_text = 'd2'

            s = msg.response.search.topk_results.add()
            s.score = 0.3
            s.chunk.doc_id = 3
            s.doc.raw_text = 'd3'

            msg.response.search.top_k = 5
            msg.envelope.num_part = 2
            c1.send_message(msg)
            r = c1.recv_message()

            print(r.response.search.topk_results)
            self.assertEqual(r.envelope.num_part, 1)
            self.assertEqual(len(r.response.search.topk_results), 3)
            self.assertGreaterEqual(r.response.search.topk_results[0].score,
                                    r.response.search.topk_results[-1].score)
示例#11
0
    def test_rerank(self):
        args = set_router_parser().parse_args([
            '--yaml_path', self.rerank_router_yaml,
            '--socket_out', str(SocketType.PUB_BIND),
            '--py_path', self.python_code
        ])
        c_args = _set_client_parser().parse_args([
            '--port_in', str(args.port_out),
            '--port_out', str(args.port_in),
            '--socket_in', str(SocketType.SUB_CONNECT)
        ])
        with RouterService(args), ZmqClient(c_args) as c1, ZmqClient(c_args) as c2:
            msg = gnes_pb2.Message()
            msg.response.search.ClearField('topk_results')

            for i, line in enumerate(self.test_str):
                s = msg.response.search.topk_results.add()
                s.score.value = 0.1
                s.doc.doc_id = i
                s.doc.raw_text = line

            msg.envelope.num_part.extend([1])
            msg.response.search.top_k = 5
            c1.send_message(msg)

            r = c1.recv_message()
            self.assertSequenceEqual(r.envelope.num_part, [1])
            self.assertEqual(len(r.response.search.topk_results), 5)

            msg = gnes_pb2.Message()
            msg.response.search.ClearField('topk_results')

            for i, line in enumerate(self.test_str[:3]):
                s = msg.response.search.topk_results.add()
                s.score.value = 0.1
                s.doc.doc_id = i
                s.doc.raw_text = line

            msg.envelope.num_part.extend([1])
            msg.response.search.top_k = 5
            c1.send_message(msg)

            r = c1.recv_message()
            self.assertSequenceEqual(r.envelope.num_part, [1])
            self.assertEqual(len(r.response.search.topk_results), 3)
示例#12
0
 def test_new_msg(self):
     a = gnes_pb2.Message()
     a.response.index.status = gnes_pb2.Response.SUCCESS
     print(a)
     a.request.train.docs.extend([gnes_pb2.Document() for _ in range(2)])
     print(a)
     a.request.train.ClearField('docs')
     a.request.train.docs.extend([gnes_pb2.Document() for _ in range(3)])
     print(a)
示例#13
0
 def test_send_recv(self):
     with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2:
         msg = gnes_pb2.Message()
         msg.envelope.client_id = c1.args.identity
         d = msg.request.index.docs.add()
         d.raw_bytes = b'aa'
         c1.send_message(msg)
         r_msg = c2.recv_message()
         self.assertEqual(r_msg.request.index.docs[0].raw_bytes, d.raw_bytes)
示例#14
0
 def build_msgs(self):
     all_msgs = []
     num_msg = 20
     for j in range(num_msg):
         msg = gnes_pb2.Message()
         msg.envelope.client_id = 'abc'
         for j in range(random.randint(10, 20)):
             d = msg.request.index.docs.add()
             # each doc is about 1MB to 10MB
             d.raw_bytes = b'a' * random.randint(1000000, 10000000)
         all_msgs.append(msg)
     return all_msgs
示例#15
0
    def apply(self, msg: 'gnes_pb2.Message', *args, **kwargs):
        """
        Log the incoming message
        :param msg: incoming message
        """

        runtime = getattr(msg, msg.WhichOneof('body')).WhichOneof('body')
        print('recieved msg')
        print(msg)
        print(runtime)
        if runtime == 'index':
            req = gnes_pb2.Message()
示例#16
0
    def test_rerank(self):
        with RouterService(self.args), ZmqClient(self.c_args) as c1:
            msg = gnes_pb2.Message()
            msg.response.search.ClearField('topk_results')
            msg.request.search.query.raw_text = 'This is a query'

            for i, line in enumerate(self.test_str):
                s = msg.response.search.topk_results.add()
                s.score.value = 0.1
                s.doc.doc_id = i
                s.doc.raw_text = line

            msg.envelope.num_part.extend([1])
            msg.response.search.top_k = 5
            c1.send_message(msg)

            r = c1.recv_message()
            # import pdb
            # pdb.set_trace()
            self.assertSequenceEqual(r.envelope.num_part, [1])
            self.assertEqual(len(r.response.search.topk_results), 5)

            msg = gnes_pb2.Message()
            msg.response.search.ClearField('topk_results')

            for i, line in enumerate(self.test_str[:1]):
                s = msg.response.search.topk_results.add()
                s.score.value = 0.1
                s.doc.doc_id = i
                s.doc.raw_text = line

            msg.envelope.num_part.extend([1])
            msg.response.search.top_k = 5
            c1.send_message(msg)

            r = c1.recv_message()
            self.assertSequenceEqual(r.envelope.num_part, [1])
            self.assertEqual(len(r.response.search.topk_results), 1)
示例#17
0
 def test_send_recv_raw_bytes(self):
     with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2:
         msg = gnes_pb2.Message()
         msg.envelope.client_id = c1.args.identity
         for j in range(random.randint(10, 20)):
             d = msg.request.index.docs.add()
             d.raw_bytes = b'a' * random.randint(100, 1000)
         raw_bytes = copy.deepcopy([d.raw_bytes for d in msg.request.index.docs])
         c1.send_message(msg, squeeze_pb=True)
         r_msg = c2.recv_message()
         for d, o_d, r_d in zip(msg.request.index.docs, raw_bytes, r_msg.request.index.docs):
             self.assertEqual(d.raw_bytes, b'')
             self.assertEqual(o_d, r_d.raw_bytes)
             print('.', end='')
         print('checked %d docs' % len(msg.request.index.docs))
    def test_video_cut_by_clustering(self):
        args = set_preprocessor_parser().parse_args(
            ['--yaml_path', self.yml_path_4])
        c_args = _set_client_parser().parse_args(
            ['--port_in',
             str(args.port_out), '--port_out',
             str(args.port_in)])

        with PreprocessorService(args), ZmqClient(c_args) as client:
            for req in RequestGenerator.index(self.video_bytes):
                msg = gnes_pb2.Message()
                msg.request.index.CopyFrom(req.index)
                client.send_message(msg)
                r = client.recv_message()
                for d in r.request.index.docs:
                    self.assertEqual(len(d.chunks), 6)
示例#19
0
 def test_preprocessor_service_echo(self):
     args = set_preprocessor_service_parser().parse_args([])
     c_args = _set_client_parser().parse_args([
         '--port_in', str(args.port_out),
         '--port_out', str(args.port_in)
     ])
     with PreprocessorService(args), ZmqClient(c_args) as client:
         msg = gnes_pb2.Message()
         msg.request.index.docs.extend([gnes_pb2.Document() for _ in range(5)])
         client.send_message(msg)
         r = client.recv_message()
         print(r)
         msg.request.train.docs.extend([gnes_pb2.Document() for _ in range(5)])
         client.send_message(msg)
         r = client.recv_message()
         print(r)
示例#20
0
 def build_msgs2(self, seed=0):
     all_msgs = []
     num_msg = 20
     random.seed(seed)
     np.random.seed(seed)
     for j in range(num_msg):
         msg = gnes_pb2.Message()
         msg.envelope.client_id = 'abc'
         for _ in range(random.randint(10, 20)):
             d = msg.request.index.docs.add()
             # each doc is about 1MB to 10MB
             for _ in range(random.randint(10, 20)):
                 c = d.chunks.add()
                 c.embedding.CopyFrom(array2blob(np.random.random([10, 20, 30])))
                 c.blob.CopyFrom(array2blob(np.random.random([10, 20, 30])))
         all_msgs.append(msg)
     return all_msgs
    def test_empty_service(self):
        args = set_encoder_parser().parse_args(['--yaml_path', '!TestEncoder {gnes_config: {name: EncoderService, is_trained: true}}'])
        c_args = _set_client_parser().parse_args([
            '--port_in', str(args.port_out),
            '--port_out', str(args.port_in)])

        with ServiceManager(EncoderService, args), ZmqClient(c_args) as client:
            msg = gnes_pb2.Message()
            d = msg.request.index.docs.add()
            d.doc_type = gnes_pb2.Document.IMAGE

            c = d.chunks.add()
            c.blob.CopyFrom(array2blob(self.test_numeric))

            client.send_message(msg)
            r = client.recv_message()
            self.assertEqual(len(r.request.index.docs), 1)
            self.assertEqual(r.response.index.status, gnes_pb2.Response.SUCCESS)
示例#22
0
 def test_map_router(self):
     args = set_router_parser().parse_args([
         '--yaml_path', self.batch_router_yaml,
     ])
     c_args = _set_client_parser().parse_args([
         '--port_in', str(args.port_out),
         '--port_out', str(args.port_in),
     ])
     with RouterService(args), ZmqClient(c_args) as c1:
         msg = gnes_pb2.Message()
         msg.request.index.docs.extend([gnes_pb2.Document() for _ in range(5)])
         c1.send_message(msg)
         r = c1.recv_message()
         self.assertEqual(len(r.request.index.docs), 2)
         r = c1.recv_message()
         self.assertEqual(len(r.request.index.docs), 2)
         r = c1.recv_message()
         self.assertEqual(len(r.request.index.docs), 1)
示例#23
0
 def test_publish_router(self):
     args = set_router_parser().parse_args([
         '--yaml_path', self.publish_router_yaml,
         '--socket_out', str(SocketType.PUB_BIND)
     ])
     c_args = _set_client_parser().parse_args([
         '--port_in', str(args.port_out),
         '--port_out', str(args.port_in),
         '--socket_in', str(SocketType.SUB_CONNECT)
     ])
     with RouterService(args), ZmqClient(c_args) as c1, ZmqClient(c_args) as c2:
         msg = gnes_pb2.Message()
         msg.request.index.docs.extend([gnes_pb2.Document() for _ in range(5)])
         msg.envelope.num_part.append(1)
         c1.send_message(msg)
         r = c1.recv_message()
         self.assertSequenceEqual(r.envelope.num_part, [1, 2])
         r = c2.recv_message()
         self.assertSequenceEqual(r.envelope.num_part, [1, 2])
    def test_video_preprocessor_service_realdata(self):
        args = set_preprocessor_parser().parse_args(
            ['--yaml_path', self.yml_path])

        c_args = _set_client_parser().parse_args(
            ['--port_in',
             str(args.port_out), '--port_out',
             str(args.port_in)])

        with PreprocessorService(args), ZmqClient(c_args) as client:
            for req in RequestGenerator.index(self.video_bytes):
                msg = gnes_pb2.Message()
                msg.request.index.CopyFrom(req.index)
                client.send_message(msg)
                r = client.recv_message()
                for d in r.request.index.docs:
                    self.assertGreater(len(d.chunks), 0)
                    for _ in range(len(d.chunks)):
                        shape = blob2array(d.chunks[_].blob).shape
                        self.assertEqual(shape, (168, 192, 3))
示例#25
0
    def test_video_decode_preprocessor(self):
        args = set_preprocessor_parser().parse_args(['--yaml_path', self.yml_path])
        c_args = _set_client_parser().parse_args([
            '--port_in', str(args.port_out),
            '--port_out', str(args.port_in)])
        video_bytes = [
            open(os.path.join(self.video_path, _), 'rb').read()
            for _ in os.listdir(self.video_path)
        ]

        with ServiceManager(PreprocessorService, args), ZmqClient(c_args) as client:
            for req in RequestGenerator.index(video_bytes):
                msg = gnes_pb2.Message()
                msg.request.index.CopyFrom(req.index)
                client.send_message(msg)
                r = client.recv_message()
                for d in r.request.index.docs:
                    self.assertGreater(len(d.chunks), 0)
                    for _ in range(len(d.chunks)):
                        shape = blob2array(d.chunks[_].blob).shape
                        self.assertEqual(shape[1:], (299, 299, 3))
    def test_empty_service(self):
        args = set_indexer_parser().parse_args([
            '--yaml_path',
            '!BaseChunkIndexer {gnes_config: {name: IndexerService}}'
        ])
        c_args = _set_client_parser().parse_args(
            ['--port_in',
             str(args.port_out), '--port_out',
             str(args.port_in)])

        with ServiceManager(IndexerService, args), ZmqClient(c_args) as client:
            msg = gnes_pb2.Message()
            d = msg.request.index.docs.add()

            c = d.chunks.add()
            c.doc_id = 0
            c.embedding.CopyFrom(array2blob(self.test_numeric))
            c.offset = 0
            c.weight = 1.0

            client.send_message(msg)
            r = client.recv_message()
            self.assertEqual(r.response.index.status,
                             gnes_pb2.Response.SUCCESS)
示例#27
0
    def test_multimap_multireduce(self):
        # p1 ->
        #      p21 ->
        #              r311
        #              r312
        #                   ->  r41
        #                             -> r5
        #      p22 ->
        #              r321
        #              r322
        #                   -> r42
        #                             -> r5
        #                                       -> client
        p1 = set_router_parser().parse_args([
            '--yaml_path', self.publish_router_yaml,
            '--socket_in', str(SocketType.PULL_CONNECT),
            '--socket_out', str(SocketType.PUB_BIND),
        ])
        r5 = set_router_parser().parse_args([
            '--yaml_path', self.reduce_router_yaml,
            '--socket_in', str(SocketType.PULL_BIND),
            '--socket_out', str(SocketType.PUSH_CONNECT),
        ])
        r41 = set_router_parser().parse_args([
            '--yaml_path', self.reduce_router_yaml,
            '--socket_in', str(SocketType.PULL_BIND),
            '--socket_out', str(SocketType.PUSH_CONNECT),
            '--port_out', str(r5.port_in)
        ])
        r42 = set_router_parser().parse_args([
            '--yaml_path', self.reduce_router_yaml,
            '--socket_in', str(SocketType.PULL_BIND),
            '--socket_out', str(SocketType.PUSH_CONNECT),
            '--port_out', str(r5.port_in)
        ])
        p21 = set_router_parser().parse_args([
            '--yaml_path', self.publish_router_yaml,
            '--socket_in', str(SocketType.SUB_CONNECT),
            '--socket_out', str(SocketType.PUB_BIND),
            '--port_in', str(p1.port_out)
        ])
        p22 = set_router_parser().parse_args([
            '--yaml_path', self.publish_router_yaml,
            '--socket_in', str(SocketType.SUB_CONNECT),
            '--socket_out', str(SocketType.PUB_BIND),
            '--port_in', str(p1.port_out)
        ])
        r311 = set_router_parser().parse_args([
            '--socket_in', str(SocketType.SUB_CONNECT),
            '--socket_out', str(SocketType.PUSH_CONNECT),
            '--port_in', str(p21.port_out),
            '--port_out', str(r41.port_in),
            '--yaml_path', 'BaseRouter'
        ])
        r312 = set_router_parser().parse_args([
            '--socket_in', str(SocketType.SUB_CONNECT),
            '--socket_out', str(SocketType.PUSH_CONNECT),
            '--port_in', str(p21.port_out),
            '--port_out', str(r41.port_in),
            '--yaml_path', 'BaseRouter'
        ])
        r321 = set_router_parser().parse_args([
            '--socket_in', str(SocketType.SUB_CONNECT),
            '--socket_out', str(SocketType.PUSH_CONNECT),
            '--port_in', str(p22.port_out),
            '--port_out', str(r42.port_in),
            '--yaml_path', 'BaseRouter'
        ])
        r322 = set_router_parser().parse_args([
            '--socket_in', str(SocketType.SUB_CONNECT),
            '--socket_out', str(SocketType.PUSH_CONNECT),
            '--port_in', str(p22.port_out),
            '--port_out', str(r42.port_in),
            '--yaml_path', 'BaseRouter'
        ])

        c_args = _set_client_parser().parse_args([
            '--port_in', str(r5.port_out),
            '--port_out', str(p1.port_in),
            '--socket_in', str(SocketType.PULL_BIND),
            '--socket_out', str(SocketType.PUSH_BIND),
        ])
        with RouterService(p1), RouterService(r5), \
             RouterService(p21), RouterService(p22), \
             RouterService(r311), RouterService(r312), RouterService(r321), RouterService(r322), \
             RouterService(r41), RouterService(r42), \
             ZmqClient(c_args) as c1:
            msg = gnes_pb2.Message()
            msg.envelope.num_part.append(1)
            c1.send_message(msg)
            r = c1.recv_message()
            self.assertSequenceEqual(r.envelope.num_part, [1])
            print(r.envelope.routes)
示例#28
0
    def test_doc_combine_score_fn(self):
        from gnes.indexer.doc.dict import DictIndexer

        document_list = []
        document_id_list = []

        for j in range(1, 4):
            d = gnes_pb2.Document()
            for i in range(1, 4):
                c = d.chunks.add()
                c.doc_id = j
                c.offset = i
                c.weight = 1 / 3
            document_id_list.append(j)
            document_list.append(d)

        self.chunk_router_yaml = 'Chunk2DocTopkReducer'

        args = set_router_parser().parse_args([
            '--yaml_path', self.chunk_router_yaml, '--socket_out',
            str(SocketType.PUB_BIND)
        ])
        c_args = _set_client_parser().parse_args([
            '--port_in',
            str(args.port_out), '--port_out',
            str(args.port_in), '--socket_in',
            str(SocketType.SUB_CONNECT)
        ])
        with RouterService(args), ZmqClient(c_args) as c1:
            msg = gnes_pb2.Message()
            s = msg.response.search.topk_results.add()
            s.score.value = 0.1
            s.score.explained = '"1-c1"'
            s.chunk.doc_id = 1

            s = msg.response.search.topk_results.add()
            s.score.value = 0.2
            s.score.explained = '"1-c2"'
            s.chunk.doc_id = 2

            s = msg.response.search.topk_results.add()

            s.score.value = 0.3
            s.score.explained = '"1-c3"'
            s.chunk.doc_id = 1

            msg.envelope.num_part.extend([1, 2])
            c1.send_message(msg)

            msg.response.search.ClearField('topk_results')

            s = msg.response.search.topk_results.add()
            s.score.value = 0.2
            s.score.explained = '"2-c1"'
            s.chunk.doc_id = 1

            s = msg.response.search.topk_results.add()
            s.score.value = 0.2
            s.score.explained = '"2-c2"'
            s.chunk.doc_id = 2

            s = msg.response.search.topk_results.add()
            s.score.value = 0.3
            s.score.explained = '"2-c3"'
            s.chunk.doc_id = 3
            c1.send_message(msg)
            r = c1.recv_message()
            doc_indexer = DictIndexer(score_fn=CoordDocScoreFn())
            doc_indexer.add(keys=document_id_list, docs=document_list)

            queried_result = doc_indexer.query_and_score(
                docs=r.response.search.topk_results, top_k=2)
示例#29
0
    def test_chunk_reduce_router(self):
        args = set_router_service_parser().parse_args([
            '--yaml_path', self.chunk_router_yaml, '--socket_out',
            str(SocketType.PUB_BIND)
        ])
        c_args = _set_client_parser().parse_args([
            '--port_in',
            str(args.port_out), '--port_out',
            str(args.port_in), '--socket_in',
            str(SocketType.SUB_CONNECT)
        ])
        with RouterService(args), ZmqClient(c_args) as c1:
            msg = gnes_pb2.Message()
            s = msg.response.search.topk_results.add()
            s.score = 0.1
            s.score_explained = '1-c1'
            s.chunk.doc_id = 1

            s = msg.response.search.topk_results.add()
            s.score = 0.2
            s.score_explained = '1-c2'
            s.chunk.doc_id = 2

            s = msg.response.search.topk_results.add()
            s.score = 0.3
            s.score_explained = '1-c3'
            s.chunk.doc_id = 1

            msg.envelope.num_part = 2
            c1.send_message(msg)

            msg.response.search.ClearField('topk_results')

            s = msg.response.search.topk_results.add()
            s.score = 0.2
            s.score_explained = '2-c1'
            s.chunk.doc_id = 1

            s = msg.response.search.topk_results.add()
            s.score = 0.2
            s.score_explained = '2-c2'
            s.chunk.doc_id = 2

            s = msg.response.search.topk_results.add()
            s.score = 0.3
            s.score_explained = '2-c3'
            s.chunk.doc_id = 3
            msg.envelope.num_part = 2
            c1.send_message(msg)
            r = c1.recv_message()
            self.assertEqual(r.envelope.num_part, 1)
            self.assertEqual(len(r.response.search.topk_results), 3)
            self.assertGreaterEqual(r.response.search.topk_results[0].score,
                                    r.response.search.topk_results[-1].score)
            print(r.response.search.topk_results)
            self.assertEqual(r.response.search.topk_results[0].score_explained,
                             '1-c1\n1-c3\n2-c1\n')
            self.assertEqual(r.response.search.topk_results[1].score_explained,
                             '1-c2\n2-c2\n')
            self.assertEqual(r.response.search.topk_results[2].score_explained,
                             '2-c3\n')

            self.assertAlmostEqual(r.response.search.topk_results[0].score,
                                   0.6)
            self.assertAlmostEqual(r.response.search.topk_results[1].score,
                                   0.4)
            self.assertAlmostEqual(r.response.search.topk_results[2].score,
                                   0.3)
示例#30
0
    def test_doc_sum_reduce_router(self):
        args = set_router_parser().parse_args([
            '--yaml_path', self.doc_sum_yaml,
            '--socket_out', str(SocketType.PUB_BIND)
        ])
        c_args = _set_client_parser().parse_args([
            '--port_in', str(args.port_out),
            '--port_out', str(args.port_in),
            '--socket_in', str(SocketType.SUB_CONNECT)
        ])
        with RouterService(args), ZmqClient(c_args) as c1:
            msg = gnes_pb2.Message()

            s = msg.response.search.topk_results.add()
            s.score.value = 0.4
            s.doc.doc_id = 1
            s.doc.raw_text = 'd3'
            s.score.explained = '1-d3\n'

            s = msg.response.search.topk_results.add()
            s.score.value = 0.3
            s.doc.doc_id = 2
            s.doc.raw_text = 'd2'
            s.score.explained = '1-d2\n'

            s = msg.response.search.topk_results.add()
            s.score.value = 0.2
            s.doc.doc_id = 3
            s.doc.raw_text = 'd1'
            s.score.explained = '1-d3\n'

            msg.envelope.num_part.extend([1, 2])
            c1.send_message(msg)

            msg.response.search.ClearField('topk_results')

            s = msg.response.search.topk_results.add()
            s.score.value = 0.5
            s.doc.doc_id = 1
            s.doc.raw_text = 'd2'
            s.score.explained = '2-d2\n'

            s = msg.response.search.topk_results.add()
            s.score.value = 0.2
            s.doc.doc_id = 2
            s.doc.raw_text = 'd1'
            s.score.explained = '2-d1\n'

            s = msg.response.search.topk_results.add()
            s.score.value = 0.1
            s.doc.doc_id = 3
            s.doc.raw_text = 'd3'
            s.score.explained = '2-d3\n'

            msg.response.search.top_k = 5
            c1.send_message(msg)
            r = c1.recv_message()

            print(r.response.search.topk_results)
            self.assertSequenceEqual(r.envelope.num_part, [1])
            self.assertEqual(len(r.response.search.topk_results), 3)
            self.assertGreaterEqual(r.response.search.topk_results[0].score.value,
                                    r.response.search.topk_results[-1].score.value)