def apply(self, x, channels, strides, prob, alpha_min, alpha_max, beta_min, beta_max, train=True): """Implements the forward pass in the module. Args: x: Input to the module. Should have shape [batch_size, dim, dim, features] where dim is the resolution (width and height if the input is an image). channels: How many channels to use in the convolutional layers. strides: Strides for the pooling. prob: Probability of dropping the block (see paper for details). alpha_min: See paper. alpha_max: See paper. beta_min: See paper. beta_max: See paper. train: If False, will use the moving average for batch norm statistics. Else, will use statistics computed on the batch. Returns: The output of the bottleneck block. """ y = utils.activation(x, apply_relu=False, train=train, name='bn_1_pre') y = nn.Conv(y, channels, (1, 1), padding='SAME', bias=False, kernel_init=utils.conv_kernel_init_fn, name='1x1_conv_contract') y = utils.activation(y, train=train, name='bn_1_post') y = nn.Conv(y, channels, (3, 3), strides, padding='SAME', bias=False, kernel_init=utils.conv_kernel_init_fn, name='3x3') y = utils.activation(y, train=train, name='bn_2') y = nn.Conv(y, channels * 4, (1, 1), padding='SAME', bias=False, kernel_init=utils.conv_kernel_init_fn, name='1x1_conv_expand') y = utils.activation(y, apply_relu=False, train=train, name='bn_3') if train: y = utils.shake_drop_train(y, prob, alpha_min, alpha_max, beta_min, beta_max) else: y = utils.shake_drop_eval(y, prob, alpha_min, alpha_max) x = _shortcut(x, channels * 4, strides) return x + y
def apply(self, x, channels, strides=(1, 1), activate_before_residual=False, train=True): """Implements the forward pass in the module. Args: x: Input to the module. Should have shape [batch_size, dim, dim, features] where dim is the resolution (width and height if the input is an image). channels: How many channels to use in the convolutional layers. strides: Strides for the pooling. activate_before_residual: True if the batch norm and relu should be applied before the residual branches out (should be True only for the first block of the model). train: If False, will use the moving average for batch norm statistics. Else, will use statistics computed on the batch. Returns: The output of the resnet block. """ if activate_before_residual: x = utils.activation(x, train, name='init_bn') orig_x = x else: orig_x = x block_x = x if not activate_before_residual: block_x = utils.activation(block_x, train, name='init_bn') block_x = nn.Conv(block_x, channels, (3, 3), strides, padding='SAME', bias=False, kernel_init=utils.conv_kernel_init_fn, name='conv1') block_x = utils.activation(block_x, train=train, name='bn_2') block_x = nn.Conv(block_x, channels, (3, 3), padding='SAME', bias=False, kernel_init=utils.conv_kernel_init_fn, name='conv2') return _output_add(block_x, orig_x)
def apply(self, x, blocks_per_group, channel_multiplier, num_outputs, train=True): """Implements a WideResnet module. Args: x: Input to the module. Should have shape [batch_size, dim, dim, 3] where dim is the resolution of the image. blocks_per_group: How many resnet blocks to add to each group (should be 4 blocks for a WRN28, and 6 for a WRN40). channel_multiplier: The multiplier to apply to the number of filters in the model (1 is classical resnet, 10 for WRN28-10, etc...). num_outputs: Dimension of the output of the model (ie number of classes for a classification problem). train: If False, will use the moving average for batch norm statistics. Returns: The output of the WideResnet, a tensor of shape [batch_size, num_classes]. """ first_x = x x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv', kernel_init=utils.conv_kernel_init_fn, bias=False) x = WideResnetGroup(x, blocks_per_group, 16 * channel_multiplier, activate_before_residual=True, train=train) x = WideResnetGroup(x, blocks_per_group, 32 * channel_multiplier, (2, 2), train=train) x = WideResnetGroup(x, blocks_per_group, 64 * channel_multiplier, (2, 2), train=train) if FLAGS.use_additional_skip_connections: x = _output_add(x, first_x) x = utils.activation(x, train=train, name='pre-pool-bn') x = nn.avg_pool(x, x.shape[1:3]) x = x.reshape((x.shape[0], -1)) x = nn.Dense(x, num_outputs, kernel_init=utils.dense_layer_init_fn) return x
def apply(self, x, channels, strides=(1, 1), train=True): """Implements the forward pass in the module. Args: x: Input to the module. Should have shape [batch_size, dim, dim, features] where dim is the resolution (width and height if the input is an image). channels: How many channels to use in the convolutional layers. strides: Strides for the pooling. train: If False, will use the moving average for batch norm statistics. Returns: The output of the resnet block. Will have shape [batch_size, dim, dim, channels] if strides = (1, 1) or [batch_size, dim/2, dim/2, channels] if strides = (2, 2). """ if x.shape[-1] == channels: return x # Skip path 1 h1 = nn.avg_pool(x, (1, 1), strides=strides, padding='VALID') h1 = nn.Conv(h1, channels // 2, (1, 1), strides=(1, 1), padding='SAME', bias=False, kernel_init=utils.conv_kernel_init_fn, name='conv_h1') # Skip path 2 # The next two lines offset the "image" by one pixel on the right and one # down (see Shake-Shake regularization, Xavier Gastaldi for details) pad_arr = [[0, 0], [0, 1], [0, 1], [0, 0]] h2 = jnp.pad(x, pad_arr)[:, 1:, 1:, :] h2 = nn.avg_pool(h2, (1, 1), strides=strides, padding='VALID') h2 = nn.Conv(h2, channels // 2, (1, 1), strides=(1, 1), padding='SAME', bias=False, kernel_init=utils.conv_kernel_init_fn, name='conv_h2') merged_branches = jnp.concatenate([h1, h2], axis=3) return utils.activation(merged_branches, apply_relu=False, train=train, name='bn_residual')
def apply(self, x, blocks_per_group, channel_multiplier, num_outputs, train=True): """Implements a WideResnet with ShakeShake regularization module. Args: x: Input to the module. Should have shape [batch_size, dim, dim, 3] where dim is the resolution of the image. blocks_per_group: How many resnet blocks to add to each group (should be 4 blocks for a WRN26 as per standard shake shake implementation). channel_multiplier: The multiplier to apply to the number of filters in the model (1 is classical resnet, 6 for WRN26-2x6, etc...). num_outputs: Dimension of the output of the model (ie number of classes for a classification problem). train: If False, will use the moving average for batch norm statistics. Else, will use statistics computed on the batch. Returns: The output of the WideResnet with ShakeShake regularization, a tensor of shape [batch_size, num_classes]. """ x = nn.Conv(x, 16, (3, 3), padding='SAME', kernel_init=utils.conv_kernel_init_fn, bias=False, name='init_conv') x = utils.activation(x, apply_relu=False, train=train, name='init_bn') x = WideResnetShakeShakeGroup(x, blocks_per_group, 16 * channel_multiplier, train=train) x = WideResnetShakeShakeGroup(x, blocks_per_group, 32 * channel_multiplier, (2, 2), train=train) x = WideResnetShakeShakeGroup(x, blocks_per_group, 64 * channel_multiplier, (2, 2), train=train) x = jax.nn.relu(x) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) return nn.Dense(x, num_outputs, kernel_init=utils.dense_layer_init_fn)
def apply(self, x, num_outputs, pyramid_alpha=200, pyramid_depth=272, train=True): """Implements the forward pass in the module. Args: x: Input to the module. Should have shape [batch_size, dim, dim, 3] where dim is the resolution of the image. num_outputs: Dimension of the output of the model (ie number of classes for a classification problem). pyramid_alpha: See paper. pyramid_depth: See paper. train: If False, will use the moving average for batch norm statistics. Else, will use statistics computed on the batch. Returns: The output of the PyramidNet model, a tensor of shape [batch_size, num_classes]. """ assert (pyramid_depth - 2) % 9 == 0 # Shake-drop hyper-params mask_prob = 0.5 alpha_min, alpha_max = (-1.0, 1.0) beta_min, beta_max = (0.0, 1.0) # Bottleneck network size blocks_per_group = (pyramid_depth - 2) // 9 # See Eqn 2 in https://arxiv.org/abs/1610.02915 num_channels = 16 # N in https://arxiv.org/abs/1610.02915 total_blocks = blocks_per_group * 3 delta_channels = pyramid_alpha / total_blocks x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv', bias=False, kernel_init=utils.conv_kernel_init_fn) x = utils.activation(x, apply_relu=False, train=train, name='init_bn') layer_num = 1 for block_i in range(blocks_per_group): num_channels += delta_channels layer_mask_prob = _calc_shakedrop_mask_prob( layer_num, total_blocks, mask_prob) x = BottleneckShakeDrop(x, int(round(num_channels)), (1, 1), layer_mask_prob, alpha_min, alpha_max, beta_min, beta_max, train=train) layer_num += 1 for block_i in range(blocks_per_group): num_channels += delta_channels layer_mask_prob = _calc_shakedrop_mask_prob( layer_num, total_blocks, mask_prob) x = BottleneckShakeDrop(x, int(round(num_channels)), ((2, 2) if block_i == 0 else (1, 1)), layer_mask_prob, alpha_min, alpha_max, beta_min, beta_max, train=train) layer_num += 1 for block_i in range(blocks_per_group): num_channels += delta_channels layer_mask_prob = _calc_shakedrop_mask_prob( layer_num, total_blocks, mask_prob) x = BottleneckShakeDrop(x, int(round(num_channels)), ((2, 2) if block_i == 0 else (1, 1)), layer_mask_prob, alpha_min, alpha_max, beta_min, beta_max, train=train) layer_num += 1 assert layer_num - 1 == total_blocks x = utils.activation(x, train=train, name='final_bn') x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(x, num_outputs, kernel_init=utils.dense_layer_init_fn) return x
def apply(self, x, channels, strides=(1, 1), train=True): """Implements the forward pass in the module. Args: x: Input to the module. Should have shape [batch_size, dim, dim, features] where dim is the resolution (width and height if the input is an image). channels: How many channels to use in the convolutional layers. strides: Strides for the pooling. train: If False, will use the moving average for batch norm statistics. Else, will use statistics computed on the batch. Returns: The output of the resnet block. Will have shape [batch_size, dim, dim, channels] if strides = (1, 1) or [batch_size, dim/2, dim/2, channels] if strides = (2, 2). """ a = b = residual = x a = jax.nn.relu(a) a = nn.Conv(a, channels, (3, 3), strides, padding='SAME', bias=False, kernel_init=utils.conv_kernel_init_fn, name='conv_a_1') a = utils.activation(a, train=train, name='bn_a_1') a = nn.Conv(a, channels, (3, 3), padding='SAME', bias=False, kernel_init=utils.conv_kernel_init_fn, name='conv_a_2') a = utils.activation(a, apply_relu=False, train=train, name='bn_a_2') b = jax.nn.relu(b) b = nn.Conv(b, channels, (3, 3), strides, padding='SAME', bias=False, kernel_init=utils.conv_kernel_init_fn, name='conv_b_1') b = utils.activation(b, train=train, name='bn_b_1') b = nn.Conv(b, channels, (3, 3), padding='SAME', bias=False, kernel_init=utils.conv_kernel_init_fn, name='conv_b_2') b = utils.activation(b, apply_relu=False, train=train, name='bn_b_2') if train and not self.is_initializing(): ab = utils.shake_shake_train(a, b) else: ab = utils.shake_shake_eval(a, b) # Apply an up projection in case of channel mismatch. residual = Shortcut(residual, channels, strides, train) return residual + ab