Esempio n. 1
0
    def test_EncodingComparisonOperator_set_target_image(self):
        class TestOperator(ops.EncodingComparisonOperator):
            def target_enc_to_repr(self, image):
                repr = image * 2.0
                ctx = torch.norm(image)
                return repr, ctx

            def input_enc_to_repr(self, image, ctx):
                pass

            def calculate_score(self, input_repr, target_repr, ctx):
                pass

        torch.manual_seed(0)
        image = torch.rand(1, 3, 128, 128)
        encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), ))

        test_op = TestOperator(encoder)
        self.assertFalse(test_op.has_target_image)

        test_op.set_target_image(image)
        self.assertTrue(test_op.has_target_image)

        actual = test_op.target_image
        desired = image
        self.assertTensorAlmostEqual(actual, desired)

        actual = test_op.target_repr
        desired = encoder(image) * 2.0
        self.assertTensorAlmostEqual(actual, desired)

        actual = test_op.ctx
        desired = torch.norm(encoder(image))
        self.assertTensorAlmostEqual(actual, desired)
Esempio n. 2
0
    def test_MRFOperator_call_guided(self):
        patch_size = 2
        stride = 2

        torch.manual_seed(0)
        target_image = torch.rand(1, 3, 32, 32)
        input_image = torch.rand(1, 3, 32, 32)
        target_guide = torch.cat(
            (torch.zeros(1, 1, 16, 32), torch.ones(1, 1, 16, 32)), dim=2)
        input_guide = target_guide.flip(2)
        encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), ))

        op = ops.MRFOperator(encoder, patch_size, stride=stride)
        op.set_target_guide(target_guide)
        op.set_target_image(target_image)
        op.set_input_guide(input_guide)

        actual = op(input_image)

        input_enc = encoder(input_image)[:, :, :16, :]
        target_enc = encoder(target_image)[:, :, 16:, :]
        desired = F.patch_matching_loss(
            pystiche.extract_patches2d(input_enc, patch_size, stride=stride),
            pystiche.extract_patches2d(target_enc, patch_size, stride=stride),
        )
        self.assertFloatAlmostEqual(actual, desired)
Esempio n. 3
0
    def test_EncodingComparisonOperator_call_guided(self):
        class TestOperator(ops.EncodingComparisonOperator):
            def target_enc_to_repr(self, image):
                repr = image + 1.0
                return repr, None

            def input_enc_to_repr(self, image, ctx):
                return image + 2.0

            def calculate_score(self, input_repr, target_repr, ctx):
                return input_repr * target_repr

        torch.manual_seed(0)
        target_image = torch.rand(1, 3, 32, 32)
        input_image = torch.rand(1, 3, 32, 32)
        target_guide = torch.rand(1, 1, 32, 32)
        input_guide = torch.rand(1, 1, 32, 32)

        encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), ))
        target_enc_guide = encoder.propagate_guide(target_guide)
        input_enc_guide = encoder.propagate_guide(input_guide)

        test_op = TestOperator(encoder)
        test_op.set_target_guide(target_guide)
        test_op.set_target_image(target_image)
        test_op.set_input_guide(input_guide)

        actual = test_op(input_image)
        desired = (
            TestOperator.apply_guide(encoder(target_image), target_enc_guide) +
            1.0) * (TestOperator.apply_guide(encoder(input_image),
                                             input_enc_guide) + 2.0)
        self.assertTensorAlmostEqual(actual, desired)
Esempio n. 4
0
    def test_EncodingComparisonOperator_set_target_guide_without_recalc(self):
        class TestOperator(ops.EncodingComparisonOperator):
            def target_enc_to_repr(self, image):
                repr = image * 2.0
                ctx = torch.norm(image)
                return repr, ctx

            def input_enc_to_repr(self, image, ctx):
                pass

            def calculate_score(self, input_repr, target_repr, ctx):
                pass

        torch.manual_seed(0)
        repr = torch.rand(1, 3, 32, 32)
        guide = torch.rand(1, 1, 32, 32)
        encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), ))

        test_op = TestOperator(encoder)
        test_op.register_buffer("target_repr", repr)
        test_op.set_target_guide(guide, recalc_repr=False)

        actual = test_op.target_repr
        desired = repr
        self.assertTensorAlmostEqual(actual, desired)
Esempio n. 5
0
    def test_PerceptualLoss_set_style_image(self):
        torch.manual_seed(0)
        image = torch.rand(1, 1, 100, 100)
        content_loss = FeatureReconstructionOperator(
            SequentialEncoder((nn.Conv2d(1, 1, 1),))
        )
        style_loss = FeatureReconstructionOperator(
            SequentialEncoder((nn.Conv2d(1, 1, 1),))
        )

        perceptual_loss = loss.PerceptualLoss(content_loss, style_loss)
        perceptual_loss.set_style_image(image)

        actual = style_loss.target_image
        desired = image
        self.assertTensorAlmostEqual(actual, desired)
Esempio n. 6
0
    def test_MSEEncodingOperator_call(self):
        torch.manual_seed(0)
        target_image = torch.rand(1, 3, 128, 128)
        input_image = torch.rand(1, 3, 128, 128)
        encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), ))

        op = ops.MSEEncodingOperator(encoder)
        op.set_target_image(target_image)

        actual = op(input_image)
        desired = mse_loss(encoder(input_image), encoder(target_image))
        self.assertTensorAlmostEqual(actual, desired)
Esempio n. 7
0
    def test_MRFOperator_enc_to_repr_guided(self):
        class Identity(pystiche.Module):
            def forward(self, image):
                return image

        patch_size = 2
        stride = 2

        op = ops.MRFOperator(SequentialEncoder((Identity(), )),
                             patch_size,
                             stride=stride)

        with self.subTest(enc="constant"):
            enc = torch.ones(1, 4, 8, 8)

            actual = op.enc_to_repr(enc, is_guided=True)
            desired = torch.ones(0, 4, stride, stride)
            self.assertTensorAlmostEqual(actual, desired)

        with self.subTest(enc="spatial_mix"):
            constant = torch.ones(1, 4, 4, 8)
            varying = torch.rand(1, 4, 4, 8)
            enc = torch.cat((constant, varying), dim=2)

            actual = op.enc_to_repr(enc, is_guided=True)
            desired = pystiche.extract_patches2d(varying,
                                                 patch_size,
                                                 stride=stride)
            self.assertTensorAlmostEqual(actual, desired)

        with self.subTest(enc="channel_mix"):
            constant = torch.ones(1, 2, 8, 8)
            varying = torch.rand(1, 2, 8, 8)
            enc = torch.cat((constant, varying), dim=1)

            actual = op.enc_to_repr(enc, is_guided=True)
            desired = pystiche.extract_patches2d(enc,
                                                 patch_size,
                                                 stride=stride)
            self.assertTensorAlmostEqual(actual, desired)

        with self.subTest(enc="varying"):
            enc = torch.rand(1, 4, 8, 8)

            actual = op.enc_to_repr(enc, is_guided=True)
            desired = pystiche.extract_patches2d(enc,
                                                 patch_size,
                                                 stride=stride)
            self.assertTensorAlmostEqual(actual, desired)
Esempio n. 8
0
    def test_GramOperator_call(self):
        torch.manual_seed(0)
        target_image = torch.rand(1, 3, 128, 128)
        input_image = torch.rand(1, 3, 128, 128)
        encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), ))

        op = ops.GramOperator(encoder)
        op.set_target_image(target_image)

        actual = op(input_image)
        desired = mse_loss(
            pystiche.gram_matrix(encoder(input_image), normalize=True),
            pystiche.gram_matrix(encoder(target_image), normalize=True),
        )
        self.assertTensorAlmostEqual(actual, desired)
Esempio n. 9
0
    def test_MRFOperator_set_target_guide_without_recalc(self):
        patch_size = 3
        stride = 2

        torch.manual_seed(0)
        repr = torch.rand(1, 3, 32, 32)
        guide = torch.rand(1, 1, 32, 32)
        encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), ))

        op = ops.MRFOperator(encoder, patch_size, stride=stride)
        op.register_buffer("target_repr", repr)
        op.set_target_guide(guide, recalc_repr=False)

        actual = op.target_repr
        desired = repr
        self.assertTensorAlmostEqual(actual, desired)
Esempio n. 10
0
    def test_EncodingRegularizationOperator_call(self):
        class TestOperator(ops.EncodingRegularizationOperator):
            def input_enc_to_repr(self, image):
                return image * 2.0

            def calculate_score(self, input_repr):
                return input_repr + 1.0

        torch.manual_seed(0)
        image = torch.rand(1, 3, 128, 128)
        encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), ))

        test_op = TestOperator(encoder)

        actual = test_op(image)
        desired = encoder(image) * 2.0 + 1.0
        self.assertTensorAlmostEqual(actual, desired)
Esempio n. 11
0
    def test_EncodingComparisonOperator_call_no_target(self):
        class TestOperator(ops.EncodingComparisonOperator):
            def target_enc_to_repr(self, image):
                pass

            def input_enc_to_repr(self, image, ctx):
                pass

            def calculate_score(self, input_repr, target_repr, ctx):
                pass

        torch.manual_seed(0)
        input_image = torch.rand(1, 3, 128, 128)
        encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), ))

        test_op = TestOperator(encoder)

        with self.assertRaises(RuntimeError):
            test_op(input_image)
Esempio n. 12
0
    def test_EncodingRegularizationOperator_call_guided(self):
        class TestOperator(ops.EncodingRegularizationOperator):
            def input_enc_to_repr(self, image):
                return image * 2.0

            def calculate_score(self, input_repr):
                return input_repr + 1.0

        torch.manual_seed(0)
        image = torch.rand(1, 3, 32, 32)
        guide = torch.rand(1, 3, 32, 32)
        encoder = SequentialEncoder((nn.Conv2d(3, 3, 1),))
        enc_guide = encoder.propagate_guide(guide)

        test_op = TestOperator(encoder)
        test_op.set_input_guide(guide)

        actual = test_op(image)
        desired = TestOperator.apply_guide(encoder(image), enc_guide) * 2.0 + 1.0
        self.assertTensorAlmostEqual(actual, desired)
Esempio n. 13
0
    def test_MRFOperator_call(self):
        patch_size = 3
        stride = 2

        torch.manual_seed(0)
        target_image = torch.rand(1, 3, 32, 32)
        input_image = torch.rand(1, 3, 32, 32)
        encoder = SequentialEncoder((nn.Conv2d(3, 3, 1),))

        op = ops.MRFOperator(encoder, patch_size, stride=stride)
        op.set_target_image(target_image)

        actual = op(input_image)
        desired = F.patch_matching_loss(
            pystiche.extract_patches2d(encoder(input_image), patch_size, stride=stride),
            pystiche.extract_patches2d(
                encoder(target_image), patch_size, stride=stride
            ),
        )
        self.assertFloatAlmostEqual(actual, desired)
Esempio n. 14
0
    def test_MRFOperator_set_target_guide(self):
        patch_size = 3
        stride = 2

        torch.manual_seed(0)
        image = torch.rand(1, 3, 32, 32)
        guide = torch.rand(1, 1, 32, 32)
        encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), ))

        op = ops.MRFOperator(encoder, patch_size, stride=stride)
        op.set_target_image(image)
        self.assertFalse(op.has_target_guide)

        op.set_target_guide(guide)
        self.assertTrue(op.has_target_guide)

        actual = op.target_guide
        desired = guide
        self.assertTensorAlmostEqual(actual, desired)

        actual = op.target_image
        desired = image
        self.assertTensorAlmostEqual(actual, desired)
Esempio n. 15
0
    def test_MRFOperator_target_image_to_repr(self):
        patch_size = 3
        stride = 2
        scale_step_width = 10e-2
        rotation_step_width = 30.0

        torch.manual_seed(0)
        image = torch.rand(1, 3, 32, 32)
        encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), ))

        op = ops.MRFOperator(
            encoder,
            patch_size,
            stride=stride,
            num_scale_steps=1,
            scale_step_width=scale_step_width,
            num_rotation_steps=1,
            rotation_step_width=rotation_step_width,
        )
        op.set_target_image(image)

        actual = op.target_repr

        reprs = []
        factors = (1.0 - scale_step_width, 1.0, 1.0 + scale_step_width)
        angles = (-rotation_step_width, 0.0, rotation_step_width)
        for factor, angle in itertools.product(factors, angles):
            transformed_image = transform_motif_affinely(image,
                                                         rotation_angle=angle,
                                                         scaling_factor=factor)
            enc = encoder(transformed_image)
            repr = pystiche.extract_patches2d(enc, patch_size, stride)
            reprs.append(repr)
        desired = torch.cat(reprs)

        self.assertTensorAlmostEqual(actual, desired)
Esempio n. 16
0
 def get_guided_perceptual_loss():
     content_loss = FeatureReconstructionOperator(
         SequentialEncoder((nn.Conv2d(1, 1, 1),))
     )
     style_loss = MultiRegionOperator(regions, get_op)
     return loss.GuidedPerceptualLoss(content_loss, style_loss)