Ejemplo n.º 1
0
    def mrf_op(self):
        patch_size = 2
        stride = 2

        return ops.MRFOperator(
            enc.SequentialEncoder((self.Identity(),)), patch_size, stride=stride
        )
Ejemplo n.º 2
0
def test_MRFOperator_call_guided():
    patch_size = 2
    stride = 2

    torch.manual_seed(0)
    target_image = torch.rand(1, 3, 32, 32)
    input_image = torch.rand(1, 3, 32, 32)
    target_guide = torch.cat(
        (torch.zeros(1, 1, 16, 32), torch.ones(1, 1, 16, 32)), dim=2)
    input_guide = target_guide.flip(2)
    encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), ))

    op = ops.MRFOperator(encoder, patch_size, stride=stride)
    op.set_target_guide(target_guide)
    op.set_target_image(target_image)
    op.set_input_guide(input_guide)

    actual = op(input_image)

    input_enc = encoder(input_image)[:, :, :16, :]
    target_enc = encoder(target_image)[:, :, 16:, :]
    desired = F.mrf_loss(
        pystiche.extract_patches2d(input_enc, patch_size, stride=stride),
        pystiche.extract_patches2d(target_enc, patch_size, stride=stride),
    )
    ptu.assert_allclose(actual, desired)
Ejemplo n.º 3
0
    def test_MRFOperator_enc_to_repr_guided(self):
        class Identity(pystiche.Module):
            def forward(self, image):
                return image

        patch_size = 2
        stride = 2

        op = ops.MRFOperator(SequentialEncoder((Identity(), )),
                             patch_size,
                             stride=stride)

        with self.subTest(enc="constant"):
            enc = torch.ones(1, 4, 8, 8)

            actual = op.enc_to_repr(enc, is_guided=True)
            desired = torch.ones(0, 4, stride, stride)
            self.assertTensorAlmostEqual(actual, desired)

        with self.subTest(enc="spatial_mix"):
            constant = torch.ones(1, 4, 4, 8)
            varying = torch.rand(1, 4, 4, 8)
            enc = torch.cat((constant, varying), dim=2)

            actual = op.enc_to_repr(enc, is_guided=True)
            desired = pystiche.extract_patches2d(varying,
                                                 patch_size,
                                                 stride=stride)
            self.assertTensorAlmostEqual(actual, desired)

        with self.subTest(enc="channel_mix"):
            constant = torch.ones(1, 2, 8, 8)
            varying = torch.rand(1, 2, 8, 8)
            enc = torch.cat((constant, varying), dim=1)

            actual = op.enc_to_repr(enc, is_guided=True)
            desired = pystiche.extract_patches2d(enc,
                                                 patch_size,
                                                 stride=stride)
            self.assertTensorAlmostEqual(actual, desired)

        with self.subTest(enc="varying"):
            enc = torch.rand(1, 4, 8, 8)

            actual = op.enc_to_repr(enc, is_guided=True)
            desired = pystiche.extract_patches2d(enc,
                                                 patch_size,
                                                 stride=stride)
            self.assertTensorAlmostEqual(actual, desired)
Ejemplo n.º 4
0
def test_MRFOperator_set_target_guide_without_recalc():
    patch_size = 3
    stride = 2

    torch.manual_seed(0)
    repr = torch.rand(1, 3, 32, 32)
    guide = torch.rand(1, 1, 32, 32)
    encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), ))

    op = ops.MRFOperator(encoder, patch_size, stride=stride)
    op.register_buffer("target_repr", repr)
    op.set_target_guide(guide, recalc_repr=False)

    actual = op.target_repr
    desired = repr
    ptu.assert_allclose(actual, desired)
Ejemplo n.º 5
0
    def test_set_target_guide_without_recalc(self):
        patch_size = 3
        stride = 2

        torch.manual_seed(0)
        image = torch.rand(1, 3, 32, 32)
        guide = torch.rand(1, 1, 32, 32)
        encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1),))

        op = ops.MRFOperator(encoder, patch_size, stride=stride)
        op.set_target_image(image)
        desired = op.target_repr.clone()

        op.set_target_guide(guide, recalc_repr=False)
        actual = op.target_repr

        ptu.assert_allclose(actual, desired)
Ejemplo n.º 6
0
    def test_call(self):
        patch_size = 3
        stride = 2

        torch.manual_seed(0)
        target_image = torch.rand(1, 3, 32, 32)
        input_image = torch.rand(1, 3, 32, 32)
        encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1),))

        op = ops.MRFOperator(encoder, patch_size, stride=stride)
        op.set_target_image(target_image)

        actual = op(input_image)
        desired = F.mrf_loss(
            pystiche.extract_patches2d(encoder(input_image), patch_size, stride=stride),
            pystiche.extract_patches2d(
                encoder(target_image), patch_size, stride=stride
            ),
        )
        ptu.assert_allclose(actual, desired)
Ejemplo n.º 7
0
    def test_MRFOperator_call(self):
        patch_size = 3
        stride = 2

        torch.manual_seed(0)
        target_image = torch.rand(1, 3, 32, 32)
        input_image = torch.rand(1, 3, 32, 32)
        encoder = SequentialEncoder((nn.Conv2d(3, 3, 1),))

        op = ops.MRFOperator(encoder, patch_size, stride=stride)
        op.set_target_image(target_image)

        actual = op(input_image)
        desired = F.patch_matching_loss(
            pystiche.extract_patches2d(encoder(input_image), patch_size, stride=stride),
            pystiche.extract_patches2d(
                encoder(target_image), patch_size, stride=stride
            ),
        )
        self.assertFloatAlmostEqual(actual, desired)
Ejemplo n.º 8
0
def test_MRFOperator_set_target_guide():
    patch_size = 3
    stride = 2

    torch.manual_seed(0)
    image = torch.rand(1, 3, 32, 32)
    guide = torch.rand(1, 1, 32, 32)
    encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), ))

    op = ops.MRFOperator(encoder, patch_size, stride=stride)
    op.set_target_image(image)
    assert not op.has_target_guide

    op.set_target_guide(guide)
    assert op.has_target_guide

    actual = op.target_guide
    desired = guide
    ptu.assert_allclose(actual, desired)

    actual = op.target_image
    desired = image
    ptu.assert_allclose(actual, desired)
Ejemplo n.º 9
0
    def test_MRFOperator_set_target_guide(self):
        patch_size = 3
        stride = 2

        torch.manual_seed(0)
        image = torch.rand(1, 3, 32, 32)
        guide = torch.rand(1, 1, 32, 32)
        encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), ))

        op = ops.MRFOperator(encoder, patch_size, stride=stride)
        op.set_target_image(image)
        self.assertFalse(op.has_target_guide)

        op.set_target_guide(guide)
        self.assertTrue(op.has_target_guide)

        actual = op.target_guide
        desired = guide
        self.assertTensorAlmostEqual(actual, desired)

        actual = op.target_image
        desired = image
        self.assertTensorAlmostEqual(actual, desired)
Ejemplo n.º 10
0
    def test_MRFOperator_target_image_to_repr(self):
        patch_size = 3
        stride = 2
        scale_step_width = 10e-2
        rotation_step_width = 30.0

        torch.manual_seed(0)
        image = torch.rand(1, 3, 32, 32)
        encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), ))

        op = ops.MRFOperator(
            encoder,
            patch_size,
            stride=stride,
            num_scale_steps=1,
            scale_step_width=scale_step_width,
            num_rotation_steps=1,
            rotation_step_width=rotation_step_width,
        )
        op.set_target_image(image)

        actual = op.target_repr

        reprs = []
        factors = (1.0 - scale_step_width, 1.0, 1.0 + scale_step_width)
        angles = (-rotation_step_width, 0.0, rotation_step_width)
        for factor, angle in itertools.product(factors, angles):
            transformed_image = transform_motif_affinely(image,
                                                         rotation_angle=angle,
                                                         scaling_factor=factor)
            enc = encoder(transformed_image)
            repr = pystiche.extract_patches2d(enc, patch_size, stride)
            reprs.append(repr)
        desired = torch.cat(reprs)

        self.assertTensorAlmostEqual(actual, desired)
Ejemplo n.º 11
0
def get_style_op(encoder, layer_weight):
    return ops.MRFOperator(encoder,
                           patch_size=3,
                           stride=2,
                           score_weight=layer_weight)