コード例 #1
0
def test_MRFOperator_call_guided():
    patch_size = 2
    stride = 2

    torch.manual_seed(0)
    target_image = torch.rand(1, 3, 32, 32)
    input_image = torch.rand(1, 3, 32, 32)
    target_guide = torch.cat(
        (torch.zeros(1, 1, 16, 32), torch.ones(1, 1, 16, 32)), dim=2)
    input_guide = target_guide.flip(2)
    encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), ))

    op = ops.MRFOperator(encoder, patch_size, stride=stride)
    op.set_target_guide(target_guide)
    op.set_target_image(target_image)
    op.set_input_guide(input_guide)

    actual = op(input_image)

    input_enc = encoder(input_image)[:, :, :16, :]
    target_enc = encoder(target_image)[:, :, 16:, :]
    desired = F.mrf_loss(
        pystiche.extract_patches2d(input_enc, patch_size, stride=stride),
        pystiche.extract_patches2d(target_enc, patch_size, stride=stride),
    )
    ptu.assert_allclose(actual, desired)
コード例 #2
0
def test_mrf_loss():
    torch.manual_seed(0)
    zero_patch = torch.zeros(3, 3, 3)
    one_patch = torch.ones(3, 3, 3)
    rand_patch = torch.randn(3, 3, 3)

    input = torch.stack((rand_patch + 0.1, rand_patch * 0.9))
    target = torch.stack((zero_patch, one_patch, rand_patch))

    actual = F.mrf_loss(input, target)
    desired = mse_loss(input, torch.stack((rand_patch, rand_patch)))
    ptu.assert_allclose(actual, desired)
コード例 #3
0
ファイル: test_comparison.py プロジェクト: pystiche/pystiche
    def test_call(self):
        patch_size = 3
        stride = 2

        torch.manual_seed(0)
        target_image = torch.rand(1, 3, 32, 32)
        input_image = torch.rand(1, 3, 32, 32)
        encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1),))

        op = ops.MRFOperator(encoder, patch_size, stride=stride)
        op.set_target_image(target_image)

        actual = op(input_image)
        desired = F.mrf_loss(
            pystiche.extract_patches2d(encoder(input_image), patch_size, stride=stride),
            pystiche.extract_patches2d(
                encoder(target_image), patch_size, stride=stride
            ),
        )
        ptu.assert_allclose(actual, desired)
コード例 #4
0
def test_mrf_loss_future_warning():
    input = torch.empty(1, 2)
    target = torch.empty(1, 2)
    with pytest.warns(FutureWarning):
        F.mrf_loss(input, target)