def apply(self, x, channels, strides=(1, 1), dropout_rate=0.0, norm_layer='group_norm', train=True): norm_layer_name = '' if norm_layer == 'batch_norm': norm_layer = nn.BatchNorm.partial(use_running_average=not train) norm_layer_name = 'bn' elif norm_layer == 'group_norm': norm_layer = nn.GroupNorm.partial(num_groups=16) norm_layer_name = 'gn' y = norm_layer(x, name=f'{norm_layer_name}1') y = jax.nn.relu(y) y = nn.Conv(y, channels, (3, 3), strides, padding='SAME', name='conv1') y = norm_layer(y, name=f'{norm_layer_name}2') y = jax.nn.relu(y) if dropout_rate > 0.0: y = nn.dropout(y, dropout_rate, deterministic=not train) y = nn.Conv(y, channels, (3, 3), padding='SAME', name='conv2') # Apply an up projection in case of channel mismatch if (x.shape[-1] != channels) or strides != (1, 1): x = nn.Conv(x, channels, (3, 3), strides, padding='SAME') return x + y
def apply(self, x, num_filters=64, block_sizes=(3, 4, 6, 3), train=True, block=BottleneckBlock, small_inputs=False): if small_inputs: x = nn.Conv(x, num_filters, kernel_size=(3, 3), strides=(1, 1), bias=False, name="init_conv") else: x = nn.Conv(x, num_filters, kernel_size=(7, 7), strides=(2, 2), bias=False, name="init_conv") x = nn.BatchNorm(x, use_running_average=not train, epsilon=1e-5, name="init_bn") if not small_inputs: x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME") for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = block(x, num_filters * 2**i, strides=strides, train=train) return x
def apply(self, x, filters, strides=(1, 1), groups=1, base_width=64, train=True): needs_projection = x.shape[-1] != filters * 4 or strides != (1, 1) width = int(filters * (base_width / 64.)) * groups batch_norm = nn.BatchNorm.partial(use_running_average=not train, momentum=0.9, epsilon=1e-5) y = nn.Conv(x, width, (1, 1), (1, 1), bias=False, name='conv1') y = batch_norm(y, name='bn1') y = jax.nn.relu(y) y = nn.Conv(y, width, (3, 3), strides, bias=False, feature_group_count=groups, name='conv2') y = batch_norm(y, name='bn2') y = jax.nn.relu(y) y = nn.Conv(y, filters * 4, (1, 1), (1, 1), bias=False, name='conv3') y = batch_norm(y, name='bn3', scale_init=initializers.zeros) if needs_projection: x = nn.Conv(x, filters * 4, (1, 1), strides, bias=False, name='proj_conv') x = batch_norm(x, name='proj_bn') return jax.nn.relu(x + y)
def apply(self, x, inner_channels=8): x = nn.Conv(x, features=inner_channels, kernel_size=(3, 3), bias=False, padding='SAME') x = nn.Conv(x, features=inner_channels, kernel_size=(3, 3), bias=False, padding='SAME') x = nn.Conv(x, features=inner_channels, kernel_size=(3, 3), bias=False, padding='SAME') return x
def apply(self, x, num_cycles=3, inner_channels=8): x = Relaxation(x, inner_channels=inner_channels) if num_cycles > 0: x1 = nn.Conv(x, features=inner_channels, kernel_size=(3, 3), bias=False, strides=(2, 2), padding='VALID') x1 = Cycle(x1, num_cycles=num_cycles-1, inner_channels=inner_channels) x1 = nn.Conv(x1, features=1, kernel_size=(3, 3), bias=False, input_dilation=(2,2),padding=[(2, 2), (2, 2)]) x = x + x1 x = Relaxation(x, inner_channels=inner_channels) return x
def apply(self, x, use_squeeze_excite = False): x = nn.Conv(x, features=8, kernel_size=(3, 3), padding="VALID") x = nn.relu(x) x = nn.Conv(x, features=16, kernel_size=(3, 3), padding="VALID") x = nn.relu(x) if use_squeeze_excite: x = SqueezeExciteLayer(x) x = nn.Conv(x, features=32, kernel_size=(3, 3), padding="VALID") x = nn.relu(x) if use_squeeze_excite: x = SqueezeExciteLayer(x) x = nn.Conv(x, features=1, kernel_size=(3, 3), padding="VALID") scores = nn.max_pool(x, window_shape=(8, 8), strides=(8, 8))[Ellipsis, 0] return scores
def apply(self, x, num_cycles=3, inner_channels=8): x = NonLinearRelaxation(x, inner_channels=inner_channels) if num_cycles > 0: x1 = nn.Conv(x, features=inner_channels, kernel_size=(3, 3), bias=False, strides=(2, 2), padding='VALID') x1 = NonLinearCycle(x1, num_cycles=num_cycles-1, inner_channels=inner_channels) x1 = nn.Conv(x1, features=1, kernel_size=(3, 3), bias=False, input_dilation=(2,2),padding=[(2, 2), (2, 2)]) x = x + x1 #x = np.concatenate((x,x1), axis=3) #print(x.shape) x = NonLinearRelaxation(x, inner_channels=inner_channels) return x
def apply(self, x): x = nn.Conv(x, features=32, kernel_size=(3, 3), name="conv") x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten. x = nn.Dense(x, 128, name="fc") return x
def apply(self, x, blocks_per_group, channel_multiplier, num_outputs, dropout_rate=0.0, train=True): x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv') x = WideResnetGroup(x, blocks_per_group, 16 * channel_multiplier, dropout_rate=dropout_rate, train=train) x = WideResnetGroup(x, blocks_per_group, 32 * channel_multiplier, (2, 2), dropout_rate=dropout_rate, train=train) x = WideResnetGroup(x, blocks_per_group, 64 * channel_multiplier, (2, 2), dropout_rate=dropout_rate, train=train) x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5) x = jax.nn.relu(x) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(x, num_outputs) return x
def apply(self, x, blocks_per_group, channel_multiplier, num_outputs, train=True): x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv') x = WideResnetShakeShakeGroup(x, blocks_per_group, 16 * channel_multiplier, train=train) x = WideResnetShakeShakeGroup(x, blocks_per_group, 32 * channel_multiplier, (2, 2), train=train) x = WideResnetShakeShakeGroup(x, blocks_per_group, 64 * channel_multiplier, (2, 2), train=train) x = jax.nn.relu(x) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(x, num_outputs) return x
def apply(self, x, channels, strides=(1, 1), dropout_rate=0.0, train=True): batch_norm = nn.BatchNorm.partial(use_running_average=not train, momentum=0.9, epsilon=1e-5) y = batch_norm(x, name='bn1') y = jax.nn.relu(y) y = nn.Conv(y, channels, (3, 3), strides, padding='SAME', name='conv1') y = batch_norm(y, name='bn2') y = jax.nn.relu(y) if dropout_rate > 0.0: y = nn.dropout(y, dropout_rate, deterministic=not train) y = nn.Conv(y, channels, (3, 3), padding='SAME', name='conv2') # Apply an up projection in case of channel mismatch if (x.shape[-1] != channels) or strides != (1, 1): x = nn.Conv(x, channels, (3, 3), strides, padding='SAME') return x + y
def apply(self, x, inner_channels=8): x = NonLinearCycle(x, 4, inner_channels) x = nn.Conv(x, features=1, kernel_size=(3, 3), bias=False, padding='SAME') x = nn.relu(x) return x
def apply(self, x, channels, strides=(1, 1), train=True): batch_norm = nn.BatchNorm.partial(use_running_average=not train, momentum=0.9, epsilon=1e-5) a = b = residual = x a = jax.nn.relu(a) a = nn.Conv(a, channels, (3, 3), strides, padding='SAME', name='conv_a_1') a = batch_norm(a, name='bn_a_1') a = jax.nn.relu(a) a = nn.Conv(a, channels, (3, 3), padding='SAME', name='conv_a_2') a = batch_norm(a, name='bn_a_2') b = jax.nn.relu(b) b = nn.Conv(b, channels, (3, 3), strides, padding='SAME', name='conv_b_1') b = batch_norm(b, name='bn_b_1') b = jax.nn.relu(b) b = nn.Conv(b, channels, (3, 3), padding='SAME', name='conv_b_2') b = batch_norm(b, name='bn_b_2') if train and not self.is_initializing(): ab = shake.shake_shake_train(a, b) else: ab = shake.shake_shake_eval(a, b) # Apply an up projection in case of channel mismatch if (residual.shape[-1] != channels) or strides != (1, 1): residual = nn.Conv(residual, channels, (3, 3), strides, padding='SAME', name='conv_residual') residual = batch_norm(residual, name='bn_residual') return residual + ab
def apply(self, inputs, output_dim, kernel_size=3, biases=True): output = nn.Conv(inputs, features=output_dim, kernel_size=(kernel_size, kernel_size), strides=(1, 1), padding='SAME', bias=biases) output = sum([ output[:, ::2, ::2, :], output[:, 1::2, ::2, :], output[:, ::2, 1::2, :], output[:, 1::2, 1::2, :] ]) / 4. return output
def apply(self, x, channels, strides, prob, alpha_min, alpha_max, beta_min, beta_max, train=True): batch_norm = nn.BatchNorm.partial(use_running_average=not train, momentum=0.9, epsilon=1e-5) y = batch_norm(x, name='bn_1_pre') y = nn.Conv(y, channels, (1, 1), padding='SAME', name='1x1_conv_contract') y = batch_norm(y, name='bn_1_post') y = jax.nn.relu(y) y = nn.Conv(y, channels, (3, 3), strides, padding='SAME', name='3x3') y = batch_norm(y, name='bn_2') y = jax.nn.relu(y) y = nn.Conv(y, channels * 4, (1, 1), padding='SAME', name='1x1_conv_expand') y = batch_norm(y, name='bn_3') if train: y = shake.shake_drop_train(y, prob, alpha_min, alpha_max, beta_min, beta_max) else: y = shake.shake_drop_eval(y, prob, alpha_min, alpha_max) x = shortcut(x, channels * 4, strides) return x + y
def ddpm_conv1x1(x, out_planes, stride=1, bias=True, dilation=1, init_scale=1.): """1x1 convolution with DDPM initialization.""" bias_init = jnn.initializers.zeros output = nn.Conv(x, out_planes, kernel_size=(1, 1), strides=(stride, stride), padding='SAME', bias=bias, kernel_dilation=(dilation, dilation), kernel_init=default_init(init_scale), bias_init=bias_init) return output
def ncsn_conv3x3(x, out_planes, stride=1, bias=True, dilation=1, init_scale=1.): """3x3 convolution with PyTorch initialization. Same as NCSNv1/NCSNv2.""" init_scale = 1e-10 if init_scale == 0 else init_scale kernel_init = jnn.initializers.variance_scaling(1 / 3 * init_scale, 'fan_in', 'uniform') kernel_shape = (3, 3) + (x.shape[-1], out_planes) bias_init = lambda key, shape: kernel_init(key, kernel_shape)[0, 0, 0, :] output = nn.Conv(x, out_planes, kernel_size=(3, 3), strides=(stride, stride), padding='SAME', bias=bias, kernel_dilation=(dilation, dilation), kernel_init=kernel_init, bias_init=bias_init) return output
def apply(self, x, num_outputs, train=True): x = nn.Conv(x, self.NUM_FILTERS, (7, 7), (2, 2), bias=False, name='init_conv') x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, name='init_bn') x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(self.BLOCK_SIZES): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = BottleneckBlock(x, self.NUM_FILTERS * 2**i, strides=strides, groups=self.GROUPS, base_width=self.WIDTH_PER_GROUP, train=train) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_outputs, name='clf') return x
def apply(self, x, communication=Communication.NONE, train=True): """Forward pass.""" batch_size = x.shape[0] if communication is Communication.SQUEEZE_EXCITE_X: x = sample_patches.SqueezeExciteLayer(x) # end if squeeze excite x d1 = nn.relu( nn.Conv(x, 128, kernel_size=(3, 3), strides=(1, 1), bias=True, name="down1")) d2 = nn.relu( nn.Conv(d1, 128, kernel_size=(3, 3), strides=(2, 2), bias=True, name="down2")) d3 = nn.relu( nn.Conv(d2, 128, kernel_size=(3, 3), strides=(2, 2), bias=True, name="down3")) if communication is Communication.SQUEEZE_EXCITE_D: d1_flatten = einops.rearrange(d1, "b h w c -> b (h w) c") d2_flatten = einops.rearrange(d2, "b h w c -> b (h w) c") d3_flatten = einops.rearrange(d3, "b h w c -> b (h w) c") nd1 = d1_flatten.shape[1] nd2 = d2_flatten.shape[1] d_together = jnp.concatenate([d1_flatten, d2_flatten, d3_flatten], axis=1) num_channels = d_together.shape[-1] y = d_together.mean(axis=1) y = nn.Dense(y, features=num_channels // 4, bias=False) y = nn.relu(y) y = nn.Dense(y, features=num_channels, bias=False) y = nn.sigmoid(y) d_together = d_together * y[:, None, :] # split and reshape d1 = d_together[:, :nd1].reshape(d1.shape) d2 = d_together[:, nd1:nd1 + nd2].reshape(d2.shape) d3 = d_together[:, nd1 + nd2:].reshape(d3.shape) elif communication is Communication.TRANSFORMER: d1_flatten = einops.rearrange(d1, "b h w c -> b (h w) c") d2_flatten = einops.rearrange(d2, "b h w c -> b (h w) c") d3_flatten = einops.rearrange(d3, "b h w c -> b (h w) c") nd1 = d1_flatten.shape[1] nd2 = d2_flatten.shape[1] d_together = jnp.concatenate([d1_flatten, d2_flatten, d3_flatten], axis=1) positional_encodings = self.param( "scale_ratio_position_encodings", shape=(1, ) + d_together.shape[1:], initializer=jax.nn.initializers.normal(1. / d_together.shape[-1])) d_together = transformer.Transformer(d_together + positional_encodings, num_layers=2, num_heads=8, is_training=train) # split and reshape d1 = d_together[:, :nd1].reshape(d1.shape) d2 = d_together[:, nd1:nd1 + nd2].reshape(d2.shape) d3 = d_together[:, nd1 + nd2:].reshape(d3.shape) t1 = nn.Conv(d1, 6, kernel_size=(1, 1), strides=(1, 1), bias=True, name="tidy1") t2 = nn.Conv(d2, 6, kernel_size=(1, 1), strides=(1, 1), bias=True, name="tidy2") t3 = nn.Conv(d3, 9, kernel_size=(1, 1), strides=(1, 1), bias=True, name="tidy3") raw_scores = (jnp.split(t1, 6, axis=-1) + jnp.split(t2, 6, axis=-1) + jnp.split(t3, 9, axis=-1)) # The following is for normalization. t = jnp.concatenate((jnp.reshape( t1, [batch_size, -1]), jnp.reshape( t2, [batch_size, -1]), jnp.reshape(t3, [batch_size, -1])), axis=1) t_min = jnp.reshape(jnp.min(t, axis=-1), [batch_size, 1, 1, 1]) t_max = jnp.reshape(jnp.max(t, axis=-1), [batch_size, 1, 1, 1]) normalized_scores = zeroone(raw_scores, t_min, t_max) stats = { "scores": normalized_scores, "raw_scores": t, } # removes the split dimension. scores are now b x h' x w' shaped normalized_scores = [s.squeeze(-1) for s in normalized_scores] return normalized_scores, stats
def apply( self, inputs, blocks_per_group, channel_multiplier, num_outputs, kernel_size=(3, 3), strides=None, maxpool=False, dropout_rate=0.0, dtype=jnp.float32, norm_layer='group_norm', train=True, return_activations=False, input_layer_key='input', has_discriminator=False, discriminator=False, ): norm_layer_name = '' if norm_layer == 'batch_norm': norm_layer = nn.BatchNorm.partial(use_running_average=not train) norm_layer_name = 'bn' elif norm_layer == 'group_norm': norm_layer = nn.GroupNorm.partial(num_groups=16) norm_layer_name = 'gn' layer_activations = collections.OrderedDict() input_is_set = False current_rep_key = 'input' if input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'init_conv' if input_is_set: x = nn.Conv( x, 16, kernel_size=kernel_size, strides=strides, padding='SAME', name='init_conv') if maxpool: x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'l1' if input_is_set: x = WideResnetGroup( x, blocks_per_group, 16 * channel_multiplier, dropout_rate=dropout_rate, norm_layer=norm_layer, train=train, name='l1') layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'l2' if input_is_set: x = WideResnetGroup( x, blocks_per_group, 32 * channel_multiplier, (2, 2), dropout_rate=dropout_rate, norm_layer=norm_layer, train=train, name='l2') layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'l3' if input_is_set: x = WideResnetGroup( x, blocks_per_group, 64 * channel_multiplier, (2, 2), dropout_rate=dropout_rate, norm_layer=norm_layer, train=train, name='l3') layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'l4' if input_is_set: x = norm_layer(x, name=f'{norm_layer_name}') x = jax.nn.relu(x) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key # DANN module if has_discriminator: z = dann_utils.flip_grad_identity(x) z = nn.Dense(z, 2, name='disc_l1', bias=True) z = nn.relu(z) z = nn.Dense(z, 2, name='disc_l2', bias=True) current_rep_key = 'head' if input_is_set: x = nn.Dense(x, num_outputs, dtype=dtype, name='head') else: x = inputs layer_activations[current_rep_key] = x rep_key = current_rep_key logging.warn('Input was never used') outputs = x if return_activations: outputs = (x, layer_activations, rep_key) if discriminator and has_discriminator: outputs = outputs + (z,) else: del layer_activations if discriminator and has_discriminator: outputs = (x, z) if discriminator and (not has_discriminator): raise ValueError( 'Incosistent values passed for discriminator and has_discriminator') return outputs
def apply(self, x, num_outputs, pyramid_alpha=200, pyramid_depth=272, train=True): assert (pyramid_depth - 2) % 9 == 0 # Shake-drop hyper-params mask_prob = 0.5 alpha_min, alpha_max = (-1.0, 1.0) beta_min, beta_max = (0.0, 1.0) # Bottleneck network size blocks_per_group = (pyramid_depth - 2) // 9 # See Eqn 2 in https://arxiv.org/abs/1610.02915 num_channels = 16 # N in https://arxiv.org/abs/1610.02915 total_blocks = blocks_per_group * 3 delta_channels = pyramid_alpha / total_blocks x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv') x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, name='init_bn') layer_num = 1 for block_i in range(blocks_per_group): num_channels += delta_channels layer_mask_prob = _calc_shakedrop_mask_prob( layer_num, total_blocks, mask_prob) x = BottleneckShakeDrop(x, int(num_channels), (1, 1), layer_mask_prob, alpha_min, alpha_max, beta_min, beta_max, train=train) layer_num += 1 for block_i in range(blocks_per_group): num_channels += delta_channels layer_mask_prob = _calc_shakedrop_mask_prob( layer_num, total_blocks, mask_prob) x = BottleneckShakeDrop(x, int(num_channels), ((2, 2) if block_i == 0 else (1, 1)), layer_mask_prob, alpha_min, alpha_max, beta_min, beta_max, train=train) layer_num += 1 for block_i in range(blocks_per_group): num_channels += delta_channels layer_mask_prob = _calc_shakedrop_mask_prob( layer_num, total_blocks, mask_prob) x = BottleneckShakeDrop(x, int(num_channels), ((2, 2) if block_i == 0 else (1, 1)), layer_mask_prob, alpha_min, alpha_max, beta_min, beta_max, train=train) layer_num += 1 assert layer_num - 1 == total_blocks x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, name='final_bn') x = jax.nn.relu(x) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(x, num_outputs) return x
def apply(self, inputs, num_outputs, num_filters=64, num_layers=50, dropout_rate=0.0, input_dropout_rate=0.0, train=True, dtype=jnp.float32, head_bias_init=jnp.zeros, return_activations=False, input_layer_key='input', has_discriminator=False, discriminator=False): """Apply a ResNet network on the input. Args: inputs: jnp array; Inputs. num_outputs: int; Number of output units. num_filters: int; Determines base number of filters. Number of filters in block i is num_filters * 2 ** i. num_layers: int; Number of layers (should be one of the predefined ones.) dropout_rate: float; Rate of dropping out the output of different hidden layers. input_dropout_rate: float; Rate of dropping out the input units. train: bool; Is train? dtype: jnp type; Type of the outputs. head_bias_init: fn(rng_key, shape)--> jnp array; Initializer for head bias parameters. return_activations: bool; If True hidden activation are also returned. input_layer_key: str; Determines where to plugin the input (this is to enable providing inputs to slices of the model). If `input_layer_key` is `layer_i` we assume the inputs are the activations of `layer_i` and pass them to `layer_{i+1}`. has_discriminator: bool; Whether the model should have discriminator layer. discriminator: bool; Whether we should return discriminator logits. Returns: Unnormalized Logits with shape `[bs, num_outputs]`, if return_activations: Logits, dict of hidden activations and the key to the representation(s) which will be used in as ``The Representation'', e.g., for computing losses. """ if num_layers not in ResNet._block_size_options: raise ValueError('Please provide a valid number of layers') block_sizes = ResNet._block_size_options[num_layers] layer_activations = collections.OrderedDict() input_is_set = False current_rep_key = 'input' if input_layer_key == current_rep_key: x = inputs input_is_set = True if input_is_set: # Input dropout x = nn.dropout(x, input_dropout_rate, deterministic=not train) layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'init_conv' if input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_is_set: # First block x = nn.Conv(x, num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], bias=False, dtype=dtype, name='init_conv') x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, dtype=dtype, name='init_bn') x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') layer_activations[current_rep_key] = x rep_key = current_rep_key # Residual blocks for i, block_size in enumerate(block_sizes): # Stage i (each stage contains blocks of the same size). for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) current_rep_key = f'block_{i + 1}+{j}' if input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_is_set: x = ResidualBlock(x, num_filters * 2**i, strides=strides, dropout_rate=dropout_rate, train=train, dtype=dtype, name=f'block_{i + 1}_{j}') layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'avg_pool' if input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_is_set: # Global Average Pool x = jnp.mean(x, axis=(1, 2)) layer_activations[current_rep_key] = x rep_key = current_rep_key # DANN module if has_discriminator: z = dann_utils.flip_grad_identity(x) z = nn.Dense(z, 2, name='disc_l1', bias=True) z = nn.relu(z) z = nn.Dense(z, 2, name='disc_l2', bias=True) current_rep_key = 'head' if input_layer_key == current_rep_key: x = inputs layer_activations[current_rep_key] = x rep_key = current_rep_key logging.warn('Input was never used') elif input_is_set: x = nn.Dense(x, num_outputs, dtype=dtype, bias_init=head_bias_init, name='head') # Make sure that the output is float32, even if our previous computations # are in float16, or other types. x = jnp.asarray(x, jnp.float32) outputs = x if return_activations: outputs = (x, layer_activations, rep_key) if discriminator and has_discriminator: outputs = outputs + (z, ) else: del layer_activations if discriminator and has_discriminator: outputs = (x, z) if discriminator and (not has_discriminator): raise ValueError( 'Incosistent values passed for discriminator and has_discriminator' ) return outputs
def apply(self, x, inner_channels=8): x = nn.Conv(x, features=1, kernel_size=(3, 3), bias=False, strides=(2, 2), padding='VALID') x = nn.Conv(x, features=inner_channels, kernel_size=(3, 3), bias=False, padding='SAME') x = nn.Conv(x, features=inner_channels, kernel_size=(3, 3), bias=False, padding='SAME') x = nn.Conv(x, features=inner_channels, kernel_size=(3, 3), bias=False, padding='VALID', strides=(2,2)) x = nn.Conv(x, features=inner_channels, kernel_size=(3, 3), bias=False, padding='SAME') x = nn.Conv(x, features=inner_channels, kernel_size=(3, 3), bias=False, padding='SAME') x = nn.Conv(x, features=inner_channels, kernel_size=(3, 3), bias=False, padding='SAME') x = nn.Conv(x, features=inner_channels, kernel_size=(3, 3), bias=False, input_dilation=(2,2),padding=[(2, 2), (2, 2)]) x = nn.Conv(x, features=inner_channels, kernel_size=(3, 3), bias=False, padding='SAME') x = nn.Conv(x, features=inner_channels, kernel_size=(3, 3), bias=False, padding='SAME') x = nn.Conv(x, features=1, kernel_size=(3, 3), bias=False, input_dilation=(2,2),padding=[(2, 2), (2, 2)]) return x