def apply(self,
            x,
            channels,
            strides=(1, 1),
            dropout_rate=0.0,
            norm_layer='group_norm',
            train=True):
    norm_layer_name = ''
    if norm_layer == 'batch_norm':
      norm_layer = nn.BatchNorm.partial(use_running_average=not train)
      norm_layer_name = 'bn'
    elif norm_layer == 'group_norm':
      norm_layer = nn.GroupNorm.partial(num_groups=16)
      norm_layer_name = 'gn'

    y = norm_layer(x, name=f'{norm_layer_name}1')
    y = jax.nn.relu(y)
    y = nn.Conv(y, channels, (3, 3), strides, padding='SAME', name='conv1')
    y = norm_layer(y, name=f'{norm_layer_name}2')
    y = jax.nn.relu(y)
    if dropout_rate > 0.0:
      y = nn.dropout(y, dropout_rate, deterministic=not train)
    y = nn.Conv(y, channels, (3, 3), padding='SAME', name='conv2')

    # Apply an up projection in case of channel mismatch
    if (x.shape[-1] != channels) or strides != (1, 1):
      x = nn.Conv(x, channels, (3, 3), strides, padding='SAME')
    return x + y
示例#2
0
  def apply(self,
            x,
            act,
            normalize,
            temb=None,
            out_ch=None,
            conv_shortcut=False,
            dropout=0.1,
            train=True,
            skip_rescale=False,
            init_scale=0.):
    B, H, W, C = x.shape
    out_ch = out_ch if out_ch else C
    h = act(normalize(x, num_groups=min(x.shape[-1] // 4, 32)))
    h = conv3x3(h, out_ch)
    # Add bias to each feature map conditioned on the time embedding
    if temb is not None:
      h += nn.Dense(
          act(temb), out_ch, kernel_init=default_init())[:, None, None, :]

    h = act(normalize(h, num_groups=min(h.shape[-1] // 4, 32)))
    h = nn.dropout(h, dropout, deterministic=not train)
    h = conv3x3(h, out_ch, init_scale=init_scale)
    if C != out_ch:
      if conv_shortcut:
        x = conv3x3(x, out_ch)
      else:
        x = NIN(x, out_ch)

    if not skip_rescale:
      return x + h
    else:
      return (x + h) / np.sqrt(2.)
示例#3
0
 def apply(self,
           x,
           act,
           normalize,
           temb=None,
           out_ch=None,
           conv_shortcut=False,
           dropout=0.5,
           train=True):
     B, H, W, C = x.shape
     out_ch = out_ch if out_ch else C
     h = act(normalize(x))
     h = ddpm_conv3x3(h, out_ch)
     # Add bias to each feature map conditioned on the time embedding
     if temb is not None:
         h += nn.Dense(act(temb), out_ch,
                       kernel_init=default_init())[:, None, None, :]
     h = act(normalize(h))
     h = nn.dropout(h, dropout, deterministic=not train)
     h = ddpm_conv3x3(h, out_ch, init_scale=0.)
     if C != out_ch:
         if conv_shortcut:
             x = ddpm_conv3x3(x, out_ch)
         else:
             x = NIN(x, out_ch)
     return x + h
示例#4
0
 def apply(self,
           inputs,
           mlp_dim,
           out_dim=None,
           dropout_rate=0.1,
           deterministic=False,
           kernel_init=nn.initializers.xavier_uniform(),
           bias_init=nn.initializers.normal(stddev=1e-6)):
   """Applies Transformer MlpBlock module."""
   actual_out_dim = inputs.shape[-1] if out_dim is None else out_dim
   x = nn.Dense(inputs, mlp_dim, kernel_init=kernel_init, bias_init=bias_init)
   x = nn.gelu(x)
   x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
   output = nn.Dense(
       x, actual_out_dim, kernel_init=kernel_init, bias_init=bias_init)
   output = nn.dropout(output, rate=dropout_rate, deterministic=deterministic)
   return output
示例#5
0
    def apply(self,
              x,
              channels,
              strides=(1, 1),
              dropout_rate=0.0,
              normalization='bn',
              activation_f=None,
              std_penalty_mult=0,
              use_residual=1,
              train=True,
              bias_scale=0.0,
              weight_norm='none',
              compensate_padding=True):
        norm = get_norm(activation_f, normalization, train)

        conv = get_conv(activation_f, bias_scale, weight_norm,
                        compensate_padding, normalization)
        penalty = 0
        y = x
        y = norm(y, name='norm1')
        if std_penalty_mult > 0:
            penalty += std_penalty(y)
        y = activation_f(y, features=y.shape[-1])
        y = conv(
            y,
            channels,
            (3, 3),
            strides,
            padding='SAME',
            name='conv1',
        )
        y = norm(y, name='norm2')
        if std_penalty_mult > 0:
            penalty += std_penalty(y)
        y = activation_f(y, features=y.shape[-1])
        if dropout_rate > 0.0:
            y = nn.dropout(y, dropout_rate, deterministic=not train)
        y = conv(y, channels, (3, 3), padding='SAME', name='conv2')

        if use_residual == 1:
            # Apply an up projection in case of channel mismatch
            if (x.shape[-1] != channels) or strides != (1, 1):
                x = conv(x, y.shape[-1], (3, 3), strides, padding='SAME')
            result = x + y
        elif use_residual == 2:
            # Unit variance preserving residual.
            if (x.shape[-1] != channels) or strides != (1, 1):
                x = conv(x, y.shape[-1], (3, 3), strides, padding='SAME')

            result = (x + y) / jnp.sqrt(
                1**2 + 1**2)  # Sum of independent normals.
        else:
            result = y

        return result, penalty
示例#6
0
    def apply(self,
              x,
              filters,
              strides=(1, 1),
              dropout_rate=0.0,
              epsilon=1e-5,
              momentum=0.9,
              norm_layer='batch_norm',
              train=True,
              dtype=jnp.float32):

        # TODO(samirabnar): Make 4 a parameter.
        needs_projection = x.shape[-1] != filters * 4 or strides != (1, 1)
        norm_layer_name = ''
        if norm_layer == 'batch_norm':
            norm_layer = nn.BatchNorm.partial(use_running_average=not train,
                                              momentum=momentum,
                                              epsilon=epsilon,
                                              dtype=dtype)
            norm_layer_name = 'bn'
        elif norm_layer == 'group_norm':
            norm_layer = nn.GroupNorm.partial(num_groups=16, dtype=dtype)
            norm_layer_name = 'gn'

        conv = nn.Conv.partial(bias=False, dtype=dtype)

        residual = x
        if needs_projection:
            residual = conv(residual,
                            filters * 4, (1, 1),
                            strides,
                            name='proj_conv')
            residual = norm_layer(residual, name=f'proj_{norm_layer_name}')

        y = conv(x, filters, (1, 1), name='conv1')
        y = norm_layer(y, name=f'{norm_layer_name}1')
        y = nn.relu(y)

        y = conv(y, filters, (3, 3), strides, name='conv2')
        y = norm_layer(y, name=f'{norm_layer_name}2')
        y = nn.relu(y)

        if dropout_rate > 0.0:
            y = nn.dropout(y, dropout_rate, deterministic=not train)
        y = conv(y, filters * 4, (1, 1), name='conv3')
        y = norm_layer(y,
                       name=f'{norm_layer_name}3',
                       scale_init=nn.initializers.zeros)
        y = nn.relu(residual + y)

        return y
示例#7
0
  def apply(self,
            x,
            act,
            normalize,
            up=False,
            down=False,
            temb=None,
            out_ch=None,
            dropout=0.1,
            fir=False,
            fir_kernel=[1, 3, 3, 1],
            train=True,
            skip_rescale=True,
            init_scale=0.):
    B, H, W, C = x.shape
    out_ch = out_ch if out_ch else C
    h = act(normalize(x, num_groups=min(x.shape[-1] // 4, 32)))

    if up:
      if fir:
        h = up_or_down_sampling.upsample_2d(h, fir_kernel, factor=2)
        x = up_or_down_sampling.upsample_2d(x, fir_kernel, factor=2)
      else:
        h = up_or_down_sampling.naive_upsample_2d(h, factor=2)
        x = up_or_down_sampling.naive_upsample_2d(x, factor=2)
    elif down:
      if fir:
        h = up_or_down_sampling.downsample_2d(h, fir_kernel, factor=2)
        x = up_or_down_sampling.downsample_2d(x, fir_kernel, factor=2)
      else:
        h = up_or_down_sampling.naive_downsample_2d(h, factor=2)
        x = up_or_down_sampling.naive_downsample_2d(x, factor=2)

    h = conv3x3(h, out_ch)
    # Add bias to each feature map conditioned on the time embedding
    if temb is not None:
      h += nn.Dense(
          act(temb), out_ch, kernel_init=default_init())[:, None, None, :]

    h = act(normalize(h, num_groups=min(h.shape[-1] // 4, 32)))
    h = nn.dropout(h, dropout, deterministic=not train)
    h = conv3x3(h, out_ch, init_scale=init_scale)
    if C != out_ch or up or down:
      x = conv1x1(x, out_ch)

    if not skip_rescale:
      return x + h
    else:
      return (x + h) / np.sqrt(2.)
示例#8
0
  def apply(self, x, channels, strides=(1, 1), dropout_rate=0.0, train=True):
    batch_norm = nn.BatchNorm.partial(use_running_average=not train,
                                      momentum=0.9, epsilon=1e-5)

    y = batch_norm(x, name='bn1')
    y = jax.nn.relu(y)
    y = nn.Conv(y, channels, (3, 3), strides, padding='SAME', name='conv1')
    y = batch_norm(y, name='bn2')
    y = jax.nn.relu(y)
    if dropout_rate > 0.0:
      y = nn.dropout(y, dropout_rate, deterministic=not train)
    y = nn.Conv(y, channels, (3, 3), padding='SAME', name='conv2')

    # Apply an up projection in case of channel mismatch
    if (x.shape[-1] != channels) or strides != (1, 1):
      x = nn.Conv(x, channels, (3, 3), strides, padding='SAME')
    return x + y
示例#9
0
    def apply(self,
              inputs,
              vocab_size,
              emb_dim=512,
              num_heads=8,
              num_layers=6,
              qkv_dim=512,
              mlp_dim=2048,
              max_len=2048,
              train=False,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              causal=True,
              cache=None,
              positional_encoding_module=AddLearnedPositionalEncodings,
              self_attention_module=nn.SelfAttention,
              attention_fn=None,
              pad_token=None,
              output_head='logits'):
        """Applies Transformer model on the inputs.

    Args:
      inputs: An array of shape (batch_size, length) or (batch_size, length,
        vocab_size) with the input sequences. When 2-dimensional, the array
        contains sequences of int tokens. Otherwise, the array contains
        next-token distributions over tokens (e.g. one-hot representations).
      vocab_size: An int with the size of the vocabulary.
      emb_dim: An int with the token embedding dimension.
      num_heads: An int with the number of attention heads.
      num_layers: An int with the number of transformer encoder layers.
      qkv_dim: An int with the dimension of the query/key/value vectors.
      mlp_dim: An int with the inner dimension of the feed-forward network which
        follows the attention block.
      max_len: An int with the maximum training sequence length.
      train: A bool denoting whether we are currently training.
      dropout_rate: A float with the dropout rate.
      attention_dropout_rate: A float with a dropout rate for attention weights.
      causal: Whether to apply causal masking.
      cache: Cache for decoding.
      positional_encoding_module: A module used for adding positional encodings.
      self_attention_module: Self attention module.
      attention_fn: Method to use in place of dot product attention.
      pad_token: Token to ignore in attention.
      output_head: String or iterable over strings containing the model's output
        head(s) to return.

    Returns:
      Output of a transformer decoder. If output_head is a string, we return a
        single output head output; if output_head is an iterable, we return a
        dict with (output head name, output head output) key-value pairs.
    """
        if inputs.ndim != 2 and inputs.ndim != 3:
            raise ValueError('Expected 2 or 3 dimensions, found %d.' %
                             inputs.ndim)

        if inputs.ndim == 3:
            padding_mask = jnp.ones_like(inputs[Ellipsis, 0])
        elif pad_token is None:
            padding_mask = jnp.ones_like(inputs)
        else:
            # Mask out padding tokens.
            padding_mask = jnp.where(inputs != pad_token, 1,
                                     0).astype(jnp.float32)
        padding_mask = padding_mask[Ellipsis, None]  # Add embedding dimension.

        heads = dict()
        x = inputs
        if inputs.ndim == 2:
            x = x.astype('int32')
        x = Embed(x,
                  num_embeddings=vocab_size,
                  num_features=emb_dim,
                  name='embed')

        if positional_encoding_module == AddLearnedPositionalEncodings:
            x = positional_encoding_module(
                x,
                max_len=max_len,
                cache=cache,
                posemb_init=sinusoidal_init(max_len=max_len))
        else:
            x = positional_encoding_module(x, max_len=max_len)
        x = nn.dropout(x, rate=dropout_rate, deterministic=not train)
        heads['input_emb'] = x
        for i in range(num_layers):
            x = Transformer1DBlock(
                x,
                qkv_dim=qkv_dim,
                mlp_dim=mlp_dim,
                num_heads=num_heads,
                causal_mask=causal,
                padding_mask=padding_mask,
                dropout_rate=dropout_rate,
                attention_dropout_rate=attention_dropout_rate,
                self_attention_module=self_attention_module,
                deterministic=not train,
                attention_fn=attention_fn,
                cache=cache,
            )
            heads['layer_%s' % i] = x
        x = nn.LayerNorm(x)
        heads['output_emb'] = x * padding_mask  # Zero out PAD positions.
        if 'logits' in output_head:
            logits = nn.Dense(x,
                              vocab_size,
                              kernel_init=nn.initializers.xavier_uniform(),
                              bias_init=nn.initializers.normal(stddev=1e-6))
            heads['logits'] = logits

        if 'regression' in output_head:
            regression = nn.Dense(
                x,
                1,
                kernel_init=nn.initializers.xavier_uniform(),
                bias_init=nn.initializers.normal(stddev=1e-6))
            regression = jnp.squeeze(regression, axis=-1)
            heads['regression'] = regression

        if isinstance(output_head, (tuple, list)):
            return {head: heads[head] for head in output_head}
        return heads[output_head]
示例#10
0
    def apply(self,
              inputs,
              qkv_dim,
              mlp_dim,
              num_heads,
              causal_mask=False,
              padding_mask=None,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              deterministic=False,
              self_attention_module=nn.SelfAttention,
              attention_fn=None,
              cache=None):
        """Applies Transformer1DBlock module.

    Args:
      inputs: input data
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      num_heads: number of heads
      causal_mask: bool, mask future or not
      padding_mask: bool, mask padding tokens
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      deterministic: bool, deterministic or not (to apply dropout)
      self_attention_module: Self attention module.
      attention_fn: dot product function to use inside attention.
      cache: Cache for decoding.

    Returns:
      output after transformer block.

    """

        # Attention block.
        assert inputs.ndim == 3
        x = nn.LayerNorm(inputs)
        if attention_fn is not None:
            self_attention_module = self_attention_module.partial(
                attention_fn=attention_fn)
        x = self_attention_module(
            x,
            num_heads=num_heads,
            qkv_features=qkv_dim,
            attention_axis=(1, ),
            causal_mask=causal_mask,
            padding_mask=padding_mask,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.normal(stddev=1e-6),
            bias=False,
            broadcast_dropout=False,
            dropout_rate=attention_dropout_rate,
            deterministic=deterministic,
            cache=cache)
        x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
        x = x + inputs

        # MLP block.
        y = nn.LayerNorm(x)
        y = MlpBlock(y,
                     mlp_dim=mlp_dim,
                     dropout_rate=dropout_rate,
                     deterministic=deterministic)

        return x + y
示例#11
0
    def apply(self,
              x,
              *,
              self_attention_module,
              dim_intermediate,
              is_training,
              dropout_rate=0.1,
              use_pre_layernorm=False,
              layernorm_epsilon=1e-6,
              with_aux_outputs=True):
        """Compute self-attention with a feed-forward network on top.

    Args:
      x: Input representations.
      self_attention_module: Self-Attention layer.
      dim_intermediate: Size of the intermediate layer of the feed forward.
      is_training: Wether to enable dropout.
      dropout_rate: Dropout probability.
      use_pre_layernorm: Use pre layer norm from
        https://arxiv.org/abs/2002.04745.
      layernorm_epsilon: Epsilon parameter for all the layer norms.
      with_aux_outputs: Whether the self_attention_module has an aux output.

    Returns:
      New representations in a jnp.array of same shape as `x`.
    """
        dim_hidden = x.shape[-1]
        use_pre_ln = use_pre_layernorm
        use_post_ln = not use_pre_ln

        def apply_ln_if(pred, x, name):
            if pred:
                return nn.LayerNorm(x, epsilon=layernorm_epsilon, name=name)
            else:
                return x

        # attention
        x = apply_ln_if(use_pre_ln, x, "ln_pre_att")
        x_att = self_attention_module(x)
        if with_aux_outputs:
            x_att, output_aux = x_att

        # dropout norm and add
        x_att = nn.dropout(x_att, dropout_rate, deterministic=not is_training)
        x = x + x_att
        x = apply_ln_if(use_post_ln, x, "ln_post_att")

        # feed forward
        x_ffn = x
        x_ffn = apply_ln_if(use_pre_ln, x, "ln_pre_ffn")
        x_ffn = nn.Dense(x_ffn, dim_intermediate, name="ff_1")
        x_ffn = jax.nn.relu(x_ffn)
        x_ffn = nn.Dense(x_ffn, dim_hidden, name="ff_2")

        # dropout norm and add
        x_ffn = nn.dropout(x_ffn, dropout_rate, deterministic=not is_training)
        x = x + x_ffn
        x = apply_ln_if(use_post_ln, x, "ln_post_ffn")

        if with_aux_outputs:
            output = x, output_aux
        else:
            output = x
        return output
示例#12
0
    def apply(self, x, config, num_classes, train=True):
        """Creates a model definition."""
        b, c = x.shape[0], x.shape[3]
        k = config.k
        sigma = config.ptopk_sigma
        num_samples = config.ptopk_num_samples

        sigma *= self.state("sigma_mutiplier",
                            shape=(),
                            initializer=nn.initializers.ones).value

        stats = {"x": x, "sigma": sigma}

        feature_extractor = models.ResNet50.shared(train=train,
                                                   name="ResNet_0")

        rpn_feature = feature_extractor(x)
        rpn_scores, rpn_stats = ProposalNet(jax.lax.stop_gradient(rpn_feature),
                                            communication=Communication(
                                                config.communication),
                                            train=train)
        stats.update(rpn_stats)

        # rpn_scores are a list of score images. We keep track of the structure
        # because it is used in the aggregation step later-on.
        rpn_scores_shapes = [s.shape for s in rpn_scores]
        rpn_scores_flat = jnp.concatenate(
            [jnp.reshape(s, [b, -1]) for s in rpn_scores], axis=1)
        top_k_indicators = sample_patches.select_patches_perturbed_topk(
            rpn_scores_flat, k=k, sigma=sigma, num_samples=num_samples)
        top_k_indicators = jnp.transpose(top_k_indicators, [0, 2, 1])
        offset = 0
        weights = []
        for sh in rpn_scores_shapes:
            cur = top_k_indicators[:, :, offset:offset + sh[1] * sh[2]]
            cur = jnp.reshape(cur, [b, k, sh[1], sh[2]])
            weights.append(cur)
            offset += sh[1] * sh[2]
        chex.assert_equal(offset, top_k_indicators.shape[-1])

        part_imgs = weighted_anchor_aggregator(x, weights)
        chex.assert_shape(part_imgs, (b * k, 224, 224, c))
        stats["part_imgs"] = jnp.reshape(part_imgs, [b, k * 224, 224, c])

        part_features = feature_extractor(part_imgs)
        part_features = jnp.mean(part_features,
                                 axis=[1, 2])  # GAP the spatial dims

        part_features = nn.dropout(  # features from parts
            jnp.reshape(part_features, [b * k, 2048]),
            0.5,
            deterministic=not train,
            rng=nn.make_rng())
        features = nn.dropout(  # features from whole image
            jnp.reshape(jnp.mean(rpn_feature, axis=[1, 2]), [b, -1]),
            0.5,
            deterministic=not train,
            rng=nn.make_rng())

        # Mean pool all part features, add it to features and predict logits.
        concat_out = jnp.mean(jnp.reshape(part_features, [b, k, 2048]),
                              axis=1) + features
        concat_logits = nn.Dense(concat_out, num_classes)
        raw_logits = nn.Dense(features, num_classes)
        part_logits = jnp.reshape(nn.Dense(part_features, num_classes),
                                  [b, k, -1])

        all_logits = {
            "raw_logits": raw_logits,
            "concat_logits": concat_logits,
            "part_logits": part_logits,
        }
        # add entropy into it for entropy regularization.
        stats["rpn_scores_entropy"] = jax.scipy.special.entr(
            jax.nn.softmax(stats["raw_scores"])).sum(axis=1).mean(axis=0)
        return all_logits, stats
示例#13
0
def extract_patches_from_indicators(x,
                                    indicators,
                                    patch_size,
                                    patch_dropout,
                                    grid_shape,
                                    train,
                                    iterative=False):
    """Extract patches from a batch of images.

  Args:
    x: The batch of images of shape (batch, height, width, channels).
    indicators: The one hot indicators of shape (batch, num_patches, k).
    patch_size: The size of the (squared) patches to extract.
    patch_dropout: Probability to replace a patch by 0 values.
    grid_shape: Pair of height, width of the disposition of the num_patches
      patches.
    train: If the model is being trained. Disable dropout if not.
    iterative: If True, etracts the patches with a for loop rather than
      instanciating the "all patches" tensor and extracting by dotproduct with
      indicators. `iterative` is more memory efficient.

  Returns:
    The patches extracted from x with shape
      (batch, k, patch_size, patch_size, channels).

  """
    batch_size, height, width, channels = x.shape
    scores_h, scores_w = grid_shape
    k = indicators.shape[-1]
    indicators = einops.rearrange(indicators,
                                  "b (h w) k -> b k h w",
                                  h=scores_h,
                                  w=scores_w)

    scale_height = height // scores_h
    scale_width = width // scores_w
    padded_height = scale_height * scores_h + patch_size - 1
    padded_width = scale_width * scores_w + patch_size - 1
    top_pad = (patch_size - scale_height) // 2
    left_pad = (patch_size - scale_width) // 2
    bottom_pad = padded_height - top_pad - height
    right_pad = padded_width - left_pad - width

    # TODO(jbcdnr): assert padding is positive.

    padded_x = jnp.pad(x, [(0, 0), (top_pad, bottom_pad),
                           (left_pad, right_pad), (0, 0)])

    # Extract the patches. Iterative fits better in memory as it does not
    # instanciate the "all patches" tensor but iterate over them to compute the
    # weighted sum with the indicator variables from topk.
    if not iterative:
        assert patch_dropout == 0., "Patch dropout not implemented."
        patches = utils.extract_images_patches(padded_x,
                                               window_size=(patch_size,
                                                            patch_size),
                                               stride=(scale_height,
                                                       scale_width))

        shape = (batch_size, scores_h, scores_w, patch_size, patch_size,
                 channels)
        chex.assert_shape(patches, shape)

        patches = jnp.einsum("b k h w, b h w i j c -> b k i j c", indicators,
                             patches)

    else:
        mask = jnp.ones((batch_size, scores_h, scores_w))
        mask = nn.dropout(mask, patch_dropout, deterministic=not train)

        def accumulate_patches(acc, index_i_j):
            i, j = index_i_j
            patch = jax.lax.dynamic_slice(
                padded_x, (0, i * scale_height, j * scale_width, 0),
                (batch_size, patch_size, patch_size, channels))
            weights = indicators[:, :, i, j]

            is_masked = mask[:, i, j]
            weighted_patch = jnp.einsum("b, bk, bijc -> bkijc", is_masked,
                                        weights, patch)
            chex.assert_equal_shape([acc, weighted_patch])
            acc += weighted_patch
            return acc, None

        indices = jnp.stack(jnp.meshgrid(jnp.arange(scores_h),
                                         jnp.arange(scores_w),
                                         indexing="ij"),
                            axis=-1)
        indices = indices.reshape((-1, 2))
        init_patches = jnp.zeros(
            (batch_size, k, patch_size, patch_size, channels))
        patches, _ = jax.lax.scan(accumulate_patches, init_patches, indices)

    return patches
示例#14
0
    def apply(self,
              inputs,
              num_outputs,
              num_filters=64,
              num_layers=50,
              dropout_rate=0.0,
              input_dropout_rate=0.0,
              train=True,
              dtype=jnp.float32,
              head_bias_init=jnp.zeros,
              return_activations=False,
              input_layer_key='input',
              has_discriminator=False,
              discriminator=False):
        """Apply a ResNet network on the input.

    Args:
      inputs: jnp array; Inputs.
      num_outputs: int; Number of output units.
      num_filters: int; Determines base number of filters. Number of filters in
        block i is  num_filters * 2 ** i.
      num_layers: int; Number of layers (should be one of the predefined ones.)
      dropout_rate: float; Rate of dropping out the output of different hidden
        layers.
      input_dropout_rate: float; Rate of dropping out the input units.
      train: bool; Is train?
      dtype: jnp type; Type of the outputs.
      head_bias_init: fn(rng_key, shape)--> jnp array; Initializer for head bias
        parameters.
      return_activations: bool; If True hidden activation are also returned.
      input_layer_key: str; Determines where to plugin the input (this is to
        enable providing inputs to slices of the model). If `input_layer_key` is
        `layer_i` we assume the inputs are the activations of `layer_i` and pass
        them to `layer_{i+1}`.
      has_discriminator: bool; Whether the model should have discriminator
        layer.
      discriminator: bool; Whether we should return discriminator logits.

    Returns:
      Unnormalized Logits with shape `[bs, num_outputs]`,
      if return_activations:
        Logits, dict of hidden activations and the key to the representation(s)
        which will be used in as ``The Representation'', e.g., for computing
        losses.
    """
        if num_layers not in ResNet._block_size_options:
            raise ValueError('Please provide a valid number of layers')

        block_sizes = ResNet._block_size_options[num_layers]

        layer_activations = collections.OrderedDict()
        input_is_set = False
        current_rep_key = 'input'
        if input_layer_key == current_rep_key:
            x = inputs
            input_is_set = True

        if input_is_set:
            # Input dropout
            x = nn.dropout(x, input_dropout_rate, deterministic=not train)
            layer_activations[current_rep_key] = x
            rep_key = current_rep_key

        current_rep_key = 'init_conv'
        if input_layer_key == current_rep_key:
            x = inputs
            input_is_set = True
            layer_activations[current_rep_key] = x
            rep_key = current_rep_key
        elif input_is_set:
            # First block
            x = nn.Conv(x,
                        num_filters, (7, 7), (2, 2),
                        padding=[(3, 3), (3, 3)],
                        bias=False,
                        dtype=dtype,
                        name='init_conv')
            x = nn.BatchNorm(x,
                             use_running_average=not train,
                             momentum=0.9,
                             epsilon=1e-5,
                             dtype=dtype,
                             name='init_bn')
            x = nn.relu(x)
            x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')

            layer_activations[current_rep_key] = x
            rep_key = current_rep_key

        # Residual blocks
        for i, block_size in enumerate(block_sizes):

            # Stage i (each stage contains blocks of the same size).
            for j in range(block_size):
                strides = (2, 2) if i > 0 and j == 0 else (1, 1)
                current_rep_key = f'block_{i + 1}+{j}'
                if input_layer_key == current_rep_key:
                    x = inputs
                    input_is_set = True
                    layer_activations[current_rep_key] = x
                    rep_key = current_rep_key
                elif input_is_set:
                    x = ResidualBlock(x,
                                      num_filters * 2**i,
                                      strides=strides,
                                      dropout_rate=dropout_rate,
                                      train=train,
                                      dtype=dtype,
                                      name=f'block_{i + 1}_{j}')
                    layer_activations[current_rep_key] = x
                    rep_key = current_rep_key

        current_rep_key = 'avg_pool'
        if input_layer_key == current_rep_key:
            x = inputs
            input_is_set = True
            layer_activations[current_rep_key] = x
            rep_key = current_rep_key
        elif input_is_set:
            # Global Average Pool
            x = jnp.mean(x, axis=(1, 2))
            layer_activations[current_rep_key] = x
            rep_key = current_rep_key

        # DANN module
        if has_discriminator:
            z = dann_utils.flip_grad_identity(x)
            z = nn.Dense(z, 2, name='disc_l1', bias=True)
            z = nn.relu(z)
            z = nn.Dense(z, 2, name='disc_l2', bias=True)

        current_rep_key = 'head'
        if input_layer_key == current_rep_key:
            x = inputs
            layer_activations[current_rep_key] = x
            rep_key = current_rep_key

            logging.warn('Input was never used')
        elif input_is_set:
            x = nn.Dense(x,
                         num_outputs,
                         dtype=dtype,
                         bias_init=head_bias_init,
                         name='head')

        # Make sure that the output is float32, even if our previous computations
        # are in float16, or other types.
        x = jnp.asarray(x, jnp.float32)

        outputs = x
        if return_activations:
            outputs = (x, layer_activations, rep_key)
            if discriminator and has_discriminator:
                outputs = outputs + (z, )
        else:
            del layer_activations
            if discriminator and has_discriminator:
                outputs = (x, z)
        if discriminator and (not has_discriminator):
            raise ValueError(
                'Incosistent values passed for discriminator and has_discriminator'
            )
        return outputs