def test_dataset_filter_3(): schema = { "img": Image((None, None, 3), max_shape=(100, 100, 3)), "cl": ClassLabel(names=["cat", "dog", "horse"]), } ds = Dataset("./data/tests/filtering_3", shape=(100,), schema=schema, mode="w") for i in range(100): ds["cl", i] = 0 if i % 5 == 0 else 1 ds["img", i] = i * np.ones((5, 6, 3)) ds["cl", 4] = 2 ds_filtered = ds.filter(lambda x: x["cl"].compute() == 0) assert ds_filtered.indexes == [5 * i for i in range(20)] ds_filtered_2 = ds.filter(lambda x: x["cl"].compute() == 2) assert (ds_filtered_2["img"].compute() == 4 * np.ones((1, 5, 6, 3))).all() for item in ds_filtered_2: assert (item["img"].compute() == 4 * np.ones((5, 6, 3))).all() assert item["cl"].compute() == 2
def test_dataset_filter_4(): schema = { "img": Image((None, None, 3), max_shape=(100, 100, 3)), "cl": ClassLabel(names=["cat", "dog", "horse"]), } ds = Dataset("./data/tests/filtering_4", shape=(100,), schema=schema, mode="w") for i in range(100): ds["cl", i] = 0 if i < 10 else 1 ds["img", i] = i * np.ones((5, 6, 3)) ds_filtered = ds.filter(lambda x: x["cl"].compute() == 0) assert (ds_filtered[3:8, "cl"].compute() == np.zeros((5,))).all()
def test_dataset_filter(): def abc_filter(sample): return sample["ab"].compute().startswith("abc") my_schema = {"img": Tensor((100, 100)), "ab": Text((None,), max_shape=(10,))} ds = Dataset("./data/new_filter", shape=(10,), schema=my_schema) for i in range(10): ds["img", i] = i * np.ones((100, 100)) ds["ab", i] = "abc" + str(i) if i % 2 == 0 else "def" + str(i) ds2 = ds.filter(abc_filter) assert ds2.indexes == [0, 2, 4, 6, 8]
def test_dataset_store(): my_schema = {"image": Tensor((100, 100), "uint8"), "abc": "uint8"} ds = Dataset("./test/ds_store", schema=my_schema, shape=(100,)) for i in range(100): ds["image", i] = i * np.ones((100, 100)) ds["abc", i] = i def my_filter(sample): return sample["abc"].compute() % 5 == 0 dsv = ds.filter(my_filter) ds2 = ds.store("./test/ds2_store") for i in range(100): assert (ds2["image", i].compute() == i * np.ones((100, 100))).all() assert ds["abc", i].compute() == i ds3 = dsv.store("./test/ds3_store") for i in range(20): assert (ds3["image", i].compute() == 5 * i * np.ones((100, 100))).all() assert ds3["abc", i].compute() == 5 * i
def test_dataset_filter_2(): my_schema = { "fname": Text((None,), max_shape=(10,)), "lname": Text((None,), max_shape=(10,)), } ds = Dataset("./data/tests/filtering", shape=(100,), schema=my_schema, mode="w") for i in range(100): ds["fname", i] = "John" ds["lname", i] = "Doe" for i in [1, 3, 6, 15, 63, 96, 75]: ds["fname", i] = "Active" for i in [15, 31, 25, 75, 3, 6]: ds["lname", i] = "loop" dsv_combined = ds.filter( lambda x: x["fname"].compute() == "Active" and x["lname"].compute() == "loop" ) tsv_combined_fname = dsv_combined["fname"] tsv_combined_lname = dsv_combined["lname"] for item in dsv_combined: assert item.compute() == {"fname": "Active", "lname": "loop"} for item in tsv_combined_fname: assert item.compute() == "Active" for item in tsv_combined_lname: assert item.compute() == "loop" dsv_1 = ds.filter(lambda x: x["fname"].compute() == "Active") dsv_2 = dsv_1.filter(lambda x: x["lname"].compute() == "loop") for item in dsv_1: assert item.compute()["fname"] == "Active" tsv_1 = dsv_1["fname"] tsv_2 = dsv_2["lname"] for item in tsv_1: assert item.compute() == "Active" for item in tsv_2: assert item.compute() == "loop" for item in dsv_2: assert item.compute() == {"fname": "Active", "lname": "loop"} assert dsv_combined.indexes == [3, 6, 15, 75] assert dsv_1.indexes == [1, 3, 6, 15, 63, 75, 96] assert dsv_2.indexes == [3, 6, 15, 75] dsv_3 = ds.filter(lambda x: x["lname"].compute() == "loop") dsv_4 = dsv_3.filter(lambda x: x["fname"].compute() == "Active") for item in dsv_3: assert item.compute()["lname"] == "loop" for item in dsv_4: assert item.compute() == {"fname": "Active", "lname": "loop"} assert dsv_3.indexes == [3, 6, 15, 25, 31, 75] assert dsv_4.indexes == [3, 6, 15, 75] my_schema2 = { "fname": Text((None,), max_shape=(10,)), "lname": Text((None,), max_shape=(10,)), "image": Image((1920, 1080, 3)), } ds = Dataset("./data/tests/filtering2", shape=(100,), schema=my_schema2, mode="w") with pytest.raises(KeyError): ds.filter(lambda x: (x["random"].compute() == np.ones((1920, 1080, 3))).all()) for i in [1, 3, 6, 15, 63, 96, 75]: ds["fname", i] = "Active" dsv = ds.filter(lambda x: x["fname"].compute() == "Active") with pytest.raises(KeyError): dsv.filter(lambda x: (x["random"].compute() == np.ones((1920, 1080, 3))).all())