Ejemplo n.º 1
0
    def apply(self,
              x,
              blocks_per_group,
              channel_multiplier,
              num_outputs,
              train=True):

        x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv')
        x = WideResnetShakeShakeGroup(x,
                                      blocks_per_group,
                                      16 * channel_multiplier,
                                      train=train)
        x = WideResnetShakeShakeGroup(x,
                                      blocks_per_group,
                                      32 * channel_multiplier, (2, 2),
                                      train=train)
        x = WideResnetShakeShakeGroup(x,
                                      blocks_per_group,
                                      64 * channel_multiplier, (2, 2),
                                      train=train)
        x = jax.nn.relu(x)
        x = nn.avg_pool(x, (8, 8))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(x, num_outputs)
        return x
Ejemplo n.º 2
0
    def apply(self,
              x,
              out_ch=None,
              with_conv=False,
              fir=False,
              fir_kernel=[1, 3, 3, 1]):
        B, H, W, C = x.shape
        out_ch = out_ch if out_ch else C
        if not fir:
            if with_conv:
                x = conv3x3(x, out_ch, stride=2)
            else:
                x = nn.avg_pool(x,
                                window_shape=(2, 2),
                                strides=(2, 2),
                                padding='SAME')
        else:
            if not with_conv:
                x = up_or_down_sampling.downsample_2d(x, fir_kernel, factor=2)
            else:
                x = up_or_down_sampling.Conv2d(x,
                                               out_ch,
                                               kernel=3,
                                               down=True,
                                               resample_kernel=fir_kernel,
                                               bias=True,
                                               kernel_init=default_init())

        assert x.shape == (B, H // 2, W // 2, out_ch)
        return x
Ejemplo n.º 3
0
 def apply(self, x):
     x = nn.Conv(x, features=32, kernel_size=(3, 3), name="conv")
     x = nn.relu(x)
     x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
     x = x.reshape((x.shape[0], -1))  # flatten.
     x = nn.Dense(x, 128, name="fc")
     return x
Ejemplo n.º 4
0
    def apply(self,
              x,
              blocks_per_group,
              channel_multiplier,
              num_outputs,
              dropout_rate=0.0,
              train=True):

        x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv')
        x = WideResnetGroup(x,
                            blocks_per_group,
                            16 * channel_multiplier,
                            dropout_rate=dropout_rate,
                            train=train)
        x = WideResnetGroup(x,
                            blocks_per_group,
                            32 * channel_multiplier, (2, 2),
                            dropout_rate=dropout_rate,
                            train=train)
        x = WideResnetGroup(x,
                            blocks_per_group,
                            64 * channel_multiplier, (2, 2),
                            dropout_rate=dropout_rate,
                            train=train)
        x = nn.BatchNorm(x,
                         use_running_average=not train,
                         momentum=0.9,
                         epsilon=1e-5)
        x = jax.nn.relu(x)
        x = nn.avg_pool(x, (8, 8))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(x, num_outputs)
        return x
Ejemplo n.º 5
0
 def apply(self, x, with_conv=False):
     B, H, W, C = x.shape
     if with_conv:
         x = ddpm_conv3x3(x, C, stride=2)
     else:
         x = nn.avg_pool(x,
                         window_shape=(2, 2),
                         strides=(2, 2),
                         padding='SAME')
     assert x.shape == (B, H // 2, W // 2, C)
     return x
Ejemplo n.º 6
0
 def apply(self, x, y, features, n_stages, normalizer, act=nn.relu):
     x = act(x)
     path = x
     for _ in range(n_stages):
         path = normalizer(path, y)
         path = nn.avg_pool(path,
                            window_shape=(5, 5),
                            strides=(1, 1),
                            padding='SAME')
         path = ncsn_conv3x3(path, features, stride=1, bias=False)
         x = path + x
     return x
Ejemplo n.º 7
0
def shortcut(x, chn_out, strides):
    """Pyramid Net Shortcut.

  Use Average pooling to downsample
  Use zero-padding to increase channels

  Args:
    x: input
    chn_out: expected number of output channels
    strides: striding applied by block

  Returns:
    shortcut
  """
    chn_in = x.shape[3]
    if strides != (1, 1):
        x = nn.avg_pool(x, strides, strides)
    if chn_out != chn_in:
        diff = chn_out - chn_in
        x = jnp.pad(x, [[0, 0], [0, 0], [0, 0], [0, diff]])
    return x
Ejemplo n.º 8
0
    def apply(self,
              x,
              *,
              train,
              num_classes,
              block_class=BottleneckResNetImageNetBlock,
              stage_sizes,
              width_factor=1,
              normalization='bn',
              activation_f=None,
              std_penalty_mult=0,
              use_residual=1,
              bias_scale=0.0,
              weight_norm='none',
              compensate_padding=True,
              softplus_scale=None,
              no_head=False,
              zero_inits=True):
        """Construct ResNet V1 with `num_classes` outputs."""
        self._stage_sizes = stage_sizes
        if std_penalty_mult > 0:
            raise NotImplementedError(
                'std_penalty_mult not supported for ResNetImageNet')

        width = 64 * width_factor

        # Root block.
        activation_f = get_activation_f(activation_f, train, softplus_scale,
                                        bias_scale)
        norm = get_norm(activation_f, normalization, train)
        conv = get_conv(activation_f, bias_scale, weight_norm,
                        compensate_padding, normalization)
        x = conv(x,
                 width,
                 kernel_size=(7, 7),
                 strides=(2, 2),
                 name='init_conv')
        x = norm(x, name='init_bn')

        if compensate_padding:
            # NOTE: this leads to lower performance.
            x = nn.avg_pool(x, (2, 2), strides=(2, 2), padding='SAME')
        else:
            x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')

        # Stages.
        for i, stage_size in enumerate(stage_sizes):
            x = ResNetStage(
                x,
                stage_size,
                filters=width * 2**i,
                block_class=block_class,
                first_block_strides=(1, 1) if i == 0 else (2, 2),
                train=train,
                name=f'stage{i + 1}',
                conv=conv,
                norm=norm,
                activation_f=activation_f,
                use_residual=use_residual,
                zero_inits=zero_inits,
            )

        if not no_head:
            # Head.
            x = jnp.mean(x, axis=(1, 2))
            x = nn.Dense(x,
                         num_classes,
                         kernel_init=nn.initializers.zeros
                         if zero_inits else nn.initializers.lecun_normal(),
                         name='head')
        return x, 0, {}
Ejemplo n.º 9
0
    def apply(self,
              x,
              blocks_per_group,
              channel_multiplier,
              num_outputs,
              dropout_rate=0.0,
              normalization='bn',
              activation_f=None,
              std_penalty_mult=0,
              use_residual=1,
              train=True,
              bias_scale=0.0,
              weight_norm='none',
              no_head=False,
              compensate_padding=True,
              softplus_scale=None):

        penalty = 0

        activation_f = get_activation_f(activation_f, train, softplus_scale,
                                        bias_scale)
        norm = get_norm(activation_f, normalization, train)
        conv = get_conv(activation_f, bias_scale, weight_norm,
                        compensate_padding, normalization)
        x = conv(x,
                 16 * channel_multiplier, (3, 3),
                 padding='SAME',
                 name='init_conv')
        x, g_penalty = WideResnetGroup(x,
                                       blocks_per_group,
                                       16 * channel_multiplier,
                                       dropout_rate=dropout_rate,
                                       normalization=normalization,
                                       activation_f=activation_f,
                                       std_penalty_mult=std_penalty_mult,
                                       use_residual=use_residual,
                                       train=train,
                                       bias_scale=bias_scale,
                                       weight_norm=weight_norm)
        penalty += g_penalty
        x, g_penalty = WideResnetGroup(x,
                                       blocks_per_group,
                                       32 * channel_multiplier, (2, 2),
                                       dropout_rate=dropout_rate,
                                       normalization=normalization,
                                       activation_f=activation_f,
                                       std_penalty_mult=std_penalty_mult,
                                       use_residual=use_residual,
                                       train=train,
                                       bias_scale=bias_scale,
                                       weight_norm=weight_norm)
        penalty += g_penalty
        x, g_penalty = WideResnetGroup(x,
                                       blocks_per_group,
                                       64 * channel_multiplier, (2, 2),
                                       dropout_rate=dropout_rate,
                                       normalization=normalization,
                                       activation_f=activation_f,
                                       std_penalty_mult=std_penalty_mult,
                                       use_residual=use_residual,
                                       train=train,
                                       bias_scale=bias_scale,
                                       weight_norm=weight_norm)
        penalty += g_penalty

        x = norm(x, name='final_norm')
        if std_penalty_mult > 0:
            penalty += std_penalty(x)
        if not no_head:
            x = activation_f(x, features=x.shape[-1])
            x = nn.avg_pool(x, (8, 8))
            x = x.reshape((x.shape[0], -1))
            x = nn.Dense(x, num_outputs)
        return x, penalty, {}
Ejemplo n.º 10
0
  def apply(
      self,
      inputs,
      blocks_per_group,
      channel_multiplier,
      num_outputs,
      kernel_size=(3, 3),
      strides=None,
      maxpool=False,
      dropout_rate=0.0,
      dtype=jnp.float32,
      norm_layer='group_norm',
      train=True,
      return_activations=False,
      input_layer_key='input',
      has_discriminator=False,
      discriminator=False,
  ):

    norm_layer_name = ''
    if norm_layer == 'batch_norm':
      norm_layer = nn.BatchNorm.partial(use_running_average=not train)
      norm_layer_name = 'bn'
    elif norm_layer == 'group_norm':
      norm_layer = nn.GroupNorm.partial(num_groups=16)
      norm_layer_name = 'gn'

    layer_activations = collections.OrderedDict()
    input_is_set = False
    current_rep_key = 'input'
    if input_layer_key == current_rep_key:
      x = inputs
      input_is_set = True
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

    current_rep_key = 'init_conv'
    if input_is_set:
      x = nn.Conv(
          x,
          16,
          kernel_size=kernel_size,
          strides=strides,
          padding='SAME',
          name='init_conv')
      if maxpool:
        x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')

      layer_activations[current_rep_key] = x
      rep_key = current_rep_key
    elif input_layer_key == current_rep_key:
      x = inputs
      input_is_set = True
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

    current_rep_key = 'l1'
    if input_is_set:
      x = WideResnetGroup(
          x,
          blocks_per_group,
          16 * channel_multiplier,
          dropout_rate=dropout_rate,
          norm_layer=norm_layer,
          train=train,
          name='l1')
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key
    elif input_layer_key == current_rep_key:
      x = inputs
      input_is_set = True
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

    current_rep_key = 'l2'
    if input_is_set:
      x = WideResnetGroup(
          x,
          blocks_per_group,
          32 * channel_multiplier, (2, 2),
          dropout_rate=dropout_rate,
          norm_layer=norm_layer,
          train=train,
          name='l2')
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key
    elif input_layer_key == current_rep_key:
      x = inputs
      input_is_set = True
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

    current_rep_key = 'l3'
    if input_is_set:
      x = WideResnetGroup(
          x,
          blocks_per_group,
          64 * channel_multiplier, (2, 2),
          dropout_rate=dropout_rate,
          norm_layer=norm_layer,
          train=train,
          name='l3')
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key
    elif input_layer_key == current_rep_key:
      x = inputs
      input_is_set = True
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

    current_rep_key = 'l4'
    if input_is_set:
      x = norm_layer(x, name=f'{norm_layer_name}')
      x = jax.nn.relu(x)
      x = nn.avg_pool(x, (8, 8))
      x = x.reshape((x.shape[0], -1))
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key
    elif input_layer_key == current_rep_key:
      x = inputs
      input_is_set = True
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

    # DANN module
    if has_discriminator:
      z = dann_utils.flip_grad_identity(x)
      z = nn.Dense(z, 2, name='disc_l1', bias=True)
      z = nn.relu(z)
      z = nn.Dense(z, 2, name='disc_l2', bias=True)

    current_rep_key = 'head'
    if input_is_set:
      x = nn.Dense(x, num_outputs, dtype=dtype, name='head')
    else:
      x = inputs
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

      logging.warn('Input was never used')

    outputs = x
    if return_activations:
      outputs = (x, layer_activations, rep_key)
      if discriminator and has_discriminator:
        outputs = outputs + (z,)
    else:
      del layer_activations
      if discriminator and has_discriminator:
        outputs = (x, z)
    if discriminator and (not has_discriminator):
      raise ValueError(
          'Incosistent values passed for discriminator and has_discriminator')
    return outputs
Ejemplo n.º 11
0
    def apply(self,
              x,
              num_outputs,
              pyramid_alpha=200,
              pyramid_depth=272,
              train=True):
        assert (pyramid_depth - 2) % 9 == 0

        # Shake-drop hyper-params
        mask_prob = 0.5
        alpha_min, alpha_max = (-1.0, 1.0)
        beta_min, beta_max = (0.0, 1.0)

        # Bottleneck network size
        blocks_per_group = (pyramid_depth - 2) // 9
        # See Eqn 2 in https://arxiv.org/abs/1610.02915
        num_channels = 16
        # N in https://arxiv.org/abs/1610.02915
        total_blocks = blocks_per_group * 3
        delta_channels = pyramid_alpha / total_blocks

        x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv')
        x = nn.BatchNorm(x,
                         use_running_average=not train,
                         momentum=0.9,
                         epsilon=1e-5,
                         name='init_bn')

        layer_num = 1

        for block_i in range(blocks_per_group):
            num_channels += delta_channels
            layer_mask_prob = _calc_shakedrop_mask_prob(
                layer_num, total_blocks, mask_prob)
            x = BottleneckShakeDrop(x,
                                    int(num_channels), (1, 1),
                                    layer_mask_prob,
                                    alpha_min,
                                    alpha_max,
                                    beta_min,
                                    beta_max,
                                    train=train)
            layer_num += 1

        for block_i in range(blocks_per_group):
            num_channels += delta_channels
            layer_mask_prob = _calc_shakedrop_mask_prob(
                layer_num, total_blocks, mask_prob)
            x = BottleneckShakeDrop(x,
                                    int(num_channels),
                                    ((2, 2) if block_i == 0 else (1, 1)),
                                    layer_mask_prob,
                                    alpha_min,
                                    alpha_max,
                                    beta_min,
                                    beta_max,
                                    train=train)
            layer_num += 1

        for block_i in range(blocks_per_group):
            num_channels += delta_channels
            layer_mask_prob = _calc_shakedrop_mask_prob(
                layer_num, total_blocks, mask_prob)
            x = BottleneckShakeDrop(x,
                                    int(num_channels),
                                    ((2, 2) if block_i == 0 else (1, 1)),
                                    layer_mask_prob,
                                    alpha_min,
                                    alpha_max,
                                    beta_min,
                                    beta_max,
                                    train=train)
            layer_num += 1

        assert layer_num - 1 == total_blocks
        x = nn.BatchNorm(x,
                         use_running_average=not train,
                         momentum=0.9,
                         epsilon=1e-5,
                         name='final_bn')
        x = jax.nn.relu(x)
        x = nn.avg_pool(x, (8, 8))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(x, num_outputs)
        return x