Exemplo n.º 1
0
    def apply(self,
              x: jnp.ndarray,
              channels: int,
              strides: Tuple[int, int] = (1, 1),
              activate_before_residual: bool = False,
              train: bool = True) -> jnp.ndarray:
        """Implements the forward pass in the module.

    Args:
      x: Input to the module. Should have shape [batch_size, dim, dim, features]
        where dim is the resolution (width and height if the input is an image).
      channels: How many channels to use in the convolutional layers.
      strides: Strides for the pooling.
      activate_before_residual: True if the batch norm and relu should be
        applied before the residual branches out (should be True only for the
        first block of the model).
      train: If False, will use the moving average for batch norm statistics.
        Else, will use statistics computed on the batch.

    Returns:
      The output of the resnet block.
    """
        if activate_before_residual:
            x = utils.activation(x, train, name='init_bn')
            orig_x = x
        else:
            orig_x = x

        block_x = x
        if not activate_before_residual:
            block_x = utils.activation(block_x, train, name='init_bn')

        block_x = nn.Conv(block_x,
                          channels, (3, 3),
                          strides,
                          padding='SAME',
                          bias=False,
                          kernel_init=utils.conv_kernel_init_fn,
                          name='conv1')
        block_x = utils.activation(block_x, train=train, name='bn_2')
        block_x = nn.Conv(block_x,
                          channels, (3, 3),
                          padding='SAME',
                          bias=False,
                          kernel_init=utils.conv_kernel_init_fn,
                          name='conv2')

        return _output_add(block_x, orig_x)
Exemplo n.º 2
0
    def apply(self,
              x: jnp.ndarray,
              blocks_per_group: int,
              channel_multiplier: int,
              num_outputs: int,
              train: bool = True,
              true_gradient: bool = False) -> jnp.ndarray:
        """Implements a WideResnet with ShakeShake regularization module.

    Args:
      x: Input to the module. Should have shape [batch_size, dim, dim, 3]
        where dim is the resolution of the image.
      blocks_per_group: How many resnet blocks to add to each group (should be
        4 blocks for a WRN26 as per standard shake shake implementation).
      channel_multiplier: The multiplier to apply to the number of filters in
        the model (1 is classical resnet, 6 for WRN26-2x6, etc...).
      num_outputs: Dimension of the output of the model (ie number of classes
        for a classification problem).
      train: If False, will use the moving average for batch norm statistics.
        Else, will use statistics computed on the batch.
      true_gradient: If true, the same mixing parameter will be used for the
        forward and backward pass (see paper for more details).

    Returns:
      The output of the WideResnet with ShakeShake regularization, a tensor of
      shape [batch_size, num_classes].
    """
        x = nn.Conv(x,
                    16, (3, 3),
                    padding='SAME',
                    kernel_init=utils.conv_kernel_init_fn,
                    bias=False,
                    name='init_conv')
        x = utils.activation(x, apply_relu=False, train=train, name='init_bn')
        x = WideResnetShakeShakeGroup(x,
                                      blocks_per_group,
                                      16 * channel_multiplier,
                                      train=train,
                                      true_gradient=true_gradient)
        x = WideResnetShakeShakeGroup(x,
                                      blocks_per_group,
                                      32 * channel_multiplier, (2, 2),
                                      train=train,
                                      true_gradient=true_gradient)
        x = WideResnetShakeShakeGroup(x,
                                      blocks_per_group,
                                      64 * channel_multiplier, (2, 2),
                                      train=train,
                                      true_gradient=true_gradient)
        x = jax.nn.relu(x)
        x = nn.avg_pool(x, x.shape[1:3])
        x = x.reshape((x.shape[0], -1))
        return nn.Dense(x, num_outputs, kernel_init=utils.dense_layer_init_fn)
Exemplo n.º 3
0
    def apply(self,
              x: jnp.ndarray,
              channels: int,
              strides: Tuple[int, int] = (1, 1),
              train: bool = True) -> jnp.ndarray:
        """Implements the forward pass in the module.

    Args:
      x: Input to the module. Should have shape [batch_size, dim, dim, features]
        where dim is the resolution (width and height if the input is an image).
      channels: How many channels to use in the convolutional layers.
      strides: Strides for the pooling.
      train: If False, will use the moving average for batch norm statistics.

    Returns:
      The output of the resnet block. Will have shape
        [batch_size, dim, dim, channels] if strides = (1, 1) or
        [batch_size, dim/2, dim/2, channels] if strides = (2, 2).
    """

        if x.shape[-1] == channels:
            return x

        # Skip path 1
        h1 = nn.avg_pool(x, (1, 1), strides=strides, padding='VALID')
        h1 = nn.Conv(h1,
                     channels // 2, (1, 1),
                     strides=(1, 1),
                     padding='SAME',
                     bias=False,
                     kernel_init=utils.conv_kernel_init_fn,
                     name='conv_h1')

        # Skip path 2
        # The next two lines offset the "image" by one pixel on the right and one
        # down (see Shake-Shake regularization, Xavier Gastaldi for details)
        pad_arr = [[0, 0], [0, 1], [0, 1], [0, 0]]
        h2 = jnp.pad(x, pad_arr)[:, 1:, 1:, :]
        h2 = nn.avg_pool(h2, (1, 1), strides=strides, padding='VALID')
        h2 = nn.Conv(h2,
                     channels // 2, (1, 1),
                     strides=(1, 1),
                     padding='SAME',
                     bias=False,
                     kernel_init=utils.conv_kernel_init_fn,
                     name='conv_h2')
        merged_branches = jnp.concatenate([h1, h2], axis=3)
        return utils.activation(merged_branches,
                                apply_relu=False,
                                train=train,
                                name='bn_residual')
Exemplo n.º 4
0
    def apply(self,
              x: jnp.ndarray,
              blocks_per_group: int,
              channel_multiplier: int,
              num_outputs: int,
              train: bool = True) -> jnp.ndarray:
        """Implements a WideResnet module.

    Args:
      x: Input to the module. Should have shape [batch_size, dim, dim, 3]
        where dim is the resolution of the image.
      blocks_per_group: How many resnet blocks to add to each group (should be
        4 blocks for a WRN28, and 6 for a WRN40).
      channel_multiplier: The multiplier to apply to the number of filters in
        the model (1 is classical resnet, 10 for WRN28-10, etc...).
      num_outputs: Dimension of the output of the model (ie number of classes
        for a classification problem).
      train: If False, will use the moving average for batch norm statistics.

    Returns:
      The output of the WideResnet, a tensor of shape [batch_size, num_classes].
    """
        first_x = x
        x = nn.Conv(x,
                    16, (3, 3),
                    padding='SAME',
                    name='init_conv',
                    kernel_init=utils.conv_kernel_init_fn,
                    bias=False)
        x = WideResnetGroup(x,
                            blocks_per_group,
                            16 * channel_multiplier,
                            activate_before_residual=True,
                            train=train)
        x = WideResnetGroup(x,
                            blocks_per_group,
                            32 * channel_multiplier, (2, 2),
                            train=train)
        x = WideResnetGroup(x,
                            blocks_per_group,
                            64 * channel_multiplier, (2, 2),
                            train=train)
        if FLAGS.use_additional_skip_connections:
            x = _output_add(x, first_x)
        x = utils.activation(x, train=train, name='pre-pool-bn')
        x = nn.avg_pool(x, x.shape[1:3])
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(x, num_outputs, kernel_init=utils.dense_layer_init_fn)
        return x
Exemplo n.º 5
0
  def apply(self,
            x: jnp.ndarray,
            channels: int,
            strides: Tuple[int, int],
            prob: float,
            alpha_min: float,
            alpha_max: float,
            beta_min: float,
            beta_max: float,
            train: bool = True,
            true_gradient: bool = False) -> jnp.ndarray:
    """Implements the forward pass in the module.

    Args:
      x: Input to the module. Should have shape [batch_size, dim, dim, features]
        where dim is the resolution (width and height if the input is an image).
      channels: How many channels to use in the convolutional layers.
      strides: Strides for the pooling.
      prob: Probability of dropping the block (see paper for details).
      alpha_min: See paper.
      alpha_max: See paper.
      beta_min: See paper.
      beta_max: See paper.
      train: If False, will use the moving average for batch norm statistics.
        Else, will use statistics computed on the batch.
      true_gradient: If true, the same mixing parameter will be used for the
        forward and backward pass (see paper for more details).

    Returns:
      The output of the bottleneck block.
    """
    y = utils.activation(x, apply_relu=False, train=train, name='bn_1_pre')
    y = nn.Conv(
        y,
        channels, (1, 1),
        padding='SAME',
        bias=False,
        kernel_init=utils.conv_kernel_init_fn,
        name='1x1_conv_contract')
    y = utils.activation(y, train=train, name='bn_1_post')
    y = nn.Conv(
        y,
        channels, (3, 3),
        strides,
        padding='SAME',
        bias=False,
        kernel_init=utils.conv_kernel_init_fn,
        name='3x3')
    y = utils.activation(y, train=train, name='bn_2')
    y = nn.Conv(
        y,
        channels * 4, (1, 1),
        padding='SAME',
        bias=False,
        kernel_init=utils.conv_kernel_init_fn,
        name='1x1_conv_expand')
    y = utils.activation(y, apply_relu=False, train=train, name='bn_3')

    if train and not self.is_initializing():
      y = utils.shake_drop_train(y, prob, alpha_min, alpha_max,
                                 beta_min, beta_max,
                                 true_gradient=true_gradient)
    else:
      y = utils.shake_drop_eval(y, prob, alpha_min, alpha_max)

    x = _shortcut(x, channels * 4, strides)
    return x + y
Exemplo n.º 6
0
  def apply(self,
            x: jnp.ndarray,
            num_outputs: int,
            pyramid_alpha: int = 200,
            pyramid_depth: int = 272,
            train: bool = True,
            true_gradient: bool = False) -> jnp.ndarray:
    """Implements the forward pass in the module.

    Args:
      x: Input to the module. Should have shape [batch_size, dim, dim, 3]
        where dim is the resolution of the image.
      num_outputs: Dimension of the output of the model (ie number of classes
        for a classification problem).
      pyramid_alpha: See paper.
      pyramid_depth: See paper.
      train: If False, will use the moving average for batch norm statistics.
        Else, will use statistics computed on the batch.
      true_gradient: If true, the same mixing parameter will be used for the
        forward and backward pass (see paper for more details).

    Returns:
      The output of the PyramidNet model, a tensor of shape
        [batch_size, num_classes].
    """
    assert (pyramid_depth - 2) % 9 == 0

    # Shake-drop hyper-params
    mask_prob = 0.5
    alpha_min, alpha_max = (-1.0, 1.0)
    beta_min, beta_max = (0.0, 1.0)

    # Bottleneck network size
    blocks_per_group = (pyramid_depth - 2) // 9
    # See Eqn 2 in https://arxiv.org/abs/1610.02915
    num_channels = 16
    # N in https://arxiv.org/abs/1610.02915
    total_blocks = blocks_per_group * 3
    delta_channels = pyramid_alpha / total_blocks

    x = nn.Conv(
        x,
        16, (3, 3),
        padding='SAME',
        name='init_conv',
        bias=False,
        kernel_init=utils.conv_kernel_init_fn)
    x = utils.activation(x, apply_relu=False, train=train, name='init_bn')

    layer_num = 1

    for block_i in range(blocks_per_group):
      num_channels += delta_channels
      layer_mask_prob = _calc_shakedrop_mask_prob(layer_num, total_blocks,
                                                  mask_prob)
      x = BottleneckShakeDrop(
          x,
          int(round(num_channels)), (1, 1),
          layer_mask_prob,
          alpha_min,
          alpha_max,
          beta_min,
          beta_max,
          train=train,
          true_gradient=true_gradient)
      layer_num += 1

    for block_i in range(blocks_per_group):
      num_channels += delta_channels
      layer_mask_prob = _calc_shakedrop_mask_prob(
          layer_num, total_blocks, mask_prob)
      x = BottleneckShakeDrop(x, int(round(num_channels)),
                              ((2, 2) if block_i == 0 else (1, 1)),
                              layer_mask_prob,
                              alpha_min, alpha_max, beta_min, beta_max,
                              train=train,
                              true_gradient=true_gradient)
      layer_num += 1

    for block_i in range(blocks_per_group):
      num_channels += delta_channels
      layer_mask_prob = _calc_shakedrop_mask_prob(
          layer_num, total_blocks, mask_prob)
      x = BottleneckShakeDrop(x, int(round(num_channels)),
                              ((2, 2) if block_i == 0 else (1, 1)),
                              layer_mask_prob,
                              alpha_min, alpha_max, beta_min, beta_max,
                              train=train,
                              true_gradient=true_gradient)
      layer_num += 1

    assert layer_num - 1 == total_blocks
    x = utils.activation(x, train=train, name='final_bn')
    x = nn.avg_pool(x, (8, 8))
    x = x.reshape((x.shape[0], -1))
    x = nn.Dense(x, num_outputs, kernel_init=utils.dense_layer_init_fn)
    return x
Exemplo n.º 7
0
    def apply(self,
              x: jnp.ndarray,
              channels: int,
              strides: Tuple[int, int] = (1, 1),
              train: bool = True,
              true_gradient: bool = False) -> jnp.ndarray:
        """Implements the forward pass in the module.

    Args:
      x: Input to the module. Should have shape [batch_size, dim, dim, features]
        where dim is the resolution (width and height if the input is an image).
      channels: How many channels to use in the convolutional layers.
      strides: Strides for the pooling.
      train: If False, will use the moving average for batch norm statistics.
        Else, will use statistics computed on the batch.
      true_gradient: If true, the same mixing parameter will be used for the
        forward and backward pass (see paper for more details).

    Returns:
      The output of the resnet block. Will have shape
        [batch_size, dim, dim, channels] if strides = (1, 1) or
        [batch_size, dim/2, dim/2, channels] if strides = (2, 2).
    """
        a = b = residual = x

        a = jax.nn.relu(a)
        a = nn.Conv(a,
                    channels, (3, 3),
                    strides,
                    padding='SAME',
                    bias=False,
                    kernel_init=utils.conv_kernel_init_fn,
                    name='conv_a_1')
        a = utils.activation(a, train=train, name='bn_a_1')
        a = nn.Conv(a,
                    channels, (3, 3),
                    padding='SAME',
                    bias=False,
                    kernel_init=utils.conv_kernel_init_fn,
                    name='conv_a_2')
        a = utils.activation(a, apply_relu=False, train=train, name='bn_a_2')

        b = jax.nn.relu(b)
        b = nn.Conv(b,
                    channels, (3, 3),
                    strides,
                    padding='SAME',
                    bias=False,
                    kernel_init=utils.conv_kernel_init_fn,
                    name='conv_b_1')
        b = utils.activation(b, train=train, name='bn_b_1')
        b = nn.Conv(b,
                    channels, (3, 3),
                    padding='SAME',
                    bias=False,
                    kernel_init=utils.conv_kernel_init_fn,
                    name='conv_b_2')
        b = utils.activation(b, apply_relu=False, train=train, name='bn_b_2')

        if train and not self.is_initializing():
            ab = utils.shake_shake_train(a, b, true_gradient=true_gradient)
        else:
            ab = utils.shake_shake_eval(a, b)

        # Apply an up projection in case of channel mismatch.
        residual = Shortcut(residual, channels, strides, train)

        return residual + ab