Exemple #1
0
class ECG_UNET(nn.Module):
    """
    """
    __DEBUG__ = True
    __name__ = "ECG_UNET"

    def __init__(self, classes: Sequence[str], n_leads: int,
                 config: dict) -> NoReturn:
        """ finished, checked,

        Parameters:
        -----------
        classes: sequence of int,
            name of the classes
        n_leads: int,
            number of input leads
        config: dict,
            other hyper-parameters, including kernel sizes, etc.
            ref. the corresponding config file
        """
        super().__init__()
        self.classes = list(classes)
        self.n_classes = len(classes)  # final out_channels
        self.__in_channels = n_leads
        self.config = ED(deepcopy(config))
        if self.__DEBUG__:
            print(
                f"configuration of {self.__name__} is as follows\n{dict_to_str(self.config)}"
            )
            __debug_seq_len = 4000

        self.init_conv = DoubleConv(
            in_channels=self.__in_channels,
            out_channels=self.n_classes,
            filter_lengths=self.config.init_filter_length,
            subsample_lengths=1,
            groups=self.config.groups,
            batch_norm=self.config.batch_norm,
            activation=self.config.activation,
            kw_activation=self.config.kw_activation,
            kernel_initializer=self.config.kernel_initializer,
            kw_initializer=self.config.kw_initializer,
        )
        if self.__DEBUG__:
            __debug_output_shape = self.init_conv.compute_output_shape(
                __debug_seq_len)
            print(
                f"given seq_len = {__debug_seq_len}, init_conv output shape = {__debug_output_shape}"
            )
            _, _, __debug_seq_len = __debug_output_shape

        self.down_blocks = nn.ModuleDict()
        in_channels = self.n_classes
        for idx in range(self.config.down_up_block_num):
            self.down_blocks[f'down_{idx}'] = \
                DownDoubleConv(
                    down_scale=self.config.down_scales[idx],
                    in_channels=in_channels,
                    out_channels=self.config.down_num_filters[idx],
                    filter_lengths=self.config.down_filter_lengths[idx],
                    groups=self.config.groups,
                    mode=self.config.down_mode,
                    **(self.config.down_block)
                )
            in_channels = self.config.down_num_filters[idx]
            if self.__DEBUG__:
                __debug_output_shape = self.down_blocks[
                    f'down_{idx}'].compute_output_shape(__debug_seq_len)
                print(
                    f"given seq_len = {__debug_seq_len}, down_{idx} output shape = {__debug_output_shape}"
                )
                _, _, __debug_seq_len = __debug_output_shape

        self.up_blocks = nn.ModuleDict()
        in_channels = self.config.down_num_filters[-1]
        for idx in range(self.config.down_up_block_num):
            self.up_blocks[f'up_{idx}'] = \
                UpDoubleConv(
                    up_scale=self.config.up_scales[idx],
                    in_channels=in_channels,
                    out_channels=self.config.up_num_filters[idx],
                    filter_lengths=self.config.up_conv_filter_lengths[idx],
                    deconv_filter_length=self.config.up_deconv_filter_lengths[idx],
                    groups=self.config.groups,
                    mode=self.config.up_mode,
                    **(self.config.up_block)
                )
            in_channels = self.config.up_num_filters[idx]
            if self.__DEBUG__:
                __debug_output_shape = self.up_blocks[
                    f'up_{idx}'].compute_output_shape(__debug_seq_len)
                print(
                    f"given seq_len = {__debug_seq_len}, up_{idx} output shape = {__debug_output_shape}"
                )
                _, _, __debug_seq_len = __debug_output_shape

        self.out_conv = Conv_Bn_Activation(
            in_channels=self.config.up_num_filters[-1],
            out_channels=self.n_classes,
            kernel_size=self.config.out_filter_length,
            stride=1,
            groups=self.config.groups,
            batch_norm=self.config.batch_norm,
            activation=self.config.activation,
            kw_activation=self.config.kw_activation,
            kernel_initializer=self.config.kernel_initializer,
            kw_initializer=self.config.kw_initializer,
        )
        if self.__DEBUG__:
            __debug_output_shape = self.out_conv.compute_output_shape(
                __debug_seq_len)
            print(
                f"given seq_len = {__debug_seq_len}, out_conv output shape = {__debug_output_shape}"
            )

    def forward(self, input: Tensor) -> Tensor:
        """
        Parameters:
        -----------
        input: Tensor,
            of shape (batch_size, channels, seq_len)
        """
        to_concat = [self.init_conv(input)]
        if self.__DEBUG__:
            print(f"shape of init conv block output = {to_concat[-1].shape}")
        for idx in range(self.config.down_up_block_num):
            to_concat.append(self.down_blocks[f"down_{idx}"](to_concat[-1]))
            if self.__DEBUG__:
                print(
                    f"shape of {idx}-th down block output = {to_concat[-1].shape}"
                )
        up_input = to_concat[-1]
        to_concat = to_concat[-2::-1]
        for idx in range(self.config.down_up_block_num):
            up_output = self.up_blocks[f"up_{idx}"](up_input, to_concat[idx])
            up_input = up_output
            if self.__DEBUG__:
                print(f"shape of {idx}-th up block output = {up_output.shape}")
        output = self.out_conv(up_output)
        if self.__DEBUG__:
            print(f"shape of out_conv layer output = {output.shape}")

        return output

    def inference(self, input: Tensor) -> Tensor:
        """
        """
        raise NotImplementedError

    def compute_output_shape(self,
                             seq_len: int,
                             batch_size: Optional[int] = None
                             ) -> Sequence[Union[int, type(None)]]:
        """ finished, checked,

        Parameters:
        -----------
        seq_len: int,
            length of the 1d sequence
        batch_size: int, optional,
            the batch size, can be None

        Returns:
        --------
        output_shape: sequence,
            the output shape of this `ECG_UNET` layer, given `seq_len` and `batch_size`
        """
        output_shape = (batch_size, self.n_classes, seq_len)
        return output_shape

    @property
    def module_size(self) -> int:
        """
        """
        module_parameters = filter(lambda p: p.requires_grad,
                                   self.parameters())
        n_params = sum([np.prod(p.size()) for p in module_parameters])
        return n_params
Exemple #2
0
class ECG_SUBTRACT_UNET(nn.Module):
    """
    """
    __DEBUG__ = True
    __name__ = "ECG_SUBTRACT_UNET"

    def __init__(self, classes: Sequence[str], n_leads: int,
                 config: dict) -> NoReturn:
        """ finished, checked,

        Parameters:
        -----------
        classes: sequence of int,
            name of the classes
        n_leads: int,
            number of input leads
        config: dict,
            other hyper-parameters, including kernel sizes, etc.
            ref. the corresponding config file
        """
        super().__init__()
        self.classes = list(classes)
        self.n_classes = len(classes)  # final out_channels if > 2
        self.__out_channels = len(classes) if len(classes) > 2 else 1
        self.__in_channels = n_leads
        self.config = ED(deepcopy(config))
        if self.__DEBUG__:
            print(
                f"configuration of {self.__name__} is as follows\n{dict_to_str(self.config)}"
            )
            __debug_seq_len = 5000

        # TODO: an init batch normalization?
        if self.config.init_batch_norm:
            self.init_bn = nn.BatchNorm1d(
                num_features=self.__in_channels,
                eps=1e-5,  # default val
                momentum=0.1,  # default val
            )

        self.init_conv = TripleConv(
            in_channels=self.__in_channels,
            out_channels=self.config.init_num_filters,
            filter_lengths=self.config.init_filter_length,
            subsample_lengths=1,
            groups=self.config.groups,
            dropouts=self.config.init_dropouts,
            batch_norm=self.config.batch_norm,
            activation=self.config.activation,
            kw_activation=self.config.kw_activation,
            kernel_initializer=self.config.kernel_initializer,
            kw_initializer=self.config.kw_initializer,
        )
        if self.__DEBUG__:
            __debug_output_shape = self.init_conv.compute_output_shape(
                __debug_seq_len)
            print(
                f"given seq_len = {__debug_seq_len}, init_conv output shape = {__debug_output_shape}"
            )
            _, _, __debug_seq_len = __debug_output_shape

        self.down_blocks = nn.ModuleDict()
        in_channels = self.config.init_num_filters
        for idx in range(self.config.down_up_block_num - 1):
            self.down_blocks[f'down_{idx}'] = \
                DownTripleConv(
                    down_scale=self.config.down_scales[idx],
                    in_channels=in_channels,
                    out_channels=self.config.down_num_filters[idx],
                    filter_lengths=self.config.down_filter_lengths[idx],
                    groups=self.config.groups,
                    dropouts=self.config.down_dropouts[idx],
                    mode=self.config.down_mode,
                    **(self.config.down_block)
                )
            in_channels = self.config.down_num_filters[idx][-1]
            if self.__DEBUG__:
                __debug_output_shape = self.down_blocks[
                    f'down_{idx}'].compute_output_shape(__debug_seq_len)
                print(
                    f"given seq_len = {__debug_seq_len}, down_{idx} output shape = {__debug_output_shape}"
                )
                _, _, __debug_seq_len = __debug_output_shape

        self.bottom_block = DownBranchedDoubleConv(
            down_scale=self.config.down_scales[-1],
            in_channels=in_channels,
            out_channels=self.config.bottom_num_filters,
            filter_lengths=self.config.bottom_filter_lengths,
            dilations=self.config.bottom_dilations,
            groups=self.config.groups,
            dropouts=self.config.bottom_dropouts,
            mode=self.config.down_mode,
            **(self.config.down_block))
        if self.__DEBUG__:
            __debug_output_shape = self.bottom_block.compute_output_shape(
                __debug_seq_len)
            print(
                f"given seq_len = {__debug_seq_len}, bottom_block output shape = {__debug_output_shape}"
            )
            _, _, __debug_seq_len = __debug_output_shape

        self.up_blocks = nn.ModuleDict()
        in_channels = sum(
            [branch[-1] for branch in self.config.bottom_num_filters])
        for idx in range(self.config.down_up_block_num):
            self.up_blocks[f'up_{idx}'] = \
                UpTripleConv(
                    up_scale=self.config.up_scales[idx],
                    in_channels=in_channels,
                    out_channels=self.config.up_num_filters[idx],
                    filter_lengths=self.config.up_conv_filter_lengths[idx],
                    deconv_filter_length=self.config.up_deconv_filter_lengths[idx],
                    groups=self.config.groups,
                    mode=self.config.up_mode,
                    dropouts=self.config.up_dropouts[idx],
                    **(self.config.up_block)
                )
            in_channels = self.config.up_num_filters[idx][-1]
            if self.__DEBUG__:
                __debug_output_shape = self.up_blocks[
                    f'up_{idx}'].compute_output_shape(__debug_seq_len)
                print(
                    f"given seq_len = {__debug_seq_len}, up_{idx} output shape = {__debug_output_shape}"
                )
                _, _, __debug_seq_len = __debug_output_shape

        self.out_conv = Conv_Bn_Activation(
            in_channels=self.config.up_num_filters[-1][-1],
            out_channels=self.__out_channels,
            kernel_size=self.config.out_filter_length,
            stride=1,
            groups=self.config.groups,
            batch_norm=self.config.batch_norm,
            activation=self.config.activation,
            kw_activation=self.config.kw_activation,
            kernel_initializer=self.config.kernel_initializer,
            kw_initializer=self.config.kw_initializer,
        )
        if self.__DEBUG__:
            __debug_output_shape = self.out_conv.compute_output_shape(
                __debug_seq_len)
            print(
                f"given seq_len = {__debug_seq_len}, out_conv output shape = {__debug_output_shape}"
            )

    def forward(self, input: Tensor) -> Tensor:
        """
        """
        if self.config.init_batch_norm:
            x = self.init_bn(input)
        else:
            x = input

        # down
        to_concat = [self.init_conv(x)]
        if self.__DEBUG__:
            print(f"shape of init conv block output = {to_concat[-1].shape}")
        for idx in range(self.config.down_up_block_num - 1):
            to_concat.append(self.down_blocks[f"down_{idx}"](to_concat[-1]))
            if self.__DEBUG__:
                print(
                    f"shape of {idx}-th down block output = {to_concat[-1].shape}"
                )
        to_concat.append(self.bottom_block(to_concat[-1]))

        # up
        up_input = to_concat[-1]
        to_concat = to_concat[-2::-1]
        for idx in range(self.config.down_up_block_num):
            up_output = self.up_blocks[f"up_{idx}"](up_input, to_concat[idx])
            up_input = up_output
            if self.__DEBUG__:
                print(f"shape of {idx}-th up block output = {up_output.shape}")

        # output
        output = self.out_conv(up_output)
        if self.__DEBUG__:
            print(f"shape of out_conv layer output = {output.shape}")

        return output

    def inference(self, input: Tensor) -> Tensor:
        """
        """
        raise NotImplementedError

    def compute_output_shape(self,
                             seq_len: int,
                             batch_size: Optional[int] = None
                             ) -> Sequence[Union[int, type(None)]]:
        """ finished, NOT checked,

        Parameters:
        -----------
        seq_len: int,
            length of the 1d sequence
        batch_size: int, optional,
            the batch size, can be None

        Returns:
        --------
        output_shape: sequence,
            the output shape of this `ECG_UNET` layer, given `seq_len` and `batch_size`
        """
        output_shape = (batch_size, self.n_classes, seq_len)
        return output_shape