Ejemplo n.º 1
0
    def __init__(
        self,
        dimensions: int,
        in_channels: int,
        out_channels: int,
        strides=1,
        kernel_size=3,
        act=Act.PRELU,
        norm=Norm.INSTANCE,
        dropout=None,
        dilation=1,
        bias: bool = True,
        conv_only: bool = False,
        is_transposed: bool = False,
    ) -> None:
        super().__init__()
        self.dimensions = dimensions
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.is_transposed = is_transposed

        padding = same_padding(kernel_size, dilation)
        conv_type = Conv[Conv.CONVTRANS if is_transposed else Conv.CONV, dimensions]

        # define the normalisation type and the arguments to the constructor
        norm_name, norm_args = split_args(norm)
        norm_type = Norm[norm_name, dimensions]

        # define the activation type and the arguments to the constructor
        if act is not None:
            act_name, act_args = split_args(act)
            act_type = Act[act_name]
        else:
            act_type = act_args = None

        if dropout:
            # if dropout was specified simply as a p value, use default name and make a keyword map with the value
            if isinstance(dropout, (int, float)):
                drop_name = Dropout.DROPOUT
                drop_args = {"p": dropout}
            else:
                drop_name, drop_args = split_args(dropout)

            drop_type = Dropout[drop_name, dimensions]

        if is_transposed:
            conv = conv_type(in_channels, out_channels, kernel_size, strides, padding, strides - 1, 1, bias, dilation)
        else:
            conv = conv_type(in_channels, out_channels, kernel_size, strides, padding, dilation, bias=bias)

        self.add_module("conv", conv)

        if not conv_only:
            self.add_module("norm", norm_type(out_channels, **norm_args))
            if dropout:
                self.add_module("dropout", drop_type(**drop_args))
            if act is not None:
                self.add_module("act", act_type(**act_args))
Ejemplo n.º 2
0
    def __init__(
        self,
        ordering: str = "NDA",
        in_channels: Optional[int] = None,
        act: Optional[Union[Tuple, str]] = "RELU",
        norm: Optional[Union[Tuple, str]] = None,
        norm_dim: Optional[int] = None,
        dropout: Optional[Union[Tuple, str, float]] = None,
        dropout_dim: Optional[int] = None,
    ) -> None:
        super().__init__()

        op_dict = {"A": None, "D": None, "N": None}
        # define the normalization type and the arguments to the constructor
        if norm is not None:
            if norm_dim is None and dropout_dim is None:
                raise ValueError(
                    "norm_dim or dropout_dim needs to be specified.")
            norm_name, norm_args = split_args(norm)
            norm_type = Norm[norm_name, norm_dim or dropout_dim]
            kw_args = dict(norm_args)
            if has_option(norm_type,
                          "num_features") and "num_features" not in kw_args:
                kw_args["num_features"] = in_channels
            if has_option(norm_type,
                          "num_channels") and "num_channels" not in kw_args:
                kw_args["num_channels"] = in_channels
            op_dict["N"] = norm_type(**kw_args)

        # define the activation type and the arguments to the constructor
        if act is not None:
            act_name, act_args = split_args(act)
            act_type = Act[act_name]
            op_dict["A"] = act_type(**act_args)

        if dropout is not None:
            # if dropout was specified simply as a p value, use default name and make a keyword map with the value
            if isinstance(dropout, (int, float)):
                drop_name = Dropout.DROPOUT
                drop_args = {"p": float(dropout)}
            else:
                drop_name, drop_args = split_args(dropout)

            if norm_dim is None and dropout_dim is None:
                raise ValueError(
                    "norm_dim or dropout_dim needs to be specified.")
            drop_type = Dropout[drop_name, dropout_dim or norm_dim]
            op_dict["D"] = drop_type(**drop_args)

        for item in ordering.upper():
            if item not in op_dict:
                raise ValueError(
                    f"ordering must be a string of {op_dict}, got {item} in it."
                )
            if op_dict[item] is not None:
                self.add_module(item, op_dict[item])  # type: ignore
Ejemplo n.º 3
0
    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        r: int = 2,
        acti_type_1: Union[Tuple[str, Dict], str] = ("relu", {
            "inplace": True
        }),
        acti_type_2: Union[Tuple[str, Dict], str] = "sigmoid",
        add_residual: bool = False,
    ) -> None:
        """
        Args:
            spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
            in_channels: number of input channels.
            r: the reduction ratio r in the paper. Defaults to 2.
            acti_type_1: activation type of the hidden squeeze layer. Defaults to ``("relu", {"inplace": True})``.
            acti_type_2: activation type of the output squeeze layer. Defaults to "sigmoid".

        Raises:
            ValueError: When ``r`` is nonpositive or larger than ``in_channels``.

        See also:

            :py:class:`monai.networks.layers.Act`

        """
        super(ChannelSELayer, self).__init__()

        self.add_residual = add_residual

        pool_type = Pool[Pool.ADAPTIVEAVG, spatial_dims]
        self.avg_pool = pool_type(1)  # spatial size (1, 1, ...)

        channels = int(in_channels // r)
        if channels <= 0:
            raise ValueError(
                f"r must be positive and smaller than in_channels, got r={r} in_channels={in_channels}."
            )

        act_1, act_1_args = split_args(acti_type_1)
        act_2, act_2_args = split_args(acti_type_2)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, channels, bias=True),
            Act[act_1](**act_1_args),
            nn.Linear(channels, in_channels, bias=True),
            Act[act_2](**act_2_args),
        )
Ejemplo n.º 4
0
def get_norm_layer(name: Union[Tuple, str],
                   spatial_dims: Optional[int] = 1,
                   channels: Optional[int] = 1):
    """
    Create a normalization layer instance.

    For example, to create normalization layers:

    .. code-block:: python

        from monai.networks.layers import get_norm_layer

        g_layer = get_norm_layer(name=("group", {"num_groups": 1}))
        n_layer = get_norm_layer(name="instance", spatial_dims=2)

    Args:
        name: a normalization type string or a tuple of type string and parameters.
        spatial_dims: number of spatial dimensions of the input.
        channels: number of features/channels when the normalization layer requires this parameter
            but it is not specified in the norm parameters.
    """
    norm_name, norm_args = split_args(name)
    norm_type = Norm[norm_name, spatial_dims]
    kw_args = dict(norm_args)
    if has_option(norm_type, "num_features") and "num_features" not in kw_args:
        kw_args["num_features"] = channels
    if has_option(norm_type, "num_channels") and "num_channels" not in kw_args:
        kw_args["num_channels"] = channels
    return norm_type(**kw_args)
Ejemplo n.º 5
0
def get_dropout_layer(name: Union[Tuple, str, float, int],
                      dropout_dim: Optional[int] = 1):
    """
    Create a dropout layer instance.

    For example, to create dropout layers:

    .. code-block:: python

        from monai.networks.layers import get_dropout_layer

        d_layer = get_dropout_layer(name="dropout")
        a_layer = get_dropout_layer(name=("alphadropout", {"p": 0.25}))

    Args:
        name: a dropout ratio or a tuple of dropout type and parameters.
        dropout_dim: the spatial dimension of the dropout operation.
    """
    if isinstance(name, (int, float)):
        # if dropout was specified simply as a p value, use default name and make a keyword map with the value
        drop_name = Dropout.DROPOUT
        drop_args = {"p": float(name)}
    else:
        drop_name, drop_args = split_args(name)
    drop_type = Dropout[drop_name, dropout_dim]
    return drop_type(**drop_args)
Ejemplo n.º 6
0
def build_net():

    from init import Options
    opt = Options().parse()
    from monai.networks.layers import Norm
    from monai.networks.layers.factories import split_args
    act_type, args = split_args("RELU")

    # # create Unet
    # Unet = monai.networks.nets.UNet(
    #     dimensions=3,
    #     in_channels=opt.in_channels,
    #     out_channels=opt.out_channels,
    #     channels=(64, 128, 256, 512, 1024),
    #     strides=(2, 2, 2, 2),
    #     act=act_type,
    #     num_res_units=3,
    #     dropout=0.2,
    #     norm=Norm.BATCH,
    #
    # )

    # create nn-Unet
    if opt.resolution is None:
        sizes, spacings = opt.patch_size, opt.spacing
    else:
        sizes, spacings = opt.patch_size, opt.resolution

    strides, kernels = [], []

    while True:
        spacing_ratio = [sp / min(spacings) for sp in spacings]
        stride = [
            2 if ratio <= 2 and size >= 8 else 1
            for (ratio, size) in zip(spacing_ratio, sizes)
        ]
        kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
        if all(s == 1 for s in stride):
            break
        sizes = [i / j for i, j in zip(sizes, stride)]
        spacings = [i * j for i, j in zip(spacings, stride)]
        kernels.append(kernel)
        strides.append(stride)
    strides.insert(0, len(spacings) * [1])
    kernels.append(len(spacings) * [3])

    nn_Unet = monai.networks.nets.DynUNet(
        spatial_dims=3,
        in_channels=opt.in_channels,
        out_channels=opt.out_channels,
        kernel_size=kernels,
        strides=strides,
        upsample_kernel_size=strides[1:],
        res_block=True,
    )

    init_weights(nn_Unet, init_type='normal')

    return nn_Unet
Ejemplo n.º 7
0
    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        r: int = 2,
        acti_type_1=("relu", {
            "inplace": True
        }),
        acti_type_2="sigmoid",
    ) -> None:
        """
        Args:
            spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
            in_channels: number of input channels.
            r: the reduction ratio r in the paper. Defaults to 2.
            acti_type_1: activation type of the hidden squeeze layer. Defaults to ``("relu", {"inplace": True})``.
            acti_type_2: activation type of the output squeeze layer. Defaults to "sigmoid".

        See also:

            :py:class:`monai.networks.layers.Act`

        Raises:
            ValueError: r must be a positive number smaller than `in_channels`.

        """
        super(ChannelSELayer, self).__init__()

        pool_type = Pool[Pool.ADAPTIVEAVG, spatial_dims]
        self.avg_pool = pool_type(1)  # spatial size (1, 1, ...)

        channels = int(in_channels // r)
        if channels <= 0:
            raise ValueError(
                "r must be a positive number smaller than `in_channels`.")

        act_1, act_1_args = split_args(acti_type_1)
        act_2, act_2_args = split_args(acti_type_2)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, channels, bias=True),
            Act[act_1](**act_1_args),
            nn.Linear(channels, in_channels, bias=True),
            Act[act_2](**act_2_args),
        )
Ejemplo n.º 8
0
def get_act_layer(name: Union[Tuple, str]):
    """
    Create an activation layer instance.

    For example, to create activation layers:

    .. code-block:: python

        from monai.networks.layers import get_act_layer

        s_layer = get_act_layer(name="swish")
        p_layer = get_act_layer(name=("prelu", {"num_parameters": 1, "init": 0.25}))

    Args:
        name: an activation type string or a tuple of type string and parameters.
    """
    act_name, act_args = split_args(name)
    act_type = Act[act_name]
    return act_type(**act_args)
Ejemplo n.º 9
0
def get_pool_layer(name: Union[Tuple, str], spatial_dims: Optional[int] = 1):
    """
    Create a pooling layer instance.

    For example, to create adaptiveavg layer:

    .. code-block:: python

        from monai.networks.layers import get_pool_layer

        pool_layer = get_pool_layer(("adaptiveavg", {"output_size": (1, 1, 1)}), spatial_dims=3)

    Args:
        name: a pooling type string or a tuple of type string and parameters.
        spatial_dims: number of spatial dimensions of the input.

    """
    pool_name, pool_args = split_args(name)
    pool_type = Pool[pool_name, spatial_dims]
    return pool_type(**pool_args)
Ejemplo n.º 10
0
    def __init__(
        self,
        in_shape: Sequence[int],
        classes: int,
        channels: Sequence[int],
        strides: Sequence[int],
        kernel_size: Union[Sequence[int], int] = 3,
        num_res_units: int = 2,
        act=Act.PRELU,
        norm=Norm.INSTANCE,
        dropout: Optional[float] = None,
        bias: bool = True,
        last_act: Optional[str] = None,
    ) -> None:
        super().__init__(in_shape, (classes,), channels, strides, kernel_size, num_res_units, act, norm, dropout, bias)

        if last_act is not None:
            last_act_name, last_act_args = split_args(last_act)
            last_act_type = Act[last_act_name]

            self.final.add_module("lastact", last_act_type(**last_act_args))
Ejemplo n.º 11
0
    def __init__(
        self,
        in_shape,
        classes,
        channels,
        strides,
        kernel_size=3,
        num_res_units=2,
        act=Act.PRELU,
        norm=Norm.INSTANCE,
        dropout=None,
        bias=True,
        last_act=None,
    ):
        super().__init__(in_shape, (classes,), channels, strides, kernel_size, num_res_units, act, norm, dropout, bias)

        if last_act is not None:
            last_act_name, last_act_args = split_args(last_act)
            last_act_type = Act[last_act_name]

            self.final.add_module("lastact", last_act_type(**last_act_args))
Ejemplo n.º 12
0
    def __init__(
        self,
        in_shape: Sequence[int],
        classes: int,
        channels: Sequence[int],
        strides: Sequence[int],
        kernel_size: Union[Sequence[int], int] = 3,
        num_res_units: int = 2,
        act=Act.PRELU,
        norm=Norm.INSTANCE,
        dropout: Optional[float] = None,
        bias: bool = True,
        last_act: Optional[str] = None,
    ) -> None:
        """
        Args:
            in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension)
            classes: integer stating the dimension of the final output tensor
            channels: tuple of integers stating the output channels of each convolutional layer
            strides: tuple of integers stating the stride (downscale factor) of each convolutional layer
            kernel_size: integer or tuple of integers stating size of convolutional kernels
            num_res_units: integer stating number of convolutions in residual units, 0 means no residual units
            act: name or type defining activation layers
            norm: name or type defining normalization layers
            dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout
            bias: boolean stating if convolution layers should have a bias component
            last_act: name defining the last activation layer
        """
        super().__init__(in_shape, (classes, ), channels, strides, kernel_size,
                         num_res_units, act, norm, dropout, bias)

        if last_act is not None:
            last_act_name, last_act_args = split_args(last_act)
            last_act_type = Act[last_act_name]

            self.final.add_module("lastact", last_act_type(**last_act_args))
Ejemplo n.º 13
0
def get_acti_layer(act: Union[Tuple[str, Dict], str]):
    act_name, act_args = split_args(act)
    act_type = Act[act_name]
    return act_type(**act_args)
Ejemplo n.º 14
0
    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        n_chns_1: int,
        n_chns_2: int,
        n_chns_3: int,
        conv_param_1: Optional[Dict] = None,
        conv_param_2: Optional[Dict] = None,
        conv_param_3: Optional[Dict] = None,
        project: Optional[Convolution] = None,
        r: int = 2,
        acti_type_1: Union[Tuple[str, Dict], str] = ("relu", {
            "inplace": True
        }),
        acti_type_2: Union[Tuple[str, Dict], str] = "sigmoid",
        acti_type_final: Optional[Union[Tuple[str, Dict], str]] = ("relu", {
            "inplace":
            True
        }),
    ):
        """
        Args:
            spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
            in_channels: number of input channels.
            n_chns_1: number of output channels in the 1st convolution.
            n_chns_2: number of output channels in the 2nd convolution.
            n_chns_3: number of output channels in the 3rd convolution.
            conv_param_1: additional parameters to the 1st convolution.
                Defaults to ``{"kernel_size": 1, "norm": Norm.BATCH, "act": ("relu", {"inplace": True})}``
            conv_param_2: additional parameters to the 2nd convolution.
                Defaults to ``{"kernel_size": 3, "norm": Norm.BATCH, "act": ("relu", {"inplace": True})}``
            conv_param_3: additional parameters to the 3rd convolution.
                Defaults to ``{"kernel_size": 1, "norm": Norm.BATCH, "act": None}``
            project: in the case of residual chns and output chns doesn't match, a project
                (Conv) layer/block is used to adjust the number of chns. In SENET, it is
                consisted with a Conv layer as well as a Norm layer.
                Defaults to None (chns are matchable) or a Conv layer with kernel size 1.
            r: the reduction ratio r in the paper. Defaults to 2.
            acti_type_1: activation type of the hidden squeeze layer. Defaults to "relu".
            acti_type_2: activation type of the output squeeze layer. Defaults to "sigmoid".
            acti_type_final: activation type of the end of the block. Defaults to "relu".

        See also:

            :py:class:`monai.networks.blocks.ChannelSELayer`

        """
        super(SEBlock, self).__init__()

        if not conv_param_1:
            conv_param_1 = {
                "kernel_size": 1,
                "norm": Norm.BATCH,
                "act": ("relu", {
                    "inplace": True
                })
            }
        self.conv1 = Convolution(dimensions=spatial_dims,
                                 in_channels=in_channels,
                                 out_channels=n_chns_1,
                                 **conv_param_1)

        if not conv_param_2:
            conv_param_2 = {
                "kernel_size": 3,
                "norm": Norm.BATCH,
                "act": ("relu", {
                    "inplace": True
                })
            }
        self.conv2 = Convolution(dimensions=spatial_dims,
                                 in_channels=n_chns_1,
                                 out_channels=n_chns_2,
                                 **conv_param_2)

        if not conv_param_3:
            conv_param_3 = {"kernel_size": 1, "norm": Norm.BATCH, "act": None}
        self.conv3 = Convolution(dimensions=spatial_dims,
                                 in_channels=n_chns_2,
                                 out_channels=n_chns_3,
                                 **conv_param_3)

        self.se_layer = ChannelSELayer(spatial_dims=spatial_dims,
                                       in_channels=n_chns_3,
                                       r=r,
                                       acti_type_1=acti_type_1,
                                       acti_type_2=acti_type_2)

        self.project = project
        if self.project is None and in_channels != n_chns_3:
            self.project = Conv[Conv.CONV, spatial_dims](in_channels,
                                                         n_chns_3,
                                                         kernel_size=1)

        self.act = None
        if acti_type_final is not None:
            act_final, act_final_args = split_args(acti_type_final)
            self.act = Act[act_final](**act_final_args)
Ejemplo n.º 15
0
    def __init__(
        self,
        dimensions: int,
        in_channels: int,
        out_channels: int,
        strides: int = 1,
        kernel_size: Union[Sequence[int], int] = 3,
        act: Optional[Union[Tuple, str]] = Act.PRELU,
        norm: Union[Tuple, str] = Norm.INSTANCE,
        dropout: Optional[Union[Tuple, str, float]] = None,
        dropout_dim: int = 1,
        dilation: Union[Sequence[int], int] = 1,
        groups: int = 1,
        bias: bool = True,
        conv_only: bool = False,
        is_transposed: bool = False,
    ) -> None:
        super().__init__()
        self.dimensions = dimensions
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.is_transposed = is_transposed

        padding = same_padding(kernel_size, dilation)
        conv_type = Conv[Conv.CONVTRANS if is_transposed else Conv.CONV, dimensions]

        # define the normalisation type and the arguments to the constructor
        if norm is not None:
            norm_name, norm_args = split_args(norm)
            norm_type = Norm[norm_name, dimensions]
        else:
            norm_type = norm_args = None

        # define the activation type and the arguments to the constructor
        if act is not None:
            act_name, act_args = split_args(act)
            act_type = Act[act_name]
        else:
            act_type = act_args = None

        if dropout:
            # if dropout was specified simply as a p value, use default name and make a keyword map with the value
            if isinstance(dropout, (int, float)):
                drop_name = Dropout.DROPOUT
                drop_args = {"p": dropout}
            else:
                drop_name, drop_args = split_args(dropout)

            if dropout_dim > dimensions:
                raise ValueError(
                    f"dropout_dim should be no larger than dimensions, got dropout_dim={dropout_dim} and dimensions={dimensions}."
                )
            drop_type = Dropout[drop_name, dropout_dim]

        if is_transposed:
            conv = conv_type(
                in_channels,
                out_channels,
                kernel_size=kernel_size,
                stride=strides,
                padding=padding,
                output_padding=strides - 1,
                groups=groups,
                bias=bias,
                dilation=dilation,
            )
        else:
            conv = conv_type(
                in_channels,
                out_channels,
                kernel_size=kernel_size,
                stride=strides,
                padding=padding,
                dilation=dilation,
                groups=groups,
                bias=bias,
            )

        self.add_module("conv", conv)

        if not conv_only:
            if norm is not None:
                self.add_module("norm", norm_type(out_channels, **norm_args))

            if dropout:
                self.add_module("dropout", drop_type(**drop_args))

            if act is not None:
                self.add_module("act", act_type(**act_args))
Ejemplo n.º 16
0
Archivo: vnet.py Proyecto: lsho76/MONAI
def get_acti_layer(act: Union[Tuple[str, Dict], str], nchan: int = 0):
    if act == "prelu":
        act = ("prelu", {"num_parameters": nchan})
    act_name, act_args = split_args(act)
    act_type = Act[act_name]
    return act_type(**act_args)
def build_net():
    from init import Options
    opt = Options().parse()
    from monai.networks.layers import Norm
    from monai.networks.layers.factories import split_args
    act_type, args = split_args("RELU")

    # create Unet
    Unet = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=opt.in_channels,
        out_channels=opt.out_channels,
        channels=(32, 64, 128, 256, 512),
        strides=(2, 2, 2, 2),
        act=act_type,
        num_res_units=3,
        dropout=0.2,
        norm=Norm.BATCH,

    )

    class UNet_David(Module):
        # __                            __
        #  1|__   ________________   __|1
        #     2|__  ____________  __|2
        #        3|__  ______  __|3
        #           4|__ __ __|4
        # The convolution operations on either side are residual subject to 1*1 Convolution for channel homogeneity

        def __init__(self, feat_channels=[32, 64, 128, 256, 512], residual='conv'):
            # residual: conv for residual input x through 1*1 conv across every layer for downsampling, None for removal of residuals

            super(UNet_David, self).__init__()

            class Conv3D_Block(Module):

                def __init__(self, inp_feat, out_feat, kernel=3, stride=1, padding=1, residual=None):

                    super(Conv3D_Block, self).__init__()

                    self.conv1 = Sequential(
                        Conv3d(inp_feat, out_feat, kernel_size=kernel,
                               stride=stride, padding=padding, bias=True),
                        BatchNorm3d(out_feat),
                        ReLU())

                    self.conv2 = Sequential(
                        Conv3d(out_feat, out_feat, kernel_size=kernel,
                               stride=stride, padding=padding, bias=True),
                        BatchNorm3d(out_feat),
                        ReLU())

                    self.residual = residual

                    if self.residual is not None:
                        self.residual_upsampler = Conv3d(inp_feat, out_feat, kernel_size=1, bias=False)

                def forward(self, x):

                    res = x

                    if not self.residual:
                        return self.conv2(self.conv1(x))
                    else:
                        return self.conv2(self.conv1(x)) + self.residual_upsampler(res)

            class Deconv3D_Block(Module):

                def __init__(self, inp_feat, out_feat, kernel=3, stride=2, padding=1):
                    super(Deconv3D_Block, self).__init__()

                    self.deconv = Sequential(
                        ConvTranspose3d(inp_feat, out_feat, kernel_size=(kernel, kernel, kernel),
                                        stride=(stride, stride, stride), padding=(padding, padding, padding),
                                        output_padding=1, bias=True),
                        ReLU())

                def forward(self, x):
                    return self.deconv(x)

            class ChannelPool3d(AvgPool1d):

                def __init__(self, kernel_size, stride, padding):
                    super(ChannelPool3d, self).__init__(kernel_size, stride, padding)
                    self.pool_1d = AvgPool1d(self.kernel_size, self.stride, self.padding, self.ceil_mode)

                def forward(self, inp):
                    n, c, d, w, h = inp.size()
                    inp = inp.view(n, c, d * w * h).permute(0, 2, 1)
                    pooled = self.pool_1d(inp)
                    c = int(c / self.kernel_size[0])
                    return inp.view(n, c, d, w, h)

            # Encoder downsamplers
            self.pool1 = MaxPool3d((2, 2, 2))
            self.pool2 = MaxPool3d((2, 2, 2))
            self.pool3 = MaxPool3d((2, 2, 2))
            self.pool4 = MaxPool3d((2, 2, 2))

            # Encoder convolutions
            self.conv_blk1 = Conv3D_Block(opt.in_channels, feat_channels[0], residual=residual)
            self.conv_blk2 = Conv3D_Block(feat_channels[0], feat_channels[1], residual=residual)
            self.conv_blk3 = Conv3D_Block(feat_channels[1], feat_channels[2], residual=residual)
            self.conv_blk4 = Conv3D_Block(feat_channels[2], feat_channels[3], residual=residual)
            self.conv_blk5 = Conv3D_Block(feat_channels[3], feat_channels[4], residual=residual)

            # Decoder convolutions
            self.dec_conv_blk4 = Conv3D_Block(2 * feat_channels[3], feat_channels[3], residual=residual)
            self.dec_conv_blk3 = Conv3D_Block(2 * feat_channels[2], feat_channels[2], residual=residual)
            self.dec_conv_blk2 = Conv3D_Block(2 * feat_channels[1], feat_channels[1], residual=residual)
            self.dec_conv_blk1 = Conv3D_Block(2 * feat_channels[0], feat_channels[0], residual=residual)

            # Decoder upsamplers
            self.deconv_blk4 = Deconv3D_Block(feat_channels[4], feat_channels[3])
            self.deconv_blk3 = Deconv3D_Block(feat_channels[3], feat_channels[2])
            self.deconv_blk2 = Deconv3D_Block(feat_channels[2], feat_channels[1])
            self.deconv_blk1 = Deconv3D_Block(feat_channels[1], feat_channels[0])

            # Final 1*1 Conv Segmentation map
            self.one_conv = Conv3d(feat_channels[0], opt.out_channels, kernel_size=1, stride=1, padding=0, bias=True)

        def forward(self, x):
            # Encoder part

            x1 = self.conv_blk1(x)

            x_low1 = self.pool1(x1)
            x2 = self.conv_blk2(x_low1)

            x_low2 = self.pool2(x2)
            x3 = self.conv_blk3(x_low2)

            x_low3 = self.pool3(x3)
            x4 = self.conv_blk4(x_low3)

            x_low4 = self.pool4(x4)
            base = self.conv_blk5(x_low4)

            # Decoder part

            d4 = torch.cat([self.deconv_blk4(base), x4], dim=1)
            d_high4 = self.dec_conv_blk4(d4)

            d3 = torch.cat([self.deconv_blk3(d_high4), x3], dim=1)
            d_high3 = self.dec_conv_blk3(d3)
            d_high3 = Dropout3d(p=0.5)(d_high3)

            d2 = torch.cat([self.deconv_blk2(d_high3), x2], dim=1)
            d_high2 = self.dec_conv_blk2(d2)
            d_high2 = Dropout3d(p=0.5)(d_high2)

            d1 = torch.cat([self.deconv_blk1(d_high2), x1], dim=1)
            d_high1 = self.dec_conv_blk1(d1)

            seg = self.one_conv(d_high1)

            return seg

    # create HighResNet
    HighResNet = monai.networks.nets.HighResNet(
        spatial_dims=3,
        in_channels=opt.in_channels,
        out_channels=opt.out_channels,
    )

    if opt.net == 'Unet_Monai':
        network = Unet
    elif opt.net == 'HighResNet':
        network = HighResNet
    elif opt.net == 'Unet_David':
        network = UNet_David(residual='pool')
    else:
        raise NotImplementedError

    init_weights(network, init_type='normal')

    return network