Beispiel #1
0
    def test_iter(self):
        images = {str(idx): data.LocalImage(str(idx)) for idx in range(3)}
        collection = data.LocalImageCollection(images)

        actual = {name: image for name, image in collection}
        desired = images
        assert actual == desired
Beispiel #2
0
    def test_getitem(self):
        images = {str(idx): data.LocalImage(str(idx)) for idx in range(3)}
        collection = data.LocalImageCollection(images)

        actual = {str(idx): collection[str(idx)] for idx in range(len(collection))}
        desired = images
        assert actual == desired
Beispiel #3
0
    def test_len(self):
        num_images = 3
        images = {str(idx): data.LocalImage(str(idx)) for idx in range(num_images)}
        collection = data.LocalImageCollection(images)

        actual = len(collection)
        desired = num_images
        assert actual == desired
Beispiel #4
0
    def test_read(self, tmpdir):
        def create_images(root):
            torch.manual_seed(0)
            files = {}
            for idx in range(3):
                name = str(idx)
                image = torch.rand(1, 3, 32, 32)
                file = path.join(root, f"{name}.png")
                write_image(image, file)
                files[name] = file
            return files

        files = create_images(tmpdir)
        collection = data.LocalImageCollection(
            {name: data.LocalImage(file) for name, file in files.items()}
        )

        actual = collection.read()
        desired = {name: read_image(file) for name, file in files.items()}
        ptu.assert_allclose(actual, desired)
    def test_LocalImageCollection_read(self):
        def create_images(root):
            torch.manual_seed(0)
            files = {}
            for idx in range(3):
                name = str(idx)
                image = torch.rand(1, 3, 32, 32)
                file = path.join(root, f"{name}.png")
                write_image(image, file)
                files[name] = file
            return files

        with get_tmp_dir() as root:
            files = create_images(root)
            collection = data.LocalImageCollection(
                {name: data.LocalImage(file) for name, file in files.items()}
            )

            actual = collection.read()
            desired = {name: read_image(file) for name, file in files.items()}
            self.assertTensorDictAlmostEqual(actual, desired)
Beispiel #6
0
    def test_repr_smoke(self):
        images = {str(idx): data.LocalImage(str(idx)) for idx in range(3)}
        collection = data.LocalImageCollection(images)

        assert isinstance(repr(collection), str)