def __setitem__(self, name: str, loss: Union[torch.Tensor, "LossDict"]) -> None: r"""Add a named loss to the entries. Args: name: Name of the loss. loss: If :class:`torch.Tensor`, it has to be scalar. If :class:`LossDict`, it is unpacked and the entries are added as level below of ``name``. Raises: TypeError: If loss is :class:`torch.Tensor` but isn't scalar. """ if isinstance(loss, torch.Tensor): if not is_scalar_tensor(loss): msg = "loss is a torch.Tensor but is not scalar." raise TypeError(msg) super().__setitem__(name, loss) elif isinstance(loss, LossDict): for child_name, child_loss in loss.items(): super().__setitem__(f"{name}.{child_name}", child_loss) else: msg = ( # type: ignore[unreachable] f"loss can be a scalar torch.Tensor or a pystiche.LossDict, but got " f"a {type(loss)} instead.") raise TypeError(msg)
def __setitem__(self, name: str, loss: Union[torch.Tensor, "LossDict"]) -> None: if isinstance(loss, torch.Tensor): if not is_scalar_tensor(loss): # FIXME raise TypeError super().__setitem__(name, loss) elif isinstance(loss, LossDict): for child_name, child_loss in loss.items(): super().__setitem__(f"{name}.{child_name}", child_loss) else: # FIXME raise TypeError
def test_is_scalar_tensor(self): for scalar_tensor in (torch.tensor(0.0), torch.empty(())): self.assertTrue(meta.is_scalar_tensor(scalar_tensor)) for nd_tensor in (torch.empty(0), torch.empty((0, ))): self.assertFalse(meta.is_scalar_tensor(nd_tensor))
def test_is_scalar_tensor(): for scalar_tensor in (torch.tensor(0.0), torch.empty(())): assert meta.is_scalar_tensor(scalar_tensor) for nd_tensor in (torch.empty(0), torch.empty((0, ))): assert not meta.is_scalar_tensor(nd_tensor)