def __init__(self, in_ch, out_ch, expansion=2.25, se_ratio=0.5, kernel_size=3, group_size=8, stride=1, beta=1.0, alpha=0.2, which_conv=base.WSConv2D, activation=jax.nn.relu, stochdepth_rate=None, name=None): super().__init__(name=name) self.in_ch, self.out_ch = in_ch, out_ch self.expansion = expansion self.se_ratio = se_ratio self.kernel_size = kernel_size self.activation = activation self.beta, self.alpha = beta, alpha # Round expanded with based on group count width = int(self.in_ch * expansion) self.groups = width // group_size self.width = group_size * self.groups self.stride = stride # Conv 0 (typically expansion conv) self.conv0 = which_conv(self.width, kernel_shape=1, padding='SAME', name='conv0') # Grouped NxN conv self.conv1 = which_conv(self.width, kernel_shape=kernel_size, stride=stride, padding='SAME', feature_group_count=self.groups, name='conv1') # Conv 2, typically projection conv self.conv2 = which_conv(self.out_ch, kernel_shape=1, padding='SAME', name='conv2') # Use shortcut conv on channel change or downsample. self.use_projection = stride > 1 or self.in_ch != self.out_ch if self.use_projection: self.conv_shortcut = which_conv(self.out_ch, kernel_shape=1, padding='SAME', name='conv_shortcut') # Squeeze + Excite Module self.se = base.SqueezeExcite(self.width, self.width, self.se_ratio) # Are we using stochastic depth? self._has_stochdepth = (stochdepth_rate is not None and stochdepth_rate > 0. and stochdepth_rate < 1.0) if self._has_stochdepth: self.stoch_depth = base.StochDepth(stochdepth_rate)
def __init__(self, in_ch, out_ch, bottleneck_ratio=0.25, kernel_size=3, stride=1, beta=1.0, alpha=0.2, which_conv=base.WSConv2D, activation=jax.nn.relu, skipinit_gain=jnp.zeros, stochdepth_rate=None, use_se=False, se_ratio=0.25, name=None): super().__init__(name=name) self.in_ch, self.out_ch = in_ch, out_ch self.kernel_size = kernel_size self.activation = activation self.beta, self.alpha = beta, alpha self.skipinit_gain = skipinit_gain self.use_se, self.se_ratio = use_se, se_ratio # Bottleneck width self.width = int(self.out_ch * bottleneck_ratio) self.stride = stride # Conv 0 (typically expansion conv) self.conv0 = which_conv(self.width, kernel_shape=1, padding='SAME', name='conv0') # Grouped NxN conv self.conv1 = which_conv(self.width, kernel_shape=kernel_size, stride=stride, padding='SAME', name='conv1') # Conv 2, typically projection conv self.conv2 = which_conv(self.out_ch, kernel_shape=1, padding='SAME', name='conv2') # Use shortcut conv on channel change or downsample. self.use_projection = stride > 1 or self.in_ch != self.out_ch if self.use_projection: self.conv_shortcut = which_conv(self.out_ch, kernel_shape=1, stride=stride, padding='SAME', name='conv_shortcut') # Are we using stochastic depth? self._has_stochdepth = (stochdepth_rate is not None and stochdepth_rate > 0. and stochdepth_rate < 1.0) if self._has_stochdepth: self.stoch_depth = base.StochDepth(stochdepth_rate) if self.use_se: self.se = base.SqueezeExcite(self.out_ch, self.out_ch, self.se_ratio)
def __init__(self, out_ch, stride=1, use_projection=False, activation=jax.nn.relu, which_norm=hk.BatchNorm, which_conv=hk.Conv2D, use_se=False, se_ratio=0.25, name=None): super().__init__(name=name) self.out_ch = out_ch self.stride = stride self.use_projection = use_projection self.activation = activation self.which_norm = which_norm self.which_conv = which_conv self.use_se = use_se self.se_ratio = se_ratio self.width = self.out_ch // 4 self.bn0 = which_norm(name='bn0') self.conv0 = which_conv(self.width, kernel_shape=1, with_bias=False, padding='SAME', name='conv0') self.bn1 = which_norm(name='bn1') self.conv1 = which_conv(self.width, stride=self.stride, kernel_shape=3, with_bias=False, padding='SAME', name='conv1') self.bn2 = which_norm(name='bn2') self.conv2 = which_conv(self.out_ch, kernel_shape=1, with_bias=False, padding='SAME', name='conv2') if self.use_projection: self.conv_shortcut = which_conv(self.out_ch, stride=stride, kernel_shape=1, with_bias=False, padding='SAME', name='conv_shortcut') if self.use_se: self.se = base.SqueezeExcite(self.out_ch, self.out_ch, self.se_ratio)