def __init__(self, in_ch, out_ch, expansion=2.25, se_ratio=0.5, kernel_size=3, group_size=8, stride=1, beta=1.0, alpha=0.2, which_conv=base.WSConv2D, activation=jax.nn.relu, stochdepth_rate=None, name=None): super().__init__(name=name) self.in_ch, self.out_ch = in_ch, out_ch self.expansion = expansion self.se_ratio = se_ratio self.kernel_size = kernel_size self.activation = activation self.beta, self.alpha = beta, alpha # Round expanded with based on group count width = int(self.in_ch * expansion) self.groups = width // group_size self.width = group_size * self.groups self.stride = stride # Conv 0 (typically expansion conv) self.conv0 = which_conv(self.width, kernel_shape=1, padding='SAME', name='conv0') # Grouped NxN conv self.conv1 = which_conv(self.width, kernel_shape=kernel_size, stride=stride, padding='SAME', feature_group_count=self.groups, name='conv1') # Conv 2, typically projection conv self.conv2 = which_conv(self.out_ch, kernel_shape=1, padding='SAME', name='conv2') # Use shortcut conv on channel change or downsample. self.use_projection = stride > 1 or self.in_ch != self.out_ch if self.use_projection: self.conv_shortcut = which_conv(self.out_ch, kernel_shape=1, padding='SAME', name='conv_shortcut') # Squeeze + Excite Module self.se = base.SqueezeExcite(self.width, self.width, self.se_ratio) # Are we using stochastic depth? self._has_stochdepth = (stochdepth_rate is not None and stochdepth_rate > 0. and stochdepth_rate < 1.0) if self._has_stochdepth: self.stoch_depth = base.StochDepth(stochdepth_rate)
def __init__(self, in_ch, out_ch, num_blocks, bottleneck_ratio=0.25, kernel_size=3, stride=1, which_conv=hk.Conv2D, activation=jax.nn.relu, stochdepth_rate=None, name=None): super().__init__(name=name) self.in_ch, self.out_ch = in_ch, out_ch self.kernel_size = kernel_size self.activation = activation # Bottleneck width self.width = int(self.out_ch * bottleneck_ratio) self.stride = stride # Conv 0 (typically expansion conv) conv0_init = hk.initializers.RandomNormal( stddev=((2 / self.width)**0.5) * (num_blocks**(-0.25))) self.conv0 = which_conv(self.width, kernel_shape=1, padding='SAME', name='conv0', w_init=conv0_init) # Grouped NxN conv conv1_init = hk.initializers.RandomNormal( stddev=((2 / (self.width * (kernel_size**2)))**0.5) * (num_blocks**(-0.25))) self.conv1 = which_conv(self.width, kernel_shape=kernel_size, stride=stride, padding='SAME', name='conv1', w_init=conv1_init) # Conv 2, typically projection conv self.conv2 = which_conv(self.out_ch, kernel_shape=1, padding='SAME', name='conv2', w_init=hk.initializers.Constant(0)) # Use shortcut conv on channel change or downsample. self.use_projection = stride > 1 or self.in_ch != self.out_ch if self.use_projection: shortcut_init = hk.initializers.RandomNormal( stddev=(2 / self.out_ch)**0.5) self.conv_shortcut = which_conv(self.out_ch, kernel_shape=1, stride=stride, padding='SAME', name='conv_shortcut', w_init=shortcut_init) # Are we using stochastic depth? self._has_stochdepth = (stochdepth_rate is not None and stochdepth_rate > 0. and stochdepth_rate < 1.0) if self._has_stochdepth: self.stoch_depth = base.StochDepth(stochdepth_rate)
def __init__(self, in_ch, out_ch, bottleneck_ratio=0.25, kernel_size=3, stride=1, beta=1.0, alpha=0.2, which_conv=base.WSConv2D, activation=jax.nn.relu, skipinit_gain=jnp.zeros, stochdepth_rate=None, use_se=False, se_ratio=0.25, name=None): super().__init__(name=name) self.in_ch, self.out_ch = in_ch, out_ch self.kernel_size = kernel_size self.activation = activation self.beta, self.alpha = beta, alpha self.skipinit_gain = skipinit_gain self.use_se, self.se_ratio = use_se, se_ratio # Bottleneck width self.width = int(self.out_ch * bottleneck_ratio) self.stride = stride # Conv 0 (typically expansion conv) self.conv0 = which_conv(self.width, kernel_shape=1, padding='SAME', name='conv0') # Grouped NxN conv self.conv1 = which_conv(self.width, kernel_shape=kernel_size, stride=stride, padding='SAME', name='conv1') # Conv 2, typically projection conv self.conv2 = which_conv(self.out_ch, kernel_shape=1, padding='SAME', name='conv2') # Use shortcut conv on channel change or downsample. self.use_projection = stride > 1 or self.in_ch != self.out_ch if self.use_projection: self.conv_shortcut = which_conv(self.out_ch, kernel_shape=1, stride=stride, padding='SAME', name='conv_shortcut') # Are we using stochastic depth? self._has_stochdepth = (stochdepth_rate is not None and stochdepth_rate > 0. and stochdepth_rate < 1.0) if self._has_stochdepth: self.stoch_depth = base.StochDepth(stochdepth_rate) if self.use_se: self.se = base.SqueezeExcite(self.out_ch, self.out_ch, self.se_ratio)