Exemplo n.º 1
0
def test_issue_3527_delete_and_match(tmpdir):
    dam = DocumentArrayMemmap(tmpdir)

    dam.append(Document(id='a', embedding=np.array([1, 2, 3], dtype=np.float32)))
    del dam['a']

    dam.append(Document(id='c', embedding=np.array([1, 2, 3], dtype=np.float32)))
    da = DocumentArray([Document(embedding=np.array([5, 6, 7], dtype=np.float32))])
    da.match(dam)
    assert da[0].matches[0].id == 'c'
Exemplo n.º 2
0
def test_match_exclude_self(exclude_self, num_matches, only_id):
    da1 = DocumentArray([
        Document(id='1', embedding=np.array([1, 2])),
        Document(id='2', embedding=np.array([3, 4])),
    ])
    da2 = DocumentArray([
        Document(id='1', embedding=np.array([1, 2])),
        Document(id='2', embedding=np.array([3, 4])),
    ])
    da1.match(da2, exclude_self=exclude_self, only_id=only_id)
    for d in da1:
        assert len(d.matches) == num_matches
Exemplo n.º 3
0
def test_exclude_self_should_keep_limit(limit, exclude_self):
    da = DocumentArray([
        Document(embedding=np.array([3, 1, 0])),
        Document(embedding=np.array([3, 0, 1])),
        Document(embedding=np.array([3, 0, 0])),
        Document(embedding=np.array([3, 1, 1])),
    ])
    da.match(da, exclude_self=exclude_self, limit=limit)
    for d in da:
        assert len(d.matches) == limit
        if exclude_self:
            for m in d.matches:
                assert d.id != m.id
Exemplo n.º 4
0
def test_match_inclusive(only_id):
    """Call match function, while the other :class:`DocumentArray` is itself
    or have same :class:`Document`.
    """
    # The document array da1 match with itself.
    da1 = DocumentArray([
        Document(embedding=np.array([1, 2, 3])),
        Document(embedding=np.array([1, 0, 1])),
        Document(embedding=np.array([1, 1, 2])),
    ])

    da1.match(da1, only_id=only_id)
    assert len(da1) == 3
    traversed = da1.traverse_flat(traversal_paths=['m', 'mm', 'mmm'])
    assert len(traversed) == 9
    # The document array da2 shares same documents with da1
    da2 = DocumentArray(
        [Document(embedding=np.array([4, 1, 3])), da1[0], da1[1]])
    da1.match(da2, only_id=only_id)
    assert len(da2) == 3
    traversed = da1.traverse_flat(traversal_paths=['m', 'mm', 'mmm'])
    assert len(traversed) == 9