def apply(self, input_ids, input_mask, type_ids, *, config, deterministic=False): """Applies BERT model on the inputs.""" word_embeddings = nn.Embed(input_ids, num_embeddings=config.vocab_size, features=config.d_emb, embedding_init=kernel_initializer, name="word_embeddings") position_embeddings = layers.PositionalEncoding( word_embeddings, max_len=config.max_len, posemb_init=kernel_initializer, name="position_embeddings") type_embeddings = nn.Embed(type_ids, num_embeddings=config.type_vocab_size, features=config.d_emb, embedding_init=kernel_initializer, name="type_embeddings") embeddings = word_embeddings + position_embeddings + type_embeddings embeddings = nn.LayerNorm(embeddings, epsilon=LAYER_NORM_EPSILON, name="embeddings_layer_norm") embeddings = nn.Dense(embeddings, config.d_model, name="embedding_hidden_mapping_in") embeddings = nn.dropout(embeddings, rate=config.dropout_rate, deterministic=deterministic) # Transformer blocks feed_forward = layers.FeedForward.partial( d_ff=config.d_ff, dropout_rate=config.dropout_rate, intermediate_activation=hidden_activation, kernel_init=kernel_initializer) self_attention = efficient_attention.BertSelfAttention.partial( num_heads=config.num_heads, num_parallel_heads=config.num_parallel_heads, d_qkv=config.d_model // config.num_heads, attention_dropout_rate=config.attention_dropout_rate, output_dropout_rate=config.dropout_rate, kernel_init=kernel_initializer, output_kernel_init=kernel_initializer) hidden_states = embeddings mask = input_mask.astype(jnp.int32) shared_encoder_layer = layers.TransformerBlock.shared( feed_forward=feed_forward, attention=self_attention, deterministic=deterministic, name="encoder_layer_0") for _ in range(config.num_layers): hidden_states = shared_encoder_layer(hidden_states, mask) pooled_output = nn.Dense(hidden_states[:, 0], config.d_model, kernel_init=kernel_initializer, name="pooler") pooled_output = jnp.tanh(pooled_output) return hidden_states, pooled_output
def apply(self, input_ids, input_mask, type_ids, *, config, deterministic=False): """Applies BERT model on the inputs.""" word_embeddings = nn.Embed(input_ids, num_embeddings=config.vocab_size, features=config.hidden_size, embedding_init=get_kernel_init(config), name='word_embeddings') position_embeddings = layers.PositionalEncoding( word_embeddings, max_len=config.max_position_embeddings, posemb_init=get_kernel_init(config), name='position_embeddings') type_embeddings = nn.Embed(type_ids, num_embeddings=config.type_vocab_size, features=config.hidden_size, embedding_init=get_kernel_init(config), name='type_embeddings') embeddings = word_embeddings + position_embeddings + type_embeddings embeddings = nn.LayerNorm(embeddings, epsilon=LAYER_NORM_EPSILON, name='embeddings_layer_norm') embeddings = nn.dropout(embeddings, rate=config.hidden_dropout_prob, deterministic=deterministic) # Transformer blocks feed_forward = layers.FeedForward.partial( d_ff=config.intermediate_size, dropout_rate=config.hidden_dropout_prob, intermediate_activation=get_hidden_activation(config), kernel_init=get_kernel_init(config)) attention = efficient_attention.BertSelfAttention.partial( num_heads=config.num_attention_heads, num_parallel_heads=None, d_qkv=config.hidden_size // config.num_attention_heads, attention_dropout_rate=config.attention_probs_dropout_prob, output_dropout_rate=config.hidden_dropout_prob, kernel_init=get_kernel_init(config), output_kernel_init=get_kernel_init(config)) hidden_states = embeddings mask = input_mask.astype(jnp.int32) for layer_num in range(config.num_hidden_layers): hidden_states = layers.TransformerBlock( hidden_states, mask, feed_forward=feed_forward, attention=attention, deterministic=deterministic, name=f'encoder_layer_{layer_num}') pooled_output = nn.Dense(hidden_states[:, 0], config.hidden_size, kernel_init=get_kernel_init(config), name='pooler') pooled_output = jnp.tanh(pooled_output) return hidden_states, pooled_output
def apply(self, x, labels, y=None, config=None, train=True): # config parsing nf = config.model.nf act = get_act(config) normalize = get_normalization(config) sigmas = utils.get_sigmas(config) nf = config.model.nf ch_mult = config.model.ch_mult num_res_blocks = config.model.num_res_blocks attn_resolutions = config.model.attn_resolutions dropout = config.model.dropout resamp_with_conv = config.model.resamp_with_conv num_resolutions = len(ch_mult) conditional = config.model.conditional # noise-conditional fir = config.model.fir fir_kernel = config.model.fir_kernel skip_rescale = config.model.skip_rescale resblock_type = config.model.resblock_type progressive = config.model.progressive progressive_input = config.model.progressive_input init_scale = config.model.init_scale assert progressive.lower() in ['none', 'output_skip', 'residual'] assert config.model.embedding_type.lower() in ['gaussian', 'positional'] combine_method = config.model.progressive_combine combiner = Combine.partial(method=combine_method) # timestep/noise_level embedding if config.model.embedding_type == 'gaussian': # Gaussian Fourier features embeddings. used_sigmas = sigmas[labels] temb = layersv3.GaussianFourierProjection( jnp.log(used_sigmas), embedding_size=nf, scale=config.model.fourier_scale) elif config.model.embedding_type == 'positional': # Sinusoidal positional embeddings. timesteps = labels temb = layers.get_timestep_embedding(timesteps, nf) else: raise ValueError(f'embedding type {config.model.embedding_type} unknown.') temb = nn.Dense(temb, nf * 4, kernel_init=default_initializer()) temb = nn.Dense(act(temb), nf * 4, kernel_init=default_initializer()) if y is not None: # class-conditional image generation class_embed = nn.Embed(y, config.data.num_classes, nf * 4) class_embed = nn.Dense( class_embed, nf * 4, kernel_init=default_initializer()) class_embed = nn.Dense( act(class_embed), nf * 4, kernel_init=default_initializer()) temb += class_embed AttnBlock = layersv3.AttnBlockv3.partial( normalize=normalize, init_scale=init_scale, skip_rescale=skip_rescale) Upsample = layersv3.Upsample.partial( with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) if progressive == 'output_skip': pyramid_upsample = layersv3.Upsample.partial( fir=fir, fir_kernel=fir_kernel, with_conv=False) elif progressive == 'residual': pyramid_upsample = layersv3.Upsample.partial( fir=fir, fir_kernel=fir_kernel, with_conv=True) Downsample = layersv3.Downsample.partial( with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) if progressive_input == 'input_skip': pyramid_downsample = layersv3.Downsample.partial( fir=fir, fir_kernel=fir_kernel, with_conv=False) elif progressive_input == 'residual': pyramid_downsample = layersv3.Downsample.partial( fir=fir, fir_kernel=fir_kernel, with_conv=True) if resblock_type == 'ddpm': ResnetBlock = ResnetBlockDDPM.partial( act=act, normalize=normalize, dropout=dropout, temb=temb if conditional else None, train=train, init_scale=init_scale, skip_rescale=skip_rescale) elif resblock_type == 'biggan': ResnetBlock = ResnetBlockBigGAN.partial( act=act, normalize=normalize, temb=temb if conditional else None, train=train, dropout=dropout, fir=fir, fir_kernel=fir_kernel, init_scale=init_scale, skip_rescale=skip_rescale) else: raise ValueError(f'resblock_type {resblock_type} unrecognized.') if not config.data.centered: # If input data is in [0, 1] x = 2 * x - 1. # Downsampling block input_pyramid = None if progressive_input != 'none': input_pyramid = x hs = [conv3x3(x, nf)] for i_level in range(num_resolutions): # Residual blocks for this resolution for i_block in range(num_res_blocks): h = ResnetBlock(hs[-1], out_ch=nf * ch_mult[i_level]) if h.shape[1] in attn_resolutions: h = AttnBlock(h) hs.append(h) if i_level != num_resolutions - 1: if resblock_type == 'ddpm': h = Downsample(hs[-1]) else: h = ResnetBlock(hs[-1], down=True) if progressive_input == 'input_skip': input_pyramid = pyramid_downsample(input_pyramid) h = combiner(input_pyramid, h) elif progressive_input == 'residual': input_pyramid = pyramid_downsample( input_pyramid, out_ch=h.shape[-1]) if skip_rescale: input_pyramid = (input_pyramid + h) / np.sqrt(2.) else: input_pyramid = input_pyramid + h h = input_pyramid hs.append(h) h = hs[-1] h = ResnetBlock(h) h = AttnBlock(h) h = ResnetBlock(h) pyramid = None # Upsampling block for i_level in reversed(range(num_resolutions)): for i_block in range(num_res_blocks + 1): h = ResnetBlock( jnp.concatenate([h, hs.pop()], axis=-1), out_ch=nf * ch_mult[i_level]) if h.shape[1] in attn_resolutions: h = AttnBlock(h) if progressive != 'none': if i_level == num_resolutions - 1: if progressive == 'output_skip': pyramid = conv3x3( act(normalize(h, num_groups=min(h.shape[-1] // 4, 32))), x.shape[-1], bias=True, init_scale=init_scale) elif progressive == 'residual': pyramid = conv3x3( act(normalize(h, num_groups=min(h.shape[-1] // 4, 32))), h.shape[-1], bias=True) else: raise ValueError(f'{progressive} is not a valid name.') else: if progressive == 'output_skip': pyramid = pyramid_upsample(pyramid) pyramid = pyramid + conv3x3( act(normalize(h, num_groups=min(h.shape[-1] // 4, 32))), x.shape[-1], bias=True, init_scale=init_scale) elif progressive == 'residual': pyramid = pyramid_upsample(pyramid, out_ch=h.shape[-1]) if skip_rescale: pyramid = (pyramid + h) / np.sqrt(2.) else: pyramid = pyramid + h h = pyramid else: raise ValueError(f'{progressive} is not a valid name') if i_level != 0: if resblock_type == 'ddpm': h = Upsample(h) else: h = ResnetBlock(h, up=True) assert not hs if progressive == 'output_skip': h = pyramid else: h = act(normalize(h, num_groups=min(h.shape[-1] // 4, 32))) h = conv3x3(h, x.shape[-1], init_scale=init_scale) if config.model.scale_by_sigma: used_sigmas = sigmas[labels].reshape((x.shape[0], *([1] * len(x.shape[1:])))) h = h / used_sigmas return h