def apply(self,
              x,
              *,
              train,
              num_classes,
              block_class=BottleneckResNetImageNetBlock,
              stage_sizes,
              width_factor=1,
              normalization='bn',
              activation_f=None,
              std_penalty_mult=0,
              use_residual=1,
              bias_scale=0.0,
              weight_norm='none',
              compensate_padding=True,
              softplus_scale=None,
              no_head=False,
              zero_inits=True):
        """Construct ResNet V1 with `num_classes` outputs."""
        self._stage_sizes = stage_sizes
        if std_penalty_mult > 0:
            raise NotImplementedError(
                'std_penalty_mult not supported for ResNetImageNet')

        width = 64 * width_factor

        # Root block.
        activation_f = get_activation_f(activation_f, train, softplus_scale,
                                        bias_scale)
        norm = get_norm(activation_f, normalization, train)
        conv = get_conv(activation_f, bias_scale, weight_norm,
                        compensate_padding, normalization)
        x = conv(x,
                 width,
                 kernel_size=(7, 7),
                 strides=(2, 2),
                 name='init_conv')
        x = norm(x, name='init_bn')

        if compensate_padding:
            # NOTE: this leads to lower performance.
            x = nn.avg_pool(x, (2, 2), strides=(2, 2), padding='SAME')
        else:
            x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')

        # Stages.
        for i, stage_size in enumerate(stage_sizes):
            x = ResNetStage(
                x,
                stage_size,
                filters=width * 2**i,
                block_class=block_class,
                first_block_strides=(1, 1) if i == 0 else (2, 2),
                train=train,
                name=f'stage{i + 1}',
                conv=conv,
                norm=norm,
                activation_f=activation_f,
                use_residual=use_residual,
                zero_inits=zero_inits,
            )

        if not no_head:
            # Head.
            x = jnp.mean(x, axis=(1, 2))
            x = nn.Dense(x,
                         num_classes,
                         kernel_init=nn.initializers.zeros
                         if zero_inits else nn.initializers.lecun_normal(),
                         name='head')
        return x, 0, {}
Beispiel #2
0
    def apply(self,
              x,
              num_outputs,
              pyramid_alpha=200,
              pyramid_depth=272,
              train=True):
        assert (pyramid_depth - 2) % 9 == 0

        # Shake-drop hyper-params
        mask_prob = 0.5
        alpha_min, alpha_max = (-1.0, 1.0)
        beta_min, beta_max = (0.0, 1.0)

        # Bottleneck network size
        blocks_per_group = (pyramid_depth - 2) // 9
        # See Eqn 2 in https://arxiv.org/abs/1610.02915
        num_channels = 16
        # N in https://arxiv.org/abs/1610.02915
        total_blocks = blocks_per_group * 3
        delta_channels = pyramid_alpha / total_blocks

        x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv')
        x = nn.BatchNorm(x,
                         use_running_average=not train,
                         momentum=0.9,
                         epsilon=1e-5,
                         name='init_bn')

        layer_num = 1

        for block_i in range(blocks_per_group):
            num_channels += delta_channels
            layer_mask_prob = _calc_shakedrop_mask_prob(
                layer_num, total_blocks, mask_prob)
            x = BottleneckShakeDrop(x,
                                    int(num_channels), (1, 1),
                                    layer_mask_prob,
                                    alpha_min,
                                    alpha_max,
                                    beta_min,
                                    beta_max,
                                    train=train)
            layer_num += 1

        for block_i in range(blocks_per_group):
            num_channels += delta_channels
            layer_mask_prob = _calc_shakedrop_mask_prob(
                layer_num, total_blocks, mask_prob)
            x = BottleneckShakeDrop(x,
                                    int(num_channels),
                                    ((2, 2) if block_i == 0 else (1, 1)),
                                    layer_mask_prob,
                                    alpha_min,
                                    alpha_max,
                                    beta_min,
                                    beta_max,
                                    train=train)
            layer_num += 1

        for block_i in range(blocks_per_group):
            num_channels += delta_channels
            layer_mask_prob = _calc_shakedrop_mask_prob(
                layer_num, total_blocks, mask_prob)
            x = BottleneckShakeDrop(x,
                                    int(num_channels),
                                    ((2, 2) if block_i == 0 else (1, 1)),
                                    layer_mask_prob,
                                    alpha_min,
                                    alpha_max,
                                    beta_min,
                                    beta_max,
                                    train=train)
            layer_num += 1

        assert layer_num - 1 == total_blocks
        x = nn.BatchNorm(x,
                         use_running_average=not train,
                         momentum=0.9,
                         epsilon=1e-5,
                         name='final_bn')
        x = jax.nn.relu(x)
        x = nn.avg_pool(x, (8, 8))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(x, num_outputs)
        return x
    def apply(self,
              x,
              blocks_per_group,
              channel_multiplier,
              num_outputs,
              dropout_rate=0.0,
              normalization='bn',
              activation_f=None,
              std_penalty_mult=0,
              use_residual=1,
              train=True,
              bias_scale=0.0,
              weight_norm='none',
              no_head=False,
              compensate_padding=True,
              softplus_scale=None):

        penalty = 0

        activation_f = get_activation_f(activation_f, train, softplus_scale,
                                        bias_scale)
        norm = get_norm(activation_f, normalization, train)
        conv = get_conv(activation_f, bias_scale, weight_norm,
                        compensate_padding, normalization)
        x = conv(x,
                 16 * channel_multiplier, (3, 3),
                 padding='SAME',
                 name='init_conv')
        x, g_penalty = WideResnetGroup(x,
                                       blocks_per_group,
                                       16 * channel_multiplier,
                                       dropout_rate=dropout_rate,
                                       normalization=normalization,
                                       activation_f=activation_f,
                                       std_penalty_mult=std_penalty_mult,
                                       use_residual=use_residual,
                                       train=train,
                                       bias_scale=bias_scale,
                                       weight_norm=weight_norm)
        penalty += g_penalty
        x, g_penalty = WideResnetGroup(x,
                                       blocks_per_group,
                                       32 * channel_multiplier, (2, 2),
                                       dropout_rate=dropout_rate,
                                       normalization=normalization,
                                       activation_f=activation_f,
                                       std_penalty_mult=std_penalty_mult,
                                       use_residual=use_residual,
                                       train=train,
                                       bias_scale=bias_scale,
                                       weight_norm=weight_norm)
        penalty += g_penalty
        x, g_penalty = WideResnetGroup(x,
                                       blocks_per_group,
                                       64 * channel_multiplier, (2, 2),
                                       dropout_rate=dropout_rate,
                                       normalization=normalization,
                                       activation_f=activation_f,
                                       std_penalty_mult=std_penalty_mult,
                                       use_residual=use_residual,
                                       train=train,
                                       bias_scale=bias_scale,
                                       weight_norm=weight_norm)
        penalty += g_penalty

        x = norm(x, name='final_norm')
        if std_penalty_mult > 0:
            penalty += std_penalty(x)
        if not no_head:
            x = activation_f(x, features=x.shape[-1])
            x = nn.avg_pool(x, (8, 8))
            x = x.reshape((x.shape[0], -1))
            x = nn.Dense(x, num_outputs)
        return x, penalty, {}
    def apply(self,
              x,
              num_outputs,
              pyramid_alpha=200,
              pyramid_depth=272,
              train=True):
        """Implements the forward pass in the module.

    Args:
      x: Input to the module. Should have shape [batch_size, dim, dim, 3]
        where dim is the resolution of the image.
      num_outputs: Dimension of the output of the model (ie number of classes
        for a classification problem).
      pyramid_alpha: See paper.
      pyramid_depth: See paper.
      train: If False, will use the moving average for batch norm statistics.
        Else, will use statistics computed on the batch.

    Returns:
      The output of the PyramidNet model, a tensor of shape
        [batch_size, num_classes].
    """
        assert (pyramid_depth - 2) % 9 == 0

        # Shake-drop hyper-params
        mask_prob = 0.5
        alpha_min, alpha_max = (-1.0, 1.0)
        beta_min, beta_max = (0.0, 1.0)

        # Bottleneck network size
        blocks_per_group = (pyramid_depth - 2) // 9
        # See Eqn 2 in https://arxiv.org/abs/1610.02915
        num_channels = 16
        # N in https://arxiv.org/abs/1610.02915
        total_blocks = blocks_per_group * 3
        delta_channels = pyramid_alpha / total_blocks

        x = nn.Conv(x,
                    16, (3, 3),
                    padding='SAME',
                    name='init_conv',
                    bias=False,
                    kernel_init=utils.conv_kernel_init_fn)
        x = utils.activation(x, apply_relu=False, train=train, name='init_bn')

        layer_num = 1

        for block_i in range(blocks_per_group):
            num_channels += delta_channels
            layer_mask_prob = _calc_shakedrop_mask_prob(
                layer_num, total_blocks, mask_prob)
            x = BottleneckShakeDrop(x,
                                    int(round(num_channels)), (1, 1),
                                    layer_mask_prob,
                                    alpha_min,
                                    alpha_max,
                                    beta_min,
                                    beta_max,
                                    train=train)
            layer_num += 1

        for block_i in range(blocks_per_group):
            num_channels += delta_channels
            layer_mask_prob = _calc_shakedrop_mask_prob(
                layer_num, total_blocks, mask_prob)
            x = BottleneckShakeDrop(x,
                                    int(round(num_channels)),
                                    ((2, 2) if block_i == 0 else (1, 1)),
                                    layer_mask_prob,
                                    alpha_min,
                                    alpha_max,
                                    beta_min,
                                    beta_max,
                                    train=train)
            layer_num += 1

        for block_i in range(blocks_per_group):
            num_channels += delta_channels
            layer_mask_prob = _calc_shakedrop_mask_prob(
                layer_num, total_blocks, mask_prob)
            x = BottleneckShakeDrop(x,
                                    int(round(num_channels)),
                                    ((2, 2) if block_i == 0 else (1, 1)),
                                    layer_mask_prob,
                                    alpha_min,
                                    alpha_max,
                                    beta_min,
                                    beta_max,
                                    train=train)
            layer_num += 1

        assert layer_num - 1 == total_blocks
        x = utils.activation(x, train=train, name='final_bn')
        x = nn.avg_pool(x, (8, 8))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(x, num_outputs, kernel_init=utils.dense_layer_init_fn)
        return x