def apply(self, x, num_filters=64, block_sizes=(3, 4, 6, 3), train=True, block=BottleneckBlock, small_inputs=False): if small_inputs: x = nn.Conv(x, num_filters, kernel_size=(3, 3), strides=(1, 1), bias=False, name="init_conv") else: x = nn.Conv(x, num_filters, kernel_size=(7, 7), strides=(2, 2), bias=False, name="init_conv") x = nn.BatchNorm(x, use_running_average=not train, epsilon=1e-5, name="init_bn") if not small_inputs: x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME") for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = block(x, num_filters * 2**i, strides=strides, train=train) return x
def apply(self, x, blocks_per_group, channel_multiplier, num_outputs, dropout_rate=0.0, train=True): x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv') x = WideResnetGroup(x, blocks_per_group, 16 * channel_multiplier, dropout_rate=dropout_rate, train=train) x = WideResnetGroup(x, blocks_per_group, 32 * channel_multiplier, (2, 2), dropout_rate=dropout_rate, train=train) x = WideResnetGroup(x, blocks_per_group, 64 * channel_multiplier, (2, 2), dropout_rate=dropout_rate, train=train) x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5) x = jax.nn.relu(x) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(x, num_outputs) return x
def apply(self, x, num_outputs, train=True): x = nn.Conv(x, self.NUM_FILTERS, (7, 7), (2, 2), bias=False, name='init_conv') x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, name='init_bn') x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(self.BLOCK_SIZES): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = BottleneckBlock(x, self.NUM_FILTERS * 2**i, strides=strides, groups=self.GROUPS, base_width=self.WIDTH_PER_GROUP, train=train) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_outputs, name='clf') return x
def apply(self, inputs, num_outputs, num_filters=64, num_layers=50, dropout_rate=0.0, input_dropout_rate=0.0, train=True, dtype=jnp.float32, head_bias_init=jnp.zeros, return_activations=False, input_layer_key='input', has_discriminator=False, discriminator=False): """Apply a ResNet network on the input. Args: inputs: jnp array; Inputs. num_outputs: int; Number of output units. num_filters: int; Determines base number of filters. Number of filters in block i is num_filters * 2 ** i. num_layers: int; Number of layers (should be one of the predefined ones.) dropout_rate: float; Rate of dropping out the output of different hidden layers. input_dropout_rate: float; Rate of dropping out the input units. train: bool; Is train? dtype: jnp type; Type of the outputs. head_bias_init: fn(rng_key, shape)--> jnp array; Initializer for head bias parameters. return_activations: bool; If True hidden activation are also returned. input_layer_key: str; Determines where to plugin the input (this is to enable providing inputs to slices of the model). If `input_layer_key` is `layer_i` we assume the inputs are the activations of `layer_i` and pass them to `layer_{i+1}`. has_discriminator: bool; Whether the model should have discriminator layer. discriminator: bool; Whether we should return discriminator logits. Returns: Unnormalized Logits with shape `[bs, num_outputs]`, if return_activations: Logits, dict of hidden activations and the key to the representation(s) which will be used in as ``The Representation'', e.g., for computing losses. """ if num_layers not in ResNet._block_size_options: raise ValueError('Please provide a valid number of layers') block_sizes = ResNet._block_size_options[num_layers] layer_activations = collections.OrderedDict() input_is_set = False current_rep_key = 'input' if input_layer_key == current_rep_key: x = inputs input_is_set = True if input_is_set: # Input dropout x = nn.dropout(x, input_dropout_rate, deterministic=not train) layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'init_conv' if input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_is_set: # First block x = nn.Conv(x, num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], bias=False, dtype=dtype, name='init_conv') x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, dtype=dtype, name='init_bn') x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') layer_activations[current_rep_key] = x rep_key = current_rep_key # Residual blocks for i, block_size in enumerate(block_sizes): # Stage i (each stage contains blocks of the same size). for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) current_rep_key = f'block_{i + 1}+{j}' if input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_is_set: x = ResidualBlock(x, num_filters * 2**i, strides=strides, dropout_rate=dropout_rate, train=train, dtype=dtype, name=f'block_{i + 1}_{j}') layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'avg_pool' if input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_is_set: # Global Average Pool x = jnp.mean(x, axis=(1, 2)) layer_activations[current_rep_key] = x rep_key = current_rep_key # DANN module if has_discriminator: z = dann_utils.flip_grad_identity(x) z = nn.Dense(z, 2, name='disc_l1', bias=True) z = nn.relu(z) z = nn.Dense(z, 2, name='disc_l2', bias=True) current_rep_key = 'head' if input_layer_key == current_rep_key: x = inputs layer_activations[current_rep_key] = x rep_key = current_rep_key logging.warn('Input was never used') elif input_is_set: x = nn.Dense(x, num_outputs, dtype=dtype, bias_init=head_bias_init, name='head') # Make sure that the output is float32, even if our previous computations # are in float16, or other types. x = jnp.asarray(x, jnp.float32) outputs = x if return_activations: outputs = (x, layer_activations, rep_key) if discriminator and has_discriminator: outputs = outputs + (z, ) else: del layer_activations if discriminator and has_discriminator: outputs = (x, z) if discriminator and (not has_discriminator): raise ValueError( 'Incosistent values passed for discriminator and has_discriminator' ) return outputs
def apply(self, x, num_outputs, pyramid_alpha=200, pyramid_depth=272, train=True): assert (pyramid_depth - 2) % 9 == 0 # Shake-drop hyper-params mask_prob = 0.5 alpha_min, alpha_max = (-1.0, 1.0) beta_min, beta_max = (0.0, 1.0) # Bottleneck network size blocks_per_group = (pyramid_depth - 2) // 9 # See Eqn 2 in https://arxiv.org/abs/1610.02915 num_channels = 16 # N in https://arxiv.org/abs/1610.02915 total_blocks = blocks_per_group * 3 delta_channels = pyramid_alpha / total_blocks x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv') x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, name='init_bn') layer_num = 1 for block_i in range(blocks_per_group): num_channels += delta_channels layer_mask_prob = _calc_shakedrop_mask_prob( layer_num, total_blocks, mask_prob) x = BottleneckShakeDrop(x, int(num_channels), (1, 1), layer_mask_prob, alpha_min, alpha_max, beta_min, beta_max, train=train) layer_num += 1 for block_i in range(blocks_per_group): num_channels += delta_channels layer_mask_prob = _calc_shakedrop_mask_prob( layer_num, total_blocks, mask_prob) x = BottleneckShakeDrop(x, int(num_channels), ((2, 2) if block_i == 0 else (1, 1)), layer_mask_prob, alpha_min, alpha_max, beta_min, beta_max, train=train) layer_num += 1 for block_i in range(blocks_per_group): num_channels += delta_channels layer_mask_prob = _calc_shakedrop_mask_prob( layer_num, total_blocks, mask_prob) x = BottleneckShakeDrop(x, int(num_channels), ((2, 2) if block_i == 0 else (1, 1)), layer_mask_prob, alpha_min, alpha_max, beta_min, beta_max, train=train) layer_num += 1 assert layer_num - 1 == total_blocks x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, name='final_bn') x = jax.nn.relu(x) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(x, num_outputs) return x