def apply(self, x, blocks_per_group, channel_multiplier, num_outputs, train=True): x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv') x = WideResnetShakeShakeGroup(x, blocks_per_group, 16 * channel_multiplier, train=train) x = WideResnetShakeShakeGroup(x, blocks_per_group, 32 * channel_multiplier, (2, 2), train=train) x = WideResnetShakeShakeGroup(x, blocks_per_group, 64 * channel_multiplier, (2, 2), train=train) x = jax.nn.relu(x) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(x, num_outputs) return x
def apply(self, x, out_ch=None, with_conv=False, fir=False, fir_kernel=[1, 3, 3, 1]): B, H, W, C = x.shape out_ch = out_ch if out_ch else C if not fir: if with_conv: x = conv3x3(x, out_ch, stride=2) else: x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2), padding='SAME') else: if not with_conv: x = up_or_down_sampling.downsample_2d(x, fir_kernel, factor=2) else: x = up_or_down_sampling.Conv2d(x, out_ch, kernel=3, down=True, resample_kernel=fir_kernel, bias=True, kernel_init=default_init()) assert x.shape == (B, H // 2, W // 2, out_ch) return x
def apply(self, x): x = nn.Conv(x, features=32, kernel_size=(3, 3), name="conv") x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten. x = nn.Dense(x, 128, name="fc") return x
def apply(self, x, blocks_per_group, channel_multiplier, num_outputs, dropout_rate=0.0, train=True): x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv') x = WideResnetGroup(x, blocks_per_group, 16 * channel_multiplier, dropout_rate=dropout_rate, train=train) x = WideResnetGroup(x, blocks_per_group, 32 * channel_multiplier, (2, 2), dropout_rate=dropout_rate, train=train) x = WideResnetGroup(x, blocks_per_group, 64 * channel_multiplier, (2, 2), dropout_rate=dropout_rate, train=train) x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5) x = jax.nn.relu(x) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(x, num_outputs) return x
def apply(self, x, with_conv=False): B, H, W, C = x.shape if with_conv: x = ddpm_conv3x3(x, C, stride=2) else: x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2), padding='SAME') assert x.shape == (B, H // 2, W // 2, C) return x
def apply(self, x, y, features, n_stages, normalizer, act=nn.relu): x = act(x) path = x for _ in range(n_stages): path = normalizer(path, y) path = nn.avg_pool(path, window_shape=(5, 5), strides=(1, 1), padding='SAME') path = ncsn_conv3x3(path, features, stride=1, bias=False) x = path + x return x
def shortcut(x, chn_out, strides): """Pyramid Net Shortcut. Use Average pooling to downsample Use zero-padding to increase channels Args: x: input chn_out: expected number of output channels strides: striding applied by block Returns: shortcut """ chn_in = x.shape[3] if strides != (1, 1): x = nn.avg_pool(x, strides, strides) if chn_out != chn_in: diff = chn_out - chn_in x = jnp.pad(x, [[0, 0], [0, 0], [0, 0], [0, diff]]) return x
def apply(self, x, *, train, num_classes, block_class=BottleneckResNetImageNetBlock, stage_sizes, width_factor=1, normalization='bn', activation_f=None, std_penalty_mult=0, use_residual=1, bias_scale=0.0, weight_norm='none', compensate_padding=True, softplus_scale=None, no_head=False, zero_inits=True): """Construct ResNet V1 with `num_classes` outputs.""" self._stage_sizes = stage_sizes if std_penalty_mult > 0: raise NotImplementedError( 'std_penalty_mult not supported for ResNetImageNet') width = 64 * width_factor # Root block. activation_f = get_activation_f(activation_f, train, softplus_scale, bias_scale) norm = get_norm(activation_f, normalization, train) conv = get_conv(activation_f, bias_scale, weight_norm, compensate_padding, normalization) x = conv(x, width, kernel_size=(7, 7), strides=(2, 2), name='init_conv') x = norm(x, name='init_bn') if compensate_padding: # NOTE: this leads to lower performance. x = nn.avg_pool(x, (2, 2), strides=(2, 2), padding='SAME') else: x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') # Stages. for i, stage_size in enumerate(stage_sizes): x = ResNetStage( x, stage_size, filters=width * 2**i, block_class=block_class, first_block_strides=(1, 1) if i == 0 else (2, 2), train=train, name=f'stage{i + 1}', conv=conv, norm=norm, activation_f=activation_f, use_residual=use_residual, zero_inits=zero_inits, ) if not no_head: # Head. x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_classes, kernel_init=nn.initializers.zeros if zero_inits else nn.initializers.lecun_normal(), name='head') return x, 0, {}
def apply(self, x, blocks_per_group, channel_multiplier, num_outputs, dropout_rate=0.0, normalization='bn', activation_f=None, std_penalty_mult=0, use_residual=1, train=True, bias_scale=0.0, weight_norm='none', no_head=False, compensate_padding=True, softplus_scale=None): penalty = 0 activation_f = get_activation_f(activation_f, train, softplus_scale, bias_scale) norm = get_norm(activation_f, normalization, train) conv = get_conv(activation_f, bias_scale, weight_norm, compensate_padding, normalization) x = conv(x, 16 * channel_multiplier, (3, 3), padding='SAME', name='init_conv') x, g_penalty = WideResnetGroup(x, blocks_per_group, 16 * channel_multiplier, dropout_rate=dropout_rate, normalization=normalization, activation_f=activation_f, std_penalty_mult=std_penalty_mult, use_residual=use_residual, train=train, bias_scale=bias_scale, weight_norm=weight_norm) penalty += g_penalty x, g_penalty = WideResnetGroup(x, blocks_per_group, 32 * channel_multiplier, (2, 2), dropout_rate=dropout_rate, normalization=normalization, activation_f=activation_f, std_penalty_mult=std_penalty_mult, use_residual=use_residual, train=train, bias_scale=bias_scale, weight_norm=weight_norm) penalty += g_penalty x, g_penalty = WideResnetGroup(x, blocks_per_group, 64 * channel_multiplier, (2, 2), dropout_rate=dropout_rate, normalization=normalization, activation_f=activation_f, std_penalty_mult=std_penalty_mult, use_residual=use_residual, train=train, bias_scale=bias_scale, weight_norm=weight_norm) penalty += g_penalty x = norm(x, name='final_norm') if std_penalty_mult > 0: penalty += std_penalty(x) if not no_head: x = activation_f(x, features=x.shape[-1]) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(x, num_outputs) return x, penalty, {}
def apply( self, inputs, blocks_per_group, channel_multiplier, num_outputs, kernel_size=(3, 3), strides=None, maxpool=False, dropout_rate=0.0, dtype=jnp.float32, norm_layer='group_norm', train=True, return_activations=False, input_layer_key='input', has_discriminator=False, discriminator=False, ): norm_layer_name = '' if norm_layer == 'batch_norm': norm_layer = nn.BatchNorm.partial(use_running_average=not train) norm_layer_name = 'bn' elif norm_layer == 'group_norm': norm_layer = nn.GroupNorm.partial(num_groups=16) norm_layer_name = 'gn' layer_activations = collections.OrderedDict() input_is_set = False current_rep_key = 'input' if input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'init_conv' if input_is_set: x = nn.Conv( x, 16, kernel_size=kernel_size, strides=strides, padding='SAME', name='init_conv') if maxpool: x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'l1' if input_is_set: x = WideResnetGroup( x, blocks_per_group, 16 * channel_multiplier, dropout_rate=dropout_rate, norm_layer=norm_layer, train=train, name='l1') layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'l2' if input_is_set: x = WideResnetGroup( x, blocks_per_group, 32 * channel_multiplier, (2, 2), dropout_rate=dropout_rate, norm_layer=norm_layer, train=train, name='l2') layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'l3' if input_is_set: x = WideResnetGroup( x, blocks_per_group, 64 * channel_multiplier, (2, 2), dropout_rate=dropout_rate, norm_layer=norm_layer, train=train, name='l3') layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'l4' if input_is_set: x = norm_layer(x, name=f'{norm_layer_name}') x = jax.nn.relu(x) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key # DANN module if has_discriminator: z = dann_utils.flip_grad_identity(x) z = nn.Dense(z, 2, name='disc_l1', bias=True) z = nn.relu(z) z = nn.Dense(z, 2, name='disc_l2', bias=True) current_rep_key = 'head' if input_is_set: x = nn.Dense(x, num_outputs, dtype=dtype, name='head') else: x = inputs layer_activations[current_rep_key] = x rep_key = current_rep_key logging.warn('Input was never used') outputs = x if return_activations: outputs = (x, layer_activations, rep_key) if discriminator and has_discriminator: outputs = outputs + (z,) else: del layer_activations if discriminator and has_discriminator: outputs = (x, z) if discriminator and (not has_discriminator): raise ValueError( 'Incosistent values passed for discriminator and has_discriminator') return outputs
def apply(self, x, num_outputs, pyramid_alpha=200, pyramid_depth=272, train=True): assert (pyramid_depth - 2) % 9 == 0 # Shake-drop hyper-params mask_prob = 0.5 alpha_min, alpha_max = (-1.0, 1.0) beta_min, beta_max = (0.0, 1.0) # Bottleneck network size blocks_per_group = (pyramid_depth - 2) // 9 # See Eqn 2 in https://arxiv.org/abs/1610.02915 num_channels = 16 # N in https://arxiv.org/abs/1610.02915 total_blocks = blocks_per_group * 3 delta_channels = pyramid_alpha / total_blocks x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv') x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, name='init_bn') layer_num = 1 for block_i in range(blocks_per_group): num_channels += delta_channels layer_mask_prob = _calc_shakedrop_mask_prob( layer_num, total_blocks, mask_prob) x = BottleneckShakeDrop(x, int(num_channels), (1, 1), layer_mask_prob, alpha_min, alpha_max, beta_min, beta_max, train=train) layer_num += 1 for block_i in range(blocks_per_group): num_channels += delta_channels layer_mask_prob = _calc_shakedrop_mask_prob( layer_num, total_blocks, mask_prob) x = BottleneckShakeDrop(x, int(num_channels), ((2, 2) if block_i == 0 else (1, 1)), layer_mask_prob, alpha_min, alpha_max, beta_min, beta_max, train=train) layer_num += 1 for block_i in range(blocks_per_group): num_channels += delta_channels layer_mask_prob = _calc_shakedrop_mask_prob( layer_num, total_blocks, mask_prob) x = BottleneckShakeDrop(x, int(num_channels), ((2, 2) if block_i == 0 else (1, 1)), layer_mask_prob, alpha_min, alpha_max, beta_min, beta_max, train=train) layer_num += 1 assert layer_num - 1 == total_blocks x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, name='final_bn') x = jax.nn.relu(x) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(x, num_outputs) return x