Пример #1
0
class InceptionCUnit(nn.Module):
    """
    InceptionResNetV1 type Inception-C unit.

    Parameters:
    ----------
    scale : float, default 1.0
        Scale value for residual branch.
    activate : bool, default True
        Whether activate the convolution block.
    bn_eps : float
        Small float added to variance in Batch norm.
    """
    def __init__(self, bn_eps, scale=0.2, activate=True):
        super(InceptionCUnit, self).__init__()
        self.activate = activate
        self.scale = scale
        in_channels = 1792

        self.branches = Concurrent()
        self.branches.add_module(
            "branch1",
            Conv1x1Branch(in_channels=in_channels,
                          out_channels=192,
                          bn_eps=bn_eps))
        self.branches.add_module(
            "branch2",
            ConvSeqBranch(in_channels=in_channels,
                          out_channels_list=(192, 192, 192),
                          kernel_size_list=(1, (1, 3), (3, 1)),
                          strides_list=(1, 1, 1),
                          padding_list=(0, (0, 1), (1, 0)),
                          bn_eps=bn_eps))
        self.conv = conv1x1(in_channels=384,
                            out_channels=in_channels,
                            bias=True)
        if self.activate:
            self.activ = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x
        x = self.branches(x)
        x = self.conv(x)
        x = self.scale * x + identity
        if self.activate:
            x = self.activ(x)
        return x
Пример #2
0
class InceptionBUnit(nn.Module):
    """
    InceptionResNetV1 type Inception-B unit.

    Parameters:
    ----------
    bn_eps : float
        Small float added to variance in Batch norm.
    """
    def __init__(self, bn_eps):
        super(InceptionBUnit, self).__init__()
        self.scale = 0.10
        in_channels = 896

        self.branches = Concurrent()
        self.branches.add_module(
            "branch1",
            Conv1x1Branch(in_channels=in_channels,
                          out_channels=128,
                          bn_eps=bn_eps))
        self.branches.add_module(
            "branch2",
            ConvSeqBranch(in_channels=in_channels,
                          out_channels_list=(128, 128, 128),
                          kernel_size_list=(1, (1, 7), (7, 1)),
                          strides_list=(1, 1, 1),
                          padding_list=(0, (0, 3), (3, 0)),
                          bn_eps=bn_eps))
        self.conv = conv1x1(in_channels=256,
                            out_channels=in_channels,
                            bias=True)
        self.activ = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x
        x = self.branches(x)
        x = self.conv(x)
        x = self.scale * x + identity
        x = self.activ(x)
        return x
Пример #3
0
class ReductionAUnit(nn.Module):
    """
    InceptionResNetV1 type Reduction-A unit.

    Parameters:
    ----------
    bn_eps : float
        Small float added to variance in Batch norm.
    """
    def __init__(self, bn_eps):
        super(ReductionAUnit, self).__init__()
        in_channels = 256

        self.branches = Concurrent()
        self.branches.add_module(
            "branch1",
            ConvSeqBranch(in_channels=in_channels,
                          out_channels_list=(384, ),
                          kernel_size_list=(3, ),
                          strides_list=(2, ),
                          padding_list=(0, ),
                          bn_eps=bn_eps))
        self.branches.add_module(
            "branch2",
            ConvSeqBranch(in_channels=in_channels,
                          out_channels_list=(192, 192, 256),
                          kernel_size_list=(1, 3, 3),
                          strides_list=(1, 1, 2),
                          padding_list=(0, 1, 0),
                          bn_eps=bn_eps))
        self.branches.add_module("branch3", MaxPoolBranch())

    def forward(self, x):
        x = self.branches(x)
        return x
Пример #4
0
class ESPBlock(nn.Module):
    """
    ESP block, which is based on the following principle: Reduce ---> Split ---> Transform --> Merge.

    Parameters:
    ----------
    in_channels : int
        Number of input channels.
    out_channels : int
        Number of output channels.
    kernel_sizes : list of int
        Convolution window size for branches.
    scale_factors : list of int
        Scale factor for branches.
    use_residual : bool
        Whether to use residual connection.
    bn_eps : float
        Small float added to variance in Batch norm.
    """
    def __init__(self, in_channels, out_channels, kernel_sizes, scale_factors,
                 use_residual, bn_eps):
        super(ESPBlock, self).__init__()
        self.use_residual = use_residual
        groups = len(kernel_sizes)

        mid_channels = int(out_channels / groups)
        res_channels = out_channels - groups * mid_channels

        self.conv = conv1x1(in_channels=in_channels,
                            out_channels=mid_channels,
                            groups=groups)

        self.c_shuffle = ChannelShuffle(channels=mid_channels, groups=groups)

        self.branches = Concurrent()
        for i in range(groups):
            out_channels_i = (mid_channels +
                              res_channels) if i == 0 else mid_channels
            self.branches.add_module(
                "branch{}".format(i + 1),
                SBBlock(in_channels=mid_channels,
                        out_channels=out_channels_i,
                        kernel_size=kernel_sizes[i],
                        scale_factor=scale_factors[i],
                        bn_eps=bn_eps))

        self.preactiv = PreActivation(in_channels=out_channels, bn_eps=bn_eps)

    def forward(self, x):
        if self.use_residual:
            identity = x

        x = self.conv(x)
        x = self.c_shuffle(x)
        x = self.branches(x)

        if self.use_residual:
            x = identity + x

        x = self.preactiv(x)
        return x