def test_join_channelwise(subtests, image_small_0, image_small_1): join_image = join_channelwise(image_small_0, image_small_1) assert isinstance(join_image, torch.Tensor) input_num_channels = image.extract_num_channels(image_small_0) assert image.extract_num_channels( join_image ) == input_num_channels + image.extract_num_channels(image_small_1) ptu.assert_allclose(join_image[:, :input_num_channels, :, :], image_small_0) ptu.assert_allclose(join_image[:, input_num_channels:, :, :], image_small_1)
def test_JoinBlock(subtests, input_image): input_image1 = input_image input_image2 = torch.cat((input_image, input_image), 1) branch_in_channels = ( image.extract_num_channels(input_image1), image.extract_num_channels(input_image2), ) channel_dim = 1 for instance_norm, names in itertools.product( (True, False), (("block1", "block2"), None) ): block = paper.JoinBlock( branch_in_channels, names=names, instance_norm=instance_norm, channel_dim=channel_dim, ) with subtests.test("norm_modules"): assert len(block.norm_modules) == len(branch_in_channels) assert any( isinstance( norm_module, nn.InstanceNorm2d if instance_norm else nn.BatchNorm2d ) for norm_module in block.norm_modules ) with subtests.test("out_channels"): assert block.out_channels == sum(branch_in_channels) with subtests.test("channel_dim"): assert block.channel_dim == channel_dim with subtests.test("forward"): inputs = (input_image1, input_image2) actual = block(*inputs) momentum = block.norm_modules[0].momentum assert isinstance(actual, torch.Tensor) desired_inputs = tuple( F.instance_norm(input_image, momentum=momentum) if instance_norm else F.batch_norm( input_image, torch.zeros(image.extract_num_channels(input_image)), torch.ones(image.extract_num_channels(input_image)), training=True, momentum=momentum, ) for input_image in inputs ) desired = torch.cat(desired_inputs, 1) ptu.assert_allclose(actual, desired)
def test_extract_num_channels(): num_channels = 3 single_image = torch.zeros(num_channels, 1, 1) actual = image_.extract_num_channels(single_image) desired = num_channels assert actual == desired batched_image = single_image.unsqueeze(0) actual = image_.extract_num_channels(batched_image) desired = num_channels assert actual == desired
def forward(self, input_image: torch.Tensor) -> torch.Tensor: is_grayscale = image.extract_num_channels(input_image) == 1 if not is_grayscale: return input_image repeats = [1] * input_image.ndim repeats[-3] = 3 return input_image.repeat(repeats)
def test_AddNoiseChannels(subtests, input_image): in_channels = image.extract_num_channels(input_image) num_noise_channels = in_channels + 1 module = AddNoiseChannels(in_channels, num_noise_channels=num_noise_channels) assert isinstance(module, nn.Module) with subtests.test("in_channels"): assert module.in_channels == in_channels desired_out_channels = in_channels + num_noise_channels with subtests.test("out_channels"): assert module.out_channels == desired_out_channels with subtests.test("forward"): output_image = module(input_image) assert image.extract_num_channels(output_image) == desired_out_channels
def test_AutoPadConvTranspose2d(subtests, auto_pad_conv_params, input_image): in_channels = out_channels = extract_num_channels(input_image) image_size = extract_image_size(input_image) for params in auto_pad_conv_params: with subtests.test(**params): conv = utils.AutoPadConvTranspose2d(in_channels, out_channels, **params) output_image = conv(input_image) actual = extract_image_size(output_image) expected = tuple(side_length * stride for side_length, stride in zip( image_size, to_2d_arg(params["stride"]))) assert actual == expected
def forward(self, input_image: torch.Tensor) -> torch.Tensor: is_grayscale = extract_num_channels(input_image) == 1 if is_grayscale: return grayscale_to_fakegrayscale(input_image) else: return input_image