def test_extract_normalized_patches2d_no_overlap(subtests): height = 4 width = 4 patch_size = 2 stride = 2 input = torch.ones(1, 1, height, width).requires_grad_(True) input_normalized = torch.ones(1, 1, height, width).requires_grad_(True) target = torch.zeros(1, 1, height, width).detach() input_patches = pystiche.extract_patches2d(input, patch_size=patch_size, stride=stride) input_patches_normalized = paper.extract_normalized_patches2d( input_normalized, patch_size=patch_size, stride=stride) target_patches = pystiche.extract_patches2d(target, patch_size=patch_size, stride=stride) loss = 0.5 * torch.sum((input_patches - target_patches)**2.0) loss.backward() loss_normalized = 0.5 * torch.sum( (input_patches_normalized - target_patches)**2.0) loss_normalized.backward() with subtests.test("forward"): ptu.assert_allclose(input_patches_normalized, input_patches) with subtests.test("backward"): ptu.assert_allclose(input_normalized.grad, input.grad)
def test_MRFOperator_call_guided(): patch_size = 2 stride = 2 torch.manual_seed(0) target_image = torch.rand(1, 3, 32, 32) input_image = torch.rand(1, 3, 32, 32) target_guide = torch.cat( (torch.zeros(1, 1, 16, 32), torch.ones(1, 1, 16, 32)), dim=2) input_guide = target_guide.flip(2) encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), )) op = ops.MRFOperator(encoder, patch_size, stride=stride) op.set_target_guide(target_guide) op.set_target_image(target_image) op.set_input_guide(input_guide) actual = op(input_image) input_enc = encoder(input_image)[:, :, :16, :] target_enc = encoder(target_image)[:, :, 16:, :] desired = F.mrf_loss( pystiche.extract_patches2d(input_enc, patch_size, stride=stride), pystiche.extract_patches2d(target_enc, patch_size, stride=stride), ) ptu.assert_allclose(actual, desired)
def test_call_guided(self, encoder): patch_size = 2 stride = 2 torch.manual_seed(0) target_image = torch.rand(1, 3, 32, 32) input_image = torch.rand(1, 3, 32, 32) target_guide = torch.cat( (torch.zeros(1, 1, 16, 32), torch.ones(1, 1, 16, 32)), dim=2 ) input_guide = target_guide.flip(2) loss = loss_.MRFLoss(encoder, patch_size, stride=stride) loss.set_target_image(target_image, guide=target_guide) loss.set_input_guide(input_guide) actual = loss(input_image) input_enc = encoder(input_image)[:, :, :16, :] target_enc = encoder(target_image)[:, :, 16:, :] desired = F.mrf_loss( pystiche.extract_patches2d(input_enc, patch_size, stride=stride), pystiche.extract_patches2d(target_enc, patch_size, stride=stride), batched_input=True, ) ptu.assert_allclose(actual, desired)
def test_MRFOperator_enc_to_repr_guided(self): class Identity(pystiche.Module): def forward(self, image): return image patch_size = 2 stride = 2 op = ops.MRFOperator(SequentialEncoder((Identity(), )), patch_size, stride=stride) with self.subTest(enc="constant"): enc = torch.ones(1, 4, 8, 8) actual = op.enc_to_repr(enc, is_guided=True) desired = torch.ones(0, 4, stride, stride) self.assertTensorAlmostEqual(actual, desired) with self.subTest(enc="spatial_mix"): constant = torch.ones(1, 4, 4, 8) varying = torch.rand(1, 4, 4, 8) enc = torch.cat((constant, varying), dim=2) actual = op.enc_to_repr(enc, is_guided=True) desired = pystiche.extract_patches2d(varying, patch_size, stride=stride) self.assertTensorAlmostEqual(actual, desired) with self.subTest(enc="channel_mix"): constant = torch.ones(1, 2, 8, 8) varying = torch.rand(1, 2, 8, 8) enc = torch.cat((constant, varying), dim=1) actual = op.enc_to_repr(enc, is_guided=True) desired = pystiche.extract_patches2d(enc, patch_size, stride=stride) self.assertTensorAlmostEqual(actual, desired) with self.subTest(enc="varying"): enc = torch.rand(1, 4, 8, 8) actual = op.enc_to_repr(enc, is_guided=True) desired = pystiche.extract_patches2d(enc, patch_size, stride=stride) self.assertTensorAlmostEqual(actual, desired)
def test_enc_to_repr_guided_varying(self, mrf_op): enc_ = torch.rand(1, 4, 8, 8) actual = mrf_op.enc_to_repr(enc_, is_guided=True) desired = pystiche.extract_patches2d( enc_, mrf_op.patch_size[0], stride=mrf_op.stride[0] ) ptu.assert_allclose(actual, desired)
def enc_to_repr(self, enc: torch.Tensor, is_guided: bool) -> torch.Tensor: if self.normalize_patches_grad: repr = extract_normalized_patches2d(enc, self.patch_size, self.stride) else: repr = pystiche.extract_patches2d(enc, self.patch_size, self.stride) if not is_guided: return repr return self._guide_repr(repr)
def test_enc_to_repr_guided_channel_mix(self, mrf_op): constant = torch.ones(1, 2, 8, 8) varying = torch.rand(1, 2, 8, 8) enc_ = torch.cat((constant, varying), dim=1) actual = mrf_op.enc_to_repr(enc_, is_guided=True) desired = pystiche.extract_patches2d( enc_, mrf_op.patch_size[0], stride=mrf_op.stride[0] ) ptu.assert_allclose(actual, desired)
def test_call(self): patch_size = 3 stride = 2 torch.manual_seed(0) target_image = torch.rand(1, 3, 32, 32) input_image = torch.rand(1, 3, 32, 32) encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1),)) op = ops.MRFOperator(encoder, patch_size, stride=stride) op.set_target_image(target_image) actual = op(input_image) desired = F.mrf_loss( pystiche.extract_patches2d(encoder(input_image), patch_size, stride=stride), pystiche.extract_patches2d( encoder(target_image), patch_size, stride=stride ), ) ptu.assert_allclose(actual, desired)
def test_MRFOperator_call(self): patch_size = 3 stride = 2 torch.manual_seed(0) target_image = torch.rand(1, 3, 32, 32) input_image = torch.rand(1, 3, 32, 32) encoder = SequentialEncoder((nn.Conv2d(3, 3, 1),)) op = ops.MRFOperator(encoder, patch_size, stride=stride) op.set_target_image(target_image) actual = op(input_image) desired = F.patch_matching_loss( pystiche.extract_patches2d(encoder(input_image), patch_size, stride=stride), pystiche.extract_patches2d( encoder(target_image), patch_size, stride=stride ), ) self.assertFloatAlmostEqual(actual, desired)
def test_call(self, encoder): patch_size = 3 stride = 2 torch.manual_seed(0) target_image = torch.rand(1, 3, 32, 32) input_image = torch.rand(1, 3, 32, 32) loss = loss_.MRFLoss(encoder, patch_size, stride=stride) loss.set_target_image(target_image) actual = loss(input_image) desired = F.mrf_loss( pystiche.extract_patches2d(encoder(input_image), patch_size, stride=stride), pystiche.extract_patches2d( encoder(target_image), patch_size, stride=stride ), batched_input=True, ) ptu.assert_allclose(actual, desired)
def test_MRFOperator_target_image_to_repr(self): patch_size = 3 stride = 2 scale_step_width = 10e-2 rotation_step_width = 30.0 torch.manual_seed(0) image = torch.rand(1, 3, 32, 32) encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), )) op = ops.MRFOperator( encoder, patch_size, stride=stride, num_scale_steps=1, scale_step_width=scale_step_width, num_rotation_steps=1, rotation_step_width=rotation_step_width, ) op.set_target_image(image) actual = op.target_repr reprs = [] factors = (1.0 - scale_step_width, 1.0, 1.0 + scale_step_width) angles = (-rotation_step_width, 0.0, rotation_step_width) for factor, angle in itertools.product(factors, angles): transformed_image = transform_motif_affinely(image, rotation_angle=angle, scaling_factor=factor) enc = encoder(transformed_image) repr = pystiche.extract_patches2d(enc, patch_size, stride) reprs.append(repr) desired = torch.cat(reprs) self.assertTensorAlmostEqual(actual, desired)
def enc_to_repr(self, enc: torch.Tensor, is_guided: bool) -> torch.Tensor: repr = pystiche.extract_patches2d(enc, self.patch_size, self.stride) if not is_guided: return repr return self._guide_repr(repr)
def extract_normalized_patches2d( input: torch.Tensor, patch_size: Tuple[int, int], stride: Tuple[int, int] ) -> torch.Tensor: for dim, size, step in zip(range(2, input.dim()), patch_size, stride): input = normalize_unfold_grad(input, dim, size, step) return pystiche.extract_patches2d(input, patch_size, stride)
def extract_normalized_patches2d( input: torch.Tensor, patch_size: Union[int, Sequence[int]], stride: Union[int, Sequence[int]], ) -> torch.Tensor: r"""Extract 2-dimensional patches from the input with normalized gradient. If ``stride >= patch_size``, this behaves just like :func:`pystiche.extract_patches2d`. Otherwise, the gradient of the input is normalized such that every value is divided by the number of patches it appears in. Examples: >>> import torch >>> import pystiche >>> input = torch.ones(1, 1, 4, 4).requires_grad_(True) >>> target = torch.zeros(1, 1, 4, 4).detach() >>> # without normalized gradient >>> input_patches = pystiche.extract_patches2d( ... input, patch_size=2, stride=1 ... ) >>> target_patches = pystiche.extract_patches2d( ... target, patch_size=2, stride=1 ... ) >>> loss = 0.5 * torch.sum((input_patches - target_patches) ** 2.0) >>> loss.backward() >>> input.grad tensor([[[[1., 2., 2., 1.], [2., 4., 4., 2.], [2., 4., 4., 2.], [1., 2., 2., 1.]]]]) >>> import torch >>> import pystiche >>> import pystiche_papers.li_wand_2016 as paper >>> input = torch.ones(1, 1, 4, 4).requires_grad_(True) >>> target = torch.zeros(1, 1, 4, 4).detach() >>> # with normalized gradient >>> input_patches = paper.extract_normalized_patches2d( ... input, patch_size=2, stride=1 ... ) >>> target_patches = pystiche.extract_patches2d( ... target, patch_size=2, stride=1 ... ) >>> loss = 0.5 * torch.sum((input_patches - target_patches) ** 2.0) >>> loss.backward() >>> input.grad tensor([[[[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.]]]]) Args: input: Input tensor of shape :math:`B \times C \times H \times W` patch_size: Patch size stride: Stride """ patch_size = misc.to_2d_arg(patch_size) stride = misc.to_2d_arg(stride) for dim, size, step in zip(range(2, input.dim()), patch_size, stride): input = normalize_unfold_grad(input, dim, size, step) return pystiche.extract_patches2d(input, patch_size, stride)