예제 #1
0
    def apply(self,
              x,
              channels,
              strides,
              prob,
              alpha_min,
              alpha_max,
              beta_min,
              beta_max,
              train=True):
        """Implements the forward pass in the module.

    Args:
      x: Input to the module. Should have shape [batch_size, dim, dim, features]
        where dim is the resolution (width and height if the input is an image).
      channels: How many channels to use in the convolutional layers.
      strides: Strides for the pooling.
      prob: Probability of dropping the block (see paper for details).
      alpha_min: See paper.
      alpha_max: See paper.
      beta_min: See paper.
      beta_max: See paper.
      train: If False, will use the moving average for batch norm statistics.
        Else, will use statistics computed on the batch.

    Returns:
      The output of the bottleneck block.
    """
        y = utils.activation(x, apply_relu=False, train=train, name='bn_1_pre')
        y = nn.Conv(y,
                    channels, (1, 1),
                    padding='SAME',
                    bias=False,
                    kernel_init=utils.conv_kernel_init_fn,
                    name='1x1_conv_contract')
        y = utils.activation(y, train=train, name='bn_1_post')
        y = nn.Conv(y,
                    channels, (3, 3),
                    strides,
                    padding='SAME',
                    bias=False,
                    kernel_init=utils.conv_kernel_init_fn,
                    name='3x3')
        y = utils.activation(y, train=train, name='bn_2')
        y = nn.Conv(y,
                    channels * 4, (1, 1),
                    padding='SAME',
                    bias=False,
                    kernel_init=utils.conv_kernel_init_fn,
                    name='1x1_conv_expand')
        y = utils.activation(y, apply_relu=False, train=train, name='bn_3')

        if train:
            y = utils.shake_drop_train(y, prob, alpha_min, alpha_max, beta_min,
                                       beta_max)
        else:
            y = utils.shake_drop_eval(y, prob, alpha_min, alpha_max)

        x = _shortcut(x, channels * 4, strides)
        return x + y
예제 #2
0
    def apply(self,
              x,
              channels,
              strides=(1, 1),
              activate_before_residual=False,
              train=True):
        """Implements the forward pass in the module.

    Args:
      x: Input to the module. Should have shape [batch_size, dim, dim, features]
        where dim is the resolution (width and height if the input is an image).
      channels: How many channels to use in the convolutional layers.
      strides: Strides for the pooling.
      activate_before_residual: True if the batch norm and relu should be
        applied before the residual branches out (should be True only for the
        first block of the model).
      train: If False, will use the moving average for batch norm statistics.
        Else, will use statistics computed on the batch.

    Returns:
      The output of the resnet block.
    """
        if activate_before_residual:
            x = utils.activation(x, train, name='init_bn')
            orig_x = x
        else:
            orig_x = x

        block_x = x
        if not activate_before_residual:
            block_x = utils.activation(block_x, train, name='init_bn')

        block_x = nn.Conv(block_x,
                          channels, (3, 3),
                          strides,
                          padding='SAME',
                          bias=False,
                          kernel_init=utils.conv_kernel_init_fn,
                          name='conv1')
        block_x = utils.activation(block_x, train=train, name='bn_2')
        block_x = nn.Conv(block_x,
                          channels, (3, 3),
                          padding='SAME',
                          bias=False,
                          kernel_init=utils.conv_kernel_init_fn,
                          name='conv2')

        return _output_add(block_x, orig_x)
예제 #3
0
    def apply(self,
              x,
              blocks_per_group,
              channel_multiplier,
              num_outputs,
              train=True):
        """Implements a WideResnet module.

    Args:
      x: Input to the module. Should have shape [batch_size, dim, dim, 3]
        where dim is the resolution of the image.
      blocks_per_group: How many resnet blocks to add to each group (should be
        4 blocks for a WRN28, and 6 for a WRN40).
      channel_multiplier: The multiplier to apply to the number of filters in
        the model (1 is classical resnet, 10 for WRN28-10, etc...).
      num_outputs: Dimension of the output of the model (ie number of classes
        for a classification problem).
      train: If False, will use the moving average for batch norm statistics.

    Returns:
      The output of the WideResnet, a tensor of shape [batch_size, num_classes].
    """
        first_x = x
        x = nn.Conv(x,
                    16, (3, 3),
                    padding='SAME',
                    name='init_conv',
                    kernel_init=utils.conv_kernel_init_fn,
                    bias=False)
        x = WideResnetGroup(x,
                            blocks_per_group,
                            16 * channel_multiplier,
                            activate_before_residual=True,
                            train=train)
        x = WideResnetGroup(x,
                            blocks_per_group,
                            32 * channel_multiplier, (2, 2),
                            train=train)
        x = WideResnetGroup(x,
                            blocks_per_group,
                            64 * channel_multiplier, (2, 2),
                            train=train)
        if FLAGS.use_additional_skip_connections:
            x = _output_add(x, first_x)
        x = utils.activation(x, train=train, name='pre-pool-bn')
        x = nn.avg_pool(x, x.shape[1:3])
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(x, num_outputs, kernel_init=utils.dense_layer_init_fn)
        return x
예제 #4
0
    def apply(self, x, channels, strides=(1, 1), train=True):
        """Implements the forward pass in the module.

    Args:
      x: Input to the module. Should have shape [batch_size, dim, dim, features]
        where dim is the resolution (width and height if the input is an image).
      channels: How many channels to use in the convolutional layers.
      strides: Strides for the pooling.
      train: If False, will use the moving average for batch norm statistics.

    Returns:
      The output of the resnet block. Will have shape
        [batch_size, dim, dim, channels] if strides = (1, 1) or
        [batch_size, dim/2, dim/2, channels] if strides = (2, 2).
    """

        if x.shape[-1] == channels:
            return x

        # Skip path 1
        h1 = nn.avg_pool(x, (1, 1), strides=strides, padding='VALID')
        h1 = nn.Conv(h1,
                     channels // 2, (1, 1),
                     strides=(1, 1),
                     padding='SAME',
                     bias=False,
                     kernel_init=utils.conv_kernel_init_fn,
                     name='conv_h1')

        # Skip path 2
        # The next two lines offset the "image" by one pixel on the right and one
        # down (see Shake-Shake regularization, Xavier Gastaldi for details)
        pad_arr = [[0, 0], [0, 1], [0, 1], [0, 0]]
        h2 = jnp.pad(x, pad_arr)[:, 1:, 1:, :]
        h2 = nn.avg_pool(h2, (1, 1), strides=strides, padding='VALID')
        h2 = nn.Conv(h2,
                     channels // 2, (1, 1),
                     strides=(1, 1),
                     padding='SAME',
                     bias=False,
                     kernel_init=utils.conv_kernel_init_fn,
                     name='conv_h2')
        merged_branches = jnp.concatenate([h1, h2], axis=3)
        return utils.activation(merged_branches,
                                apply_relu=False,
                                train=train,
                                name='bn_residual')
예제 #5
0
    def apply(self,
              x,
              blocks_per_group,
              channel_multiplier,
              num_outputs,
              train=True):
        """Implements a WideResnet with ShakeShake regularization module.

    Args:
      x: Input to the module. Should have shape [batch_size, dim, dim, 3]
        where dim is the resolution of the image.
      blocks_per_group: How many resnet blocks to add to each group (should be
        4 blocks for a WRN26 as per standard shake shake implementation).
      channel_multiplier: The multiplier to apply to the number of filters in
        the model (1 is classical resnet, 6 for WRN26-2x6, etc...).
      num_outputs: Dimension of the output of the model (ie number of classes
        for a classification problem).
      train: If False, will use the moving average for batch norm statistics.
        Else, will use statistics computed on the batch.

    Returns:
      The output of the WideResnet with ShakeShake regularization, a tensor of
      shape [batch_size, num_classes].
    """
        x = nn.Conv(x,
                    16, (3, 3),
                    padding='SAME',
                    kernel_init=utils.conv_kernel_init_fn,
                    bias=False,
                    name='init_conv')
        x = utils.activation(x, apply_relu=False, train=train, name='init_bn')
        x = WideResnetShakeShakeGroup(x,
                                      blocks_per_group,
                                      16 * channel_multiplier,
                                      train=train)
        x = WideResnetShakeShakeGroup(x,
                                      blocks_per_group,
                                      32 * channel_multiplier, (2, 2),
                                      train=train)
        x = WideResnetShakeShakeGroup(x,
                                      blocks_per_group,
                                      64 * channel_multiplier, (2, 2),
                                      train=train)
        x = jax.nn.relu(x)
        x = nn.avg_pool(x, (8, 8))
        x = x.reshape((x.shape[0], -1))
        return nn.Dense(x, num_outputs, kernel_init=utils.dense_layer_init_fn)
예제 #6
0
    def apply(self,
              x,
              num_outputs,
              pyramid_alpha=200,
              pyramid_depth=272,
              train=True):
        """Implements the forward pass in the module.

    Args:
      x: Input to the module. Should have shape [batch_size, dim, dim, 3]
        where dim is the resolution of the image.
      num_outputs: Dimension of the output of the model (ie number of classes
        for a classification problem).
      pyramid_alpha: See paper.
      pyramid_depth: See paper.
      train: If False, will use the moving average for batch norm statistics.
        Else, will use statistics computed on the batch.

    Returns:
      The output of the PyramidNet model, a tensor of shape
        [batch_size, num_classes].
    """
        assert (pyramid_depth - 2) % 9 == 0

        # Shake-drop hyper-params
        mask_prob = 0.5
        alpha_min, alpha_max = (-1.0, 1.0)
        beta_min, beta_max = (0.0, 1.0)

        # Bottleneck network size
        blocks_per_group = (pyramid_depth - 2) // 9
        # See Eqn 2 in https://arxiv.org/abs/1610.02915
        num_channels = 16
        # N in https://arxiv.org/abs/1610.02915
        total_blocks = blocks_per_group * 3
        delta_channels = pyramid_alpha / total_blocks

        x = nn.Conv(x,
                    16, (3, 3),
                    padding='SAME',
                    name='init_conv',
                    bias=False,
                    kernel_init=utils.conv_kernel_init_fn)
        x = utils.activation(x, apply_relu=False, train=train, name='init_bn')

        layer_num = 1

        for block_i in range(blocks_per_group):
            num_channels += delta_channels
            layer_mask_prob = _calc_shakedrop_mask_prob(
                layer_num, total_blocks, mask_prob)
            x = BottleneckShakeDrop(x,
                                    int(round(num_channels)), (1, 1),
                                    layer_mask_prob,
                                    alpha_min,
                                    alpha_max,
                                    beta_min,
                                    beta_max,
                                    train=train)
            layer_num += 1

        for block_i in range(blocks_per_group):
            num_channels += delta_channels
            layer_mask_prob = _calc_shakedrop_mask_prob(
                layer_num, total_blocks, mask_prob)
            x = BottleneckShakeDrop(x,
                                    int(round(num_channels)),
                                    ((2, 2) if block_i == 0 else (1, 1)),
                                    layer_mask_prob,
                                    alpha_min,
                                    alpha_max,
                                    beta_min,
                                    beta_max,
                                    train=train)
            layer_num += 1

        for block_i in range(blocks_per_group):
            num_channels += delta_channels
            layer_mask_prob = _calc_shakedrop_mask_prob(
                layer_num, total_blocks, mask_prob)
            x = BottleneckShakeDrop(x,
                                    int(round(num_channels)),
                                    ((2, 2) if block_i == 0 else (1, 1)),
                                    layer_mask_prob,
                                    alpha_min,
                                    alpha_max,
                                    beta_min,
                                    beta_max,
                                    train=train)
            layer_num += 1

        assert layer_num - 1 == total_blocks
        x = utils.activation(x, train=train, name='final_bn')
        x = nn.avg_pool(x, (8, 8))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(x, num_outputs, kernel_init=utils.dense_layer_init_fn)
        return x
예제 #7
0
    def apply(self, x, channels, strides=(1, 1), train=True):
        """Implements the forward pass in the module.

    Args:
      x: Input to the module. Should have shape [batch_size, dim, dim, features]
        where dim is the resolution (width and height if the input is an image).
      channels: How many channels to use in the convolutional layers.
      strides: Strides for the pooling.
      train: If False, will use the moving average for batch norm statistics.
        Else, will use statistics computed on the batch.

    Returns:
      The output of the resnet block. Will have shape
        [batch_size, dim, dim, channels] if strides = (1, 1) or
        [batch_size, dim/2, dim/2, channels] if strides = (2, 2).
    """
        a = b = residual = x

        a = jax.nn.relu(a)
        a = nn.Conv(a,
                    channels, (3, 3),
                    strides,
                    padding='SAME',
                    bias=False,
                    kernel_init=utils.conv_kernel_init_fn,
                    name='conv_a_1')
        a = utils.activation(a, train=train, name='bn_a_1')
        a = nn.Conv(a,
                    channels, (3, 3),
                    padding='SAME',
                    bias=False,
                    kernel_init=utils.conv_kernel_init_fn,
                    name='conv_a_2')
        a = utils.activation(a, apply_relu=False, train=train, name='bn_a_2')

        b = jax.nn.relu(b)
        b = nn.Conv(b,
                    channels, (3, 3),
                    strides,
                    padding='SAME',
                    bias=False,
                    kernel_init=utils.conv_kernel_init_fn,
                    name='conv_b_1')
        b = utils.activation(b, train=train, name='bn_b_1')
        b = nn.Conv(b,
                    channels, (3, 3),
                    padding='SAME',
                    bias=False,
                    kernel_init=utils.conv_kernel_init_fn,
                    name='conv_b_2')
        b = utils.activation(b, apply_relu=False, train=train, name='bn_b_2')

        if train and not self.is_initializing():
            ab = utils.shake_shake_train(a, b)
        else:
            ab = utils.shake_shake_eval(a, b)

        # Apply an up projection in case of channel mismatch.
        residual = Shortcut(residual, channels, strides, train)

        return residual + ab