def test_match2docranker_batching_flow(ranker, mocker): NUM_DOCS_QUERIES = 15 NUM_MATCHES = 10 queries = DocumentArray([]) for i in range(NUM_DOCS_QUERIES): query = Document(id=f'query-{i}') for j in range(NUM_MATCHES): m = Document(id=f'match-{i}-{j}', tags={'dummy_score': j}) query.matches.append(m) queries.append(query) def validate_response(resp): assert len(resp.search.docs) == NUM_DOCS_QUERIES for i, query in enumerate(resp.search.docs): for j, match in enumerate(query.matches, 1): assert match.id == f'match-{i}-{NUM_MATCHES - j}' assert match.score.value == NUM_MATCHES - j mock = mocker.Mock() with Flow().add(name='ranker', uses=ranker) as f: f.search(inputs=queries, on_done=mock) mock.assert_called_once() validate_callback(mock, validate_response)
def test_union(docarray, document_factory): additional_docarray = DocumentArray([]) for idx in range(4, 10): doc = document_factory.create(idx, f'test {idx}') additional_docarray.append(doc) union = docarray + additional_docarray for idx in range(0, 3): assert union[idx].id == docarray[idx].id for idx in range(0, 6): assert union[idx + 3].id == additional_docarray[idx].id
def test_get_content(stack, num_rows, field): batch_size = 10 embed_size = 20 kwargs = {field: np.random.random((num_rows, embed_size))} docs = DocumentArray([Document(**kwargs) for _ in range(batch_size)]) docs.append(Document()) contents, pts = docs.extract_docs(field, stack_contents=stack) if stack: assert isinstance(contents, np.ndarray) assert contents.shape == (batch_size, num_rows, embed_size) else: assert len(contents) == batch_size for content in contents: assert content.shape == (num_rows, embed_size)
def test_match2docranker_batching(ranker): NUM_DOCS_QUERIES = 15 NUM_MATCHES = 10 old_matches_scores = [] queries_metas = [] matches_metas = [] queries = DocumentArray([]) for i in range(NUM_DOCS_QUERIES): old_match_scores = [] match_metas = [] query = Document(id=f'query-{i}') for j in range(NUM_MATCHES): m = Document(id=f'match-{i}-{j}', tags={'dummy_score': j}) query.matches.append(m) old_match_scores.append(0) match_metas.append(m.get_attrs('tags__dummy_score')) queries.append(query) old_matches_scores.append(old_match_scores) queries_metas.append(None) matches_metas.append(match_metas) queries_scores = ranker.score(old_matches_scores, queries_metas, matches_metas) assert len(queries_scores) == NUM_DOCS_QUERIES for i, (query, matches_scores) in enumerate(zip(queries, queries_scores)): assert len(matches_scores) == NUM_MATCHES for j, (match, score) in enumerate(zip(query.matches, matches_scores)): match.score = NamedScore(value=j) assert score == j query.matches.sort(key=lambda x: x.score.value, reverse=True) for j, match in enumerate(query.matches, 1): assert match.id == f'match-{i}-{NUM_MATCHES - j}' assert match.score.value == NUM_MATCHES - j