Example #1
0
    def test_MultiOperatorLoss_trim(self):
        class TestOperator(EncodingOperator):
            def __init__(self, encoder, **kwargs):
                super().__init__(**kwargs)
                self._encoder = encoder

            @property
            def encoder(self):
                return self._encoder

            def process_input_image(self, image):
                pass

        layers = [str(idx) for idx in range(3)]
        modules = [(layer, nn.Module()) for layer in layers]
        multi_layer_encoder = MultiLayerEncoder(modules)

        ops = ((
            "op",
            TestOperator(
                multi_layer_encoder.extract_single_layer_encoder(layers[0])),
        ), )
        loss.MultiOperatorLoss(ops, trim=True)

        self.assertTrue(layers[0] in multi_layer_encoder)
        for layer in layers[1:]:
            self.assertFalse(layer in multi_layer_encoder)
Example #2
0
    def test_MultiOperatorLoss_call_encode(self):
        class TestOperator(EncodingOperator):
            def __init__(self, encoder, **kwargs):
                super().__init__(**kwargs)
                self._encoder = encoder

            @property
            def encoder(self):
                return self._encoder

            def process_input_image(self, image):
                return torch.sum(image)

        count = ForwardPassCounter()
        modules = (("count", count), )
        multi_layer_encoder = MultiLayerEncoder(modules)

        ops = [(
            str(idx),
            TestOperator(
                multi_layer_encoder.extract_single_layer_encoder("count")),
        ) for idx in range(3)]
        multi_op_loss = loss.MultiOperatorLoss(ops)

        torch.manual_seed(0)
        input = torch.rand(1, 3, 128, 128)

        multi_op_loss(input)
        actual = count.count
        desired = 1
        self.assertEqual(actual, desired)

        multi_op_loss(input)
        actual = count.count
        desired = 2
        self.assertEqual(actual, desired)