Example #1
0
    def apply(self,
              x,
              num_filters=64,
              block_sizes=(3, 4, 6, 3),
              train=True,
              block=BottleneckBlock,
              small_inputs=False):
        if small_inputs:
            x = nn.Conv(x,
                        num_filters,
                        kernel_size=(3, 3),
                        strides=(1, 1),
                        bias=False,
                        name="init_conv")
        else:
            x = nn.Conv(x,
                        num_filters,
                        kernel_size=(7, 7),
                        strides=(2, 2),
                        bias=False,
                        name="init_conv")
        x = nn.BatchNorm(x,
                         use_running_average=not train,
                         epsilon=1e-5,
                         name="init_bn")
        if not small_inputs:
            x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME")
        for i, block_size in enumerate(block_sizes):
            for j in range(block_size):
                strides = (2, 2) if i > 0 and j == 0 else (1, 1)
                x = block(x, num_filters * 2**i, strides=strides, train=train)

        return x
Example #2
0
    def apply(self,
              x,
              blocks_per_group,
              channel_multiplier,
              num_outputs,
              dropout_rate=0.0,
              train=True):

        x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv')
        x = WideResnetGroup(x,
                            blocks_per_group,
                            16 * channel_multiplier,
                            dropout_rate=dropout_rate,
                            train=train)
        x = WideResnetGroup(x,
                            blocks_per_group,
                            32 * channel_multiplier, (2, 2),
                            dropout_rate=dropout_rate,
                            train=train)
        x = WideResnetGroup(x,
                            blocks_per_group,
                            64 * channel_multiplier, (2, 2),
                            dropout_rate=dropout_rate,
                            train=train)
        x = nn.BatchNorm(x,
                         use_running_average=not train,
                         momentum=0.9,
                         epsilon=1e-5)
        x = jax.nn.relu(x)
        x = nn.avg_pool(x, (8, 8))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(x, num_outputs)
        return x
Example #3
0
 def apply(self, x, num_outputs, train=True):
     x = nn.Conv(x,
                 self.NUM_FILTERS, (7, 7), (2, 2),
                 bias=False,
                 name='init_conv')
     x = nn.BatchNorm(x,
                      use_running_average=not train,
                      momentum=0.9,
                      epsilon=1e-5,
                      name='init_bn')
     x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
     for i, block_size in enumerate(self.BLOCK_SIZES):
         for j in range(block_size):
             strides = (2, 2) if i > 0 and j == 0 else (1, 1)
             x = BottleneckBlock(x,
                                 self.NUM_FILTERS * 2**i,
                                 strides=strides,
                                 groups=self.GROUPS,
                                 base_width=self.WIDTH_PER_GROUP,
                                 train=train)
     x = jnp.mean(x, axis=(1, 2))
     x = nn.Dense(x, num_outputs, name='clf')
     return x
Example #4
0
    def apply(self,
              inputs,
              num_outputs,
              num_filters=64,
              num_layers=50,
              dropout_rate=0.0,
              input_dropout_rate=0.0,
              train=True,
              dtype=jnp.float32,
              head_bias_init=jnp.zeros,
              return_activations=False,
              input_layer_key='input',
              has_discriminator=False,
              discriminator=False):
        """Apply a ResNet network on the input.

    Args:
      inputs: jnp array; Inputs.
      num_outputs: int; Number of output units.
      num_filters: int; Determines base number of filters. Number of filters in
        block i is  num_filters * 2 ** i.
      num_layers: int; Number of layers (should be one of the predefined ones.)
      dropout_rate: float; Rate of dropping out the output of different hidden
        layers.
      input_dropout_rate: float; Rate of dropping out the input units.
      train: bool; Is train?
      dtype: jnp type; Type of the outputs.
      head_bias_init: fn(rng_key, shape)--> jnp array; Initializer for head bias
        parameters.
      return_activations: bool; If True hidden activation are also returned.
      input_layer_key: str; Determines where to plugin the input (this is to
        enable providing inputs to slices of the model). If `input_layer_key` is
        `layer_i` we assume the inputs are the activations of `layer_i` and pass
        them to `layer_{i+1}`.
      has_discriminator: bool; Whether the model should have discriminator
        layer.
      discriminator: bool; Whether we should return discriminator logits.

    Returns:
      Unnormalized Logits with shape `[bs, num_outputs]`,
      if return_activations:
        Logits, dict of hidden activations and the key to the representation(s)
        which will be used in as ``The Representation'', e.g., for computing
        losses.
    """
        if num_layers not in ResNet._block_size_options:
            raise ValueError('Please provide a valid number of layers')

        block_sizes = ResNet._block_size_options[num_layers]

        layer_activations = collections.OrderedDict()
        input_is_set = False
        current_rep_key = 'input'
        if input_layer_key == current_rep_key:
            x = inputs
            input_is_set = True

        if input_is_set:
            # Input dropout
            x = nn.dropout(x, input_dropout_rate, deterministic=not train)
            layer_activations[current_rep_key] = x
            rep_key = current_rep_key

        current_rep_key = 'init_conv'
        if input_layer_key == current_rep_key:
            x = inputs
            input_is_set = True
            layer_activations[current_rep_key] = x
            rep_key = current_rep_key
        elif input_is_set:
            # First block
            x = nn.Conv(x,
                        num_filters, (7, 7), (2, 2),
                        padding=[(3, 3), (3, 3)],
                        bias=False,
                        dtype=dtype,
                        name='init_conv')
            x = nn.BatchNorm(x,
                             use_running_average=not train,
                             momentum=0.9,
                             epsilon=1e-5,
                             dtype=dtype,
                             name='init_bn')
            x = nn.relu(x)
            x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')

            layer_activations[current_rep_key] = x
            rep_key = current_rep_key

        # Residual blocks
        for i, block_size in enumerate(block_sizes):

            # Stage i (each stage contains blocks of the same size).
            for j in range(block_size):
                strides = (2, 2) if i > 0 and j == 0 else (1, 1)
                current_rep_key = f'block_{i + 1}+{j}'
                if input_layer_key == current_rep_key:
                    x = inputs
                    input_is_set = True
                    layer_activations[current_rep_key] = x
                    rep_key = current_rep_key
                elif input_is_set:
                    x = ResidualBlock(x,
                                      num_filters * 2**i,
                                      strides=strides,
                                      dropout_rate=dropout_rate,
                                      train=train,
                                      dtype=dtype,
                                      name=f'block_{i + 1}_{j}')
                    layer_activations[current_rep_key] = x
                    rep_key = current_rep_key

        current_rep_key = 'avg_pool'
        if input_layer_key == current_rep_key:
            x = inputs
            input_is_set = True
            layer_activations[current_rep_key] = x
            rep_key = current_rep_key
        elif input_is_set:
            # Global Average Pool
            x = jnp.mean(x, axis=(1, 2))
            layer_activations[current_rep_key] = x
            rep_key = current_rep_key

        # DANN module
        if has_discriminator:
            z = dann_utils.flip_grad_identity(x)
            z = nn.Dense(z, 2, name='disc_l1', bias=True)
            z = nn.relu(z)
            z = nn.Dense(z, 2, name='disc_l2', bias=True)

        current_rep_key = 'head'
        if input_layer_key == current_rep_key:
            x = inputs
            layer_activations[current_rep_key] = x
            rep_key = current_rep_key

            logging.warn('Input was never used')
        elif input_is_set:
            x = nn.Dense(x,
                         num_outputs,
                         dtype=dtype,
                         bias_init=head_bias_init,
                         name='head')

        # Make sure that the output is float32, even if our previous computations
        # are in float16, or other types.
        x = jnp.asarray(x, jnp.float32)

        outputs = x
        if return_activations:
            outputs = (x, layer_activations, rep_key)
            if discriminator and has_discriminator:
                outputs = outputs + (z, )
        else:
            del layer_activations
            if discriminator and has_discriminator:
                outputs = (x, z)
        if discriminator and (not has_discriminator):
            raise ValueError(
                'Incosistent values passed for discriminator and has_discriminator'
            )
        return outputs
Example #5
0
    def apply(self,
              x,
              num_outputs,
              pyramid_alpha=200,
              pyramid_depth=272,
              train=True):
        assert (pyramid_depth - 2) % 9 == 0

        # Shake-drop hyper-params
        mask_prob = 0.5
        alpha_min, alpha_max = (-1.0, 1.0)
        beta_min, beta_max = (0.0, 1.0)

        # Bottleneck network size
        blocks_per_group = (pyramid_depth - 2) // 9
        # See Eqn 2 in https://arxiv.org/abs/1610.02915
        num_channels = 16
        # N in https://arxiv.org/abs/1610.02915
        total_blocks = blocks_per_group * 3
        delta_channels = pyramid_alpha / total_blocks

        x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv')
        x = nn.BatchNorm(x,
                         use_running_average=not train,
                         momentum=0.9,
                         epsilon=1e-5,
                         name='init_bn')

        layer_num = 1

        for block_i in range(blocks_per_group):
            num_channels += delta_channels
            layer_mask_prob = _calc_shakedrop_mask_prob(
                layer_num, total_blocks, mask_prob)
            x = BottleneckShakeDrop(x,
                                    int(num_channels), (1, 1),
                                    layer_mask_prob,
                                    alpha_min,
                                    alpha_max,
                                    beta_min,
                                    beta_max,
                                    train=train)
            layer_num += 1

        for block_i in range(blocks_per_group):
            num_channels += delta_channels
            layer_mask_prob = _calc_shakedrop_mask_prob(
                layer_num, total_blocks, mask_prob)
            x = BottleneckShakeDrop(x,
                                    int(num_channels),
                                    ((2, 2) if block_i == 0 else (1, 1)),
                                    layer_mask_prob,
                                    alpha_min,
                                    alpha_max,
                                    beta_min,
                                    beta_max,
                                    train=train)
            layer_num += 1

        for block_i in range(blocks_per_group):
            num_channels += delta_channels
            layer_mask_prob = _calc_shakedrop_mask_prob(
                layer_num, total_blocks, mask_prob)
            x = BottleneckShakeDrop(x,
                                    int(num_channels),
                                    ((2, 2) if block_i == 0 else (1, 1)),
                                    layer_mask_prob,
                                    alpha_min,
                                    alpha_max,
                                    beta_min,
                                    beta_max,
                                    train=train)
            layer_num += 1

        assert layer_num - 1 == total_blocks
        x = nn.BatchNorm(x,
                         use_running_average=not train,
                         momentum=0.9,
                         epsilon=1e-5,
                         name='final_bn')
        x = jax.nn.relu(x)
        x = nn.avg_pool(x, (8, 8))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(x, num_outputs)
        return x