class SEmodule(torch.nn.Module): """ This class implements the Squeeze-and-excitation module Arguements ---------- inner_dim: int inner dimension of bottle-neck network of the SE Module (default 12) activation: torch class activation function for SE Module (default torch.nn.Sigmoid) norm: torch class normalization to regularize the model (default BatchNorm1d) Example ------- >>> inp = torch.randn([8, 120, 40]) >>> net = SEmodule(input_shape=inp.shape, inner_dim=64) >>> out = net(inp) >>> out.shape torch.Size([8, 120, 40]) """ def __init__( self, input_shape, inner_dim, activation=torch.nn.Sigmoid, norm=BatchNorm1d, ): super().__init__() self.inner_dim = inner_dim self.norm = norm self.activation = activation bz, t, chn = input_shape self.conv = Sequential(input_shape=input_shape) self.conv.append( DepthwiseSeparableConv1d, out_channels=chn, kernel_size=1, stride=1, ) self.conv.append(self.norm) self.conv.append(self.activation()) self.avg_pool = AdaptivePool(1) self.bottleneck = Sequential( Linear(input_size=input_shape[-1], n_neurons=self.inner_dim), self.activation(), Linear(input_size=self.inner_dim, n_neurons=chn), self.activation(), ) def forward(self, x): bz, t, chn = x.shape x = self.conv(x) avg = self.avg_pool(x) avg = self.bottleneck(avg) context = avg.repeat(1, t, 1) return x * context
class ContextNetBlock(torch.nn.Module): """ This class implements a block in ContextNet Arguements ---------- out_channels: int number of output channels of this model (default 640) kernel_size: int kernel size of convolution layers (default 3) strides: int striding factor for this context block (default 1) num_layers: int number of depthwise convolution layers for this context block (default 5) inner_dim: int inner dimension of bottle-neck network of the SE Module (default 12) beta: float beta to scale the Swish activation (default 1) dropout: float dropout (default 0.15) activation: torch class activation function for this context block (default Swish) se_activation: torch class activation function for SE Module (default torch.nn.Sigmoid) norm: torch class normalization to regularize the model (default BatchNorm1d) residuals: bool whether apply residual connection at this context block (default None) Example ------- >>> inp = torch.randn([8, 120, 40]) >>> block = ContextNetBlock(256, 3, 5, 12, input_shape=inp.shape, stride=2) >>> out = block(inp) >>> out.shape torch.Size([8, 60, 256]) """ def __init__( self, out_channels, kernel_size, num_layers, inner_dim, input_shape, stride=1, beta=1, dropout=0.15, activation=Swish, se_activation=torch.nn.Sigmoid, norm=BatchNorm1d, residual=True, ): super().__init__() self.residual = residual self.Convs = Sequential(input_shape=input_shape) for i in range(num_layers): self.Convs.append( DepthwiseSeparableConv1d, out_channels, kernel_size, stride=stride if i == num_layers - 1 else 1, ) self.Convs.append(norm) self.SE = SEmodule( input_shape=self.Convs.get_output_shape(), inner_dim=inner_dim, activation=se_activation, norm=norm, ) self.drop = Dropout(dropout) self.reduced_cov = None if residual: self.reduced_cov = Sequential(input_shape=input_shape) self.reduced_cov.append( Conv1d, out_channels, kernel_size=3, stride=stride, ) self.reduced_cov.append(norm) if isinstance(activation, Swish): self.activation = activation(beta) else: self.activation = activation() self._reset_params() def forward(self, x): out = self.Convs(x) out = self.SE(out) if self.reduced_cov: out = out + self.reduced_cov(x) out = self.activation(out) return self.drop(out) def _reset_params(self): for p in self.parameters(): if p.dim() > 1: torch.nn.init.kaiming_normal_(p)
class ConvBlock(torch.nn.Module): """An implementation of convolution block with 1d or 2d convolutions (depthwise). Arguments ---------- out_channels : int Number of output channels of this model (default 640). kernel_size : int Kernel size of convolution layers (default 3). strides : int Striding factor for this block (default 1). num_layers : int Number of depthwise convolution layers for this block. activation : torch class Activation function for this block. norm : torch class Normalization to regularize the model (default BatchNorm1d). residuals: bool Whether apply residual connection at this block (default None). Example ------- >>> x = torch.rand((8, 30, 10)) >>> conv = ConvBlock(2, 16, input_shape=x.shape) >>> out = conv(x) >>> x.shape torch.Size([8, 30, 10]) """ def __init__( self, num_layers, out_channels, input_shape, kernel_size=3, stride=1, dilation=1, residual=False, conv_module=Conv2d, activation=torch.nn.LeakyReLU, norm=None, dropout=0.1, ): super().__init__() self.convs = Sequential(input_shape=input_shape) for i in range(num_layers): self.convs.append( conv_module, out_channels=out_channels, kernel_size=kernel_size, stride=stride if i == num_layers - 1 else 1, dilation=dilation, layer_name=f"conv_{i}", ) if norm is not None: self.convs.append(norm, layer_name=f"norm_{i}") self.convs.append(activation(), layer_name=f"act_{i}") self.convs.append( torch.nn.Dropout(dropout), layer_name=f"dropout_{i}" ) self.reduce_conv = None self.drop = None if residual: self.reduce_conv = Sequential(input_shape=input_shape) self.reduce_conv.append( conv_module, out_channels=out_channels, kernel_size=1, stride=stride, layer_name="conv", ) self.reduce_conv.append(norm, layer_name="norm") self.drop = torch.nn.Dropout(dropout) def forward(self, x): out = self.convs(x) if self.reduce_conv: out = out + self.reduce_conv(x) out = self.drop(x) return out