Пример #1
0
    def test_LossDict_aggregate_max_depth_gt_0(self):
        def loss():
            return torch.tensor(1.0)

        loss_dict = pystiche.LossDict(
            (("0.0.0", loss()), ("0.0.1", loss()), ("0.1", loss()), ("1", loss()))
        )

        actual = loss_dict.aggregate(1)
        desired = pystiche.LossDict((("0", 3 * loss()), ("1", loss())))
        self.assertDictEqual(actual, desired)

        actual = loss_dict.aggregate(2)
        desired = pystiche.LossDict(
            (("0.0", 2 * loss()), ("0.1", loss()), ("1", loss()))
        )
        self.assertDictEqual(actual, desired)

        actual = loss_dict.aggregate(3)
        desired = loss_dict
        self.assertDictEqual(actual, desired)

        actual = loss_dict.aggregate(4)
        desired = loss_dict
        self.assertDictEqual(actual, desired)
Пример #2
0
    def test_setitem_LossDict(self):
        name = "loss"
        loss_dict = pystiche.LossDict()
        num_sub_losses = 3

        loss = pystiche.LossDict([(str(idx), torch.tensor(idx))
                                  for idx in range(num_sub_losses)])
        loss_dict[name] = loss

        for idx in range(num_sub_losses):
            actual = loss_dict[f"{name}.{idx}"]
            desired = loss[str(idx)]
            ptu.assert_allclose(actual, desired)
Пример #3
0
    def test_default_image_optim_log_fn_loss_dict_smoke(self):
        class MockOptimLogger:
            def __init__(self):
                self.msg = None

            @contextlib.contextmanager
            def environment(self, header):
                yield

            def message(self, msg):
                self.msg = msg

        loss_dict = pystiche.LossDict(
            (("a", torch.tensor(0.0)), ("b.c", torch.tensor(1.0))))

        log_freq = 1
        max_depth = 1
        optim_logger = MockOptimLogger()
        log_fn = optim.default_image_optim_log_fn(optim_logger,
                                                  log_freq=log_freq,
                                                  max_depth=max_depth)

        step = log_freq
        log_fn(step, loss_dict)

        actual = optim_logger.msg
        desired = loss_dict.format(max_depth=max_depth)
        self.assertEqual(actual, desired)
Пример #4
0
    def test_backward(self):
        losses = [
            torch.tensor(val, dtype=torch.float, requires_grad=True)
            for val in range(3)
        ]

        def zero_grad():
            for loss in losses:
                loss.grad = None

        def extract_grads():
            return [loss.grad.clone() for loss in losses]

        zero_grad()
        loss_dict = pystiche.LossDict([(str(idx), loss)
                                       for idx, loss in enumerate(losses)])
        loss_dict.backward()
        actuals = extract_grads()

        zero_grad()
        total = sum(losses)
        total.backward()
        desireds = extract_grads()

        for actual, desired in zip(actuals, desireds):
            ptu.assert_allclose(actual, desired)
Пример #5
0
    def test_item(self):
        losses = (1.0, 2.0)

        loss_dict = pystiche.LossDict([(f"loss{idx}", torch.tensor(val))
                                       for idx, val in enumerate(losses)])

        actual = loss_dict.item()
        desired = sum(losses)
        assert actual == pytest.approx(desired)
Пример #6
0
    def test_LossDict_item(self):
        losses = (1.0, 2.0)

        loss_dict = pystiche.LossDict([(f"loss{idx}", torch.tensor(val))
                                       for idx, val in enumerate(losses)])

        actual = loss_dict.item()
        desired = sum(losses)
        self.assertAlmostEqual(actual, desired)
Пример #7
0
    def test_call(self):
        input = torch.tensor(0.0)

        named_ops = [(str(idx), self.Operator(idx + 1.0)) for idx in range(3)]
        op_container = ops.OperatorContainer(named_ops)

        actual = op_container(input)
        desired = pystiche.LossDict([(name, input + op.bias) for name, op in named_ops])
        ptu.assert_allclose(actual, desired)
Пример #8
0
    def test_LossDict_total(self):
        loss1 = torch.tensor(1.0)
        loss2 = torch.tensor(2.0)

        loss_dict = pystiche.LossDict((("loss1", loss1), ("loss2", loss2)))

        actual = loss_dict.total()
        desired = loss1 + loss2
        self.assertTensorAlmostEqual(actual, desired)
Пример #9
0
    def test_total(self):
        loss1 = torch.tensor(1.0)
        loss2 = torch.tensor(2.0)

        loss_dict = pystiche.LossDict((("loss1", loss1), ("loss2", loss2)))

        actual = loss_dict.total()
        desired = loss1 + loss2
        ptu.assert_allclose(actual, desired)
Пример #10
0
    def test_setitem_Tensor(self):
        name = "loss"
        loss_dict = pystiche.LossDict()

        loss = torch.tensor(1.0)
        loss_dict[name] = loss

        actual = loss_dict[name]
        desired = loss
        ptu.assert_allclose(actual, desired)
Пример #11
0
    def test_LossDict_setitem_Tensor(self):
        name = "loss"
        loss_dict = pystiche.LossDict()

        loss = torch.tensor(1.0)
        loss_dict[name] = loss

        actual = loss_dict[name]
        desired = loss
        self.assertTensorAlmostEqual(actual, desired)
Пример #12
0
    def forward(self, input_image: torch.Tensor) -> pystiche.LossDict:
        for encoder in self._multi_layer_encoders:
            encoder.encode(input_image)

        loss = pystiche.LossDict([(name, op(input_image))
                                  for name, op in self.named_children()])

        for encoder in self._multi_layer_encoders:
            encoder.empty_storage()

        return loss
Пример #13
0
    def test_mul(self):
        losses = (2.0, 3.0)
        factor = 4.0

        loss_dict = pystiche.LossDict([(f"loss{idx}", torch.tensor(val))
                                       for idx, val in enumerate(losses)])
        loss_dict = loss_dict * factor

        for idx, loss in enumerate(losses):
            actual = float(loss_dict[f"loss{idx}"])
            desired = loss * factor
            assert actual == ptu.approx(desired)
Пример #14
0
    def test_OperatorContainer_call(self):
        class TestOperator(ops.Operator):
            def __init__(self, bias):
                super().__init__()
                self.bias = bias

            def process_input_image(self, image):
                return image + self.bias

        input = torch.tensor(0.0)

        named_ops = [(str(idx), TestOperator(idx + 1.0)) for idx in range(3)]
        op_container = ops.OperatorContainer(named_ops)

        actual = op_container(input)
        desired = pystiche.LossDict([(name, input + op.bias) for name, op in named_ops])
        self.assertTensorDictAlmostEqual(actual, desired)
Пример #15
0
def test_MultiOperatorLoss_call():
    class TestOperator(ops.Operator):
        def __init__(self, bias):
            super().__init__()
            self.bias = bias

        def process_input_image(self, image):
            return image + self.bias

    input = torch.tensor(0.0)

    named_ops = [(str(idx), TestOperator(idx + 1.0)) for idx in range(3)]
    multi_op_loss = loss.MultiOperatorLoss(named_ops)

    actual = multi_op_loss(input)
    desired = pystiche.LossDict([(name, input + op.bias)
                                 for name, op in named_ops])
    ptu.assert_allclose(actual, desired)
Пример #16
0
    def test_LossMeter_update_LossDict(self):
        actual_meter = optim.LossMeter("actual_meter")
        desired_meter = optim.LossMeter("desired_meter")

        losses = torch.arange(3, dtype=torch.float)
        loss_dict = pystiche.LossDict([(str(idx), loss)
                                       for idx, loss in enumerate(losses)])

        actual_meter.update(loss_dict)
        desired_meter.update(torch.sum(losses).item())

        for attr in (
                "count",
                "last_val",
                "global_sum",
                "global_min",
                "global_max",
                "global_avg",
                "local_avg",
        ):
            actual = getattr(actual_meter, attr)
            desired = getattr(desired_meter, attr)
            self.assertAlmostEqual(actual, desired)
Пример #17
0
    def test_setitem_other(self):
        name = "loss"
        loss_dict = pystiche.LossDict()

        with pytest.raises(TypeError):
            loss_dict[name] = 1.0
Пример #18
0
    def test_setitem_non_scalar_Tensor(self):
        name = "loss"
        loss_dict = pystiche.LossDict()

        with pytest.raises(TypeError):
            loss_dict[name] = torch.ones(1)
Пример #19
0
 def test_LossDict_float(self):
     loss_dict = pystiche.LossDict(
         (("a", torch.tensor(0.0)), ("b", torch.tensor(1.0))))
     self.assertAlmostEqual(float(loss_dict), loss_dict.item())
Пример #20
0
 def process_input_image(self, input_image: torch.Tensor) -> pystiche.LossDict:
     return pystiche.LossDict(
         [(name, op(input_image)) for name, op in self.named_children()]
     )
Пример #21
0
 def forward(self, input_image: torch.Tensor) -> pystiche.LossDict:
     with self._mle_handler(input_image):
         return pystiche.LossDict(
             [(name, op(input_image)) for name, op in self.named_children()]
         )
Пример #22
0
 def test_repr_smoke(self):
     loss_dict = pystiche.LossDict(
         (("a", torch.tensor(0.0)), ("b", torch.tensor(1.0))))
     assert isinstance(repr(loss_dict), str)
Пример #23
0
 def test_float(self):
     loss_dict = pystiche.LossDict(
         (("a", torch.tensor(0.0)), ("b", torch.tensor(1.0))))
     assert float(loss_dict) == pytest.approx(loss_dict.item())
Пример #24
0
    def test_LossDict_setitem_other(self):
        name = "loss"
        loss_dict = pystiche.LossDict()

        with self.assertRaises(TypeError):
            loss_dict[name] = 1.0