def test_LossDict_aggregate_max_depth_gt_0(self): def loss(): return torch.tensor(1.0) loss_dict = pystiche.LossDict( (("0.0.0", loss()), ("0.0.1", loss()), ("0.1", loss()), ("1", loss())) ) actual = loss_dict.aggregate(1) desired = pystiche.LossDict((("0", 3 * loss()), ("1", loss()))) self.assertDictEqual(actual, desired) actual = loss_dict.aggregate(2) desired = pystiche.LossDict( (("0.0", 2 * loss()), ("0.1", loss()), ("1", loss())) ) self.assertDictEqual(actual, desired) actual = loss_dict.aggregate(3) desired = loss_dict self.assertDictEqual(actual, desired) actual = loss_dict.aggregate(4) desired = loss_dict self.assertDictEqual(actual, desired)
def test_setitem_LossDict(self): name = "loss" loss_dict = pystiche.LossDict() num_sub_losses = 3 loss = pystiche.LossDict([(str(idx), torch.tensor(idx)) for idx in range(num_sub_losses)]) loss_dict[name] = loss for idx in range(num_sub_losses): actual = loss_dict[f"{name}.{idx}"] desired = loss[str(idx)] ptu.assert_allclose(actual, desired)
def test_default_image_optim_log_fn_loss_dict_smoke(self): class MockOptimLogger: def __init__(self): self.msg = None @contextlib.contextmanager def environment(self, header): yield def message(self, msg): self.msg = msg loss_dict = pystiche.LossDict( (("a", torch.tensor(0.0)), ("b.c", torch.tensor(1.0)))) log_freq = 1 max_depth = 1 optim_logger = MockOptimLogger() log_fn = optim.default_image_optim_log_fn(optim_logger, log_freq=log_freq, max_depth=max_depth) step = log_freq log_fn(step, loss_dict) actual = optim_logger.msg desired = loss_dict.format(max_depth=max_depth) self.assertEqual(actual, desired)
def test_backward(self): losses = [ torch.tensor(val, dtype=torch.float, requires_grad=True) for val in range(3) ] def zero_grad(): for loss in losses: loss.grad = None def extract_grads(): return [loss.grad.clone() for loss in losses] zero_grad() loss_dict = pystiche.LossDict([(str(idx), loss) for idx, loss in enumerate(losses)]) loss_dict.backward() actuals = extract_grads() zero_grad() total = sum(losses) total.backward() desireds = extract_grads() for actual, desired in zip(actuals, desireds): ptu.assert_allclose(actual, desired)
def test_item(self): losses = (1.0, 2.0) loss_dict = pystiche.LossDict([(f"loss{idx}", torch.tensor(val)) for idx, val in enumerate(losses)]) actual = loss_dict.item() desired = sum(losses) assert actual == pytest.approx(desired)
def test_LossDict_item(self): losses = (1.0, 2.0) loss_dict = pystiche.LossDict([(f"loss{idx}", torch.tensor(val)) for idx, val in enumerate(losses)]) actual = loss_dict.item() desired = sum(losses) self.assertAlmostEqual(actual, desired)
def test_call(self): input = torch.tensor(0.0) named_ops = [(str(idx), self.Operator(idx + 1.0)) for idx in range(3)] op_container = ops.OperatorContainer(named_ops) actual = op_container(input) desired = pystiche.LossDict([(name, input + op.bias) for name, op in named_ops]) ptu.assert_allclose(actual, desired)
def test_LossDict_total(self): loss1 = torch.tensor(1.0) loss2 = torch.tensor(2.0) loss_dict = pystiche.LossDict((("loss1", loss1), ("loss2", loss2))) actual = loss_dict.total() desired = loss1 + loss2 self.assertTensorAlmostEqual(actual, desired)
def test_total(self): loss1 = torch.tensor(1.0) loss2 = torch.tensor(2.0) loss_dict = pystiche.LossDict((("loss1", loss1), ("loss2", loss2))) actual = loss_dict.total() desired = loss1 + loss2 ptu.assert_allclose(actual, desired)
def test_setitem_Tensor(self): name = "loss" loss_dict = pystiche.LossDict() loss = torch.tensor(1.0) loss_dict[name] = loss actual = loss_dict[name] desired = loss ptu.assert_allclose(actual, desired)
def test_LossDict_setitem_Tensor(self): name = "loss" loss_dict = pystiche.LossDict() loss = torch.tensor(1.0) loss_dict[name] = loss actual = loss_dict[name] desired = loss self.assertTensorAlmostEqual(actual, desired)
def forward(self, input_image: torch.Tensor) -> pystiche.LossDict: for encoder in self._multi_layer_encoders: encoder.encode(input_image) loss = pystiche.LossDict([(name, op(input_image)) for name, op in self.named_children()]) for encoder in self._multi_layer_encoders: encoder.empty_storage() return loss
def test_mul(self): losses = (2.0, 3.0) factor = 4.0 loss_dict = pystiche.LossDict([(f"loss{idx}", torch.tensor(val)) for idx, val in enumerate(losses)]) loss_dict = loss_dict * factor for idx, loss in enumerate(losses): actual = float(loss_dict[f"loss{idx}"]) desired = loss * factor assert actual == ptu.approx(desired)
def test_OperatorContainer_call(self): class TestOperator(ops.Operator): def __init__(self, bias): super().__init__() self.bias = bias def process_input_image(self, image): return image + self.bias input = torch.tensor(0.0) named_ops = [(str(idx), TestOperator(idx + 1.0)) for idx in range(3)] op_container = ops.OperatorContainer(named_ops) actual = op_container(input) desired = pystiche.LossDict([(name, input + op.bias) for name, op in named_ops]) self.assertTensorDictAlmostEqual(actual, desired)
def test_MultiOperatorLoss_call(): class TestOperator(ops.Operator): def __init__(self, bias): super().__init__() self.bias = bias def process_input_image(self, image): return image + self.bias input = torch.tensor(0.0) named_ops = [(str(idx), TestOperator(idx + 1.0)) for idx in range(3)] multi_op_loss = loss.MultiOperatorLoss(named_ops) actual = multi_op_loss(input) desired = pystiche.LossDict([(name, input + op.bias) for name, op in named_ops]) ptu.assert_allclose(actual, desired)
def test_LossMeter_update_LossDict(self): actual_meter = optim.LossMeter("actual_meter") desired_meter = optim.LossMeter("desired_meter") losses = torch.arange(3, dtype=torch.float) loss_dict = pystiche.LossDict([(str(idx), loss) for idx, loss in enumerate(losses)]) actual_meter.update(loss_dict) desired_meter.update(torch.sum(losses).item()) for attr in ( "count", "last_val", "global_sum", "global_min", "global_max", "global_avg", "local_avg", ): actual = getattr(actual_meter, attr) desired = getattr(desired_meter, attr) self.assertAlmostEqual(actual, desired)
def test_setitem_other(self): name = "loss" loss_dict = pystiche.LossDict() with pytest.raises(TypeError): loss_dict[name] = 1.0
def test_setitem_non_scalar_Tensor(self): name = "loss" loss_dict = pystiche.LossDict() with pytest.raises(TypeError): loss_dict[name] = torch.ones(1)
def test_LossDict_float(self): loss_dict = pystiche.LossDict( (("a", torch.tensor(0.0)), ("b", torch.tensor(1.0)))) self.assertAlmostEqual(float(loss_dict), loss_dict.item())
def process_input_image(self, input_image: torch.Tensor) -> pystiche.LossDict: return pystiche.LossDict( [(name, op(input_image)) for name, op in self.named_children()] )
def forward(self, input_image: torch.Tensor) -> pystiche.LossDict: with self._mle_handler(input_image): return pystiche.LossDict( [(name, op(input_image)) for name, op in self.named_children()] )
def test_repr_smoke(self): loss_dict = pystiche.LossDict( (("a", torch.tensor(0.0)), ("b", torch.tensor(1.0)))) assert isinstance(repr(loss_dict), str)
def test_float(self): loss_dict = pystiche.LossDict( (("a", torch.tensor(0.0)), ("b", torch.tensor(1.0)))) assert float(loss_dict) == pytest.approx(loss_dict.item())
def test_LossDict_setitem_other(self): name = "loss" loss_dict = pystiche.LossDict() with self.assertRaises(TypeError): loss_dict[name] = 1.0