def test_EncodingComparisonOperator_call_guided(self): class TestOperator(ops.EncodingComparisonOperator): def target_enc_to_repr(self, image): repr = image + 1.0 return repr, None def input_enc_to_repr(self, image, ctx): return image + 2.0 def calculate_score(self, input_repr, target_repr, ctx): return input_repr * target_repr torch.manual_seed(0) target_image = torch.rand(1, 3, 32, 32) input_image = torch.rand(1, 3, 32, 32) target_guide = torch.rand(1, 1, 32, 32) input_guide = torch.rand(1, 1, 32, 32) encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), )) target_enc_guide = encoder.propagate_guide(target_guide) input_enc_guide = encoder.propagate_guide(input_guide) test_op = TestOperator(encoder) test_op.set_target_guide(target_guide) test_op.set_target_image(target_image) test_op.set_input_guide(input_guide) actual = test_op(input_image) desired = ( TestOperator.apply_guide(encoder(target_image), target_enc_guide) + 1.0) * (TestOperator.apply_guide(encoder(input_image), input_enc_guide) + 2.0) self.assertTensorAlmostEqual(actual, desired)
def test_EncodingComparisonOperator_set_target_image(self): class TestOperator(ops.EncodingComparisonOperator): def target_enc_to_repr(self, image): repr = image * 2.0 ctx = torch.norm(image) return repr, ctx def input_enc_to_repr(self, image, ctx): pass def calculate_score(self, input_repr, target_repr, ctx): pass torch.manual_seed(0) image = torch.rand(1, 3, 128, 128) encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), )) test_op = TestOperator(encoder) self.assertFalse(test_op.has_target_image) test_op.set_target_image(image) self.assertTrue(test_op.has_target_image) actual = test_op.target_image desired = image self.assertTensorAlmostEqual(actual, desired) actual = test_op.target_repr desired = encoder(image) * 2.0 self.assertTensorAlmostEqual(actual, desired) actual = test_op.ctx desired = torch.norm(encoder(image)) self.assertTensorAlmostEqual(actual, desired)
def test_MRFOperator_call_guided(self): patch_size = 2 stride = 2 torch.manual_seed(0) target_image = torch.rand(1, 3, 32, 32) input_image = torch.rand(1, 3, 32, 32) target_guide = torch.cat( (torch.zeros(1, 1, 16, 32), torch.ones(1, 1, 16, 32)), dim=2) input_guide = target_guide.flip(2) encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), )) op = ops.MRFOperator(encoder, patch_size, stride=stride) op.set_target_guide(target_guide) op.set_target_image(target_image) op.set_input_guide(input_guide) actual = op(input_image) input_enc = encoder(input_image)[:, :, :16, :] target_enc = encoder(target_image)[:, :, 16:, :] desired = F.patch_matching_loss( pystiche.extract_patches2d(input_enc, patch_size, stride=stride), pystiche.extract_patches2d(target_enc, patch_size, stride=stride), ) self.assertFloatAlmostEqual(actual, desired)
def test_EncodingComparisonOperator_set_target_guide_without_recalc(self): class TestOperator(ops.EncodingComparisonOperator): def target_enc_to_repr(self, image): repr = image * 2.0 ctx = torch.norm(image) return repr, ctx def input_enc_to_repr(self, image, ctx): pass def calculate_score(self, input_repr, target_repr, ctx): pass torch.manual_seed(0) repr = torch.rand(1, 3, 32, 32) guide = torch.rand(1, 1, 32, 32) encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), )) test_op = TestOperator(encoder) test_op.register_buffer("target_repr", repr) test_op.set_target_guide(guide, recalc_repr=False) actual = test_op.target_repr desired = repr self.assertTensorAlmostEqual(actual, desired)
def test_PerceptualLoss_set_style_image(self): torch.manual_seed(0) image = torch.rand(1, 1, 100, 100) content_loss = FeatureReconstructionOperator( SequentialEncoder((nn.Conv2d(1, 1, 1),)) ) style_loss = FeatureReconstructionOperator( SequentialEncoder((nn.Conv2d(1, 1, 1),)) ) perceptual_loss = loss.PerceptualLoss(content_loss, style_loss) perceptual_loss.set_style_image(image) actual = style_loss.target_image desired = image self.assertTensorAlmostEqual(actual, desired)
def test_EncodingRegularizationOperator_call_guided(self): class TestOperator(ops.EncodingRegularizationOperator): def input_enc_to_repr(self, image): return image * 2.0 def calculate_score(self, input_repr): return input_repr + 1.0 torch.manual_seed(0) image = torch.rand(1, 3, 32, 32) guide = torch.rand(1, 3, 32, 32) encoder = SequentialEncoder((nn.Conv2d(3, 3, 1),)) enc_guide = encoder.propagate_guide(guide) test_op = TestOperator(encoder) test_op.set_input_guide(guide) actual = test_op(image) desired = TestOperator.apply_guide(encoder(image), enc_guide) * 2.0 + 1.0 self.assertTensorAlmostEqual(actual, desired)
def test_MSEEncodingOperator_call(self): torch.manual_seed(0) target_image = torch.rand(1, 3, 128, 128) input_image = torch.rand(1, 3, 128, 128) encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), )) op = ops.MSEEncodingOperator(encoder) op.set_target_image(target_image) actual = op(input_image) desired = mse_loss(encoder(input_image), encoder(target_image)) self.assertTensorAlmostEqual(actual, desired)
def test_MRFOperator_enc_to_repr_guided(self): class Identity(pystiche.Module): def forward(self, image): return image patch_size = 2 stride = 2 op = ops.MRFOperator(SequentialEncoder((Identity(), )), patch_size, stride=stride) with self.subTest(enc="constant"): enc = torch.ones(1, 4, 8, 8) actual = op.enc_to_repr(enc, is_guided=True) desired = torch.ones(0, 4, stride, stride) self.assertTensorAlmostEqual(actual, desired) with self.subTest(enc="spatial_mix"): constant = torch.ones(1, 4, 4, 8) varying = torch.rand(1, 4, 4, 8) enc = torch.cat((constant, varying), dim=2) actual = op.enc_to_repr(enc, is_guided=True) desired = pystiche.extract_patches2d(varying, patch_size, stride=stride) self.assertTensorAlmostEqual(actual, desired) with self.subTest(enc="channel_mix"): constant = torch.ones(1, 2, 8, 8) varying = torch.rand(1, 2, 8, 8) enc = torch.cat((constant, varying), dim=1) actual = op.enc_to_repr(enc, is_guided=True) desired = pystiche.extract_patches2d(enc, patch_size, stride=stride) self.assertTensorAlmostEqual(actual, desired) with self.subTest(enc="varying"): enc = torch.rand(1, 4, 8, 8) actual = op.enc_to_repr(enc, is_guided=True) desired = pystiche.extract_patches2d(enc, patch_size, stride=stride) self.assertTensorAlmostEqual(actual, desired)
def test_GramOperator_call(self): torch.manual_seed(0) target_image = torch.rand(1, 3, 128, 128) input_image = torch.rand(1, 3, 128, 128) encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), )) op = ops.GramOperator(encoder) op.set_target_image(target_image) actual = op(input_image) desired = mse_loss( pystiche.gram_matrix(encoder(input_image), normalize=True), pystiche.gram_matrix(encoder(target_image), normalize=True), ) self.assertTensorAlmostEqual(actual, desired)
def test_MRFOperator_set_target_guide_without_recalc(self): patch_size = 3 stride = 2 torch.manual_seed(0) repr = torch.rand(1, 3, 32, 32) guide = torch.rand(1, 1, 32, 32) encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), )) op = ops.MRFOperator(encoder, patch_size, stride=stride) op.register_buffer("target_repr", repr) op.set_target_guide(guide, recalc_repr=False) actual = op.target_repr desired = repr self.assertTensorAlmostEqual(actual, desired)
def test_EncodingRegularizationOperator_call(self): class TestOperator(ops.EncodingRegularizationOperator): def input_enc_to_repr(self, image): return image * 2.0 def calculate_score(self, input_repr): return input_repr + 1.0 torch.manual_seed(0) image = torch.rand(1, 3, 128, 128) encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), )) test_op = TestOperator(encoder) actual = test_op(image) desired = encoder(image) * 2.0 + 1.0 self.assertTensorAlmostEqual(actual, desired)
def test_EncodingComparisonOperator_call_no_target(self): class TestOperator(ops.EncodingComparisonOperator): def target_enc_to_repr(self, image): pass def input_enc_to_repr(self, image, ctx): pass def calculate_score(self, input_repr, target_repr, ctx): pass torch.manual_seed(0) input_image = torch.rand(1, 3, 128, 128) encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), )) test_op = TestOperator(encoder) with self.assertRaises(RuntimeError): test_op(input_image)
def test_MRFOperator_call(self): patch_size = 3 stride = 2 torch.manual_seed(0) target_image = torch.rand(1, 3, 32, 32) input_image = torch.rand(1, 3, 32, 32) encoder = SequentialEncoder((nn.Conv2d(3, 3, 1),)) op = ops.MRFOperator(encoder, patch_size, stride=stride) op.set_target_image(target_image) actual = op(input_image) desired = F.patch_matching_loss( pystiche.extract_patches2d(encoder(input_image), patch_size, stride=stride), pystiche.extract_patches2d( encoder(target_image), patch_size, stride=stride ), ) self.assertFloatAlmostEqual(actual, desired)
def test_MRFOperator_set_target_guide(self): patch_size = 3 stride = 2 torch.manual_seed(0) image = torch.rand(1, 3, 32, 32) guide = torch.rand(1, 1, 32, 32) encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), )) op = ops.MRFOperator(encoder, patch_size, stride=stride) op.set_target_image(image) self.assertFalse(op.has_target_guide) op.set_target_guide(guide) self.assertTrue(op.has_target_guide) actual = op.target_guide desired = guide self.assertTensorAlmostEqual(actual, desired) actual = op.target_image desired = image self.assertTensorAlmostEqual(actual, desired)
def test_MRFOperator_target_image_to_repr(self): patch_size = 3 stride = 2 scale_step_width = 10e-2 rotation_step_width = 30.0 torch.manual_seed(0) image = torch.rand(1, 3, 32, 32) encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), )) op = ops.MRFOperator( encoder, patch_size, stride=stride, num_scale_steps=1, scale_step_width=scale_step_width, num_rotation_steps=1, rotation_step_width=rotation_step_width, ) op.set_target_image(image) actual = op.target_repr reprs = [] factors = (1.0 - scale_step_width, 1.0, 1.0 + scale_step_width) angles = (-rotation_step_width, 0.0, rotation_step_width) for factor, angle in itertools.product(factors, angles): transformed_image = transform_motif_affinely(image, rotation_angle=angle, scaling_factor=factor) enc = encoder(transformed_image) repr = pystiche.extract_patches2d(enc, patch_size, stride) reprs.append(repr) desired = torch.cat(reprs) self.assertTensorAlmostEqual(actual, desired)
def get_guided_perceptual_loss(): content_loss = FeatureReconstructionOperator( SequentialEncoder((nn.Conv2d(1, 1, 1),)) ) style_loss = MultiRegionOperator(regions, get_op) return loss.GuidedPerceptualLoss(content_loss, style_loss)