def __init__(
            self,
            in_channel: int,
            out_channel: int,
            kernel_size: int = 3,
            padding: int = 1,
            kernels_per_layer: int = 1,
            norm_layer_name: str = "BatchNorm2d",
            act_fnc_name: str = "ReLU",
            dimension: int = 2,
            **kwargs  # For the normalization layer
    ) -> None:
        super().__init__(in_channel, out_channel, dimension=dimension)

        self.depth_wise_conv = nn.Sequential(
            self.conv(self.in_channel,
                      self.in_channel * kernels_per_layer,
                      kernel_size,
                      padding=padding,
                      groups=self.in_channel),
            get_attr_if_exists(nn, norm_layer_name)(
                self.in_channel * kernels_per_layer, **kwargs),
            get_attr_if_exists(nn, act_fnc_name)())
        self.pixel_wise_conv = nn.Sequential(
            self.conv(
                self.in_channel * kernels_per_layer,
                out_channel,
                kernel_size=1,
            ),
            get_attr_if_exists(nn, norm_layer_name)(
                self.in_channel * kernels_per_layer, **kwargs))
    def _build_stage(self, stage_idx: str, **kwargs) -> nn.Module:
        access_idx = int(
            stage_idx) - 1  # Since the naming of stages is not 0-based
        num_stages = len(self.num_layers)

        if 0 < access_idx and access_idx < 8:  # If it's Stage 2 ~ 8
            path_dropout = kwargs.pop("path_dropout")

            return nn.Sequential(
                OrderedDict([
                    (
                        f"layer_{idx}",
                        MBConvXd(
                            in_channel=self.channels[access_idx - 1]
                            if idx == 0 else self.channels[access_idx],
                            out_channel=self.channels[access_idx],
                            expand_scale=self.expand_scales[access_idx],
                            kernel_size=self.kernel_sizes[access_idx],
                            # only for the first MBConvXd block
                            stride=self.strides[access_idx] if idx == 0 else 1,
                            # (kernel_size = 1 -> padding = 0), (kernel_size = 3 -> padding = 1), (kernel_size = 5 -> padding = 2)
                            padding=self.kernel_sizes[access_idx] // 2,
                            se_scale=self.se_scales[access_idx],
                            # Inverted `survival_prob` from: https://github.com/tensorflow/tpu/blob/3679ca6b979349dde6da7156be2528428b000c7c/models/official/efficientnet/efficientnet_model.py#L659
                            path_dropout=path_dropout * (access_idx + 1) /
                            (num_stages - 2),  # 2 for the first and last stage
                            **kwargs))
                    for idx in range(self.num_layers[access_idx])
                ]))

        else:
            return nn.Sequential(
                OrderedDict([
                    (f"layer_0",
                     nn.Sequential(
                         nn.Conv2d(kwargs['in_channel'],
                                   kwargs['out_channel'],
                                   kwargs['kernel_size'],
                                   padding=kwargs['kernel_size'] // 2,
                                   bias=False),
                         nn.BatchNorm2d(
                             kwargs['out_channel'],
                             eps=kwargs['eps'],
                             momentum=kwargs['momentum'],
                         ),
                         get_attr_if_exists(nn, kwargs['act_fnc_name'])()))
                ]))
    def _build_stage(self, stage_idx: int, **kwargs) -> nn.Module:
        num_stages = len(self.num_layers)

        if 0 < stage_idx and stage_idx < (num_stages - 1):
            path_dropout = kwargs.pop("path_dropout")
            conv_block = FusedMBConvXd if self.se_scales[stage_idx] is None else MBConvXd

            return nn.Sequential(OrderedDict([
                (
                    f"stage_{stage_idx}_layer_{idx}",
                    conv_block(
                        in_channel=self.channels[stage_idx - 1] if idx == 0 else self.channels[stage_idx],
                        out_channel=self.channels[stage_idx],
                        expand_scale=self.expand_scales[stage_idx],
                        kernel_size=self.kernel_sizes[stage_idx],
                        # only for the first MBConvXd block
                        stride=self.strides[stage_idx] if idx == 0 else 1,
                        padding=self.kernel_sizes[stage_idx] // 2,
                        se_scale=self.se_scales[stage_idx],
                        # Inverted `survival_prob` from: https://github.com/tensorflow/tpu/blob/3679ca6b979349dde6da7156be2528428b000c7c/models/official/efficientnet/efficientnet_model.py#L659
                        path_dropout=path_dropout*(stage_idx + 1) / (num_stages - 2),  # 2 for the first and last stage
                        **kwargs
                    ) 
                ) for idx in range(self.num_layers[stage_idx])
            ]))

        else:
            return nn.Sequential(OrderedDict([
                (
                    f"stage_{stage_idx}_layer_0",
                    nn.Sequential(
                        nn.Conv2d(
                            kwargs['in_channel'],
                            kwargs['out_channel'],
                            kwargs['kernel_size'],
                            padding=kwargs['kernel_size'] // 2,
                            bias=False
                        ),
                        nn.BatchNorm2d(
                            kwargs['out_channel'],
                            eps=kwargs['eps'],
                            momentum=kwargs['momentum'],
                        ),
                        get_attr_if_exists(nn, kwargs['act_fnc_name'])()
                    )
                )
            ]))
    def __init__(
        self,
        in_channel: int,
        bottleneck_channel: int,
        out_channel: Optional[int] = None,
        se_act_fnc_name: Optional[str] = None,
        dimension: int = 2,
    ) -> None:
        super().__init__(in_channel, out_channel, dimension=dimension)

        self.layers = nn.Sequential(
            nn.AdaptiveMaxPool2d((1, 1)),
            self.conv(self.in_channel, bottleneck_channel, 1),
            get_attr_if_exists(nn, se_act_fnc_name)(),
            self.conv(bottleneck_channel, self.out_channel, 1),
            nn.Sigmoid(),
        )
    def __init__(
        self,
        in_channel: int,
        out_channel: Optional[int] = None,
        expand_channel: Optional[int] = None,
        expand_scale: Optional[int] = None,
        kernel_size: int = 3,
        stride: int = 1,
        padding: int = 1,
        norm_layer_name: str = "BatchNorm2d",
        act_fnc_name: str = "SiLU",
        se_scale: Optional[float] = None,
        se_act_fnc_name: str = "SiLU",
        dimension: int = 2,
        path_dropout: float = 0.,
        expansion_head_type: Literal["pixel_depth", "fused"] = "pixel_depth",
        **
        kwargs  # For example: `eps` and `elementwise_affine` for `nn.LayerNorm`
    ):
        super().__init__(in_channel, out_channel, dimension=dimension)

        assert (
            expand_channel is not None or expand_scale is not None
        ), name_with_msg(
            self,
            "Either `expand_channel` or `expand_scale` should be specified")
        expand_channel = expand_channel if expand_channel is not None else in_channel * expand_scale

        assert (
            isinstance(expansion_head_type, str)
            and expansion_head_type in ["pixel_depth", "fused"]
        ), name_with_msg(
            f"The specified `expansion_head_type` - {expansion_head_type} ({type(expansion_head_type)}) doesn't exist.\n \
            Please choose from here: ['pixel_depth', 'fused']")

        # Expansion Head
        if expansion_head_type == "pixel_depth":
            pixel_wise_conv_0 = nn.Sequential(
                self.conv(self.in_channel,
                          expand_channel,
                          kernel_size=1,
                          bias=False),
                get_attr_if_exists(nn, norm_layer_name)(expand_channel,
                                                        **kwargs),
                get_attr_if_exists(nn, act_fnc_name)())

            depth_wise_conv = nn.Sequential(
                self.conv(expand_channel,
                          expand_channel,
                          kernel_size,
                          stride=stride,
                          padding=padding,
                          groups=expand_channel,
                          bias=False),
                get_attr_if_exists(nn, norm_layer_name)(expand_channel,
                                                        **kwargs),
                get_attr_if_exists(nn, act_fnc_name)())

            self.expansion_head = nn.Sequential(pixel_wise_conv_0,
                                                depth_wise_conv)
        else:
            self.expansion_head = nn.Sequential(
                nn.Conv2d(self.in_channel,
                          expand_channel,
                          kernel_size,
                          stride=stride,
                          padding=padding,
                          bias=False),
                get_attr_if_exists(nn, norm_layer_name)(expand_channel,
                                                        **kwargs),
                get_attr_if_exists(nn, act_fnc_name)())

        #
        self.se_block = None
        if se_scale is not None:
            bottleneck_channel = int(expand_channel * se_scale)

            self.se_block = SEConvXd(
                expand_channel,
                bottleneck_channel,
                se_act_fnc_name=se_act_fnc_name,
            )

        #
        self.pixel_wise_conv_1 = nn.Sequential(
            self.conv(
                expand_channel,
                self.out_channel,
                kernel_size=1,
                bias=False,
            ),
            get_attr_if_exists(nn, norm_layer_name)(self.out_channel,
                                                    **kwargs))

        # From: https://github.com/tensorflow/tpu/blob/3679ca6b979349dde6da7156be2528428b000c7c/models/official/efficientnet/utils.py#L276
        # It's a batch-wise dropout
        self.path_dropout = PathDropout(path_dropout)
        self.skip = True if self.in_channel == self.out_channel and stride == 1 else False