def apply(self, x, reduction=16): num_channels = x.shape[-1] y = x.mean(axis=(1, 2)) y = nn.Dense(y, features=num_channels // reduction, bias=False) y = nn.relu(y) y = nn.Dense(y, features=num_channels, bias=False) y = nn.sigmoid(y) return x * y[:, None, None, :]
def get_logits(code_embeddings, length): # code_embeddings.shape: length, emb_dim initial_carry_e = encoder.initialize_carry( jax.random.PRNGKey(0), encoder_cells, (), emb_dim) def apply_encoder(carry, inp): i = carry[1] c1, o1 = encoder(carry[0], inp) return jax.tree_multimap( lambda x_new, x_old: jnp.where(i < length, x_new, x_old), ((c1, i+1), o1), (carry, inp) ) (encoder_state, unused_i), unused_outputs = ( jax.lax.scan( apply_encoder, (initial_carry_e, 0), code_embeddings ) ) decoder_inputs = jnp.zeros((output_length, emb_dim)) unused_carry, decoder_outputs = jax.lax.scan( decoder, encoder_state, decoder_inputs) logits = nn.Dense( decoder_outputs, output_token_vocabulary_size, kernel_init=nn.initializers.normal( stddev=config.initialization.maxval), bias_init=nn.initializers.zeros, name='output_layer') return logits
def apply(self, x, act, normalize, temb=None, out_ch=None, conv_shortcut=False, dropout=0.1, train=True, skip_rescale=False, init_scale=0.): B, H, W, C = x.shape out_ch = out_ch if out_ch else C h = act(normalize(x, num_groups=min(x.shape[-1] // 4, 32))) h = conv3x3(h, out_ch) # Add bias to each feature map conditioned on the time embedding if temb is not None: h += nn.Dense( act(temb), out_ch, kernel_init=default_init())[:, None, None, :] h = act(normalize(h, num_groups=min(h.shape[-1] // 4, 32))) h = nn.dropout(h, dropout, deterministic=not train) h = conv3x3(h, out_ch, init_scale=init_scale) if C != out_ch: if conv_shortcut: x = conv3x3(x, out_ch) else: x = NIN(x, out_ch) if not skip_rescale: return x + h else: return (x + h) / np.sqrt(2.)
def apply(self, x, blocks_per_group, channel_multiplier, num_outputs, dropout_rate=0.0, train=True): x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv') x = WideResnetGroup(x, blocks_per_group, 16 * channel_multiplier, dropout_rate=dropout_rate, train=train) x = WideResnetGroup(x, blocks_per_group, 32 * channel_multiplier, (2, 2), dropout_rate=dropout_rate, train=train) x = WideResnetGroup(x, blocks_per_group, 64 * channel_multiplier, (2, 2), dropout_rate=dropout_rate, train=train) x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5) x = jax.nn.relu(x) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(x, num_outputs) return x
def apply(self, x, blocks_per_group, channel_multiplier, num_outputs, train=True): x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv') x = WideResnetShakeShakeGroup(x, blocks_per_group, 16 * channel_multiplier, train=train) x = WideResnetShakeShakeGroup(x, blocks_per_group, 32 * channel_multiplier, (2, 2), train=train) x = WideResnetShakeShakeGroup(x, blocks_per_group, 64 * channel_multiplier, (2, 2), train=train) x = jax.nn.relu(x) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(x, num_outputs) return x
def apply(self, x): x = nn.Conv(x, features=32, kernel_size=(3, 3), name="conv") x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten. x = nn.Dense(x, 128, name="fc") return x
def apply(self, x, act, normalize, temb=None, out_ch=None, conv_shortcut=False, dropout=0.5, train=True): B, H, W, C = x.shape out_ch = out_ch if out_ch else C h = act(normalize(x)) h = ddpm_conv3x3(h, out_ch) # Add bias to each feature map conditioned on the time embedding if temb is not None: h += nn.Dense(act(temb), out_ch, kernel_init=default_init())[:, None, None, :] h = act(normalize(h)) h = nn.dropout(h, dropout, deterministic=not train) h = ddpm_conv3x3(h, out_ch, init_scale=0.) if C != out_ch: if conv_shortcut: x = ddpm_conv3x3(x, out_ch) else: x = NIN(x, out_ch) return x + h
def apply(self, inputs, mlp_dim, out_dim=None, dropout_rate=0.1, deterministic=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)): """Applies Transformer MlpBlock module.""" actual_out_dim = inputs.shape[-1] if out_dim is None else out_dim x = nn.Dense(inputs, mlp_dim, kernel_init=kernel_init, bias_init=bias_init) x = nn.gelu(x) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) output = nn.Dense( x, actual_out_dim, kernel_init=kernel_init, bias_init=bias_init) output = nn.dropout(output, rate=dropout_rate, deterministic=deterministic) return output
def apply(self, x, act, normalize, up=False, down=False, temb=None, out_ch=None, dropout=0.1, fir=False, fir_kernel=[1, 3, 3, 1], train=True, skip_rescale=True, init_scale=0.): B, H, W, C = x.shape out_ch = out_ch if out_ch else C h = act(normalize(x, num_groups=min(x.shape[-1] // 4, 32))) if up: if fir: h = up_or_down_sampling.upsample_2d(h, fir_kernel, factor=2) x = up_or_down_sampling.upsample_2d(x, fir_kernel, factor=2) else: h = up_or_down_sampling.naive_upsample_2d(h, factor=2) x = up_or_down_sampling.naive_upsample_2d(x, factor=2) elif down: if fir: h = up_or_down_sampling.downsample_2d(h, fir_kernel, factor=2) x = up_or_down_sampling.downsample_2d(x, fir_kernel, factor=2) else: h = up_or_down_sampling.naive_downsample_2d(h, factor=2) x = up_or_down_sampling.naive_downsample_2d(x, factor=2) h = conv3x3(h, out_ch) # Add bias to each feature map conditioned on the time embedding if temb is not None: h += nn.Dense( act(temb), out_ch, kernel_init=default_init())[:, None, None, :] h = act(normalize(h, num_groups=min(h.shape[-1] // 4, 32))) h = nn.dropout(h, dropout, deterministic=not train) h = conv3x3(h, out_ch, init_scale=init_scale) if C != out_ch or up or down: x = conv1x1(x, out_ch) if not skip_rescale: return x + h else: return (x + h) / np.sqrt(2.)
def apply(self, x, num_outputs, train=True): x = nn.Conv(x, self.NUM_FILTERS, (7, 7), (2, 2), bias=False, name='init_conv') x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, name='init_bn') x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(self.BLOCK_SIZES): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = BottleneckBlock(x, self.NUM_FILTERS * 2**i, strides=strides, groups=self.GROUPS, base_width=self.WIDTH_PER_GROUP, train=train) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_outputs, name='clf') return x
def apply(self, x, *, self_attention_module, dim_intermediate, is_training, dropout_rate=0.1, use_pre_layernorm=False, layernorm_epsilon=1e-6, with_aux_outputs=True): """Compute self-attention with a feed-forward network on top. Args: x: Input representations. self_attention_module: Self-Attention layer. dim_intermediate: Size of the intermediate layer of the feed forward. is_training: Wether to enable dropout. dropout_rate: Dropout probability. use_pre_layernorm: Use pre layer norm from https://arxiv.org/abs/2002.04745. layernorm_epsilon: Epsilon parameter for all the layer norms. with_aux_outputs: Whether the self_attention_module has an aux output. Returns: New representations in a jnp.array of same shape as `x`. """ dim_hidden = x.shape[-1] use_pre_ln = use_pre_layernorm use_post_ln = not use_pre_ln def apply_ln_if(pred, x, name): if pred: return nn.LayerNorm(x, epsilon=layernorm_epsilon, name=name) else: return x # attention x = apply_ln_if(use_pre_ln, x, "ln_pre_att") x_att = self_attention_module(x) if with_aux_outputs: x_att, output_aux = x_att # dropout norm and add x_att = nn.dropout(x_att, dropout_rate, deterministic=not is_training) x = x + x_att x = apply_ln_if(use_post_ln, x, "ln_post_att") # feed forward x_ffn = x x_ffn = apply_ln_if(use_pre_ln, x, "ln_pre_ffn") x_ffn = nn.Dense(x_ffn, dim_intermediate, name="ff_1") x_ffn = jax.nn.relu(x_ffn) x_ffn = nn.Dense(x_ffn, dim_hidden, name="ff_2") # dropout norm and add x_ffn = nn.dropout(x_ffn, dropout_rate, deterministic=not is_training) x = x + x_ffn x = apply_ln_if(use_post_ln, x, "ln_post_ffn") if with_aux_outputs: output = x, output_aux else: output = x return output
def apply(self, x, labels, y=None, config=None, train=True): # config parsing nf = config.model.nf act = get_act(config) normalize = get_normalization(config) sigmas = utils.get_sigmas(config) nf = config.model.nf ch_mult = config.model.ch_mult num_res_blocks = config.model.num_res_blocks attn_resolutions = config.model.attn_resolutions dropout = config.model.dropout resamp_with_conv = config.model.resamp_with_conv num_resolutions = len(ch_mult) conditional = config.model.conditional # noise-conditional fir = config.model.fir fir_kernel = config.model.fir_kernel skip_rescale = config.model.skip_rescale resblock_type = config.model.resblock_type progressive = config.model.progressive progressive_input = config.model.progressive_input init_scale = config.model.init_scale assert progressive.lower() in ['none', 'output_skip', 'residual'] assert config.model.embedding_type.lower() in [ 'gaussian', 'positional' ] combine_method = config.model.progressive_combine combiner = Combine.partial(method=combine_method) # timestep/noise_level embedding if config.model.embedding_type == 'gaussian': # Gaussian Fourier features embeddings. used_sigmas = sigmas[labels] temb = layersv3.GaussianFourierProjection( jnp.log(used_sigmas), embedding_size=nf, scale=config.model.fourier_scale) elif config.model.embedding_type == 'positional': # Sinusoidal positional embeddings. timesteps = labels temb = layers.get_timestep_embedding(timesteps, nf) else: raise ValueError( f'embedding type {config.model.embedding_type} unknown.') temb = nn.Dense(temb, nf * 4, kernel_init=default_initializer()) temb = nn.Dense(act(temb), nf * 4, kernel_init=default_initializer()) if y is not None: # class-conditional image generation class_embed = nn.Embed(y, config.data.num_classes, nf * 4) class_embed = nn.Dense(class_embed, nf * 4, kernel_init=default_initializer()) class_embed = nn.Dense(act(class_embed), nf * 4, kernel_init=default_initializer()) temb += class_embed AttnBlock = layersv3.AttnBlockv3.partial(normalize=normalize, init_scale=init_scale, skip_rescale=skip_rescale) Upsample = layersv3.Upsample.partial(with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) if progressive == 'output_skip': pyramid_upsample = layersv3.Upsample.partial(fir=fir, fir_kernel=fir_kernel, with_conv=False) elif progressive == 'residual': pyramid_upsample = layersv3.Upsample.partial(fir=fir, fir_kernel=fir_kernel, with_conv=True) Downsample = layersv3.Downsample.partial(with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) if progressive_input == 'input_skip': pyramid_downsample = layersv3.Downsample.partial( fir=fir, fir_kernel=fir_kernel, with_conv=False) elif progressive_input == 'residual': pyramid_downsample = layersv3.Downsample.partial( fir=fir, fir_kernel=fir_kernel, with_conv=True) if resblock_type == 'ddpm': ResnetBlock = ResnetBlockDDPM.partial( act=act, normalize=normalize, dropout=dropout, temb=temb if conditional else None, train=train, init_scale=init_scale, skip_rescale=skip_rescale) elif resblock_type == 'biggan': ResnetBlock = ResnetBlockBigGAN.partial( act=act, normalize=normalize, temb=temb if conditional else None, train=train, dropout=dropout, fir=fir, fir_kernel=fir_kernel, init_scale=init_scale, skip_rescale=skip_rescale) else: raise ValueError(f'resblock_type {resblock_type} unrecognized.') if not config.data.centered: # If input data is in [0, 1] x = 2 * x - 1. # Downsampling block input_pyramid = None if progressive_input != 'none': input_pyramid = x hs = [conv3x3(x, nf)] for i_level in range(num_resolutions): # Residual blocks for this resolution for i_block in range(num_res_blocks): h = ResnetBlock(hs[-1], out_ch=nf * ch_mult[i_level]) if h.shape[1] in attn_resolutions: h = AttnBlock(h) hs.append(h) if i_level != num_resolutions - 1: if resblock_type == 'ddpm': h = Downsample(hs[-1]) else: h = ResnetBlock(hs[-1], down=True) if progressive_input == 'input_skip': input_pyramid = pyramid_downsample(input_pyramid) h = combiner(input_pyramid, h) elif progressive_input == 'residual': input_pyramid = pyramid_downsample(input_pyramid, out_ch=h.shape[-1]) if skip_rescale: input_pyramid = (input_pyramid + h) / np.sqrt(2.) else: input_pyramid = input_pyramid + h h = input_pyramid hs.append(h) h = hs[-1] h = ResnetBlock(h) h = AttnBlock(h) h = ResnetBlock(h) pyramid = None # Upsampling block for i_level in reversed(range(num_resolutions)): for i_block in range(num_res_blocks + 1): h = ResnetBlock(jnp.concatenate([h, hs.pop()], axis=-1), out_ch=nf * ch_mult[i_level]) if h.shape[1] in attn_resolutions: h = AttnBlock(h) if progressive != 'none': if i_level == num_resolutions - 1: if progressive == 'output_skip': pyramid = conv3x3(act( normalize(h, num_groups=min(h.shape[-1] // 4, 32))), x.shape[-1], bias=True, init_scale=init_scale) elif progressive == 'residual': pyramid = conv3x3(act( normalize(h, num_groups=min(h.shape[-1] // 4, 32))), h.shape[-1], bias=True) else: raise ValueError(f'{progressive} is not a valid name.') else: if progressive == 'output_skip': pyramid = pyramid_upsample(pyramid) pyramid = pyramid + conv3x3(act( normalize(h, num_groups=min(h.shape[-1] // 4, 32))), x.shape[-1], bias=True, init_scale=init_scale) elif progressive == 'residual': pyramid = pyramid_upsample(pyramid, out_ch=h.shape[-1]) if skip_rescale: pyramid = (pyramid + h) / np.sqrt(2.) else: pyramid = pyramid + h h = pyramid else: raise ValueError(f'{progressive} is not a valid name') if i_level != 0: if resblock_type == 'ddpm': h = Upsample(h) else: h = ResnetBlock(h, up=True) assert not hs if progressive == 'output_skip': h = pyramid else: h = act(normalize(h, num_groups=min(h.shape[-1] // 4, 32))) h = conv3x3(h, x.shape[-1], init_scale=init_scale) if config.model.scale_by_sigma: used_sigmas = sigmas[labels].reshape( (x.shape[0], *([1] * len(x.shape[1:])))) h = h / used_sigmas return h
def apply(self, x, *, train, num_classes, block_class=BottleneckResNetImageNetBlock, stage_sizes, width_factor=1, normalization='bn', activation_f=None, std_penalty_mult=0, use_residual=1, bias_scale=0.0, weight_norm='none', compensate_padding=True, softplus_scale=None, no_head=False, zero_inits=True): """Construct ResNet V1 with `num_classes` outputs.""" self._stage_sizes = stage_sizes if std_penalty_mult > 0: raise NotImplementedError( 'std_penalty_mult not supported for ResNetImageNet') width = 64 * width_factor # Root block. activation_f = get_activation_f(activation_f, train, softplus_scale, bias_scale) norm = get_norm(activation_f, normalization, train) conv = get_conv(activation_f, bias_scale, weight_norm, compensate_padding, normalization) x = conv(x, width, kernel_size=(7, 7), strides=(2, 2), name='init_conv') x = norm(x, name='init_bn') if compensate_padding: # NOTE: this leads to lower performance. x = nn.avg_pool(x, (2, 2), strides=(2, 2), padding='SAME') else: x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') # Stages. for i, stage_size in enumerate(stage_sizes): x = ResNetStage( x, stage_size, filters=width * 2**i, block_class=block_class, first_block_strides=(1, 1) if i == 0 else (2, 2), train=train, name=f'stage{i + 1}', conv=conv, norm=norm, activation_f=activation_f, use_residual=use_residual, zero_inits=zero_inits, ) if not no_head: # Head. x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_classes, kernel_init=nn.initializers.zeros if zero_inits else nn.initializers.lecun_normal(), name='head') return x, 0, {}
def apply(self, x, blocks_per_group, channel_multiplier, num_outputs, dropout_rate=0.0, normalization='bn', activation_f=None, std_penalty_mult=0, use_residual=1, train=True, bias_scale=0.0, weight_norm='none', no_head=False, compensate_padding=True, softplus_scale=None): penalty = 0 activation_f = get_activation_f(activation_f, train, softplus_scale, bias_scale) norm = get_norm(activation_f, normalization, train) conv = get_conv(activation_f, bias_scale, weight_norm, compensate_padding, normalization) x = conv(x, 16 * channel_multiplier, (3, 3), padding='SAME', name='init_conv') x, g_penalty = WideResnetGroup(x, blocks_per_group, 16 * channel_multiplier, dropout_rate=dropout_rate, normalization=normalization, activation_f=activation_f, std_penalty_mult=std_penalty_mult, use_residual=use_residual, train=train, bias_scale=bias_scale, weight_norm=weight_norm) penalty += g_penalty x, g_penalty = WideResnetGroup(x, blocks_per_group, 32 * channel_multiplier, (2, 2), dropout_rate=dropout_rate, normalization=normalization, activation_f=activation_f, std_penalty_mult=std_penalty_mult, use_residual=use_residual, train=train, bias_scale=bias_scale, weight_norm=weight_norm) penalty += g_penalty x, g_penalty = WideResnetGroup(x, blocks_per_group, 64 * channel_multiplier, (2, 2), dropout_rate=dropout_rate, normalization=normalization, activation_f=activation_f, std_penalty_mult=std_penalty_mult, use_residual=use_residual, train=train, bias_scale=bias_scale, weight_norm=weight_norm) penalty += g_penalty x = norm(x, name='final_norm') if std_penalty_mult > 0: penalty += std_penalty(x) if not no_head: x = activation_f(x, features=x.shape[-1]) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(x, num_outputs) return x, penalty, {}
def apply(self, x, depth, num_outputs, dropout_rate=0.0, normalization='bn', activation_f=None, std_penalty_mult=0, use_residual=1, train=True, bias_scale=0.0, weight_norm='none', filters=16, no_head=False, report_metrics=False, benchmark='cifar10', compensate_padding=True, softplus_scale=None): bn_index = iter(range(1000)) conv_index = iter(range(1000)) summaries = {} summary_ind = [0] def add_summary(name, val): """Summarize statistics of tensor.""" if report_metrics: assert val.ndim == 4, ( 'Assuming 4D inputs with channels last, got %s' % str(val.shape)) assert val.shape[1] == val.shape[ 2], 'Assuming 4D inputs with channels last' summaries['%s_%d_mean_abs' % (name, summary_ind[0] // 2)] = jnp.mean( jnp.abs(jnp.mean(val, axis=(0, 1, 2)))) summaries['%s_%d_mean_std' % (name, summary_ind[0] // 2)] = jnp.mean( jnp.std(val, axis=(0, 1, 2))) summary_ind[0] += 1 penalty = 0 activation_f = get_activation_f(activation_f, train, softplus_scale, bias_scale) norm = get_norm(activation_f, normalization, train) conv = get_conv(activation_f, bias_scale, weight_norm, compensate_padding, normalization) def resnet_layer( inputs, penalty, filters, kernel_size=3, strides=1, activation=None, ): """2D Convolution-Batch Normalization-Activation stack builder.""" x = inputs x = conv(x, filters, (kernel_size, kernel_size), strides=(strides, strides), padding='SAME', name='conv%d' % next(conv_index)) x = norm(x, name='norm%d' % next(bn_index)) add_summary('postnorm', x) if std_penalty_mult > 0: penalty += std_penalty(x) if activation: x = activation_f(x, features=x.shape[-1]) add_summary('postact', x) return x, penalty # Main network code. num_res_blocks = (depth - 2) // 6 if (depth - 2) % 6 != 0: raise ValueError('depth must be 6n+2 (e.g. 20, 32, 44).') inputs = x add_summary('input', x) add_summary('inputb', x) if benchmark in ['cifar10', 'cifar100']: x, penalty = resnet_layer(inputs, penalty, filters=filters, activation=True) head_kernel_init = nn.initializers.lecun_normal() elif benchmark in ['imagenet']: head_kernel_init = nn.initializers.zeros x, penalty = resnet_layer(inputs, penalty, filters=filters, activation=False, kernel_size=7, strides=2) # TODO(basv): evaluate max pool v/s avg_pool in an experiment? # if compensate_padding: # x = nn.avg_pool(x, (2, 2), strides=(2, 2), padding="VALID") # else: x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') else: raise ValueError('Model def not prepared for benchmark %s' % benchmark) for stack in range(3): for res_block in range(num_res_blocks): strides = 1 if stack > 0 and res_block == 0: # First layer but not first stack. strides = 2 # Downsample. y, penalty = resnet_layer( x, penalty, filters=filters, strides=strides, activation=True, ) y, penalty = resnet_layer( y, penalty, filters=filters, activation=False, ) if stack > 0 and res_block == 0: # First layer but not first stack. # Linear projection residual shortcut to match changed dims. x, penalty = resnet_layer( x, penalty, filters=filters, kernel_size=1, strides=strides, activation=False, ) if use_residual == 1: # Apply an up projection in case of channel mismatch x = x + y elif use_residual == 2: x = (x + y) / jnp.sqrt( 1**2 + 1**2) # Sum of independent normals. else: x = y add_summary('postres', x) x = activation_f(x, features=x.shape[-1]) add_summary('postresact', x) filters *= 2 # V1 does not use BN after last shortcut connection-ReLU. if not no_head: x = jnp.mean(x, axis=(1, 2)) add_summary('postpool', x) x = x.reshape((x.shape[0], -1)) x = nn.Dense(x, num_outputs, kernel_init=head_kernel_init) return x, penalty, summaries
def apply(self, x): x = nn.Dense(x, features=256) x = nn.relu(x) x = nn.Dense(x, features=256) x = nn.relu(x) return nn.Dense(x, features=2)
def apply(self, x, labels, config, train=True): # config parsing nf = config.model.nf act = get_act(config) normalize = get_normalization(config) sigmas = utils.get_sigmas(config) nf = config.model.nf ch_mult = config.model.ch_mult num_res_blocks = config.model.num_res_blocks attn_resolutions = config.model.attn_resolutions dropout = config.model.dropout resamp_with_conv = config.model.resamp_with_conv num_resolutions = len(ch_mult) # timestep/scale embedding timesteps = labels # sigmas[labels] / jnp.max(sigmas) temb = layers.get_timestep_embedding(timesteps, nf) temb = nn.Dense(temb, nf * 4, kernel_init=default_initializer()) temb = nn.Dense(act(temb), nf * 4, kernel_init=default_initializer()) AttnBlock = layers.AttnBlock.partial(normalize=normalize) if config.model.conditional: # Condition on noise levels. ResnetBlock = ResnetBlockDDPM.partial(act=act, normalize=normalize, dropout=dropout, temb=temb, train=train) else: # Do not condition on noise levels explicitly. ResnetBlock = ResnetBlockDDPM.partial(act=act, normalize=normalize, dropout=dropout, temb=None, train=train) if config.data.centered: # Input is in [-1, 1] h = x else: # Input is in [0, 1] h = 2 * x - 1. # Downsampling block hs = [conv3x3(h, nf)] for i_level in range(num_resolutions): # Residual blocks for this resolution for i_block in range(num_res_blocks): h = ResnetBlock(hs[-1], out_ch=nf * ch_mult[i_level]) if h.shape[1] in attn_resolutions: h = AttnBlock(h) hs.append(h) if i_level != num_resolutions - 1: hs.append(Downsample(hs[-1], with_conv=resamp_with_conv)) h = hs[-1] h = ResnetBlock(h) h = AttnBlock(h) h = ResnetBlock(h) # Upsampling block for i_level in reversed(range(num_resolutions)): for i_block in range(num_res_blocks + 1): h = ResnetBlock(jnp.concatenate([h, hs.pop()], axis=-1), out_ch=nf * ch_mult[i_level]) if h.shape[1] in attn_resolutions: h = AttnBlock(h) if i_level != 0: h = Upsample(h, with_conv=resamp_with_conv) assert not hs h = act(normalize(h)) h = conv3x3(h, x.shape[-1], init_scale=0.) if config.model.scale_by_sigma: # Divide the output by sigmas. Useful for training with the NCSN loss. # The DDPM loss scales the network output by sigma in the loss function, # so no need of doing it here. used_sigmas = sigmas[labels].reshape( (x.shape[0], *([1] * len(x.shape[1:])))) h = h / used_sigmas return h
def apply(self, x, *, num_layers, num_heads, dim_hidden=None, pooling=Pooling.NONE, is_training): """Transformer. Args: x: Input tensor of shape (batch, sequence_length, dim). num_layers: Number of layers. num_heads: Number of attention heads. dim_hidden: Dimension of the representations, default to last dimension of `x`. pooling: Optional pooling of the output tokens representations to obtain a single representation of the sequence. is_training: Whether the model is being trained. Returns: The sequences representations (batch, sequence_length, dim_hidden). """ pooling = Pooling(pooling) dim_hidden = dim_hidden or x.shape[-1] if dim_hidden != x.shape[-1]: x = nn.Dense(x, features=dim_hidden) if pooling == Pooling.CLS: cls_token = self.param("cls_token", shape=(1, 1, dim_hidden), initializer=jax.nn.initializers.normal( 1. / dim_hidden)) batch_size = x.shape[0] cls_token = cls_token.repeat(batch_size, axis=0) x = jnp.concatenate([cls_token, x], axis=1) self_attention = nn.MultiHeadDotProductAttention.partial( num_heads=num_heads, deterministic=not is_training, inputs_kv=None) layer = TransformerLayer.partial(self_attention_module=self_attention, dim_intermediate=2 * dim_hidden, with_aux_outputs=False) transformer = stacked_layers.StackedLayers.partial( layer=layer, num_layers=num_layers, with_aux_outputs=False) representations = transformer(x) if pooling == Pooling.NONE: output = representations if pooling == Pooling.MEAN: output = representations.mean(axis=1) if pooling == Pooling.SUM: output = representations.sum(axis=1) if pooling == Pooling.MAX: output = representations.max(axis=1) if pooling == Pooling.CLS: output = representations[:, 0, :] return output
def apply(self, x): x = nn.Dense(x, hidden_reps_dim, bias=True, name='l1') x = nn.relu(x) x = nn.Dense(x, hidden_reps_dim, bias=True, name='l2') return x
def apply(self, x): x = nn.Dense(x, hidden_reps_dim, name='l1', bias=True) return x
def apply(self, inputs, num_outputs, num_filters=64, num_layers=50, dropout_rate=0.0, input_dropout_rate=0.0, train=True, dtype=jnp.float32, head_bias_init=jnp.zeros, return_activations=False, input_layer_key='input', has_discriminator=False, discriminator=False): """Apply a ResNet network on the input. Args: inputs: jnp array; Inputs. num_outputs: int; Number of output units. num_filters: int; Determines base number of filters. Number of filters in block i is num_filters * 2 ** i. num_layers: int; Number of layers (should be one of the predefined ones.) dropout_rate: float; Rate of dropping out the output of different hidden layers. input_dropout_rate: float; Rate of dropping out the input units. train: bool; Is train? dtype: jnp type; Type of the outputs. head_bias_init: fn(rng_key, shape)--> jnp array; Initializer for head bias parameters. return_activations: bool; If True hidden activation are also returned. input_layer_key: str; Determines where to plugin the input (this is to enable providing inputs to slices of the model). If `input_layer_key` is `layer_i` we assume the inputs are the activations of `layer_i` and pass them to `layer_{i+1}`. has_discriminator: bool; Whether the model should have discriminator layer. discriminator: bool; Whether we should return discriminator logits. Returns: Unnormalized Logits with shape `[bs, num_outputs]`, if return_activations: Logits, dict of hidden activations and the key to the representation(s) which will be used in as ``The Representation'', e.g., for computing losses. """ if num_layers not in ResNet._block_size_options: raise ValueError('Please provide a valid number of layers') block_sizes = ResNet._block_size_options[num_layers] layer_activations = collections.OrderedDict() input_is_set = False current_rep_key = 'input' if input_layer_key == current_rep_key: x = inputs input_is_set = True if input_is_set: # Input dropout x = nn.dropout(x, input_dropout_rate, deterministic=not train) layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'init_conv' if input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_is_set: # First block x = nn.Conv(x, num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], bias=False, dtype=dtype, name='init_conv') x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, dtype=dtype, name='init_bn') x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') layer_activations[current_rep_key] = x rep_key = current_rep_key # Residual blocks for i, block_size in enumerate(block_sizes): # Stage i (each stage contains blocks of the same size). for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) current_rep_key = f'block_{i + 1}+{j}' if input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_is_set: x = ResidualBlock(x, num_filters * 2**i, strides=strides, dropout_rate=dropout_rate, train=train, dtype=dtype, name=f'block_{i + 1}_{j}') layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'avg_pool' if input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_is_set: # Global Average Pool x = jnp.mean(x, axis=(1, 2)) layer_activations[current_rep_key] = x rep_key = current_rep_key # DANN module if has_discriminator: z = dann_utils.flip_grad_identity(x) z = nn.Dense(z, 2, name='disc_l1', bias=True) z = nn.relu(z) z = nn.Dense(z, 2, name='disc_l2', bias=True) current_rep_key = 'head' if input_layer_key == current_rep_key: x = inputs layer_activations[current_rep_key] = x rep_key = current_rep_key logging.warn('Input was never used') elif input_is_set: x = nn.Dense(x, num_outputs, dtype=dtype, bias_init=head_bias_init, name='head') # Make sure that the output is float32, even if our previous computations # are in float16, or other types. x = jnp.asarray(x, jnp.float32) outputs = x if return_activations: outputs = (x, layer_activations, rep_key) if discriminator and has_discriminator: outputs = outputs + (z, ) else: del layer_activations if discriminator and has_discriminator: outputs = (x, z) if discriminator and (not has_discriminator): raise ValueError( 'Incosistent values passed for discriminator and has_discriminator' ) return outputs
def apply(self, x, *, patch_size, k, downscale, scorer_has_se, normalization_str="identity", selection_method, selection_method_kwargs=None, selection_method_inference=None, patch_dropout=0., hard_topk_probability=0., random_patch_probability=0., use_iterative_extraction, append_position_to_input, feature_network, aggregation_method, aggregation_method_kwargs=None, train): """Process a high resolution image by selecting a subset of useful patches. This model processes the input as follow: 1. Compute scores per patch on a downscaled version of the input. 2. Select "important" patches using sampling or top-k methods. 3. Extract the patches from the high-resolution image. 4. Compute representation vector for each patch with a feature network. 5. Aggregate the patch representation to obtain an image representation. Args: x: Input tensor of shape (batch, height, witdh, channels). patch_size: Size of the (squared) patches to extract. k: Number of patches to extract per image. downscale: Downscale multiplier for the input of the scorer network. scorer_has_se: Whether scorer network has Squeeze-excite layers. normalization_str: String specifying the normalization of the scores. selection_method: Method that selects which patches should be extracted, based on their scores. Either returns indices (hard selection) or indicators vectors (which could yield interpolated patches). selection_method_kwargs: Keyword args for the selection_method. selection_method_inference: Selection method used at inference. patch_dropout: Probability to replace a patch by 0 values. hard_topk_probability: Probability to use the true topk on the scores to select the patches. This operation has no gradient so scorer's weights won't be trained. random_patch_probability: Probability to replace each patch by a random patch in the image during training. use_iterative_extraction: If True, uses a for loop instead of patch indexing for memory efficiency. append_position_to_input: Append normalized (height, width) position to the channels of the input. feature_network: Network to be applied on each patch individually to obtain patch representation vectors. aggregation_method: Method to aggregate the representations of the k patches of each image to obtain the image representation. aggregation_method_kwargs: Keywords arguments for aggregation_method. train: If the model is being trained. Disable dropout otherwise. Returns: A representation vector for each image in the batch. """ selection_method = SelectionMethod(selection_method) aggregation_method = AggregationMethod(aggregation_method) if selection_method_inference: selection_method_inference = SelectionMethod( selection_method_inference) selection_method_kwargs = selection_method_kwargs or {} aggregation_method_kwargs = aggregation_method_kwargs or {} stats = {} # Compute new dimension of the scoring image. b, h, w, c = x.shape scoring_shape = (b, h // downscale, w // downscale, c) # === Compute the scores with a small CNN. if selection_method == SelectionMethod.RANDOM: scores_h, scores_w = Scorer.compute_output_size( h // downscale, w // downscale) num_patches = scores_h * scores_w else: # Downscale input to run scorer on. scoring_x = jax.image.resize(x, scoring_shape, method="bilinear") scores = Scorer(scoring_x, use_squeeze_excite=scorer_has_se, name="scorer") flatten_scores = einops.rearrange(scores, "b h w -> b (h w)") num_patches = flatten_scores.shape[-1] scores_h, scores_w = scores.shape[1:3] # Compute entropy before normalization prob_scores = jax.nn.softmax(flatten_scores) stats["entropy_before_normalization"] = jax.scipy.special.entr( prob_scores).sum(axis=1).mean(axis=0) # Normalize the flatten scores normalization_fn = create_normalization_fn(normalization_str) flatten_scores = normalization_fn(flatten_scores) scores = flatten_scores.reshape(scores.shape) stats["scores"] = scores[Ellipsis, None] # Concatenate height and width position to the input channels. if append_position_to_input: coords = utils.create_grid([h, w], value_range=(0., 1.)) x = jnp.concatenate( [x, coords[jnp.newaxis, Ellipsis].repeat(b, axis=0)], axis=-1) c += 2 # Overwrite the selection method at inference if selection_method_inference and not train: selection_method = selection_method_inference # === Patch selection # Select the patches by sampling or top-k. Some methods returns the indices # of the selected patches, other methods return indicator vectors. extract_by_indices = selection_method in [ SelectionMethod.HARD_TOPK, SelectionMethod.RANDOM ] if selection_method is SelectionMethod.SINKHORN_TOPK: indicators = select_patches_sinkhorn_topk( flatten_scores, k=k, **selection_method_kwargs) elif selection_method is SelectionMethod.PERTURBED_TOPK: sigma = selection_method_kwargs["sigma"] num_samples = selection_method_kwargs["num_samples"] sigma *= self.state("sigma_mutiplier", shape=(), initializer=nn.initializers.ones).value stats["sigma"] = sigma indicators = select_patches_perturbed_topk(flatten_scores, k=k, sigma=sigma, num_samples=num_samples) elif selection_method is SelectionMethod.HARD_TOPK: indices = select_patches_hard_topk(flatten_scores, k=k) elif selection_method is SelectionMethod.RANDOM: batch_random_indices_fn = jax.vmap( functools.partial(jax.random.choice, a=num_patches, shape=(k, ), replace=False)) indices = batch_random_indices_fn( jax.random.split(nn.make_rng(), b)) # Compute scores entropy for regularization if selection_method not in [SelectionMethod.RANDOM]: prob_scores = flatten_scores # Normalize the scores if it is not already done. if "softmax" not in normalization_str: prob_scores = jax.nn.softmax(prob_scores) stats["entropy"] = jax.scipy.special.entr(prob_scores).sum( axis=1).mean(axis=0) # Randomly use hard topk at training. if (train and hard_topk_probability > 0 and selection_method not in [SelectionMethod.HARD_TOPK, SelectionMethod.RANDOM]): true_indices = select_patches_hard_topk(flatten_scores, k=k) random_values = jax.random.uniform(nn.make_rng(), (b, )) use_hard = random_values < hard_topk_probability if extract_by_indices: indices = jnp.where(use_hard[:, None], true_indices, indices) else: true_indicators = make_indicators(true_indices, num_patches) indicators = jnp.where(use_hard[:, None, None], true_indicators, indicators) # Sample some random patches during training with random_patch_probability. if (train and random_patch_probability > 0 and selection_method is not SelectionMethod.RANDOM): single_random_patches = functools.partial(jax.random.choice, a=num_patches, shape=(k, ), replace=False) random_indices = jax.vmap(single_random_patches)(jax.random.split( nn.make_rng(), b)) random_values = jax.random.uniform(nn.make_rng(), (b, k)) use_random = random_values < random_patch_probability if extract_by_indices: indices = jnp.where(use_random, random_indices, indices) else: random_indicators = make_indicators(random_indices, num_patches) indicators = jnp.where(use_random[:, None, :], random_indicators, indicators) # === Patch extraction if extract_by_indices: patches = extract_patches_from_indices(x, indices, patch_size=patch_size, grid_shape=(scores_h, scores_w)) indicators = make_indicators(indices, num_patches) else: patches = extract_patches_from_indicators( x, indicators, patch_size, grid_shape=(scores_h, scores_w), iterative=use_iterative_extraction, patch_dropout=patch_dropout, train=train) chex.assert_shape(patches, (b, k, patch_size, patch_size, c)) stats["extracted_patches"] = einops.rearrange( patches, "b k i j c -> b i (k j) c") # Remove position channels for plotting. if append_position_to_input: stats["extracted_patches"] = ( stats["extracted_patches"][Ellipsis, :-2]) # === Compute patch features flatten_patches = einops.rearrange(patches, "b k i j c -> (b k) i j c") representations = feature_network(flatten_patches, train=train) if representations.ndim > 2: collapse_axis = tuple(range(1, representations.ndim - 1)) representations = representations.mean(axis=collapse_axis) representations = einops.rearrange(representations, "(b k) d -> b k d", k=k) stats["patch_representations"] = representations # === Aggregate the k patches # - for sampling we are forced to take an expectation # - for topk we have multiple options: mean, max, transformer. if aggregation_method is AggregationMethod.TRANSFORMER: patch_pos_encoding = nn.Dense(einops.rearrange( indicators, "b d k -> b k d"), features=representations.shape[-1]) chex.assert_equal_shape([representations, patch_pos_encoding]) representations += patch_pos_encoding representations = transformer.Transformer( representations, **aggregation_method_kwargs, is_training=train) elif aggregation_method is AggregationMethod.MEANPOOLING: representations = representations.mean(axis=1) elif aggregation_method is AggregationMethod.MAXPOOLING: representations = representations.max(axis=1) elif aggregation_method is AggregationMethod.SUM_LAYERNORM: representations = representations.sum(axis=1) representations = nn.LayerNorm(representations) representations = nn.Dense(representations, features=representations.shape[-1], name="classification_dense1") representations = nn.swish(representations) return representations, stats
def apply(self, x, config, num_classes, train=True): """Creates a model definition.""" if config.get("append_position_to_input", False): b, h, w, _ = x.shape coords = utils.create_grid([h, w], value_range=(0., 1.)) x = jnp.concatenate( [x, coords[jnp.newaxis, Ellipsis].repeat(b, axis=0)], axis=-1) if config.model.lower() == "cnn": h = models.SimpleCNNImageClassifier(x) h = nn.relu(h) stats = None elif config.model.lower() == "resnet": smallinputs = config.get("resnet.small_inputs", False) blocks = config.get("resnet.blocks", [3, 4, 6, 3]) h = models.ResNet(x, train=train, block_sizes=blocks, small_inputs=smallinputs) h = jnp.mean(h, axis=[1, 2]) # global average pool stats = None elif config.model.lower() == "resnet18": h = models.ResNet18(x, train=train) h = jnp.mean(h, axis=[1, 2]) # global average pool stats = None elif config.model.lower() == "resnet50": h = models.ResNet50(x, train=train) h = jnp.mean(h, axis=[1, 2]) # global average pool stats = None elif config.model.lower() == "ats-traffic": h = models.ATSFeatureNetwork(x, train=train) stats = None elif config.model.lower() == "patchnet": feature_network = { "resnet18": models.ResNet18, "resnet18-fourth": models.ResNet.partial(num_filters=16, block_sizes=(2, 2, 2, 2), block=models.BasicBlock), "resnet50": models.ResNet50, "ats-traffic": models.ATSFeatureNetwork, }[config.feature_network.lower()] selection_method = sample_patches.SelectionMethod( config.selection_method) selection_method_kwargs = {} if selection_method is sample_patches.SelectionMethod.SINKHORN_TOPK: selection_method_kwargs = config.sinkhorn_topk_kwargs if selection_method is sample_patches.SelectionMethod.PERTURBED_TOPK: selection_method_kwargs = config.perturbed_topk_kwargs h, stats = sample_patches.PatchNet( x, patch_size=config.patch_size, k=config.k, downscale=config.downscale, scorer_has_se=config.get("scorer_has_se", False), selection_method=config.selection_method, selection_method_kwargs=selection_method_kwargs, selection_method_inference=config.get( "selection_method_inference", None), normalization_str=config.normalization_str, aggregation_method=config.aggregation_method, aggregation_method_kwargs=config.get( "aggregation_method_kwargs", {}), append_position_to_input=config.get("append_position_to_input", False), feature_network=feature_network, use_iterative_extraction=config.use_iterative_extraction, hard_topk_probability=config.get("hard_topk_probability", 0.), random_patch_probability=config.get("random_patch_probability", 0.), train=train) stats["x"] = x else: raise RuntimeError("Unknown classification model type: %s" % config.model.lower()) out = nn.Dense(h, num_classes, name="final") return out, stats
def apply(self, x, num_outputs, pyramid_alpha=200, pyramid_depth=272, train=True): assert (pyramid_depth - 2) % 9 == 0 # Shake-drop hyper-params mask_prob = 0.5 alpha_min, alpha_max = (-1.0, 1.0) beta_min, beta_max = (0.0, 1.0) # Bottleneck network size blocks_per_group = (pyramid_depth - 2) // 9 # See Eqn 2 in https://arxiv.org/abs/1610.02915 num_channels = 16 # N in https://arxiv.org/abs/1610.02915 total_blocks = blocks_per_group * 3 delta_channels = pyramid_alpha / total_blocks x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv') x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, name='init_bn') layer_num = 1 for block_i in range(blocks_per_group): num_channels += delta_channels layer_mask_prob = _calc_shakedrop_mask_prob( layer_num, total_blocks, mask_prob) x = BottleneckShakeDrop(x, int(num_channels), (1, 1), layer_mask_prob, alpha_min, alpha_max, beta_min, beta_max, train=train) layer_num += 1 for block_i in range(blocks_per_group): num_channels += delta_channels layer_mask_prob = _calc_shakedrop_mask_prob( layer_num, total_blocks, mask_prob) x = BottleneckShakeDrop(x, int(num_channels), ((2, 2) if block_i == 0 else (1, 1)), layer_mask_prob, alpha_min, alpha_max, beta_min, beta_max, train=train) layer_num += 1 for block_i in range(blocks_per_group): num_channels += delta_channels layer_mask_prob = _calc_shakedrop_mask_prob( layer_num, total_blocks, mask_prob) x = BottleneckShakeDrop(x, int(num_channels), ((2, 2) if block_i == 0 else (1, 1)), layer_mask_prob, alpha_min, alpha_max, beta_min, beta_max, train=train) layer_num += 1 assert layer_num - 1 == total_blocks x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, name='final_bn') x = jax.nn.relu(x) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(x, num_outputs) return x
def apply( self, inputs, blocks_per_group, channel_multiplier, num_outputs, kernel_size=(3, 3), strides=None, maxpool=False, dropout_rate=0.0, dtype=jnp.float32, norm_layer='group_norm', train=True, return_activations=False, input_layer_key='input', has_discriminator=False, discriminator=False, ): norm_layer_name = '' if norm_layer == 'batch_norm': norm_layer = nn.BatchNorm.partial(use_running_average=not train) norm_layer_name = 'bn' elif norm_layer == 'group_norm': norm_layer = nn.GroupNorm.partial(num_groups=16) norm_layer_name = 'gn' layer_activations = collections.OrderedDict() input_is_set = False current_rep_key = 'input' if input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'init_conv' if input_is_set: x = nn.Conv( x, 16, kernel_size=kernel_size, strides=strides, padding='SAME', name='init_conv') if maxpool: x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'l1' if input_is_set: x = WideResnetGroup( x, blocks_per_group, 16 * channel_multiplier, dropout_rate=dropout_rate, norm_layer=norm_layer, train=train, name='l1') layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'l2' if input_is_set: x = WideResnetGroup( x, blocks_per_group, 32 * channel_multiplier, (2, 2), dropout_rate=dropout_rate, norm_layer=norm_layer, train=train, name='l2') layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'l3' if input_is_set: x = WideResnetGroup( x, blocks_per_group, 64 * channel_multiplier, (2, 2), dropout_rate=dropout_rate, norm_layer=norm_layer, train=train, name='l3') layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key current_rep_key = 'l4' if input_is_set: x = norm_layer(x, name=f'{norm_layer_name}') x = jax.nn.relu(x) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) layer_activations[current_rep_key] = x rep_key = current_rep_key elif input_layer_key == current_rep_key: x = inputs input_is_set = True layer_activations[current_rep_key] = x rep_key = current_rep_key # DANN module if has_discriminator: z = dann_utils.flip_grad_identity(x) z = nn.Dense(z, 2, name='disc_l1', bias=True) z = nn.relu(z) z = nn.Dense(z, 2, name='disc_l2', bias=True) current_rep_key = 'head' if input_is_set: x = nn.Dense(x, num_outputs, dtype=dtype, name='head') else: x = inputs layer_activations[current_rep_key] = x rep_key = current_rep_key logging.warn('Input was never used') outputs = x if return_activations: outputs = (x, layer_activations, rep_key) if discriminator and has_discriminator: outputs = outputs + (z,) else: del layer_activations if discriminator and has_discriminator: outputs = (x, z) if discriminator and (not has_discriminator): raise ValueError( 'Incosistent values passed for discriminator and has_discriminator') return outputs
def apply(self, x, communication=Communication.NONE, train=True): """Forward pass.""" batch_size = x.shape[0] if communication is Communication.SQUEEZE_EXCITE_X: x = sample_patches.SqueezeExciteLayer(x) # end if squeeze excite x d1 = nn.relu( nn.Conv(x, 128, kernel_size=(3, 3), strides=(1, 1), bias=True, name="down1")) d2 = nn.relu( nn.Conv(d1, 128, kernel_size=(3, 3), strides=(2, 2), bias=True, name="down2")) d3 = nn.relu( nn.Conv(d2, 128, kernel_size=(3, 3), strides=(2, 2), bias=True, name="down3")) if communication is Communication.SQUEEZE_EXCITE_D: d1_flatten = einops.rearrange(d1, "b h w c -> b (h w) c") d2_flatten = einops.rearrange(d2, "b h w c -> b (h w) c") d3_flatten = einops.rearrange(d3, "b h w c -> b (h w) c") nd1 = d1_flatten.shape[1] nd2 = d2_flatten.shape[1] d_together = jnp.concatenate([d1_flatten, d2_flatten, d3_flatten], axis=1) num_channels = d_together.shape[-1] y = d_together.mean(axis=1) y = nn.Dense(y, features=num_channels // 4, bias=False) y = nn.relu(y) y = nn.Dense(y, features=num_channels, bias=False) y = nn.sigmoid(y) d_together = d_together * y[:, None, :] # split and reshape d1 = d_together[:, :nd1].reshape(d1.shape) d2 = d_together[:, nd1:nd1 + nd2].reshape(d2.shape) d3 = d_together[:, nd1 + nd2:].reshape(d3.shape) elif communication is Communication.TRANSFORMER: d1_flatten = einops.rearrange(d1, "b h w c -> b (h w) c") d2_flatten = einops.rearrange(d2, "b h w c -> b (h w) c") d3_flatten = einops.rearrange(d3, "b h w c -> b (h w) c") nd1 = d1_flatten.shape[1] nd2 = d2_flatten.shape[1] d_together = jnp.concatenate([d1_flatten, d2_flatten, d3_flatten], axis=1) positional_encodings = self.param( "scale_ratio_position_encodings", shape=(1, ) + d_together.shape[1:], initializer=jax.nn.initializers.normal(1. / d_together.shape[-1])) d_together = transformer.Transformer(d_together + positional_encodings, num_layers=2, num_heads=8, is_training=train) # split and reshape d1 = d_together[:, :nd1].reshape(d1.shape) d2 = d_together[:, nd1:nd1 + nd2].reshape(d2.shape) d3 = d_together[:, nd1 + nd2:].reshape(d3.shape) t1 = nn.Conv(d1, 6, kernel_size=(1, 1), strides=(1, 1), bias=True, name="tidy1") t2 = nn.Conv(d2, 6, kernel_size=(1, 1), strides=(1, 1), bias=True, name="tidy2") t3 = nn.Conv(d3, 9, kernel_size=(1, 1), strides=(1, 1), bias=True, name="tidy3") raw_scores = (jnp.split(t1, 6, axis=-1) + jnp.split(t2, 6, axis=-1) + jnp.split(t3, 9, axis=-1)) # The following is for normalization. t = jnp.concatenate((jnp.reshape( t1, [batch_size, -1]), jnp.reshape( t2, [batch_size, -1]), jnp.reshape(t3, [batch_size, -1])), axis=1) t_min = jnp.reshape(jnp.min(t, axis=-1), [batch_size, 1, 1, 1]) t_max = jnp.reshape(jnp.max(t, axis=-1), [batch_size, 1, 1, 1]) normalized_scores = zeroone(raw_scores, t_min, t_max) stats = { "scores": normalized_scores, "raw_scores": t, } # removes the split dimension. scores are now b x h' x w' shaped normalized_scores = [s.squeeze(-1) for s in normalized_scores] return normalized_scores, stats
def apply(self, inputs, vocab_size, emb_dim=512, num_heads=8, num_layers=6, qkv_dim=512, mlp_dim=2048, max_len=2048, train=False, dropout_rate=0.1, attention_dropout_rate=0.1, causal=True, cache=None, positional_encoding_module=AddLearnedPositionalEncodings, self_attention_module=nn.SelfAttention, attention_fn=None, pad_token=None, output_head='logits'): """Applies Transformer model on the inputs. Args: inputs: An array of shape (batch_size, length) or (batch_size, length, vocab_size) with the input sequences. When 2-dimensional, the array contains sequences of int tokens. Otherwise, the array contains next-token distributions over tokens (e.g. one-hot representations). vocab_size: An int with the size of the vocabulary. emb_dim: An int with the token embedding dimension. num_heads: An int with the number of attention heads. num_layers: An int with the number of transformer encoder layers. qkv_dim: An int with the dimension of the query/key/value vectors. mlp_dim: An int with the inner dimension of the feed-forward network which follows the attention block. max_len: An int with the maximum training sequence length. train: A bool denoting whether we are currently training. dropout_rate: A float with the dropout rate. attention_dropout_rate: A float with a dropout rate for attention weights. causal: Whether to apply causal masking. cache: Cache for decoding. positional_encoding_module: A module used for adding positional encodings. self_attention_module: Self attention module. attention_fn: Method to use in place of dot product attention. pad_token: Token to ignore in attention. output_head: String or iterable over strings containing the model's output head(s) to return. Returns: Output of a transformer decoder. If output_head is a string, we return a single output head output; if output_head is an iterable, we return a dict with (output head name, output head output) key-value pairs. """ if inputs.ndim != 2 and inputs.ndim != 3: raise ValueError('Expected 2 or 3 dimensions, found %d.' % inputs.ndim) if inputs.ndim == 3: padding_mask = jnp.ones_like(inputs[Ellipsis, 0]) elif pad_token is None: padding_mask = jnp.ones_like(inputs) else: # Mask out padding tokens. padding_mask = jnp.where(inputs != pad_token, 1, 0).astype(jnp.float32) padding_mask = padding_mask[Ellipsis, None] # Add embedding dimension. heads = dict() x = inputs if inputs.ndim == 2: x = x.astype('int32') x = Embed(x, num_embeddings=vocab_size, num_features=emb_dim, name='embed') if positional_encoding_module == AddLearnedPositionalEncodings: x = positional_encoding_module( x, max_len=max_len, cache=cache, posemb_init=sinusoidal_init(max_len=max_len)) else: x = positional_encoding_module(x, max_len=max_len) x = nn.dropout(x, rate=dropout_rate, deterministic=not train) heads['input_emb'] = x for i in range(num_layers): x = Transformer1DBlock( x, qkv_dim=qkv_dim, mlp_dim=mlp_dim, num_heads=num_heads, causal_mask=causal, padding_mask=padding_mask, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, self_attention_module=self_attention_module, deterministic=not train, attention_fn=attention_fn, cache=cache, ) heads['layer_%s' % i] = x x = nn.LayerNorm(x) heads['output_emb'] = x * padding_mask # Zero out PAD positions. if 'logits' in output_head: logits = nn.Dense(x, vocab_size, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) heads['logits'] = logits if 'regression' in output_head: regression = nn.Dense( x, 1, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) regression = jnp.squeeze(regression, axis=-1) heads['regression'] = regression if isinstance(output_head, (tuple, list)): return {head: heads[head] for head in output_head} return heads[output_head]
def apply(self, x, config, num_classes, train=True): """Creates a model definition.""" b, c = x.shape[0], x.shape[3] k = config.k sigma = config.ptopk_sigma num_samples = config.ptopk_num_samples sigma *= self.state("sigma_mutiplier", shape=(), initializer=nn.initializers.ones).value stats = {"x": x, "sigma": sigma} feature_extractor = models.ResNet50.shared(train=train, name="ResNet_0") rpn_feature = feature_extractor(x) rpn_scores, rpn_stats = ProposalNet(jax.lax.stop_gradient(rpn_feature), communication=Communication( config.communication), train=train) stats.update(rpn_stats) # rpn_scores are a list of score images. We keep track of the structure # because it is used in the aggregation step later-on. rpn_scores_shapes = [s.shape for s in rpn_scores] rpn_scores_flat = jnp.concatenate( [jnp.reshape(s, [b, -1]) for s in rpn_scores], axis=1) top_k_indicators = sample_patches.select_patches_perturbed_topk( rpn_scores_flat, k=k, sigma=sigma, num_samples=num_samples) top_k_indicators = jnp.transpose(top_k_indicators, [0, 2, 1]) offset = 0 weights = [] for sh in rpn_scores_shapes: cur = top_k_indicators[:, :, offset:offset + sh[1] * sh[2]] cur = jnp.reshape(cur, [b, k, sh[1], sh[2]]) weights.append(cur) offset += sh[1] * sh[2] chex.assert_equal(offset, top_k_indicators.shape[-1]) part_imgs = weighted_anchor_aggregator(x, weights) chex.assert_shape(part_imgs, (b * k, 224, 224, c)) stats["part_imgs"] = jnp.reshape(part_imgs, [b, k * 224, 224, c]) part_features = feature_extractor(part_imgs) part_features = jnp.mean(part_features, axis=[1, 2]) # GAP the spatial dims part_features = nn.dropout( # features from parts jnp.reshape(part_features, [b * k, 2048]), 0.5, deterministic=not train, rng=nn.make_rng()) features = nn.dropout( # features from whole image jnp.reshape(jnp.mean(rpn_feature, axis=[1, 2]), [b, -1]), 0.5, deterministic=not train, rng=nn.make_rng()) # Mean pool all part features, add it to features and predict logits. concat_out = jnp.mean(jnp.reshape(part_features, [b, k, 2048]), axis=1) + features concat_logits = nn.Dense(concat_out, num_classes) raw_logits = nn.Dense(features, num_classes) part_logits = jnp.reshape(nn.Dense(part_features, num_classes), [b, k, -1]) all_logits = { "raw_logits": raw_logits, "concat_logits": concat_logits, "part_logits": part_logits, } # add entropy into it for entropy regularization. stats["rpn_scores_entropy"] = jax.scipy.special.entr( jax.nn.softmax(stats["raw_scores"])).sum(axis=1).mean(axis=0) return all_logits, stats