Beispiel #1
0
def test_PerceptualLoss_set_content_image():
    torch.manual_seed(0)
    image = torch.rand(1, 1, 100, 100)
    content_loss = ops.FeatureReconstructionOperator(
        enc.SequentialEncoder((nn.Conv2d(1, 1, 1), )))
    style_loss = ops.FeatureReconstructionOperator(
        enc.SequentialEncoder((nn.Conv2d(1, 1, 1), )))

    perceptual_loss = loss.PerceptualLoss(content_loss, style_loss)
    perceptual_loss.set_content_image(image)

    actual = content_loss.target_image
    desired = image
    ptu.assert_allclose(actual, desired)
Beispiel #2
0
def test_EncodingComparisonOperator_set_target_image():
    class TestOperator(ops.EncodingComparisonOperator):
        def target_enc_to_repr(self, image):
            repr = image * 2.0
            ctx = torch.norm(image)
            return repr, ctx

        def input_enc_to_repr(self, image, ctx):
            pass

        def calculate_score(self, input_repr, target_repr, ctx):
            pass

    torch.manual_seed(0)
    image = torch.rand(1, 3, 128, 128)
    encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), ))

    test_op = TestOperator(encoder)
    assert not test_op.has_target_image

    test_op.set_target_image(image)
    assert test_op.has_target_image

    actual = test_op.target_image
    desired = image
    ptu.assert_allclose(actual, desired)

    actual = test_op.target_repr
    desired = encoder(image) * 2.0
    ptu.assert_allclose(actual, desired)

    actual = test_op.ctx
    desired = torch.norm(encoder(image))
    ptu.assert_allclose(actual, desired)
Beispiel #3
0
    def test_non_persistent_images(self):
        class TestOperator(ops.EncodingComparisonOperator):
            def target_enc_to_repr(self, enc):
                return enc, None

            def input_enc_to_repr(self, enc, ctx):
                pass

            def calculate_score(self, input_repr, target_repr, ctx):
                pass

        torch.manual_seed(0)
        target_image = torch.rand(1, 3, 32, 32)
        target_guide = torch.rand(1, 1, 32, 32)
        input_guide = torch.rand(1, 1, 32, 32)

        encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1),))

        test_op = TestOperator(encoder)
        test_op.set_target_guide(target_guide)
        test_op.set_target_image(target_image)
        test_op.set_input_guide(input_guide)
        state_dict = test_op.state_dict()

        new_test_op = TestOperator(encoder)
        new_test_op.load_state_dict(state_dict, strict=True)
Beispiel #4
0
def test_EncodingComparisonOperator_set_target_guide_without_recalc():
    class TestOperator(ops.EncodingComparisonOperator):
        def target_enc_to_repr(self, image):
            repr = image * 2.0
            ctx = torch.norm(image)
            return repr, ctx

        def input_enc_to_repr(self, image, ctx):
            pass

        def calculate_score(self, input_repr, target_repr, ctx):
            pass

    torch.manual_seed(0)
    repr = torch.rand(1, 3, 32, 32)
    guide = torch.rand(1, 1, 32, 32)
    encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), ))

    test_op = TestOperator(encoder)
    test_op.register_buffer("target_repr", repr)
    test_op.set_target_guide(guide, recalc_repr=False)

    actual = test_op.target_repr
    desired = repr
    ptu.assert_allclose(actual, desired)
Beispiel #5
0
def test_EncodingComparisonOperator_call_batch_size_error():
    class TestOperator(ops.EncodingComparisonOperator):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.batch_size_equal = False

        def target_enc_to_repr(self, enc):
            return enc, None

        def input_enc_to_repr(self, enc, ctx):
            return enc

        def calculate_score(self, input_repr, target_repr, ctx):
            pass

    torch.manual_seed(0)
    target_image = torch.rand(2, 1, 1, 1)
    input_image = torch.rand(1, 1, 1, 1)
    encoder = enc.SequentialEncoder((nn.Conv2d(1, 1, 1), ))

    test_op = TestOperator(encoder)
    test_op.set_target_image(target_image)

    with pytest.raises(RuntimeError):
        test_op(input_image)
Beispiel #6
0
def test_MRFOperator_call_guided():
    patch_size = 2
    stride = 2

    torch.manual_seed(0)
    target_image = torch.rand(1, 3, 32, 32)
    input_image = torch.rand(1, 3, 32, 32)
    target_guide = torch.cat(
        (torch.zeros(1, 1, 16, 32), torch.ones(1, 1, 16, 32)), dim=2)
    input_guide = target_guide.flip(2)
    encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), ))

    op = ops.MRFOperator(encoder, patch_size, stride=stride)
    op.set_target_guide(target_guide)
    op.set_target_image(target_image)
    op.set_input_guide(input_guide)

    actual = op(input_image)

    input_enc = encoder(input_image)[:, :, :16, :]
    target_enc = encoder(target_image)[:, :, 16:, :]
    desired = F.mrf_loss(
        pystiche.extract_patches2d(input_enc, patch_size, stride=stride),
        pystiche.extract_patches2d(target_enc, patch_size, stride=stride),
    )
    ptu.assert_allclose(actual, desired)
Beispiel #7
0
def test_EncodingComparisonOperator_call_batch_size_mismatch():
    class TestOperator(ops.EncodingComparisonOperator):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.batch_size_equal = False

        def target_enc_to_repr(self, enc):
            return enc, None

        def input_enc_to_repr(self, enc, ctx):
            return enc

        def calculate_score(self, input_repr, target_repr, ctx):
            input_batch_size = input_repr.size()[0]
            target_batch_size = target_repr.size()[0]
            self.batch_size_equal = input_batch_size == target_batch_size
            return 0.0

    torch.manual_seed(0)
    target_image = torch.rand(1, 1, 1, 1)
    input_image = torch.rand(2, 1, 1, 1)
    encoder = enc.SequentialEncoder((nn.Conv2d(1, 1, 1), ))

    test_op = TestOperator(encoder)
    test_op.set_target_image(target_image)

    test_op(input_image)
    assert test_op.batch_size_equal
Beispiel #8
0
    def mrf_op(self):
        patch_size = 2
        stride = 2

        return ops.MRFOperator(
            enc.SequentialEncoder((self.Identity(),)), patch_size, stride=stride
        )
Beispiel #9
0
def test_EncodingComparisonOperator_call_guided():
    class TestOperator(ops.EncodingComparisonOperator):
        def target_enc_to_repr(self, image):
            repr = image + 1.0
            return repr, None

        def input_enc_to_repr(self, image, ctx):
            return image + 2.0

        def calculate_score(self, input_repr, target_repr, ctx):
            return input_repr * target_repr

    torch.manual_seed(0)
    target_image = torch.rand(1, 3, 32, 32)
    input_image = torch.rand(1, 3, 32, 32)
    target_guide = torch.rand(1, 1, 32, 32)
    input_guide = torch.rand(1, 1, 32, 32)

    encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), ))
    target_enc_guide = encoder.propagate_guide(target_guide)
    input_enc_guide = encoder.propagate_guide(input_guide)

    test_op = TestOperator(encoder)
    test_op.set_target_guide(target_guide)
    test_op.set_target_image(target_image)
    test_op.set_input_guide(input_guide)

    actual = test_op(input_image)
    desired = (
        TestOperator.apply_guide(encoder(target_image), target_enc_guide) + 1.0
    ) * (TestOperator.apply_guide(encoder(input_image), input_enc_guide) + 2.0)
    ptu.assert_allclose(actual, desired)
Beispiel #10
0
    def test_SequentialEncoder_call(self):
        torch.manual_seed(0)
        modules = (nn.Conv2d(3, 3, 3), nn.ReLU())
        input = torch.rand(1, 3, 256, 256)

        pystiche_encoder = enc.SequentialEncoder(modules)
        torch_encoder = nn.Sequential(*modules)

        actual = pystiche_encoder(input)
        desired = torch_encoder(input)
        self.assertTensorAlmostEqual(actual, desired)
Beispiel #11
0
def test_SequentialEncoder_call():
    torch.manual_seed(0)
    modules = (nn.Conv2d(3, 3, 3), nn.ReLU())
    input = torch.rand(1, 3, 256, 256)

    pystiche_encoder = enc.SequentialEncoder(modules)
    torch_encoder = nn.Sequential(*modules)

    actual = pystiche_encoder(input)
    desired = torch_encoder(input)
    ptu.assert_allclose(actual, desired)
Beispiel #12
0
def test_FeatureReconstructionOperator_call():
    torch.manual_seed(0)
    target_image = torch.rand(1, 3, 128, 128)
    input_image = torch.rand(1, 3, 128, 128)
    encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), ))

    op = ops.FeatureReconstructionOperator(encoder)
    op.set_target_image(target_image)

    actual = op(input_image)
    desired = mse_loss(encoder(input_image), encoder(target_image))
    ptu.assert_allclose(actual, desired)
Beispiel #13
0
def test_GramOperator_call():
    torch.manual_seed(0)
    target_image = torch.rand(1, 3, 128, 128)
    input_image = torch.rand(1, 3, 128, 128)
    encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), ))

    op = ops.GramOperator(encoder)
    op.set_target_image(target_image)

    actual = op(input_image)
    desired = mse_loss(
        pystiche.gram_matrix(encoder(input_image), normalize=True),
        pystiche.gram_matrix(encoder(target_image), normalize=True),
    )
    ptu.assert_allclose(actual, desired)
Beispiel #14
0
def test_MRFOperator_set_target_guide_without_recalc():
    patch_size = 3
    stride = 2

    torch.manual_seed(0)
    repr = torch.rand(1, 3, 32, 32)
    guide = torch.rand(1, 1, 32, 32)
    encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), ))

    op = ops.MRFOperator(encoder, patch_size, stride=stride)
    op.register_buffer("target_repr", repr)
    op.set_target_guide(guide, recalc_repr=False)

    actual = op.target_repr
    desired = repr
    ptu.assert_allclose(actual, desired)
Beispiel #15
0
def test_EncodingRegularizationOperator_call():
    class TestOperator(ops.EncodingRegularizationOperator):
        def input_enc_to_repr(self, image):
            return image * 2.0

        def calculate_score(self, input_repr):
            return input_repr + 1.0

    torch.manual_seed(0)
    image = torch.rand(1, 3, 128, 128)
    encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), ))

    test_op = TestOperator(encoder)

    actual = test_op(image)
    desired = encoder(image) * 2.0 + 1.0
    ptu.assert_allclose(actual, desired)
Beispiel #16
0
    def test_set_target_guide_without_recalc(self):
        patch_size = 3
        stride = 2

        torch.manual_seed(0)
        image = torch.rand(1, 3, 32, 32)
        guide = torch.rand(1, 1, 32, 32)
        encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1),))

        op = ops.MRFOperator(encoder, patch_size, stride=stride)
        op.set_target_image(image)
        desired = op.target_repr.clone()

        op.set_target_guide(guide, recalc_repr=False)
        actual = op.target_repr

        ptu.assert_allclose(actual, desired)
Beispiel #17
0
def test_MRFOperator_enc_to_repr_guided(subtests):
    class Identity(pystiche.Module):
        def forward(self, image):
            return image

    patch_size = 2
    stride = 2

    op = ops.MRFOperator(enc.SequentialEncoder((Identity(), )),
                         patch_size,
                         stride=stride)

    with subtests.test(enc="constant"):
        enc_ = torch.ones(1, 4, 8, 8)

        actual = op.enc_to_repr(enc_, is_guided=True)
        desired = torch.ones(0, 4, stride, stride)
        ptu.assert_allclose(actual, desired)

    with subtests.test(enc="spatial_mix"):
        constant = torch.ones(1, 4, 4, 8)
        varying = torch.rand(1, 4, 4, 8)
        enc_ = torch.cat((constant, varying), dim=2)

        actual = op.enc_to_repr(enc_, is_guided=True)
        desired = pystiche.extract_patches2d(varying,
                                             patch_size,
                                             stride=stride)
        ptu.assert_allclose(actual, desired)

    with subtests.test(enc="channel_mix"):
        constant = torch.ones(1, 2, 8, 8)
        varying = torch.rand(1, 2, 8, 8)
        enc_ = torch.cat((constant, varying), dim=1)

        actual = op.enc_to_repr(enc_, is_guided=True)
        desired = pystiche.extract_patches2d(enc_, patch_size, stride=stride)
        ptu.assert_allclose(actual, desired)

    with subtests.test(enc="varying"):
        enc_ = torch.rand(1, 4, 8, 8)

        actual = op.enc_to_repr(enc_, is_guided=True)
        desired = pystiche.extract_patches2d(enc_, patch_size, stride=stride)
        ptu.assert_allclose(actual, desired)
Beispiel #18
0
def test_EncodingComparisonOperator_call_no_target():
    class TestOperator(ops.EncodingComparisonOperator):
        def target_enc_to_repr(self, image):
            pass

        def input_enc_to_repr(self, image, ctx):
            pass

        def calculate_score(self, input_repr, target_repr, ctx):
            pass

    torch.manual_seed(0)
    input_image = torch.rand(1, 3, 128, 128)
    encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), ))

    test_op = TestOperator(encoder)

    with pytest.raises(RuntimeError):
        test_op(input_image)
Beispiel #19
0
    def test_call(self):
        patch_size = 3
        stride = 2

        torch.manual_seed(0)
        target_image = torch.rand(1, 3, 32, 32)
        input_image = torch.rand(1, 3, 32, 32)
        encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1),))

        op = ops.MRFOperator(encoder, patch_size, stride=stride)
        op.set_target_image(target_image)

        actual = op(input_image)
        desired = F.mrf_loss(
            pystiche.extract_patches2d(encoder(input_image), patch_size, stride=stride),
            pystiche.extract_patches2d(
                encoder(target_image), patch_size, stride=stride
            ),
        )
        ptu.assert_allclose(actual, desired)
Beispiel #20
0
def test_EncodingRegularizationOperator_call_guided():
    class TestOperator(ops.EncodingRegularizationOperator):
        def input_enc_to_repr(self, image):
            return image * 2.0

        def calculate_score(self, input_repr):
            return input_repr + 1.0

    torch.manual_seed(0)
    image = torch.rand(1, 3, 32, 32)
    guide = torch.rand(1, 3, 32, 32)
    encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), ))
    enc_guide = encoder.propagate_guide(guide)

    test_op = TestOperator(encoder)
    test_op.set_input_guide(guide)

    actual = test_op(image)
    desired = TestOperator.apply_guide(encoder(image), enc_guide) * 2.0 + 1.0
    ptu.assert_allclose(actual, desired)
Beispiel #21
0
def test_MRFOperator_set_target_guide():
    patch_size = 3
    stride = 2

    torch.manual_seed(0)
    image = torch.rand(1, 3, 32, 32)
    guide = torch.rand(1, 1, 32, 32)
    encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), ))

    op = ops.MRFOperator(encoder, patch_size, stride=stride)
    op.set_target_image(image)
    assert not op.has_target_guide

    op.set_target_guide(guide)
    assert op.has_target_guide

    actual = op.target_guide
    desired = guide
    ptu.assert_allclose(actual, desired)

    actual = op.target_image
    desired = image
    ptu.assert_allclose(actual, desired)
Beispiel #22
0
def encoder():
    return enc.SequentialEncoder((nn.Conv2d(3, 3, 1), ))
Beispiel #23
0
 def get_guided_perceptual_loss():
     content_loss = ops.FeatureReconstructionOperator(
         enc.SequentialEncoder((nn.Conv2d(1, 1, 1), )))
     style_loss = ops.MultiRegionOperator(regions, get_op)
     return loss.GuidedPerceptualLoss(content_loss, style_loss)