def test_PerceptualLoss_set_content_image(): torch.manual_seed(0) image = torch.rand(1, 1, 100, 100) content_loss = ops.FeatureReconstructionOperator( enc.SequentialEncoder((nn.Conv2d(1, 1, 1), ))) style_loss = ops.FeatureReconstructionOperator( enc.SequentialEncoder((nn.Conv2d(1, 1, 1), ))) perceptual_loss = loss.PerceptualLoss(content_loss, style_loss) perceptual_loss.set_content_image(image) actual = content_loss.target_image desired = image ptu.assert_allclose(actual, desired)
def test_EncodingComparisonOperator_set_target_image(): class TestOperator(ops.EncodingComparisonOperator): def target_enc_to_repr(self, image): repr = image * 2.0 ctx = torch.norm(image) return repr, ctx def input_enc_to_repr(self, image, ctx): pass def calculate_score(self, input_repr, target_repr, ctx): pass torch.manual_seed(0) image = torch.rand(1, 3, 128, 128) encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), )) test_op = TestOperator(encoder) assert not test_op.has_target_image test_op.set_target_image(image) assert test_op.has_target_image actual = test_op.target_image desired = image ptu.assert_allclose(actual, desired) actual = test_op.target_repr desired = encoder(image) * 2.0 ptu.assert_allclose(actual, desired) actual = test_op.ctx desired = torch.norm(encoder(image)) ptu.assert_allclose(actual, desired)
def test_non_persistent_images(self): class TestOperator(ops.EncodingComparisonOperator): def target_enc_to_repr(self, enc): return enc, None def input_enc_to_repr(self, enc, ctx): pass def calculate_score(self, input_repr, target_repr, ctx): pass torch.manual_seed(0) target_image = torch.rand(1, 3, 32, 32) target_guide = torch.rand(1, 1, 32, 32) input_guide = torch.rand(1, 1, 32, 32) encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1),)) test_op = TestOperator(encoder) test_op.set_target_guide(target_guide) test_op.set_target_image(target_image) test_op.set_input_guide(input_guide) state_dict = test_op.state_dict() new_test_op = TestOperator(encoder) new_test_op.load_state_dict(state_dict, strict=True)
def test_EncodingComparisonOperator_set_target_guide_without_recalc(): class TestOperator(ops.EncodingComparisonOperator): def target_enc_to_repr(self, image): repr = image * 2.0 ctx = torch.norm(image) return repr, ctx def input_enc_to_repr(self, image, ctx): pass def calculate_score(self, input_repr, target_repr, ctx): pass torch.manual_seed(0) repr = torch.rand(1, 3, 32, 32) guide = torch.rand(1, 1, 32, 32) encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), )) test_op = TestOperator(encoder) test_op.register_buffer("target_repr", repr) test_op.set_target_guide(guide, recalc_repr=False) actual = test_op.target_repr desired = repr ptu.assert_allclose(actual, desired)
def test_EncodingComparisonOperator_call_batch_size_error(): class TestOperator(ops.EncodingComparisonOperator): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.batch_size_equal = False def target_enc_to_repr(self, enc): return enc, None def input_enc_to_repr(self, enc, ctx): return enc def calculate_score(self, input_repr, target_repr, ctx): pass torch.manual_seed(0) target_image = torch.rand(2, 1, 1, 1) input_image = torch.rand(1, 1, 1, 1) encoder = enc.SequentialEncoder((nn.Conv2d(1, 1, 1), )) test_op = TestOperator(encoder) test_op.set_target_image(target_image) with pytest.raises(RuntimeError): test_op(input_image)
def test_MRFOperator_call_guided(): patch_size = 2 stride = 2 torch.manual_seed(0) target_image = torch.rand(1, 3, 32, 32) input_image = torch.rand(1, 3, 32, 32) target_guide = torch.cat( (torch.zeros(1, 1, 16, 32), torch.ones(1, 1, 16, 32)), dim=2) input_guide = target_guide.flip(2) encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), )) op = ops.MRFOperator(encoder, patch_size, stride=stride) op.set_target_guide(target_guide) op.set_target_image(target_image) op.set_input_guide(input_guide) actual = op(input_image) input_enc = encoder(input_image)[:, :, :16, :] target_enc = encoder(target_image)[:, :, 16:, :] desired = F.mrf_loss( pystiche.extract_patches2d(input_enc, patch_size, stride=stride), pystiche.extract_patches2d(target_enc, patch_size, stride=stride), ) ptu.assert_allclose(actual, desired)
def test_EncodingComparisonOperator_call_batch_size_mismatch(): class TestOperator(ops.EncodingComparisonOperator): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.batch_size_equal = False def target_enc_to_repr(self, enc): return enc, None def input_enc_to_repr(self, enc, ctx): return enc def calculate_score(self, input_repr, target_repr, ctx): input_batch_size = input_repr.size()[0] target_batch_size = target_repr.size()[0] self.batch_size_equal = input_batch_size == target_batch_size return 0.0 torch.manual_seed(0) target_image = torch.rand(1, 1, 1, 1) input_image = torch.rand(2, 1, 1, 1) encoder = enc.SequentialEncoder((nn.Conv2d(1, 1, 1), )) test_op = TestOperator(encoder) test_op.set_target_image(target_image) test_op(input_image) assert test_op.batch_size_equal
def mrf_op(self): patch_size = 2 stride = 2 return ops.MRFOperator( enc.SequentialEncoder((self.Identity(),)), patch_size, stride=stride )
def test_EncodingComparisonOperator_call_guided(): class TestOperator(ops.EncodingComparisonOperator): def target_enc_to_repr(self, image): repr = image + 1.0 return repr, None def input_enc_to_repr(self, image, ctx): return image + 2.0 def calculate_score(self, input_repr, target_repr, ctx): return input_repr * target_repr torch.manual_seed(0) target_image = torch.rand(1, 3, 32, 32) input_image = torch.rand(1, 3, 32, 32) target_guide = torch.rand(1, 1, 32, 32) input_guide = torch.rand(1, 1, 32, 32) encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), )) target_enc_guide = encoder.propagate_guide(target_guide) input_enc_guide = encoder.propagate_guide(input_guide) test_op = TestOperator(encoder) test_op.set_target_guide(target_guide) test_op.set_target_image(target_image) test_op.set_input_guide(input_guide) actual = test_op(input_image) desired = ( TestOperator.apply_guide(encoder(target_image), target_enc_guide) + 1.0 ) * (TestOperator.apply_guide(encoder(input_image), input_enc_guide) + 2.0) ptu.assert_allclose(actual, desired)
def test_SequentialEncoder_call(self): torch.manual_seed(0) modules = (nn.Conv2d(3, 3, 3), nn.ReLU()) input = torch.rand(1, 3, 256, 256) pystiche_encoder = enc.SequentialEncoder(modules) torch_encoder = nn.Sequential(*modules) actual = pystiche_encoder(input) desired = torch_encoder(input) self.assertTensorAlmostEqual(actual, desired)
def test_SequentialEncoder_call(): torch.manual_seed(0) modules = (nn.Conv2d(3, 3, 3), nn.ReLU()) input = torch.rand(1, 3, 256, 256) pystiche_encoder = enc.SequentialEncoder(modules) torch_encoder = nn.Sequential(*modules) actual = pystiche_encoder(input) desired = torch_encoder(input) ptu.assert_allclose(actual, desired)
def test_FeatureReconstructionOperator_call(): torch.manual_seed(0) target_image = torch.rand(1, 3, 128, 128) input_image = torch.rand(1, 3, 128, 128) encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), )) op = ops.FeatureReconstructionOperator(encoder) op.set_target_image(target_image) actual = op(input_image) desired = mse_loss(encoder(input_image), encoder(target_image)) ptu.assert_allclose(actual, desired)
def test_GramOperator_call(): torch.manual_seed(0) target_image = torch.rand(1, 3, 128, 128) input_image = torch.rand(1, 3, 128, 128) encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), )) op = ops.GramOperator(encoder) op.set_target_image(target_image) actual = op(input_image) desired = mse_loss( pystiche.gram_matrix(encoder(input_image), normalize=True), pystiche.gram_matrix(encoder(target_image), normalize=True), ) ptu.assert_allclose(actual, desired)
def test_MRFOperator_set_target_guide_without_recalc(): patch_size = 3 stride = 2 torch.manual_seed(0) repr = torch.rand(1, 3, 32, 32) guide = torch.rand(1, 1, 32, 32) encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), )) op = ops.MRFOperator(encoder, patch_size, stride=stride) op.register_buffer("target_repr", repr) op.set_target_guide(guide, recalc_repr=False) actual = op.target_repr desired = repr ptu.assert_allclose(actual, desired)
def test_EncodingRegularizationOperator_call(): class TestOperator(ops.EncodingRegularizationOperator): def input_enc_to_repr(self, image): return image * 2.0 def calculate_score(self, input_repr): return input_repr + 1.0 torch.manual_seed(0) image = torch.rand(1, 3, 128, 128) encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), )) test_op = TestOperator(encoder) actual = test_op(image) desired = encoder(image) * 2.0 + 1.0 ptu.assert_allclose(actual, desired)
def test_set_target_guide_without_recalc(self): patch_size = 3 stride = 2 torch.manual_seed(0) image = torch.rand(1, 3, 32, 32) guide = torch.rand(1, 1, 32, 32) encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1),)) op = ops.MRFOperator(encoder, patch_size, stride=stride) op.set_target_image(image) desired = op.target_repr.clone() op.set_target_guide(guide, recalc_repr=False) actual = op.target_repr ptu.assert_allclose(actual, desired)
def test_MRFOperator_enc_to_repr_guided(subtests): class Identity(pystiche.Module): def forward(self, image): return image patch_size = 2 stride = 2 op = ops.MRFOperator(enc.SequentialEncoder((Identity(), )), patch_size, stride=stride) with subtests.test(enc="constant"): enc_ = torch.ones(1, 4, 8, 8) actual = op.enc_to_repr(enc_, is_guided=True) desired = torch.ones(0, 4, stride, stride) ptu.assert_allclose(actual, desired) with subtests.test(enc="spatial_mix"): constant = torch.ones(1, 4, 4, 8) varying = torch.rand(1, 4, 4, 8) enc_ = torch.cat((constant, varying), dim=2) actual = op.enc_to_repr(enc_, is_guided=True) desired = pystiche.extract_patches2d(varying, patch_size, stride=stride) ptu.assert_allclose(actual, desired) with subtests.test(enc="channel_mix"): constant = torch.ones(1, 2, 8, 8) varying = torch.rand(1, 2, 8, 8) enc_ = torch.cat((constant, varying), dim=1) actual = op.enc_to_repr(enc_, is_guided=True) desired = pystiche.extract_patches2d(enc_, patch_size, stride=stride) ptu.assert_allclose(actual, desired) with subtests.test(enc="varying"): enc_ = torch.rand(1, 4, 8, 8) actual = op.enc_to_repr(enc_, is_guided=True) desired = pystiche.extract_patches2d(enc_, patch_size, stride=stride) ptu.assert_allclose(actual, desired)
def test_EncodingComparisonOperator_call_no_target(): class TestOperator(ops.EncodingComparisonOperator): def target_enc_to_repr(self, image): pass def input_enc_to_repr(self, image, ctx): pass def calculate_score(self, input_repr, target_repr, ctx): pass torch.manual_seed(0) input_image = torch.rand(1, 3, 128, 128) encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), )) test_op = TestOperator(encoder) with pytest.raises(RuntimeError): test_op(input_image)
def test_call(self): patch_size = 3 stride = 2 torch.manual_seed(0) target_image = torch.rand(1, 3, 32, 32) input_image = torch.rand(1, 3, 32, 32) encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1),)) op = ops.MRFOperator(encoder, patch_size, stride=stride) op.set_target_image(target_image) actual = op(input_image) desired = F.mrf_loss( pystiche.extract_patches2d(encoder(input_image), patch_size, stride=stride), pystiche.extract_patches2d( encoder(target_image), patch_size, stride=stride ), ) ptu.assert_allclose(actual, desired)
def test_EncodingRegularizationOperator_call_guided(): class TestOperator(ops.EncodingRegularizationOperator): def input_enc_to_repr(self, image): return image * 2.0 def calculate_score(self, input_repr): return input_repr + 1.0 torch.manual_seed(0) image = torch.rand(1, 3, 32, 32) guide = torch.rand(1, 3, 32, 32) encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), )) enc_guide = encoder.propagate_guide(guide) test_op = TestOperator(encoder) test_op.set_input_guide(guide) actual = test_op(image) desired = TestOperator.apply_guide(encoder(image), enc_guide) * 2.0 + 1.0 ptu.assert_allclose(actual, desired)
def test_MRFOperator_set_target_guide(): patch_size = 3 stride = 2 torch.manual_seed(0) image = torch.rand(1, 3, 32, 32) guide = torch.rand(1, 1, 32, 32) encoder = enc.SequentialEncoder((nn.Conv2d(3, 3, 1), )) op = ops.MRFOperator(encoder, patch_size, stride=stride) op.set_target_image(image) assert not op.has_target_guide op.set_target_guide(guide) assert op.has_target_guide actual = op.target_guide desired = guide ptu.assert_allclose(actual, desired) actual = op.target_image desired = image ptu.assert_allclose(actual, desired)
def encoder(): return enc.SequentialEncoder((nn.Conv2d(3, 3, 1), ))
def get_guided_perceptual_loss(): content_loss = ops.FeatureReconstructionOperator( enc.SequentialEncoder((nn.Conv2d(1, 1, 1), ))) style_loss = ops.MultiRegionOperator(regions, get_op) return loss.GuidedPerceptualLoss(content_loss, style_loss)