Пример #1
0
    def read(
        self,
        root: Optional[str] = None,
        download: Optional[bool] = None,
        overwrite: bool = False,
        **read_image_kwargs: Any,
    ) -> torch.Tensor:
        r"""Read the image from file with :func:`pystiche.image.read_image`. If
        available the :attr:`.transform` is applied afterwards.

        Args:
            root: Optional root directory if the file is a relative path.
                Defaults to :func:`pystiche.home`.
            download: If ``True``, downloads the image first. Defaults to ``False`` if
                the file already exists and the MD5 checksum is not known. Otherwise
                defaults to ``True``.
            overwrite: If downloaded, overwrites files if they already exists or the
                MD5 checksum does not match. Defaults to ``False``.
            **read_image_kwargs: Optional parameters passed to
                :func:`pystiche.image.read_image`.
        """
        if root is None:
            root = pystiche.home()
        if download is None:
            file_exists = path.isfile(path.join(root, self.file))
            md5_available = self.md5 is not None
            download = False if file_exists and not md5_available else True
        if download:
            self.download(root=root, overwrite=overwrite)
        return super().read(root=root, **read_image_kwargs)
Пример #2
0
def demo_images():
    cache_path = pathlib.Path(pystiche.home())
    graphics_path = GRAPHICS / "demo_images"
    graphics_path.mkdir(exist_ok=True)
    api_path = HERE / "api"

    images = pystiche.demo.images()
    images.download()

    entries = {}
    for name, image in images:
        entries[name] = (image.file, extract_aspect_ratio(image.read()))
        if not (graphics_path / image.file).exists():
            (graphics_path / image.file).symlink_to(cache_path / image.file)

    field_len = max(
        max(len(name) for name in entries.keys()) + 2, len("images"))

    def sep(char):
        return "+" + char * (field_len + 2) + "+" + char * (field_len +
                                                            2) + "+"

    def row(name, image):
        key = f"{name:{field_len}}"
        value = image or f"|{name}|"
        value += " " * (field_len - len(image or name))
        return f"| {key} | {value} |"

    images_table = [
        sep("-"),
        row("name", "image"),
        sep("="),
        *itertools.chain(*[(row(name, f"|{name}|"), sep("-"))
                           for name in sorted(entries.keys())]),
    ]

    width = 300
    aliases = [
        f".. |{name}|\n"
        f"  image:: ../graphics/demo_images/{file}\n"
        f"    :width: {width}px\n"
        f"    :height: {width / aspect_ratio:.0f}px\n"
        for name, (file, aspect_ratio) in entries.items()
    ]

    loader = jinja2.FileSystemLoader(searchpath=api_path)
    env = jinja2.Environment(loader=loader)
    template = env.get_template("pystiche.demo.rst.template")

    with open(api_path / "pystiche.demo.rst", "w") as fh:
        fh.write(
            template.render(
                images_table="\n".join(images_table),
                aliases="\n".join(aliases),
            ))

    return None, None
Пример #3
0
 def test_env(self):
     tmp_dir = tempfile.mkdtemp()
     pystiche_home = os.getenv("PYSTICHE_HOME")
     os.environ["PYSTICHE_HOME"] = tmp_dir
     try:
         actual = pystiche.home()
         desired = tmp_dir
         assert actual == desired
     finally:
         if pystiche_home is None:
             del os.environ["PYSTICHE_HOME"]
         else:
             os.environ["PYSTICHE_HOME"] = pystiche_home
Пример #4
0
    def download(self,
                 root: Optional[str] = None,
                 overwrite: bool = False) -> None:
        r"""Download the image and if applicable the guides from their URL. If the
        correct MD5 checksum is known, it is verified first. If it checks out the file
        not re-downloaded.

        Args:
            root: Optional root directory for the download if the file is a relative
                path. Defaults to :func:`pystiche.home`.
            overwrite: Overwrites files if they already exists or the MD5 checksum does
                not match. Defaults to ``False``.
        """
        def _download(file: str) -> None:
            os.makedirs(path.dirname(file), exist_ok=True)
            download_file(self.url, file=file, md5=self.md5)

        if root is None:
            root = pystiche.home()

        if isinstance(self.guides, DownloadableImageCollection):
            self.guides.download(root=root, overwrite=overwrite)

        file = self.file
        if not path.isabs(file) and root is not None:
            file = path.join(root, file)

        if not path.isfile(file):
            _download(file)
            return

        msg_overwrite = "If you want to overwrite it, set overwrite=True."

        if self.md5 is None:
            if overwrite:
                _download(file)
                return
            else:
                msg = f"{path.basename(file)} already exists in {root}. {msg_overwrite}"
                raise FileExistsError(msg)

        if not check_md5(file, self.md5):
            if overwrite:
                _download(file)
                return
            else:
                msg = (
                    f"{path.basename(file)} with a different MD5 hash already exists "
                    f"in {root}. {msg_overwrite}")
                raise FileExistsError(msg)
Пример #5
0
 def read(
     self,
     root: Optional[str] = None,
     download: Optional[bool] = None,
     overwrite: bool = False,
     **read_image_kwargs: Any,
 ) -> torch.Tensor:
     if root is None:
         root = pystiche.home()
     if download is None:
         file_exists = path.isfile(path.join(root, self.file))
         md5_available = self.md5 is not None
         download = False if file_exists and not md5_available else True
     if download:
         self.download(root=root, overwrite=overwrite)
     return super().read(root=root, **read_image_kwargs)
Пример #6
0
    def test_file_starting_point(self, mock_execution_with):
        image = demo.images()["bird2"]
        expected = image.read()
        file = pathlib.Path(pystiche.home()) / image.file
        mock = self._mock_execution(
            mock_execution_with, option="starting_point", value=str(file)
        )

        with exits(), pytest.warns(UserWarning):
            main()

        (input_image, perceptual_loss), _ = mock.call_args

        ptu.assert_allclose(
            input_image,
            resize(expected, list(extract_image_size(perceptual_loss.content_image))),
        )
Пример #7
0
    def test_home_env(self):
        tmp_dir = tempfile.mkdtemp()
        os.rmdir(tmp_dir)

        pystiche_home = os.getenv("PYSTICHE_HOME")
        os.environ["PYSTICHE_HOME"] = tmp_dir
        try:
            actual = pystiche.home()
            desired = tmp_dir
            self.assertEqual(actual, desired)
            self.assertTrue(path.exists(desired) and path.isdir(desired))
        finally:
            if pystiche_home is None:
                del os.environ["PYSTICHE_HOME"]
            else:
                os.environ["PYSTICHE_HOME"] = pystiche_home
            os.rmdir(tmp_dir)
Пример #8
0
    def test_file(self, mock_execution_with, option):
        image = demo.images()["bird2"]
        # TODO: make this independent of the default size value
        expected = image.read(size=500)
        file = pathlib.Path(pystiche.home()) / image.file
        mock = self._mock_execution(mock_execution_with, option=option, value=str(file))

        with exits():
            cli.main()

        (input_image, perceptual_loss), _ = mock.call_args

        if option == "content":
            actual = perceptual_loss.content_image
        else:  # option == "style":
            actual = perceptual_loss.style_image

        ptu.assert_allclose(actual, expected)
Пример #9
0
    def download(self,
                 root: Optional[str] = None,
                 overwrite: bool = False) -> None:
        def _download(file: str) -> None:
            os.makedirs(path.dirname(file), exist_ok=True)
            download_file(self.url, file)

        if root is None:
            root = pystiche.home()

        if isinstance(self.guides, DownloadableImageCollection):
            self.guides.download(root=root, overwrite=overwrite)

        file = self.file
        if not path.isabs(file) and root is not None:
            file = path.join(root, file)

        if not path.isfile(file):
            _download(file)
            return

        msg_overwrite = "If you want to overwrite it, set overwrite=True."

        if self.md5 is None:
            if overwrite:
                _download(file)
                return
            else:
                msg = f"{path.basename(file)} already exists in {root}. {msg_overwrite}"
                raise FileExistsError(msg)

        if not check_md5(file, self.md5):
            if overwrite:
                _download(file)
                return
            else:
                msg = (
                    f"{path.basename(file)} with a different MD5 hash already exists "
                    f"in {root}. {msg_overwrite}")
                raise FileExistsError(msg)
Пример #10
0
def patch_home(home, copy=True, mocker=DEFAULT_MOCKER):
    if copy:
        dir_util.copy_tree(pystiche.home(), home)

    return mocker.patch.dict(os.environ, values={"PYSTICHE_HOME": home})
Пример #11
0
 def test_default(self):
     actual = pystiche.home()
     desired = path.expanduser(path.join("~", ".cache", "pystiche"))
     assert actual == desired
Пример #12
0
def watch_pystiche_home():
    with utils.watch_dir(pystiche.home()):
        yield
Пример #13
0
 def test_home_default(self, makedirs_mock):
     actual = pystiche.home()
     desired = path.expanduser(path.join("~", ".cache", "pystiche"))
     self.assertEqual(actual, desired)