示例#1
0
    def __init__(self,
                 in_ch,
                 out_ch,
                 expansion=2.25,
                 se_ratio=0.5,
                 kernel_size=3,
                 group_size=8,
                 stride=1,
                 beta=1.0,
                 alpha=0.2,
                 which_conv=base.WSConv2D,
                 activation=jax.nn.relu,
                 stochdepth_rate=None,
                 name=None):
        super().__init__(name=name)
        self.in_ch, self.out_ch = in_ch, out_ch
        self.expansion = expansion
        self.se_ratio = se_ratio
        self.kernel_size = kernel_size
        self.activation = activation
        self.beta, self.alpha = beta, alpha
        # Round expanded with based on group count
        width = int(self.in_ch * expansion)
        self.groups = width // group_size
        self.width = group_size * self.groups
        self.stride = stride
        # Conv 0 (typically expansion conv)
        self.conv0 = which_conv(self.width,
                                kernel_shape=1,
                                padding='SAME',
                                name='conv0')
        # Grouped NxN conv
        self.conv1 = which_conv(self.width,
                                kernel_shape=kernel_size,
                                stride=stride,
                                padding='SAME',
                                feature_group_count=self.groups,
                                name='conv1')
        # Conv 2, typically projection conv
        self.conv2 = which_conv(self.out_ch,
                                kernel_shape=1,
                                padding='SAME',
                                name='conv2')
        # Use shortcut conv on channel change or downsample.
        self.use_projection = stride > 1 or self.in_ch != self.out_ch
        if self.use_projection:
            self.conv_shortcut = which_conv(self.out_ch,
                                            kernel_shape=1,
                                            padding='SAME',
                                            name='conv_shortcut')
        # Squeeze + Excite Module
        self.se = base.SqueezeExcite(self.width, self.width, self.se_ratio)

        # Are we using stochastic depth?
        self._has_stochdepth = (stochdepth_rate is not None
                                and stochdepth_rate > 0.
                                and stochdepth_rate < 1.0)
        if self._has_stochdepth:
            self.stoch_depth = base.StochDepth(stochdepth_rate)
示例#2
0
    def __init__(self,
                 in_ch,
                 out_ch,
                 bottleneck_ratio=0.25,
                 kernel_size=3,
                 stride=1,
                 beta=1.0,
                 alpha=0.2,
                 which_conv=base.WSConv2D,
                 activation=jax.nn.relu,
                 skipinit_gain=jnp.zeros,
                 stochdepth_rate=None,
                 use_se=False,
                 se_ratio=0.25,
                 name=None):
        super().__init__(name=name)
        self.in_ch, self.out_ch = in_ch, out_ch
        self.kernel_size = kernel_size
        self.activation = activation
        self.beta, self.alpha = beta, alpha
        self.skipinit_gain = skipinit_gain
        self.use_se, self.se_ratio = use_se, se_ratio
        # Bottleneck width
        self.width = int(self.out_ch * bottleneck_ratio)
        self.stride = stride
        # Conv 0 (typically expansion conv)
        self.conv0 = which_conv(self.width,
                                kernel_shape=1,
                                padding='SAME',
                                name='conv0')
        # Grouped NxN conv
        self.conv1 = which_conv(self.width,
                                kernel_shape=kernel_size,
                                stride=stride,
                                padding='SAME',
                                name='conv1')
        # Conv 2, typically projection conv
        self.conv2 = which_conv(self.out_ch,
                                kernel_shape=1,
                                padding='SAME',
                                name='conv2')
        # Use shortcut conv on channel change or downsample.
        self.use_projection = stride > 1 or self.in_ch != self.out_ch
        if self.use_projection:
            self.conv_shortcut = which_conv(self.out_ch,
                                            kernel_shape=1,
                                            stride=stride,
                                            padding='SAME',
                                            name='conv_shortcut')
        # Are we using stochastic depth?
        self._has_stochdepth = (stochdepth_rate is not None
                                and stochdepth_rate > 0.
                                and stochdepth_rate < 1.0)
        if self._has_stochdepth:
            self.stoch_depth = base.StochDepth(stochdepth_rate)

        if self.use_se:
            self.se = base.SqueezeExcite(self.out_ch, self.out_ch,
                                         self.se_ratio)
示例#3
0
    def __init__(self,
                 out_ch,
                 stride=1,
                 use_projection=False,
                 activation=jax.nn.relu,
                 which_norm=hk.BatchNorm,
                 which_conv=hk.Conv2D,
                 use_se=False,
                 se_ratio=0.25,
                 name=None):
        super().__init__(name=name)
        self.out_ch = out_ch
        self.stride = stride
        self.use_projection = use_projection
        self.activation = activation
        self.which_norm = which_norm
        self.which_conv = which_conv
        self.use_se = use_se
        self.se_ratio = se_ratio

        self.width = self.out_ch // 4

        self.bn0 = which_norm(name='bn0')
        self.conv0 = which_conv(self.width,
                                kernel_shape=1,
                                with_bias=False,
                                padding='SAME',
                                name='conv0')
        self.bn1 = which_norm(name='bn1')
        self.conv1 = which_conv(self.width,
                                stride=self.stride,
                                kernel_shape=3,
                                with_bias=False,
                                padding='SAME',
                                name='conv1')
        self.bn2 = which_norm(name='bn2')
        self.conv2 = which_conv(self.out_ch,
                                kernel_shape=1,
                                with_bias=False,
                                padding='SAME',
                                name='conv2')
        if self.use_projection:
            self.conv_shortcut = which_conv(self.out_ch,
                                            stride=stride,
                                            kernel_shape=1,
                                            with_bias=False,
                                            padding='SAME',
                                            name='conv_shortcut')
        if self.use_se:
            self.se = base.SqueezeExcite(self.out_ch, self.out_ch,
                                         self.se_ratio)