Exemple #1
0
    def __init__(self,
                 classes: Sequence[str],
                 n_leads: int,
                 config: Optional[ED] = None) -> NoReturn:
        """ finished, checked,

        Parameters
        ----------
        classes: list,
            list of the classes for classification
        n_leads: int,
            number of leads (number of input channels)
        config: dict, optional,
            other hyper-parameters, including kernel sizes, etc.
            ref. the corresponding config file
        """
        super().__init__()
        self.classes = list(classes)
        self.n_classes = len(classes)
        self.n_leads = n_leads
        self.config = deepcopy(RR_LSTM_CONFIG)
        self.config.update(deepcopy(config) or {})
        if self.__DEBUG__:
            print(
                f"classes (totally {self.n_classes}) for prediction:{self.classes}"
            )
            print(
                f"configuration of {self.__name__} is as follows\n{dict_to_str(self.config)}"
            )

        self.lstm = StackedLSTM(
            input_size=self.n_leads,
            hidden_sizes=self.config.lstm.hidden_sizes,
            bias=self.config.lstm.bias,
            dropouts=self.config.lstm.dropouts,
            bidirectional=self.config.lstm.bidirectional,
            return_sequences=self.config.lstm.retseq,
        )

        if self.__DEBUG__:
            print(f"\042lstm\042 module has size {self.lstm.module_size}")

        attn_input_size = self.lstm.compute_output_shape(None, None)[-1]

        if not self.config.lstm.retseq:
            self.attn = None
        elif self.config.attn.name.lower() == "none":
            self.attn = None
            clf_input_size = attn_input_size
        elif self.config.attn.name.lower() == "nl":  # non_local
            self.attn = NonLocalBlock(
                in_channels=attn_input_size,
                filter_lengths=self.config.attn.nl.filter_lengths,
                subsample_length=self.config.attn.nl.subsample_length,
                batch_norm=self.config.attn.nl.batch_norm,
            )
            clf_input_size = self.attn.compute_output_shape(None, None)[1]
        elif self.config.attn.name.lower() == "se":  # squeeze_exitation
            self.attn = SEBlock(
                in_channels=attn_input_size,
                reduction=self.config.attn.se.reduction,
                activation=self.config.attn.se.activation,
                kw_activation=self.config.attn.se.kw_activation,
                bias=self.config.attn.se.bias,
            )
            clf_input_size = self.attn.compute_output_shape(None, None)[1]
        elif self.config.attn.name.lower() == "gc":  # global_context
            self.attn = GlobalContextBlock(
                in_channels=attn_input_size,
                ratio=self.config.attn.gc.ratio,
                reduction=self.config.attn.gc.reduction,
                pooling_type=self.config.attn.gc.pooling_type,
                fusion_types=self.config.attn.gc.fusion_types,
            )
            clf_input_size = self.attn.compute_output_shape(None, None)[1]
        elif self.config.attn.name.lower() == "sa":  # self_attention
            # NOTE: this branch NOT tested
            self.attn = SelfAttention(
                in_features=attn_input_size,
                head_num=self.config.attn.sa.head_num,
                dropout=self.config.attn.sa.dropout,
                bias=self.config.attn.sa.bias,
            )
            clf_input_size = self.attn.compute_output_shape(None, None)[-1]
        else:
            raise NotImplementedError

        if self.__DEBUG__ and self.attn:
            print(
                f"attn module \042{self.config.attn.name}\042 has size {self.attn.module_size}"
            )

        if not self.config.lstm.retseq:
            self.pool = None
            self.clf = None
        elif self.config.clf.name.lower() == "linear":
            if self.config.global_pool.lower() == "max":
                self.pool = nn.AdaptiveMaxPool1d((1, ))
                self.clf = SeqLin(
                    in_channels=clf_input_size,
                    out_channels=self.config.clf.linear.out_channels +
                    [self.n_classes],
                    activation=self.config.clf.linear.activation,
                    bias=self.config.clf.linear.bias,
                    dropouts=self.config.clf.linear.dropouts,
                    skip_last_activation=True,
                )
        elif self.config.clf.name.lower() == "crf":
            self.pool = None
            self.clf = ExtendedCRF(in_channels=clf_input_size,
                                   num_tags=self.n_classes,
                                   bias=self.config.clf.crf.proj_bias)

        if self.__DEBUG__ and self.clf:
            print(
                f"clf module \042{self.config.clf.name}\042 has size {self.clf.module_size}"
            )

        # for inference, except for crf
        self.softmax = nn.Softmax(dim=-1)
        self.sigmoid = nn.Sigmoid()
Exemple #2
0
    def __init__(self,
                 classes: Sequence[str],
                 n_leads: int,
                 config: Optional[ED] = None) -> NoReturn:
        """ finished, checked,

        Parameters:
        -----------
        classes: list,
            list of the classes for classification
        n_leads: int,
            number of leads (number of input channels)
        config: dict, optional,
            other hyper-parameters, including kernel sizes, etc.
            ref. the corresponding config file
        """
        super().__init__()
        self.classes = list(classes)
        self.n_classes = len(classes)
        self.n_leads = n_leads
        self.config = deepcopy(ECG_CRNN_CONFIG)
        self.config.update(deepcopy(config) or {})
        if self.__DEBUG__:
            print(
                f"classes (totally {self.n_classes}) for prediction:{self.classes}"
            )
            print(
                f"configuration of {self.__name__} is as follows\n{dict_to_str(self.config)}"
            )
            debug_input_len = 4000

        cnn_choice = self.config.cnn.name.lower()
        if "vgg16" in cnn_choice:
            self.cnn = VGG16(self.n_leads, **(self.config.cnn[cnn_choice]))
            # rnn_input_size = self.config.cnn.vgg16.num_filters[-1]
        elif "resnet" in cnn_choice:
            self.cnn = ResNet(self.n_leads, **(self.config.cnn[cnn_choice]))
            # rnn_input_size = \
            #     2**len(self.config.cnn[cnn_choice].num_blocks) * self.config.cnn[cnn_choice].init_num_filters
        elif "multi_scopic" in cnn_choice:
            self.cnn = MultiScopicCNN(self.n_leads,
                                      **(self.config.cnn[cnn_choice]))
            # rnn_input_size = self.cnn.compute_output_shape(None, None)[1]
        elif "densenet" in cnn_choice or "dense_net" in cnn_choice:
            self.cnn = DenseNet(self.n_leads, **(self.config.cnn[cnn_choice]))
        else:
            raise NotImplementedError(
                f"the CNN \042{cnn_choice}\042 not implemented yet")
        rnn_input_size = self.cnn.compute_output_shape(None, None)[1]

        if self.__DEBUG__:
            cnn_output_shape = self.cnn.compute_output_shape(
                debug_input_len, None)
            print(
                f"cnn output shape (batch_size, features, seq_len) = {cnn_output_shape}, given input_len = {debug_input_len}"
            )

        if self.config.rnn.name.lower() == "none":
            self.rnn = None
            attn_input_size = rnn_input_size
        elif self.config.rnn.name.lower() == "lstm":
            # hidden_sizes = self.config.rnn.lstm.hidden_sizes + [self.n_classes]
            # if self.__DEBUG__:
            #     print(f"lstm hidden sizes {self.config.rnn.lstm.hidden_sizes} ---> {hidden_sizes}")
            self.rnn = StackedLSTM(
                input_size=rnn_input_size,
                hidden_sizes=self.config.rnn.lstm.hidden_sizes,
                bias=self.config.rnn.lstm.bias,
                dropouts=self.config.rnn.lstm.dropouts,
                bidirectional=self.config.rnn.lstm.bidirectional,
                return_sequences=self.config.rnn.lstm.retseq,
            )
            attn_input_size = self.rnn.compute_output_shape(None, None)[-1]
        elif self.config.rnn.name.lower() == "linear":
            self.rnn = SeqLin(
                in_channels=rnn_input_size,
                out_channels=self.config.rnn.linear.out_channels,
                activation=self.config.rnn.linear.activation,
                bias=self.config.rnn.linear.bias,
                dropouts=self.config.rnn.linear.dropouts,
            )
            attn_input_size = self.rnn.compute_output_shape(None, None)[-1]
        else:
            raise NotImplementedError

        # attention
        if self.config.rnn.name.lower(
        ) == "lstm" and not self.config.rnn.lstm.retseq:
            self.attn = None
            clf_input_size = attn_input_size
            if self.config.attn.name.lower() != "none":
                print(
                    f"since `retseq` of rnn is False, hence attention `{self.config.attn.name}` is ignored"
                )
        elif self.config.attn.name.lower() == "none":
            self.attn = None
            clf_input_size = attn_input_size
        elif self.config.attn.name.lower() == "nl":  # non_local
            self.attn = NonLocalBlock(
                in_channels=attn_input_size,
                filter_lengths=self.config.attn.nl.filter_lengths,
                subsample_length=self.config.attn.nl.subsample_length,
                batch_norm=self.config.attn.nl.batch_norm,
            )
            clf_input_size = self.attn.compute_output_shape(None, None)[1]
        elif self.config.attn.name.lower() == "se":  # squeeze_exitation
            self.attn = SEBlock(
                in_channels=attn_input_size,
                reduction=self.config.attn.se.reduction,
                activation=self.config.attn.se.activation,
                kw_activation=self.config.attn.se.kw_activation,
                bias=self.config.attn.se.bias,
            )
            clf_input_size = self.attn.compute_output_shape(None, None)[1]
        elif self.config.attn.name.lower() == "gc":  # global_context
            self.attn = GlobalContextBlock(
                in_channels=attn_input_size,
                ratio=self.config.attn.gc.ratio,
                reduction=self.config.attn.gc.reduction,
                pooling_type=self.config.attn.gc.pooling_type,
                fusion_types=self.config.attn.gc.fusion_types,
            )
            clf_input_size = self.attn.compute_output_shape(None, None)[1]
        elif self.config.attn.name.lower() == "sa":  # self_attention
            # NOTE: this branch NOT tested
            self.attn = SelfAttention(
                in_features=attn_input_size,
                head_num=self.config.attn.sa.head_num,
                dropout=self.config.attn.sa.dropout,
                bias=self.config.attn.sa.bias,
            )
            clf_input_size = self.attn.compute_output_shape(None, None)[-1]
        else:
            raise NotImplementedError

        if self.__DEBUG__:
            print(f"clf_input_size = {clf_input_size}")

        if self.config.rnn.name.lower(
        ) == "lstm" and not self.config.rnn.lstm.retseq:
            self.pool = None
            if self.config.global_pool.lower() != "none":
                print(
                    f"since `retseq` of rnn is False, hence global pooling `{self.config.global_pool}` is ignored"
                )
        elif self.config.global_pool.lower() == "max":
            self.pool = nn.AdaptiveMaxPool1d((1, ), return_indices=False)
        elif self.config.global_pool.lower() == "avg":
            self.pool = nn.AdaptiveAvgPool1d((1, ))
        elif self.config.global_pool.lower() == "attn":
            raise NotImplementedError("Attentive pooling not implemented yet!")
        else:
            raise NotImplementedError(
                f"pooling method {self.config.global_pool} not implemented yet!"
            )

        # input of `self.clf` has shape: batch_size, channels
        self.clf = SeqLin(
            in_channels=clf_input_size,
            out_channels=self.config.clf.out_channels + [self.n_classes],
            activation=self.config.clf.activation,
            bias=self.config.clf.bias,
            dropouts=self.config.clf.dropouts,
            skip_last_activation=True,
        )

        # sigmoid for inference
        self.sigmoid = nn.Sigmoid()  # for making inference
Exemple #3
0
class RR_LSTM(nn.Module):
    """
    classification or sequence labeling using LSTM and using RR intervals as input
    """
    __DEBUG__ = True
    __name__ = "RR_LSTM"

    def __init__(self,
                 classes: Sequence[str],
                 n_leads: int,
                 config: Optional[ED] = None) -> NoReturn:
        """ finished, checked,

        Parameters
        ----------
        classes: list,
            list of the classes for classification
        n_leads: int,
            number of leads (number of input channels)
        config: dict, optional,
            other hyper-parameters, including kernel sizes, etc.
            ref. the corresponding config file
        """
        super().__init__()
        self.classes = list(classes)
        self.n_classes = len(classes)
        self.n_leads = n_leads
        self.config = deepcopy(RR_LSTM_CONFIG)
        self.config.update(deepcopy(config) or {})
        if self.__DEBUG__:
            print(
                f"classes (totally {self.n_classes}) for prediction:{self.classes}"
            )
            print(
                f"configuration of {self.__name__} is as follows\n{dict_to_str(self.config)}"
            )

        self.lstm = StackedLSTM(
            input_size=self.n_leads,
            hidden_sizes=self.config.lstm.hidden_sizes,
            bias=self.config.lstm.bias,
            dropouts=self.config.lstm.dropouts,
            bidirectional=self.config.lstm.bidirectional,
            return_sequences=self.config.lstm.retseq,
        )

        if self.__DEBUG__:
            print(f"\042lstm\042 module has size {self.lstm.module_size}")

        attn_input_size = self.lstm.compute_output_shape(None, None)[-1]

        if not self.config.lstm.retseq:
            self.attn = None
        elif self.config.attn.name.lower() == "none":
            self.attn = None
            clf_input_size = attn_input_size
        elif self.config.attn.name.lower() == "nl":  # non_local
            self.attn = NonLocalBlock(
                in_channels=attn_input_size,
                filter_lengths=self.config.attn.nl.filter_lengths,
                subsample_length=self.config.attn.nl.subsample_length,
                batch_norm=self.config.attn.nl.batch_norm,
            )
            clf_input_size = self.attn.compute_output_shape(None, None)[1]
        elif self.config.attn.name.lower() == "se":  # squeeze_exitation
            self.attn = SEBlock(
                in_channels=attn_input_size,
                reduction=self.config.attn.se.reduction,
                activation=self.config.attn.se.activation,
                kw_activation=self.config.attn.se.kw_activation,
                bias=self.config.attn.se.bias,
            )
            clf_input_size = self.attn.compute_output_shape(None, None)[1]
        elif self.config.attn.name.lower() == "gc":  # global_context
            self.attn = GlobalContextBlock(
                in_channels=attn_input_size,
                ratio=self.config.attn.gc.ratio,
                reduction=self.config.attn.gc.reduction,
                pooling_type=self.config.attn.gc.pooling_type,
                fusion_types=self.config.attn.gc.fusion_types,
            )
            clf_input_size = self.attn.compute_output_shape(None, None)[1]
        elif self.config.attn.name.lower() == "sa":  # self_attention
            # NOTE: this branch NOT tested
            self.attn = SelfAttention(
                in_features=attn_input_size,
                head_num=self.config.attn.sa.head_num,
                dropout=self.config.attn.sa.dropout,
                bias=self.config.attn.sa.bias,
            )
            clf_input_size = self.attn.compute_output_shape(None, None)[-1]
        else:
            raise NotImplementedError

        if self.__DEBUG__ and self.attn:
            print(
                f"attn module \042{self.config.attn.name}\042 has size {self.attn.module_size}"
            )

        if not self.config.lstm.retseq:
            self.pool = None
            self.clf = None
        elif self.config.clf.name.lower() == "linear":
            if self.config.global_pool.lower() == "max":
                self.pool = nn.AdaptiveMaxPool1d((1, ))
                self.clf = SeqLin(
                    in_channels=clf_input_size,
                    out_channels=self.config.clf.linear.out_channels +
                    [self.n_classes],
                    activation=self.config.clf.linear.activation,
                    bias=self.config.clf.linear.bias,
                    dropouts=self.config.clf.linear.dropouts,
                    skip_last_activation=True,
                )
        elif self.config.clf.name.lower() == "crf":
            self.pool = None
            self.clf = ExtendedCRF(in_channels=clf_input_size,
                                   num_tags=self.n_classes,
                                   bias=self.config.clf.crf.proj_bias)

        if self.__DEBUG__ and self.clf:
            print(
                f"clf module \042{self.config.clf.name}\042 has size {self.clf.module_size}"
            )

        # for inference, except for crf
        self.softmax = nn.Softmax(dim=-1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input: Tensor) -> Tensor:
        """ finished, checked,

        Parameters
        ----------
        input: Tensor,
            of shape (seq_len, batch_size, n_channels)

        Returns
        -------
        output: Tensor,
            of shape (batch_size, seq_len, n_classes) or (batch_size, n_classes)
        """
        x = self.lstm(
            input
        )  # (seq_len, batch_size, n_channels) or (batch_size, n_channels)
        if self.attn:
            # (seq_len, batch_size, n_channels) --> (batch_size, n_channels, seq_len)
            x = x.permute(1, 2, 0)
            x = self.attn(x)  # (batch_size, n_channels, seq_len)
        if self.pool:
            x = self.pool(x)  # (batch_size, n_channels, 1)
            x = x.squeeze(dim=-1)  # (batch_size, n_channels)
        elif x.ndim == 3:
            # (batch_size, n_channels, seq_len) --> (batch_size, seq_len, n_channels)
            x = x.permute(0, 2, 1)
        else:
            # x of shape (batch_size, n_channels),
            # in the case where config.lstm.retseq = False
            pass
        if self.clf:
            x = self.clf(
                x
            )  # (batch_size, seq_len, n_classes) or (batch_size, n_classes)
        output = x

        return output

    @torch.no_grad()
    def inference(self, input: Tensor, bin_pred_thr: float = 0.5) -> Tensor:
        """
        """
        raise NotImplementedError("implement a task specific inference method")

    def compute_output_shape(
            self,
            seq_len: Optional[int] = None,
            batch_size: Optional[int] = None) -> Sequence[Union[int, None]]:
        """ finished, checked,

        Parameters
        ----------
        seq_len: int, optional,
            length of the 1d sequence,
            if is None, then the input is composed of single feature vectors for each batch
        batch_size: int, optional,
            the batch size, can be None

        Returns
        -------
        output_shape: sequence,
            the output shape of this module, given `seq_len` and `batch_size`
        """
        if self.config.clf.name.lower() == "crf":
            output_shape = (batch_size, seq_len, self.n_classes)
        else:
            # clf is "linear" or lstm.retseq is False
            output_shape = (batch_size, self.n_classes)
        return output_shape

    @property
    def module_size(self) -> int:
        """
        """
        return compute_module_size(self)
Exemple #4
0
class ECG_CRNN(nn.Module):
    """ finished, continuously improving,

    C(R)NN models modified from the following refs.

    References:
    -----------
    [1] Yao, Qihang, et al. "Time-Incremental Convolutional Neural Network for Arrhythmia Detection in Varied-Length Electrocardiogram." 2018 IEEE 16th Intl Conf on Dependable, Autonomic and Secure Computing, 16th Intl Conf on Pervasive Intelligence and Computing, 4th Intl Conf on Big Data Intelligence and Computing and Cyber Science and Technology Congress (DASC/PiCom/DataCom/CyberSciTech). IEEE, 2018.
    [2] Yao, Qihang, et al. "Multi-class Arrhythmia detection from 12-lead varied-length ECG using Attention-based Time-Incremental Convolutional Neural Network." Information Fusion 53 (2020): 174-182.
    [3] Hannun, Awni Y., et al. "Cardiologist-level arrhythmia detection and classification in ambulatory electrocardiograms using a deep neural network." Nature medicine 25.1 (2019): 65.
    [4] https://stanfordmlgroup.github.io/projects/ecg2/
    [5] https://github.com/awni/ecg
    [6] CPSC2018 entry 0236
    [7] CPSC2019 entry 0416
    """
    __DEBUG__ = True
    __name__ = "ECG_CRNN"

    def __init__(self,
                 classes: Sequence[str],
                 n_leads: int,
                 config: Optional[ED] = None) -> NoReturn:
        """ finished, checked,

        Parameters:
        -----------
        classes: list,
            list of the classes for classification
        n_leads: int,
            number of leads (number of input channels)
        config: dict, optional,
            other hyper-parameters, including kernel sizes, etc.
            ref. the corresponding config file
        """
        super().__init__()
        self.classes = list(classes)
        self.n_classes = len(classes)
        self.n_leads = n_leads
        self.config = deepcopy(ECG_CRNN_CONFIG)
        self.config.update(deepcopy(config) or {})
        if self.__DEBUG__:
            print(
                f"classes (totally {self.n_classes}) for prediction:{self.classes}"
            )
            print(
                f"configuration of {self.__name__} is as follows\n{dict_to_str(self.config)}"
            )
            debug_input_len = 4000

        cnn_choice = self.config.cnn.name.lower()
        if "vgg16" in cnn_choice:
            self.cnn = VGG16(self.n_leads, **(self.config.cnn[cnn_choice]))
            # rnn_input_size = self.config.cnn.vgg16.num_filters[-1]
        elif "resnet" in cnn_choice:
            self.cnn = ResNet(self.n_leads, **(self.config.cnn[cnn_choice]))
            # rnn_input_size = \
            #     2**len(self.config.cnn[cnn_choice].num_blocks) * self.config.cnn[cnn_choice].init_num_filters
        elif "multi_scopic" in cnn_choice:
            self.cnn = MultiScopicCNN(self.n_leads,
                                      **(self.config.cnn[cnn_choice]))
            # rnn_input_size = self.cnn.compute_output_shape(None, None)[1]
        elif "densenet" in cnn_choice or "dense_net" in cnn_choice:
            self.cnn = DenseNet(self.n_leads, **(self.config.cnn[cnn_choice]))
        else:
            raise NotImplementedError(
                f"the CNN \042{cnn_choice}\042 not implemented yet")
        rnn_input_size = self.cnn.compute_output_shape(None, None)[1]

        if self.__DEBUG__:
            cnn_output_shape = self.cnn.compute_output_shape(
                debug_input_len, None)
            print(
                f"cnn output shape (batch_size, features, seq_len) = {cnn_output_shape}, given input_len = {debug_input_len}"
            )

        if self.config.rnn.name.lower() == "none":
            self.rnn = None
            attn_input_size = rnn_input_size
        elif self.config.rnn.name.lower() == "lstm":
            # hidden_sizes = self.config.rnn.lstm.hidden_sizes + [self.n_classes]
            # if self.__DEBUG__:
            #     print(f"lstm hidden sizes {self.config.rnn.lstm.hidden_sizes} ---> {hidden_sizes}")
            self.rnn = StackedLSTM(
                input_size=rnn_input_size,
                hidden_sizes=self.config.rnn.lstm.hidden_sizes,
                bias=self.config.rnn.lstm.bias,
                dropouts=self.config.rnn.lstm.dropouts,
                bidirectional=self.config.rnn.lstm.bidirectional,
                return_sequences=self.config.rnn.lstm.retseq,
            )
            attn_input_size = self.rnn.compute_output_shape(None, None)[-1]
        elif self.config.rnn.name.lower() == "linear":
            self.rnn = SeqLin(
                in_channels=rnn_input_size,
                out_channels=self.config.rnn.linear.out_channels,
                activation=self.config.rnn.linear.activation,
                bias=self.config.rnn.linear.bias,
                dropouts=self.config.rnn.linear.dropouts,
            )
            attn_input_size = self.rnn.compute_output_shape(None, None)[-1]
        else:
            raise NotImplementedError

        # attention
        if self.config.rnn.name.lower(
        ) == "lstm" and not self.config.rnn.lstm.retseq:
            self.attn = None
            clf_input_size = attn_input_size
            if self.config.attn.name.lower() != "none":
                print(
                    f"since `retseq` of rnn is False, hence attention `{self.config.attn.name}` is ignored"
                )
        elif self.config.attn.name.lower() == "none":
            self.attn = None
            clf_input_size = attn_input_size
        elif self.config.attn.name.lower() == "nl":  # non_local
            self.attn = NonLocalBlock(
                in_channels=attn_input_size,
                filter_lengths=self.config.attn.nl.filter_lengths,
                subsample_length=self.config.attn.nl.subsample_length,
                batch_norm=self.config.attn.nl.batch_norm,
            )
            clf_input_size = self.attn.compute_output_shape(None, None)[1]
        elif self.config.attn.name.lower() == "se":  # squeeze_exitation
            self.attn = SEBlock(
                in_channels=attn_input_size,
                reduction=self.config.attn.se.reduction,
                activation=self.config.attn.se.activation,
                kw_activation=self.config.attn.se.kw_activation,
                bias=self.config.attn.se.bias,
            )
            clf_input_size = self.attn.compute_output_shape(None, None)[1]
        elif self.config.attn.name.lower() == "gc":  # global_context
            self.attn = GlobalContextBlock(
                in_channels=attn_input_size,
                ratio=self.config.attn.gc.ratio,
                reduction=self.config.attn.gc.reduction,
                pooling_type=self.config.attn.gc.pooling_type,
                fusion_types=self.config.attn.gc.fusion_types,
            )
            clf_input_size = self.attn.compute_output_shape(None, None)[1]
        elif self.config.attn.name.lower() == "sa":  # self_attention
            # NOTE: this branch NOT tested
            self.attn = SelfAttention(
                in_features=attn_input_size,
                head_num=self.config.attn.sa.head_num,
                dropout=self.config.attn.sa.dropout,
                bias=self.config.attn.sa.bias,
            )
            clf_input_size = self.attn.compute_output_shape(None, None)[-1]
        else:
            raise NotImplementedError

        if self.__DEBUG__:
            print(f"clf_input_size = {clf_input_size}")

        if self.config.rnn.name.lower(
        ) == "lstm" and not self.config.rnn.lstm.retseq:
            self.pool = None
            if self.config.global_pool.lower() != "none":
                print(
                    f"since `retseq` of rnn is False, hence global pooling `{self.config.global_pool}` is ignored"
                )
        elif self.config.global_pool.lower() == "max":
            self.pool = nn.AdaptiveMaxPool1d((1, ), return_indices=False)
        elif self.config.global_pool.lower() == "avg":
            self.pool = nn.AdaptiveAvgPool1d((1, ))
        elif self.config.global_pool.lower() == "attn":
            raise NotImplementedError("Attentive pooling not implemented yet!")
        else:
            raise NotImplementedError(
                f"pooling method {self.config.global_pool} not implemented yet!"
            )

        # input of `self.clf` has shape: batch_size, channels
        self.clf = SeqLin(
            in_channels=clf_input_size,
            out_channels=self.config.clf.out_channels + [self.n_classes],
            activation=self.config.clf.activation,
            bias=self.config.clf.bias,
            dropouts=self.config.clf.dropouts,
            skip_last_activation=True,
        )

        # sigmoid for inference
        self.sigmoid = nn.Sigmoid()  # for making inference

    def extract_features(self, input: Tensor) -> Tensor:
        """ finished, checked,

        extract feature map before the dense (linear) classifying layer(s)

        Parameters:
        -----------
        input: Tensor,
            of shape (batch_size, channels, seq_len)
        
        Returns:
        --------
        features: Tensor,
            of shape (batch_size, channels, seq_len) or (batch_size, channels)
        """
        # CNN
        features = self.cnn(input)  # batch_size, channels, seq_len
        # print(f"cnn out shape = {features.shape}")

        # RNN (optional)
        if self.config.rnn.name.lower() in ["lstm"]:
            # (batch_size, channels, seq_len) --> (seq_len, batch_size, channels)
            features = features.permute(2, 0, 1)
            features = self.rnn(
                features
            )  # (seq_len, batch_size, channels) or (batch_size, channels)
        elif self.config.rnn.name.lower() in ["linear"]:
            # (batch_size, channels, seq_len) --> (batch_size, seq_len, channels)
            features = features.permute(0, 2, 1)
            features = self.rnn(features)  # (batch_size, seq_len, channels)
            # (batch_size, seq_len, channels) --> (seq_len, batch_size, channels)
            features = features.permute(1, 0, 2)
        else:
            # (batch_size, channels, seq_len) --> (seq_len, batch_size, channels)
            features = features.permute(2, 0, 1)

        # Attention (optional)
        if self.attn is None and x.ndim == 3:
            # (seq_len, batch_size, channels) --> (batch_size, channels, seq_len)
            features = features.permute(1, 2, 0)
        elif self.config.attn.name.lower() in ["nl", "se", "gc"]:
            # (seq_len, batch_size, channels) --> (batch_size, channels, seq_len)
            features = features.permute(1, 2, 0)
            features = self.attn(features)  # (batch_size, channels, seq_len)
        elif self.config.attn.name.lower() in ["sa"]:
            features = self.attn(features)  # (seq_len, batch_size, channels)
            # (seq_len, batch_size, channels) -> (batch_size, channels, seq_len)
            features = features.permute(1, 2, 0)
        return features

    def forward(self, input: Tensor) -> Tensor:
        """ finished, partly checked (rnn part might have bugs),

        Parameters:
        -----------
        input: Tensor,
            of shape (batch_size, channels, seq_len)
        
        Returns:
        --------
        pred: Tensor,
            of shape (batch_size, n_classes)
        """
        features = self.extract_features(input)

        if self.pool:
            features = self.pool(features)  # (batch_size, channels, 1)
            features = features.squeeze(dim=-1)
        else:
            # features of shape (batch_size, channels)
            pass

        # print(f"clf in shape = {x.shape}")
        pred = self.clf(features)  # batch_size, n_classes

        return pred

    @torch.no_grad()
    def inference(
        self,
        input: Union[np.ndarray, Tensor],
        class_names: bool = False,
        bin_pred_thr: float = 0.5
    ) -> Tuple[Union[np.ndarray, pd.DataFrame], np.ndarray]:
        """ finished, checked,

        Parameters:
        -----------
        input: ndarray or Tensor,
            input tensor, of shape (batch_size, channels, seq_len)
        class_names: bool, default False,
            if True, the returned scalar predictions will be a `DataFrame`,
            with class names for each scalar prediction
        bin_pred_thr: float, default 0.5,
            the threshold for making binary predictions from scalar predictions

        Returns:
        --------
        pred: ndarray or DataFrame,
            scalar predictions, (and binary predictions if `class_names` is True)
        bin_pred: ndarray,
            the array (with values 0, 1 for each class) of binary prediction
        """
        raise NotImplementedError(
            f"implement a task specific inference method")

    @property
    def module_size(self):
        """
        """
        return compute_module_size(self)