示例#1
0
    def test_MRFOperator_call_guided(self):
        patch_size = 2
        stride = 2

        torch.manual_seed(0)
        target_image = torch.rand(1, 3, 32, 32)
        input_image = torch.rand(1, 3, 32, 32)
        target_guide = torch.cat(
            (torch.zeros(1, 1, 16, 32), torch.ones(1, 1, 16, 32)), dim=2)
        input_guide = target_guide.flip(2)
        encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), ))

        op = ops.MRFOperator(encoder, patch_size, stride=stride)
        op.set_target_guide(target_guide)
        op.set_target_image(target_image)
        op.set_input_guide(input_guide)

        actual = op(input_image)

        input_enc = encoder(input_image)[:, :, :16, :]
        target_enc = encoder(target_image)[:, :, 16:, :]
        desired = F.patch_matching_loss(
            pystiche.extract_patches2d(input_enc, patch_size, stride=stride),
            pystiche.extract_patches2d(target_enc, patch_size, stride=stride),
        )
        self.assertFloatAlmostEqual(actual, desired)
示例#2
0
    def test_patch_matching_loss(self):
        torch.manual_seed(0)
        zero_patch = torch.zeros(3, 3, 3)
        one_patch = torch.ones(3, 3, 3)
        rand_patch = torch.randn(3, 3, 3)

        input = torch.stack((rand_patch + 0.1, rand_patch * 0.9))
        target = torch.stack((zero_patch, one_patch, rand_patch))

        actual = F.patch_matching_loss(input, target)
        desired = mse_loss(input, torch.stack((rand_patch, rand_patch)))
        self.assertFloatAlmostEqual(actual, desired)
示例#3
0
    def test_MRFOperator_call(self):
        patch_size = 3
        stride = 2

        torch.manual_seed(0)
        target_image = torch.rand(1, 3, 32, 32)
        input_image = torch.rand(1, 3, 32, 32)
        encoder = SequentialEncoder((nn.Conv2d(3, 3, 1),))

        op = ops.MRFOperator(encoder, patch_size, stride=stride)
        op.set_target_image(target_image)

        actual = op(input_image)
        desired = F.patch_matching_loss(
            pystiche.extract_patches2d(encoder(input_image), patch_size, stride=stride),
            pystiche.extract_patches2d(
                encoder(target_image), patch_size, stride=stride
            ),
        )
        self.assertFloatAlmostEqual(actual, desired)
示例#4
0
 def calculate_score(self, input_repr, target_repr, ctx):
     score = F.patch_matching_loss(
         input_repr, target_repr, reduction=self.loss_reduction
     )
     return score * self.score_correction_factor