Exemplo n.º 1
0
def test_join_channelwise(subtests, image_small_0, image_small_1):
    join_image = join_channelwise(image_small_0, image_small_1)
    assert isinstance(join_image, torch.Tensor)

    input_num_channels = image.extract_num_channels(image_small_0)
    assert image.extract_num_channels(
        join_image
    ) == input_num_channels + image.extract_num_channels(image_small_1)
    ptu.assert_allclose(join_image[:, :input_num_channels, :, :], image_small_0)
    ptu.assert_allclose(join_image[:, input_num_channels:, :, :], image_small_1)
Exemplo n.º 2
0
def test_JoinBlock(subtests, input_image):
    input_image1 = input_image
    input_image2 = torch.cat((input_image, input_image), 1)
    branch_in_channels = (
        image.extract_num_channels(input_image1),
        image.extract_num_channels(input_image2),
    )
    channel_dim = 1

    for instance_norm, names in itertools.product(
        (True, False), (("block1", "block2"), None)
    ):
        block = paper.JoinBlock(
            branch_in_channels,
            names=names,
            instance_norm=instance_norm,
            channel_dim=channel_dim,
        )

        with subtests.test("norm_modules"):
            assert len(block.norm_modules) == len(branch_in_channels)
            assert any(
                isinstance(
                    norm_module, nn.InstanceNorm2d if instance_norm else nn.BatchNorm2d
                )
                for norm_module in block.norm_modules
            )

        with subtests.test("out_channels"):
            assert block.out_channels == sum(branch_in_channels)

        with subtests.test("channel_dim"):
            assert block.channel_dim == channel_dim

        with subtests.test("forward"):
            inputs = (input_image1, input_image2)
            actual = block(*inputs)
            momentum = block.norm_modules[0].momentum
            assert isinstance(actual, torch.Tensor)
            desired_inputs = tuple(
                F.instance_norm(input_image, momentum=momentum)
                if instance_norm
                else F.batch_norm(
                    input_image,
                    torch.zeros(image.extract_num_channels(input_image)),
                    torch.ones(image.extract_num_channels(input_image)),
                    training=True,
                    momentum=momentum,
                )
                for input_image in inputs
            )

            desired = torch.cat(desired_inputs, 1)
            ptu.assert_allclose(actual, desired)
Exemplo n.º 3
0
def test_extract_num_channels():
    num_channels = 3

    single_image = torch.zeros(num_channels, 1, 1)
    actual = image_.extract_num_channels(single_image)
    desired = num_channels
    assert actual == desired

    batched_image = single_image.unsqueeze(0)
    actual = image_.extract_num_channels(batched_image)
    desired = num_channels
    assert actual == desired
Exemplo n.º 4
0
    def forward(self, input_image: torch.Tensor) -> torch.Tensor:
        is_grayscale = image.extract_num_channels(input_image) == 1
        if not is_grayscale:
            return input_image

        repeats = [1] * input_image.ndim
        repeats[-3] = 3
        return input_image.repeat(repeats)
Exemplo n.º 5
0
def test_AddNoiseChannels(subtests, input_image):
    in_channels = image.extract_num_channels(input_image)
    num_noise_channels = in_channels + 1
    module = AddNoiseChannels(in_channels, num_noise_channels=num_noise_channels)

    assert isinstance(module, nn.Module)

    with subtests.test("in_channels"):
        assert module.in_channels == in_channels

    desired_out_channels = in_channels + num_noise_channels

    with subtests.test("out_channels"):
        assert module.out_channels == desired_out_channels

    with subtests.test("forward"):
        output_image = module(input_image)
        assert image.extract_num_channels(output_image) == desired_out_channels
Exemplo n.º 6
0
def test_AutoPadConvTranspose2d(subtests, auto_pad_conv_params, input_image):
    in_channels = out_channels = extract_num_channels(input_image)
    image_size = extract_image_size(input_image)

    for params in auto_pad_conv_params:
        with subtests.test(**params):
            conv = utils.AutoPadConvTranspose2d(in_channels, out_channels,
                                                **params)
            output_image = conv(input_image)

            actual = extract_image_size(output_image)
            expected = tuple(side_length * stride
                             for side_length, stride in zip(
                                 image_size, to_2d_arg(params["stride"])))
            assert actual == expected
Exemplo n.º 7
0
 def forward(self, input_image: torch.Tensor) -> torch.Tensor:
     is_grayscale = extract_num_channels(input_image) == 1
     if is_grayscale:
         return grayscale_to_fakegrayscale(input_image)
     else:
         return input_image