def test_send_recv_response(self): with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2: msg = gnes_pb2.Message() msg.envelope.client_id = c1.args.identity msg.response.train.status = 2 c1.send_message(msg, squeeze_pb=True) r_msg = c2.recv_message() self.assertEqual(msg.response.train.status, r_msg.response.train.status)
def test_send_recv(self): with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2: msg = gnes_pb2.Message() msg.envelope.client_id = c1.args.identity d = msg.request.index.docs.add() d.raw_bytes = b'aa' c1.send_message(msg) r_msg = c2.recv_message() self.assertEqual(r_msg.request.index.docs[0].raw_bytes, d.raw_bytes)
def test_benchmark3(self): all_msgs = self.build_msgs() all_msgs_bak = copy.deepcopy(all_msgs) with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2: for m, m1 in zip(all_msgs, all_msgs_bak): c1.send_message(m, squeeze_pb=True) r_m = c2.recv_message() for d, o_d, r_d in zip(m.request.index.docs, m1.request.index.docs, r_m.request.index.docs): self.assertEqual(d.raw_bytes, b'') self.assertEqual(o_d.raw_bytes, r_d.raw_bytes)
def test_benchmark4(self): all_msgs = self.build_msgs2() with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2: with TimeContext('send->recv, squeeze_pb=False'): for m in all_msgs: c1.send_message(m, squeeze_pb=False) c2.recv_message() with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2: with TimeContext('send->recv, squeeze_pb=True'): for m in all_msgs: c1.send_message(m, squeeze_pb=True) c2.recv_message()
def test_benchmark5(self): all_msgs = self.build_msgs2() all_msgs_bak = copy.deepcopy(all_msgs) with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2: with TimeContext('send->recv, squeeze_pb=True'): for m, m1 in zip(all_msgs, all_msgs_bak): c1.send_message(m, squeeze_pb=True) r_m = c2.recv_message() for d, r_d in zip(m1.request.index.docs, r_m.request.index.docs): for c, r_c in zip(d.chunks, r_d.chunks): np.allclose(blob2array(c.embedding), blob2array(r_c.embedding)) np.allclose(blob2array(c.blob), blob2array(r_c.blob))
def test_send_recv_raw_bytes(self): with ZmqClient(self.c1_args) as c1, ZmqClient(self.c2_args) as c2: msg = gnes_pb2.Message() msg.envelope.client_id = c1.args.identity for j in range(random.randint(10, 20)): d = msg.request.index.docs.add() d.raw_bytes = b'a' * random.randint(100, 1000) raw_bytes = copy.deepcopy([d.raw_bytes for d in msg.request.index.docs]) c1.send_message(msg, squeeze_pb=True) r_msg = c2.recv_message() for d, o_d, r_d in zip(msg.request.index.docs, raw_bytes, r_msg.request.index.docs): self.assertEqual(d.raw_bytes, b'') self.assertEqual(o_d, r_d.raw_bytes) print('.', end='') print('checked %d docs' % len(msg.request.index.docs))
def test_fasterrcnn_preprocessor(self): args = set_preprocessor_parser().parse_args( ['--yaml_path', self.fasterrcnn_yaml]) c_args = _set_client_parser().parse_args( ['--port_in', str(args.port_out), '--port_out', str(args.port_in)]) all_zips = zipfile.ZipFile(self.data_path) all_bytes = [all_zips.open(v).read() for v in all_zips.namelist()] with ServiceManager(PreprocessorService, args), ZmqClient(c_args) as client: for req in RequestGenerator.index(all_bytes): msg = gnes_pb2.Message() msg.request.index.CopyFrom(req.index) client.send_message(msg) r = client.recv_message() for d in r.request.index.docs: self.assertGreater(len(d.chunks), 0) for _ in range(len(d.chunks)): self.assertEqual( len(blob2array(d.chunks[_].blob).shape), 3) self.assertEqual( blob2array(d.chunks[_].blob).shape[-1], 3) self.assertEqual( blob2array(d.chunks[_].blob).shape[0], 224) self.assertEqual( blob2array(d.chunks[_].blob).shape[1], 224) print(blob2array(d.chunks[0].blob).dtype)
def test_preprocessor_service_realdata(self): args = set_preprocessor_parser().parse_args( ['--yaml_path', self.yaml_path]) c_args = _set_client_parser().parse_args( ['--port_in', str(args.port_out), '--port_out', str(args.port_in)]) with open(os.path.join(self.dirname, '26-doc-chinese.txt'), 'r', encoding='utf8') as fp: msg = gnes_pb2.Message() all_text = '' for v in fp: if v.strip(): d = msg.request.train.docs.add() d.raw_text = v all_text += v with PreprocessorService(args), ZmqClient(c_args) as client: client.send_message(msg) r = client.recv_message() print(r) msg1 = gnes_pb2.Message() msg1.request.index.docs.extend(msg.request.train.docs) client.send_message(msg1) r = client.recv_message() print(r) msg2 = gnes_pb2.Message() msg2.request.search.query.raw_text = all_text client.send_message(msg2) r = client.recv_message() print(r)
def test_chunk_sum_reduce_router(self): args = set_router_parser().parse_args([ '--yaml_path', self.chunk_sum_yaml, '--socket_out', str(SocketType.PUB_BIND) ]) c_args = _set_client_parser().parse_args([ '--port_in', str(args.port_out), '--port_out', str(args.port_in), '--socket_in', str(SocketType.SUB_CONNECT) ]) with RouterService(args), ZmqClient(c_args) as c1: msg = gnes_pb2.Message() s = msg.response.search.topk_results.add() s.score.value = 0.6 s.score.explained = json.dumps(['1-c1', '1-c3', '2-c1']) s.doc.doc_id = 1 s = msg.response.search.topk_results.add() s.score.value = 0.4 s.score.explained = json.dumps(['1-c2', '2-c2']) s.doc.doc_id = 2 s = msg.response.search.topk_results.add() s.score.value = 0.3 s.score.explained = json.dumps(['2-c3']) s.doc.doc_id = 3 msg.envelope.num_part.extend([1, 2]) c1.send_message(msg) msg.response.search.ClearField('topk_results') s = msg.response.search.topk_results.add() s.score.value = 0.5 s.score.explained = json.dumps(['2-c1', '1-c2', '1-c1']) s.doc.doc_id = 2 s = msg.response.search.topk_results.add() s.score.value = 0.3 s.score.explained = json.dumps(['1-c3', '2-c2']) s.doc.doc_id = 3 s = msg.response.search.topk_results.add() s.score.value = 0.1 s.score.explained = json.dumps(['2-c3']) s.doc.doc_id = 1 c1.send_message(msg) r = c1.recv_message() self.assertSequenceEqual(r.envelope.num_part, [1]) self.assertEqual(len(r.response.search.topk_results), 3) self.assertGreaterEqual(r.response.search.topk_results[0].score.value, r.response.search.topk_results[-1].score.value) print(r.response.search.topk_results) self.assertEqual(r.response.search.topk_results[0].score.explained, '1-c2\n2-c2\n\n2-c1\n1-c2\n1-c1\n\n') self.assertEqual(r.response.search.topk_results[1].score.explained, '1-c1\n1-c3\n2-c1\n\n2-c3\n\n') self.assertEqual(r.response.search.topk_results[2].score.explained, '2-c3\n\n1-c3\n2-c2\n\n') self.assertAlmostEqual(r.response.search.topk_results[0].score.value, 0.9) self.assertAlmostEqual(r.response.search.topk_results[1].score.value, 0.7) self.assertAlmostEqual(r.response.search.topk_results[2].score.value, 0.6)
def test_rerank_train(self): with RouterService(self.args), ZmqClient(self.c_args) as c1: msg = gnes_pb2.Message() msg.response.search.ClearField('topk_results') msg.request.search.query.raw_text = 'This is a query' for i, line in enumerate(self.test_str[:5]): s = msg.response.search.topk_results.add() s.score.value = 0.1 s.doc.doc_id = i s.doc.raw_text = line msg.envelope.num_part.extend([1]) msg.response.search.top_k = 5 c1.send_message(msg) r = c1.recv_message() print(r) msg = gnes_pb2.Message() for i, line in enumerate(self.test_str): doc = msg.request.train.docs.add() doc.doc_id = i doc.raw_bytes = json.dumps({ 'Query': 'test query', 'Candidate': line, 'Label': 1.0 if i % 2 == 0 else 0.0 }).encode('utf-8') msg.envelope.num_part.extend([1]) c1.send_message(msg) r = c1.recv_message() print(r)
def test_concat_router(self): args = set_router_parser().parse_args([ '--yaml_path', self.concat_router_yaml, '--socket_out', str(SocketType.PUSH_BIND) ]) c_args = _set_client_parser().parse_args([ '--port_in', str(args.port_out), '--port_out', str(args.port_in), '--socket_in', str(SocketType.PULL_CONNECT) ]) with RouterService(args), ZmqClient(c_args) as c1: msg = gnes_pb2.Message() msg.request.search.query.chunk_embeddings.CopyFrom(array2blob(np.random.random([5, 2]))) msg.envelope.num_part.extend([1, 3]) c1.send_message(msg) c1.send_message(msg) c1.send_message(msg) r = c1.recv_message() self.assertSequenceEqual(r.envelope.num_part, [1]) print(r.envelope.routes) self.assertEqual(r.request.search.query.chunk_embeddings.shape, [5, 6]) for j in range(1, 4): d = msg.request.index.docs.add() d.chunk_embeddings.CopyFrom(array2blob(np.random.random([5, 2 * j]))) c1.send_message(msg) c1.send_message(msg) c1.send_message(msg) r = c1.recv_message() self.assertSequenceEqual(r.envelope.num_part, [1]) for j in range(1, 4): self.assertEqual(r.request.index.docs[j - 1].chunk_embeddings.shape, [5, 6 * j])
def test_segmentation_preprocessor_service_realdata(self): args = set_preprocessor_parser().parse_args([ '--yaml_path', self.segmentation_img_pre_yaml ]) c_args = _set_client_parser().parse_args([ '--port_in', str(args.port_out), '--port_out', str(args.port_in) ]) all_zips = zipfile.ZipFile(os.path.join(self.dirname, 'imgs/test.zip')) all_bytes = [all_zips.open(v).read() for v in all_zips.namelist()] with PreprocessorService(args), ZmqClient(c_args) as client: for req in RequestGenerator.index(all_bytes): msg = gnes_pb2.Message() msg.request.index.CopyFrom(req.index) client.send_message(msg) r = client.recv_message() self.assertEqual(r.envelope.routes[0].service, 'PipelinePreprocessor') for d in r.request.index.docs: self.assertEqual(len(blob2array(d.chunks[0].blob).shape), 3) self.assertEqual(blob2array(d.chunks[0].blob).shape[-1], 3) self.assertEqual(blob2array(d.chunks[0].blob).shape[0], 224) self.assertEqual(blob2array(d.chunks[0].blob).shape[1], 224) print(blob2array(d.chunks[0].blob).dtype)
def test_rerank(self): args = set_router_parser().parse_args([ '--yaml_path', self.rerank_router_yaml, '--socket_out', str(SocketType.PUB_BIND), '--py_path', self.python_code ]) c_args = _set_client_parser().parse_args([ '--port_in', str(args.port_out), '--port_out', str(args.port_in), '--socket_in', str(SocketType.SUB_CONNECT) ]) with RouterService(args), ZmqClient(c_args) as c1, ZmqClient(c_args) as c2: msg = gnes_pb2.Message() msg.response.search.ClearField('topk_results') for i, line in enumerate(self.test_str): s = msg.response.search.topk_results.add() s.score.value = 0.1 s.doc.doc_id = i s.doc.raw_text = line msg.envelope.num_part.extend([1]) msg.response.search.top_k = 5 c1.send_message(msg) r = c1.recv_message() self.assertSequenceEqual(r.envelope.num_part, [1]) self.assertEqual(len(r.response.search.topk_results), 5) msg = gnes_pb2.Message() msg.response.search.ClearField('topk_results') for i, line in enumerate(self.test_str[:3]): s = msg.response.search.topk_results.add() s.score.value = 0.1 s.doc.doc_id = i s.doc.raw_text = line msg.envelope.num_part.extend([1]) msg.response.search.top_k = 5 c1.send_message(msg) r = c1.recv_message() self.assertSequenceEqual(r.envelope.num_part, [1]) self.assertEqual(len(r.response.search.topk_results), 3)
def test_doc_reduce_router(self): args = set_router_service_parser().parse_args([ '--yaml_path', self.doc_router_yaml, '--socket_out', str(SocketType.PUB_BIND) ]) c_args = _set_client_parser().parse_args([ '--port_in', str(args.port_out), '--port_out', str(args.port_in), '--socket_in', str(SocketType.SUB_CONNECT) ]) with RouterService(args), ZmqClient(c_args) as c1: msg = gnes_pb2.Message() # shard1 only has d1 s = msg.response.search.topk_results.add() s.score = 0.1 s.doc.doc_id = 1 s.doc.raw_text = 'd1' s = msg.response.search.topk_results.add() s.score = 0.2 s.doc.doc_id = 2 s = msg.response.search.topk_results.add() s.score = 0.3 s.doc.doc_id = 3 msg.envelope.num_part.extend([1, 2]) c1.send_message(msg) msg.response.search.ClearField('topk_results') # shard2 has d2 and d3 s = msg.response.search.topk_results.add() s.score = 0.1 s.doc.doc_id = 1 s = msg.response.search.topk_results.add() s.score = 0.2 s.doc.doc_id = 2 s.doc.raw_text = 'd2' s = msg.response.search.topk_results.add() s.score = 0.3 s.doc.doc_id = 3 s.doc.raw_text = 'd3' msg.response.search.top_k = 5 c1.send_message(msg) r = c1.recv_message() print(r.response.search.topk_results) self.assertSequenceEqual(r.envelope.num_part, [1]) self.assertEqual(len(r.response.search.topk_results), 3) self.assertGreaterEqual(r.response.search.topk_results[0].score, r.response.search.topk_results[-1].score)
def test_block_router(self): with RouterService(self.args), ZmqClient(self.c_args) as c1: msg = gnes_pb2.Message() msg.request.train.docs.add() c1.send_message(msg) msg = gnes_pb2.Message() msg.request.index.docs.add() c1.send_message(msg) r = c1.recv_message()
def test_publish_router(self): args = set_router_parser().parse_args([ '--yaml_path', self.publish_router_yaml, '--socket_out', str(SocketType.PUB_BIND) ]) c_args = _set_client_parser().parse_args([ '--port_in', str(args.port_out), '--port_out', str(args.port_in), '--socket_in', str(SocketType.SUB_CONNECT) ]) with RouterService(args), ZmqClient(c_args) as c1, ZmqClient(c_args) as c2: msg = gnes_pb2.Message() msg.request.index.docs.extend([gnes_pb2.Document() for _ in range(5)]) msg.envelope.num_part.append(1) c1.send_message(msg) r = c1.recv_message() self.assertSequenceEqual(r.envelope.num_part, [1, 2]) r = c2.recv_message() self.assertSequenceEqual(r.envelope.num_part, [1, 2])
def test_keyword(self): args = set_router_parser().parse_args([ '--yaml_path', self.yaml, '--socket_out', str(SocketType.PUB_BIND), '--py_path', self.python_code, ]) args.as_response = True c_args = _set_client_parser().parse_args([ '--port_in', str(args.port_out), '--port_out', str(args.port_in), '--socket_in', str(SocketType.SUB_CONNECT) ]) with IndexerService(args), ZmqClient(c_args) as c1: msg = gnes_pb2.Message() for i, vec in enumerate(self.test_vec): doc = msg.request.index.docs.add() doc.doc_id = i doc.raw_text = self.test_str[i] c = doc.chunks.add() c.doc_id = i c.offset = 0 c.embedding.data = vec.tobytes() for d in vec.shape: c.embedding.shape.extend([d]) c.embedding.dtype = str(vec.dtype) c.text = self.test_str[i] c1.send_message(msg) r = c1.recv_message() self.assert_(r.response.index) for i, vec in enumerate(self.test_vec): msg = gnes_pb2.Message() msg.request.search.query.doc_id = 1 msg.request.search.top_k = 1 c = msg.request.search.query.chunks.add() c.doc_id = 1 c.embedding.data = vec.tobytes() for d in vec.shape: c.embedding.shape.extend([d]) c.embedding.dtype = str(vec.dtype) c.offset = 0 c.weight = 1 c.text = self.test_str[i] c1.send_message(msg) r = c1.recv_message() self.assert_( r.response.search.topk_results[0].chunk.doc_id == i)
def test_video_cut_by_clustering(self): args = set_preprocessor_parser().parse_args( ['--yaml_path', self.yml_path_4]) c_args = _set_client_parser().parse_args( ['--port_in', str(args.port_out), '--port_out', str(args.port_in)]) with PreprocessorService(args), ZmqClient(c_args) as client: for req in RequestGenerator.index(self.video_bytes): msg = gnes_pb2.Message() msg.request.index.CopyFrom(req.index) client.send_message(msg) r = client.recv_message() for d in r.request.index.docs: self.assertEqual(len(d.chunks), 6)
def test_segmentation_preprocessor_service_echo(self): args = set_preprocessor_parser().parse_args([ '--yaml_path', self.segmentation_img_pre_yaml ]) c_args = _set_client_parser().parse_args([ '--port_in', str(args.port_out), '--port_out', str(args.port_in) ]) with PreprocessorService(args), ZmqClient(c_args) as client: msg = gnes_pb2.Message() msg.request.index.docs.extend([gnes_pb2.Document() for _ in range(5)]) client.send_message(msg) r = client.recv_message() # print(r) msg.request.train.docs.extend([gnes_pb2.Document() for _ in range(5)]) client.send_message(msg) r = client.recv_message()
def test_map_router(self): args = set_router_parser().parse_args([ '--yaml_path', self.batch_router_yaml, ]) c_args = _set_client_parser().parse_args([ '--port_in', str(args.port_out), '--port_out', str(args.port_in), ]) with RouterService(args), ZmqClient(c_args) as c1: msg = gnes_pb2.Message() msg.request.index.docs.extend([gnes_pb2.Document() for _ in range(5)]) c1.send_message(msg) r = c1.recv_message() self.assertEqual(len(r.request.index.docs), 2) r = c1.recv_message() self.assertEqual(len(r.request.index.docs), 2) r = c1.recv_message() self.assertEqual(len(r.request.index.docs), 1)
def test_empty_service(self): args = set_encoder_parser().parse_args(['--yaml_path', '!TestEncoder {gnes_config: {name: EncoderService, is_trained: true}}']) c_args = _set_client_parser().parse_args([ '--port_in', str(args.port_out), '--port_out', str(args.port_in)]) with ServiceManager(EncoderService, args), ZmqClient(c_args) as client: msg = gnes_pb2.Message() d = msg.request.index.docs.add() d.doc_type = gnes_pb2.Document.IMAGE c = d.chunks.add() c.blob.CopyFrom(array2blob(self.test_numeric)) client.send_message(msg) r = client.recv_message() self.assertEqual(len(r.request.index.docs), 1) self.assertEqual(r.response.index.status, gnes_pb2.Response.SUCCESS)
def test_video_preprocessor_service_realdata(self): args = set_preprocessor_parser().parse_args( ['--yaml_path', self.yml_path]) c_args = _set_client_parser().parse_args( ['--port_in', str(args.port_out), '--port_out', str(args.port_in)]) with PreprocessorService(args), ZmqClient(c_args) as client: for req in RequestGenerator.index(self.video_bytes): msg = gnes_pb2.Message() msg.request.index.CopyFrom(req.index) client.send_message(msg) r = client.recv_message() for d in r.request.index.docs: self.assertGreater(len(d.chunks), 0) for _ in range(len(d.chunks)): shape = blob2array(d.chunks[_].blob).shape self.assertEqual(shape, (168, 192, 3))
def test_video_decode_preprocessor(self): args = set_preprocessor_parser().parse_args(['--yaml_path', self.yml_path]) c_args = _set_client_parser().parse_args([ '--port_in', str(args.port_out), '--port_out', str(args.port_in)]) video_bytes = [ open(os.path.join(self.video_path, _), 'rb').read() for _ in os.listdir(self.video_path) ] with ServiceManager(PreprocessorService, args), ZmqClient(c_args) as client: for req in RequestGenerator.index(video_bytes): msg = gnes_pb2.Message() msg.request.index.CopyFrom(req.index) client.send_message(msg) r = client.recv_message() for d in r.request.index.docs: self.assertGreater(len(d.chunks), 0) for _ in range(len(d.chunks)): shape = blob2array(d.chunks[_].blob).shape self.assertEqual(shape[1:], (299, 299, 3))
def test_rerank(self): with RouterService(self.args), ZmqClient(self.c_args) as c1: msg = gnes_pb2.Message() msg.response.search.ClearField('topk_results') msg.request.search.query.raw_text = 'This is a query' for i, line in enumerate(self.test_str): s = msg.response.search.topk_results.add() s.score.value = 0.1 s.doc.doc_id = i s.doc.raw_text = line msg.envelope.num_part.extend([1]) msg.response.search.top_k = 5 c1.send_message(msg) r = c1.recv_message() # import pdb # pdb.set_trace() self.assertSequenceEqual(r.envelope.num_part, [1]) self.assertEqual(len(r.response.search.topk_results), 5) msg = gnes_pb2.Message() msg.response.search.ClearField('topk_results') for i, line in enumerate(self.test_str[:1]): s = msg.response.search.topk_results.add() s.score.value = 0.1 s.doc.doc_id = i s.doc.raw_text = line msg.envelope.num_part.extend([1]) msg.response.search.top_k = 5 c1.send_message(msg) r = c1.recv_message() self.assertSequenceEqual(r.envelope.num_part, [1]) self.assertEqual(len(r.response.search.topk_results), 1)
def test_empty_service(self): args = set_indexer_parser().parse_args([ '--yaml_path', '!BaseChunkIndexer {gnes_config: {name: IndexerService}}' ]) c_args = _set_client_parser().parse_args( ['--port_in', str(args.port_out), '--port_out', str(args.port_in)]) with ServiceManager(IndexerService, args), ZmqClient(c_args) as client: msg = gnes_pb2.Message() d = msg.request.index.docs.add() c = d.chunks.add() c.doc_id = 0 c.embedding.CopyFrom(array2blob(self.test_numeric)) c.offset = 0 c.weight = 1.0 client.send_message(msg) r = client.recv_message() self.assertEqual(r.response.index.status, gnes_pb2.Response.SUCCESS)
def test_multimap_multireduce(self): # p1 -> # p21 -> # r311 # r312 # -> r41 # -> r5 # p22 -> # r321 # r322 # -> r42 # -> r5 # -> client p1 = set_router_parser().parse_args([ '--yaml_path', self.publish_router_yaml, '--socket_in', str(SocketType.PULL_CONNECT), '--socket_out', str(SocketType.PUB_BIND), ]) r5 = set_router_parser().parse_args([ '--yaml_path', self.reduce_router_yaml, '--socket_in', str(SocketType.PULL_BIND), '--socket_out', str(SocketType.PUSH_CONNECT), ]) r41 = set_router_parser().parse_args([ '--yaml_path', self.reduce_router_yaml, '--socket_in', str(SocketType.PULL_BIND), '--socket_out', str(SocketType.PUSH_CONNECT), '--port_out', str(r5.port_in) ]) r42 = set_router_parser().parse_args([ '--yaml_path', self.reduce_router_yaml, '--socket_in', str(SocketType.PULL_BIND), '--socket_out', str(SocketType.PUSH_CONNECT), '--port_out', str(r5.port_in) ]) p21 = set_router_parser().parse_args([ '--yaml_path', self.publish_router_yaml, '--socket_in', str(SocketType.SUB_CONNECT), '--socket_out', str(SocketType.PUB_BIND), '--port_in', str(p1.port_out) ]) p22 = set_router_parser().parse_args([ '--yaml_path', self.publish_router_yaml, '--socket_in', str(SocketType.SUB_CONNECT), '--socket_out', str(SocketType.PUB_BIND), '--port_in', str(p1.port_out) ]) r311 = set_router_parser().parse_args([ '--socket_in', str(SocketType.SUB_CONNECT), '--socket_out', str(SocketType.PUSH_CONNECT), '--port_in', str(p21.port_out), '--port_out', str(r41.port_in), '--yaml_path', 'BaseRouter' ]) r312 = set_router_parser().parse_args([ '--socket_in', str(SocketType.SUB_CONNECT), '--socket_out', str(SocketType.PUSH_CONNECT), '--port_in', str(p21.port_out), '--port_out', str(r41.port_in), '--yaml_path', 'BaseRouter' ]) r321 = set_router_parser().parse_args([ '--socket_in', str(SocketType.SUB_CONNECT), '--socket_out', str(SocketType.PUSH_CONNECT), '--port_in', str(p22.port_out), '--port_out', str(r42.port_in), '--yaml_path', 'BaseRouter' ]) r322 = set_router_parser().parse_args([ '--socket_in', str(SocketType.SUB_CONNECT), '--socket_out', str(SocketType.PUSH_CONNECT), '--port_in', str(p22.port_out), '--port_out', str(r42.port_in), '--yaml_path', 'BaseRouter' ]) c_args = _set_client_parser().parse_args([ '--port_in', str(r5.port_out), '--port_out', str(p1.port_in), '--socket_in', str(SocketType.PULL_BIND), '--socket_out', str(SocketType.PUSH_BIND), ]) with RouterService(p1), RouterService(r5), \ RouterService(p21), RouterService(p22), \ RouterService(r311), RouterService(r312), RouterService(r321), RouterService(r322), \ RouterService(r41), RouterService(r42), \ ZmqClient(c_args) as c1: msg = gnes_pb2.Message() msg.envelope.num_part.append(1) c1.send_message(msg) r = c1.recv_message() self.assertSequenceEqual(r.envelope.num_part, [1]) print(r.envelope.routes)
def test_doc_sum_reduce_router(self): args = set_router_service_parser().parse_args([ '--yaml_path', self.doc_sum_yaml, '--socket_out', str(SocketType.PUB_BIND) ]) c_args = _set_client_parser().parse_args([ '--port_in', str(args.port_out), '--port_out', str(args.port_in), '--socket_in', str(SocketType.SUB_CONNECT) ]) with RouterService(args), ZmqClient(c_args) as c1: msg = gnes_pb2.Message() s = msg.response.search.topk_results.add() s.score = 0.4 s.doc.meta_info = b'1' s.doc.raw_text = 'd3' s.score_explained = '1-d3\n' s = msg.response.search.topk_results.add() s.score = 0.3 s.doc.meta_info = b'2' s.doc.raw_text = 'd2' s.score_explained = '1-d2\n' s = msg.response.search.topk_results.add() s.score = 0.2 s.doc.meta_info = b'3' s.doc.raw_text = 'd1' s.score_explained = '1-d3\n' msg.envelope.num_part.extend([1, 2]) c1.send_message(msg) msg.response.search.ClearField('topk_results') s = msg.response.search.topk_results.add() s.score = 0.5 s.doc.meta_info = b'1' s.doc.raw_text = 'd2' s.score_explained = '2-d2\n' s = msg.response.search.topk_results.add() s.score = 0.2 s.doc.meta_info = b'2' s.doc.raw_text = 'd1' s.score_explained = '2-d1\n' s = msg.response.search.topk_results.add() s.score = 0.1 s.doc.meta_info = b'3' s.doc.raw_text = 'd3' s.score_explained = '2-d3\n' msg.response.search.top_k = 5 c1.send_message(msg) r = c1.recv_message() print(r.response.search.topk_results) self.assertSequenceEqual(r.envelope.num_part, [1]) self.assertEqual(len(r.response.search.topk_results), 3) self.assertGreaterEqual(r.response.search.topk_results[0].score, r.response.search.topk_results[-1].score)
def test_doc_combine_score_fn(self): from gnes.indexer.doc.dict import DictIndexer document_list = [] document_id_list = [] for j in range(1, 4): d = gnes_pb2.Document() for i in range(1, 4): c = d.chunks.add() c.doc_id = j c.offset = i c.weight = 1 / 3 document_id_list.append(j) document_list.append(d) self.chunk_router_yaml = 'Chunk2DocTopkReducer' args = set_router_parser().parse_args([ '--yaml_path', self.chunk_router_yaml, '--socket_out', str(SocketType.PUB_BIND) ]) c_args = _set_client_parser().parse_args([ '--port_in', str(args.port_out), '--port_out', str(args.port_in), '--socket_in', str(SocketType.SUB_CONNECT) ]) with RouterService(args), ZmqClient(c_args) as c1: msg = gnes_pb2.Message() s = msg.response.search.topk_results.add() s.score.value = 0.1 s.score.explained = '"1-c1"' s.chunk.doc_id = 1 s = msg.response.search.topk_results.add() s.score.value = 0.2 s.score.explained = '"1-c2"' s.chunk.doc_id = 2 s = msg.response.search.topk_results.add() s.score.value = 0.3 s.score.explained = '"1-c3"' s.chunk.doc_id = 1 msg.envelope.num_part.extend([1, 2]) c1.send_message(msg) msg.response.search.ClearField('topk_results') s = msg.response.search.topk_results.add() s.score.value = 0.2 s.score.explained = '"2-c1"' s.chunk.doc_id = 1 s = msg.response.search.topk_results.add() s.score.value = 0.2 s.score.explained = '"2-c2"' s.chunk.doc_id = 2 s = msg.response.search.topk_results.add() s.score.value = 0.3 s.score.explained = '"2-c3"' s.chunk.doc_id = 3 c1.send_message(msg) r = c1.recv_message() doc_indexer = DictIndexer(score_fn=CoordDocScoreFn()) doc_indexer.add(keys=document_id_list, docs=document_list) queried_result = doc_indexer.query_and_score( docs=r.response.search.topk_results, top_k=2)