コード例 #1
0
def test_request_extend_queryset():
    q1 = SliceQL(start=3, end=4)
    q2 = QueryLang(SliceQL(start=3, end=4, priority=1))
    q3 = jina_pb2.QueryLangProto()
    q3.name = 'SliceQL'
    q3.parameters['start'] = 3
    q3.parameters['end'] = 4
    q3.priority = 2
    r = Request()
    r.queryset.extend([q1, q2, q3])
    assert isinstance(r.queryset, Sequence)
    for idx, q in enumerate(r.queryset):
        assert q.priority == idx
        assert q.parameters['start'] == 3
        assert q.parameters['end'] == 4

    # q1 and q2 refer to the same
    assert len({id(q) for q in r.queryset}) == 2

    r2 = Request()
    r2.queryset.extend(r.queryset)
    assert len({id(q) for q in r2.queryset}) == 2

    r = Request()
    r.queryset.append(q1)
    r.queryset.append(q2)
    r.queryset.append(q3)
    for idx, q in enumerate(r.queryset):
        assert q.priority == idx
        assert q.parameters['start'] == 3
        assert q.parameters['end'] == 4

    with pytest.raises(TypeError):
        r.queryset.extend(1)
コード例 #2
0
def test_queryset_with_struct(random_workspace):
    total_docs = 4
    docs = []
    for doc_id in range(total_docs):
        doc = jina_pb2.DocumentProto()
        doc.text = f'I am doc{doc_id}'
        NdArray(doc.embedding).value = np.array([doc_id])
        doc.tags['label'] = f'label{doc_id % 2 + 1}'
        docs.append(doc)

    f = (Flow()
         .add(uses='- !FilterQL | {lookups: {tags__label__in: [label1, label2]}, traversal_paths: [r]}'))

    def validate_all_docs(resp):
        assert len(resp.docs) == total_docs

    def validate_label2_docs(resp):
        assert len(resp.docs) == total_docs / 2

    with f:
        # keep all the docs
        f.index(docs, output_fn=validate_all_docs, callback_on='body')

        # keep only the docs with label2
        qs = jina_pb2.QueryLangProto(name='FilterQL', priority=1)
        qs.parameters['lookups'] = {'tags__label': 'label2'}
        qs.parameters['traversal_paths'] = ['r']
        f.index(docs, queryset=qs, output_fn=validate_label2_docs, callback_on='body')
コード例 #3
0
def test_request_extend_queryset():
    q1 = SliceQL(start=3, end=4)
    q2 = QueryLang(SliceQL(start=3, end=4, priority=1))
    q3 = jina_pb2.QueryLangProto()
    q3.name = 'SliceQL'
    q3.parameters['start'] = 3
    q3.parameters['end'] = 4
    q3.priority = 2
    r = Request()
    r.extend_queryset([q1, q2, q3])
    for idx, q in enumerate(r.queryset):
        assert q.priority == idx
        assert q.parameters['start'] == 3
        assert q.parameters['end'] == 4

    r = Request()
    r.extend_queryset(q1)
    r.extend_queryset(q2)
    r.extend_queryset(q3)
    for idx, q in enumerate(r.queryset):
        assert q.priority == idx
        assert q.parameters['start'] == 3
        assert q.parameters['end'] == 4

    with pytest.raises(TypeError):
        r.extend_queryset(1)
コード例 #4
0
def test_read_from_req():
    def validate1(req):
        assert len(req.docs) == 5

    def validate2(req):
        assert len(req.docs) == 3

    qs = jina_pb2.QueryLangProto(name='SliceQL', priority=1)
    qs.parameters['start'] = 1
    qs.parameters['end'] = 4

    f = Flow(callback_on='body').add(uses='- !SliceQL | {start: 0, end: 5}')

    # without queryset
    with f:
        f.index(random_docs(10), output_fn=validate1)

    # with queryset
    with f:
        f.index(random_docs(10), queryset=qs, output_fn=validate2)

    qs.priority = -1
    # with queryset, but priority is no larger than driver's default
    with f:
        f.index(random_docs(10), queryset=qs, output_fn=validate1)
コード例 #5
0
def test_topk_override(config):
    # Making queryset
    top_k_queryset = jina_pb2.QueryLangProto()
    top_k_queryset.name = 'VectorSearchDriver'
    top_k_queryset.priority = 1
    top_k_queryset.parameters['top_k'] = os.environ['JINA_TOPK_OVERRIDE']

    with Flow().load_config('flow.yml') as index_flow:
        index_flow.index(input_fn=random_docs(100))
    with Flow().load_config('flow.yml') as search_flow:
        search_flow.search(input_fn=random_docs(int(os.environ['JINA_NDOCS'])),
                           output_fn=validate_override_results,
                           queryset=[top_k_queryset])
コード例 #6
0
 def queryset(self):
     q = jina_pb2.QueryLangProto()
     q.name = 'SimpleVectorSearchDriver'
     q.priority = 1
     q.parameters['top_k'] = 4
     return [q]