def apply(self, x, *, train, num_classes, block_class=BottleneckResNetImageNetBlock, stage_sizes, width_factor=1, normalization='bn', activation_f=None, std_penalty_mult=0, use_residual=1, bias_scale=0.0, weight_norm='none', compensate_padding=True, softplus_scale=None, no_head=False, zero_inits=True): """Construct ResNet V1 with `num_classes` outputs.""" self._stage_sizes = stage_sizes if std_penalty_mult > 0: raise NotImplementedError( 'std_penalty_mult not supported for ResNetImageNet') width = 64 * width_factor # Root block. activation_f = get_activation_f(activation_f, train, softplus_scale, bias_scale) norm = get_norm(activation_f, normalization, train) conv = get_conv(activation_f, bias_scale, weight_norm, compensate_padding, normalization) x = conv(x, width, kernel_size=(7, 7), strides=(2, 2), name='init_conv') x = norm(x, name='init_bn') if compensate_padding: # NOTE: this leads to lower performance. x = nn.avg_pool(x, (2, 2), strides=(2, 2), padding='SAME') else: x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') # Stages. for i, stage_size in enumerate(stage_sizes): x = ResNetStage( x, stage_size, filters=width * 2**i, block_class=block_class, first_block_strides=(1, 1) if i == 0 else (2, 2), train=train, name=f'stage{i + 1}', conv=conv, norm=norm, activation_f=activation_f, use_residual=use_residual, zero_inits=zero_inits, ) if not no_head: # Head. x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_classes, kernel_init=nn.initializers.zeros if zero_inits else nn.initializers.lecun_normal(), name='head') return x, 0, {}
def apply(self, x, num_outputs, pyramid_alpha=200, pyramid_depth=272, train=True): assert (pyramid_depth - 2) % 9 == 0 # Shake-drop hyper-params mask_prob = 0.5 alpha_min, alpha_max = (-1.0, 1.0) beta_min, beta_max = (0.0, 1.0) # Bottleneck network size blocks_per_group = (pyramid_depth - 2) // 9 # See Eqn 2 in https://arxiv.org/abs/1610.02915 num_channels = 16 # N in https://arxiv.org/abs/1610.02915 total_blocks = blocks_per_group * 3 delta_channels = pyramid_alpha / total_blocks x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv') x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, name='init_bn') layer_num = 1 for block_i in range(blocks_per_group): num_channels += delta_channels layer_mask_prob = _calc_shakedrop_mask_prob( layer_num, total_blocks, mask_prob) x = BottleneckShakeDrop(x, int(num_channels), (1, 1), layer_mask_prob, alpha_min, alpha_max, beta_min, beta_max, train=train) layer_num += 1 for block_i in range(blocks_per_group): num_channels += delta_channels layer_mask_prob = _calc_shakedrop_mask_prob( layer_num, total_blocks, mask_prob) x = BottleneckShakeDrop(x, int(num_channels), ((2, 2) if block_i == 0 else (1, 1)), layer_mask_prob, alpha_min, alpha_max, beta_min, beta_max, train=train) layer_num += 1 for block_i in range(blocks_per_group): num_channels += delta_channels layer_mask_prob = _calc_shakedrop_mask_prob( layer_num, total_blocks, mask_prob) x = BottleneckShakeDrop(x, int(num_channels), ((2, 2) if block_i == 0 else (1, 1)), layer_mask_prob, alpha_min, alpha_max, beta_min, beta_max, train=train) layer_num += 1 assert layer_num - 1 == total_blocks x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, name='final_bn') x = jax.nn.relu(x) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(x, num_outputs) return x
def apply(self, x, blocks_per_group, channel_multiplier, num_outputs, dropout_rate=0.0, normalization='bn', activation_f=None, std_penalty_mult=0, use_residual=1, train=True, bias_scale=0.0, weight_norm='none', no_head=False, compensate_padding=True, softplus_scale=None): penalty = 0 activation_f = get_activation_f(activation_f, train, softplus_scale, bias_scale) norm = get_norm(activation_f, normalization, train) conv = get_conv(activation_f, bias_scale, weight_norm, compensate_padding, normalization) x = conv(x, 16 * channel_multiplier, (3, 3), padding='SAME', name='init_conv') x, g_penalty = WideResnetGroup(x, blocks_per_group, 16 * channel_multiplier, dropout_rate=dropout_rate, normalization=normalization, activation_f=activation_f, std_penalty_mult=std_penalty_mult, use_residual=use_residual, train=train, bias_scale=bias_scale, weight_norm=weight_norm) penalty += g_penalty x, g_penalty = WideResnetGroup(x, blocks_per_group, 32 * channel_multiplier, (2, 2), dropout_rate=dropout_rate, normalization=normalization, activation_f=activation_f, std_penalty_mult=std_penalty_mult, use_residual=use_residual, train=train, bias_scale=bias_scale, weight_norm=weight_norm) penalty += g_penalty x, g_penalty = WideResnetGroup(x, blocks_per_group, 64 * channel_multiplier, (2, 2), dropout_rate=dropout_rate, normalization=normalization, activation_f=activation_f, std_penalty_mult=std_penalty_mult, use_residual=use_residual, train=train, bias_scale=bias_scale, weight_norm=weight_norm) penalty += g_penalty x = norm(x, name='final_norm') if std_penalty_mult > 0: penalty += std_penalty(x) if not no_head: x = activation_f(x, features=x.shape[-1]) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(x, num_outputs) return x, penalty, {}
def apply(self, x, num_outputs, pyramid_alpha=200, pyramid_depth=272, train=True): """Implements the forward pass in the module. Args: x: Input to the module. Should have shape [batch_size, dim, dim, 3] where dim is the resolution of the image. num_outputs: Dimension of the output of the model (ie number of classes for a classification problem). pyramid_alpha: See paper. pyramid_depth: See paper. train: If False, will use the moving average for batch norm statistics. Else, will use statistics computed on the batch. Returns: The output of the PyramidNet model, a tensor of shape [batch_size, num_classes]. """ assert (pyramid_depth - 2) % 9 == 0 # Shake-drop hyper-params mask_prob = 0.5 alpha_min, alpha_max = (-1.0, 1.0) beta_min, beta_max = (0.0, 1.0) # Bottleneck network size blocks_per_group = (pyramid_depth - 2) // 9 # See Eqn 2 in https://arxiv.org/abs/1610.02915 num_channels = 16 # N in https://arxiv.org/abs/1610.02915 total_blocks = blocks_per_group * 3 delta_channels = pyramid_alpha / total_blocks x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv', bias=False, kernel_init=utils.conv_kernel_init_fn) x = utils.activation(x, apply_relu=False, train=train, name='init_bn') layer_num = 1 for block_i in range(blocks_per_group): num_channels += delta_channels layer_mask_prob = _calc_shakedrop_mask_prob( layer_num, total_blocks, mask_prob) x = BottleneckShakeDrop(x, int(round(num_channels)), (1, 1), layer_mask_prob, alpha_min, alpha_max, beta_min, beta_max, train=train) layer_num += 1 for block_i in range(blocks_per_group): num_channels += delta_channels layer_mask_prob = _calc_shakedrop_mask_prob( layer_num, total_blocks, mask_prob) x = BottleneckShakeDrop(x, int(round(num_channels)), ((2, 2) if block_i == 0 else (1, 1)), layer_mask_prob, alpha_min, alpha_max, beta_min, beta_max, train=train) layer_num += 1 for block_i in range(blocks_per_group): num_channels += delta_channels layer_mask_prob = _calc_shakedrop_mask_prob( layer_num, total_blocks, mask_prob) x = BottleneckShakeDrop(x, int(round(num_channels)), ((2, 2) if block_i == 0 else (1, 1)), layer_mask_prob, alpha_min, alpha_max, beta_min, beta_max, train=train) layer_num += 1 assert layer_num - 1 == total_blocks x = utils.activation(x, train=train, name='final_bn') x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(x, num_outputs, kernel_init=utils.dense_layer_init_fn) return x