Beispiel #1
0
def evaluate_steady_state(images_root, results_root, device):
    target_file = path.join(images_root, get_npr_general_proxy_file())
    num_steps = 200_000

    target_image = read_image(target_file).to(device)
    level_steps = (0, num_steps)
    print_steps = intgeomspace(1, num_steps, num=1000)

    for ssim_loss in (False, True):
        with record_nst(quiet=True) as recorder:
            perform_ncr(
                target_image,
                level_steps=level_steps,
                quiet=False,
                print_steps=print_steps,
                ssim_loss=ssim_loss,
                diagnose_ssim_score=True,
            )

            df = recorder.extract()

        loss_type = "SSIM" if ssim_loss else "SE"
        df = df.rename(
            columns={f"Content loss ({loss_type})": "loss", "SSIM score": "ssim_score"}
        )
        df = df[["ssim_score", "loss"]]
        df = df.dropna(axis="index", how="all")

        file = f"{loss_type.lower()}.csv"
        file = path.join(results_root, "steady_state", "raw", file)
        df_to_csv(df, file, index=False)
Beispiel #2
0
    def __init__(
        self,
        style_image: Optional[Union[str, torch.Tensor]] = None,
        model: Optional[nn.Module] = None,
        backbone: str = "vgg16",
        content_layer: str = "relu2_2",
        content_weight: float = 1e5,
        style_layers: Union[Sequence[str], str] = ("relu1_2", "relu2_2",
                                                   "relu3_3", "relu4_3"),
        style_weight: float = 1e10,
        optimizer: Union[Type[torch.optim.Optimizer],
                         torch.optim.Optimizer] = torch.optim.Adam,
        optimizer_kwargs: Optional[Dict[str, Any]] = None,
        scheduler: Optional[Union[Type[_LRScheduler], str,
                                  _LRScheduler]] = None,
        scheduler_kwargs: Optional[Dict[str, Any]] = None,
        learning_rate: float = 1e-3,
        serializer: Optional[Union[Serializer, Mapping[str,
                                                       Serializer]]] = None,
    ):

        if not _IMAGE_STLYE_TRANSFER:
            raise ModuleNotFoundError(
                "Please, pip install -e '.[image_style_transfer]'")

        self.save_hyperparameters(ignore="style_image")

        if style_image is None:
            style_image = self.default_style_image()
        elif isinstance(style_image, str):
            style_image = read_image(style_image)

        if model is None:
            model = pystiche.demo.transformer()

        if not isinstance(style_layers, (List, Tuple)):
            style_layers = (style_layers, )

        perceptual_loss = self._get_perceptual_loss(
            backbone=backbone,
            content_layer=content_layer,
            content_weight=content_weight,
            style_layers=style_layers,
            style_weight=style_weight,
        )
        perceptual_loss.set_style_image(style_image)

        super().__init__(
            model=model,
            loss_fn=perceptual_loss,
            optimizer=optimizer,
            optimizer_kwargs=optimizer_kwargs,
            scheduler=scheduler,
            scheduler_kwargs=scheduler_kwargs,
            learning_rate=learning_rate,
            serializer=serializer,
        )

        self.perceptual_loss = perceptual_loss
Beispiel #3
0
def read_local_or_builtin_image(root, name, builtin_images, **read_image_kwargs):
    file = name
    if not path.abspath(file):
        file = path.join(root, name)
    if path.exists(file):
        return image.read_image(file, **read_image_kwargs)

    return builtin_images[name].read(root, **read_image_kwargs)
Beispiel #4
0
def test_download_file(tmpdir, test_image_url, test_image):
    file = path.join(tmpdir, path.basename(test_image_url))
    misc.download_file(test_image_url,
                       file,
                       md5="a858d33c424eaac1322cf3cab6d3d568")

    actual = read_image(file)
    desired = test_image
    ptu.assert_allclose(actual, desired)
Beispiel #5
0
    def test_download(self, tmpdir, test_image_url, test_image):
        image = data.DownloadableImage(test_image_url)
        image.download(tmpdir)

        file = path.join(tmpdir, image.file)
        assert path.exists(file)

        actual = read_image(file)
        desired = test_image
        ptu.assert_allclose(actual, desired)
Beispiel #6
0
def test_write_image(tmpdir):
    torch.manual_seed(0)
    image = torch.rand(3, 100, 100)

    file = path.join(tmpdir, "tmp_image.png")
    image_.write_image(image, file)

    actual = image_.read_image(file=file)

    desired = image
    pyimagetest.assert_images_almost_equal(actual, desired)
Beispiel #7
0
def test_read_image_resize_scalar(test_image_file, test_image_pil):
    edge_size = 200

    aspect_ratio = image_.calculate_aspect_ratio(
        (test_image_pil.height, test_image_pil.width)
    )
    image_size = image_.edge_to_image_size(edge_size, aspect_ratio)

    actual = image_.read_image(test_image_file, size=edge_size)
    desired = test_image_pil.resize(image_size[::-1])
    pyimagetest.assert_images_almost_equal(actual, desired)
Beispiel #8
0
    def test_download_guides(self, tmpdir, test_image_url, test_image):
        guide = data.DownloadableImage(
            test_image_url, file="guide" + path.splitext(test_image_url)[1],
        )
        image = data.DownloadableImage(
            test_image_url, guides=data.DownloadableImageCollection({"guide": guide}),
        )
        image.download(tmpdir)

        actual = read_image(path.join(tmpdir, guide.file))
        desired = test_image
        ptu.assert_allclose(actual, desired)
Beispiel #9
0
    def test_download_overwrite(self, tmpdir, test_image_url, test_image):
        def create_fake_image(file):
            open(file, "wb").close()

        image = data.DownloadableImage(test_image_url)
        file = path.join(tmpdir, image.file)

        create_fake_image(file)
        image.download(tmpdir, overwrite=True)

        actual = read_image(file)
        desired = test_image
        ptu.assert_allclose(actual, desired)

        create_fake_image(file)
        image.md5 = "a858d33c424eaac1322cf3cab6d3d568"
        image.download(tmpdir, overwrite=True)

        actual = read_image(file)
        desired = test_image
        ptu.assert_allclose(actual, desired)
Beispiel #10
0
def test_write_guides(tmpdir):
    guides, _ = get_test_guides()

    image.write_guides(guides, tmpdir)

    actual = {
        region: image.read_image(file=path.join(tmpdir, f"{region}.png"))
        for region in guides.keys()
    }
    desired = guides

    ptu.assert_allclose(actual, desired)
Beispiel #11
0
    def read(
        self,
        root: Optional[str] = None,
        **read_image_kwargs: Any,
    ) -> torch.Tensor:
        file = self.file
        if not path.isabs(file) and root is not None:
            file = path.join(root, file)

        image = read_image(file, **read_image_kwargs)
        if self.transform is None:
            return image

        return self.transform(image)
Beispiel #12
0
    def __init__(
        self,
        style_image: Optional[Union[str, torch.Tensor]] = None,
        model: Optional[nn.Module] = None,
        backbone: str = "vgg16",
        content_layer: str = "relu2_2",
        content_weight: float = 1e5,
        style_layers: Union[Sequence[str], str] = [
            "relu1_2", "relu2_2", "relu3_3", "relu4_3"
        ],
        style_weight: float = 1e10,
        optimizer: OPTIMIZER_TYPE = "Adam",
        lr_scheduler: LR_SCHEDULER_TYPE = None,
        learning_rate: Optional[float] = None,
    ):
        self.save_hyperparameters(ignore="style_image")

        if style_image is None:
            style_image = self.default_style_image()
        elif isinstance(style_image, str):
            style_image = read_image(style_image)

        if model is None:
            model = pystiche.demo.transformer()

        if not isinstance(style_layers, (List, Tuple)):
            style_layers = (style_layers, )

        perceptual_loss = self._get_perceptual_loss(
            backbone=backbone,
            content_layer=content_layer,
            content_weight=content_weight,
            style_layers=style_layers,
            style_weight=style_weight,
        )
        perceptual_loss.set_style_image(style_image)

        super().__init__(
            model=model,
            loss_fn=perceptual_loss,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            learning_rate=learning_rate,
        )

        self.perceptual_loss = perceptual_loss
Beispiel #13
0
def benchmark_ncr(images_root, results_root, device):
    target_files = get_npr_general_files()
    ssim_component_weight_ratios = (0.0, 3.0, 9.0, np.inf)
    num_seeds = 5

    loss_variations = [
        (True, ssim_component_weight_ratio)
        for ssim_component_weight_ratio in ssim_component_weight_ratios
    ]
    loss_variations = [(False, None)] + loss_variations
    seeds = np.arange(num_seeds)

    calculate_ssim_score = SimplifiedMSSIM().to(device)
    data = []
    for target_file in target_files:
        target_name = path.splitext(path.basename(target_file))[0]
        target_image = read_image(path.join(images_root, target_file)).to(device)

        eval_transform = get_eval_transform(target_image)
        target_image_eval = eval_transform(target_image)

        for loss_variation, seed in itertools.product(loss_variations, seeds):
            ssim_loss, ssim_component_weight_ratio = loss_variation

            output_image = perform_ncr(
                target_image,
                seed=seed,
                ssim_loss=ssim_loss,
                ssim_component_weight_ratio=ssim_component_weight_ratio,
            )
            output_image_eval = eval_transform(output_image)

            mssim = calculate_ssim_score(output_image_eval, target_image_eval)
            ssim_score = mssim.cpu().item()

            data.append(
                (target_name, ssim_loss, ssim_component_weight_ratio, seed, ssim_score)
            )

    columns = ("name", "ssim_loss", "ssim_component_weight_ratio", "seed", "ssim_score")
    df = pd.DataFrame.from_records(data, columns=columns)
    file = path.join(results_root, "ncr_benchmark", "raw.csv")
    df_to_csv(df, file)
Beispiel #14
0
    def test_read(self, tmpdir):
        def create_images(root):
            torch.manual_seed(0)
            files = {}
            for idx in range(3):
                name = str(idx)
                image = torch.rand(1, 3, 32, 32)
                file = path.join(root, f"{name}.png")
                write_image(image, file)
                files[name] = file
            return files

        files = create_images(tmpdir)
        collection = data.LocalImageCollection(
            {name: data.LocalImage(file) for name, file in files.items()}
        )

        actual = collection.read()
        desired = {name: read_image(file) for name, file in files.items()}
        ptu.assert_allclose(actual, desired)
Beispiel #15
0
def evaluate_ssim_window(images_root, results_root, device):
    target_file = path.join(images_root, get_npr_general_proxy_file())
    window_types = ("gauss", "box")
    output_shapes = ("same", "valid")
    radii = range(1, 10)
    num_seeds = 5

    target_image = read_image(target_file).to(device)

    eval_transform = get_eval_transform(target_image)
    target_image_eval = eval_transform(target_image)

    def get_image_filter(window_type, output_shape, radius):
        kwargs = {"output_shape": output_shape, "padding_mode": "replicate"}
        if window_type == "gauss":
            return GaussFilter(radius=radius, std=radius / 3.0, **kwargs)
        else:  # filter_type == "box"
            return BoxFilter(radius=radius, **kwargs)

    seeds = range(num_seeds)

    calculate_mssim = SimplifiedMSSIM().to(device)
    data = []

    for image_filter_params in itertools.product(window_types, output_shapes, radii):
        image_filter = get_image_filter(*image_filter_params)

        for seed in seeds:

            kwargs = {"seed": seed, "image_filter": image_filter}
            output_image = perform_ncr(target_image, **kwargs)
            output_image_eval = eval_transform(output_image)

            mssim = calculate_mssim(output_image_eval, target_image_eval)
            ssim_score = mssim.cpu().item()
            data.append((*image_filter_params, seed, ssim_score))

    columns = ("window_type", "output_shape", "radius", "seed", "ssim_score")
    df = pd.DataFrame.from_records(data, columns=columns)
    file = path.join(results_root, "ssim_window", "raw.csv")
    df_to_csv(df, file)
    def test_LocalImageCollection_read(self):
        def create_images(root):
            torch.manual_seed(0)
            files = {}
            for idx in range(3):
                name = str(idx)
                image = torch.rand(1, 3, 32, 32)
                file = path.join(root, f"{name}.png")
                write_image(image, file)
                files[name] = file
            return files

        with get_tmp_dir() as root:
            files = create_images(root)
            collection = data.LocalImageCollection(
                {name: data.LocalImage(file) for name, file in files.items()}
            )

            actual = collection.read()
            desired = {name: read_image(file) for name, file in files.items()}
            self.assertTensorDictAlmostEqual(actual, desired)
Beispiel #17
0
    def read(
        self,
        root: Optional[str] = None,
        **read_image_kwargs: Any,
    ) -> torch.Tensor:
        r"""Read the image from file with :func:`pystiche.image.read_image` and
        optionally apply :attr:`.transform`.

        Args:
            root: Optional root directory if the file is a relative path.
            **read_image_kwargs: Optional parameters passed to
                :func:`pystiche.image.read_image`.
        """
        file = self.file
        if not path.isabs(file) and root is not None:
            file = path.join(root, file)

        image = read_image(file, **read_image_kwargs)
        if self.transform is None:
            return image

        return self.transform(image)
Beispiel #18
0
def read_images():
    content_image = read_image("content.jpg")
    style_image = read_image("style.jpg")
    stylized_image = read_image("stylized.jpg")

    return content_image, style_image, stylized_image
Beispiel #19
0
def test_read_image_resize(test_image_file, test_image_pil):
    image_size = (200, 300)
    actual = image_.read_image(test_image_file, size=image_size)
    desired = test_image_pil.resize(image_size[::-1])
    pyimagetest.assert_images_almost_equal(actual, desired)
Beispiel #20
0
def test_read_image_resize_other(test_image_file):
    with pytest.raises(TypeError):
        image_.read_image(test_image_file, size="invalid_size")
Beispiel #21
0
 def importer(file: str) -> torch.Tensor:
     return read_image(file, make_batched=False)
Beispiel #22
0
def test_read_image(test_image_file, test_image):
    actual = image_.read_image(test_image_file)
    desired = test_image
    assert image_.is_batched_image(actual)
    pyimagetest.assert_images_almost_equal(actual, desired)
Beispiel #23
0
 def process_image(file):
     name = path.splitext(path.basename(file))[0]
     image = read_image(path.join(images_root, file)).to(device)
     return name, image