Exemplo n.º 1
0
def test_BranchBlock(subtests, input_image):
    deep_branch = nn.Conv2d(3, 3, 1)
    shallow_branch = nn.Conv2d(3, 3, 1)
    block = paper.BranchBlock(deep_branch, shallow_branch)

    with subtests.test("deep"):
        assert block.deep == deep_branch

    with subtests.test("shallow"):
        assert block.shallow == shallow_branch

    with subtests.test("join_block"):
        assert isinstance(block.join, paper.JoinBlock)

    with subtests.test("out_channels"):
        assert (
            block.out_channels == deep_branch.out_channels + shallow_branch.out_channels
        )

    with subtests.test("forward"):
        actual = block(input_image)
        assert isinstance(actual, torch.Tensor)
        momentum = block.join.norm_modules[0].momentum
        inputs = (deep_branch(input_image), shallow_branch(input_image))
        desired_inputs = tuple(
            F.instance_norm(image, momentum=momentum) for image in inputs
        )

        desired = torch.cat(desired_inputs, 1)
        ptu.assert_allclose(actual, desired)
Exemplo n.º 2
0
def test_extract_normalized_patches2d_no_overlap(subtests):
    height = 4
    width = 4
    patch_size = 2
    stride = 2

    input = torch.ones(1, 1, height, width).requires_grad_(True)
    input_normalized = torch.ones(1, 1, height, width).requires_grad_(True)
    target = torch.zeros(1, 1, height, width).detach()

    input_patches = pystiche.extract_patches2d(input,
                                               patch_size=patch_size,
                                               stride=stride)
    input_patches_normalized = paper.extract_normalized_patches2d(
        input_normalized, patch_size=patch_size, stride=stride)
    target_patches = pystiche.extract_patches2d(target,
                                                patch_size=patch_size,
                                                stride=stride)

    loss = 0.5 * torch.sum((input_patches - target_patches)**2.0)
    loss.backward()

    loss_normalized = 0.5 * torch.sum(
        (input_patches_normalized - target_patches)**2.0)
    loss_normalized.backward()

    with subtests.test("forward"):
        ptu.assert_allclose(input_patches_normalized, input_patches)

    with subtests.test("backward"):
        ptu.assert_allclose(input_normalized.grad, input.grad)
Exemplo n.º 3
0
def test_Uint8ToFloatRange():
    image = torch.tensor(255.0)
    transform = transforms.Uint8ToFloatRange()

    actual = transform(image)
    desired = image / 255.0
    ptu.assert_allclose(actual, desired)
Exemplo n.º 4
0
def test_FloatToUint8Range():
    image = torch.tensor(1.0)
    transform = transforms.FloatToUint8Range()

    actual = transform(image)
    desired = image * 255.0
    ptu.assert_allclose(actual, desired)
Exemplo n.º 5
0
def test_nst_smoke(subtests, default_image_pyramid_optim_loop_patch,
                   content_image, style_image):
    patch = default_image_pyramid_optim_loop_patch

    paper.nst(content_image, style_image)

    args, kwargs = patch.call_args
    input_image, criterion, image_pyramid = args
    get_optimizer = kwargs["get_optimizer"]
    preprocessor = kwargs["preprocessor"]
    postprocessor = kwargs["postprocessor"]

    with subtests.test("input_image"):
        ptu.assert_allclose(input_image,
                            image_pyramid[-1].resize_image(content_image))

    with subtests.test("criterion"):
        assert isinstance(criterion, type(paper.perceptual_loss()))

    with subtests.test("image_pyramid"):
        assert isinstance(image_pyramid, type(paper.image_pyramid()))

    with subtests.test("optimizer"):
        assert is_callable(get_optimizer)
        optimizer = get_optimizer(input_image)
        assert isinstance(optimizer, type(paper.optimizer(input_image)))

    with subtests.test("preprocessor"):
        assert isinstance(preprocessor, type(paper.preprocessor()))

    with subtests.test("postprocessor"):
        assert isinstance(postprocessor, type(paper.postprocessor()))
Exemplo n.º 6
0
def test_stylization_smoke(stylization, postprocessor_mocks, input_image):
    hyper_parameters = paper.hyper_parameters()
    edge_size = hyper_parameters.content_transform.edge_size
    _, _, output_image = stylization(input_image)
    ptu.assert_allclose(output_image,
                        F.resize(input_image, (edge_size, edge_size)) + 0.5,
                        rtol=1e-6)
Exemplo n.º 7
0
    def test_call_guided(self, encoder):
        patch_size = 2
        stride = 2

        torch.manual_seed(0)
        target_image = torch.rand(1, 3, 32, 32)
        input_image = torch.rand(1, 3, 32, 32)
        target_guide = torch.cat(
            (torch.zeros(1, 1, 16, 32), torch.ones(1, 1, 16, 32)), dim=2
        )
        input_guide = target_guide.flip(2)

        loss = loss_.MRFLoss(encoder, patch_size, stride=stride)
        loss.set_target_image(target_image, guide=target_guide)
        loss.set_input_guide(input_guide)

        actual = loss(input_image)

        input_enc = encoder(input_image)[:, :, :16, :]
        target_enc = encoder(target_image)[:, :, 16:, :]
        desired = F.mrf_loss(
            pystiche.extract_patches2d(input_enc, patch_size, stride=stride),
            pystiche.extract_patches2d(target_enc, patch_size, stride=stride),
            batched_input=True,
        )
        ptu.assert_allclose(actual, desired)
Exemplo n.º 8
0
def test_content_transform_grayscale_image(subtests, content_image,
                                           impl_params, instance_norm):
    content_image = F.rgb_to_grayscale(content_image)
    edge_size = 16

    hyper_parameters = paper.hyper_parameters(impl_params=impl_params,
                                              instance_norm=instance_norm)
    hyper_parameters.content_transform.edge_size = edge_size

    content_transform = paper.content_transform(
        impl_params=impl_params,
        instance_norm=instance_norm,
        hyper_parameters=hyper_parameters,
    )
    if instance_norm:
        utils.make_reproducible()
    actual = content_transform(content_image)

    if impl_params:
        if instance_norm:
            # Since the transform involves an uncontrollable random component, we can't
            # do a reference test here
            return

        transform_image = F.resize(content_image, edge_size)
    else:
        transform = transforms.CenterCrop(edge_size)
        transform_image = transform(content_image)

    desired = transform_image.repeat(1, 3, 1, 1)

    ptu.assert_allclose(actual, desired)
Exemplo n.º 9
0
def test_MRFOperator_call_guided():
    patch_size = 2
    stride = 2

    torch.manual_seed(0)
    target_image = torch.rand(1, 3, 32, 32)
    input_image = torch.rand(1, 3, 32, 32)
    target_guide = torch.cat(
        (torch.zeros(1, 1, 16, 32), torch.ones(1, 1, 16, 32)), dim=2)
    input_guide = target_guide.flip(2)
    encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), ))

    op = ops.MRFOperator(encoder, patch_size, stride=stride)
    op.set_target_guide(target_guide)
    op.set_target_image(target_image)
    op.set_input_guide(input_guide)

    actual = op(input_image)

    input_enc = encoder(input_image)[:, :, :16, :]
    target_enc = encoder(target_image)[:, :, 16:, :]
    desired = F.mrf_loss(
        pystiche.extract_patches2d(input_enc, patch_size, stride=stride),
        pystiche.extract_patches2d(target_enc, patch_size, stride=stride),
    )
    ptu.assert_allclose(actual, desired)
Exemplo n.º 10
0
def test_MultiLayerEncoder_encode(forward_pass_counter):
    torch.manual_seed(0)
    conv = nn.Conv2d(3, 1, 1)
    relu = nn.ReLU(inplace=False)
    input = torch.rand(1, 3, 128, 128)

    modules = (("count", forward_pass_counter), ("conv", conv), ("relu", relu))
    multi_layer_encoder = enc.MultiLayerEncoder(modules)

    layers = ("conv", "relu")
    multi_layer_encoder.registered_layers.update(layers)
    multi_layer_encoder.encode(input)
    encs = multi_layer_encoder(input, layers)

    actual = encs[0]
    desired = conv(input)
    ptu.assert_allclose(actual, desired)

    actual = encs[1]
    desired = relu(conv(input))
    ptu.assert_allclose(actual, desired)

    actual = forward_pass_counter.count
    desired = 1
    assert actual == desired
Exemplo n.º 11
0
def test_PixelComparisonOperator_call_guided():
    class TestOperator(ops.PixelComparisonOperator):
        def target_image_to_repr(self, image):
            repr = image + 1.0
            return repr, None

        def input_image_to_repr(self, image, ctx):
            return image + 2.0

        def calculate_score(self, input_repr, target_repr, ctx):
            return input_repr * target_repr

    torch.manual_seed(0)
    target_image = torch.rand(1, 3, 32, 32)
    target_guide = torch.rand(1, 1, 32, 32)
    input_image = torch.rand(1, 3, 32, 32)
    input_guide = torch.rand(1, 1, 32, 32)

    test_op = TestOperator()
    test_op.set_target_guide(target_guide)
    test_op.set_target_image(target_image)
    test_op.set_input_guide(input_guide)

    actual = test_op(input_image)
    desired = (TestOperator.apply_guide(target_image, target_guide) + 1.0) * (
        TestOperator.apply_guide(input_image, input_guide) + 2.0)
    ptu.assert_allclose(actual, desired)
Exemplo n.º 12
0
def test_segmentation_to_guides():
    guides, _ = get_test_guides()
    segmentation, region_map = get_test_segmentation()

    actual = image.segmentation_to_guides(segmentation, region_map=region_map)
    desired = guides
    ptu.assert_allclose(actual, desired)
Exemplo n.º 13
0
def test_PixelComparisonOperator_set_target_guide():
    class TestOperator(ops.PixelComparisonOperator):
        def target_image_to_repr(self, image):
            repr = image * 2.0
            ctx = torch.norm(image)
            return repr, ctx

        def input_image_to_repr(self, image, ctx):
            pass

        def calculate_score(self, input_repr, target_repr, ctx):
            pass

    torch.manual_seed(0)
    image = torch.rand(1, 3, 32, 32)
    guide = torch.rand(1, 1, 32, 32)

    test_op = TestOperator()
    test_op.set_target_image(image)
    assert not test_op.has_target_guide

    test_op.set_target_guide(guide)
    assert test_op.has_target_guide

    actual = test_op.target_guide
    desired = guide
    ptu.assert_allclose(actual, desired)

    actual = test_op.target_image
    desired = image
    ptu.assert_allclose(actual, desired)
Exemplo n.º 14
0
def test_EncodingComparisonOperator_call_guided():
    class TestOperator(ops.EncodingComparisonOperator):
        def target_enc_to_repr(self, image):
            repr = image + 1.0
            return repr, None

        def input_enc_to_repr(self, image, ctx):
            return image + 2.0

        def calculate_score(self, input_repr, target_repr, ctx):
            return input_repr * target_repr

    torch.manual_seed(0)
    target_image = torch.rand(1, 3, 32, 32)
    input_image = torch.rand(1, 3, 32, 32)
    target_guide = torch.rand(1, 1, 32, 32)
    input_guide = torch.rand(1, 1, 32, 32)

    encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), ))
    target_enc_guide = encoder.propagate_guide(target_guide)
    input_enc_guide = encoder.propagate_guide(input_guide)

    test_op = TestOperator(encoder)
    test_op.set_target_guide(target_guide)
    test_op.set_target_image(target_image)
    test_op.set_input_guide(input_guide)

    actual = test_op(input_image)
    desired = (
        TestOperator.apply_guide(encoder(target_image), target_enc_guide) + 1.0
    ) * (TestOperator.apply_guide(encoder(input_image), input_enc_guide) + 2.0)
    ptu.assert_allclose(actual, desired)
Exemplo n.º 15
0
def test_guides_to_segmentation():
    guides, color_map = get_test_guides()
    segmentation, _ = get_test_segmentation()

    actual = image.guides_to_segmentation(guides, color_map=color_map)
    desired = segmentation
    ptu.assert_allclose(actual, desired)
Exemplo n.º 16
0
def test_AutoPadAvgPool1d_count_include_pad():
    kernel_size = 5

    from pystiche.misc import to_1d_arg
    from pystiche_papers.utils.modules import _AutoPadAvgPoolNdMixin

    class AutoPadAvgPool1d(_AutoPadAvgPoolNdMixin, nn.AvgPool1d):
        def __init__(
            self,
            kernel_size,
            stride=None,
            **kwargs,
        ) -> None:
            kernel_size = to_1d_arg(kernel_size)
            stride = kernel_size if stride is None else to_1d_arg(stride)
            super().__init__(kernel_size, stride=stride, **kwargs)

    torch.manual_seed(0)
    input = torch.rand(1, 1, 32)

    manual = nn.AvgPool1d(kernel_size,
                          stride=1,
                          padding=(kernel_size - 1) // 2,
                          count_include_pad=False)
    auto = AutoPadAvgPool1d(kernel_size, stride=1, count_include_pad=False)

    ptu.assert_allclose(auto(input), manual(input), rtol=1e-6)
Exemplo n.º 17
0
    def test_backward(self):
        losses = [
            torch.tensor(val, dtype=torch.float, requires_grad=True)
            for val in range(3)
        ]

        def zero_grad():
            for loss in losses:
                loss.grad = None

        def extract_grads():
            return [loss.grad.clone() for loss in losses]

        zero_grad()
        loss_dict = pystiche.LossDict([(str(idx), loss)
                                       for idx, loss in enumerate(losses)])
        loss_dict.backward()
        actuals = extract_grads()

        zero_grad()
        total = sum(losses)
        total.backward()
        desireds = extract_grads()

        for actual, desired in zip(actuals, desireds):
            ptu.assert_allclose(actual, desired)
Exemplo n.º 18
0
def test_nst_smoke(subtests, mocker, content_image, style_image):
    mock = mocker.patch("pystiche.optim.image_optimization")

    paper.nst(content_image, style_image)

    args, kwargs = mock.call_args
    input_image, criterion = args
    optimizer = kwargs["optimizer"]
    num_steps = kwargs["num_steps"]
    preprocessor = kwargs["preprocessor"]
    postprocessor = kwargs["postprocessor"]

    hyper_parameters = paper.hyper_parameters().nst

    with subtests.test("input_image"):
        ptu.assert_allclose(input_image, content_image)

    with subtests.test("criterion"):
        assert isinstance(criterion, type(paper.perceptual_loss()))

    with subtests.test("optimizer"):
        assert optimizer is paper.optimizer

    with subtests.test("num_steps"):
        assert num_steps == hyper_parameters.num_steps

    with subtests.test("preprocessor"):
        assert isinstance(preprocessor, type(paper.preprocessor()))

    with subtests.test("postprocessor"):
        assert isinstance(postprocessor, type(paper.postprocessor()))
Exemplo n.º 19
0
    def test_set_target_guide_without_recalc(self):
        class TestOperator(ops.PixelComparisonOperator):
            def target_image_to_repr(self, image):
                repr = image * 2.0
                ctx = torch.norm(image)
                return repr, ctx

            def input_image_to_repr(self, image, ctx):
                pass

            def calculate_score(self, input_repr, target_repr, ctx):
                pass

        torch.manual_seed(0)
        image = torch.rand(1, 3, 32, 32)
        guide = torch.rand(1, 1, 32, 32)

        test_op = TestOperator()
        test_op.set_target_image(image)
        desired = test_op.target_repr.clone()

        test_op.set_target_guide(guide, recalc_repr=False)
        actual = test_op.target_repr

        ptu.assert_allclose(actual, desired)
Exemplo n.º 20
0
    def test_read_rel_path(self, test_image_file, test_image):
        root, filename = path.dirname(test_image_file), path.basename(test_image_file)
        image = data.LocalImage(filename)

        actual = image.read(root)
        desired = test_image
        ptu.assert_allclose(actual, desired)
Exemplo n.º 21
0
def test_EncodingComparisonOperator_set_target_guide_without_recalc():
    class TestOperator(ops.EncodingComparisonOperator):
        def target_enc_to_repr(self, image):
            repr = image * 2.0
            ctx = torch.norm(image)
            return repr, ctx

        def input_enc_to_repr(self, image, ctx):
            pass

        def calculate_score(self, input_repr, target_repr, ctx):
            pass

    torch.manual_seed(0)
    repr = torch.rand(1, 3, 32, 32)
    guide = torch.rand(1, 1, 32, 32)
    encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), ))

    test_op = TestOperator(encoder)
    test_op.register_buffer("target_repr", repr)
    test_op.set_target_guide(guide, recalc_repr=False)

    actual = test_op.target_repr
    desired = repr
    ptu.assert_allclose(actual, desired)
Exemplo n.º 22
0
def test_CaffePostprocessing(pixel, caffe_mean, caffe_std):
    processing = enc.CaffePostprocessing()

    expected = pixel
    input = pixel.sub(caffe_mean).div(caffe_std).mul(255).flip(1)
    actual = processing(input)

    ptu.assert_allclose(actual, expected, rtol=1e-6)
Exemplo n.º 23
0
    def test_main(self, tmpdir, test_image_url, test_image):
        images = {"test_image": data.DownloadableImage(test_image_url)}
        collection = data.DownloadableImageCollection(images,)
        collection.download(root=tmpdir)

        actual = collection["test_image"].read(root=tmpdir)
        desired = test_image
        ptu.assert_allclose(actual, desired)
Exemplo n.º 24
0
    def test_enc_to_repr_guided_varying(self, mrf_op):
        enc_ = torch.rand(1, 4, 8, 8)

        actual = mrf_op.enc_to_repr(enc_, is_guided=True)
        desired = pystiche.extract_patches2d(
            enc_, mrf_op.patch_size[0], stride=mrf_op.stride[0]
        )
        ptu.assert_allclose(actual, desired)
Exemplo n.º 25
0
def test_SequentialWithOutChannels_forward_behaviour(input_image):
    sequential_modules = (nn.Conv2d(3, 3, 1), nn.Conv2d(3, 5, 1))
    sequential = utils.SequentialWithOutChannels(*sequential_modules)
    actual = sequential(input_image)
    desired = input_image
    for module in sequential_modules:
        desired = module(desired)
    ptu.assert_allclose(actual, desired)
Exemplo n.º 26
0
def test_Operator_apply_guide():
    torch.manual_seed(0)
    image = torch.rand(1, 3, 32, 32)
    guide = torch.rand(1, 1, 32, 32)

    actual = ops.Operator.apply_guide(image, guide)
    desired = image * guide
    ptu.assert_allclose(actual, desired)
Exemplo n.º 27
0
def test_TorchPreprocessing(pixel, torch_mean, torch_std):
    processing = enc.TorchPreprocessing()

    expected = pixel
    input = pixel.mul(torch_std).add(torch_mean)
    actual = processing(input)

    ptu.assert_allclose(actual, expected, rtol=1e-6)
Exemplo n.º 28
0
def test_TorchPostprocessing(pixel, torch_mean, torch_std):
    processing = enc.TorchPostprocessing()

    expected = pixel
    input = pixel.sub(torch_mean).div(torch_std)
    actual = processing(input)

    ptu.assert_allclose(actual, expected, rtol=1e-6)
Exemplo n.º 29
0
    def test_read(self, tmpdir, test_image_url, test_image):
        names = [str(idx) for idx in range(3)]
        collection = data.DownloadableImageCollection(
            {name: data.DownloadableImage(test_image_url) for name in names}
        )
        images = collection.read(tmpdir)

        ptu.assert_allclose(images, dict(zip(names, [test_image] * len(names))))
Exemplo n.º 30
0
def test_AutoPadConvTranspose2d_state_dict():
    kwargs = dict(in_channels=1, out_channels=2, kernel_size=3, bias=True)
    conv = nn.ConvTranspose2d(**kwargs)
    auto_pad_conv = utils.AutoPadConvTranspose2d(**kwargs)

    state_dict = conv.state_dict()
    auto_pad_conv.load_state_dict(state_dict)
    ptu.assert_allclose(auto_pad_conv.state_dict(), state_dict)