def __init__(self, classes: Sequence[str], n_leads: int, config: Optional[ED] = None) -> NoReturn: """ finished, checked, Parameters ---------- classes: list, list of the classes for classification n_leads: int, number of leads (number of input channels) config: dict, optional, other hyper-parameters, including kernel sizes, etc. ref. the corresponding config file """ super().__init__() self.classes = list(classes) self.n_classes = len(classes) self.n_leads = n_leads self.config = deepcopy(RR_LSTM_CONFIG) self.config.update(deepcopy(config) or {}) if self.__DEBUG__: print( f"classes (totally {self.n_classes}) for prediction:{self.classes}" ) print( f"configuration of {self.__name__} is as follows\n{dict_to_str(self.config)}" ) self.lstm = StackedLSTM( input_size=self.n_leads, hidden_sizes=self.config.lstm.hidden_sizes, bias=self.config.lstm.bias, dropouts=self.config.lstm.dropouts, bidirectional=self.config.lstm.bidirectional, return_sequences=self.config.lstm.retseq, ) if self.__DEBUG__: print(f"\042lstm\042 module has size {self.lstm.module_size}") attn_input_size = self.lstm.compute_output_shape(None, None)[-1] if not self.config.lstm.retseq: self.attn = None elif self.config.attn.name.lower() == "none": self.attn = None clf_input_size = attn_input_size elif self.config.attn.name.lower() == "nl": # non_local self.attn = NonLocalBlock( in_channels=attn_input_size, filter_lengths=self.config.attn.nl.filter_lengths, subsample_length=self.config.attn.nl.subsample_length, batch_norm=self.config.attn.nl.batch_norm, ) clf_input_size = self.attn.compute_output_shape(None, None)[1] elif self.config.attn.name.lower() == "se": # squeeze_exitation self.attn = SEBlock( in_channels=attn_input_size, reduction=self.config.attn.se.reduction, activation=self.config.attn.se.activation, kw_activation=self.config.attn.se.kw_activation, bias=self.config.attn.se.bias, ) clf_input_size = self.attn.compute_output_shape(None, None)[1] elif self.config.attn.name.lower() == "gc": # global_context self.attn = GlobalContextBlock( in_channels=attn_input_size, ratio=self.config.attn.gc.ratio, reduction=self.config.attn.gc.reduction, pooling_type=self.config.attn.gc.pooling_type, fusion_types=self.config.attn.gc.fusion_types, ) clf_input_size = self.attn.compute_output_shape(None, None)[1] elif self.config.attn.name.lower() == "sa": # self_attention # NOTE: this branch NOT tested self.attn = SelfAttention( in_features=attn_input_size, head_num=self.config.attn.sa.head_num, dropout=self.config.attn.sa.dropout, bias=self.config.attn.sa.bias, ) clf_input_size = self.attn.compute_output_shape(None, None)[-1] else: raise NotImplementedError if self.__DEBUG__ and self.attn: print( f"attn module \042{self.config.attn.name}\042 has size {self.attn.module_size}" ) if not self.config.lstm.retseq: self.pool = None self.clf = None elif self.config.clf.name.lower() == "linear": if self.config.global_pool.lower() == "max": self.pool = nn.AdaptiveMaxPool1d((1, )) self.clf = SeqLin( in_channels=clf_input_size, out_channels=self.config.clf.linear.out_channels + [self.n_classes], activation=self.config.clf.linear.activation, bias=self.config.clf.linear.bias, dropouts=self.config.clf.linear.dropouts, skip_last_activation=True, ) elif self.config.clf.name.lower() == "crf": self.pool = None self.clf = ExtendedCRF(in_channels=clf_input_size, num_tags=self.n_classes, bias=self.config.clf.crf.proj_bias) if self.__DEBUG__ and self.clf: print( f"clf module \042{self.config.clf.name}\042 has size {self.clf.module_size}" ) # for inference, except for crf self.softmax = nn.Softmax(dim=-1) self.sigmoid = nn.Sigmoid()
def __init__(self, classes: Sequence[str], n_leads: int, config: Optional[ED] = None) -> NoReturn: """ finished, checked, Parameters: ----------- classes: list, list of the classes for classification n_leads: int, number of leads (number of input channels) config: dict, optional, other hyper-parameters, including kernel sizes, etc. ref. the corresponding config file """ super().__init__() self.classes = list(classes) self.n_classes = len(classes) self.n_leads = n_leads self.config = deepcopy(ECG_CRNN_CONFIG) self.config.update(deepcopy(config) or {}) if self.__DEBUG__: print( f"classes (totally {self.n_classes}) for prediction:{self.classes}" ) print( f"configuration of {self.__name__} is as follows\n{dict_to_str(self.config)}" ) debug_input_len = 4000 cnn_choice = self.config.cnn.name.lower() if "vgg16" in cnn_choice: self.cnn = VGG16(self.n_leads, **(self.config.cnn[cnn_choice])) # rnn_input_size = self.config.cnn.vgg16.num_filters[-1] elif "resnet" in cnn_choice: self.cnn = ResNet(self.n_leads, **(self.config.cnn[cnn_choice])) # rnn_input_size = \ # 2**len(self.config.cnn[cnn_choice].num_blocks) * self.config.cnn[cnn_choice].init_num_filters elif "multi_scopic" in cnn_choice: self.cnn = MultiScopicCNN(self.n_leads, **(self.config.cnn[cnn_choice])) # rnn_input_size = self.cnn.compute_output_shape(None, None)[1] elif "densenet" in cnn_choice or "dense_net" in cnn_choice: self.cnn = DenseNet(self.n_leads, **(self.config.cnn[cnn_choice])) else: raise NotImplementedError( f"the CNN \042{cnn_choice}\042 not implemented yet") rnn_input_size = self.cnn.compute_output_shape(None, None)[1] if self.__DEBUG__: cnn_output_shape = self.cnn.compute_output_shape( debug_input_len, None) print( f"cnn output shape (batch_size, features, seq_len) = {cnn_output_shape}, given input_len = {debug_input_len}" ) if self.config.rnn.name.lower() == "none": self.rnn = None attn_input_size = rnn_input_size elif self.config.rnn.name.lower() == "lstm": # hidden_sizes = self.config.rnn.lstm.hidden_sizes + [self.n_classes] # if self.__DEBUG__: # print(f"lstm hidden sizes {self.config.rnn.lstm.hidden_sizes} ---> {hidden_sizes}") self.rnn = StackedLSTM( input_size=rnn_input_size, hidden_sizes=self.config.rnn.lstm.hidden_sizes, bias=self.config.rnn.lstm.bias, dropouts=self.config.rnn.lstm.dropouts, bidirectional=self.config.rnn.lstm.bidirectional, return_sequences=self.config.rnn.lstm.retseq, ) attn_input_size = self.rnn.compute_output_shape(None, None)[-1] elif self.config.rnn.name.lower() == "linear": self.rnn = SeqLin( in_channels=rnn_input_size, out_channels=self.config.rnn.linear.out_channels, activation=self.config.rnn.linear.activation, bias=self.config.rnn.linear.bias, dropouts=self.config.rnn.linear.dropouts, ) attn_input_size = self.rnn.compute_output_shape(None, None)[-1] else: raise NotImplementedError # attention if self.config.rnn.name.lower( ) == "lstm" and not self.config.rnn.lstm.retseq: self.attn = None clf_input_size = attn_input_size if self.config.attn.name.lower() != "none": print( f"since `retseq` of rnn is False, hence attention `{self.config.attn.name}` is ignored" ) elif self.config.attn.name.lower() == "none": self.attn = None clf_input_size = attn_input_size elif self.config.attn.name.lower() == "nl": # non_local self.attn = NonLocalBlock( in_channels=attn_input_size, filter_lengths=self.config.attn.nl.filter_lengths, subsample_length=self.config.attn.nl.subsample_length, batch_norm=self.config.attn.nl.batch_norm, ) clf_input_size = self.attn.compute_output_shape(None, None)[1] elif self.config.attn.name.lower() == "se": # squeeze_exitation self.attn = SEBlock( in_channels=attn_input_size, reduction=self.config.attn.se.reduction, activation=self.config.attn.se.activation, kw_activation=self.config.attn.se.kw_activation, bias=self.config.attn.se.bias, ) clf_input_size = self.attn.compute_output_shape(None, None)[1] elif self.config.attn.name.lower() == "gc": # global_context self.attn = GlobalContextBlock( in_channels=attn_input_size, ratio=self.config.attn.gc.ratio, reduction=self.config.attn.gc.reduction, pooling_type=self.config.attn.gc.pooling_type, fusion_types=self.config.attn.gc.fusion_types, ) clf_input_size = self.attn.compute_output_shape(None, None)[1] elif self.config.attn.name.lower() == "sa": # self_attention # NOTE: this branch NOT tested self.attn = SelfAttention( in_features=attn_input_size, head_num=self.config.attn.sa.head_num, dropout=self.config.attn.sa.dropout, bias=self.config.attn.sa.bias, ) clf_input_size = self.attn.compute_output_shape(None, None)[-1] else: raise NotImplementedError if self.__DEBUG__: print(f"clf_input_size = {clf_input_size}") if self.config.rnn.name.lower( ) == "lstm" and not self.config.rnn.lstm.retseq: self.pool = None if self.config.global_pool.lower() != "none": print( f"since `retseq` of rnn is False, hence global pooling `{self.config.global_pool}` is ignored" ) elif self.config.global_pool.lower() == "max": self.pool = nn.AdaptiveMaxPool1d((1, ), return_indices=False) elif self.config.global_pool.lower() == "avg": self.pool = nn.AdaptiveAvgPool1d((1, )) elif self.config.global_pool.lower() == "attn": raise NotImplementedError("Attentive pooling not implemented yet!") else: raise NotImplementedError( f"pooling method {self.config.global_pool} not implemented yet!" ) # input of `self.clf` has shape: batch_size, channels self.clf = SeqLin( in_channels=clf_input_size, out_channels=self.config.clf.out_channels + [self.n_classes], activation=self.config.clf.activation, bias=self.config.clf.bias, dropouts=self.config.clf.dropouts, skip_last_activation=True, ) # sigmoid for inference self.sigmoid = nn.Sigmoid() # for making inference
class RR_LSTM(nn.Module): """ classification or sequence labeling using LSTM and using RR intervals as input """ __DEBUG__ = True __name__ = "RR_LSTM" def __init__(self, classes: Sequence[str], n_leads: int, config: Optional[ED] = None) -> NoReturn: """ finished, checked, Parameters ---------- classes: list, list of the classes for classification n_leads: int, number of leads (number of input channels) config: dict, optional, other hyper-parameters, including kernel sizes, etc. ref. the corresponding config file """ super().__init__() self.classes = list(classes) self.n_classes = len(classes) self.n_leads = n_leads self.config = deepcopy(RR_LSTM_CONFIG) self.config.update(deepcopy(config) or {}) if self.__DEBUG__: print( f"classes (totally {self.n_classes}) for prediction:{self.classes}" ) print( f"configuration of {self.__name__} is as follows\n{dict_to_str(self.config)}" ) self.lstm = StackedLSTM( input_size=self.n_leads, hidden_sizes=self.config.lstm.hidden_sizes, bias=self.config.lstm.bias, dropouts=self.config.lstm.dropouts, bidirectional=self.config.lstm.bidirectional, return_sequences=self.config.lstm.retseq, ) if self.__DEBUG__: print(f"\042lstm\042 module has size {self.lstm.module_size}") attn_input_size = self.lstm.compute_output_shape(None, None)[-1] if not self.config.lstm.retseq: self.attn = None elif self.config.attn.name.lower() == "none": self.attn = None clf_input_size = attn_input_size elif self.config.attn.name.lower() == "nl": # non_local self.attn = NonLocalBlock( in_channels=attn_input_size, filter_lengths=self.config.attn.nl.filter_lengths, subsample_length=self.config.attn.nl.subsample_length, batch_norm=self.config.attn.nl.batch_norm, ) clf_input_size = self.attn.compute_output_shape(None, None)[1] elif self.config.attn.name.lower() == "se": # squeeze_exitation self.attn = SEBlock( in_channels=attn_input_size, reduction=self.config.attn.se.reduction, activation=self.config.attn.se.activation, kw_activation=self.config.attn.se.kw_activation, bias=self.config.attn.se.bias, ) clf_input_size = self.attn.compute_output_shape(None, None)[1] elif self.config.attn.name.lower() == "gc": # global_context self.attn = GlobalContextBlock( in_channels=attn_input_size, ratio=self.config.attn.gc.ratio, reduction=self.config.attn.gc.reduction, pooling_type=self.config.attn.gc.pooling_type, fusion_types=self.config.attn.gc.fusion_types, ) clf_input_size = self.attn.compute_output_shape(None, None)[1] elif self.config.attn.name.lower() == "sa": # self_attention # NOTE: this branch NOT tested self.attn = SelfAttention( in_features=attn_input_size, head_num=self.config.attn.sa.head_num, dropout=self.config.attn.sa.dropout, bias=self.config.attn.sa.bias, ) clf_input_size = self.attn.compute_output_shape(None, None)[-1] else: raise NotImplementedError if self.__DEBUG__ and self.attn: print( f"attn module \042{self.config.attn.name}\042 has size {self.attn.module_size}" ) if not self.config.lstm.retseq: self.pool = None self.clf = None elif self.config.clf.name.lower() == "linear": if self.config.global_pool.lower() == "max": self.pool = nn.AdaptiveMaxPool1d((1, )) self.clf = SeqLin( in_channels=clf_input_size, out_channels=self.config.clf.linear.out_channels + [self.n_classes], activation=self.config.clf.linear.activation, bias=self.config.clf.linear.bias, dropouts=self.config.clf.linear.dropouts, skip_last_activation=True, ) elif self.config.clf.name.lower() == "crf": self.pool = None self.clf = ExtendedCRF(in_channels=clf_input_size, num_tags=self.n_classes, bias=self.config.clf.crf.proj_bias) if self.__DEBUG__ and self.clf: print( f"clf module \042{self.config.clf.name}\042 has size {self.clf.module_size}" ) # for inference, except for crf self.softmax = nn.Softmax(dim=-1) self.sigmoid = nn.Sigmoid() def forward(self, input: Tensor) -> Tensor: """ finished, checked, Parameters ---------- input: Tensor, of shape (seq_len, batch_size, n_channels) Returns ------- output: Tensor, of shape (batch_size, seq_len, n_classes) or (batch_size, n_classes) """ x = self.lstm( input ) # (seq_len, batch_size, n_channels) or (batch_size, n_channels) if self.attn: # (seq_len, batch_size, n_channels) --> (batch_size, n_channels, seq_len) x = x.permute(1, 2, 0) x = self.attn(x) # (batch_size, n_channels, seq_len) if self.pool: x = self.pool(x) # (batch_size, n_channels, 1) x = x.squeeze(dim=-1) # (batch_size, n_channels) elif x.ndim == 3: # (batch_size, n_channels, seq_len) --> (batch_size, seq_len, n_channels) x = x.permute(0, 2, 1) else: # x of shape (batch_size, n_channels), # in the case where config.lstm.retseq = False pass if self.clf: x = self.clf( x ) # (batch_size, seq_len, n_classes) or (batch_size, n_classes) output = x return output @torch.no_grad() def inference(self, input: Tensor, bin_pred_thr: float = 0.5) -> Tensor: """ """ raise NotImplementedError("implement a task specific inference method") def compute_output_shape( self, seq_len: Optional[int] = None, batch_size: Optional[int] = None) -> Sequence[Union[int, None]]: """ finished, checked, Parameters ---------- seq_len: int, optional, length of the 1d sequence, if is None, then the input is composed of single feature vectors for each batch batch_size: int, optional, the batch size, can be None Returns ------- output_shape: sequence, the output shape of this module, given `seq_len` and `batch_size` """ if self.config.clf.name.lower() == "crf": output_shape = (batch_size, seq_len, self.n_classes) else: # clf is "linear" or lstm.retseq is False output_shape = (batch_size, self.n_classes) return output_shape @property def module_size(self) -> int: """ """ return compute_module_size(self)
class ECG_CRNN(nn.Module): """ finished, continuously improving, C(R)NN models modified from the following refs. References: ----------- [1] Yao, Qihang, et al. "Time-Incremental Convolutional Neural Network for Arrhythmia Detection in Varied-Length Electrocardiogram." 2018 IEEE 16th Intl Conf on Dependable, Autonomic and Secure Computing, 16th Intl Conf on Pervasive Intelligence and Computing, 4th Intl Conf on Big Data Intelligence and Computing and Cyber Science and Technology Congress (DASC/PiCom/DataCom/CyberSciTech). IEEE, 2018. [2] Yao, Qihang, et al. "Multi-class Arrhythmia detection from 12-lead varied-length ECG using Attention-based Time-Incremental Convolutional Neural Network." Information Fusion 53 (2020): 174-182. [3] Hannun, Awni Y., et al. "Cardiologist-level arrhythmia detection and classification in ambulatory electrocardiograms using a deep neural network." Nature medicine 25.1 (2019): 65. [4] https://stanfordmlgroup.github.io/projects/ecg2/ [5] https://github.com/awni/ecg [6] CPSC2018 entry 0236 [7] CPSC2019 entry 0416 """ __DEBUG__ = True __name__ = "ECG_CRNN" def __init__(self, classes: Sequence[str], n_leads: int, config: Optional[ED] = None) -> NoReturn: """ finished, checked, Parameters: ----------- classes: list, list of the classes for classification n_leads: int, number of leads (number of input channels) config: dict, optional, other hyper-parameters, including kernel sizes, etc. ref. the corresponding config file """ super().__init__() self.classes = list(classes) self.n_classes = len(classes) self.n_leads = n_leads self.config = deepcopy(ECG_CRNN_CONFIG) self.config.update(deepcopy(config) or {}) if self.__DEBUG__: print( f"classes (totally {self.n_classes}) for prediction:{self.classes}" ) print( f"configuration of {self.__name__} is as follows\n{dict_to_str(self.config)}" ) debug_input_len = 4000 cnn_choice = self.config.cnn.name.lower() if "vgg16" in cnn_choice: self.cnn = VGG16(self.n_leads, **(self.config.cnn[cnn_choice])) # rnn_input_size = self.config.cnn.vgg16.num_filters[-1] elif "resnet" in cnn_choice: self.cnn = ResNet(self.n_leads, **(self.config.cnn[cnn_choice])) # rnn_input_size = \ # 2**len(self.config.cnn[cnn_choice].num_blocks) * self.config.cnn[cnn_choice].init_num_filters elif "multi_scopic" in cnn_choice: self.cnn = MultiScopicCNN(self.n_leads, **(self.config.cnn[cnn_choice])) # rnn_input_size = self.cnn.compute_output_shape(None, None)[1] elif "densenet" in cnn_choice or "dense_net" in cnn_choice: self.cnn = DenseNet(self.n_leads, **(self.config.cnn[cnn_choice])) else: raise NotImplementedError( f"the CNN \042{cnn_choice}\042 not implemented yet") rnn_input_size = self.cnn.compute_output_shape(None, None)[1] if self.__DEBUG__: cnn_output_shape = self.cnn.compute_output_shape( debug_input_len, None) print( f"cnn output shape (batch_size, features, seq_len) = {cnn_output_shape}, given input_len = {debug_input_len}" ) if self.config.rnn.name.lower() == "none": self.rnn = None attn_input_size = rnn_input_size elif self.config.rnn.name.lower() == "lstm": # hidden_sizes = self.config.rnn.lstm.hidden_sizes + [self.n_classes] # if self.__DEBUG__: # print(f"lstm hidden sizes {self.config.rnn.lstm.hidden_sizes} ---> {hidden_sizes}") self.rnn = StackedLSTM( input_size=rnn_input_size, hidden_sizes=self.config.rnn.lstm.hidden_sizes, bias=self.config.rnn.lstm.bias, dropouts=self.config.rnn.lstm.dropouts, bidirectional=self.config.rnn.lstm.bidirectional, return_sequences=self.config.rnn.lstm.retseq, ) attn_input_size = self.rnn.compute_output_shape(None, None)[-1] elif self.config.rnn.name.lower() == "linear": self.rnn = SeqLin( in_channels=rnn_input_size, out_channels=self.config.rnn.linear.out_channels, activation=self.config.rnn.linear.activation, bias=self.config.rnn.linear.bias, dropouts=self.config.rnn.linear.dropouts, ) attn_input_size = self.rnn.compute_output_shape(None, None)[-1] else: raise NotImplementedError # attention if self.config.rnn.name.lower( ) == "lstm" and not self.config.rnn.lstm.retseq: self.attn = None clf_input_size = attn_input_size if self.config.attn.name.lower() != "none": print( f"since `retseq` of rnn is False, hence attention `{self.config.attn.name}` is ignored" ) elif self.config.attn.name.lower() == "none": self.attn = None clf_input_size = attn_input_size elif self.config.attn.name.lower() == "nl": # non_local self.attn = NonLocalBlock( in_channels=attn_input_size, filter_lengths=self.config.attn.nl.filter_lengths, subsample_length=self.config.attn.nl.subsample_length, batch_norm=self.config.attn.nl.batch_norm, ) clf_input_size = self.attn.compute_output_shape(None, None)[1] elif self.config.attn.name.lower() == "se": # squeeze_exitation self.attn = SEBlock( in_channels=attn_input_size, reduction=self.config.attn.se.reduction, activation=self.config.attn.se.activation, kw_activation=self.config.attn.se.kw_activation, bias=self.config.attn.se.bias, ) clf_input_size = self.attn.compute_output_shape(None, None)[1] elif self.config.attn.name.lower() == "gc": # global_context self.attn = GlobalContextBlock( in_channels=attn_input_size, ratio=self.config.attn.gc.ratio, reduction=self.config.attn.gc.reduction, pooling_type=self.config.attn.gc.pooling_type, fusion_types=self.config.attn.gc.fusion_types, ) clf_input_size = self.attn.compute_output_shape(None, None)[1] elif self.config.attn.name.lower() == "sa": # self_attention # NOTE: this branch NOT tested self.attn = SelfAttention( in_features=attn_input_size, head_num=self.config.attn.sa.head_num, dropout=self.config.attn.sa.dropout, bias=self.config.attn.sa.bias, ) clf_input_size = self.attn.compute_output_shape(None, None)[-1] else: raise NotImplementedError if self.__DEBUG__: print(f"clf_input_size = {clf_input_size}") if self.config.rnn.name.lower( ) == "lstm" and not self.config.rnn.lstm.retseq: self.pool = None if self.config.global_pool.lower() != "none": print( f"since `retseq` of rnn is False, hence global pooling `{self.config.global_pool}` is ignored" ) elif self.config.global_pool.lower() == "max": self.pool = nn.AdaptiveMaxPool1d((1, ), return_indices=False) elif self.config.global_pool.lower() == "avg": self.pool = nn.AdaptiveAvgPool1d((1, )) elif self.config.global_pool.lower() == "attn": raise NotImplementedError("Attentive pooling not implemented yet!") else: raise NotImplementedError( f"pooling method {self.config.global_pool} not implemented yet!" ) # input of `self.clf` has shape: batch_size, channels self.clf = SeqLin( in_channels=clf_input_size, out_channels=self.config.clf.out_channels + [self.n_classes], activation=self.config.clf.activation, bias=self.config.clf.bias, dropouts=self.config.clf.dropouts, skip_last_activation=True, ) # sigmoid for inference self.sigmoid = nn.Sigmoid() # for making inference def extract_features(self, input: Tensor) -> Tensor: """ finished, checked, extract feature map before the dense (linear) classifying layer(s) Parameters: ----------- input: Tensor, of shape (batch_size, channels, seq_len) Returns: -------- features: Tensor, of shape (batch_size, channels, seq_len) or (batch_size, channels) """ # CNN features = self.cnn(input) # batch_size, channels, seq_len # print(f"cnn out shape = {features.shape}") # RNN (optional) if self.config.rnn.name.lower() in ["lstm"]: # (batch_size, channels, seq_len) --> (seq_len, batch_size, channels) features = features.permute(2, 0, 1) features = self.rnn( features ) # (seq_len, batch_size, channels) or (batch_size, channels) elif self.config.rnn.name.lower() in ["linear"]: # (batch_size, channels, seq_len) --> (batch_size, seq_len, channels) features = features.permute(0, 2, 1) features = self.rnn(features) # (batch_size, seq_len, channels) # (batch_size, seq_len, channels) --> (seq_len, batch_size, channels) features = features.permute(1, 0, 2) else: # (batch_size, channels, seq_len) --> (seq_len, batch_size, channels) features = features.permute(2, 0, 1) # Attention (optional) if self.attn is None and x.ndim == 3: # (seq_len, batch_size, channels) --> (batch_size, channels, seq_len) features = features.permute(1, 2, 0) elif self.config.attn.name.lower() in ["nl", "se", "gc"]: # (seq_len, batch_size, channels) --> (batch_size, channels, seq_len) features = features.permute(1, 2, 0) features = self.attn(features) # (batch_size, channels, seq_len) elif self.config.attn.name.lower() in ["sa"]: features = self.attn(features) # (seq_len, batch_size, channels) # (seq_len, batch_size, channels) -> (batch_size, channels, seq_len) features = features.permute(1, 2, 0) return features def forward(self, input: Tensor) -> Tensor: """ finished, partly checked (rnn part might have bugs), Parameters: ----------- input: Tensor, of shape (batch_size, channels, seq_len) Returns: -------- pred: Tensor, of shape (batch_size, n_classes) """ features = self.extract_features(input) if self.pool: features = self.pool(features) # (batch_size, channels, 1) features = features.squeeze(dim=-1) else: # features of shape (batch_size, channels) pass # print(f"clf in shape = {x.shape}") pred = self.clf(features) # batch_size, n_classes return pred @torch.no_grad() def inference( self, input: Union[np.ndarray, Tensor], class_names: bool = False, bin_pred_thr: float = 0.5 ) -> Tuple[Union[np.ndarray, pd.DataFrame], np.ndarray]: """ finished, checked, Parameters: ----------- input: ndarray or Tensor, input tensor, of shape (batch_size, channels, seq_len) class_names: bool, default False, if True, the returned scalar predictions will be a `DataFrame`, with class names for each scalar prediction bin_pred_thr: float, default 0.5, the threshold for making binary predictions from scalar predictions Returns: -------- pred: ndarray or DataFrame, scalar predictions, (and binary predictions if `class_names` is True) bin_pred: ndarray, the array (with values 0, 1 for each class) of binary prediction """ raise NotImplementedError( f"implement a task specific inference method") @property def module_size(self): """ """ return compute_module_size(self)