Пример #1
0
def test_extract_normalized_patches2d_no_overlap(subtests):
    height = 4
    width = 4
    patch_size = 2
    stride = 2

    input = torch.ones(1, 1, height, width).requires_grad_(True)
    input_normalized = torch.ones(1, 1, height, width).requires_grad_(True)
    target = torch.zeros(1, 1, height, width).detach()

    input_patches = pystiche.extract_patches2d(input,
                                               patch_size=patch_size,
                                               stride=stride)
    input_patches_normalized = paper.extract_normalized_patches2d(
        input_normalized, patch_size=patch_size, stride=stride)
    target_patches = pystiche.extract_patches2d(target,
                                                patch_size=patch_size,
                                                stride=stride)

    loss = 0.5 * torch.sum((input_patches - target_patches)**2.0)
    loss.backward()

    loss_normalized = 0.5 * torch.sum(
        (input_patches_normalized - target_patches)**2.0)
    loss_normalized.backward()

    with subtests.test("forward"):
        ptu.assert_allclose(input_patches_normalized, input_patches)

    with subtests.test("backward"):
        ptu.assert_allclose(input_normalized.grad, input.grad)
Пример #2
0
def test_MRFOperator_call_guided():
    patch_size = 2
    stride = 2

    torch.manual_seed(0)
    target_image = torch.rand(1, 3, 32, 32)
    input_image = torch.rand(1, 3, 32, 32)
    target_guide = torch.cat(
        (torch.zeros(1, 1, 16, 32), torch.ones(1, 1, 16, 32)), dim=2)
    input_guide = target_guide.flip(2)
    encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), ))

    op = ops.MRFOperator(encoder, patch_size, stride=stride)
    op.set_target_guide(target_guide)
    op.set_target_image(target_image)
    op.set_input_guide(input_guide)

    actual = op(input_image)

    input_enc = encoder(input_image)[:, :, :16, :]
    target_enc = encoder(target_image)[:, :, 16:, :]
    desired = F.mrf_loss(
        pystiche.extract_patches2d(input_enc, patch_size, stride=stride),
        pystiche.extract_patches2d(target_enc, patch_size, stride=stride),
    )
    ptu.assert_allclose(actual, desired)
Пример #3
0
    def test_call_guided(self, encoder):
        patch_size = 2
        stride = 2

        torch.manual_seed(0)
        target_image = torch.rand(1, 3, 32, 32)
        input_image = torch.rand(1, 3, 32, 32)
        target_guide = torch.cat(
            (torch.zeros(1, 1, 16, 32), torch.ones(1, 1, 16, 32)), dim=2
        )
        input_guide = target_guide.flip(2)

        loss = loss_.MRFLoss(encoder, patch_size, stride=stride)
        loss.set_target_image(target_image, guide=target_guide)
        loss.set_input_guide(input_guide)

        actual = loss(input_image)

        input_enc = encoder(input_image)[:, :, :16, :]
        target_enc = encoder(target_image)[:, :, 16:, :]
        desired = F.mrf_loss(
            pystiche.extract_patches2d(input_enc, patch_size, stride=stride),
            pystiche.extract_patches2d(target_enc, patch_size, stride=stride),
            batched_input=True,
        )
        ptu.assert_allclose(actual, desired)
Пример #4
0
    def test_MRFOperator_enc_to_repr_guided(self):
        class Identity(pystiche.Module):
            def forward(self, image):
                return image

        patch_size = 2
        stride = 2

        op = ops.MRFOperator(SequentialEncoder((Identity(), )),
                             patch_size,
                             stride=stride)

        with self.subTest(enc="constant"):
            enc = torch.ones(1, 4, 8, 8)

            actual = op.enc_to_repr(enc, is_guided=True)
            desired = torch.ones(0, 4, stride, stride)
            self.assertTensorAlmostEqual(actual, desired)

        with self.subTest(enc="spatial_mix"):
            constant = torch.ones(1, 4, 4, 8)
            varying = torch.rand(1, 4, 4, 8)
            enc = torch.cat((constant, varying), dim=2)

            actual = op.enc_to_repr(enc, is_guided=True)
            desired = pystiche.extract_patches2d(varying,
                                                 patch_size,
                                                 stride=stride)
            self.assertTensorAlmostEqual(actual, desired)

        with self.subTest(enc="channel_mix"):
            constant = torch.ones(1, 2, 8, 8)
            varying = torch.rand(1, 2, 8, 8)
            enc = torch.cat((constant, varying), dim=1)

            actual = op.enc_to_repr(enc, is_guided=True)
            desired = pystiche.extract_patches2d(enc,
                                                 patch_size,
                                                 stride=stride)
            self.assertTensorAlmostEqual(actual, desired)

        with self.subTest(enc="varying"):
            enc = torch.rand(1, 4, 8, 8)

            actual = op.enc_to_repr(enc, is_guided=True)
            desired = pystiche.extract_patches2d(enc,
                                                 patch_size,
                                                 stride=stride)
            self.assertTensorAlmostEqual(actual, desired)
Пример #5
0
    def test_enc_to_repr_guided_varying(self, mrf_op):
        enc_ = torch.rand(1, 4, 8, 8)

        actual = mrf_op.enc_to_repr(enc_, is_guided=True)
        desired = pystiche.extract_patches2d(
            enc_, mrf_op.patch_size[0], stride=mrf_op.stride[0]
        )
        ptu.assert_allclose(actual, desired)
Пример #6
0
    def enc_to_repr(self, enc: torch.Tensor, is_guided: bool) -> torch.Tensor:
        if self.normalize_patches_grad:
            repr = extract_normalized_patches2d(enc, self.patch_size, self.stride)
        else:
            repr = pystiche.extract_patches2d(enc, self.patch_size, self.stride)
        if not is_guided:
            return repr

        return self._guide_repr(repr)
Пример #7
0
    def test_enc_to_repr_guided_channel_mix(self, mrf_op):
        constant = torch.ones(1, 2, 8, 8)
        varying = torch.rand(1, 2, 8, 8)
        enc_ = torch.cat((constant, varying), dim=1)

        actual = mrf_op.enc_to_repr(enc_, is_guided=True)
        desired = pystiche.extract_patches2d(
            enc_, mrf_op.patch_size[0], stride=mrf_op.stride[0]
        )
        ptu.assert_allclose(actual, desired)
Пример #8
0
    def test_call(self):
        patch_size = 3
        stride = 2

        torch.manual_seed(0)
        target_image = torch.rand(1, 3, 32, 32)
        input_image = torch.rand(1, 3, 32, 32)
        encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1),))

        op = ops.MRFOperator(encoder, patch_size, stride=stride)
        op.set_target_image(target_image)

        actual = op(input_image)
        desired = F.mrf_loss(
            pystiche.extract_patches2d(encoder(input_image), patch_size, stride=stride),
            pystiche.extract_patches2d(
                encoder(target_image), patch_size, stride=stride
            ),
        )
        ptu.assert_allclose(actual, desired)
Пример #9
0
    def test_MRFOperator_call(self):
        patch_size = 3
        stride = 2

        torch.manual_seed(0)
        target_image = torch.rand(1, 3, 32, 32)
        input_image = torch.rand(1, 3, 32, 32)
        encoder = SequentialEncoder((nn.Conv2d(3, 3, 1),))

        op = ops.MRFOperator(encoder, patch_size, stride=stride)
        op.set_target_image(target_image)

        actual = op(input_image)
        desired = F.patch_matching_loss(
            pystiche.extract_patches2d(encoder(input_image), patch_size, stride=stride),
            pystiche.extract_patches2d(
                encoder(target_image), patch_size, stride=stride
            ),
        )
        self.assertFloatAlmostEqual(actual, desired)
Пример #10
0
    def test_call(self, encoder):
        patch_size = 3
        stride = 2

        torch.manual_seed(0)
        target_image = torch.rand(1, 3, 32, 32)
        input_image = torch.rand(1, 3, 32, 32)

        loss = loss_.MRFLoss(encoder, patch_size, stride=stride)
        loss.set_target_image(target_image)

        actual = loss(input_image)
        desired = F.mrf_loss(
            pystiche.extract_patches2d(encoder(input_image), patch_size, stride=stride),
            pystiche.extract_patches2d(
                encoder(target_image), patch_size, stride=stride
            ),
            batched_input=True,
        )
        ptu.assert_allclose(actual, desired)
Пример #11
0
    def test_MRFOperator_target_image_to_repr(self):
        patch_size = 3
        stride = 2
        scale_step_width = 10e-2
        rotation_step_width = 30.0

        torch.manual_seed(0)
        image = torch.rand(1, 3, 32, 32)
        encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), ))

        op = ops.MRFOperator(
            encoder,
            patch_size,
            stride=stride,
            num_scale_steps=1,
            scale_step_width=scale_step_width,
            num_rotation_steps=1,
            rotation_step_width=rotation_step_width,
        )
        op.set_target_image(image)

        actual = op.target_repr

        reprs = []
        factors = (1.0 - scale_step_width, 1.0, 1.0 + scale_step_width)
        angles = (-rotation_step_width, 0.0, rotation_step_width)
        for factor, angle in itertools.product(factors, angles):
            transformed_image = transform_motif_affinely(image,
                                                         rotation_angle=angle,
                                                         scaling_factor=factor)
            enc = encoder(transformed_image)
            repr = pystiche.extract_patches2d(enc, patch_size, stride)
            reprs.append(repr)
        desired = torch.cat(reprs)

        self.assertTensorAlmostEqual(actual, desired)
Пример #12
0
    def enc_to_repr(self, enc: torch.Tensor, is_guided: bool) -> torch.Tensor:
        repr = pystiche.extract_patches2d(enc, self.patch_size, self.stride)
        if not is_guided:
            return repr

        return self._guide_repr(repr)
Пример #13
0
def extract_normalized_patches2d(
    input: torch.Tensor, patch_size: Tuple[int, int], stride: Tuple[int, int]
) -> torch.Tensor:
    for dim, size, step in zip(range(2, input.dim()), patch_size, stride):
        input = normalize_unfold_grad(input, dim, size, step)
    return pystiche.extract_patches2d(input, patch_size, stride)
Пример #14
0
def extract_normalized_patches2d(
    input: torch.Tensor,
    patch_size: Union[int, Sequence[int]],
    stride: Union[int, Sequence[int]],
) -> torch.Tensor:
    r"""Extract 2-dimensional patches from the input with normalized gradient.

    If ``stride >= patch_size``, this behaves just like
    :func:`pystiche.extract_patches2d`. Otherwise, the gradient of the input is
    normalized such that every value is divided by the number of patches it appears in.

    Examples:
        >>> import torch
        >>> import pystiche
        >>> input = torch.ones(1, 1, 4, 4).requires_grad_(True)
        >>> target = torch.zeros(1, 1, 4, 4).detach()
        >>> # without normalized gradient
        >>> input_patches = pystiche.extract_patches2d(
        ...    input, patch_size=2, stride=1
        ... )
        >>> target_patches = pystiche.extract_patches2d(
        ...     target, patch_size=2, stride=1
        ... )
        >>> loss = 0.5 * torch.sum((input_patches - target_patches) ** 2.0)
        >>> loss.backward()
        >>> input.grad
        tensor([[[[1., 2., 2., 1.],
                  [2., 4., 4., 2.],
                  [2., 4., 4., 2.],
                  [1., 2., 2., 1.]]]])

        >>> import torch
        >>> import pystiche
        >>> import pystiche_papers.li_wand_2016 as paper
        >>> input = torch.ones(1, 1, 4, 4).requires_grad_(True)
        >>> target = torch.zeros(1, 1, 4, 4).detach()
        >>> # with normalized gradient
        >>> input_patches = paper.extract_normalized_patches2d(
        ...    input, patch_size=2, stride=1
        ... )
        >>> target_patches = pystiche.extract_patches2d(
        ...     target, patch_size=2, stride=1
        ... )
        >>> loss = 0.5 * torch.sum((input_patches - target_patches) ** 2.0)
        >>> loss.backward()
        >>> input.grad
        tensor([[[[1., 1., 1., 1.],
                  [1., 1., 1., 1.],
                  [1., 1., 1., 1.],
                  [1., 1., 1., 1.]]]])

    Args:
        input: Input tensor of shape :math:`B \times C \times H \times W`
        patch_size: Patch size
        stride: Stride
    """
    patch_size = misc.to_2d_arg(patch_size)
    stride = misc.to_2d_arg(stride)
    for dim, size, step in zip(range(2, input.dim()), patch_size, stride):
        input = normalize_unfold_grad(input, dim, size, step)
    return pystiche.extract_patches2d(input, patch_size, stride)