def mrf_op(self): patch_size = 2 stride = 2 return ops.MRFOperator( enc.SequentialEncoder((self.Identity(),)), patch_size, stride=stride )
def test_MRFOperator_call_guided(): patch_size = 2 stride = 2 torch.manual_seed(0) target_image = torch.rand(1, 3, 32, 32) input_image = torch.rand(1, 3, 32, 32) target_guide = torch.cat( (torch.zeros(1, 1, 16, 32), torch.ones(1, 1, 16, 32)), dim=2) input_guide = target_guide.flip(2) encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), )) op = ops.MRFOperator(encoder, patch_size, stride=stride) op.set_target_guide(target_guide) op.set_target_image(target_image) op.set_input_guide(input_guide) actual = op(input_image) input_enc = encoder(input_image)[:, :, :16, :] target_enc = encoder(target_image)[:, :, 16:, :] desired = F.mrf_loss( pystiche.extract_patches2d(input_enc, patch_size, stride=stride), pystiche.extract_patches2d(target_enc, patch_size, stride=stride), ) ptu.assert_allclose(actual, desired)
def test_MRFOperator_enc_to_repr_guided(self): class Identity(pystiche.Module): def forward(self, image): return image patch_size = 2 stride = 2 op = ops.MRFOperator(SequentialEncoder((Identity(), )), patch_size, stride=stride) with self.subTest(enc="constant"): enc = torch.ones(1, 4, 8, 8) actual = op.enc_to_repr(enc, is_guided=True) desired = torch.ones(0, 4, stride, stride) self.assertTensorAlmostEqual(actual, desired) with self.subTest(enc="spatial_mix"): constant = torch.ones(1, 4, 4, 8) varying = torch.rand(1, 4, 4, 8) enc = torch.cat((constant, varying), dim=2) actual = op.enc_to_repr(enc, is_guided=True) desired = pystiche.extract_patches2d(varying, patch_size, stride=stride) self.assertTensorAlmostEqual(actual, desired) with self.subTest(enc="channel_mix"): constant = torch.ones(1, 2, 8, 8) varying = torch.rand(1, 2, 8, 8) enc = torch.cat((constant, varying), dim=1) actual = op.enc_to_repr(enc, is_guided=True) desired = pystiche.extract_patches2d(enc, patch_size, stride=stride) self.assertTensorAlmostEqual(actual, desired) with self.subTest(enc="varying"): enc = torch.rand(1, 4, 8, 8) actual = op.enc_to_repr(enc, is_guided=True) desired = pystiche.extract_patches2d(enc, patch_size, stride=stride) self.assertTensorAlmostEqual(actual, desired)
def test_MRFOperator_set_target_guide_without_recalc(): patch_size = 3 stride = 2 torch.manual_seed(0) repr = torch.rand(1, 3, 32, 32) guide = torch.rand(1, 1, 32, 32) encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), )) op = ops.MRFOperator(encoder, patch_size, stride=stride) op.register_buffer("target_repr", repr) op.set_target_guide(guide, recalc_repr=False) actual = op.target_repr desired = repr ptu.assert_allclose(actual, desired)
def test_set_target_guide_without_recalc(self): patch_size = 3 stride = 2 torch.manual_seed(0) image = torch.rand(1, 3, 32, 32) guide = torch.rand(1, 1, 32, 32) encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1),)) op = ops.MRFOperator(encoder, patch_size, stride=stride) op.set_target_image(image) desired = op.target_repr.clone() op.set_target_guide(guide, recalc_repr=False) actual = op.target_repr ptu.assert_allclose(actual, desired)
def test_call(self): patch_size = 3 stride = 2 torch.manual_seed(0) target_image = torch.rand(1, 3, 32, 32) input_image = torch.rand(1, 3, 32, 32) encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1),)) op = ops.MRFOperator(encoder, patch_size, stride=stride) op.set_target_image(target_image) actual = op(input_image) desired = F.mrf_loss( pystiche.extract_patches2d(encoder(input_image), patch_size, stride=stride), pystiche.extract_patches2d( encoder(target_image), patch_size, stride=stride ), ) ptu.assert_allclose(actual, desired)
def test_MRFOperator_call(self): patch_size = 3 stride = 2 torch.manual_seed(0) target_image = torch.rand(1, 3, 32, 32) input_image = torch.rand(1, 3, 32, 32) encoder = SequentialEncoder((nn.Conv2d(3, 3, 1),)) op = ops.MRFOperator(encoder, patch_size, stride=stride) op.set_target_image(target_image) actual = op(input_image) desired = F.patch_matching_loss( pystiche.extract_patches2d(encoder(input_image), patch_size, stride=stride), pystiche.extract_patches2d( encoder(target_image), patch_size, stride=stride ), ) self.assertFloatAlmostEqual(actual, desired)
def test_MRFOperator_set_target_guide(): patch_size = 3 stride = 2 torch.manual_seed(0) image = torch.rand(1, 3, 32, 32) guide = torch.rand(1, 1, 32, 32) encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), )) op = ops.MRFOperator(encoder, patch_size, stride=stride) op.set_target_image(image) assert not op.has_target_guide op.set_target_guide(guide) assert op.has_target_guide actual = op.target_guide desired = guide ptu.assert_allclose(actual, desired) actual = op.target_image desired = image ptu.assert_allclose(actual, desired)
def test_MRFOperator_set_target_guide(self): patch_size = 3 stride = 2 torch.manual_seed(0) image = torch.rand(1, 3, 32, 32) guide = torch.rand(1, 1, 32, 32) encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), )) op = ops.MRFOperator(encoder, patch_size, stride=stride) op.set_target_image(image) self.assertFalse(op.has_target_guide) op.set_target_guide(guide) self.assertTrue(op.has_target_guide) actual = op.target_guide desired = guide self.assertTensorAlmostEqual(actual, desired) actual = op.target_image desired = image self.assertTensorAlmostEqual(actual, desired)
def test_MRFOperator_target_image_to_repr(self): patch_size = 3 stride = 2 scale_step_width = 10e-2 rotation_step_width = 30.0 torch.manual_seed(0) image = torch.rand(1, 3, 32, 32) encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), )) op = ops.MRFOperator( encoder, patch_size, stride=stride, num_scale_steps=1, scale_step_width=scale_step_width, num_rotation_steps=1, rotation_step_width=rotation_step_width, ) op.set_target_image(image) actual = op.target_repr reprs = [] factors = (1.0 - scale_step_width, 1.0, 1.0 + scale_step_width) angles = (-rotation_step_width, 0.0, rotation_step_width) for factor, angle in itertools.product(factors, angles): transformed_image = transform_motif_affinely(image, rotation_angle=angle, scaling_factor=factor) enc = encoder(transformed_image) repr = pystiche.extract_patches2d(enc, patch_size, stride) reprs.append(repr) desired = torch.cat(reprs) self.assertTensorAlmostEqual(actual, desired)
def get_style_op(encoder, layer_weight): return ops.MRFOperator(encoder, patch_size=3, stride=2, score_weight=layer_weight)