Example #1
0
    def apply(self,
              x,
              num_filters=64,
              block_sizes=(3, 4, 6, 3),
              train=True,
              block=BottleneckBlock,
              small_inputs=False):
        if small_inputs:
            x = nn.Conv(x,
                        num_filters,
                        kernel_size=(3, 3),
                        strides=(1, 1),
                        bias=False,
                        name="init_conv")
        else:
            x = nn.Conv(x,
                        num_filters,
                        kernel_size=(7, 7),
                        strides=(2, 2),
                        bias=False,
                        name="init_conv")
        x = nn.BatchNorm(x,
                         use_running_average=not train,
                         epsilon=1e-5,
                         name="init_bn")
        if not small_inputs:
            x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME")
        for i, block_size in enumerate(block_sizes):
            for j in range(block_size):
                strides = (2, 2) if i > 0 and j == 0 else (1, 1)
                x = block(x, num_filters * 2**i, strides=strides, train=train)

        return x
Example #2
0
 def apply(self, x, features, n_stages, act=nn.relu):
     x = act(x)
     path = x
     for _ in range(n_stages):
         path = nn.max_pool(path,
                            window_shape=(5, 5),
                            strides=(1, 1),
                            padding='SAME')
         path = ncsn_conv3x3(path, features, stride=1, bias=False)
         x = path + x
     return x
 def apply(self, x, use_squeeze_excite = False):
   x = nn.Conv(x, features=8, kernel_size=(3, 3), padding="VALID")
   x = nn.relu(x)
   x = nn.Conv(x, features=16, kernel_size=(3, 3), padding="VALID")
   x = nn.relu(x)
   if use_squeeze_excite:
     x = SqueezeExciteLayer(x)
   x = nn.Conv(x, features=32, kernel_size=(3, 3), padding="VALID")
   x = nn.relu(x)
   if use_squeeze_excite:
     x = SqueezeExciteLayer(x)
   x = nn.Conv(x, features=1, kernel_size=(3, 3), padding="VALID")
   scores = nn.max_pool(x, window_shape=(8, 8), strides=(8, 8))[Ellipsis, 0]
   return scores
Example #4
0
 def apply(self, x, num_outputs, train=True):
     x = nn.Conv(x,
                 self.NUM_FILTERS, (7, 7), (2, 2),
                 bias=False,
                 name='init_conv')
     x = nn.BatchNorm(x,
                      use_running_average=not train,
                      momentum=0.9,
                      epsilon=1e-5,
                      name='init_bn')
     x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
     for i, block_size in enumerate(self.BLOCK_SIZES):
         for j in range(block_size):
             strides = (2, 2) if i > 0 and j == 0 else (1, 1)
             x = BottleneckBlock(x,
                                 self.NUM_FILTERS * 2**i,
                                 strides=strides,
                                 groups=self.GROUPS,
                                 base_width=self.WIDTH_PER_GROUP,
                                 train=train)
     x = jnp.mean(x, axis=(1, 2))
     x = nn.Dense(x, num_outputs, name='clf')
     return x
Example #5
0
    def apply(self,
              x,
              *,
              train,
              num_classes,
              block_class=BottleneckResNetImageNetBlock,
              stage_sizes,
              width_factor=1,
              normalization='bn',
              activation_f=None,
              std_penalty_mult=0,
              use_residual=1,
              bias_scale=0.0,
              weight_norm='none',
              compensate_padding=True,
              softplus_scale=None,
              no_head=False,
              zero_inits=True):
        """Construct ResNet V1 with `num_classes` outputs."""
        self._stage_sizes = stage_sizes
        if std_penalty_mult > 0:
            raise NotImplementedError(
                'std_penalty_mult not supported for ResNetImageNet')

        width = 64 * width_factor

        # Root block.
        activation_f = get_activation_f(activation_f, train, softplus_scale,
                                        bias_scale)
        norm = get_norm(activation_f, normalization, train)
        conv = get_conv(activation_f, bias_scale, weight_norm,
                        compensate_padding, normalization)
        x = conv(x,
                 width,
                 kernel_size=(7, 7),
                 strides=(2, 2),
                 name='init_conv')
        x = norm(x, name='init_bn')

        if compensate_padding:
            # NOTE: this leads to lower performance.
            x = nn.avg_pool(x, (2, 2), strides=(2, 2), padding='SAME')
        else:
            x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')

        # Stages.
        for i, stage_size in enumerate(stage_sizes):
            x = ResNetStage(
                x,
                stage_size,
                filters=width * 2**i,
                block_class=block_class,
                first_block_strides=(1, 1) if i == 0 else (2, 2),
                train=train,
                name=f'stage{i + 1}',
                conv=conv,
                norm=norm,
                activation_f=activation_f,
                use_residual=use_residual,
                zero_inits=zero_inits,
            )

        if not no_head:
            # Head.
            x = jnp.mean(x, axis=(1, 2))
            x = nn.Dense(x,
                         num_classes,
                         kernel_init=nn.initializers.zeros
                         if zero_inits else nn.initializers.lecun_normal(),
                         name='head')
        return x, 0, {}
Example #6
0
    def apply(self,
              x,
              depth,
              num_outputs,
              dropout_rate=0.0,
              normalization='bn',
              activation_f=None,
              std_penalty_mult=0,
              use_residual=1,
              train=True,
              bias_scale=0.0,
              weight_norm='none',
              filters=16,
              no_head=False,
              report_metrics=False,
              benchmark='cifar10',
              compensate_padding=True,
              softplus_scale=None):

        bn_index = iter(range(1000))
        conv_index = iter(range(1000))
        summaries = {}
        summary_ind = [0]

        def add_summary(name, val):
            """Summarize statistics of tensor."""
            if report_metrics:
                assert val.ndim == 4, (
                    'Assuming 4D inputs with channels last, got %s' %
                    str(val.shape))
                assert val.shape[1] == val.shape[
                    2], 'Assuming 4D inputs with channels last'
                summaries['%s_%d_mean_abs' %
                          (name, summary_ind[0] // 2)] = jnp.mean(
                              jnp.abs(jnp.mean(val, axis=(0, 1, 2))))
                summaries['%s_%d_mean_std' %
                          (name, summary_ind[0] // 2)] = jnp.mean(
                              jnp.std(val, axis=(0, 1, 2)))
                summary_ind[0] += 1

        penalty = 0

        activation_f = get_activation_f(activation_f, train, softplus_scale,
                                        bias_scale)
        norm = get_norm(activation_f, normalization, train)

        conv = get_conv(activation_f, bias_scale, weight_norm,
                        compensate_padding, normalization)

        def resnet_layer(
            inputs,
            penalty,
            filters,
            kernel_size=3,
            strides=1,
            activation=None,
        ):
            """2D Convolution-Batch Normalization-Activation stack builder."""
            x = inputs
            x = conv(x,
                     filters, (kernel_size, kernel_size),
                     strides=(strides, strides),
                     padding='SAME',
                     name='conv%d' % next(conv_index))
            x = norm(x, name='norm%d' % next(bn_index))
            add_summary('postnorm', x)
            if std_penalty_mult > 0:
                penalty += std_penalty(x)

            if activation:
                x = activation_f(x, features=x.shape[-1])
            add_summary('postact', x)
            return x, penalty

        # Main network code.
        num_res_blocks = (depth - 2) // 6

        if (depth - 2) % 6 != 0:
            raise ValueError('depth must be 6n+2 (e.g. 20, 32, 44).')

        inputs = x
        add_summary('input', x)
        add_summary('inputb', x)
        if benchmark in ['cifar10', 'cifar100']:
            x, penalty = resnet_layer(inputs,
                                      penalty,
                                      filters=filters,
                                      activation=True)
            head_kernel_init = nn.initializers.lecun_normal()
        elif benchmark in ['imagenet']:
            head_kernel_init = nn.initializers.zeros
            x, penalty = resnet_layer(inputs,
                                      penalty,
                                      filters=filters,
                                      activation=False,
                                      kernel_size=7,
                                      strides=2)
            # TODO(basv): evaluate max pool v/s avg_pool in an experiment?
            # if compensate_padding:
            #   x = nn.avg_pool(x, (2, 2), strides=(2, 2), padding="VALID")
            # else:
            x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
        else:
            raise ValueError('Model def not prepared for benchmark %s' %
                             benchmark)

        for stack in range(3):
            for res_block in range(num_res_blocks):
                strides = 1
                if stack > 0 and res_block == 0:  # First layer but not first stack.
                    strides = 2  # Downsample.
                y, penalty = resnet_layer(
                    x,
                    penalty,
                    filters=filters,
                    strides=strides,
                    activation=True,
                )
                y, penalty = resnet_layer(
                    y,
                    penalty,
                    filters=filters,
                    activation=False,
                )
                if stack > 0 and res_block == 0:  # First layer but not first stack.
                    # Linear projection residual shortcut to match changed dims.
                    x, penalty = resnet_layer(
                        x,
                        penalty,
                        filters=filters,
                        kernel_size=1,
                        strides=strides,
                        activation=False,
                    )

                if use_residual == 1:
                    # Apply an up projection in case of channel mismatch
                    x = x + y
                elif use_residual == 2:
                    x = (x + y) / jnp.sqrt(
                        1**2 + 1**2)  # Sum of independent normals.
                else:
                    x = y

                add_summary('postres', x)
                x = activation_f(x, features=x.shape[-1])
                add_summary('postresact', x)
            filters *= 2

        # V1 does not use BN after last shortcut connection-ReLU.
        if not no_head:
            x = jnp.mean(x, axis=(1, 2))
            add_summary('postpool', x)
            x = x.reshape((x.shape[0], -1))

            x = nn.Dense(x, num_outputs, kernel_init=head_kernel_init)
        return x, penalty, summaries
  def apply(
      self,
      inputs,
      blocks_per_group,
      channel_multiplier,
      num_outputs,
      kernel_size=(3, 3),
      strides=None,
      maxpool=False,
      dropout_rate=0.0,
      dtype=jnp.float32,
      norm_layer='group_norm',
      train=True,
      return_activations=False,
      input_layer_key='input',
      has_discriminator=False,
      discriminator=False,
  ):

    norm_layer_name = ''
    if norm_layer == 'batch_norm':
      norm_layer = nn.BatchNorm.partial(use_running_average=not train)
      norm_layer_name = 'bn'
    elif norm_layer == 'group_norm':
      norm_layer = nn.GroupNorm.partial(num_groups=16)
      norm_layer_name = 'gn'

    layer_activations = collections.OrderedDict()
    input_is_set = False
    current_rep_key = 'input'
    if input_layer_key == current_rep_key:
      x = inputs
      input_is_set = True
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

    current_rep_key = 'init_conv'
    if input_is_set:
      x = nn.Conv(
          x,
          16,
          kernel_size=kernel_size,
          strides=strides,
          padding='SAME',
          name='init_conv')
      if maxpool:
        x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')

      layer_activations[current_rep_key] = x
      rep_key = current_rep_key
    elif input_layer_key == current_rep_key:
      x = inputs
      input_is_set = True
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

    current_rep_key = 'l1'
    if input_is_set:
      x = WideResnetGroup(
          x,
          blocks_per_group,
          16 * channel_multiplier,
          dropout_rate=dropout_rate,
          norm_layer=norm_layer,
          train=train,
          name='l1')
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key
    elif input_layer_key == current_rep_key:
      x = inputs
      input_is_set = True
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

    current_rep_key = 'l2'
    if input_is_set:
      x = WideResnetGroup(
          x,
          blocks_per_group,
          32 * channel_multiplier, (2, 2),
          dropout_rate=dropout_rate,
          norm_layer=norm_layer,
          train=train,
          name='l2')
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key
    elif input_layer_key == current_rep_key:
      x = inputs
      input_is_set = True
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

    current_rep_key = 'l3'
    if input_is_set:
      x = WideResnetGroup(
          x,
          blocks_per_group,
          64 * channel_multiplier, (2, 2),
          dropout_rate=dropout_rate,
          norm_layer=norm_layer,
          train=train,
          name='l3')
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key
    elif input_layer_key == current_rep_key:
      x = inputs
      input_is_set = True
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

    current_rep_key = 'l4'
    if input_is_set:
      x = norm_layer(x, name=f'{norm_layer_name}')
      x = jax.nn.relu(x)
      x = nn.avg_pool(x, (8, 8))
      x = x.reshape((x.shape[0], -1))
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key
    elif input_layer_key == current_rep_key:
      x = inputs
      input_is_set = True
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

    # DANN module
    if has_discriminator:
      z = dann_utils.flip_grad_identity(x)
      z = nn.Dense(z, 2, name='disc_l1', bias=True)
      z = nn.relu(z)
      z = nn.Dense(z, 2, name='disc_l2', bias=True)

    current_rep_key = 'head'
    if input_is_set:
      x = nn.Dense(x, num_outputs, dtype=dtype, name='head')
    else:
      x = inputs
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

      logging.warn('Input was never used')

    outputs = x
    if return_activations:
      outputs = (x, layer_activations, rep_key)
      if discriminator and has_discriminator:
        outputs = outputs + (z,)
    else:
      del layer_activations
      if discriminator and has_discriminator:
        outputs = (x, z)
    if discriminator and (not has_discriminator):
      raise ValueError(
          'Incosistent values passed for discriminator and has_discriminator')
    return outputs
Example #8
0
    def apply(self,
              inputs,
              num_outputs,
              num_filters=64,
              num_layers=50,
              dropout_rate=0.0,
              input_dropout_rate=0.0,
              train=True,
              dtype=jnp.float32,
              head_bias_init=jnp.zeros,
              return_activations=False,
              input_layer_key='input',
              has_discriminator=False,
              discriminator=False):
        """Apply a ResNet network on the input.

    Args:
      inputs: jnp array; Inputs.
      num_outputs: int; Number of output units.
      num_filters: int; Determines base number of filters. Number of filters in
        block i is  num_filters * 2 ** i.
      num_layers: int; Number of layers (should be one of the predefined ones.)
      dropout_rate: float; Rate of dropping out the output of different hidden
        layers.
      input_dropout_rate: float; Rate of dropping out the input units.
      train: bool; Is train?
      dtype: jnp type; Type of the outputs.
      head_bias_init: fn(rng_key, shape)--> jnp array; Initializer for head bias
        parameters.
      return_activations: bool; If True hidden activation are also returned.
      input_layer_key: str; Determines where to plugin the input (this is to
        enable providing inputs to slices of the model). If `input_layer_key` is
        `layer_i` we assume the inputs are the activations of `layer_i` and pass
        them to `layer_{i+1}`.
      has_discriminator: bool; Whether the model should have discriminator
        layer.
      discriminator: bool; Whether we should return discriminator logits.

    Returns:
      Unnormalized Logits with shape `[bs, num_outputs]`,
      if return_activations:
        Logits, dict of hidden activations and the key to the representation(s)
        which will be used in as ``The Representation'', e.g., for computing
        losses.
    """
        if num_layers not in ResNet._block_size_options:
            raise ValueError('Please provide a valid number of layers')

        block_sizes = ResNet._block_size_options[num_layers]

        layer_activations = collections.OrderedDict()
        input_is_set = False
        current_rep_key = 'input'
        if input_layer_key == current_rep_key:
            x = inputs
            input_is_set = True

        if input_is_set:
            # Input dropout
            x = nn.dropout(x, input_dropout_rate, deterministic=not train)
            layer_activations[current_rep_key] = x
            rep_key = current_rep_key

        current_rep_key = 'init_conv'
        if input_layer_key == current_rep_key:
            x = inputs
            input_is_set = True
            layer_activations[current_rep_key] = x
            rep_key = current_rep_key
        elif input_is_set:
            # First block
            x = nn.Conv(x,
                        num_filters, (7, 7), (2, 2),
                        padding=[(3, 3), (3, 3)],
                        bias=False,
                        dtype=dtype,
                        name='init_conv')
            x = nn.BatchNorm(x,
                             use_running_average=not train,
                             momentum=0.9,
                             epsilon=1e-5,
                             dtype=dtype,
                             name='init_bn')
            x = nn.relu(x)
            x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')

            layer_activations[current_rep_key] = x
            rep_key = current_rep_key

        # Residual blocks
        for i, block_size in enumerate(block_sizes):

            # Stage i (each stage contains blocks of the same size).
            for j in range(block_size):
                strides = (2, 2) if i > 0 and j == 0 else (1, 1)
                current_rep_key = f'block_{i + 1}+{j}'
                if input_layer_key == current_rep_key:
                    x = inputs
                    input_is_set = True
                    layer_activations[current_rep_key] = x
                    rep_key = current_rep_key
                elif input_is_set:
                    x = ResidualBlock(x,
                                      num_filters * 2**i,
                                      strides=strides,
                                      dropout_rate=dropout_rate,
                                      train=train,
                                      dtype=dtype,
                                      name=f'block_{i + 1}_{j}')
                    layer_activations[current_rep_key] = x
                    rep_key = current_rep_key

        current_rep_key = 'avg_pool'
        if input_layer_key == current_rep_key:
            x = inputs
            input_is_set = True
            layer_activations[current_rep_key] = x
            rep_key = current_rep_key
        elif input_is_set:
            # Global Average Pool
            x = jnp.mean(x, axis=(1, 2))
            layer_activations[current_rep_key] = x
            rep_key = current_rep_key

        # DANN module
        if has_discriminator:
            z = dann_utils.flip_grad_identity(x)
            z = nn.Dense(z, 2, name='disc_l1', bias=True)
            z = nn.relu(z)
            z = nn.Dense(z, 2, name='disc_l2', bias=True)

        current_rep_key = 'head'
        if input_layer_key == current_rep_key:
            x = inputs
            layer_activations[current_rep_key] = x
            rep_key = current_rep_key

            logging.warn('Input was never used')
        elif input_is_set:
            x = nn.Dense(x,
                         num_outputs,
                         dtype=dtype,
                         bias_init=head_bias_init,
                         name='head')

        # Make sure that the output is float32, even if our previous computations
        # are in float16, or other types.
        x = jnp.asarray(x, jnp.float32)

        outputs = x
        if return_activations:
            outputs = (x, layer_activations, rep_key)
            if discriminator and has_discriminator:
                outputs = outputs + (z, )
        else:
            del layer_activations
            if discriminator and has_discriminator:
                outputs = (x, z)
        if discriminator and (not has_discriminator):
            raise ValueError(
                'Incosistent values passed for discriminator and has_discriminator'
            )
        return outputs