Beispiel #1
0
    def __init__(self,
                 in_ch,
                 out_ch,
                 expansion=2.25,
                 se_ratio=0.5,
                 kernel_size=3,
                 group_size=8,
                 stride=1,
                 beta=1.0,
                 alpha=0.2,
                 which_conv=base.WSConv2D,
                 activation=jax.nn.relu,
                 stochdepth_rate=None,
                 name=None):
        super().__init__(name=name)
        self.in_ch, self.out_ch = in_ch, out_ch
        self.expansion = expansion
        self.se_ratio = se_ratio
        self.kernel_size = kernel_size
        self.activation = activation
        self.beta, self.alpha = beta, alpha
        # Round expanded with based on group count
        width = int(self.in_ch * expansion)
        self.groups = width // group_size
        self.width = group_size * self.groups
        self.stride = stride
        # Conv 0 (typically expansion conv)
        self.conv0 = which_conv(self.width,
                                kernel_shape=1,
                                padding='SAME',
                                name='conv0')
        # Grouped NxN conv
        self.conv1 = which_conv(self.width,
                                kernel_shape=kernel_size,
                                stride=stride,
                                padding='SAME',
                                feature_group_count=self.groups,
                                name='conv1')
        # Conv 2, typically projection conv
        self.conv2 = which_conv(self.out_ch,
                                kernel_shape=1,
                                padding='SAME',
                                name='conv2')
        # Use shortcut conv on channel change or downsample.
        self.use_projection = stride > 1 or self.in_ch != self.out_ch
        if self.use_projection:
            self.conv_shortcut = which_conv(self.out_ch,
                                            kernel_shape=1,
                                            padding='SAME',
                                            name='conv_shortcut')
        # Squeeze + Excite Module
        self.se = base.SqueezeExcite(self.width, self.width, self.se_ratio)

        # Are we using stochastic depth?
        self._has_stochdepth = (stochdepth_rate is not None
                                and stochdepth_rate > 0.
                                and stochdepth_rate < 1.0)
        if self._has_stochdepth:
            self.stoch_depth = base.StochDepth(stochdepth_rate)
Beispiel #2
0
 def __init__(self,
              in_ch,
              out_ch,
              num_blocks,
              bottleneck_ratio=0.25,
              kernel_size=3,
              stride=1,
              which_conv=hk.Conv2D,
              activation=jax.nn.relu,
              stochdepth_rate=None,
              name=None):
     super().__init__(name=name)
     self.in_ch, self.out_ch = in_ch, out_ch
     self.kernel_size = kernel_size
     self.activation = activation
     # Bottleneck width
     self.width = int(self.out_ch * bottleneck_ratio)
     self.stride = stride
     # Conv 0 (typically expansion conv)
     conv0_init = hk.initializers.RandomNormal(
         stddev=((2 / self.width)**0.5) * (num_blocks**(-0.25)))
     self.conv0 = which_conv(self.width,
                             kernel_shape=1,
                             padding='SAME',
                             name='conv0',
                             w_init=conv0_init)
     # Grouped NxN conv
     conv1_init = hk.initializers.RandomNormal(
         stddev=((2 / (self.width *
                       (kernel_size**2)))**0.5) * (num_blocks**(-0.25)))
     self.conv1 = which_conv(self.width,
                             kernel_shape=kernel_size,
                             stride=stride,
                             padding='SAME',
                             name='conv1',
                             w_init=conv1_init)
     # Conv 2, typically projection conv
     self.conv2 = which_conv(self.out_ch,
                             kernel_shape=1,
                             padding='SAME',
                             name='conv2',
                             w_init=hk.initializers.Constant(0))
     # Use shortcut conv on channel change or downsample.
     self.use_projection = stride > 1 or self.in_ch != self.out_ch
     if self.use_projection:
         shortcut_init = hk.initializers.RandomNormal(
             stddev=(2 / self.out_ch)**0.5)
         self.conv_shortcut = which_conv(self.out_ch,
                                         kernel_shape=1,
                                         stride=stride,
                                         padding='SAME',
                                         name='conv_shortcut',
                                         w_init=shortcut_init)
     # Are we using stochastic depth?
     self._has_stochdepth = (stochdepth_rate is not None
                             and stochdepth_rate > 0.
                             and stochdepth_rate < 1.0)
     if self._has_stochdepth:
         self.stoch_depth = base.StochDepth(stochdepth_rate)
Beispiel #3
0
    def __init__(self,
                 in_ch,
                 out_ch,
                 bottleneck_ratio=0.25,
                 kernel_size=3,
                 stride=1,
                 beta=1.0,
                 alpha=0.2,
                 which_conv=base.WSConv2D,
                 activation=jax.nn.relu,
                 skipinit_gain=jnp.zeros,
                 stochdepth_rate=None,
                 use_se=False,
                 se_ratio=0.25,
                 name=None):
        super().__init__(name=name)
        self.in_ch, self.out_ch = in_ch, out_ch
        self.kernel_size = kernel_size
        self.activation = activation
        self.beta, self.alpha = beta, alpha
        self.skipinit_gain = skipinit_gain
        self.use_se, self.se_ratio = use_se, se_ratio
        # Bottleneck width
        self.width = int(self.out_ch * bottleneck_ratio)
        self.stride = stride
        # Conv 0 (typically expansion conv)
        self.conv0 = which_conv(self.width,
                                kernel_shape=1,
                                padding='SAME',
                                name='conv0')
        # Grouped NxN conv
        self.conv1 = which_conv(self.width,
                                kernel_shape=kernel_size,
                                stride=stride,
                                padding='SAME',
                                name='conv1')
        # Conv 2, typically projection conv
        self.conv2 = which_conv(self.out_ch,
                                kernel_shape=1,
                                padding='SAME',
                                name='conv2')
        # Use shortcut conv on channel change or downsample.
        self.use_projection = stride > 1 or self.in_ch != self.out_ch
        if self.use_projection:
            self.conv_shortcut = which_conv(self.out_ch,
                                            kernel_shape=1,
                                            stride=stride,
                                            padding='SAME',
                                            name='conv_shortcut')
        # Are we using stochastic depth?
        self._has_stochdepth = (stochdepth_rate is not None
                                and stochdepth_rate > 0.
                                and stochdepth_rate < 1.0)
        if self._has_stochdepth:
            self.stoch_depth = base.StochDepth(stochdepth_rate)

        if self.use_se:
            self.se = base.SqueezeExcite(self.out_ch, self.out_ch,
                                         self.se_ratio)