예제 #1
0
파일: utils.py 프로젝트: ahatamiz/MONAI
def assert_allclose(
    actual: NdarrayOrTensor,
    desired: NdarrayOrTensor,
    type_test: Union[bool, str] = True,
    device_test: bool = False,
    *args,
    **kwargs,
):
    """
    Assert that types and all values of two data objects are close.

    Args:
        actual: Pytorch Tensor or numpy array for comparison.
        desired: Pytorch Tensor or numpy array to compare against.
        type_test: whether to test that `actual` and `desired` are both numpy arrays or torch tensors.
            if type_test == "tensor", it checks whether the `actual` is a torch.tensor or metatensor according to
            `get_track_meta`.
        device_test: whether to test the device property.
        args: extra arguments to pass on to `np.testing.assert_allclose`.
        kwargs: extra arguments to pass on to `np.testing.assert_allclose`.


    """
    if isinstance(type_test, str) and type_test == "tensor":
        if get_track_meta():
            np.testing.assert_equal(isinstance(actual, MetaTensor), True,
                                    "must be a MetaTensor")
        else:
            np.testing.assert_equal(
                isinstance(actual, torch.Tensor)
                and not isinstance(actual, MetaTensor), True,
                "must be a torch.Tensor")
    elif type_test:
        # check both actual and desired are of the same type
        np.testing.assert_equal(isinstance(actual, np.ndarray),
                                isinstance(desired, np.ndarray), "numpy type")
        np.testing.assert_equal(isinstance(actual, torch.Tensor),
                                isinstance(desired, torch.Tensor),
                                "torch type")

    if isinstance(desired, torch.Tensor) or isinstance(actual, torch.Tensor):
        if device_test:
            np.testing.assert_equal(str(actual.device), str(desired.device),
                                    "torch device check")  # type: ignore
        actual = actual.detach().cpu().numpy() if isinstance(
            actual, torch.Tensor) else actual
        desired = desired.detach().cpu().numpy() if isinstance(
            desired, torch.Tensor) else desired
    np.testing.assert_allclose(actual, desired, *args, **kwargs)
예제 #2
0
    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
        """
        Filter the image on the `applied_labels`.

        Args:
            img: Pytorch tensor or numpy array of any shape.

        Raises:
            NotImplementedError: The provided image was not a Pytorch Tensor or numpy array.

        Returns:
            Pytorch tensor or numpy array of the same shape as the input.
        """
        if not isinstance(img, (np.ndarray, torch.Tensor)):
            raise NotImplementedError(
                f"{self.__class__} can not handle data of type {type(img)}.")

        if isinstance(img, torch.Tensor):
            if hasattr(torch, "isin"):  # `isin` is new in torch 1.10.0
                appl_lbls = torch.as_tensor(self.applied_labels,
                                            device=img.device)
                return torch.where(torch.isin(img, appl_lbls), img,
                                   torch.tensor(0.0).to(img))
            else:
                out = self(img.detach().cpu().numpy())
                out, *_ = convert_to_dst_type(out, img)
                return out
        return np.asarray(np.where(np.isin(img, self.applied_labels), img, 0))
예제 #3
0
def assert_allclose(
    actual: NdarrayOrTensor,
    desired: NdarrayOrTensor,
    type_test: bool = True,
    device_test: bool = False,
    *args,
    **kwargs,
):
    """
    Assert that types and all values of two data objects are close.

    Args:
        actual: Pytorch Tensor or numpy array for comparison.
        desired: Pytorch Tensor or numpy array to compare against.
        type_test: whether to test that `actual` and `desired` are both numpy arrays or torch tensors.
        device_test: whether to test the device property.
        args: extra arguments to pass on to `np.testing.assert_allclose`.
        kwargs: extra arguments to pass on to `np.testing.assert_allclose`.


    """
    if type_test:
        # check both actual and desired are of the same type
        np.testing.assert_equal(isinstance(actual, np.ndarray),
                                isinstance(desired, np.ndarray), "numpy type")
        np.testing.assert_equal(isinstance(actual, torch.Tensor),
                                isinstance(desired, torch.Tensor),
                                "torch type")

    if isinstance(desired, torch.Tensor) or isinstance(actual, torch.Tensor):
        if device_test:
            np.testing.assert_equal(str(actual.device), str(desired.device),
                                    "torch device check")  # type: ignore
        actual = actual.detach().cpu().numpy() if isinstance(
            actual, torch.Tensor) else actual
        desired = desired.detach().cpu().numpy() if isinstance(
            desired, torch.Tensor) else desired
    np.testing.assert_allclose(actual, desired, *args, **kwargs)