class InceptionCUnit(nn.Module): """ InceptionResNetV1 type Inception-C unit. Parameters: ---------- scale : float, default 1.0 Scale value for residual branch. activate : bool, default True Whether activate the convolution block. bn_eps : float Small float added to variance in Batch norm. """ def __init__(self, bn_eps, scale=0.2, activate=True): super(InceptionCUnit, self).__init__() self.activate = activate self.scale = scale in_channels = 1792 self.branches = Concurrent() self.branches.add_module( "branch1", Conv1x1Branch(in_channels=in_channels, out_channels=192, bn_eps=bn_eps)) self.branches.add_module( "branch2", ConvSeqBranch(in_channels=in_channels, out_channels_list=(192, 192, 192), kernel_size_list=(1, (1, 3), (3, 1)), strides_list=(1, 1, 1), padding_list=(0, (0, 1), (1, 0)), bn_eps=bn_eps)) self.conv = conv1x1(in_channels=384, out_channels=in_channels, bias=True) if self.activate: self.activ = nn.ReLU(inplace=True) def forward(self, x): identity = x x = self.branches(x) x = self.conv(x) x = self.scale * x + identity if self.activate: x = self.activ(x) return x
class InceptionBUnit(nn.Module): """ InceptionResNetV1 type Inception-B unit. Parameters: ---------- bn_eps : float Small float added to variance in Batch norm. """ def __init__(self, bn_eps): super(InceptionBUnit, self).__init__() self.scale = 0.10 in_channels = 896 self.branches = Concurrent() self.branches.add_module( "branch1", Conv1x1Branch(in_channels=in_channels, out_channels=128, bn_eps=bn_eps)) self.branches.add_module( "branch2", ConvSeqBranch(in_channels=in_channels, out_channels_list=(128, 128, 128), kernel_size_list=(1, (1, 7), (7, 1)), strides_list=(1, 1, 1), padding_list=(0, (0, 3), (3, 0)), bn_eps=bn_eps)) self.conv = conv1x1(in_channels=256, out_channels=in_channels, bias=True) self.activ = nn.ReLU(inplace=True) def forward(self, x): identity = x x = self.branches(x) x = self.conv(x) x = self.scale * x + identity x = self.activ(x) return x
class ReductionAUnit(nn.Module): """ InceptionResNetV1 type Reduction-A unit. Parameters: ---------- bn_eps : float Small float added to variance in Batch norm. """ def __init__(self, bn_eps): super(ReductionAUnit, self).__init__() in_channels = 256 self.branches = Concurrent() self.branches.add_module( "branch1", ConvSeqBranch(in_channels=in_channels, out_channels_list=(384, ), kernel_size_list=(3, ), strides_list=(2, ), padding_list=(0, ), bn_eps=bn_eps)) self.branches.add_module( "branch2", ConvSeqBranch(in_channels=in_channels, out_channels_list=(192, 192, 256), kernel_size_list=(1, 3, 3), strides_list=(1, 1, 2), padding_list=(0, 1, 0), bn_eps=bn_eps)) self.branches.add_module("branch3", MaxPoolBranch()) def forward(self, x): x = self.branches(x) return x
class ESPBlock(nn.Module): """ ESP block, which is based on the following principle: Reduce ---> Split ---> Transform --> Merge. Parameters: ---------- in_channels : int Number of input channels. out_channels : int Number of output channels. kernel_sizes : list of int Convolution window size for branches. scale_factors : list of int Scale factor for branches. use_residual : bool Whether to use residual connection. bn_eps : float Small float added to variance in Batch norm. """ def __init__(self, in_channels, out_channels, kernel_sizes, scale_factors, use_residual, bn_eps): super(ESPBlock, self).__init__() self.use_residual = use_residual groups = len(kernel_sizes) mid_channels = int(out_channels / groups) res_channels = out_channels - groups * mid_channels self.conv = conv1x1(in_channels=in_channels, out_channels=mid_channels, groups=groups) self.c_shuffle = ChannelShuffle(channels=mid_channels, groups=groups) self.branches = Concurrent() for i in range(groups): out_channels_i = (mid_channels + res_channels) if i == 0 else mid_channels self.branches.add_module( "branch{}".format(i + 1), SBBlock(in_channels=mid_channels, out_channels=out_channels_i, kernel_size=kernel_sizes[i], scale_factor=scale_factors[i], bn_eps=bn_eps)) self.preactiv = PreActivation(in_channels=out_channels, bn_eps=bn_eps) def forward(self, x): if self.use_residual: identity = x x = self.conv(x) x = self.c_shuffle(x) x = self.branches(x) if self.use_residual: x = identity + x x = self.preactiv(x) return x