Beispiel #1
0
    def __init__(
        self,
        spatial_dims: int = 3,
        init_filters: int = 8,
        in_channels: int = 1,
        out_channels: int = 2,
        dropout_prob: Optional[float] = None,
        norm_name: str = "group",
        num_groups: int = 8,
        use_conv_final: bool = True,
        blocks_down: tuple = (1, 2, 2, 4),
        blocks_up: tuple = (1, 1, 1),
        upsample_mode: Union[UpsampleMode, str] = UpsampleMode.NONTRAINABLE,
    ):
        super().__init__()

        assert spatial_dims == 2 or spatial_dims == 3, "spatial_dims can only be 2 or 3."

        self.spatial_dims = spatial_dims
        self.init_filters = init_filters
        self.blocks_down = blocks_down
        self.blocks_up = blocks_up
        self.dropout_prob = dropout_prob
        self.norm_name = norm_name
        self.num_groups = num_groups
        self.upsample_mode = UpsampleMode(upsample_mode)
        self.use_conv_final = use_conv_final
        self.convInit = get_conv_layer(spatial_dims, in_channels, init_filters)
        self.down_layers = self._make_down_layers()
        self.up_layers, self.up_samples = self._make_up_layers()
        self.relu = Act[Act.RELU](inplace=True)
        self.conv_final = self._make_final_conv(out_channels)

        if dropout_prob is not None:
            self.dropout = Dropout[Dropout.DROPOUT, spatial_dims](dropout_prob)
Beispiel #2
0
 def __init__(
     self,
     spatial_dims: int,
     in_channels: int,
     out_channels: Optional[int] = None,
     scale_factor: Union[Sequence[float], float] = 2,
     with_conv: bool = False,
     mode: Union[UpsampleMode, str] = UpsampleMode.LINEAR,
     align_corners: Optional[bool] = True,
 ) -> None:
     """
     Args:
         spatial_dims: number of spatial dimensions of the input image.
         in_channels: number of channels of the input image.
         out_channels: number of channels of the output image. Defaults to `in_channels`.
         scale_factor: multiplier for spatial size. Has to match input size if it is a tuple. Defaults to 2.
         with_conv: whether to use a transposed convolution for upsampling. Defaults to False.
         mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
             If ends with ``"linear"`` will use ``spatial dims`` to determine the correct interpolation.
             This corresponds to linear, bilinear, trilinear for 1D, 2D, and 3D respectively.
             The interpolation mode. Defaults to ``"linear"``.
             See also: https://pytorch.org/docs/stable/nn.html#upsample
         align_corners: set the align_corners parameter of `torch.nn.Upsample`. Defaults to True.
     """
     super().__init__()
     scale_factor_ = ensure_tuple_rep(scale_factor, spatial_dims)
     if not out_channels:
         out_channels = in_channels
     if not with_conv:
         mode = UpsampleMode(mode)
         linear_mode = [
             UpsampleMode.LINEAR, UpsampleMode.BILINEAR,
             UpsampleMode.TRILINEAR
         ]
         if mode in linear_mode:  # choose mode based on spatial_dims
             mode = linear_mode[spatial_dims - 1]
         self.upsample = nn.Sequential(
             Conv[Conv.CONV, spatial_dims](in_channels=in_channels,
                                           out_channels=out_channels,
                                           kernel_size=1),
             nn.Upsample(scale_factor=scale_factor_,
                         mode=mode.value,
                         align_corners=align_corners),
         )
     else:
         self.upsample = Conv[Conv.CONVTRANS,
                              spatial_dims](in_channels=in_channels,
                                            out_channels=out_channels,
                                            kernel_size=scale_factor_,
                                            stride=scale_factor_)
Beispiel #3
0
    def __init__(
        self,
        spatial_dims: int = 3,
        init_filters: int = 8,
        in_channels: int = 1,
        out_channels: int = 2,
        dropout_prob: Optional[float] = None,
        act: Union[Tuple, str] = ("RELU", {
            "inplace": True
        }),
        norm: Union[Tuple, str] = ("GROUP", {
            "num_groups": 8
        }),
        norm_name: str = "",
        num_groups: int = 8,
        use_conv_final: bool = True,
        blocks_down: tuple = (1, 2, 2, 4),
        blocks_up: tuple = (1, 1, 1),
        upsample_mode: Union[UpsampleMode, str] = UpsampleMode.NONTRAINABLE,
    ):
        super().__init__()

        if spatial_dims not in (2, 3):
            raise ValueError("`spatial_dims` can only be 2 or 3.")

        self.spatial_dims = spatial_dims
        self.init_filters = init_filters
        self.in_channels = in_channels
        self.blocks_down = blocks_down
        self.blocks_up = blocks_up
        self.dropout_prob = dropout_prob
        self.act = act  # input options
        self.act_mod = get_act_layer(act)
        if norm_name:
            if norm_name.lower() != "group":
                raise ValueError(
                    f"Deprecating option 'norm_name={norm_name}', please use 'norm' instead."
                )
            norm = ("group", {"num_groups": num_groups})
        self.norm = norm
        self.upsample_mode = UpsampleMode(upsample_mode)
        self.use_conv_final = use_conv_final
        self.convInit = get_conv_layer(spatial_dims, in_channels, init_filters)
        self.down_layers = self._make_down_layers()
        self.up_layers, self.up_samples = self._make_up_layers()
        self.conv_final = self._make_final_conv(out_channels)

        if dropout_prob is not None:
            self.dropout = Dropout[Dropout.DROPOUT, spatial_dims](dropout_prob)
Beispiel #4
0
    def __init__(
        self,
        spatial_dims: int = 3,
        init_filters: int = 8,
        in_channels: int = 1,
        out_channels: int = 2,
        dropout_prob: Optional[float] = None,
        act: Union[Tuple, str] = ("RELU", {
            "inplace": True
        }),
        norm: Union[Tuple, str] = ("GROUP", {
            "num_groups": 8
        }),
        use_conv_final: bool = True,
        blocks_down: tuple = (1, 2, 2, 4),
        blocks_up: tuple = (1, 1, 1),
        upsample_mode: Union[UpsampleMode, str] = UpsampleMode.NONTRAINABLE,
    ):
        super().__init__()

        if spatial_dims not in (2, 3):
            raise AssertionError("spatial_dims can only be 2 or 3.")

        self.spatial_dims = spatial_dims
        self.init_filters = init_filters
        self.in_channels = in_channels
        self.blocks_down = blocks_down
        self.blocks_up = blocks_up
        self.dropout_prob = dropout_prob
        self.act = get_act_layer(act)
        self.norm = norm
        self.upsample_mode = UpsampleMode(upsample_mode)
        self.use_conv_final = use_conv_final
        self.convInit = get_conv_layer(spatial_dims, in_channels, init_filters)
        self.down_layers = self._make_down_layers()
        self.up_layers, self.up_samples = self._make_up_layers()
        self.conv_final = self._make_final_conv(out_channels)

        if dropout_prob is not None:
            self.dropout = Dropout[Dropout.DROPOUT, spatial_dims](dropout_prob)
Beispiel #5
0
    def __init__(
        self,
        dimensions: int,
        in_channels: Optional[int] = None,
        out_channels: Optional[int] = None,
        scale_factor: Union[Sequence[float], float] = 2,
        mode: Union[UpsampleMode, str] = UpsampleMode.DECONV,
        pre_conv: Optional[Union[nn.Module, str]] = "default",
        interp_mode: Union[InterpolateMode, str] = InterpolateMode.LINEAR,
        align_corners: Optional[bool] = True,
        bias: bool = True,
        apply_pad_pool: bool = True,
    ) -> None:
        """
        Args:
            dimensions: number of spatial dimensions of the input image.
            in_channels: number of channels of the input image.
            out_channels: number of channels of the output image. Defaults to `in_channels`.
            scale_factor: multiplier for spatial size. Has to match input size if it is a tuple. Defaults to 2.
            mode: {``"deconv"``, ``"nontrainable"``, ``"pixelshuffle"``}. Defaults to ``"deconv"``.
            pre_conv: a conv block applied before upsampling. Defaults to None.
                When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized when
                Only used in the "nontrainable" or "pixelshuffle" mode.
            interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
                Only used when ``mode`` is ``UpsampleMode.NONTRAINABLE``.
                If ends with ``"linear"`` will use ``spatial dims`` to determine the correct interpolation.
                This corresponds to linear, bilinear, trilinear for 1D, 2D, and 3D respectively.
                The interpolation mode. Defaults to ``"linear"``.
                See also: https://pytorch.org/docs/stable/nn.html#upsample
            align_corners: set the align_corners parameter of `torch.nn.Upsample`. Defaults to True.
                Only used in the nontrainable mode.
            bias: whether to have a bias term in the default preconv and deconv layers. Defaults to True.
            apply_pad_pool: if True the upsampled tensor is padded then average pooling is applied with a kernel the
                size of `scale_factor` with a stride of 1. See also: :py:class:`monai.networks.blocks.SubpixelUpsample`.
                Only used in the pixelshuffle mode.
        """
        super().__init__()
        scale_factor_ = ensure_tuple_rep(scale_factor, dimensions)
        up_mode = UpsampleMode(mode)
        if up_mode == UpsampleMode.DECONV:
            if not in_channels:
                raise ValueError(
                    f"in_channels needs to be specified in the '{mode}' mode.")
            self.add_module(
                "deconv",
                Conv[Conv.CONVTRANS, dimensions](
                    in_channels=in_channels,
                    out_channels=out_channels or in_channels,
                    kernel_size=scale_factor_,
                    stride=scale_factor_,
                    bias=bias,
                ),
            )
        elif up_mode == UpsampleMode.NONTRAINABLE:
            if pre_conv == "default" and (
                    out_channels !=
                    in_channels):  # defaults to no conv if out_chns==in_chns
                if not in_channels:
                    raise ValueError(
                        f"in_channels needs to be specified in the '{mode}' mode."
                    )
                self.add_module(
                    "preconv",
                    Conv[Conv.CONV, dimensions](in_channels=in_channels,
                                                out_channels=out_channels
                                                or in_channels,
                                                kernel_size=1,
                                                bias=bias),
                )
            elif pre_conv is not None and pre_conv != "default":
                self.add_module("preconv", pre_conv)  # type: ignore

            interp_mode = InterpolateMode(interp_mode)
            linear_mode = [
                InterpolateMode.LINEAR, InterpolateMode.BILINEAR,
                InterpolateMode.TRILINEAR
            ]
            if interp_mode in linear_mode:  # choose mode based on dimensions
                interp_mode = linear_mode[dimensions - 1]
            self.add_module(
                "upsample_non_trainable",
                nn.Upsample(scale_factor=scale_factor_,
                            mode=interp_mode.value,
                            align_corners=align_corners),
            )
        elif up_mode == UpsampleMode.PIXELSHUFFLE:
            self.add_module(
                "pixelshuffle",
                SubpixelUpsample(
                    dimensions=dimensions,
                    in_channels=in_channels,
                    out_channels=out_channels,
                    scale_factor=scale_factor_[0],  # isotropic
                    conv_block=pre_conv,
                    apply_pad_pool=apply_pad_pool,
                    bias=bias,
                ),
            )
        else:
            raise NotImplementedError(f"Unsupported upsampling mode {mode}.")