def test_batching( docs: DocumentArray, batch_size: int, filter_attr: str, expected_sizes: List[int], traversal_path: List[str], ): generator = docs.batch( traversal_paths=traversal_path, batch_size=batch_size, require_attr=filter_attr, ) for batch, expected_size in zip(generator, expected_sizes): assert (len(batch) == expected_size ), f'Expected size {expected_size} but got {len(batch)}'
def test_needs_attr_empty(attr_name: str, attr_value: Any): """ Test that filtering by attribute works properly for empty documents """ docs = DocumentArray([Document(), Document()]) setattr(docs[1], attr_name, attr_value) generator = docs.batch(batch_size=1, require_attr=attr_name) filtered_docs = list(generator) assert len(filtered_docs) == 1 and len(filtered_docs[0]) == 1 if attr_name in ['blob', 'embedding']: np.testing.assert_array_equal(getattr(filtered_docs[0][0], attr_name), attr_value) else: assert getattr(filtered_docs[0][0], attr_name) == attr_value