Beispiel #1
0
def test_MultiOperatorLoss_trim():
    class TestOperator(ops.EncodingOperator):
        def __init__(self, encoder, **kwargs):
            super().__init__(**kwargs)
            self._encoder = encoder

        @property
        def encoder(self):
            return self._encoder

        def forward(self, image):
            pass

    layers = [str(idx) for idx in range(3)]
    modules = [(layer, nn.Module()) for layer in layers]
    multi_layer_encoder = enc.MultiLayerEncoder(modules)

    ops_ = ((
        "op",
        TestOperator(multi_layer_encoder.extract_encoder(layers[0])),
    ), )
    loss.MultiOperatorLoss(ops_, trim=True)

    assert layers[0] in multi_layer_encoder
    assert all(layer not in multi_layer_encoder for layer in layers[1:])
Beispiel #2
0
def test_MultiOperatorLoss_call_encode(forward_pass_counter):
    class TestOperator(ops.EncodingOperator):
        def __init__(self, encoder, **kwargs):
            super().__init__(**kwargs)
            self._encoder = encoder

        @property
        def encoder(self):
            return self._encoder

        def forward(self, image):
            return torch.sum(self.encoder(image))

    modules = (("count", forward_pass_counter), )
    multi_layer_encoder = enc.MultiLayerEncoder(modules)

    ops_ = [(
        str(idx),
        TestOperator(multi_layer_encoder.extract_encoder("count")),
    ) for idx in range(3)]
    multi_op_loss = loss.MultiOperatorLoss(ops_)

    torch.manual_seed(0)
    input = torch.rand(1, 3, 128, 128)

    multi_op_loss(input)
    actual = forward_pass_counter.count
    desired = 1
    assert actual == desired

    multi_op_loss(input)
    actual = forward_pass_counter.count
    desired = 2
    assert actual == desired
Beispiel #3
0
    def test_MultiOperatorLoss_trim(self):
        class TestOperator(EncodingOperator):
            def __init__(self, encoder, **kwargs):
                super().__init__(**kwargs)
                self._encoder = encoder

            @property
            def encoder(self):
                return self._encoder

            def process_input_image(self, image):
                pass

        layers = [str(idx) for idx in range(3)]
        modules = [(layer, nn.Module()) for layer in layers]
        multi_layer_encoder = MultiLayerEncoder(modules)

        ops = ((
            "op",
            TestOperator(
                multi_layer_encoder.extract_single_layer_encoder(layers[0])),
        ), )
        loss.MultiOperatorLoss(ops, trim=True)

        self.assertTrue(layers[0] in multi_layer_encoder)
        for layer in layers[1:]:
            self.assertFalse(layer in multi_layer_encoder)
Beispiel #4
0
    def test_MultiOperatorLoss_call_encode(self):
        class TestOperator(EncodingOperator):
            def __init__(self, encoder, **kwargs):
                super().__init__(**kwargs)
                self._encoder = encoder

            @property
            def encoder(self):
                return self._encoder

            def process_input_image(self, image):
                return torch.sum(image)

        count = ForwardPassCounter()
        modules = (("count", count),)
        multi_layer_encoder = MultiLayerEncoder(modules)

        ops = [
            (str(idx), TestOperator(multi_layer_encoder.extract_encoder("count")),)
            for idx in range(3)
        ]
        multi_op_loss = loss.MultiOperatorLoss(ops)

        torch.manual_seed(0)
        input = torch.rand(1, 3, 128, 128)

        multi_op_loss(input)
        actual = count.count
        desired = 1
        self.assertEqual(actual, desired)

        multi_op_loss(input)
        actual = count.count
        desired = 2
        self.assertEqual(actual, desired)
Beispiel #5
0
def test_MultiOperatorLoss():
    class TestOperator(ops.Operator):
        def process_input_image(self, image):
            pass

    named_ops = [(str(idx), TestOperator()) for idx in range(3)]
    multi_op_loss = loss.MultiOperatorLoss(named_ops)

    actuals = multi_op_loss.named_children()
    desireds = named_ops
    assert_named_modules_identical(actuals, desireds)
Beispiel #6
0
    def test_MultiOperatorLoss(self):
        class TestOperator(Operator):
            def process_input_image(self, image):
                pass

        named_ops = [(str(idx), TestOperator()) for idx in range(3)]
        multi_op_loss = loss.MultiOperatorLoss(named_ops)

        actuals = multi_op_loss.named_children()
        desireds = named_ops
        self.assertNamedChildrenEqual(actuals, desireds)
Beispiel #7
0
    def test_MultiOperatorLoss_call(self):
        class TestOperator(Operator):
            def __init__(self, bias):
                super().__init__()
                self.bias = bias

            def process_input_image(self, image):
                return image + self.bias

        input = torch.tensor(0.0)

        named_ops = [(str(idx), TestOperator(idx + 1.0)) for idx in range(3)]
        multi_op_loss = loss.MultiOperatorLoss(named_ops)

        actual = multi_op_loss(input)
        desired = pystiche.LossDict([(name, input + op.bias) for name, op in named_ops])
        self.assertTensorDictAlmostEqual(actual, desired)