def apply(self, x, num_filters=64, block_sizes=(3, 4, 6, 3), train=True, block=BottleneckBlock, small_inputs=False): if small_inputs: x = nn.Conv(x, num_filters, kernel_size=(3, 3), strides=(1, 1), bias=False, name="init_conv") else: x = nn.Conv(x, num_filters, kernel_size=(7, 7), strides=(2, 2), bias=False, name="init_conv") x = nn.BatchNorm(x, use_running_average=not train, epsilon=1e-5, name="init_bn") if not small_inputs: x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME") for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = block(x, num_filters * 2**i, strides=strides, train=train) return x
def apply(self, x, features, n_stages, act=nn.relu): x = act(x) path = x for _ in range(n_stages): path = nn.max_pool(path, window_shape=(5, 5), strides=(1, 1), padding='SAME') path = ncsn_conv3x3(path, features, stride=1, bias=False) x = path + x return x
def apply(self, x, use_squeeze_excite = False): x = nn.Conv(x, features=8, kernel_size=(3, 3), padding="VALID") x = nn.relu(x) x = nn.Conv(x, features=16, kernel_size=(3, 3), padding="VALID") x = nn.relu(x) if use_squeeze_excite: x = SqueezeExciteLayer(x) x = nn.Conv(x, features=32, kernel_size=(3, 3), padding="VALID") x = nn.relu(x) if use_squeeze_excite: x = SqueezeExciteLayer(x) x = nn.Conv(x, features=1, kernel_size=(3, 3), padding="VALID") scores = nn.max_pool(x, window_shape=(8, 8), strides=(8, 8))[Ellipsis, 0] return scores
def apply(self, x, num_outputs, train=True): x = nn.Conv(x, self.NUM_FILTERS, (7, 7), (2, 2), bias=False, name='init_conv') x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, name='init_bn') x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(self.BLOCK_SIZES): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = BottleneckBlock(x, self.NUM_FILTERS * 2**i, strides=strides, groups=self.GROUPS, base_width=self.WIDTH_PER_GROUP, train=train) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_outputs, name='clf') return x
def apply(self, x, *, train, num_classes, block_class=BottleneckResNetImageNetBlock, stage_sizes, width_factor=1, normalization='bn', activation_f=None, std_penalty_mult=0, use_residual=1, bias_scale=0.0, weight_norm='none', compensate_padding=True, softplus_scale=None, no_head=False, zero_inits=True): """Construct ResNet V1 with `num_classes` outputs.""" self._stage_sizes = stage_sizes if std_penalty_mult > 0: raise NotImplementedError( 'std_penalty_mult not supported for ResNetImageNet') width = 64 * width_factor # Root block. activation_f = get_activation_f(activation_f, train, softplus_scale, bias_scale) norm = get_norm(activation_f, normalization, train) conv = get_conv(activation_f, bias_scale, weight_norm, compensate_padding, normalization) x = conv(x, width, kernel_size=(7, 7), strides=(2, 2), name='init_conv') x = norm(x, name='init_bn') if compensate_padding: # NOTE: this leads to lower performance. x = nn.avg_pool(x, (2, 2), strides=(2, 2), padding='SAME') else: x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') # Stages. for i, stage_size in enumerate(stage_sizes): x = ResNetStage( x, stage_size, filters=width * 2**i, block_class=block_class, first_block_strides=(1, 1) if i == 0 else (2, 2), train=train, name=f'stage{i + 1}', conv=conv, norm=norm, activation_f=activation_f, use_residual=use_residual, zero_inits=zero_inits, ) if not no_head: # Head. x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_classes, kernel_init=nn.initializers.zeros if zero_inits else nn.initializers.lecun_normal(), name='head') return x, 0, {}
def apply(self, x, depth, num_outputs, dropout_rate=0.0, normalization='bn', activation_f=None, std_penalty_mult=0, use_residual=1, train=True, bias_scale=0.0, weight_norm='none', filters=16, no_head=False, report_metrics=False, benchmark='cifar10', compensate_padding=True, softplus_scale=None): bn_index = iter(range(1000)) conv_index = iter(range(1000)) summaries = {} summary_ind = [0] def add_summary(name, val): """Summarize statistics of tensor.""" if report_metrics: assert val.ndim == 4, ( 'Assuming 4D inputs with channels last, got %s' % str(val.shape)) assert val.shape[1] == val.shape[ 2], 'Assuming 4D inputs with channels last' summaries['%s_%d_mean_abs' % (name, summary_ind[0] // 2)] = jnp.mean( jnp.abs(jnp.mean(val, axis=(0, 1, 2)))) summaries['%s_%d_mean_std' % (name, summary_ind[0] // 2)] = jnp.mean( jnp.std(val, axis=(0, 1, 2))) summary_ind[0] += 1 penalty = 0 activation_f = get_activation_f(activation_f, train, softplus_scale, bias_scale) norm = get_norm(activation_f, normalization, train) conv = get_conv(activation_f, bias_scale, weight_norm, compensate_padding, normalization) def resnet_layer( inputs, penalty, filters, kernel_size=3, strides=1, activation=None, ): """2D Convolution-Batch Normalization-Activation stack builder.""" x = inputs x = conv(x, filters, (kernel_size, kernel_size), strides=(strides, strides), padding='SAME', name='conv%d' % next(conv_index)) x = norm(x, name='norm%d' % next(bn_index)) add_summary('postnorm', x) if std_penalty_mult > 0: penalty += std_penalty(x) if activation: x = activation_f(x, features=x.shape[-1]) add_summary('postact', x) return x, penalty # Main network code. num_res_blocks = (depth - 2) // 6 if (depth - 2) % 6 != 0: raise ValueError('depth must be 6n+2 (e.g. 20, 32, 44).') inputs = x add_summary('input', x) add_summary('inputb', x) if benchmark in ['cifar10', 'cifar100']: x, penalty = resnet_layer(inputs, penalty, filters=filters, activation=True) head_kernel_init = nn.initializers.lecun_normal() elif benchmark in ['imagenet']: head_kernel_init = nn.initializers.zeros x, penalty = resnet_layer(inputs, penalty, filters=filters, activation=False, kernel_size=7, strides=2) # TODO(basv): evaluate max pool v/s avg_pool in an experiment? # if compensate_padding: # x = nn.avg_pool(x, (2, 2), strides=(2, 2), padding="VALID") # else: x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') else: raise ValueError('Model def not prepared for benchmark %s' % benchmark) for stack in range(3): for res_block in range(num_res_blocks): strides = 1 if stack > 0 and res_block == 0: # First layer but not first stack. strides = 2 # Downsample. y, penalty = resnet_layer( x, penalty, filters=filters, strides=strides, activation=True, ) y, penalty = resnet_layer( y, penalty, filters=filters, activation=False, ) if stack > 0 and res_block == 0: # First layer but not first stack. # Linear projection residual shortcut to match changed dims. x, penalty = resnet_layer( x, penalty, filters=filters, kernel_size=1, strides=strides, activation=False, ) if use_residual == 1: # Apply an up projection in case of channel mismatch x = x + y elif use_residual == 2: x = (x + y) / jnp.sqrt( 1**2 + 1**2) # Sum of independent normals. else: x = y add_summary('postres', x) x = activation_f(x, features=x.shape[-1]) add_summary('postresact', x) filters *= 2 # V1 does not use BN after last shortcut connection-ReLU. if not no_head: x = jnp.mean(x, axis=(1, 2)) add_summary('postpool', x) x = x.reshape((x.shape[0], -1)) x = nn.Dense(x, num_outputs, kernel_init=head_kernel_init) return x, penalty, summaries
def apply( self, inputs, blocks_per_group, channel_multiplier, num_outputs, kernel_size=(3, 3), strides=None, maxpool=False, dropout_rate=0.0, dtype=jnp.float32, norm_layer='group_norm', train=True, return_activations=False, input_layer_key='input', has_discriminator=False, discriminator=False, ): norm_layer_name = '' if norm_layer == 'batch_norm': norm_layer = nn.BatchNorm.partial(use_running_average=not train) norm_layer_name = 'bn' elif norm_layer == 'group_norm': norm_layer = nn.GroupNorm.partial(num_groups=16) norm_layer_name = 'gn' layer_activations = collections.OrderedDict() input_is_set = False current_rep_key = 'input' if input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'init_conv' if input_is_set: x = nn.Conv( x, 16, kernel_size=kernel_size, strides=strides, padding='SAME', name='init_conv') if maxpool: x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'l1' if input_is_set: x = WideResnetGroup( x, blocks_per_group, 16 * channel_multiplier, dropout_rate=dropout_rate, norm_layer=norm_layer, train=train, name='l1') layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'l2' if input_is_set: x = WideResnetGroup( x, blocks_per_group, 32 * channel_multiplier, (2, 2), dropout_rate=dropout_rate, norm_layer=norm_layer, train=train, name='l2') layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'l3' if input_is_set: x = WideResnetGroup( x, blocks_per_group, 64 * channel_multiplier, (2, 2), dropout_rate=dropout_rate, norm_layer=norm_layer, train=train, name='l3') layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'l4' if input_is_set: x = norm_layer(x, name=f'{norm_layer_name}') x = jax.nn.relu(x) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key # DANN module if has_discriminator: z = dann_utils.flip_grad_identity(x) z = nn.Dense(z, 2, name='disc_l1', bias=True) z = nn.relu(z) z = nn.Dense(z, 2, name='disc_l2', bias=True) current_rep_key = 'head' if input_is_set: x = nn.Dense(x, num_outputs, dtype=dtype, name='head') else: x = inputs layer_activations[current_rep_key] = x rep_key = current_rep_key logging.warn('Input was never used') outputs = x if return_activations: outputs = (x, layer_activations, rep_key) if discriminator and has_discriminator: outputs = outputs + (z,) else: del layer_activations if discriminator and has_discriminator: outputs = (x, z) if discriminator and (not has_discriminator): raise ValueError( 'Incosistent values passed for discriminator and has_discriminator') return outputs
def apply(self, inputs, num_outputs, num_filters=64, num_layers=50, dropout_rate=0.0, input_dropout_rate=0.0, train=True, dtype=jnp.float32, head_bias_init=jnp.zeros, return_activations=False, input_layer_key='input', has_discriminator=False, discriminator=False): """Apply a ResNet network on the input. Args: inputs: jnp array; Inputs. num_outputs: int; Number of output units. num_filters: int; Determines base number of filters. Number of filters in block i is num_filters * 2 ** i. num_layers: int; Number of layers (should be one of the predefined ones.) dropout_rate: float; Rate of dropping out the output of different hidden layers. input_dropout_rate: float; Rate of dropping out the input units. train: bool; Is train? dtype: jnp type; Type of the outputs. head_bias_init: fn(rng_key, shape)--> jnp array; Initializer for head bias parameters. return_activations: bool; If True hidden activation are also returned. input_layer_key: str; Determines where to plugin the input (this is to enable providing inputs to slices of the model). If `input_layer_key` is `layer_i` we assume the inputs are the activations of `layer_i` and pass them to `layer_{i+1}`. has_discriminator: bool; Whether the model should have discriminator layer. discriminator: bool; Whether we should return discriminator logits. Returns: Unnormalized Logits with shape `[bs, num_outputs]`, if return_activations: Logits, dict of hidden activations and the key to the representation(s) which will be used in as ``The Representation'', e.g., for computing losses. """ if num_layers not in ResNet._block_size_options: raise ValueError('Please provide a valid number of layers') block_sizes = ResNet._block_size_options[num_layers] layer_activations = collections.OrderedDict() input_is_set = False current_rep_key = 'input' if input_layer_key == current_rep_key: x = inputs input_is_set = True if input_is_set: # Input dropout x = nn.dropout(x, input_dropout_rate, deterministic=not train) layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'init_conv' if input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_is_set: # First block x = nn.Conv(x, num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], bias=False, dtype=dtype, name='init_conv') x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, dtype=dtype, name='init_bn') x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') layer_activations[current_rep_key] = x rep_key = current_rep_key # Residual blocks for i, block_size in enumerate(block_sizes): # Stage i (each stage contains blocks of the same size). for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) current_rep_key = f'block_{i + 1}+{j}' if input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_is_set: x = ResidualBlock(x, num_filters * 2**i, strides=strides, dropout_rate=dropout_rate, train=train, dtype=dtype, name=f'block_{i + 1}_{j}') layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'avg_pool' if input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_is_set: # Global Average Pool x = jnp.mean(x, axis=(1, 2)) layer_activations[current_rep_key] = x rep_key = current_rep_key # DANN module if has_discriminator: z = dann_utils.flip_grad_identity(x) z = nn.Dense(z, 2, name='disc_l1', bias=True) z = nn.relu(z) z = nn.Dense(z, 2, name='disc_l2', bias=True) current_rep_key = 'head' if input_layer_key == current_rep_key: x = inputs layer_activations[current_rep_key] = x rep_key = current_rep_key logging.warn('Input was never used') elif input_is_set: x = nn.Dense(x, num_outputs, dtype=dtype, bias_init=head_bias_init, name='head') # Make sure that the output is float32, even if our previous computations # are in float16, or other types. x = jnp.asarray(x, jnp.float32) outputs = x if return_activations: outputs = (x, layer_activations, rep_key) if discriminator and has_discriminator: outputs = outputs + (z, ) else: del layer_activations if discriminator and has_discriminator: outputs = (x, z) if discriminator and (not has_discriminator): raise ValueError( 'Incosistent values passed for discriminator and has_discriminator' ) return outputs