Ejemplo n.º 1
0
    def __init__(self,
                 in_channels: int,
                 num_filters: int,
                 filter_length: int,
                 subsample_length: int,
                 groups: int = 1,
                 dilation: int = 1,
                 dropouts: Union[float, Sequence[float]] = 0,
                 **config) -> NoReturn:
        """ finished, NOT checked,

        Parameters
        ----------
        in_channels: int,
            number of features (channels) of the input
        num_filters: int,
            number of filters for the convolutional layers
        filter_length: int,
            length (size) of the filter kernels
        subsample_lengths: int,
            subsample length,
            including pool size for short cut, and stride for the top convolutional layer
        groups: int, default 1,
            pattern of connections between inputs and outputs,
            for more details, ref. `nn.Conv1d`
        dilation: int, default 1,
            dilation of the convolutional layers
        dropouts: float, or sequence of float, default 0.0,
            dropout ratio after each convolution (and batch normalization, and activation, etc.)
        config: dict,
            other hyper-parameters, including
            filter length (kernel size), activation choices, weight initializer,
            and short cut patterns, etc.
        """
        super().__init__()
        self.__num_convs = 2
        self.__in_channels = in_channels
        self.__out_channels = num_filters
        self.__kernel_size = filter_length
        self.__down_scale = subsample_length
        self.__stride = subsample_length
        self.__groups = groups
        self.__dilation = dilation
        if isinstance(dropouts, float):
            self.__dropouts = list(repeat(dropouts, self.__num_convs))
        else:
            self.__dropouts = list(dropouts)
        assert len(self.__dropouts) == self.__num_convs
        self.config = ED(deepcopy(config))
        if self.__DEBUG__:
            print(
                f"configuration of {self.__name__} is as follows\n{dict_to_str(self.config)}"
            )

        self.__increase_channels = (self.__out_channels > self.__in_channels)
        self.shortcut = self._make_shortcut_layer()

        self.main_stream = nn.Sequential()
        conv_in_channels = self.__in_channels
        for i in range(self.__num_convs):
            conv_activation = (self.config.activation
                               if i < self.__num_convs - 1 else None)
            self.main_stream.add_module(
                f"cba_{i}",
                Conv_Bn_Activation(
                    in_channels=conv_in_channels,
                    out_channels=self.__out_channels,
                    kernel_size=self.__kernel_size,
                    stride=(self.__stride if i == 0 else 1),
                    dilation=self.__dilation,
                    groups=self.__groups,
                    batch_norm=True,
                    activation=conv_activation,
                    kw_activation=self.config.kw_activation,
                    kernel_initializer=self.config.kernel_initializer,
                    kw_initializer=self.config.kw_initializer,
                    bias=self.config.bias,
                ))
            conv_in_channels = self.__out_channels
            if i == 0 and self.__dropouts[i] > 0:
                self.main_stream.add_module(f"dropout_{i}",
                                            nn.Dropout(self.__dropouts[i]))
            if i == 1:
                self.main_stream.add_module(
                    f"gcb",
                    GlobalContextBlock(
                        in_channels=self.__out_channels,
                        ratio=self.config.gcb.ratio,
                        reduction=self.config.gcb.reduction,
                        pooling_type=self.config.gcb.pooling_type,
                        fusion_types=self.config.gcb.fusion_types,
                    ))

        if isinstance(self.config.activation, str):
            self.out_activation = \
                Activations[self.config.activation.lower()](**self.config.kw_activation)
        else:
            self.out_activation = \
                self.config.activation(**self.config.kw_activation)

        if self.__dropouts[1] > 0:
            self.out_dropout = nn.Dropout(self.__dropouts[1])
        else:
            self.out_dropout = None
Ejemplo n.º 2
0
    def __init__(self, stage: int, out_branches: int, in_channels: int,
                 **config) -> NoReturn:
        """ NOT finished, NOT checked,
        """
        super().__init__()
        self.stage = stage
        self.out_branches = out_branches
        self.in_channels = in_channels
        self.config = ED(config)

        self.branches = nn.ModuleList()
        for i in range(self.stage):
            w = in_channels * (2**i)
            branch = nn.Sequential(
                ResNetGCBlock(in_channels=w,
                              num_filters=w,
                              **(config.resnet_gc)),
                ResNetGCBlock(in_channels=w,
                              num_filters=w,
                              **(config.resnet_gc)),
                ResNetGCBlock(in_channels=w,
                              num_filters=w,
                              **(config.resnet_gc)),
            )
            self.branches.append(branch)

        self.fuse_layers = nn.ModuleList()
        for i in range(self.out_branches):
            fl = nn.ModuleList()
            for j in range(self.stage):
                if i == j:
                    fl.append(nn.Sequential())
                elif i < j:
                    if i == 0:
                        fl.append(
                            nn.Sequential(
                                nn.Conv1d(in_channels * (2**j),
                                          in_channels * (2**i),
                                          kernel_size=1,
                                          stride=1),
                                nn.BatchNorm1d(in_channels * (2**i)),
                                nn.Upsample(size=625),
                            ))
                    elif i == 1:
                        fl.append(
                            nn.Sequential(
                                nn.Conv1d(in_channels * (2**j),
                                          in_channels * (2**i),
                                          kernel_size=1,
                                          stride=1),
                                nn.BatchNorm1d(in_channels * (2**i)),
                                nn.Upsample(size=313)))
                    elif i == 2:
                        fl.append(
                            nn.Sequential(
                                nn.Conv1d(in_channels * (2**j),
                                          in_channels * (2**i),
                                          kernel_size=1,
                                          stride=1),
                                nn.BatchNorm1d(in_channels * (2**i)),
                                nn.Upsample(size=157)))

                elif i > j:
                    opts = []
                    if i == j + 1:
                        opts.append(
                            Conv_Bn_Activation(
                                in_channels=in_channels * (2**j),
                                out_channels=in_channels * (2**i),
                                kernel_size=7,
                                stride=2,
                                batch_norm=True,
                                activation=None,
                            ))
                    elif i == j + 2:
                        opts.append(
                            MultiConv(
                                in_channels=in_channels * (2**j),
                                out_channels=[
                                    in_channels * (2**(j + 1)),
                                    in_channels * (2**(j + 2))
                                ],
                                filter_lengths=7,
                                subsample_lengths=2,
                                out_activation=False,
                            ))
                    elif i == j + 3:
                        opts.append(
                            MultiConv(
                                in_channels=in_channels * (2**j),
                                out_channels=[
                                    in_channels * (2**(j + 1)),
                                    in_channels * (2**(j + 2)),
                                    in_channels * (2**(j + 3))
                                ],
                                filter_lengths=7,
                                subsample_lengths=2,
                                out_activation=False,
                            ))
                    fl.append(nn.Sequential(*opts))
            self.fuse_layers.append(fl)
        self.fuse_activation = nn.ReLU(inplace=True)
Ejemplo n.º 3
0
    def __init__(self, classes: Sequence[str], n_leads: int,
                 config: dict) -> NoReturn:
        """ finished, checked,

        Parameters:
        -----------
        classes: sequence of int,
            name of the classes
        n_leads: int,
            number of input leads (number of input channels)
        config: dict,
            other hyper-parameters, including kernel sizes, etc.
            ref. the corresponding config file
        """
        super().__init__()
        self.classes = list(classes)
        self.n_classes = len(classes)  # final out_channels
        self.__out_channels = self.n_classes
        self.__in_channels = n_leads
        self.config = ED(deepcopy(config))
        if self.__DEBUG__:
            print(
                f"configuration of {self.__name__} is as follows\n{dict_to_str(self.config)}"
            )
            __debug_seq_len = 4000

        self.init_conv = DoubleConv(
            in_channels=self.__in_channels,
            out_channels=self.config.init_num_filters,
            filter_lengths=self.config.init_filter_length,
            subsample_lengths=1,
            groups=self.config.groups,
            batch_norm=self.config.batch_norm,
            activation=self.config.activation,
            kw_activation=self.config.kw_activation,
            kernel_initializer=self.config.kernel_initializer,
            kw_initializer=self.config.kw_initializer,
        )
        if self.__DEBUG__:
            __debug_output_shape = self.init_conv.compute_output_shape(
                __debug_seq_len)
            print(
                f"given seq_len = {__debug_seq_len}, init_conv output shape = {__debug_output_shape}"
            )
            _, _, __debug_seq_len = __debug_output_shape

        self.down_blocks = nn.ModuleDict()
        in_channels = self.config.init_num_filters
        for idx in range(self.config.down_up_block_num):
            self.down_blocks[f"down_{idx}"] = \
                DownDoubleConv(
                    down_scale=self.config.down_scales[idx],
                    in_channels=in_channels,
                    out_channels=self.config.down_num_filters[idx],
                    filter_lengths=self.config.down_filter_lengths[idx],
                    groups=self.config.groups,
                    mode=self.config.down_mode,
                    **(self.config.down_block)
                )
            in_channels = self.config.down_num_filters[idx]
            if self.__DEBUG__:
                __debug_output_shape = self.down_blocks[
                    f"down_{idx}"].compute_output_shape(__debug_seq_len)
                print(
                    f"given seq_len = {__debug_seq_len}, down_{idx} output shape = {__debug_output_shape}"
                )
                _, _, __debug_seq_len = __debug_output_shape

        self.up_blocks = nn.ModuleDict()
        in_channels = self.config.down_num_filters[-1]
        for idx in range(self.config.down_up_block_num):
            self.up_blocks[f"up_{idx}"] = \
                UpDoubleConv(
                    up_scale=self.config.up_scales[idx],
                    in_channels=in_channels,
                    out_channels=self.config.up_num_filters[idx],
                    filter_lengths=self.config.up_conv_filter_lengths[idx],
                    deconv_filter_length=self.config.up_deconv_filter_lengths[idx],
                    groups=self.config.groups,
                    mode=self.config.up_mode,
                    **(self.config.up_block)
                )
            in_channels = self.config.up_num_filters[idx]
            if self.__DEBUG__:
                __debug_output_shape = self.up_blocks[
                    f"up_{idx}"].compute_output_shape(__debug_seq_len)
                print(
                    f"given seq_len = {__debug_seq_len}, up_{idx} output shape = {__debug_output_shape}"
                )
                _, _, __debug_seq_len = __debug_output_shape

        self.out_conv = Conv_Bn_Activation(
            in_channels=self.config.up_num_filters[-1],
            out_channels=self.__out_channels,
            kernel_size=self.config.out_filter_length,
            stride=1,
            groups=self.config.groups,
            batch_norm=self.config.batch_norm,
            activation=self.config.activation,
            kw_activation=self.config.kw_activation,
            kernel_initializer=self.config.kernel_initializer,
            kw_initializer=self.config.kw_initializer,
        )
        if self.__DEBUG__:
            __debug_output_shape = self.out_conv.compute_output_shape(
                __debug_seq_len)
            print(
                f"given seq_len = {__debug_seq_len}, out_conv output shape = {__debug_output_shape}"
            )

        # for inference
        # if background counted in `classes`, use softmax
        # otherwise use sigmoid
        self.softmax = nn.Softmax(-1)
        self.sigmoid = nn.Sigmoid()
Ejemplo n.º 4
0
class ECG_UNET(nn.Module):
    """ finished, checked,

    UNet for (multi-lead) ECG wave delineation

    References:
    -----------
    [1] Moskalenko, Viktor, Nikolai Zolotykh, and Grigory Osipov. "Deep Learning for ECG Segmentation." International Conference on Neuroinformatics. Springer, Cham, 2019.
    [2] https://github.com/milesial/Pytorch-UNet/
    """
    __DEBUG__ = True
    __name__ = "ECG_UNET"

    def __init__(self, classes: Sequence[str], n_leads: int,
                 config: dict) -> NoReturn:
        """ finished, checked,

        Parameters:
        -----------
        classes: sequence of int,
            name of the classes
        n_leads: int,
            number of input leads (number of input channels)
        config: dict,
            other hyper-parameters, including kernel sizes, etc.
            ref. the corresponding config file
        """
        super().__init__()
        self.classes = list(classes)
        self.n_classes = len(classes)  # final out_channels
        self.__out_channels = self.n_classes
        self.__in_channels = n_leads
        self.config = ED(deepcopy(config))
        if self.__DEBUG__:
            print(
                f"configuration of {self.__name__} is as follows\n{dict_to_str(self.config)}"
            )
            __debug_seq_len = 4000

        self.init_conv = DoubleConv(
            in_channels=self.__in_channels,
            out_channels=self.config.init_num_filters,
            filter_lengths=self.config.init_filter_length,
            subsample_lengths=1,
            groups=self.config.groups,
            batch_norm=self.config.batch_norm,
            activation=self.config.activation,
            kw_activation=self.config.kw_activation,
            kernel_initializer=self.config.kernel_initializer,
            kw_initializer=self.config.kw_initializer,
        )
        if self.__DEBUG__:
            __debug_output_shape = self.init_conv.compute_output_shape(
                __debug_seq_len)
            print(
                f"given seq_len = {__debug_seq_len}, init_conv output shape = {__debug_output_shape}"
            )
            _, _, __debug_seq_len = __debug_output_shape

        self.down_blocks = nn.ModuleDict()
        in_channels = self.config.init_num_filters
        for idx in range(self.config.down_up_block_num):
            self.down_blocks[f"down_{idx}"] = \
                DownDoubleConv(
                    down_scale=self.config.down_scales[idx],
                    in_channels=in_channels,
                    out_channels=self.config.down_num_filters[idx],
                    filter_lengths=self.config.down_filter_lengths[idx],
                    groups=self.config.groups,
                    mode=self.config.down_mode,
                    **(self.config.down_block)
                )
            in_channels = self.config.down_num_filters[idx]
            if self.__DEBUG__:
                __debug_output_shape = self.down_blocks[
                    f"down_{idx}"].compute_output_shape(__debug_seq_len)
                print(
                    f"given seq_len = {__debug_seq_len}, down_{idx} output shape = {__debug_output_shape}"
                )
                _, _, __debug_seq_len = __debug_output_shape

        self.up_blocks = nn.ModuleDict()
        in_channels = self.config.down_num_filters[-1]
        for idx in range(self.config.down_up_block_num):
            self.up_blocks[f"up_{idx}"] = \
                UpDoubleConv(
                    up_scale=self.config.up_scales[idx],
                    in_channels=in_channels,
                    out_channels=self.config.up_num_filters[idx],
                    filter_lengths=self.config.up_conv_filter_lengths[idx],
                    deconv_filter_length=self.config.up_deconv_filter_lengths[idx],
                    groups=self.config.groups,
                    mode=self.config.up_mode,
                    **(self.config.up_block)
                )
            in_channels = self.config.up_num_filters[idx]
            if self.__DEBUG__:
                __debug_output_shape = self.up_blocks[
                    f"up_{idx}"].compute_output_shape(__debug_seq_len)
                print(
                    f"given seq_len = {__debug_seq_len}, up_{idx} output shape = {__debug_output_shape}"
                )
                _, _, __debug_seq_len = __debug_output_shape

        self.out_conv = Conv_Bn_Activation(
            in_channels=self.config.up_num_filters[-1],
            out_channels=self.__out_channels,
            kernel_size=self.config.out_filter_length,
            stride=1,
            groups=self.config.groups,
            batch_norm=self.config.batch_norm,
            activation=self.config.activation,
            kw_activation=self.config.kw_activation,
            kernel_initializer=self.config.kernel_initializer,
            kw_initializer=self.config.kw_initializer,
        )
        if self.__DEBUG__:
            __debug_output_shape = self.out_conv.compute_output_shape(
                __debug_seq_len)
            print(
                f"given seq_len = {__debug_seq_len}, out_conv output shape = {__debug_output_shape}"
            )

        # for inference
        # if background counted in `classes`, use softmax
        # otherwise use sigmoid
        self.softmax = nn.Softmax(-1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input: Tensor) -> Tensor:
        """ finished, checked,

        Parameters:
        -----------
        input: Tensor,
            of shape (batch_size, n_channels, seq_len)

        Returns:
        --------
        output: Tensor,
            of shape (batch_size, n_channels, seq_len)
        """
        to_concat = [self.init_conv(input)]
        if self.__DEBUG__:
            print(f"shape of init conv block output = {to_concat[-1].shape}")
        for idx in range(self.config.down_up_block_num):
            to_concat.append(self.down_blocks[f"down_{idx}"](to_concat[-1]))
            if self.__DEBUG__:
                print(
                    f"shape of {idx}-th down block output = {to_concat[-1].shape}"
                )
        up_input = to_concat[-1]
        to_concat = to_concat[-2::-1]
        for idx in range(self.config.down_up_block_num):
            if self.__DEBUG__:
                print(
                    f"shape of {idx}-th up block 1st input = {up_input.shape}")
                print(
                    f"shape of {idx}-th up block 2nd input (from down) = {to_concat[idx].shape}"
                )
            up_output = self.up_blocks[f"up_{idx}"](up_input, to_concat[idx])
            up_input = up_output
            if self.__DEBUG__:
                print(f"shape of {idx}-th up block output = {up_output.shape}")
        output = self.out_conv(up_output)
        if self.__DEBUG__:
            print(f"shape of out_conv layer output = {output.shape}")

        return output

    @torch.no_grad()
    def inference(self, input: Tensor, bin_pred_thr: float = 0.5) -> Tensor:
        """
        """
        raise NotImplementedError("implement a task specific inference method")

    def compute_output_shape(
            self,
            seq_len: Optional[int] = None,
            batch_size: Optional[int] = None) -> Sequence[Union[int, None]]:
        """ finished, checked,

        Parameters:
        -----------
        seq_len: int,
            length of the 1d sequence
        batch_size: int, optional,
            the batch size, can be None

        Returns:
        --------
        output_shape: sequence,
            the output shape of this `ECG_UNET` layer, given `seq_len` and `batch_size`
        """
        output_shape = (batch_size, self.n_classes, seq_len)
        return output_shape

    @property
    def module_size(self) -> int:
        """
        """
        return compute_module_size(self)
Ejemplo n.º 5
0
    def __init__(self, classes: Sequence[str], n_leads: int,
                 config: dict) -> NoReturn:
        """ finished, checked,

        Parameters:
        -----------
        classes: sequence of int,
            name of the classes
        n_leads: int,
            number of input leads
        config: dict,
            other hyper-parameters, including kernel sizes, etc.
            ref. the corresponding config file
        """
        super().__init__()
        self.classes = list(classes)
        self.n_classes = len(classes)
        self.__out_channels = len(classes)
        self.__in_channels = n_leads
        self.config = ED(deepcopy(config))
        if self.__DEBUG__:
            print(
                f"configuration of {self.__name__} is as follows\n{dict_to_str(self.config)}"
            )
            __debug_seq_len = 5000

        # TODO: an init batch normalization?
        if self.config.init_batch_norm:
            self.init_bn = nn.BatchNorm1d(
                num_features=self.__in_channels,
                eps=1e-5,  # default val
                momentum=0.1,  # default val
            )

        self.init_conv = TripleConv(
            in_channels=self.__in_channels,
            out_channels=self.config.init_num_filters,
            filter_lengths=self.config.init_filter_length,
            subsample_lengths=1,
            groups=self.config.groups,
            dropouts=self.config.init_dropouts,
            batch_norm=self.config.batch_norm,
            activation=self.config.activation,
            kw_activation=self.config.kw_activation,
            kernel_initializer=self.config.kernel_initializer,
            kw_initializer=self.config.kw_initializer,
        )
        if self.__DEBUG__:
            __debug_output_shape = self.init_conv.compute_output_shape(
                __debug_seq_len)
            print(
                f"given seq_len = {__debug_seq_len}, init_conv output shape = {__debug_output_shape}"
            )
            _, _, __debug_seq_len = __debug_output_shape

        self.down_blocks = nn.ModuleDict()
        in_channels = self.config.init_num_filters
        for idx in range(self.config.down_up_block_num - 1):
            self.down_blocks[f"down_{idx}"] = \
                DownTripleConv(
                    down_scale=self.config.down_scales[idx],
                    in_channels=in_channels,
                    out_channels=self.config.down_num_filters[idx],
                    filter_lengths=self.config.down_filter_lengths[idx],
                    groups=self.config.groups,
                    dropouts=self.config.down_dropouts[idx],
                    mode=self.config.down_mode,
                    **(self.config.down_block)
                )
            in_channels = self.config.down_num_filters[idx][-1]
            if self.__DEBUG__:
                __debug_output_shape = self.down_blocks[
                    f"down_{idx}"].compute_output_shape(__debug_seq_len)
                print(
                    f"given seq_len = {__debug_seq_len}, down_{idx} output shape = {__debug_output_shape}"
                )
                _, _, __debug_seq_len = __debug_output_shape

        self.bottom_block = DownBranchedDoubleConv(
            down_scale=self.config.down_scales[-1],
            in_channels=in_channels,
            out_channels=self.config.bottom_num_filters,
            filter_lengths=self.config.bottom_filter_lengths,
            dilations=self.config.bottom_dilations,
            groups=self.config.groups,
            dropouts=self.config.bottom_dropouts,
            mode=self.config.down_mode,
            **(self.config.down_block))
        if self.__DEBUG__:
            __debug_output_shape = self.bottom_block.compute_output_shape(
                __debug_seq_len)
            print(
                f"given seq_len = {__debug_seq_len}, bottom_block output shape = {__debug_output_shape}"
            )
            _, _, __debug_seq_len = __debug_output_shape

        self.up_blocks = nn.ModuleDict()
        # in_channels = sum([branch[-1] for branch in self.config.bottom_num_filters])
        in_channels = self.bottom_block.compute_output_shape(None, None)[1]
        for idx in range(self.config.down_up_block_num):
            self.up_blocks[f"up_{idx}"] = \
                UpTripleConv(
                    up_scale=self.config.up_scales[idx],
                    in_channels=in_channels,
                    out_channels=self.config.up_num_filters[idx],
                    filter_lengths=self.config.up_conv_filter_lengths[idx],
                    deconv_filter_length=self.config.up_deconv_filter_lengths[idx],
                    groups=self.config.groups,
                    mode=self.config.up_mode,
                    dropouts=self.config.up_dropouts[idx],
                    **(self.config.up_block)
                )
            in_channels = self.config.up_num_filters[idx][-1]
            if self.__DEBUG__:
                __debug_output_shape = self.up_blocks[
                    f"up_{idx}"].compute_output_shape(__debug_seq_len)
                print(
                    f"given seq_len = {__debug_seq_len}, up_{idx} output shape = {__debug_output_shape}"
                )
                _, _, __debug_seq_len = __debug_output_shape

        self.out_conv = Conv_Bn_Activation(
            in_channels=self.config.up_num_filters[-1][-1],
            out_channels=self.__out_channels,
            kernel_size=self.config.out_filter_length,
            stride=1,
            groups=self.config.groups,
            batch_norm=self.config.batch_norm,
            activation=self.config.activation,
            kw_activation=self.config.kw_activation,
            kernel_initializer=self.config.kernel_initializer,
            kw_initializer=self.config.kw_initializer,
        )
        if self.__DEBUG__:
            __debug_output_shape = self.out_conv.compute_output_shape(
                __debug_seq_len)
            print(
                f"given seq_len = {__debug_seq_len}, out_conv output shape = {__debug_output_shape}"
            )

        # for inference
        # if background counted in `classes`, use softmax
        # otherwise use sigmoid
        self.softmax = nn.Softmax(-1)
        self.sigmoid = nn.Sigmoid()
Ejemplo n.º 6
0
class ECG_SUBTRACT_UNET(nn.Module):
    """ finished, NOT checked,

    entry 0433 of CPSC2019
    """
    __DEBUG__ = True
    __name__ = "ECG_SUBTRACT_UNET"

    def __init__(self, classes: Sequence[str], n_leads: int,
                 config: dict) -> NoReturn:
        """ finished, checked,

        Parameters:
        -----------
        classes: sequence of int,
            name of the classes
        n_leads: int,
            number of input leads
        config: dict,
            other hyper-parameters, including kernel sizes, etc.
            ref. the corresponding config file
        """
        super().__init__()
        self.classes = list(classes)
        self.n_classes = len(classes)
        self.__out_channels = len(classes)
        self.__in_channels = n_leads
        self.config = ED(deepcopy(config))
        if self.__DEBUG__:
            print(
                f"configuration of {self.__name__} is as follows\n{dict_to_str(self.config)}"
            )
            __debug_seq_len = 5000

        # TODO: an init batch normalization?
        if self.config.init_batch_norm:
            self.init_bn = nn.BatchNorm1d(
                num_features=self.__in_channels,
                eps=1e-5,  # default val
                momentum=0.1,  # default val
            )

        self.init_conv = TripleConv(
            in_channels=self.__in_channels,
            out_channels=self.config.init_num_filters,
            filter_lengths=self.config.init_filter_length,
            subsample_lengths=1,
            groups=self.config.groups,
            dropouts=self.config.init_dropouts,
            batch_norm=self.config.batch_norm,
            activation=self.config.activation,
            kw_activation=self.config.kw_activation,
            kernel_initializer=self.config.kernel_initializer,
            kw_initializer=self.config.kw_initializer,
        )
        if self.__DEBUG__:
            __debug_output_shape = self.init_conv.compute_output_shape(
                __debug_seq_len)
            print(
                f"given seq_len = {__debug_seq_len}, init_conv output shape = {__debug_output_shape}"
            )
            _, _, __debug_seq_len = __debug_output_shape

        self.down_blocks = nn.ModuleDict()
        in_channels = self.config.init_num_filters
        for idx in range(self.config.down_up_block_num - 1):
            self.down_blocks[f"down_{idx}"] = \
                DownTripleConv(
                    down_scale=self.config.down_scales[idx],
                    in_channels=in_channels,
                    out_channels=self.config.down_num_filters[idx],
                    filter_lengths=self.config.down_filter_lengths[idx],
                    groups=self.config.groups,
                    dropouts=self.config.down_dropouts[idx],
                    mode=self.config.down_mode,
                    **(self.config.down_block)
                )
            in_channels = self.config.down_num_filters[idx][-1]
            if self.__DEBUG__:
                __debug_output_shape = self.down_blocks[
                    f"down_{idx}"].compute_output_shape(__debug_seq_len)
                print(
                    f"given seq_len = {__debug_seq_len}, down_{idx} output shape = {__debug_output_shape}"
                )
                _, _, __debug_seq_len = __debug_output_shape

        self.bottom_block = DownBranchedDoubleConv(
            down_scale=self.config.down_scales[-1],
            in_channels=in_channels,
            out_channels=self.config.bottom_num_filters,
            filter_lengths=self.config.bottom_filter_lengths,
            dilations=self.config.bottom_dilations,
            groups=self.config.groups,
            dropouts=self.config.bottom_dropouts,
            mode=self.config.down_mode,
            **(self.config.down_block))
        if self.__DEBUG__:
            __debug_output_shape = self.bottom_block.compute_output_shape(
                __debug_seq_len)
            print(
                f"given seq_len = {__debug_seq_len}, bottom_block output shape = {__debug_output_shape}"
            )
            _, _, __debug_seq_len = __debug_output_shape

        self.up_blocks = nn.ModuleDict()
        # in_channels = sum([branch[-1] for branch in self.config.bottom_num_filters])
        in_channels = self.bottom_block.compute_output_shape(None, None)[1]
        for idx in range(self.config.down_up_block_num):
            self.up_blocks[f"up_{idx}"] = \
                UpTripleConv(
                    up_scale=self.config.up_scales[idx],
                    in_channels=in_channels,
                    out_channels=self.config.up_num_filters[idx],
                    filter_lengths=self.config.up_conv_filter_lengths[idx],
                    deconv_filter_length=self.config.up_deconv_filter_lengths[idx],
                    groups=self.config.groups,
                    mode=self.config.up_mode,
                    dropouts=self.config.up_dropouts[idx],
                    **(self.config.up_block)
                )
            in_channels = self.config.up_num_filters[idx][-1]
            if self.__DEBUG__:
                __debug_output_shape = self.up_blocks[
                    f"up_{idx}"].compute_output_shape(__debug_seq_len)
                print(
                    f"given seq_len = {__debug_seq_len}, up_{idx} output shape = {__debug_output_shape}"
                )
                _, _, __debug_seq_len = __debug_output_shape

        self.out_conv = Conv_Bn_Activation(
            in_channels=self.config.up_num_filters[-1][-1],
            out_channels=self.__out_channels,
            kernel_size=self.config.out_filter_length,
            stride=1,
            groups=self.config.groups,
            batch_norm=self.config.batch_norm,
            activation=self.config.activation,
            kw_activation=self.config.kw_activation,
            kernel_initializer=self.config.kernel_initializer,
            kw_initializer=self.config.kw_initializer,
        )
        if self.__DEBUG__:
            __debug_output_shape = self.out_conv.compute_output_shape(
                __debug_seq_len)
            print(
                f"given seq_len = {__debug_seq_len}, out_conv output shape = {__debug_output_shape}"
            )

        # for inference
        # if background counted in `classes`, use softmax
        # otherwise use sigmoid
        self.softmax = nn.Softmax(-1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input: Tensor) -> Tensor:
        """ finished, checked,

        Parameters:
        -----------
        input: Tensor,
            of shape (batch_size, n_channels, seq_len)

        Returns:
        --------
        output: Tensor,
            of shape (batch_size, n_channels, seq_len)
        """
        if self.config.init_batch_norm:
            x = self.init_bn(input)
        else:
            x = input

        # down
        to_concat = [self.init_conv(x)]
        if self.__DEBUG__:
            print(
                f"shape of the init conv block output = {to_concat[-1].shape}")
        for idx in range(self.config.down_up_block_num - 1):
            to_concat.append(self.down_blocks[f"down_{idx}"](to_concat[-1]))
            if self.__DEBUG__:
                print(
                    f"shape of the {idx}-th down block output = {to_concat[-1].shape}"
                )
        to_concat.append(self.bottom_block(to_concat[-1]))
        if self.__DEBUG__:
            print(f"shape of the bottom block output = {to_concat[-1].shape}")

        # up
        up_input = to_concat[-1]
        to_concat = to_concat[-2::-1]
        for idx in range(self.config.down_up_block_num):
            up_output = self.up_blocks[f"up_{idx}"](up_input, to_concat[idx])
            up_input = up_output
            if self.__DEBUG__:
                print(
                    f"shape of the {idx}-th up block output = {up_output.shape}"
                )

        # output
        output = self.out_conv(up_output)
        if self.__DEBUG__:
            print(f"shape of out_conv layer output = {output.shape}")

        return output

    @torch.no_grad()
    def inference(self,
                  input: Union[np.ndarray, Tensor],
                  bin_pred_thr: float = 0.5) -> Tensor:
        """
        """
        NotImplementedError("implement a task specific inference method")

    def compute_output_shape(
            self,
            seq_len: int,
            batch_size: Optional[int] = None) -> Sequence[Union[int, None]]:
        """ finished, NOT checked,

        Parameters:
        -----------
        seq_len: int,
            length of the 1d sequence
        batch_size: int, optional,
            the batch size, can be None

        Returns:
        --------
        output_shape: sequence,
            the output shape of this `ECG_UNET` layer, given `seq_len` and `batch_size`
        """
        output_shape = (batch_size, self.n_classes, seq_len)
        return output_shape

    @property
    def module_size(self) -> int:
        """
        """
        return compute_module_size(self)