def test_delitem(self): dataset = DatasetBase("test_name") segments = [Segment(str(i)) for i in range(5)] for segment in segments: dataset.add_segment(segment) del segments[1:3] del dataset[1:3] assert len(dataset) == len(segments) for dataset_segment, segment in zip(dataset, segments): assert dataset_segment is segment del segments[1] del dataset[1] assert len(dataset) == len(segments) for dataset_segment, segment in zip(dataset, segments): assert dataset_segment is segment del segments[segments.index(dataset["4"])] del dataset["4"] assert len(dataset) == len(segments) for dataset_segment, segment in zip(dataset, segments): assert dataset_segment is segment del dataset[100:200] assert len(dataset) == len(segments) for dataset_segment, segment in zip(dataset, segments): assert dataset_segment is segment with pytest.raises(IndexError): del dataset[100] with pytest.raises(KeyError): del dataset["100"]
def test_load_catalog(self): catalog_path = os.path.join(os.path.dirname(__file__), "..", "..", "opendataset", "HeadPoseImage", "catalog.json") dataset = DatasetBase("test_name") dataset.load_catalog(catalog_path) with open(catalog_path, encoding="utf-8") as fp: catalog = json.load(fp) assert dataset.catalog.dumps() == catalog
def test_contains(self): dataset = DatasetBase("test_name") keys = ("test", "train") for key in keys: dataset.add_segment(Segment(key)) for key in keys: assert key in dataset assert "val" not in dataset assert 100 not in dataset
def test_getitem(self): dataset = DatasetBase("test_name") train_segment = Segment("train") test_segment = Segment("test") dataset.add_segment(train_segment) dataset.add_segment(test_segment) assert test_segment is dataset[0] is dataset["test"] assert train_segment is dataset[1] is dataset["train"] with pytest.raises(IndexError): dataset[2] with pytest.raises(KeyError): dataset["unknown"]
def test_enable_cache(self, mocker): dataset = DatasetBase("test_name") with pytest.raises(TypeError): dataset.enable_cache()
def test_add_and_getitem(self): dataset = DatasetBase("test_name") segment = Segment("train") dataset.add_segment(segment) assert dataset[0] is segment
def test_len(self): dataset = DatasetBase("test_name") dataset.add_segment(Segment("train")) assert len(dataset) == 1
def test_keys(self): dataset = DatasetBase("test_name") keys = ("test", "train") for key in keys: dataset.add_segment(Segment(key)) assert dataset.keys() == keys