Example #1
0
 def __init__(self,
              in_channels,
              mid_channels,
              out_channels,
              stride=1,
              dilate=1,
              groups=1,
              initialW=None,
              bn_kwargs={},
              residual_conv=False,
              stride_first=False,
              add_seblock=False):
     if stride_first:
         first_stride = stride
         second_stride = 1
     else:
         first_stride = 1
         second_stride = stride
     super(Bottleneck, self).__init__()
     with self.init_scope():
         self.conv1 = Conv2DBNActiv(in_channels,
                                    mid_channels,
                                    1,
                                    first_stride,
                                    0,
                                    nobias=True,
                                    initialW=initialW,
                                    bn_kwargs=bn_kwargs)
         # pad = dilate
         self.conv2 = Conv2DBNActiv(mid_channels,
                                    mid_channels,
                                    3,
                                    second_stride,
                                    dilate,
                                    dilate,
                                    groups,
                                    nobias=True,
                                    initialW=initialW,
                                    bn_kwargs=bn_kwargs)
         self.conv3 = Conv2DBNActiv(mid_channels,
                                    out_channels,
                                    1,
                                    1,
                                    0,
                                    nobias=True,
                                    initialW=initialW,
                                    activ=None,
                                    bn_kwargs=bn_kwargs)
         if add_seblock:
             self.se = SEBlock(out_channels)
         if residual_conv:
             self.residual_conv = Conv2DBNActiv(in_channels,
                                                out_channels,
                                                1,
                                                stride,
                                                0,
                                                nobias=True,
                                                initialW=initialW,
                                                activ=None,
                                                bn_kwargs=bn_kwargs)
Example #2
0
    def __init__(self, levels=8, scales=6, planes=1024):
        super(SFAM, self).__init__()

        self.levels = levels
        self.scales = scales

        with self.init_scope():
            self.ses = ChainList(*[SEBlock(planes) for _ in range(scales)])
Example #3
0
 def __init__(self,
              n_channel: int = 2048,
              ratio: int = 2,
              axis: Tuple[int] = (2, 3)) -> None:
     """Initialize."""
     super(GlobalcSEPooling, self).__init__()
     self.axis = axis
     with self.init_scope():
         self.cse = SEBlock(n_channel, ratio)
Example #4
0
    def __init__(self, in_channels, mid_channels, out_channels,
                 stride=1, dilate=1, groups=1, initialW=None, bn_kwargs={},
                 residual_conv=False, stride_first=False, add_seblock=False):
        if stride_first:
            first_stride = stride
            second_stride = 1
        else:
            first_stride = 1
            second_stride = stride

        pad = 2 - stride  # ?
        if residual_conv and dilate > 1:
            dilate = dilate // 2
            pad = dilate

        if dilate > 1:
            pad = dilate

        super(Bottleneck, self).__init__()
        with self.init_scope():
            self.conv1 = Conv2DBNActiv(in_channels, mid_channels,
                                       1, first_stride, 0,
                                       nobias=True, initialW=initialW,
                                       bn_kwargs=bn_kwargs)
            self.conv2 = Conv2DBNActiv(mid_channels, mid_channels,
                                       3, second_stride, pad, dilate,
                                       groups, nobias=True, initialW=initialW,
                                       bn_kwargs=bn_kwargs)
            self.conv3 = Conv2DBNActiv(mid_channels, out_channels, 1, 1, 0,
                                       nobias=True, initialW=initialW,
                                       bn_kwargs=bn_kwargs, activ=None)
            if add_seblock:
                self.se = SEBlock(out_channels)
            if residual_conv:
                if isinstance(residual_conv, chainer.Link):
                    self.residual_conv = residual_conv
                else:
                    self.residual_conv = Conv2DBNActiv(
                        in_channels, out_channels, 1, stride, 0,
                        nobias=True, initialW=initialW,
                        activ=None, bn_kwargs=bn_kwargs)
Example #5
0
    def __init__(self,
                 in_channels,
                 mid_channels,
                 out_channels,
                 stride=1,
                 scale=1,
                 dilate=1,
                 groups=1,
                 initialW=None,
                 bn_kwargs={},
                 residual_conv=False,
                 stride_first=False,
                 add_block=None,
                 aa_kwargs={}):

        if stride_first:
            first_stride = stride
            second_stride = 1
        else:
            first_stride = 1
            second_stride = stride
        super(Bottleneck, self).__init__()
        with self.init_scope():
            self.conv1 = Conv2DBNActiv(in_channels,
                                       mid_channels,
                                       1,
                                       first_stride,
                                       0,
                                       nobias=True,
                                       initialW=initialW,
                                       bn_kwargs=bn_kwargs)
            if len(aa_kwargs) > 0:
                self.conv2 = Conv2DBNActiv(mid_channels,
                                           mid_channels,
                                           3,
                                           1,
                                           1,
                                           1,
                                           1,
                                           1,
                                           nobias=True,
                                           initialW=initialW,
                                           bn_kwargs=bn_kwargs,
                                           aa_kwargs=aa_kwargs)
            elif stride > 1:
                self.conv2 = Conv2DBNActiv(mid_channels,
                                           mid_channels,
                                           3,
                                           second_stride,
                                           dilate,
                                           1,
                                           dilate,
                                           groups,
                                           nobias=True,
                                           initialW=initialW,
                                           bn_kwargs=bn_kwargs)
            else:
                self.conv2 = Conv2DBNActiv(mid_channels,
                                           mid_channels,
                                           3,
                                           second_stride,
                                           dilate,
                                           scale,
                                           dilate,
                                           groups,
                                           nobias=True,
                                           initialW=initialW,
                                           bn_kwargs=bn_kwargs)

            self.conv3 = Conv2DBNActiv(mid_channels,
                                       out_channels,
                                       1,
                                       1,
                                       0,
                                       nobias=True,
                                       initialW=initialW,
                                       activ=None,
                                       bn_kwargs=bn_kwargs)
            if add_block == 'se':
                self.se = SEBlock(out_channels)
            elif add_block == 'gc':
                self.gc = GCBlock(out_channels)
            elif add_block is not None:
                raise ValueError
            if residual_conv:
                self.residual_conv = Conv2DBNActiv(in_channels,
                                                   out_channels,
                                                   1,
                                                   stride,
                                                   0,
                                                   nobias=True,
                                                   initialW=initialW,
                                                   activ=None,
                                                   bn_kwargs=bn_kwargs)
Example #6
0
    def __init__(
            self,
            in_channels,
            out_channels,
            stride=1,
            dilate=1,
            groups=1,
            initialW=None,
            bn_kwargs={},
            residual_conv=False,  # TODO: remove this
            stride_first=False,
            add_seblock=False,
            res_scale=None,
            use_fixup=False):
        """ CTOR """
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        self.dilate = dilate
        self.groups = groups
        self.initialW = initialW
        self.bn_kwargs = bn_kwargs
        self.residual_conv = residual_conv

        if res_scale is not None:
            if use_fixup:
                raise ValueError('Cannot use fixup when res_scale is not None')
            self.res_scale = res_scale
        else:
            self.res_scale = 1.0

        # Fixup
        self.use_fixup = use_fixup
        self.cached_zeros = None

        # Conv function
        ConvLink = Conv2DBNActiv if not use_fixup else FixupConv2D
        # parameters
        kwargs = {
            'ksize': 3,
            'pad': dilate,
            'nobias': True,
            'groups': groups,
            'initialW': initialW
        }
        if not use_fixup:
            kwargs['bn_kwargs'] = bn_kwargs
        else:
            kwargs['use_scale'] = False

        with self.init_scope():
            # pad = dilate
            self.conv1 = ConvLink(in_channels,
                                  out_channels,
                                  stride=1,
                                  **kwargs)

            # parameters for the second conv
            kwargs['activ'] = None
            if use_fixup:
                kwargs['initialW'] = None
                kwargs['use_scale'] = True  # turn on use scale

            self.conv2 = ConvLink(out_channels,
                                  out_channels,
                                  stride=stride,
                                  **kwargs)  # no ReLU after conv2

            # Squeeze-and-Excitation
            if add_seblock:
                # TODO: check whether this block will affect the numerical stability of a model
                self.se = SEBlock(out_channels)

            # the additional mapping block on the residual connection
            if residual_conv:
                if not use_fixup:
                    self.residual = Conv2DBNActiv(in_channels,
                                                  out_channels,
                                                  ksize=1,
                                                  stride=stride,
                                                  pad=0,
                                                  nobias=True,
                                                  initialW=initialW,
                                                  activ=None,
                                                  bn_kwargs=bn_kwargs)
                else:  # When using fixup, we pass the residual connection through average pooling
                    self.residual = FixupIdentity(stride)
            else:
                self.residual = lambda x: x