예제 #1
0
    def apply(self,
              inputs,
              vocab_size,
              emb_dim=512,
              num_heads=8,
              num_layers=6,
              qkv_dim=512,
              mlp_dim=2048,
              max_len=2048,
              train=False,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              causal=True,
              cache=None,
              positional_encoding_module=AddLearnedPositionalEncodings,
              self_attention_module=nn.SelfAttention,
              attention_fn=None,
              pad_token=None,
              output_head='logits'):
        """Applies Transformer model on the inputs.

    Args:
      inputs: An array of shape (batch_size, length) or (batch_size, length,
        vocab_size) with the input sequences. When 2-dimensional, the array
        contains sequences of int tokens. Otherwise, the array contains
        next-token distributions over tokens (e.g. one-hot representations).
      vocab_size: An int with the size of the vocabulary.
      emb_dim: An int with the token embedding dimension.
      num_heads: An int with the number of attention heads.
      num_layers: An int with the number of transformer encoder layers.
      qkv_dim: An int with the dimension of the query/key/value vectors.
      mlp_dim: An int with the inner dimension of the feed-forward network which
        follows the attention block.
      max_len: An int with the maximum training sequence length.
      train: A bool denoting whether we are currently training.
      dropout_rate: A float with the dropout rate.
      attention_dropout_rate: A float with a dropout rate for attention weights.
      causal: Whether to apply causal masking.
      cache: Cache for decoding.
      positional_encoding_module: A module used for adding positional encodings.
      self_attention_module: Self attention module.
      attention_fn: Method to use in place of dot product attention.
      pad_token: Token to ignore in attention.
      output_head: String or iterable over strings containing the model's output
        head(s) to return.

    Returns:
      Output of a transformer decoder. If output_head is a string, we return a
        single output head output; if output_head is an iterable, we return a
        dict with (output head name, output head output) key-value pairs.
    """
        if inputs.ndim != 2 and inputs.ndim != 3:
            raise ValueError('Expected 2 or 3 dimensions, found %d.' %
                             inputs.ndim)

        if inputs.ndim == 3:
            padding_mask = jnp.ones_like(inputs[Ellipsis, 0])
        elif pad_token is None:
            padding_mask = jnp.ones_like(inputs)
        else:
            # Mask out padding tokens.
            padding_mask = jnp.where(inputs != pad_token, 1,
                                     0).astype(jnp.float32)
        padding_mask = padding_mask[Ellipsis, None]  # Add embedding dimension.

        heads = dict()
        x = inputs
        if inputs.ndim == 2:
            x = x.astype('int32')
        x = Embed(x,
                  num_embeddings=vocab_size,
                  num_features=emb_dim,
                  name='embed')

        if positional_encoding_module == AddLearnedPositionalEncodings:
            x = positional_encoding_module(
                x,
                max_len=max_len,
                cache=cache,
                posemb_init=sinusoidal_init(max_len=max_len))
        else:
            x = positional_encoding_module(x, max_len=max_len)
        x = nn.dropout(x, rate=dropout_rate, deterministic=not train)
        heads['input_emb'] = x
        for i in range(num_layers):
            x = Transformer1DBlock(
                x,
                qkv_dim=qkv_dim,
                mlp_dim=mlp_dim,
                num_heads=num_heads,
                causal_mask=causal,
                padding_mask=padding_mask,
                dropout_rate=dropout_rate,
                attention_dropout_rate=attention_dropout_rate,
                self_attention_module=self_attention_module,
                deterministic=not train,
                attention_fn=attention_fn,
                cache=cache,
            )
            heads['layer_%s' % i] = x
        x = nn.LayerNorm(x)
        heads['output_emb'] = x * padding_mask  # Zero out PAD positions.
        if 'logits' in output_head:
            logits = nn.Dense(x,
                              vocab_size,
                              kernel_init=nn.initializers.xavier_uniform(),
                              bias_init=nn.initializers.normal(stddev=1e-6))
            heads['logits'] = logits

        if 'regression' in output_head:
            regression = nn.Dense(
                x,
                1,
                kernel_init=nn.initializers.xavier_uniform(),
                bias_init=nn.initializers.normal(stddev=1e-6))
            regression = jnp.squeeze(regression, axis=-1)
            heads['regression'] = regression

        if isinstance(output_head, (tuple, list)):
            return {head: heads[head] for head in output_head}
        return heads[output_head]
예제 #2
0
 def apply_ln_if(pred, x, name):
     if pred:
         return nn.LayerNorm(x, epsilon=layernorm_epsilon, name=name)
     else:
         return x
예제 #3
0
    def apply(self,
              inputs,
              qkv_dim,
              mlp_dim,
              num_heads,
              causal_mask=False,
              padding_mask=None,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              deterministic=False,
              self_attention_module=nn.SelfAttention,
              attention_fn=None,
              cache=None):
        """Applies Transformer1DBlock module.

    Args:
      inputs: input data
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      num_heads: number of heads
      causal_mask: bool, mask future or not
      padding_mask: bool, mask padding tokens
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      deterministic: bool, deterministic or not (to apply dropout)
      self_attention_module: Self attention module.
      attention_fn: dot product function to use inside attention.
      cache: Cache for decoding.

    Returns:
      output after transformer block.

    """

        # Attention block.
        assert inputs.ndim == 3
        x = nn.LayerNorm(inputs)
        if attention_fn is not None:
            self_attention_module = self_attention_module.partial(
                attention_fn=attention_fn)
        x = self_attention_module(
            x,
            num_heads=num_heads,
            qkv_features=qkv_dim,
            attention_axis=(1, ),
            causal_mask=causal_mask,
            padding_mask=padding_mask,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.normal(stddev=1e-6),
            bias=False,
            broadcast_dropout=False,
            dropout_rate=attention_dropout_rate,
            deterministic=deterministic,
            cache=cache)
        x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
        x = x + inputs

        # MLP block.
        y = nn.LayerNorm(x)
        y = MlpBlock(y,
                     mlp_dim=mlp_dim,
                     dropout_rate=dropout_rate,
                     deterministic=deterministic)

        return x + y
예제 #4
0
    def apply(self,
              x,
              *,
              patch_size,
              k,
              downscale,
              scorer_has_se,
              normalization_str="identity",
              selection_method,
              selection_method_kwargs=None,
              selection_method_inference=None,
              patch_dropout=0.,
              hard_topk_probability=0.,
              random_patch_probability=0.,
              use_iterative_extraction,
              append_position_to_input,
              feature_network,
              aggregation_method,
              aggregation_method_kwargs=None,
              train):
        """Process a high resolution image by selecting a subset of useful patches.

    This model processes the input as follow:
    1. Compute scores per patch on a downscaled version of the input.
    2. Select "important" patches using sampling or top-k methods.
    3. Extract the patches from the high-resolution image.
    4. Compute representation vector for each patch with a feature network.
    5. Aggregate the patch representation to obtain an image representation.

    Args:
      x: Input tensor of shape (batch, height, witdh, channels).
      patch_size: Size of the (squared) patches to extract.
      k: Number of patches to extract per image.
      downscale: Downscale multiplier for the input of the scorer network.
      scorer_has_se: Whether scorer network has Squeeze-excite layers.
      normalization_str: String specifying the normalization of the scores.
      selection_method: Method that selects which patches should be extracted,
        based on their scores. Either returns indices (hard selection) or
        indicators vectors (which could yield interpolated patches).
      selection_method_kwargs: Keyword args for the selection_method.
      selection_method_inference: Selection method used at inference.
      patch_dropout: Probability to replace a patch by 0 values.
      hard_topk_probability: Probability to use the true topk on the scores to
        select the patches. This operation has no gradient so scorer's weights
        won't be trained.
      random_patch_probability: Probability to replace each patch by a random
        patch in the image during training.
      use_iterative_extraction: If True, uses a for loop instead of patch
        indexing for memory efficiency.
      append_position_to_input: Append normalized (height, width) position to
        the channels of the input.
      feature_network: Network to be applied on each patch individually to
        obtain patch representation vectors.
      aggregation_method: Method to aggregate the representations of the k
        patches of each image to obtain the image representation.
      aggregation_method_kwargs: Keywords arguments for aggregation_method.
      train: If the model is being trained. Disable dropout otherwise.

    Returns:
      A representation vector for each image in the batch.
    """
        selection_method = SelectionMethod(selection_method)
        aggregation_method = AggregationMethod(aggregation_method)
        if selection_method_inference:
            selection_method_inference = SelectionMethod(
                selection_method_inference)

        selection_method_kwargs = selection_method_kwargs or {}
        aggregation_method_kwargs = aggregation_method_kwargs or {}

        stats = {}

        # Compute new dimension of the scoring image.
        b, h, w, c = x.shape
        scoring_shape = (b, h // downscale, w // downscale, c)

        # === Compute the scores with a small CNN.
        if selection_method == SelectionMethod.RANDOM:
            scores_h, scores_w = Scorer.compute_output_size(
                h // downscale, w // downscale)
            num_patches = scores_h * scores_w
        else:
            # Downscale input to run scorer on.
            scoring_x = jax.image.resize(x, scoring_shape, method="bilinear")
            scores = Scorer(scoring_x,
                            use_squeeze_excite=scorer_has_se,
                            name="scorer")
            flatten_scores = einops.rearrange(scores, "b h w -> b (h w)")
            num_patches = flatten_scores.shape[-1]
            scores_h, scores_w = scores.shape[1:3]

            # Compute entropy before normalization
            prob_scores = jax.nn.softmax(flatten_scores)
            stats["entropy_before_normalization"] = jax.scipy.special.entr(
                prob_scores).sum(axis=1).mean(axis=0)

            # Normalize the flatten scores
            normalization_fn = create_normalization_fn(normalization_str)
            flatten_scores = normalization_fn(flatten_scores)
            scores = flatten_scores.reshape(scores.shape)
            stats["scores"] = scores[Ellipsis, None]

        # Concatenate height and width position to the input channels.
        if append_position_to_input:
            coords = utils.create_grid([h, w], value_range=(0., 1.))
            x = jnp.concatenate(
                [x, coords[jnp.newaxis, Ellipsis].repeat(b, axis=0)], axis=-1)
            c += 2

        # Overwrite the selection method at inference
        if selection_method_inference and not train:
            selection_method = selection_method_inference

        # === Patch selection

        # Select the patches by sampling or top-k. Some methods returns the indices
        # of the selected patches, other methods return indicator vectors.
        extract_by_indices = selection_method in [
            SelectionMethod.HARD_TOPK, SelectionMethod.RANDOM
        ]
        if selection_method is SelectionMethod.SINKHORN_TOPK:
            indicators = select_patches_sinkhorn_topk(
                flatten_scores, k=k, **selection_method_kwargs)
        elif selection_method is SelectionMethod.PERTURBED_TOPK:
            sigma = selection_method_kwargs["sigma"]
            num_samples = selection_method_kwargs["num_samples"]
            sigma *= self.state("sigma_mutiplier",
                                shape=(),
                                initializer=nn.initializers.ones).value
            stats["sigma"] = sigma
            indicators = select_patches_perturbed_topk(flatten_scores,
                                                       k=k,
                                                       sigma=sigma,
                                                       num_samples=num_samples)
        elif selection_method is SelectionMethod.HARD_TOPK:
            indices = select_patches_hard_topk(flatten_scores, k=k)
        elif selection_method is SelectionMethod.RANDOM:
            batch_random_indices_fn = jax.vmap(
                functools.partial(jax.random.choice,
                                  a=num_patches,
                                  shape=(k, ),
                                  replace=False))
            indices = batch_random_indices_fn(
                jax.random.split(nn.make_rng(), b))

        # Compute scores entropy for regularization
        if selection_method not in [SelectionMethod.RANDOM]:
            prob_scores = flatten_scores
            # Normalize the scores if it is not already done.
            if "softmax" not in normalization_str:
                prob_scores = jax.nn.softmax(prob_scores)
            stats["entropy"] = jax.scipy.special.entr(prob_scores).sum(
                axis=1).mean(axis=0)

        # Randomly use hard topk at training.
        if (train and hard_topk_probability > 0 and selection_method
                not in [SelectionMethod.HARD_TOPK, SelectionMethod.RANDOM]):
            true_indices = select_patches_hard_topk(flatten_scores, k=k)
            random_values = jax.random.uniform(nn.make_rng(), (b, ))
            use_hard = random_values < hard_topk_probability
            if extract_by_indices:
                indices = jnp.where(use_hard[:, None], true_indices, indices)
            else:
                true_indicators = make_indicators(true_indices, num_patches)
                indicators = jnp.where(use_hard[:, None, None],
                                       true_indicators, indicators)

        # Sample some random patches during training with random_patch_probability.
        if (train and random_patch_probability > 0
                and selection_method is not SelectionMethod.RANDOM):
            single_random_patches = functools.partial(jax.random.choice,
                                                      a=num_patches,
                                                      shape=(k, ),
                                                      replace=False)
            random_indices = jax.vmap(single_random_patches)(jax.random.split(
                nn.make_rng(), b))
            random_values = jax.random.uniform(nn.make_rng(), (b, k))
            use_random = random_values < random_patch_probability
            if extract_by_indices:
                indices = jnp.where(use_random, random_indices, indices)
            else:
                random_indicators = make_indicators(random_indices,
                                                    num_patches)
                indicators = jnp.where(use_random[:, None, :],
                                       random_indicators, indicators)

        # === Patch extraction
        if extract_by_indices:
            patches = extract_patches_from_indices(x,
                                                   indices,
                                                   patch_size=patch_size,
                                                   grid_shape=(scores_h,
                                                               scores_w))
            indicators = make_indicators(indices, num_patches)
        else:
            patches = extract_patches_from_indicators(
                x,
                indicators,
                patch_size,
                grid_shape=(scores_h, scores_w),
                iterative=use_iterative_extraction,
                patch_dropout=patch_dropout,
                train=train)

        chex.assert_shape(patches, (b, k, patch_size, patch_size, c))

        stats["extracted_patches"] = einops.rearrange(
            patches, "b k i j c -> b i (k j) c")
        # Remove position channels for plotting.
        if append_position_to_input:
            stats["extracted_patches"] = (
                stats["extracted_patches"][Ellipsis, :-2])

        # === Compute patch features
        flatten_patches = einops.rearrange(patches, "b k i j c -> (b k) i j c")
        representations = feature_network(flatten_patches, train=train)
        if representations.ndim > 2:
            collapse_axis = tuple(range(1, representations.ndim - 1))
            representations = representations.mean(axis=collapse_axis)
        representations = einops.rearrange(representations,
                                           "(b k) d -> b k d",
                                           k=k)

        stats["patch_representations"] = representations

        # === Aggregate the k patches

        # - for sampling we are forced to take an expectation
        # - for topk we have multiple options: mean, max, transformer.
        if aggregation_method is AggregationMethod.TRANSFORMER:
            patch_pos_encoding = nn.Dense(einops.rearrange(
                indicators, "b d k -> b k d"),
                                          features=representations.shape[-1])

            chex.assert_equal_shape([representations, patch_pos_encoding])
            representations += patch_pos_encoding
            representations = transformer.Transformer(
                representations,
                **aggregation_method_kwargs,
                is_training=train)

        elif aggregation_method is AggregationMethod.MEANPOOLING:
            representations = representations.mean(axis=1)
        elif aggregation_method is AggregationMethod.MAXPOOLING:
            representations = representations.max(axis=1)
        elif aggregation_method is AggregationMethod.SUM_LAYERNORM:
            representations = representations.sum(axis=1)
            representations = nn.LayerNorm(representations)

        representations = nn.Dense(representations,
                                   features=representations.shape[-1],
                                   name="classification_dense1")
        representations = nn.swish(representations)

        return representations, stats