コード例 #1
0
    def test_EncodingComparisonOperator_call_guided(self):
        class TestOperator(ops.EncodingComparisonOperator):
            def target_enc_to_repr(self, image):
                repr = image + 1.0
                return repr, None

            def input_enc_to_repr(self, image, ctx):
                return image + 2.0

            def calculate_score(self, input_repr, target_repr, ctx):
                return input_repr * target_repr

        torch.manual_seed(0)
        target_image = torch.rand(1, 3, 32, 32)
        input_image = torch.rand(1, 3, 32, 32)
        target_guide = torch.rand(1, 1, 32, 32)
        input_guide = torch.rand(1, 1, 32, 32)

        encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), ))
        target_enc_guide = encoder.propagate_guide(target_guide)
        input_enc_guide = encoder.propagate_guide(input_guide)

        test_op = TestOperator(encoder)
        test_op.set_target_guide(target_guide)
        test_op.set_target_image(target_image)
        test_op.set_input_guide(input_guide)

        actual = test_op(input_image)
        desired = (
            TestOperator.apply_guide(encoder(target_image), target_enc_guide) +
            1.0) * (TestOperator.apply_guide(encoder(input_image),
                                             input_enc_guide) + 2.0)
        self.assertTensorAlmostEqual(actual, desired)
コード例 #2
0
ファイル: test_ops.py プロジェクト: sourcery-ai-bot/pystiche
    def test_EncodingRegularizationOperator_call_guided(self):
        class TestOperator(ops.EncodingRegularizationOperator):
            def input_enc_to_repr(self, image):
                return image * 2.0

            def calculate_score(self, input_repr):
                return input_repr + 1.0

        torch.manual_seed(0)
        image = torch.rand(1, 3, 32, 32)
        guide = torch.rand(1, 3, 32, 32)
        encoder = SequentialEncoder((nn.Conv2d(3, 3, 1),))
        enc_guide = encoder.propagate_guide(guide)

        test_op = TestOperator(encoder)
        test_op.set_input_guide(guide)

        actual = test_op(image)
        desired = TestOperator.apply_guide(encoder(image), enc_guide) * 2.0 + 1.0
        self.assertTensorAlmostEqual(actual, desired)
コード例 #3
0
    def test_EncodingComparisonOperator_set_target_guide(self):
        class TestOperator(ops.EncodingComparisonOperator):
            def target_enc_to_repr(self, image):
                repr = image * 2.0
                ctx = torch.norm(image)
                return repr, ctx

            def input_enc_to_repr(self, image, ctx):
                pass

            def calculate_score(self, input_repr, target_repr, ctx):
                pass

        torch.manual_seed(0)
        image = torch.rand(1, 3, 32, 32)
        guide = torch.rand(1, 1, 32, 32)
        encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), ))
        enc_guide = encoder.propagate_guide(guide)

        test_op = TestOperator(encoder)
        test_op.set_target_image(image)
        self.assertFalse(test_op.has_target_guide)

        test_op.set_target_guide(guide)
        self.assertTrue(test_op.has_target_guide)

        actual = test_op.target_guide
        desired = guide
        self.assertTensorAlmostEqual(actual, desired)

        actual = test_op.target_enc_guide
        desired = enc_guide
        self.assertTensorAlmostEqual(actual, desired)

        actual = test_op.target_image
        desired = image
        self.assertTensorAlmostEqual(actual, desired)