Example #1
0
    def __setitem__(self, name: str, loss: Union[torch.Tensor,
                                                 "LossDict"]) -> None:
        r"""Add a named loss to the entries.

        Args:
            name: Name of the loss.
            loss: If :class:`torch.Tensor`, it has to be scalar. If :class:`LossDict`,
                it is unpacked and the entries are added as level below of ``name``.

        Raises:
            TypeError: If loss is :class:`torch.Tensor` but isn't scalar.
        """
        if isinstance(loss, torch.Tensor):
            if not is_scalar_tensor(loss):
                msg = "loss is a torch.Tensor but is not scalar."
                raise TypeError(msg)
            super().__setitem__(name, loss)
        elif isinstance(loss, LossDict):
            for child_name, child_loss in loss.items():
                super().__setitem__(f"{name}.{child_name}", child_loss)
        else:
            msg = (  # type: ignore[unreachable]
                f"loss can be a scalar torch.Tensor or a pystiche.LossDict, but got "
                f"a {type(loss)} instead.")
            raise TypeError(msg)
Example #2
0
 def __setitem__(self, name: str, loss: Union[torch.Tensor,
                                              "LossDict"]) -> None:
     if isinstance(loss, torch.Tensor):
         if not is_scalar_tensor(loss):
             # FIXME
             raise TypeError
         super().__setitem__(name, loss)
     elif isinstance(loss, LossDict):
         for child_name, child_loss in loss.items():
             super().__setitem__(f"{name}.{child_name}", child_loss)
     else:
         # FIXME
         raise TypeError
Example #3
0
    def test_is_scalar_tensor(self):
        for scalar_tensor in (torch.tensor(0.0), torch.empty(())):
            self.assertTrue(meta.is_scalar_tensor(scalar_tensor))

        for nd_tensor in (torch.empty(0), torch.empty((0, ))):
            self.assertFalse(meta.is_scalar_tensor(nd_tensor))
Example #4
0
def test_is_scalar_tensor():
    for scalar_tensor in (torch.tensor(0.0), torch.empty(())):
        assert meta.is_scalar_tensor(scalar_tensor)

    for nd_tensor in (torch.empty(0), torch.empty((0, ))):
        assert not meta.is_scalar_tensor(nd_tensor)