def test_EncodingComparisonOperator_call_guided(self): class TestOperator(ops.EncodingComparisonOperator): def target_enc_to_repr(self, image): repr = image + 1.0 return repr, None def input_enc_to_repr(self, image, ctx): return image + 2.0 def calculate_score(self, input_repr, target_repr, ctx): return input_repr * target_repr torch.manual_seed(0) target_image = torch.rand(1, 3, 32, 32) input_image = torch.rand(1, 3, 32, 32) target_guide = torch.rand(1, 1, 32, 32) input_guide = torch.rand(1, 1, 32, 32) encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), )) target_enc_guide = encoder.propagate_guide(target_guide) input_enc_guide = encoder.propagate_guide(input_guide) test_op = TestOperator(encoder) test_op.set_target_guide(target_guide) test_op.set_target_image(target_image) test_op.set_input_guide(input_guide) actual = test_op(input_image) desired = ( TestOperator.apply_guide(encoder(target_image), target_enc_guide) + 1.0) * (TestOperator.apply_guide(encoder(input_image), input_enc_guide) + 2.0) self.assertTensorAlmostEqual(actual, desired)
def test_EncodingRegularizationOperator_call_guided(self): class TestOperator(ops.EncodingRegularizationOperator): def input_enc_to_repr(self, image): return image * 2.0 def calculate_score(self, input_repr): return input_repr + 1.0 torch.manual_seed(0) image = torch.rand(1, 3, 32, 32) guide = torch.rand(1, 3, 32, 32) encoder = SequentialEncoder((nn.Conv2d(3, 3, 1),)) enc_guide = encoder.propagate_guide(guide) test_op = TestOperator(encoder) test_op.set_input_guide(guide) actual = test_op(image) desired = TestOperator.apply_guide(encoder(image), enc_guide) * 2.0 + 1.0 self.assertTensorAlmostEqual(actual, desired)
def test_EncodingComparisonOperator_set_target_guide(self): class TestOperator(ops.EncodingComparisonOperator): def target_enc_to_repr(self, image): repr = image * 2.0 ctx = torch.norm(image) return repr, ctx def input_enc_to_repr(self, image, ctx): pass def calculate_score(self, input_repr, target_repr, ctx): pass torch.manual_seed(0) image = torch.rand(1, 3, 32, 32) guide = torch.rand(1, 1, 32, 32) encoder = SequentialEncoder((nn.Conv2d(3, 3, 1), )) enc_guide = encoder.propagate_guide(guide) test_op = TestOperator(encoder) test_op.set_target_image(image) self.assertFalse(test_op.has_target_guide) test_op.set_target_guide(guide) self.assertTrue(test_op.has_target_guide) actual = test_op.target_guide desired = guide self.assertTensorAlmostEqual(actual, desired) actual = test_op.target_enc_guide desired = enc_guide self.assertTensorAlmostEqual(actual, desired) actual = test_op.target_image desired = image self.assertTensorAlmostEqual(actual, desired)