Exemple #1
0
def write_png(
    data: np.ndarray,
    file_name: str,
    output_spatial_shape: Optional[Sequence[int]] = None,
    mode: str = InterpolateMode.BICUBIC,
    scale: Optional[int] = None,
) -> None:
    """
    Write numpy data into png files to disk.
    Spatially it supports HW for 2D.(H,W) or (H,W,3) or (H,W,4).
    If `scale` is None, expect the input data in `np.uint8` or `np.uint16` type.
    It's based on the Image module in PIL library:
    https://pillow.readthedocs.io/en/stable/reference/Image.html

    Args:
        data: input data to write to file.
        file_name: expected file name that saved on disk.
        output_spatial_shape: spatial shape of the output image.
        mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}
            The interpolation mode. Defaults to ``"bicubic"``.
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
        scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling to
            [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling.

    Raises:
        ValueError: When ``scale`` is not one of [255, 65535].

    .. deprecated:: 0.8
        Use :py:meth:`monai.data.PILWriter` instead.

    """
    if not isinstance(data, np.ndarray):
        raise ValueError("input data must be numpy array.")
    if len(data.shape) == 3 and data.shape[2] == 1:  # PIL Image can't save image with 1 channel
        data = data.squeeze(2)
    if output_spatial_shape is not None:
        output_spatial_shape_ = ensure_tuple_rep(output_spatial_shape, 2)
        mode = look_up_option(mode, InterpolateMode)
        align_corners = None if mode in (InterpolateMode.NEAREST, InterpolateMode.AREA) else False
        xform = Resize(spatial_size=output_spatial_shape_, mode=mode, align_corners=align_corners)
        _min, _max = np.min(data), np.max(data)
        if len(data.shape) == 3:
            data = np.moveaxis(data, -1, 0)  # to channel first
            data = xform(data)  # type: ignore
            data = np.moveaxis(data, 0, -1)
        else:  # (H, W)
            data = np.expand_dims(data, 0)  # make a channel
            data = xform(data)[0]  # type: ignore
        if mode != InterpolateMode.NEAREST:
            data = np.clip(data, _min, _max)

    if scale is not None:
        data = np.clip(data, 0.0, 1.0)  # png writer only can scale data in range [0, 1]
        if scale == np.iinfo(np.uint8).max:
            data = convert_data_type((scale * data), np.ndarray, dtype=np.uint8)[0]
        elif scale == np.iinfo(np.uint16).max:
            data = convert_data_type((scale * data), np.ndarray, dtype=np.uint16)[0]
        else:
            raise ValueError(f"Unsupported scale: {scale}, available options are [255, 65535]")

    # PNG data must be int number
    if data.dtype not in (np.uint8, np.uint16):
        data = data.astype(np.uint8, copy=False)

    data = np.moveaxis(data, 0, 1)
    img = Image.fromarray(data)
    img.save(file_name, "PNG")
    return
Exemple #2
0
    def __init__(
        self,
        transform: InvertibleTransform,
        loader: TorchDataLoader,
        output_keys: KeysCollection = CommonKeys.PRED,
        batch_keys: KeysCollection = CommonKeys.IMAGE,
        meta_keys: Optional[KeysCollection] = None,
        batch_meta_keys: Optional[KeysCollection] = None,
        meta_key_postfix: str = "meta_dict",
        collate_fn: Optional[Callable] = no_collation,
        nearest_interp: Union[bool, Sequence[bool]] = True,
        to_tensor: Union[bool, Sequence[bool]] = True,
        device: Union[Union[str, torch.device],
                      Sequence[Union[str, torch.device]]] = "cpu",
        post_func: Union[Callable, Sequence[Callable]] = lambda x: x,
        num_workers: Optional[int] = 0,
    ) -> None:
        """
        Args:
            transform: a callable data transform on input data.
            loader: data loader used to run transforms and generate the batch of data.
            output_keys: the key of expected data in `ignite.engine.output`, invert transforms on it.
                it also can be a list of keys, will invert transform for each of them.
                Default to "pred". it's in-place operation.
            batch_keys: the key of input data in `ignite.engine.batch`. will get the applied transforms
                for this input data, then invert them for the expected data with `output_keys`.
                It can also be a list of keys, each matches to the `output_keys` data. default to "image".
            meta_keys: explicitly indicate the key for the inverted meta data dictionary.
                the meta data is a dictionary object which contains: filename, original_shape, etc.
                it can be a sequence of string, map to the `keys`.
                if None, will try to construct meta_keys by `{key}_{meta_key_postfix}`.
            batch_meta_keys: the key of the meta data of input data in `ignite.engine.batch`,
                will get the `affine`, `data_shape`, etc.
                the meta data is a dictionary object which contains: filename, original_shape, etc.
                it can be a sequence of string, map to the `keys`.
                if None, will try to construct meta_keys by `{orig_key}_{meta_key_postfix}`.
                meta data will also be inverted and stored in `meta_keys`.
            meta_key_postfix: if `orig_meta_keys` is None, use `{orig_key}_{meta_key_postfix}` to to fetch the
                meta data from dict, if `meta_keys` is None, use `{key}_{meta_key_postfix}`.
                default is `meta_dict`, the meta data is a dictionary object.
                For example, to handle orig_key `image`,  read/write `affine` matrices from the
                metadata `image_meta_dict` dictionary's `affine` field.
                the inverted meta dict will be stored with key: "{key}_{meta_key_postfix}".
            collate_fn: how to collate data after inverse transformations. default won't do any collation,
                so the output will be a list of PyTorch Tensor or numpy array without batch dim.
            nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms,
                default to `True`. If `False`, use the same interpolation mode as the original transform.
                it also can be a list of bool, each matches to the `output_keys` data.
            to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`.
                it also can be a list of bool, each matches to the `output_keys` data.
            device: if converted to Tensor, move the inverted results to target device before `post_func`,
                default to "cpu", it also can be a list of string or `torch.device`,
                each matches to the `output_keys` data.
            post_func: post processing for the inverted data, should be a callable function.
                it also can be a list of callable, each matches to the `output_keys` data.
            num_workers: number of workers when run data loader for inverse transforms,
                default to 0 as only run one iteration and multi-processing may be even slower.
                Set to `None`, to use the `num_workers` of the input transform data loader.

        """
        self.inverter = Invertd(
            keys=output_keys,
            transform=transform,
            loader=loader,
            orig_keys=batch_keys,
            meta_keys=meta_keys,
            orig_meta_keys=batch_meta_keys,
            meta_key_postfix=meta_key_postfix,
            collate_fn=collate_fn,
            nearest_interp=nearest_interp,
            to_tensor=to_tensor,
            device=device,
            post_func=post_func,
            num_workers=num_workers,
        )
        self.output_keys = ensure_tuple(output_keys)
        self.meta_keys = ensure_tuple_rep(None, len(
            self.output_keys)) if meta_keys is None else ensure_tuple(
                meta_keys)
        if len(self.output_keys) != len(self.meta_keys):
            raise ValueError(
                "meta_keys should have the same length as output_keys.")
        self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix,
                                                 len(self.output_keys))
    def __init__(
        self,
        transform: InvertibleTransform,
        loader: TorchDataLoader,
        output_keys: Union[str, Sequence[str]] = CommonKeys.PRED,
        batch_keys: Union[str, Sequence[str]] = CommonKeys.IMAGE,
        meta_key_postfix: str = "meta_dict",
        collate_fn: Optional[Callable] = no_collation,
        postfix: str = "inverted",
        nearest_interp: Union[bool, Sequence[bool]] = True,
        to_tensor: Union[bool, Sequence[bool]] = True,
        device: Union[Union[str, torch.device],
                      Sequence[Union[str, torch.device]]] = "cpu",
        post_func: Union[Callable, Sequence[Callable]] = lambda x: x,
        num_workers: Optional[int] = 0,
    ) -> None:
        """
        Args:
            transform: a callable data transform on input data.
            loader: data loader used to run transforms and generate the batch of data.
            output_keys: the key of expected data in `ignite.engine.output`, invert transforms on it.
                it also can be a list of keys, will invert transform for each of them. Default to "pred".
            batch_keys: the key of input data in `ignite.engine.batch`. will get the applied transforms
                for this input data, then invert them for the expected data with `output_keys`.
                It can also be a list of keys, each matches to the `output_keys` data. default to "image".
            meta_key_postfix: use `{batch_key}_{postfix}` to to fetch the meta data according to the key data,
                default is `meta_dict`, the meta data is a dictionary object.
                For example, to handle key `image`,  read/write affine matrices from the
                metadata `image_meta_dict` dictionary's `affine` field.
            collate_fn: how to collate data after inverse transformations.
                default won't do any collation, so the output will be a list of size batch size.
            postfix: will save the inverted result into `ignite.engine.output` with key `{output_key}_{postfix}`.
            nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms,
                default to `True`. If `False`, use the same interpolation mode as the original transform.
                it also can be a list of bool, each matches to the `output_keys` data.
            to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`.
                it also can be a list of bool, each matches to the `output_keys` data.
            device: if converted to Tensor, move the inverted results to target device before `post_func`,
                default to "cpu", it also can be a list of string or `torch.device`,
                each matches to the `output_keys` data.
            post_func: post processing for the inverted data, should be a callable function.
                it also can be a list of callable, each matches to the `output_keys` data.
            num_workers: number of workers when run data loader for inverse transforms,
                default to 0 as only run one iteration and multi-processing may be even slower.
                Set to `None`, to use the `num_workers` of the input transform data loader.

        """
        self.transform = transform
        self.inverter = BatchInverseTransform(
            transform=transform,
            loader=loader,
            collate_fn=collate_fn,
            num_workers=num_workers,
        )
        self.output_keys = ensure_tuple(output_keys)
        self.batch_keys = ensure_tuple_rep(batch_keys, len(self.output_keys))
        self.meta_key_postfix = meta_key_postfix
        self.postfix = postfix
        self.nearest_interp = ensure_tuple_rep(nearest_interp,
                                               len(self.output_keys))
        self.to_tensor = ensure_tuple_rep(to_tensor, len(self.output_keys))
        self.device = ensure_tuple_rep(device, len(self.output_keys))
        self.post_func = ensure_tuple_rep(post_func, len(self.output_keys))
        self._totensor = ToTensor()
Exemple #4
0
    def __init__(
        self,
        keys: KeysCollection,
        argmax: Union[Sequence[bool], bool] = False,
        to_onehot: Union[Sequence[Optional[int]], Optional[int]] = None,
        threshold: Union[Sequence[Optional[float]], Optional[float]] = None,
        rounding: Union[Sequence[Optional[str]], Optional[str]] = None,
        allow_missing_keys: bool = False,
        n_classes: Optional[Union[Sequence[int], int]] = None,  # deprecated
        num_classes: Optional[Union[Sequence[int], int]] = None,  # deprecated
        logit_thresh: Union[Sequence[float], float] = 0.5,  # deprecated
        threshold_values: Union[Sequence[bool], bool] = False,  # deprecated
        **kwargs,
    ) -> None:
        """
        Args:
            keys: keys of the corresponding items to model output and label.
                See also: :py:class:`monai.transforms.compose.MapTransform`
            argmax: whether to execute argmax function on input data before transform.
                it also can be a sequence of bool, each element corresponds to a key in ``keys``.
            to_onehot: if not None, convert input data into the one-hot format with specified number of classes.
                defaults to ``None``. it also can be a sequence, each element corresponds to a key in ``keys``.
            threshold: if not None, threshold the float values to int number 0 or 1 with specified theashold value.
                defaults to ``None``. it also can be a sequence, each element corresponds to a key in ``keys``.
            rounding: if not None, round the data according to the specified option,
                available options: ["torchrounding"]. it also can be a sequence of str or None,
                each element corresponds to a key in ``keys``.
            allow_missing_keys: don't raise exception if key is missing.
            kwargs: additional parameters to ``AsDiscrete``.
                ``dim``, ``keepdim``, ``dtype`` are supported, unrecognized parameters will be ignored.
                These default to ``0``, ``True``, ``torch.float`` respectively.
        .. deprecated:: 0.6.0
            ``n_classes`` is deprecated, use ``to_onehot`` instead.
        .. deprecated:: 0.7.0
            ``num_classes`` is deprecated, use ``to_onehot`` instead.
            ``logit_thresh`` is deprecated, use ``threshold`` instead.
            ``threshold_values`` is deprecated, use ``threshold`` instead.
        """
        super().__init__(keys, allow_missing_keys)
        self.argmax = ensure_tuple_rep(argmax, len(self.keys))
        to_onehot_ = ensure_tuple_rep(to_onehot, len(self.keys))
        num_classes = ensure_tuple_rep(num_classes, len(self.keys))
        self.to_onehot = []
        for flag, val in zip(to_onehot_, num_classes):
            if isinstance(flag, bool):
                warnings.warn(
                    "`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead."
                )
                self.to_onehot.append(val if flag else None)
            else:
                self.to_onehot.append(flag)

        threshold_ = ensure_tuple_rep(threshold, len(self.keys))
        logit_thresh = ensure_tuple_rep(logit_thresh, len(self.keys))
        self.threshold = []
        for flag, val in zip(threshold_, logit_thresh):
            if isinstance(flag, bool):
                warnings.warn(
                    "`threshold_values=True/False` is deprecated, please use `threshold=value` instead."
                )
                self.threshold.append(val if flag else None)
            else:
                self.threshold.append(flag)

        self.rounding = ensure_tuple_rep(rounding, len(self.keys))
        self.converter = AsDiscreteEx()
        self.converter.kwargs = kwargs
Exemple #5
0
def generate_param_groups(
    network: torch.nn.Module,
    layer_matches: Sequence[Callable],
    match_types: Sequence[str],
    lr_values: Sequence[float],
    include_others: bool = True,
):
    """
    Utility function to generate parameter groups with different LR values for optimizer.
    The output parameter groups have the same order as `layer_match` functions.

    Args:
        network: source network to generate parameter groups from.
        layer_matches: a list of callable functions to select or filter out network layer groups,
            for "select" type, the input will be the `network`, for "filter" type,
            the input will be every item of `network.named_parameters()`.
            for "select", the parameters will be
            `select_func(network).parameters()`.
            for "filter", the parameters will be
            `(x[1] for x in filter(f, network.named_parameters()))`
        match_types: a list of tags to identify the matching type corresponding to the `layer_matches` functions,
            can be "select" or "filter".
        lr_values: a list of LR values corresponding to the `layer_matches` functions.
        include_others: whether to include the rest layers as the last group, default to True.

    It's mainly used to set different LR values for different network elements, for example:

    .. code-block:: python

        net = Unet(spatial_dims=3, in_channels=1, out_channels=3, channels=[2, 2, 2], strides=[1, 1, 1])
        print(net)  # print out network components to select expected items
        print(net.named_parameters())  # print out all the named parameters to filter out expected items
        params = generate_param_groups(
            network=net,
            layer_matches=[lambda x: x.model[0], lambda x: "2.0.conv" in x[0]],
            match_types=["select", "filter"],
            lr_values=[1e-2, 1e-3],
        )
        # the groups will be a list of dictionaries:
        # [{'params': <generator object Module.parameters at 0x7f9090a70bf8>, 'lr': 0.01},
        #  {'params': <filter object at 0x7f9088fd0dd8>, 'lr': 0.001},
        #  {'params': <filter object at 0x7f9088fd0da0>}]
        optimizer = torch.optim.Adam(params, 1e-4)

    """
    layer_matches = ensure_tuple(layer_matches)
    match_types = ensure_tuple_rep(match_types, len(layer_matches))
    lr_values = ensure_tuple_rep(lr_values, len(layer_matches))

    def _get_select(f):
        def _select():
            return f(network).parameters()

        return _select

    def _get_filter(f):
        def _filter():
            # should eventually generate a list of network parameters
            return (x[1] for x in filter(f, network.named_parameters()))

        return _filter

    params = []
    _layers = []
    for func, ty, lr in zip(layer_matches, match_types, lr_values):
        if ty.lower() == "select":
            layer_params = _get_select(func)
        elif ty.lower() == "filter":
            layer_params = _get_filter(func)
        else:
            raise ValueError(f"unsupported layer match type: {ty}.")

        params.append({"params": layer_params(), "lr": lr})
        _layers.extend([id(x) for x in layer_params()])

    if include_others:
        params.append({
            "params":
            filter(lambda p: id(p) not in _layers, network.parameters())
        })

    return params
Exemple #6
0
 def _get_size(self, sample: Dict):
     if self.patch_size is None:
         return ensure_tuple_rep(sample.get(WSIPatchKeys.SIZE), 2)
     return self.patch_size
Exemple #7
0
    def __init__(
        self,
        dimensions: int = 3,
        in_channels: int = 1,
        out_channels: int = 2,
        features: Sequence[int] = (32, 32, 64, 128, 256, 32),
        act: Union[str, tuple] = ("LeakyReLU", {
            "negative_slope": 0.1,
            "inplace": True
        }),
        norm: Union[str, tuple] = ("instance", {
            "affine": True
        }),
        dropout: Union[float, tuple] = 0.0,
        upsample: str = "deconv",
    ):
        print(
            f" ####################-------------------- Triggering your own Arch code "
        )
        print(
            f" --------------------#################### You can change this as you see fit"
        )
        """
        A UNet implementation with 1D/2D/3D supports.

        Based on:

            Falk et al. "U-Net – Deep Learning for Cell Counting, Detection, and
            Morphometry". Nature Methods 16, 67–70 (2019), DOI:
            http://dx.doi.org/10.1038/s41592-018-0261-2

        Args:
            dimensions: number of spatial dimensions. Defaults to 3 for spatial 3D inputs.
            in_channels: number of input channels. Defaults to 1.
            out_channels: number of output channels. Defaults to 2.
            features: six integers as numbers of features.
                Defaults to ``(32, 32, 64, 128, 256, 32)``,

                - the first five values correspond to the five-level encoder feature sizes.
                - the last value corresponds to the feature size after the last upsampling.

            act: activation type and arguments. Defaults to LeakyReLU.
            norm: feature normalization type and arguments. Defaults to instance norm.
            dropout: dropout ratio. Defaults to no dropout.
            upsample: upsampling mode, available options are
                ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.

        Examples::
            # for spatial 2D
            >>> net = BasicUNet(dimensions=2, features=(64, 128, 256, 512, 1024, 128))
            # for spatial 2D, with group norm
            >>> net = BasicUNet(dimensions=2, features=(64, 128, 256, 512, 1024, 128), norm=("group", {"num_groups": 4}))
            # for spatial 3D
            >>> net = BasicUNet(dimensions=3, features=(32, 32, 64, 128, 256, 32))
        See Also

            - :py:class:`monai.networks.nets.DynUNet`
            - :py:class:`monai.networks.nets.UNet`

        """
        super().__init__()

        fea = ensure_tuple_rep(features, 6)
        print(f"BasicUNet features: {fea}.")

        self.conv_0 = TwoConv(dimensions, in_channels, features[0], act, norm,
                              dropout)
        self.down_1 = Down(dimensions, fea[0], fea[1], act, norm, dropout)
        self.down_2 = Down(dimensions, fea[1], fea[2], act, norm, dropout)
        self.down_3 = Down(dimensions, fea[2], fea[3], act, norm, dropout)
        self.down_4 = Down(dimensions, fea[3], fea[4], act, norm, dropout)

        self.upcat_4 = UpCat(dimensions, fea[4], fea[3], fea[3], act, norm,
                             dropout, upsample)
        self.upcat_3 = UpCat(dimensions, fea[3], fea[2], fea[2], act, norm,
                             dropout, upsample)
        self.upcat_2 = UpCat(dimensions, fea[2], fea[1], fea[1], act, norm,
                             dropout, upsample)
        self.upcat_1 = UpCat(dimensions,
                             fea[1],
                             fea[0],
                             fea[5],
                             act,
                             norm,
                             dropout,
                             upsample,
                             halves=False)

        self.final_conv = Conv["conv", dimensions](fea[5],
                                                   out_channels,
                                                   kernel_size=1)
    def __init__(
        self,
        keys: KeysCollection,
        sigma_range: Tuple[float, float],
        magnitude_range: Tuple[float, float],
        spatial_size: Optional[Union[Sequence[int], int]] = None,
        prob: float = 0.1,
        rotate_range: Optional[Union[Sequence[float], float]] = None,
        shear_range: Optional[Union[Sequence[float], float]] = None,
        translate_range: Optional[Union[Sequence[float], float]] = None,
        scale_range: Optional[Union[Sequence[float], float]] = None,
        mode: GridSampleModeSequence = GridSampleMode.BILINEAR,
        padding_mode: GridSamplePadModeSequence = GridSamplePadMode.REFLECTION,
        as_tensor_output: bool = False,
        device: Optional[torch.device] = None,
    ) -> None:
        """
        Args:
            keys: keys of the corresponding items to be transformed.
            sigma_range: a Gaussian kernel with standard deviation sampled from
                ``uniform[sigma_range[0], sigma_range[1])`` will be used to smooth the random offset grid.
            magnitude_range: the random offsets on the grid will be generated from
                ``uniform[magnitude[0], magnitude[1])``.
            spatial_size: specifying output image spatial size [h, w, d].
                if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1,
                the transform will use the spatial size of `img`.
                if the components of the `spatial_size` are non-positive values, the transform will use the
                corresponding components of img size. For example, `spatial_size=(32, 32, -1)` will be adapted
                to `(32, 32, 64)` if the third spatial dimension size of img is `64`.
            prob: probability of returning a randomized affine grid.
                defaults to 0.1, with 10% chance returns a randomized grid,
                otherwise returns a ``spatial_size`` centered area extracted from the input image.
            rotate_range: angle range in radians. rotate_range[0] with be used to generate the 1st rotation
                parameter from `uniform[-rotate_range[0], rotate_range[0])`. Similarly, `rotate_range[1]` and
                `rotate_range[2]` are used in 3D affine for the range of 2nd and 3rd axes.
            shear_range: shear_range[0] with be used to generate the 1st shearing parameter from
                `uniform[-shear_range[0], shear_range[0])`. Similarly, `shear_range[1]` and `shear_range[2]`
                controls the range of the uniform distribution used to generate the 2nd and 3rd parameters.
            translate_range : translate_range[0] with be used to generate the 1st shift parameter from
                `uniform[-translate_range[0], translate_range[0])`. Similarly, `translate_range[1]` and
                `translate_range[2]` controls the range of the uniform distribution used to generate
                the 2nd and 3rd parameters.
            scale_range: scaling_range[0] with be used to generate the 1st scaling factor from
                `uniform[-scale_range[0], scale_range[0]) + 1.0`. Similarly, `scale_range[1]` and `scale_range[2]`
                controls the range of the uniform distribution used to generate the 2nd and 3rd parameters.
            mode: {``"bilinear"``, ``"nearest"``}
                Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
                See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
                It also can be a sequence of string, each element corresponds to a key in ``keys``.
            padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
                Padding mode for outside grid values. Defaults to ``"reflection"``.
                See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
                It also can be a sequence of string, each element corresponds to a key in ``keys``.
            as_tensor_output: the computation is implemented using pytorch tensors, this option specifies
                whether to convert it back to numpy arrays.
            device: device on which the tensor will be allocated.

        See also:
            - :py:class:`RandAffineGrid` for the random affine parameters configurations.
            - :py:class:`Affine` for the affine transformation parameters configurations.
        """
        super().__init__(keys)
        self.rand_3d_elastic = Rand3DElastic(
            sigma_range=sigma_range,
            magnitude_range=magnitude_range,
            prob=prob,
            rotate_range=rotate_range,
            shear_range=shear_range,
            translate_range=translate_range,
            scale_range=scale_range,
            spatial_size=spatial_size,
            as_tensor_output=as_tensor_output,
            device=device,
        )
        self.mode = ensure_tuple_rep(mode, len(self.keys))
        self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys))
Exemple #9
0
    def __init__(
        self,
        keys: KeysCollection,
        transform: InvertibleTransform,
        orig_keys: KeysCollection,
        meta_keys: Optional[KeysCollection] = None,
        orig_meta_keys: Optional[KeysCollection] = None,
        meta_key_postfix: str = "meta_dict",
        nearest_interp: Union[bool, Sequence[bool]] = True,
        to_tensor: Union[bool, Sequence[bool]] = True,
        device: Union[Union[str, torch.device],
                      Sequence[Union[str, torch.device]]] = "cpu",
        post_func: Union[Callable, Sequence[Callable]] = lambda x: x,
        allow_missing_keys: bool = False,
    ) -> None:
        """
        Args:
            keys: the key of expected data in the dict, invert transforms on it, in-place operation.
                it also can be a list of keys, will invert transform for each of them, like: ["pred", "pred_class2"].
            transform: the previous callable transform that applied on input data.
            orig_keys: the key of the original input data in the dict. will get the applied transform information
                for this input data, then invert them for the expected data with `keys`.
                It can also be a list of keys, each matches to the `keys` data.
            meta_keys: explicitly indicate the key for the inverted meta data dictionary.
                the meta data is a dictionary object which contains: filename, original_shape, etc.
                it can be a sequence of string, map to the `keys`.
                if None, will try to construct meta_keys by `{key}_{meta_key_postfix}`.
            orig_meta_keys: the key of the meta data of original input data, will get the `affine`, `data_shape`, etc.
                the meta data is a dictionary object which contains: filename, original_shape, etc.
                it can be a sequence of string, map to the `keys`.
                if None, will try to construct meta_keys by `{orig_key}_{meta_key_postfix}`.
                meta data will also be inverted and stored in `meta_keys`.
            meta_key_postfix: if `orig_meta_keys` is None, use `{orig_key}_{meta_key_postfix}` to to fetch the
                meta data from dict, if `meta_keys` is None, use `{key}_{meta_key_postfix}`.
                default is `meta_dict`, the meta data is a dictionary object.
                For example, to handle orig_key `image`,  read/write `affine` matrices from the
                metadata `image_meta_dict` dictionary's `affine` field.
                the inverted meta dict will be stored with key: "{key}_{meta_key_postfix}".
            nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms,
                default to `True`. If `False`, use the same interpolation mode as the original transform.
                it also can be a list of bool, each matches to the `keys` data.
            to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`.
                it also can be a list of bool, each matches to the `keys` data.
            device: if converted to Tensor, move the inverted results to target device before `post_func`,
                default to "cpu", it also can be a list of string or `torch.device`,
                each matches to the `keys` data.
            post_func: post processing for the inverted data, should be a callable function.
                it also can be a list of callable, each matches to the `keys` data.
            allow_missing_keys: don't raise exception if key is missing.

        """
        super().__init__(keys, allow_missing_keys)
        if not isinstance(transform, InvertibleTransform):
            raise ValueError(
                "transform is not invertible, can't invert transform for the data."
            )
        self.transform = transform
        self.orig_keys = ensure_tuple_rep(orig_keys, len(self.keys))
        self.meta_keys = ensure_tuple_rep(None, len(
            self.keys)) if meta_keys is None else ensure_tuple(meta_keys)
        if len(self.keys) != len(self.meta_keys):
            raise ValueError("meta_keys should have the same length as keys.")
        self.orig_meta_keys = ensure_tuple_rep(orig_meta_keys, len(self.keys))
        self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix,
                                                 len(self.keys))
        self.nearest_interp = ensure_tuple_rep(nearest_interp, len(self.keys))
        self.to_tensor = ensure_tuple_rep(to_tensor, len(self.keys))
        self.device = ensure_tuple_rep(device, len(self.keys))
        self.post_func = ensure_tuple_rep(post_func, len(self.keys))
        self._totensor = ToTensor()
Exemple #10
0
    def __init__(
        self,
        dimensions: int,
        in_channels: int,
        out_channels: int,
        strides: Union[Sequence[int], int] = 1,
        kernel_size: Union[Sequence[int], int] = 3,
        subunits: int = 2,
        adn_ordering: Union[Sequence[str], str] = ["NDA", "NDA"],
        act: Optional[Union[Tuple, str]] = "PRELU",
        norm: Optional[Union[Tuple, str]] = "INSTANCE",
        dropout: Optional[Union[Tuple, str, float]] = None,
        dropout_dim: Optional[int] = 1,
        dilation: Union[Sequence[int], int] = 1,
        bias: bool = True,
        last_conv_only: bool = False,
        is_prunable: bool = False,
        padding: Optional[Union[Sequence[int], int]] = None,
    ) -> None:
        super().__init__()
        self.dimensions = dimensions
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.conv = nn.Sequential()
        self.residual = nn.Identity()
        self.adn_ordering = ensure_tuple_rep(adn_ordering, subunits)
        if not padding:
            padding = same_padding(kernel_size, dilation)
        schannels = in_channels
        sstrides = strides
        subunits = max(1, subunits)

        for su in range(subunits):
            conv_only = last_conv_only and su == (subunits - 1)
            unit = ConvolutionEx(
                dimensions,
                schannels,
                out_channels,
                strides=sstrides,
                kernel_size=kernel_size,
                adn_ordering=self.adn_ordering[su],
                act=act,
                norm=norm,
                dropout=dropout,
                dropout_dim=dropout_dim,
                dilation=dilation,
                bias=bias,
                conv_only=conv_only,
                is_prunable=is_prunable,
                padding=padding,
            )

            self.conv.add_module(f"unit{su:d}", unit)

            # after first loop set channels and strides to what they should be for subsequent units
            schannels = out_channels
            sstrides = 1

        # apply convolution to input to change number of output channels and size to match that coming from self.conv
        if np.prod(strides) != 1 or in_channels != out_channels:
            rkernel_size = kernel_size
            rpadding = padding

            if np.prod(
                    strides
            ) == 1:  # if only adapting number of channels a 1x1 kernel is used with no padding
                rkernel_size = 1
                rpadding = 0

            conv_type = Conv[Conv.CONV, dimensions]
            self.residual = conv_type(in_channels,
                                      out_channels,
                                      rkernel_size,
                                      strides,
                                      rpadding,
                                      bias=bias)
Exemple #11
0
    def __init__(
        self,
        block: Type[Union[ResNetBlock, ResNetBottleneck]],
        layers: List[int],
        block_inplanes: List[int],
        spatial_dims: int = 3,
        n_input_channels: int = 3,
        conv1_t_size: Union[Tuple[int], int] = 7,
        conv1_t_stride: Union[Tuple[int], int] = 1,
        no_max_pool: bool = False,
        shortcut_type: str = "B",
        widen_factor: float = 1.0,
        num_classes: int = 400,
        feed_forward: bool = True,
        n_classes: Optional[int] = None,
    ) -> None:

        super().__init__()
        # in case the new num_classes is default but you still call deprecated n_classes
        if n_classes is not None and num_classes == 400:
            num_classes = n_classes

        conv_type: Type[Union[nn.Conv1d, nn.Conv2d,
                              nn.Conv3d]] = Conv[Conv.CONV, spatial_dims]
        norm_type: Type[Union[nn.BatchNorm1d, nn.BatchNorm2d,
                              nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims]
        pool_type: Type[Union[nn.MaxPool1d, nn.MaxPool2d,
                              nn.MaxPool3d]] = Pool[Pool.MAX, spatial_dims]
        avgp_type: Type[Union[nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d,
                              nn.AdaptiveAvgPool3d]] = Pool[Pool.ADAPTIVEAVG,
                                                            spatial_dims]

        block_avgpool = get_avgpool()
        block_inplanes = [int(x * widen_factor) for x in block_inplanes]

        self.in_planes = block_inplanes[0]
        self.no_max_pool = no_max_pool

        conv1_kernel_size = ensure_tuple_rep(conv1_t_size, spatial_dims)
        conv1_stride = ensure_tuple_rep(conv1_t_stride, spatial_dims)

        self.conv1 = conv_type(
            n_input_channels,
            self.in_planes,
            kernel_size=conv1_kernel_size,  # type: ignore
            stride=conv1_stride,  # type: ignore
            padding=tuple(k // 2 for k in conv1_kernel_size),  # type: ignore
            bias=False,
        )
        self.bn1 = norm_type(self.in_planes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = pool_type(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, block_inplanes[0], layers[0],
                                       spatial_dims, shortcut_type)
        self.layer2 = self._make_layer(block,
                                       block_inplanes[1],
                                       layers[1],
                                       spatial_dims,
                                       shortcut_type,
                                       stride=2)
        self.layer3 = self._make_layer(block,
                                       block_inplanes[2],
                                       layers[2],
                                       spatial_dims,
                                       shortcut_type,
                                       stride=2)
        self.layer4 = self._make_layer(block,
                                       block_inplanes[3],
                                       layers[3],
                                       spatial_dims,
                                       shortcut_type,
                                       stride=2)
        self.avgpool = avgp_type(block_avgpool[spatial_dims])
        self.fc = nn.Linear(block_inplanes[3] * block.expansion,
                            num_classes) if feed_forward else None

        for m in self.modules():
            if isinstance(m, conv_type):
                nn.init.kaiming_normal_(torch.as_tensor(m.weight),
                                        mode="fan_out",
                                        nonlinearity="relu")
            elif isinstance(m, norm_type):
                nn.init.constant_(torch.as_tensor(m.weight), 1)
                nn.init.constant_(torch.as_tensor(m.bias), 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(torch.as_tensor(m.bias), 0)
Exemple #12
0
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        img_size: Union[Sequence[int], int],
        feature_size: int = 16,
        hidden_size: int = 768,
        mlp_dim: int = 3072,
        num_heads: int = 12,
        pos_embed: str = "conv",
        norm_name: Union[Tuple, str] = "instance",
        conv_block: bool = True,
        res_block: bool = True,
        dropout_rate: float = 0.0,
        spatial_dims: int = 3,
    ) -> None:
        """
        Args:
            in_channels: dimension of input channels.
            out_channels: dimension of output channels.
            img_size: dimension of input image.
            feature_size: dimension of network feature size.
            hidden_size: dimension of hidden layer.
            mlp_dim: dimension of feedforward layer.
            num_heads: number of attention heads.
            pos_embed: position embedding layer type.
            norm_name: feature normalization type and arguments.
            conv_block: bool argument to determine if convolutional block is used.
            res_block: bool argument to determine if residual block is used.
            dropout_rate: faction of the input units to drop.
            spatial_dims: number of spatial dims.

        Examples::

            # for single channel input 4-channel output with image size of (96,96,96), feature size of 32 and batch norm
            >>> net = UNETR(in_channels=1, out_channels=4, img_size=(96,96,96), feature_size=32, norm_name='batch')

             # for single channel input 4-channel output with image size of (96,96), feature size of 32 and batch norm
            >>> net = UNETR(in_channels=1, out_channels=4, img_size=96, feature_size=32, norm_name='batch', spatial_dims=2)

            # for 4-channel input 3-channel output with image size of (128,128,128), conv position embedding and instance norm
            >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance')

        """

        super().__init__()

        if not (0 <= dropout_rate <= 1):
            raise ValueError("dropout_rate should be between 0 and 1.")

        if hidden_size % num_heads != 0:
            raise ValueError("hidden_size should be divisible by num_heads.")

        self.num_layers = 12
        img_size = ensure_tuple_rep(img_size, spatial_dims)
        self.patch_size = ensure_tuple_rep(16, spatial_dims)
        self.feat_size = tuple(
            img_d // p_d for img_d, p_d in zip(img_size, self.patch_size))
        self.hidden_size = hidden_size
        self.classification = False
        self.vit = ViT(
            in_channels=in_channels,
            img_size=img_size,
            patch_size=self.patch_size,
            hidden_size=hidden_size,
            mlp_dim=mlp_dim,
            num_layers=self.num_layers,
            num_heads=num_heads,
            pos_embed=pos_embed,
            classification=self.classification,
            dropout_rate=dropout_rate,
            spatial_dims=spatial_dims,
        )
        self.encoder1 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=in_channels,
            out_channels=feature_size,
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.encoder2 = UnetrPrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=hidden_size,
            out_channels=feature_size * 2,
            num_layer=2,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=2,
            norm_name=norm_name,
            conv_block=conv_block,
            res_block=res_block,
        )
        self.encoder3 = UnetrPrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=hidden_size,
            out_channels=feature_size * 4,
            num_layer=1,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=2,
            norm_name=norm_name,
            conv_block=conv_block,
            res_block=res_block,
        )
        self.encoder4 = UnetrPrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=hidden_size,
            out_channels=feature_size * 8,
            num_layer=0,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=2,
            norm_name=norm_name,
            conv_block=conv_block,
            res_block=res_block,
        )
        self.decoder5 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=hidden_size,
            out_channels=feature_size * 8,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder4 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size * 8,
            out_channels=feature_size * 4,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder3 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size * 4,
            out_channels=feature_size * 2,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder2 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size * 2,
            out_channels=feature_size,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.out = UnetOutBlock(spatial_dims=spatial_dims,
                                in_channels=feature_size,
                                out_channels=out_channels)
        self.proj_axes = (0, spatial_dims + 1) + tuple(
            d + 1 for d in range(spatial_dims))
        self.proj_view_shape = list(self.feat_size) + [self.hidden_size]
Exemple #13
0
    def __init__(
        self,
        data: Sequence,
        patch_size: Optional[Union[int, Tuple[int, int]]] = None,
        patch_level: Optional[int] = None,
        mask_level: int = 0,
        overlap: Union[Tuple[float, float], float] = 0.0,
        offset: Union[Tuple[int, int], int, str] = (0, 0),
        offset_limits: Optional[Union[Tuple[Tuple[int, int], Tuple[int, int]],
                                      Tuple[int, int]]] = None,
        transform: Optional[Callable] = None,
        include_label: bool = False,
        center_location: bool = False,
        additional_meta_keys: Sequence[str] = (ProbMapKeys.LOCATION,
                                               ProbMapKeys.SIZE,
                                               ProbMapKeys.COUNT),
        reader="cuCIM",
        map_level: int = 0,
        seed: int = 0,
        **kwargs,
    ):
        super().__init__(
            data=[],
            patch_size=patch_size,
            patch_level=patch_level,
            transform=transform,
            include_label=include_label,
            center_location=center_location,
            additional_meta_keys=additional_meta_keys,
            reader=reader,
            **kwargs,
        )
        self.overlap = overlap
        self.set_random_state(seed)
        # Set the offset config
        self.random_offset = False
        if isinstance(offset, str):
            if offset == "random":
                self.random_offset = True
                self.offset_limits: Optional[Tuple[Tuple[int, int],
                                                   Tuple[int, int]]]
                if offset_limits is None:
                    self.offset_limits = None
                elif isinstance(offset_limits, tuple):
                    if isinstance(offset_limits[0], int):
                        self.offset_limits = (offset_limits, offset_limits)
                    elif isinstance(offset_limits[0], tuple):
                        self.offset_limits = offset_limits
                    else:
                        raise ValueError(
                            "The offset limits should be either a tuple of integers or tuple of tuple of integers."
                        )
                else:
                    raise ValueError("The offset limits should be a tuple.")
            else:
                raise ValueError(
                    f'Invalid string for offset "{offset}". It should be either "random" as a string,'
                    "an integer, or a tuple of integers defining the offset.")
        else:
            self.offset = ensure_tuple_rep(offset, 2)

        self.mask_level = mask_level
        # Create single sample for each patch (in a sliding window manner)
        self.data: list
        self.image_data = list(data)
        for sample in self.image_data:
            patch_samples = self._evaluate_patch_locations(sample)
            self.data.extend(patch_samples)
Exemple #14
0
    def __init__(
        self,
        img_size: Union[Sequence[int], int],
        in_channels: int,
        out_channels: int,
        depths: Sequence[int] = (2, 2, 2, 2),
        num_heads: Sequence[int] = (3, 6, 12, 24),
        feature_size: int = 48,
        norm_name: Union[Tuple, str] = "instance",
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        dropout_path_rate: float = 0.0,
        normalize: bool = False,
        use_checkpoint: bool = False,
        spatial_dims: int = 3,
    ) -> None:
        """
        Args:
            img_size: dimension of input image.
            in_channels: dimension of input channels.
            out_channels: dimension of output channels.
            feature_size: dimension of network feature size.
            depths: number of layers in each stage.
            num_heads: number of attention heads.
            norm_name: feature normalization type and arguments.
            drop_rate: dropout rate.
            attn_drop_rate: attention dropout rate.
            dropout_path_rate: drop path rate.
            normalize: normalize output intermediate features in each stage.
            use_checkpoint: use gradient checkpointing for reduced memory usage.
            spatial_dims: number of spatial dims.

        Examples::

            # for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48.
            >>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48)

            # for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage.
            >>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2))

            # for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
            >>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)

        """

        super().__init__()

        img_size = ensure_tuple_rep(img_size, spatial_dims)
        patch_size = ensure_tuple_rep(2, spatial_dims)
        window_size = ensure_tuple_rep(7, spatial_dims)

        if not (spatial_dims == 2 or spatial_dims == 3):
            raise ValueError("spatial dimension should be 2 or 3.")

        for m, p in zip(img_size, patch_size):
            for i in range(5):
                if m % np.power(p, i + 1) != 0:
                    raise ValueError(
                        "input image size (img_size) should be divisible by stage-wise image resolution."
                    )

        if not (0 <= drop_rate <= 1):
            raise ValueError("dropout rate should be between 0 and 1.")

        if not (0 <= attn_drop_rate <= 1):
            raise ValueError(
                "attention dropout rate should be between 0 and 1.")

        if not (0 <= dropout_path_rate <= 1):
            raise ValueError("drop path rate should be between 0 and 1.")

        if feature_size % 12 != 0:
            raise ValueError("feature_size should be divisible by 12.")

        self.normalize = normalize

        self.swinViT = SwinTransformer(
            in_chans=in_channels,
            embed_dim=feature_size,
            window_size=window_size,
            patch_size=patch_size,
            depths=depths,
            num_heads=num_heads,
            mlp_ratio=4.0,
            qkv_bias=True,
            drop_rate=drop_rate,
            attn_drop_rate=attn_drop_rate,
            drop_path_rate=dropout_path_rate,
            norm_layer=nn.LayerNorm,
            use_checkpoint=use_checkpoint,
            spatial_dims=spatial_dims,
        )

        self.encoder1 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=in_channels,
            out_channels=feature_size,
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=True,
        )

        self.encoder2 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size,
            out_channels=feature_size,
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=True,
        )

        self.encoder3 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=2 * feature_size,
            out_channels=2 * feature_size,
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=True,
        )

        self.encoder4 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=4 * feature_size,
            out_channels=4 * feature_size,
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=True,
        )

        self.encoder10 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=16 * feature_size,
            out_channels=16 * feature_size,
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=True,
        )

        self.decoder5 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=16 * feature_size,
            out_channels=8 * feature_size,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=True,
        )

        self.decoder4 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size * 8,
            out_channels=feature_size * 4,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=True,
        )

        self.decoder3 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size * 4,
            out_channels=feature_size * 2,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=True,
        )
        self.decoder2 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size * 2,
            out_channels=feature_size,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=True,
        )

        self.decoder1 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size,
            out_channels=feature_size,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=True,
        )

        self.out = UnetOutBlock(spatial_dims=spatial_dims,
                                in_channels=feature_size,
                                out_channels=out_channels)  # type: ignore
Exemple #15
0
    def __init__(
        self,
        in_channels: int,
        img_size: Union[Sequence[int], int],
        patch_size: Union[Sequence[int], int],
        hidden_size: int,
        num_heads: int,
        pos_embed: str,
        dropout_rate: float = 0.0,
        spatial_dims: int = 3,
    ) -> None:
        """
        Args:
            in_channels: dimension of input channels.
            img_size: dimension of input image.
            patch_size: dimension of patch size.
            hidden_size: dimension of hidden layer.
            num_heads: number of attention heads.
            pos_embed: position embedding layer type.
            dropout_rate: faction of the input units to drop.
            spatial_dims: number of spatial dimensions.


        """

        super().__init__()

        if not (0 <= dropout_rate <= 1):
            raise ValueError("dropout_rate should be between 0 and 1.")

        if hidden_size % num_heads != 0:
            raise ValueError("hidden size should be divisible by num_heads.")

        self.pos_embed = look_up_option(pos_embed, SUPPORTED_EMBEDDING_TYPES)

        img_size = ensure_tuple_rep(img_size, spatial_dims)
        patch_size = ensure_tuple_rep(patch_size, spatial_dims)
        for m, p in zip(img_size, patch_size):
            if m < p:
                raise ValueError("patch_size should be smaller than img_size.")
            if self.pos_embed == "perceptron" and m % p != 0:
                raise ValueError("patch_size should be divisible by img_size for perceptron.")
        self.n_patches = np.prod([im_d // p_d for im_d, p_d in zip(img_size, patch_size)])
        self.patch_dim = int(in_channels * np.prod(patch_size))

        self.patch_embeddings: nn.Module
        if self.pos_embed == "conv":
            self.patch_embeddings = Conv[Conv.CONV, spatial_dims](
                in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size
            )
        elif self.pos_embed == "perceptron":
            # for 3d: "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)"
            chars = (("h", "p1"), ("w", "p2"), ("d", "p3"))[:spatial_dims]
            from_chars = "b c " + " ".join(f"({k} {v})" for k, v in chars)
            to_chars = f"b ({' '.join([c[0] for c in chars])}) ({' '.join([c[1] for c in chars])} c)"
            axes_len = {f"p{i+1}": p for i, p in enumerate(patch_size)}
            self.patch_embeddings = nn.Sequential(
                Rearrange(f"{from_chars} -> {to_chars}", **axes_len), nn.Linear(self.patch_dim, hidden_size)
            )
        self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size))
        self.dropout = nn.Dropout(dropout_rate)
        trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0)
        self.apply(self._init_weights)
Exemple #16
0
    def __init__(
        self,
        in_shape: Sequence[int],
        out_shape: Sequence[int],
        channels: Sequence[int],
        strides: Sequence[int],
        kernel_size: Union[Sequence[int], int] = 3,
        num_res_units: int = 2,
        act=Act.PRELU,
        norm=Norm.INSTANCE,
        dropout: Optional[float] = None,
        bias: bool = True,
    ) -> None:
        """
        Construct the regressor network with the number of layers defined by `channels` and `strides`. Inputs are
        first passed through the convolutional layers in the forward pass, the output from this is then pass
        through a fully connected layer to relate them to the final output tensor.

        Args:
            in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension)
            out_shape: tuple of integers stating the dimension of the final output tensor
            channels: tuple of integers stating the output channels of each convolutional layer
            strides: tuple of integers stating the stride (downscale factor) of each convolutional layer
            kernel_size: integer or tuple of integers stating size of convolutional kernels
            num_res_units: integer stating number of convolutions in residual units, 0 means no residual units
            act: name or type defining activation layers
            norm: name or type defining normalization layers
            dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout
            bias: boolean stating if convolution layers should have a bias component
        """
        super().__init__()

        self.in_channels, *self.in_shape = ensure_tuple(in_shape)
        self.dimensions = len(self.in_shape)
        self.channels = ensure_tuple(channels)
        self.strides = ensure_tuple(strides)
        self.out_shape = ensure_tuple(out_shape)
        self.kernel_size = ensure_tuple_rep(kernel_size, self.dimensions)
        self.num_res_units = num_res_units
        self.act = act
        self.norm = norm
        self.dropout = dropout
        self.bias = bias
        self.net = nn.Sequential()

        echannel = self.in_channels

        padding = same_padding(kernel_size)

        self.final_size = np.asarray(self.in_shape, np.int)
        self.reshape = Reshape(*self.out_shape)

        # encode stage
        for i, (c, s) in enumerate(zip(self.channels, self.strides)):
            layer = self._get_layer(echannel, c, s, i == len(channels) - 1)
            echannel = c  # use the output channel number as the input for the next loop
            self.net.add_module("layer_%i" % i, layer)
            self.final_size = calculate_out_shape(self.final_size, kernel_size,
                                                  s, padding)

        self.final = self._get_final_layer((echannel, ) + self.final_size)
Exemple #17
0
    def __init__(
        self,
        spatial_dims: int,
        in_channels: Optional[int] = None,
        out_channels: Optional[int] = None,
        scale_factor: Union[Sequence[float], float] = 2,
        size: Optional[Union[Tuple[int], int]] = None,
        mode: Union[UpsampleMode, str] = UpsampleMode.DECONV,
        pre_conv: Optional[Union[nn.Module, str]] = "default",
        interp_mode: Union[InterpolateMode, str] = InterpolateMode.LINEAR,
        align_corners: Optional[bool] = True,
        bias: bool = True,
        apply_pad_pool: bool = True,
        dimensions: Optional[int] = None,
    ) -> None:
        """
        Args:
            spatial_dims: number of spatial dimensions of the input image.
            in_channels: number of channels of the input image.
            out_channels: number of channels of the output image. Defaults to `in_channels`.
            scale_factor: multiplier for spatial size. Has to match input size if it is a tuple. Defaults to 2.
            size: spatial size of the output image.
                Only used when ``mode`` is ``UpsampleMode.NONTRAINABLE``.
                In torch.nn.functional.interpolate, only one of `size` or `scale_factor` should be defined,
                thus if size is defined, `scale_factor` will not be used.
                Defaults to None.
            mode: {``"deconv"``, ``"nontrainable"``, ``"pixelshuffle"``}. Defaults to ``"deconv"``.
            pre_conv: a conv block applied before upsampling. Defaults to "default".
                When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized when
                Only used in the "nontrainable" or "pixelshuffle" mode.
            interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
                Only used in the "nontrainable" mode.
                If ends with ``"linear"`` will use ``spatial dims`` to determine the correct interpolation.
                This corresponds to linear, bilinear, trilinear for 1D, 2D, and 3D respectively.
                The interpolation mode. Defaults to ``"linear"``.
                See also: https://pytorch.org/docs/stable/nn.html#upsample
            align_corners: set the align_corners parameter of `torch.nn.Upsample`. Defaults to True.
                Only used in the "nontrainable" mode.
            bias: whether to have a bias term in the default preconv and deconv layers. Defaults to True.
            apply_pad_pool: if True the upsampled tensor is padded then average pooling is applied with a kernel the
                size of `scale_factor` with a stride of 1. See also: :py:class:`monai.networks.blocks.SubpixelUpsample`.
                Only used in the "pixelshuffle" mode.

        .. deprecated:: 0.6.0
            ``dimensions`` is deprecated, use ``spatial_dims`` instead.
        """
        super().__init__()
        if dimensions is not None:
            spatial_dims = dimensions
        scale_factor_ = ensure_tuple_rep(scale_factor, spatial_dims)
        up_mode = look_up_option(mode, UpsampleMode)
        if up_mode == UpsampleMode.DECONV:
            if not in_channels:
                raise ValueError(
                    f"in_channels needs to be specified in the '{mode}' mode.")
            self.add_module(
                "deconv",
                Conv[Conv.CONVTRANS, spatial_dims](
                    in_channels=in_channels,
                    out_channels=out_channels or in_channels,
                    kernel_size=scale_factor_,
                    stride=scale_factor_,
                    bias=bias,
                ),
            )
        elif up_mode == UpsampleMode.NONTRAINABLE:
            if pre_conv == "default" and (
                    out_channels !=
                    in_channels):  # defaults to no conv if out_chns==in_chns
                if not in_channels:
                    raise ValueError(
                        f"in_channels needs to be specified in the '{mode}' mode."
                    )
                self.add_module(
                    "preconv",
                    Conv[Conv.CONV, spatial_dims](in_channels=in_channels,
                                                  out_channels=out_channels
                                                  or in_channels,
                                                  kernel_size=1,
                                                  bias=bias),
                )
            elif pre_conv is not None and pre_conv != "default":
                self.add_module("preconv", pre_conv)  # type: ignore
            elif pre_conv is None and (out_channels != in_channels):
                raise ValueError(
                    "in the nontrainable mode, if not setting pre_conv, out_channels should equal to in_channels."
                )

            interp_mode = InterpolateMode(interp_mode)
            linear_mode = [
                InterpolateMode.LINEAR, InterpolateMode.BILINEAR,
                InterpolateMode.TRILINEAR
            ]
            if interp_mode in linear_mode:  # choose mode based on dimensions
                interp_mode = linear_mode[spatial_dims - 1]
            self.add_module(
                "upsample_non_trainable",
                nn.Upsample(
                    size=size,
                    scale_factor=None if size else scale_factor_,
                    mode=interp_mode.value,
                    align_corners=align_corners,
                ),
            )
        elif up_mode == UpsampleMode.PIXELSHUFFLE:
            self.add_module(
                "pixelshuffle",
                SubpixelUpsample(
                    spatial_dims=spatial_dims,
                    in_channels=in_channels,
                    out_channels=out_channels,
                    scale_factor=scale_factor_[0],  # isotropic
                    conv_block=pre_conv,
                    apply_pad_pool=apply_pad_pool,
                    bias=bias,
                ),
            )
        else:
            raise NotImplementedError(f"Unsupported upsampling mode {mode}.")
    def __init__(
        self,
        keys: KeysCollection,
        pixdim: Sequence[float],
        diagonal: bool = False,
        mode: GridSampleModeSequence = GridSampleMode.BILINEAR,
        padding_mode: GridSamplePadModeSequence = GridSamplePadMode.BORDER,
        align_corners: Union[Sequence[bool], bool] = False,
        dtype: Optional[Union[Sequence[np.dtype], np.dtype]] = np.float64,
        meta_key_postfix: str = "meta_dict",
    ) -> None:
        """
        Args:
            pixdim: output voxel spacing.
            diagonal: whether to resample the input to have a diagonal affine matrix.
                If True, the input data is resampled to the following affine::

                    np.diag((pixdim_0, pixdim_1, pixdim_2, 1))

                This effectively resets the volume to the world coordinate system (RAS+ in nibabel).
                The original orientation, rotation, shearing are not preserved.

                If False, the axes orientation, orthogonal rotation and
                translations components from the original affine will be
                preserved in the target affine. This option will not flip/swap
                axes against the original ones.
            mode: {``"bilinear"``, ``"nearest"``}
                Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
                See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
                It also can be a sequence of string, each element corresponds to a key in ``keys``.
            padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
                Padding mode for outside grid values. Defaults to ``"border"``.
                See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
                It also can be a sequence of string, each element corresponds to a key in ``keys``.
            align_corners: Geometrically, we consider the pixels of the input as squares rather than points.
                See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
                It also can be a sequence of bool, each element corresponds to a key in ``keys``.
            dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision.
                If None, use the data type of input data. To be compatible with other modules,
                the output data type is always ``np.float32``.
                It also can be a sequence of np.dtype, each element corresponds to a key in ``keys``.
            meta_key_postfix: use `key_{postfix}` to to fetch the meta data according to the key data,
                default is `meta_dict`, the meta data is a dictionary object.
                For example, to handle key `image`,  read/write affine matrices from the
                metadata `image_meta_dict` dictionary's `affine` field.

        Raises:
            TypeError: When ``meta_key_postfix`` is not a ``str``.

        """
        super().__init__(keys)
        self.spacing_transform = Spacing(pixdim, diagonal=diagonal)
        self.mode = ensure_tuple_rep(mode, len(self.keys))
        self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys))
        self.align_corners = ensure_tuple_rep(align_corners, len(self.keys))
        self.dtype = ensure_tuple_rep(dtype, len(self.keys))
        if not isinstance(meta_key_postfix, str):
            raise TypeError(
                f"meta_key_postfix must be a str but is {type(meta_key_postfix).__name__}."
            )
        self.meta_key_postfix = meta_key_postfix
    def __init__(
        self,
        spatial_dims: int = 3,
        in_channels: int = 1,
        out_channels: int = 2,
        features: Sequence[int] = (32, 32, 64, 128, 256, 32),
        act: Union[str, tuple] = ("LeakyReLU", {"negative_slope": 0.1, "inplace": True}),
        norm: Union[str, tuple] = ("instance", {"affine": True}),
        bias: bool = True,
        dropout: Union[float, tuple] = 0.0,
        upsample: str = "deconv",
        dimensions: Optional[int] = None,
    ):
        """
        A UNet implementation with 1D/2D/3D supports.

        Based on:

            Falk et al. "U-Net – Deep Learning for Cell Counting, Detection, and
            Morphometry". Nature Methods 16, 67–70 (2019), DOI:
            http://dx.doi.org/10.1038/s41592-018-0261-2

        Args:
            spatial_dims: number of spatial dimensions. Defaults to 3 for spatial 3D inputs.
            in_channels: number of input channels. Defaults to 1.
            out_channels: number of output channels. Defaults to 2.
            features: six integers as numbers of features.
                Defaults to ``(32, 32, 64, 128, 256, 32)``,

                - the first five values correspond to the five-level encoder feature sizes.
                - the last value corresponds to the feature size after the last upsampling.

            act: activation type and arguments. Defaults to LeakyReLU.
            norm: feature normalization type and arguments. Defaults to instance norm.
            bias: whether to have a bias term in convolution blocks. Defaults to True.
                According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,
                if a conv layer is directly followed by a batch norm layer, bias should be False.
            dropout: dropout ratio. Defaults to no dropout.
            upsample: upsampling mode, available options are
                ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.

        .. deprecated:: 0.6.0
            ``dimensions`` is deprecated, use ``spatial_dims`` instead.

        Examples::

            # for spatial 2D
            >>> net = BasicUNet(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128))

            # for spatial 2D, with group norm
            >>> net = BasicUNet(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128), norm=("group", {"num_groups": 4}))

            # for spatial 3D
            >>> net = BasicUNet(spatial_dims=3, features=(32, 32, 64, 128, 256, 32))

        See Also

            - :py:class:`monai.networks.nets.DynUNet`
            - :py:class:`monai.networks.nets.UNet`

        """
        super().__init__()
        if dimensions is not None:
            spatial_dims = dimensions

        fea = ensure_tuple_rep(features, 6)
        print(f"BasicUNet features: {fea}.")

        self.conv_0 = TwoConv(spatial_dims, in_channels, features[0], act, norm, bias, dropout)
        self.down_1 = Down(spatial_dims, fea[0], fea[1], act, norm, bias, dropout)
        self.down_2 = Down(spatial_dims, fea[1], fea[2], act, norm, bias, dropout)
        self.down_3 = Down(spatial_dims, fea[2], fea[3], act, norm, bias, dropout)
        self.down_4 = Down(spatial_dims, fea[3], fea[4], act, norm, bias, dropout)

        self.upcat_4 = UpCat(spatial_dims, fea[4], fea[3], fea[3], act, norm, bias, dropout, upsample)
        self.upcat_3 = UpCat(spatial_dims, fea[3], fea[2], fea[2], act, norm, bias, dropout, upsample)
        self.upcat_2 = UpCat(spatial_dims, fea[2], fea[1], fea[1], act, norm, bias, dropout, upsample)
        self.upcat_1 = UpCat(spatial_dims, fea[1], fea[0], fea[5], act, norm, bias, dropout, upsample, halves=False)

        self.final_conv = Conv["conv", spatial_dims](fea[5], out_channels, kernel_size=1)
Exemple #20
0
 def __init__(self, keys: KeysCollection, func: Union[Sequence[Callable], Callable]) -> None:
     super().__init__(keys)
     self.func = ensure_tuple_rep(func, len(self.keys))
     self.lambd = Lambda()
Exemple #21
0
    def __init__(
        self,
        latent_shape: Sequence[int],
        start_shape: Sequence[int],
        channels: Sequence[int],
        strides: Sequence[int],
        kernel_size: Union[Sequence[int], int] = 3,
        num_res_units: int = 2,
        act=Act.PRELU,
        norm=Norm.INSTANCE,
        dropout: Optional[float] = None,
        bias: bool = True,
    ) -> None:
        """
        Construct the generator network with the number of layers defined by `channels` and `strides`. In the
        forward pass a `nn.Linear` layer relates the input latent vector to a tensor of dimensions `start_shape`,
        this is then fed forward through the sequence of convolutional layers. The number of layers is defined by
        the length of `channels` and `strides` which must match, each layer having the number of output channels
        given in `channels` and an upsample factor given in `strides` (ie. a transpose convolution with that stride
        size).

        Args:
            latent_shape: tuple of integers stating the dimension of the input latent vector (minus batch dimension)
            start_shape: tuple of integers stating the dimension of the tensor to pass to convolution subnetwork
            channels: tuple of integers stating the output channels of each convolutional layer
            strides: tuple of integers stating the stride (upscale factor) of each convolutional layer
            kernel_size: integer or tuple of integers stating size of convolutional kernels
            num_res_units: integer stating number of convolutions in residual units, 0 means no residual units
            act: name or type defining activation layers
            norm: name or type defining normalization layers
            dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout
            bias: boolean stating if convolution layers should have a bias component
        """
        super().__init__()

        self.in_channels, *self.start_shape = ensure_tuple(start_shape)
        self.dimensions = len(self.start_shape)

        self.latent_shape = ensure_tuple(latent_shape)
        self.channels = ensure_tuple(channels)
        self.strides = ensure_tuple(strides)
        self.kernel_size = ensure_tuple_rep(kernel_size, self.dimensions)
        self.num_res_units = num_res_units
        self.act = act
        self.norm = norm
        self.dropout = dropout
        self.bias = bias

        self.flatten = nn.Flatten()
        self.linear = nn.Linear(int(np.prod(self.latent_shape)),
                                int(np.prod(start_shape)))
        self.reshape = Reshape(*start_shape)
        self.conv = nn.Sequential()

        echannel = self.in_channels

        # transform tensor of shape `start_shape' into output shape through transposed convolutions and residual units
        for i, (c, s) in enumerate(zip(channels, strides)):
            is_last = i == len(channels) - 1
            layer = self._get_layer(echannel, c, s, is_last)
            self.conv.add_module("layer_%i" % i, layer)
            echannel = c