def apply(self, inputs, vocab_size, emb_dim=512, num_heads=8, num_layers=6, qkv_dim=512, mlp_dim=2048, max_len=2048, train=False, dropout_rate=0.1, attention_dropout_rate=0.1, causal=True, cache=None, positional_encoding_module=AddLearnedPositionalEncodings, self_attention_module=nn.SelfAttention, attention_fn=None, pad_token=None, output_head='logits'): """Applies Transformer model on the inputs. Args: inputs: An array of shape (batch_size, length) or (batch_size, length, vocab_size) with the input sequences. When 2-dimensional, the array contains sequences of int tokens. Otherwise, the array contains next-token distributions over tokens (e.g. one-hot representations). vocab_size: An int with the size of the vocabulary. emb_dim: An int with the token embedding dimension. num_heads: An int with the number of attention heads. num_layers: An int with the number of transformer encoder layers. qkv_dim: An int with the dimension of the query/key/value vectors. mlp_dim: An int with the inner dimension of the feed-forward network which follows the attention block. max_len: An int with the maximum training sequence length. train: A bool denoting whether we are currently training. dropout_rate: A float with the dropout rate. attention_dropout_rate: A float with a dropout rate for attention weights. causal: Whether to apply causal masking. cache: Cache for decoding. positional_encoding_module: A module used for adding positional encodings. self_attention_module: Self attention module. attention_fn: Method to use in place of dot product attention. pad_token: Token to ignore in attention. output_head: String or iterable over strings containing the model's output head(s) to return. Returns: Output of a transformer decoder. If output_head is a string, we return a single output head output; if output_head is an iterable, we return a dict with (output head name, output head output) key-value pairs. """ if inputs.ndim != 2 and inputs.ndim != 3: raise ValueError('Expected 2 or 3 dimensions, found %d.' % inputs.ndim) if inputs.ndim == 3: padding_mask = jnp.ones_like(inputs[Ellipsis, 0]) elif pad_token is None: padding_mask = jnp.ones_like(inputs) else: # Mask out padding tokens. padding_mask = jnp.where(inputs != pad_token, 1, 0).astype(jnp.float32) padding_mask = padding_mask[Ellipsis, None] # Add embedding dimension. heads = dict() x = inputs if inputs.ndim == 2: x = x.astype('int32') x = Embed(x, num_embeddings=vocab_size, num_features=emb_dim, name='embed') if positional_encoding_module == AddLearnedPositionalEncodings: x = positional_encoding_module( x, max_len=max_len, cache=cache, posemb_init=sinusoidal_init(max_len=max_len)) else: x = positional_encoding_module(x, max_len=max_len) x = nn.dropout(x, rate=dropout_rate, deterministic=not train) heads['input_emb'] = x for i in range(num_layers): x = Transformer1DBlock( x, qkv_dim=qkv_dim, mlp_dim=mlp_dim, num_heads=num_heads, causal_mask=causal, padding_mask=padding_mask, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, self_attention_module=self_attention_module, deterministic=not train, attention_fn=attention_fn, cache=cache, ) heads['layer_%s' % i] = x x = nn.LayerNorm(x) heads['output_emb'] = x * padding_mask # Zero out PAD positions. if 'logits' in output_head: logits = nn.Dense(x, vocab_size, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) heads['logits'] = logits if 'regression' in output_head: regression = nn.Dense( x, 1, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) regression = jnp.squeeze(regression, axis=-1) heads['regression'] = regression if isinstance(output_head, (tuple, list)): return {head: heads[head] for head in output_head} return heads[output_head]
def apply_ln_if(pred, x, name): if pred: return nn.LayerNorm(x, epsilon=layernorm_epsilon, name=name) else: return x
def apply(self, inputs, qkv_dim, mlp_dim, num_heads, causal_mask=False, padding_mask=None, dropout_rate=0.1, attention_dropout_rate=0.1, deterministic=False, self_attention_module=nn.SelfAttention, attention_fn=None, cache=None): """Applies Transformer1DBlock module. Args: inputs: input data qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block num_heads: number of heads causal_mask: bool, mask future or not padding_mask: bool, mask padding tokens dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights deterministic: bool, deterministic or not (to apply dropout) self_attention_module: Self attention module. attention_fn: dot product function to use inside attention. cache: Cache for decoding. Returns: output after transformer block. """ # Attention block. assert inputs.ndim == 3 x = nn.LayerNorm(inputs) if attention_fn is not None: self_attention_module = self_attention_module.partial( attention_fn=attention_fn) x = self_attention_module( x, num_heads=num_heads, qkv_features=qkv_dim, attention_axis=(1, ), causal_mask=causal_mask, padding_mask=padding_mask, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=deterministic, cache=cache) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(x) y = MlpBlock(y, mlp_dim=mlp_dim, dropout_rate=dropout_rate, deterministic=deterministic) return x + y
def apply(self, x, *, patch_size, k, downscale, scorer_has_se, normalization_str="identity", selection_method, selection_method_kwargs=None, selection_method_inference=None, patch_dropout=0., hard_topk_probability=0., random_patch_probability=0., use_iterative_extraction, append_position_to_input, feature_network, aggregation_method, aggregation_method_kwargs=None, train): """Process a high resolution image by selecting a subset of useful patches. This model processes the input as follow: 1. Compute scores per patch on a downscaled version of the input. 2. Select "important" patches using sampling or top-k methods. 3. Extract the patches from the high-resolution image. 4. Compute representation vector for each patch with a feature network. 5. Aggregate the patch representation to obtain an image representation. Args: x: Input tensor of shape (batch, height, witdh, channels). patch_size: Size of the (squared) patches to extract. k: Number of patches to extract per image. downscale: Downscale multiplier for the input of the scorer network. scorer_has_se: Whether scorer network has Squeeze-excite layers. normalization_str: String specifying the normalization of the scores. selection_method: Method that selects which patches should be extracted, based on their scores. Either returns indices (hard selection) or indicators vectors (which could yield interpolated patches). selection_method_kwargs: Keyword args for the selection_method. selection_method_inference: Selection method used at inference. patch_dropout: Probability to replace a patch by 0 values. hard_topk_probability: Probability to use the true topk on the scores to select the patches. This operation has no gradient so scorer's weights won't be trained. random_patch_probability: Probability to replace each patch by a random patch in the image during training. use_iterative_extraction: If True, uses a for loop instead of patch indexing for memory efficiency. append_position_to_input: Append normalized (height, width) position to the channels of the input. feature_network: Network to be applied on each patch individually to obtain patch representation vectors. aggregation_method: Method to aggregate the representations of the k patches of each image to obtain the image representation. aggregation_method_kwargs: Keywords arguments for aggregation_method. train: If the model is being trained. Disable dropout otherwise. Returns: A representation vector for each image in the batch. """ selection_method = SelectionMethod(selection_method) aggregation_method = AggregationMethod(aggregation_method) if selection_method_inference: selection_method_inference = SelectionMethod( selection_method_inference) selection_method_kwargs = selection_method_kwargs or {} aggregation_method_kwargs = aggregation_method_kwargs or {} stats = {} # Compute new dimension of the scoring image. b, h, w, c = x.shape scoring_shape = (b, h // downscale, w // downscale, c) # === Compute the scores with a small CNN. if selection_method == SelectionMethod.RANDOM: scores_h, scores_w = Scorer.compute_output_size( h // downscale, w // downscale) num_patches = scores_h * scores_w else: # Downscale input to run scorer on. scoring_x = jax.image.resize(x, scoring_shape, method="bilinear") scores = Scorer(scoring_x, use_squeeze_excite=scorer_has_se, name="scorer") flatten_scores = einops.rearrange(scores, "b h w -> b (h w)") num_patches = flatten_scores.shape[-1] scores_h, scores_w = scores.shape[1:3] # Compute entropy before normalization prob_scores = jax.nn.softmax(flatten_scores) stats["entropy_before_normalization"] = jax.scipy.special.entr( prob_scores).sum(axis=1).mean(axis=0) # Normalize the flatten scores normalization_fn = create_normalization_fn(normalization_str) flatten_scores = normalization_fn(flatten_scores) scores = flatten_scores.reshape(scores.shape) stats["scores"] = scores[Ellipsis, None] # Concatenate height and width position to the input channels. if append_position_to_input: coords = utils.create_grid([h, w], value_range=(0., 1.)) x = jnp.concatenate( [x, coords[jnp.newaxis, Ellipsis].repeat(b, axis=0)], axis=-1) c += 2 # Overwrite the selection method at inference if selection_method_inference and not train: selection_method = selection_method_inference # === Patch selection # Select the patches by sampling or top-k. Some methods returns the indices # of the selected patches, other methods return indicator vectors. extract_by_indices = selection_method in [ SelectionMethod.HARD_TOPK, SelectionMethod.RANDOM ] if selection_method is SelectionMethod.SINKHORN_TOPK: indicators = select_patches_sinkhorn_topk( flatten_scores, k=k, **selection_method_kwargs) elif selection_method is SelectionMethod.PERTURBED_TOPK: sigma = selection_method_kwargs["sigma"] num_samples = selection_method_kwargs["num_samples"] sigma *= self.state("sigma_mutiplier", shape=(), initializer=nn.initializers.ones).value stats["sigma"] = sigma indicators = select_patches_perturbed_topk(flatten_scores, k=k, sigma=sigma, num_samples=num_samples) elif selection_method is SelectionMethod.HARD_TOPK: indices = select_patches_hard_topk(flatten_scores, k=k) elif selection_method is SelectionMethod.RANDOM: batch_random_indices_fn = jax.vmap( functools.partial(jax.random.choice, a=num_patches, shape=(k, ), replace=False)) indices = batch_random_indices_fn( jax.random.split(nn.make_rng(), b)) # Compute scores entropy for regularization if selection_method not in [SelectionMethod.RANDOM]: prob_scores = flatten_scores # Normalize the scores if it is not already done. if "softmax" not in normalization_str: prob_scores = jax.nn.softmax(prob_scores) stats["entropy"] = jax.scipy.special.entr(prob_scores).sum( axis=1).mean(axis=0) # Randomly use hard topk at training. if (train and hard_topk_probability > 0 and selection_method not in [SelectionMethod.HARD_TOPK, SelectionMethod.RANDOM]): true_indices = select_patches_hard_topk(flatten_scores, k=k) random_values = jax.random.uniform(nn.make_rng(), (b, )) use_hard = random_values < hard_topk_probability if extract_by_indices: indices = jnp.where(use_hard[:, None], true_indices, indices) else: true_indicators = make_indicators(true_indices, num_patches) indicators = jnp.where(use_hard[:, None, None], true_indicators, indicators) # Sample some random patches during training with random_patch_probability. if (train and random_patch_probability > 0 and selection_method is not SelectionMethod.RANDOM): single_random_patches = functools.partial(jax.random.choice, a=num_patches, shape=(k, ), replace=False) random_indices = jax.vmap(single_random_patches)(jax.random.split( nn.make_rng(), b)) random_values = jax.random.uniform(nn.make_rng(), (b, k)) use_random = random_values < random_patch_probability if extract_by_indices: indices = jnp.where(use_random, random_indices, indices) else: random_indicators = make_indicators(random_indices, num_patches) indicators = jnp.where(use_random[:, None, :], random_indicators, indicators) # === Patch extraction if extract_by_indices: patches = extract_patches_from_indices(x, indices, patch_size=patch_size, grid_shape=(scores_h, scores_w)) indicators = make_indicators(indices, num_patches) else: patches = extract_patches_from_indicators( x, indicators, patch_size, grid_shape=(scores_h, scores_w), iterative=use_iterative_extraction, patch_dropout=patch_dropout, train=train) chex.assert_shape(patches, (b, k, patch_size, patch_size, c)) stats["extracted_patches"] = einops.rearrange( patches, "b k i j c -> b i (k j) c") # Remove position channels for plotting. if append_position_to_input: stats["extracted_patches"] = ( stats["extracted_patches"][Ellipsis, :-2]) # === Compute patch features flatten_patches = einops.rearrange(patches, "b k i j c -> (b k) i j c") representations = feature_network(flatten_patches, train=train) if representations.ndim > 2: collapse_axis = tuple(range(1, representations.ndim - 1)) representations = representations.mean(axis=collapse_axis) representations = einops.rearrange(representations, "(b k) d -> b k d", k=k) stats["patch_representations"] = representations # === Aggregate the k patches # - for sampling we are forced to take an expectation # - for topk we have multiple options: mean, max, transformer. if aggregation_method is AggregationMethod.TRANSFORMER: patch_pos_encoding = nn.Dense(einops.rearrange( indicators, "b d k -> b k d"), features=representations.shape[-1]) chex.assert_equal_shape([representations, patch_pos_encoding]) representations += patch_pos_encoding representations = transformer.Transformer( representations, **aggregation_method_kwargs, is_training=train) elif aggregation_method is AggregationMethod.MEANPOOLING: representations = representations.mean(axis=1) elif aggregation_method is AggregationMethod.MAXPOOLING: representations = representations.max(axis=1) elif aggregation_method is AggregationMethod.SUM_LAYERNORM: representations = representations.sum(axis=1) representations = nn.LayerNorm(representations) representations = nn.Dense(representations, features=representations.shape[-1], name="classification_dense1") representations = nn.swish(representations) return representations, stats