def apply(self, x, channels, strides=(1, 1), dropout_rate=0.0, norm_layer='group_norm', train=True): norm_layer_name = '' if norm_layer == 'batch_norm': norm_layer = nn.BatchNorm.partial(use_running_average=not train) norm_layer_name = 'bn' elif norm_layer == 'group_norm': norm_layer = nn.GroupNorm.partial(num_groups=16) norm_layer_name = 'gn' y = norm_layer(x, name=f'{norm_layer_name}1') y = jax.nn.relu(y) y = nn.Conv(y, channels, (3, 3), strides, padding='SAME', name='conv1') y = norm_layer(y, name=f'{norm_layer_name}2') y = jax.nn.relu(y) if dropout_rate > 0.0: y = nn.dropout(y, dropout_rate, deterministic=not train) y = nn.Conv(y, channels, (3, 3), padding='SAME', name='conv2') # Apply an up projection in case of channel mismatch if (x.shape[-1] != channels) or strides != (1, 1): x = nn.Conv(x, channels, (3, 3), strides, padding='SAME') return x + y
def apply(self, x, act, normalize, temb=None, out_ch=None, conv_shortcut=False, dropout=0.1, train=True, skip_rescale=False, init_scale=0.): B, H, W, C = x.shape out_ch = out_ch if out_ch else C h = act(normalize(x, num_groups=min(x.shape[-1] // 4, 32))) h = conv3x3(h, out_ch) # Add bias to each feature map conditioned on the time embedding if temb is not None: h += nn.Dense( act(temb), out_ch, kernel_init=default_init())[:, None, None, :] h = act(normalize(h, num_groups=min(h.shape[-1] // 4, 32))) h = nn.dropout(h, dropout, deterministic=not train) h = conv3x3(h, out_ch, init_scale=init_scale) if C != out_ch: if conv_shortcut: x = conv3x3(x, out_ch) else: x = NIN(x, out_ch) if not skip_rescale: return x + h else: return (x + h) / np.sqrt(2.)
def apply(self, x, act, normalize, temb=None, out_ch=None, conv_shortcut=False, dropout=0.5, train=True): B, H, W, C = x.shape out_ch = out_ch if out_ch else C h = act(normalize(x)) h = ddpm_conv3x3(h, out_ch) # Add bias to each feature map conditioned on the time embedding if temb is not None: h += nn.Dense(act(temb), out_ch, kernel_init=default_init())[:, None, None, :] h = act(normalize(h)) h = nn.dropout(h, dropout, deterministic=not train) h = ddpm_conv3x3(h, out_ch, init_scale=0.) if C != out_ch: if conv_shortcut: x = ddpm_conv3x3(x, out_ch) else: x = NIN(x, out_ch) return x + h
def apply(self, inputs, mlp_dim, out_dim=None, dropout_rate=0.1, deterministic=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)): """Applies Transformer MlpBlock module.""" actual_out_dim = inputs.shape[-1] if out_dim is None else out_dim x = nn.Dense(inputs, mlp_dim, kernel_init=kernel_init, bias_init=bias_init) x = nn.gelu(x) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) output = nn.Dense( x, actual_out_dim, kernel_init=kernel_init, bias_init=bias_init) output = nn.dropout(output, rate=dropout_rate, deterministic=deterministic) return output
def apply(self, x, channels, strides=(1, 1), dropout_rate=0.0, normalization='bn', activation_f=None, std_penalty_mult=0, use_residual=1, train=True, bias_scale=0.0, weight_norm='none', compensate_padding=True): norm = get_norm(activation_f, normalization, train) conv = get_conv(activation_f, bias_scale, weight_norm, compensate_padding, normalization) penalty = 0 y = x y = norm(y, name='norm1') if std_penalty_mult > 0: penalty += std_penalty(y) y = activation_f(y, features=y.shape[-1]) y = conv( y, channels, (3, 3), strides, padding='SAME', name='conv1', ) y = norm(y, name='norm2') if std_penalty_mult > 0: penalty += std_penalty(y) y = activation_f(y, features=y.shape[-1]) if dropout_rate > 0.0: y = nn.dropout(y, dropout_rate, deterministic=not train) y = conv(y, channels, (3, 3), padding='SAME', name='conv2') if use_residual == 1: # Apply an up projection in case of channel mismatch if (x.shape[-1] != channels) or strides != (1, 1): x = conv(x, y.shape[-1], (3, 3), strides, padding='SAME') result = x + y elif use_residual == 2: # Unit variance preserving residual. if (x.shape[-1] != channels) or strides != (1, 1): x = conv(x, y.shape[-1], (3, 3), strides, padding='SAME') result = (x + y) / jnp.sqrt( 1**2 + 1**2) # Sum of independent normals. else: result = y return result, penalty
def apply(self, x, filters, strides=(1, 1), dropout_rate=0.0, epsilon=1e-5, momentum=0.9, norm_layer='batch_norm', train=True, dtype=jnp.float32): # TODO(samirabnar): Make 4 a parameter. needs_projection = x.shape[-1] != filters * 4 or strides != (1, 1) norm_layer_name = '' if norm_layer == 'batch_norm': norm_layer = nn.BatchNorm.partial(use_running_average=not train, momentum=momentum, epsilon=epsilon, dtype=dtype) norm_layer_name = 'bn' elif norm_layer == 'group_norm': norm_layer = nn.GroupNorm.partial(num_groups=16, dtype=dtype) norm_layer_name = 'gn' conv = nn.Conv.partial(bias=False, dtype=dtype) residual = x if needs_projection: residual = conv(residual, filters * 4, (1, 1), strides, name='proj_conv') residual = norm_layer(residual, name=f'proj_{norm_layer_name}') y = conv(x, filters, (1, 1), name='conv1') y = norm_layer(y, name=f'{norm_layer_name}1') y = nn.relu(y) y = conv(y, filters, (3, 3), strides, name='conv2') y = norm_layer(y, name=f'{norm_layer_name}2') y = nn.relu(y) if dropout_rate > 0.0: y = nn.dropout(y, dropout_rate, deterministic=not train) y = conv(y, filters * 4, (1, 1), name='conv3') y = norm_layer(y, name=f'{norm_layer_name}3', scale_init=nn.initializers.zeros) y = nn.relu(residual + y) return y
def apply(self, x, act, normalize, up=False, down=False, temb=None, out_ch=None, dropout=0.1, fir=False, fir_kernel=[1, 3, 3, 1], train=True, skip_rescale=True, init_scale=0.): B, H, W, C = x.shape out_ch = out_ch if out_ch else C h = act(normalize(x, num_groups=min(x.shape[-1] // 4, 32))) if up: if fir: h = up_or_down_sampling.upsample_2d(h, fir_kernel, factor=2) x = up_or_down_sampling.upsample_2d(x, fir_kernel, factor=2) else: h = up_or_down_sampling.naive_upsample_2d(h, factor=2) x = up_or_down_sampling.naive_upsample_2d(x, factor=2) elif down: if fir: h = up_or_down_sampling.downsample_2d(h, fir_kernel, factor=2) x = up_or_down_sampling.downsample_2d(x, fir_kernel, factor=2) else: h = up_or_down_sampling.naive_downsample_2d(h, factor=2) x = up_or_down_sampling.naive_downsample_2d(x, factor=2) h = conv3x3(h, out_ch) # Add bias to each feature map conditioned on the time embedding if temb is not None: h += nn.Dense( act(temb), out_ch, kernel_init=default_init())[:, None, None, :] h = act(normalize(h, num_groups=min(h.shape[-1] // 4, 32))) h = nn.dropout(h, dropout, deterministic=not train) h = conv3x3(h, out_ch, init_scale=init_scale) if C != out_ch or up or down: x = conv1x1(x, out_ch) if not skip_rescale: return x + h else: return (x + h) / np.sqrt(2.)
def apply(self, x, channels, strides=(1, 1), dropout_rate=0.0, train=True): batch_norm = nn.BatchNorm.partial(use_running_average=not train, momentum=0.9, epsilon=1e-5) y = batch_norm(x, name='bn1') y = jax.nn.relu(y) y = nn.Conv(y, channels, (3, 3), strides, padding='SAME', name='conv1') y = batch_norm(y, name='bn2') y = jax.nn.relu(y) if dropout_rate > 0.0: y = nn.dropout(y, dropout_rate, deterministic=not train) y = nn.Conv(y, channels, (3, 3), padding='SAME', name='conv2') # Apply an up projection in case of channel mismatch if (x.shape[-1] != channels) or strides != (1, 1): x = nn.Conv(x, channels, (3, 3), strides, padding='SAME') return x + y
def apply(self, inputs, vocab_size, emb_dim=512, num_heads=8, num_layers=6, qkv_dim=512, mlp_dim=2048, max_len=2048, train=False, dropout_rate=0.1, attention_dropout_rate=0.1, causal=True, cache=None, positional_encoding_module=AddLearnedPositionalEncodings, self_attention_module=nn.SelfAttention, attention_fn=None, pad_token=None, output_head='logits'): """Applies Transformer model on the inputs. Args: inputs: An array of shape (batch_size, length) or (batch_size, length, vocab_size) with the input sequences. When 2-dimensional, the array contains sequences of int tokens. Otherwise, the array contains next-token distributions over tokens (e.g. one-hot representations). vocab_size: An int with the size of the vocabulary. emb_dim: An int with the token embedding dimension. num_heads: An int with the number of attention heads. num_layers: An int with the number of transformer encoder layers. qkv_dim: An int with the dimension of the query/key/value vectors. mlp_dim: An int with the inner dimension of the feed-forward network which follows the attention block. max_len: An int with the maximum training sequence length. train: A bool denoting whether we are currently training. dropout_rate: A float with the dropout rate. attention_dropout_rate: A float with a dropout rate for attention weights. causal: Whether to apply causal masking. cache: Cache for decoding. positional_encoding_module: A module used for adding positional encodings. self_attention_module: Self attention module. attention_fn: Method to use in place of dot product attention. pad_token: Token to ignore in attention. output_head: String or iterable over strings containing the model's output head(s) to return. Returns: Output of a transformer decoder. If output_head is a string, we return a single output head output; if output_head is an iterable, we return a dict with (output head name, output head output) key-value pairs. """ if inputs.ndim != 2 and inputs.ndim != 3: raise ValueError('Expected 2 or 3 dimensions, found %d.' % inputs.ndim) if inputs.ndim == 3: padding_mask = jnp.ones_like(inputs[Ellipsis, 0]) elif pad_token is None: padding_mask = jnp.ones_like(inputs) else: # Mask out padding tokens. padding_mask = jnp.where(inputs != pad_token, 1, 0).astype(jnp.float32) padding_mask = padding_mask[Ellipsis, None] # Add embedding dimension. heads = dict() x = inputs if inputs.ndim == 2: x = x.astype('int32') x = Embed(x, num_embeddings=vocab_size, num_features=emb_dim, name='embed') if positional_encoding_module == AddLearnedPositionalEncodings: x = positional_encoding_module( x, max_len=max_len, cache=cache, posemb_init=sinusoidal_init(max_len=max_len)) else: x = positional_encoding_module(x, max_len=max_len) x = nn.dropout(x, rate=dropout_rate, deterministic=not train) heads['input_emb'] = x for i in range(num_layers): x = Transformer1DBlock( x, qkv_dim=qkv_dim, mlp_dim=mlp_dim, num_heads=num_heads, causal_mask=causal, padding_mask=padding_mask, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, self_attention_module=self_attention_module, deterministic=not train, attention_fn=attention_fn, cache=cache, ) heads['layer_%s' % i] = x x = nn.LayerNorm(x) heads['output_emb'] = x * padding_mask # Zero out PAD positions. if 'logits' in output_head: logits = nn.Dense(x, vocab_size, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) heads['logits'] = logits if 'regression' in output_head: regression = nn.Dense( x, 1, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) regression = jnp.squeeze(regression, axis=-1) heads['regression'] = regression if isinstance(output_head, (tuple, list)): return {head: heads[head] for head in output_head} return heads[output_head]
def apply(self, inputs, qkv_dim, mlp_dim, num_heads, causal_mask=False, padding_mask=None, dropout_rate=0.1, attention_dropout_rate=0.1, deterministic=False, self_attention_module=nn.SelfAttention, attention_fn=None, cache=None): """Applies Transformer1DBlock module. Args: inputs: input data qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block num_heads: number of heads causal_mask: bool, mask future or not padding_mask: bool, mask padding tokens dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights deterministic: bool, deterministic or not (to apply dropout) self_attention_module: Self attention module. attention_fn: dot product function to use inside attention. cache: Cache for decoding. Returns: output after transformer block. """ # Attention block. assert inputs.ndim == 3 x = nn.LayerNorm(inputs) if attention_fn is not None: self_attention_module = self_attention_module.partial( attention_fn=attention_fn) x = self_attention_module( x, num_heads=num_heads, qkv_features=qkv_dim, attention_axis=(1, ), causal_mask=causal_mask, padding_mask=padding_mask, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=deterministic, cache=cache) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(x) y = MlpBlock(y, mlp_dim=mlp_dim, dropout_rate=dropout_rate, deterministic=deterministic) return x + y
def apply(self, x, *, self_attention_module, dim_intermediate, is_training, dropout_rate=0.1, use_pre_layernorm=False, layernorm_epsilon=1e-6, with_aux_outputs=True): """Compute self-attention with a feed-forward network on top. Args: x: Input representations. self_attention_module: Self-Attention layer. dim_intermediate: Size of the intermediate layer of the feed forward. is_training: Wether to enable dropout. dropout_rate: Dropout probability. use_pre_layernorm: Use pre layer norm from https://arxiv.org/abs/2002.04745. layernorm_epsilon: Epsilon parameter for all the layer norms. with_aux_outputs: Whether the self_attention_module has an aux output. Returns: New representations in a jnp.array of same shape as `x`. """ dim_hidden = x.shape[-1] use_pre_ln = use_pre_layernorm use_post_ln = not use_pre_ln def apply_ln_if(pred, x, name): if pred: return nn.LayerNorm(x, epsilon=layernorm_epsilon, name=name) else: return x # attention x = apply_ln_if(use_pre_ln, x, "ln_pre_att") x_att = self_attention_module(x) if with_aux_outputs: x_att, output_aux = x_att # dropout norm and add x_att = nn.dropout(x_att, dropout_rate, deterministic=not is_training) x = x + x_att x = apply_ln_if(use_post_ln, x, "ln_post_att") # feed forward x_ffn = x x_ffn = apply_ln_if(use_pre_ln, x, "ln_pre_ffn") x_ffn = nn.Dense(x_ffn, dim_intermediate, name="ff_1") x_ffn = jax.nn.relu(x_ffn) x_ffn = nn.Dense(x_ffn, dim_hidden, name="ff_2") # dropout norm and add x_ffn = nn.dropout(x_ffn, dropout_rate, deterministic=not is_training) x = x + x_ffn x = apply_ln_if(use_post_ln, x, "ln_post_ffn") if with_aux_outputs: output = x, output_aux else: output = x return output
def apply(self, x, config, num_classes, train=True): """Creates a model definition.""" b, c = x.shape[0], x.shape[3] k = config.k sigma = config.ptopk_sigma num_samples = config.ptopk_num_samples sigma *= self.state("sigma_mutiplier", shape=(), initializer=nn.initializers.ones).value stats = {"x": x, "sigma": sigma} feature_extractor = models.ResNet50.shared(train=train, name="ResNet_0") rpn_feature = feature_extractor(x) rpn_scores, rpn_stats = ProposalNet(jax.lax.stop_gradient(rpn_feature), communication=Communication( config.communication), train=train) stats.update(rpn_stats) # rpn_scores are a list of score images. We keep track of the structure # because it is used in the aggregation step later-on. rpn_scores_shapes = [s.shape for s in rpn_scores] rpn_scores_flat = jnp.concatenate( [jnp.reshape(s, [b, -1]) for s in rpn_scores], axis=1) top_k_indicators = sample_patches.select_patches_perturbed_topk( rpn_scores_flat, k=k, sigma=sigma, num_samples=num_samples) top_k_indicators = jnp.transpose(top_k_indicators, [0, 2, 1]) offset = 0 weights = [] for sh in rpn_scores_shapes: cur = top_k_indicators[:, :, offset:offset + sh[1] * sh[2]] cur = jnp.reshape(cur, [b, k, sh[1], sh[2]]) weights.append(cur) offset += sh[1] * sh[2] chex.assert_equal(offset, top_k_indicators.shape[-1]) part_imgs = weighted_anchor_aggregator(x, weights) chex.assert_shape(part_imgs, (b * k, 224, 224, c)) stats["part_imgs"] = jnp.reshape(part_imgs, [b, k * 224, 224, c]) part_features = feature_extractor(part_imgs) part_features = jnp.mean(part_features, axis=[1, 2]) # GAP the spatial dims part_features = nn.dropout( # features from parts jnp.reshape(part_features, [b * k, 2048]), 0.5, deterministic=not train, rng=nn.make_rng()) features = nn.dropout( # features from whole image jnp.reshape(jnp.mean(rpn_feature, axis=[1, 2]), [b, -1]), 0.5, deterministic=not train, rng=nn.make_rng()) # Mean pool all part features, add it to features and predict logits. concat_out = jnp.mean(jnp.reshape(part_features, [b, k, 2048]), axis=1) + features concat_logits = nn.Dense(concat_out, num_classes) raw_logits = nn.Dense(features, num_classes) part_logits = jnp.reshape(nn.Dense(part_features, num_classes), [b, k, -1]) all_logits = { "raw_logits": raw_logits, "concat_logits": concat_logits, "part_logits": part_logits, } # add entropy into it for entropy regularization. stats["rpn_scores_entropy"] = jax.scipy.special.entr( jax.nn.softmax(stats["raw_scores"])).sum(axis=1).mean(axis=0) return all_logits, stats
def extract_patches_from_indicators(x, indicators, patch_size, patch_dropout, grid_shape, train, iterative=False): """Extract patches from a batch of images. Args: x: The batch of images of shape (batch, height, width, channels). indicators: The one hot indicators of shape (batch, num_patches, k). patch_size: The size of the (squared) patches to extract. patch_dropout: Probability to replace a patch by 0 values. grid_shape: Pair of height, width of the disposition of the num_patches patches. train: If the model is being trained. Disable dropout if not. iterative: If True, etracts the patches with a for loop rather than instanciating the "all patches" tensor and extracting by dotproduct with indicators. `iterative` is more memory efficient. Returns: The patches extracted from x with shape (batch, k, patch_size, patch_size, channels). """ batch_size, height, width, channels = x.shape scores_h, scores_w = grid_shape k = indicators.shape[-1] indicators = einops.rearrange(indicators, "b (h w) k -> b k h w", h=scores_h, w=scores_w) scale_height = height // scores_h scale_width = width // scores_w padded_height = scale_height * scores_h + patch_size - 1 padded_width = scale_width * scores_w + patch_size - 1 top_pad = (patch_size - scale_height) // 2 left_pad = (patch_size - scale_width) // 2 bottom_pad = padded_height - top_pad - height right_pad = padded_width - left_pad - width # TODO(jbcdnr): assert padding is positive. padded_x = jnp.pad(x, [(0, 0), (top_pad, bottom_pad), (left_pad, right_pad), (0, 0)]) # Extract the patches. Iterative fits better in memory as it does not # instanciate the "all patches" tensor but iterate over them to compute the # weighted sum with the indicator variables from topk. if not iterative: assert patch_dropout == 0., "Patch dropout not implemented." patches = utils.extract_images_patches(padded_x, window_size=(patch_size, patch_size), stride=(scale_height, scale_width)) shape = (batch_size, scores_h, scores_w, patch_size, patch_size, channels) chex.assert_shape(patches, shape) patches = jnp.einsum("b k h w, b h w i j c -> b k i j c", indicators, patches) else: mask = jnp.ones((batch_size, scores_h, scores_w)) mask = nn.dropout(mask, patch_dropout, deterministic=not train) def accumulate_patches(acc, index_i_j): i, j = index_i_j patch = jax.lax.dynamic_slice( padded_x, (0, i * scale_height, j * scale_width, 0), (batch_size, patch_size, patch_size, channels)) weights = indicators[:, :, i, j] is_masked = mask[:, i, j] weighted_patch = jnp.einsum("b, bk, bijc -> bkijc", is_masked, weights, patch) chex.assert_equal_shape([acc, weighted_patch]) acc += weighted_patch return acc, None indices = jnp.stack(jnp.meshgrid(jnp.arange(scores_h), jnp.arange(scores_w), indexing="ij"), axis=-1) indices = indices.reshape((-1, 2)) init_patches = jnp.zeros( (batch_size, k, patch_size, patch_size, channels)) patches, _ = jax.lax.scan(accumulate_patches, init_patches, indices) return patches
def apply(self, inputs, num_outputs, num_filters=64, num_layers=50, dropout_rate=0.0, input_dropout_rate=0.0, train=True, dtype=jnp.float32, head_bias_init=jnp.zeros, return_activations=False, input_layer_key='input', has_discriminator=False, discriminator=False): """Apply a ResNet network on the input. Args: inputs: jnp array; Inputs. num_outputs: int; Number of output units. num_filters: int; Determines base number of filters. Number of filters in block i is num_filters * 2 ** i. num_layers: int; Number of layers (should be one of the predefined ones.) dropout_rate: float; Rate of dropping out the output of different hidden layers. input_dropout_rate: float; Rate of dropping out the input units. train: bool; Is train? dtype: jnp type; Type of the outputs. head_bias_init: fn(rng_key, shape)--> jnp array; Initializer for head bias parameters. return_activations: bool; If True hidden activation are also returned. input_layer_key: str; Determines where to plugin the input (this is to enable providing inputs to slices of the model). If `input_layer_key` is `layer_i` we assume the inputs are the activations of `layer_i` and pass them to `layer_{i+1}`. has_discriminator: bool; Whether the model should have discriminator layer. discriminator: bool; Whether we should return discriminator logits. Returns: Unnormalized Logits with shape `[bs, num_outputs]`, if return_activations: Logits, dict of hidden activations and the key to the representation(s) which will be used in as ``The Representation'', e.g., for computing losses. """ if num_layers not in ResNet._block_size_options: raise ValueError('Please provide a valid number of layers') block_sizes = ResNet._block_size_options[num_layers] layer_activations = collections.OrderedDict() input_is_set = False current_rep_key = 'input' if input_layer_key == current_rep_key: x = inputs input_is_set = True if input_is_set: # Input dropout x = nn.dropout(x, input_dropout_rate, deterministic=not train) layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'init_conv' if input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_is_set: # First block x = nn.Conv(x, num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], bias=False, dtype=dtype, name='init_conv') x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, dtype=dtype, name='init_bn') x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') layer_activations[current_rep_key] = x rep_key = current_rep_key # Residual blocks for i, block_size in enumerate(block_sizes): # Stage i (each stage contains blocks of the same size). for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) current_rep_key = f'block_{i + 1}+{j}' if input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_is_set: x = ResidualBlock(x, num_filters * 2**i, strides=strides, dropout_rate=dropout_rate, train=train, dtype=dtype, name=f'block_{i + 1}_{j}') layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'avg_pool' if input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_is_set: # Global Average Pool x = jnp.mean(x, axis=(1, 2)) layer_activations[current_rep_key] = x rep_key = current_rep_key # DANN module if has_discriminator: z = dann_utils.flip_grad_identity(x) z = nn.Dense(z, 2, name='disc_l1', bias=True) z = nn.relu(z) z = nn.Dense(z, 2, name='disc_l2', bias=True) current_rep_key = 'head' if input_layer_key == current_rep_key: x = inputs layer_activations[current_rep_key] = x rep_key = current_rep_key logging.warn('Input was never used') elif input_is_set: x = nn.Dense(x, num_outputs, dtype=dtype, bias_init=head_bias_init, name='head') # Make sure that the output is float32, even if our previous computations # are in float16, or other types. x = jnp.asarray(x, jnp.float32) outputs = x if return_activations: outputs = (x, layer_activations, rep_key) if discriminator and has_discriminator: outputs = outputs + (z, ) else: del layer_activations if discriminator and has_discriminator: outputs = (x, z) if discriminator and (not has_discriminator): raise ValueError( 'Incosistent values passed for discriminator and has_discriminator' ) return outputs