def decode(self, programs, encoded, encoded_padding_mask): """Applies decoder on programs and encoded specification.""" cfg = self.config # Allow for decoding without num_partial dimension for beam search. # programs shape == [batch_size, (num_partial), length] assert programs.ndim in [2, 3], ('Number of program dimensions should be' '2 or 3, but it is: %d' % programs.ndim) assert encoded.ndim == programs.ndim + 2 # Collapse num_io dimension. num_io_axis = 1 if programs.ndim == 2 else 2 flat_encoded = base_models.flatten_num_io_dim(encoded, axis=num_io_axis) flat_encoded_padding_mask = base_models.flatten_num_io_dim( encoded_padding_mask, axis=num_io_axis) # Make attention masks. if cfg.decode: # For fast decode with caching, programs shape == [batch_size, 1] and # cfg.shift = False, cfg.decode = True. decoder_mask = None encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(programs), flat_encoded_padding_mask, dtype=cfg.dtype) else: decoder_mask = nn.combine_masks( nn.make_attention_mask(programs > 0, programs > 0, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) encoder_decoder_mask = nn.make_attention_mask( programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype) return self.decoder( programs, flat_encoded, decoder_mask, encoder_decoder_mask)
def __call__(self, inputs, outputs): """Applies Transformer model to encode the IO specification. Args: inputs: input data [batch_size, num_io, length] outputs: output data [batch_size, num_io, length2] Returns: Encoded IO data `[batch_size, num_io, length2, dim]` """ cfg = self.config # Inputs and outputs shared embeddings. embed = nn.Embed(num_embeddings=cfg.vocab_size, features=cfg.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0), name='embed') pos_emb = AddPositionEmbs(config=cfg, cache=False, name='posembed_io') x = inputs.astype('int32') y = outputs.astype('int32') # Make attention masks. inputs_encoder_mask = nn.make_attention_mask(x > 0, x > 0, dtype=cfg.dtype) outputs_encoder_mask = nn.make_attention_mask(y > 0, y > 0, dtype=cfg.dtype) encoder_decoder_mask = nn.make_attention_mask(y > 0, x > 0, dtype=cfg.dtype) # Embed inputs. x = embed(x) x = pos_emb(x) x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=cfg.deterministic) x = x.astype(cfg.dtype) for lyr in range(cfg.num_layers): x = EncoderBlock( # Attend to inputs. config=cfg, name=f'encoderblock_{lyr}')(x, inputs_encoder_mask) x = nn.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x) # Embed outputs. y = embed(y) y = pos_emb(y) y = nn.Dropout(rate=cfg.dropout_rate)(y, deterministic=cfg.deterministic) encode_decoder_cfg = cfg.replace(decode=False) for lyr in range(cfg.num_layers): y = EncoderDecoderBlock( # Double attend to inputs and outputs. config=encode_decoder_cfg, name=f'encoderdecoderblock_{lyr}')(y, x, outputs_encoder_mask, encoder_decoder_mask) y = nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_norm')(y) return y
def encode(self, inputs, inputs_positions=None, inputs_segmentation=None): """Applies Transformer encoder-branch on the inputs. Args: inputs: input data. inputs_positions: input subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. Returns: encoded feature array from the transformer encoder. """ cfg = self.config # Make padding attention mask. encoder_mask = nn.make_attention_mask(inputs > 0, inputs > 0, dtype=cfg.dtype) # Add segmentation block-diagonal attention mask if using segmented data. if inputs_segmentation is not None: encoder_mask = nn.combine_masks( encoder_mask, nn.make_attention_mask(inputs_segmentation, inputs_segmentation, jnp.equal, dtype=cfg.dtype)) return self.encoder(inputs, inputs_positions=inputs_positions, encoder_mask=encoder_mask)
def decode(self, programs, encoded, encoded_padding_mask): """Applies decoder on programs and encoded specification.""" cfg = self.config assert programs.ndim == 2, ('Number of program dimensions should be 2,' ' but it is: %d' % programs.ndim) assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,' ' but it is: %d' % encoded.ndim) # Collapse num_io dimension flat_encoded = flatten_num_io_dim(encoded) flat_encoded_padding_mask = flatten_num_io_dim(encoded_padding_mask) # Make attention masks. if cfg.decode: # For fast decode with caching, programs shape == [batch_size, 1] and # cfg.shift = False, cfg.decode = True. decoder_mask = None encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(programs), flat_encoded_padding_mask, dtype=cfg.dtype) else: decoder_mask = nn.combine_masks( nn.make_attention_mask(programs > 0, programs > 0, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) encoder_decoder_mask = nn.make_attention_mask( programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype) return self.decoder( programs, flat_encoded, decoder_mask, encoder_decoder_mask)
def decode( self, encoded, inputs, # only needed for masks targets, targets_positions=None, inputs_segmentation=None, targets_segmentation=None): """Applies Transformer decoder-branch on encoded-input and target. Args: encoded: encoded input data from encoder. inputs: input data (only needed for masking). targets: target data. targets_positions: target subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. targets_segmentation: target segmentation info for packed examples. Returns: logits array from transformer decoder. """ config = self.config # Make padding attention masks. if config.decode: # for fast autoregressive decoding only a special encoder-decoder mask is used decoder_mask = None encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(targets) > 0, inputs > 0, dtype=config.dtype) else: decoder_mask = nn.combine_masks( nn.make_attention_mask(targets > 0, targets > 0, dtype=config.dtype), nn.make_causal_mask(targets, dtype=config.dtype)) encoder_decoder_mask = nn.make_attention_mask(targets > 0, inputs > 0, dtype=config.dtype) # Add segmentation block-diagonal attention masks if using segmented data. if inputs_segmentation is not None: decoder_mask = nn.combine_masks( decoder_mask, nn.make_attention_mask(targets_segmentation, targets_segmentation, jnp.equal, dtype=config.dtype)) encoder_decoder_mask = nn.combine_masks( encoder_decoder_mask, nn.make_attention_mask(targets_segmentation, inputs_segmentation, jnp.equal, dtype=config.dtype)) logits = self.decoder(encoded, targets, targets_positions=targets_positions, decoder_mask=decoder_mask, encoder_decoder_mask=encoder_decoder_mask) return logits.astype(self.config.dtype)
def __call__(self, input_qkv): cfg = self.config cfg.max_len % cfg.max_seg_len == 0 bsize = input_qkv.shape[0] features = self.out_features or input_qkv.shape[-1] num_seg = cfg.max_len // cfg.max_seg_len x_sqr = input_qkv.reshape([bsize, num_seg, cfg.max_seg_len, input_qkv.shape[-1]]) q_row_local, key_row_local, value_row_local, head_dim = get_qkv(cfg, x_sqr) local_logits = jnp.einsum('...qhd,...khd->...qhk', q_row_local, key_row_local) row_probs = jax.nn.softmax(local_logits) if not cfg.deterministic and cfg.attention_dropout_rate > 0.: dropout_rng = self.make_rng('dropout') row_probs = dropatt(row_probs, dropout_rng, 1 - cfg.attention_dropout_rate) row_attn_out = jnp.einsum('...qhk,...khd->...qhd', row_probs, value_row_local) key_row = DenseGeneral(features=input_qkv.shape[-1], axis=(-2, -1), kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, dtype=cfg.dtype)(row_attn_out) key_row = nn.Dropout(rate=cfg.dropout_rate)(key_row, deterministic=cfg.deterministic) key_row = key_row + x_sqr key_row = nn.LayerNorm(dtype=cfg.dtype)(key_row) key_row = DenseGeneral(axis=-1, features=(cfg.num_heads, head_dim), kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, dtype=cfg.dtype)(key_row) idx_cols = jnp.arange(cfg.max_seg_len) local_mask = nn.make_attention_mask(idx_cols, idx_cols, jnp.less, extra_batch_dims=1) local_mask = jnp.expand_dims(local_mask, axis=-2) * -1e10 local_logits = local_logits + local_mask global_logits = jnp.einsum('bqlhd,bklhd->bqlhk', q_row_local, key_row) idx_rows = jnp.arange(num_seg) global_mask = nn.make_attention_mask(idx_rows, idx_rows, jnp.less_equal) global_mask = global_mask[:, :, jnp.newaxis, jnp.newaxis, :] * -1e10 global_logits = global_logits + global_mask joint_logits = jnp.concatenate((local_logits, global_logits), axis=-1) attn_probs = jax.nn.softmax(joint_logits, axis=-1) local_att, global_att = jnp.split(attn_probs, [cfg.max_seg_len], axis=-1) if not cfg.deterministic and cfg.attention_dropout_rate > 0.: dropout_rng = self.make_rng('dropout') local_att = dropatt(local_att, dropout_rng, 1 - cfg.attention_dropout_rate) local_merged = jnp.einsum('bsqhk,bskhd->bsqhd', local_att, value_row_local) global_merged = jnp.einsum('bqlhv,bvlhd->bqlhd', global_att, row_attn_out) joint_merged = jnp.reshape(local_merged + global_merged, [bsize, cfg.max_len, cfg.num_heads, head_dim]) x = DenseGeneral(features=features, axis=(-2, -1), kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, dtype=cfg.dtype)(joint_merged) return x
def decode(self, programs, encoded, encoded_padding_mask): """Applies decoder on programs and encoded specification.""" cfg = self.config assert programs.ndim == 2, ('Number of program dimensions should be 2,' ' but it is: %d' % programs.ndim) assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,' ' but it is: %d' % encoded.ndim) # Collapse num_io dimension flat_encoded = base_models.flatten_num_io_dim(encoded) flat_encoded_padding_mask = base_models.flatten_num_io_dim( encoded_padding_mask) preshift_programs = programs # Save pre-shifted programs for padding mask. if cfg.shift: programs = base_models.shift_right(programs, cfg.bos_token) # Make attention masks. if cfg.decode: # For fast decode with caching, programs shape == [batch_size, 1] and # cfg.shift = False, cfg.decode = True. # TODO(jxihong): Fast decoding currently does not work with new attention. decoder_mask = None encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(programs), flat_encoded_padding_mask, dtype=cfg.dtype) else: # BOS tokens attend to all previous BOS tokens. decoder_bos_mask = nn.combine_masks( nn.make_attention_mask(programs == cfg.bos_token, programs == cfg.bos_token, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) # Program tokens attend to all previous tokens in partial program. decoder_partial_mask = nn.combine_masks( make_partial_program_mask(programs, bos_token=cfg.bos_token, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) decoder_mask = nn.combine_masks( nn.make_attention_mask(preshift_programs > 0, preshift_programs > 0, dtype=cfg.dtype), jnp.logical_or(decoder_bos_mask, decoder_partial_mask)) encoder_decoder_mask = nn.make_attention_mask( programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype) return self.decoder(programs, flat_encoded, decoder_mask, encoder_decoder_mask)
def decode(self, programs, latents, encoded, latents_padding_mask, encoded_padding_mask): """Applies decoder on programs and encoded specification.""" cfg = self.config assert programs.ndim == 2, ('Number of program dimensions should be 2,' ' but it is: %d' % programs.ndim) assert latents.ndim == 3, ('Number of latents dimensions should be 3,' ' but it is: %d' % latents.ndim) assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,' ' but it is: %d' % encoded.ndim) # Collapse num_io dimension flat_encoded = models.flatten_num_io_dim(encoded) flat_encoded_padding_mask = models.flatten_num_io_dim( encoded_padding_mask) latents = self.latent_pos_emb(latents) # Concatenate the i/o encoding and latents together. flat_encoded = jnp.concatenate([flat_encoded, latents], axis=1) # Make attention masks. if cfg.decode: # For fast decode with caching, programs shape == [batch_size, 1] and # cfg.shift = False, cfg.decode = True. decoder_mask = None latent_decoder_mask = nn.make_attention_mask( jnp.ones_like(programs), latents_padding_mask, dtype=cfg.dtype) encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(programs), flat_encoded_padding_mask, dtype=cfg.dtype) encoder_decoder_mask = jnp.concatenate( [encoder_decoder_mask, latent_decoder_mask], axis=-1) else: decoder_mask = nn.combine_masks( nn.make_attention_mask(programs > 0, programs > 0, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) latent_decoder_mask = nn.make_attention_mask(programs > 0, latents_padding_mask, dtype=cfg.dtype) encoder_decoder_mask = nn.make_attention_mask( programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype) encoder_decoder_mask = jnp.concatenate( [encoder_decoder_mask, latent_decoder_mask], axis=-1) return self.decoder(programs, flat_encoded, decoder_mask, encoder_decoder_mask)
def __call__(self, input_ids, input_mask, type_ids, deterministic = False): """Applies model on the inputs. Args: input_ids: Tokenized inputs of shape <int>[BATCH_SIZE, MAX_SEQ_LENGTH]. input_mask: <bool>[BATCH_SIZE, MAX_SEQ_LENGTH] mask separating actual inputs from padding. Only used by BERT. type_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] ids partitioning input into different types. deterministic: Whether or not to apply dropout in each layer. Returns: Hidden states of shape <float>[BATCH_SIZE, MAX_SEQ_LENGTH, HIDDEN_DIM], and pooled output <float>[BATCH_SIZE, HIDDEN_DIM] scaled to (-1, 1). """ hidden_states = self.embedder( input_ids, type_ids, deterministic=deterministic) # Only used by (BERT) self-attention sublayer. padding_mask = input_mask.astype(jnp.int32) padding_mask = nn.make_attention_mask( query_input=padding_mask, key_input=padding_mask) for encoder_block in self.encoder_blocks: hidden_states = encoder_block( hidden_states, padding_mask, deterministic=deterministic) pooled_output = self.pooler(hidden_states[:, 0]) pooled_output = jnp.tanh(pooled_output) return hidden_states, pooled_output
def get_attention_masks(self, inputs, targets): cfg = self.config if cfg.decode: decoder_mask = None encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(targets) > 0, inputs > 0) else: decoder_mask = nn.combine_masks( nn.make_attention_mask(targets > 0, targets > 0, dtype=cfg.dtype), nn.make_causal_mask(targets, dtype=cfg.dtype)) encoder_decoder_mask = nn.make_attention_mask(targets > 0, inputs > 0, dtype=cfg.dtype) return decoder_mask, encoder_decoder_mask
def decode(self, encoded, inputs, targets, targets_positions=None, inputs_segmentation=None, targets_segmentation=None, train=False): # Make padding attention masks. dtype = jnp.bfloat16 if self.use_bfloat16 else jnp.float32 if self.should_decode: # For fast autoregressive decoding, only a special encoder-decoder mask is # used. decoder_mask = None encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(targets) > 0, inputs > 0, dtype=dtype) else: decoder_mask = nn.combine_masks( nn.make_attention_mask(targets > 0, targets > 0, dtype=dtype), nn.make_causal_mask(targets, dtype=dtype)) encoder_decoder_mask = nn.make_attention_mask(targets > 0, inputs > 0, dtype=dtype) # Add segmentation block-diagonal attention masks if using segmented data. if inputs_segmentation is not None: decoder_mask = nn.combine_masks( decoder_mask, nn.make_attention_mask(targets_segmentation, targets_segmentation, jnp.equal, dtype=dtype)) encoder_decoder_mask = nn.combine_masks( encoder_decoder_mask, nn.make_attention_mask(targets_segmentation, inputs_segmentation, jnp.equal, dtype=dtype)) logits = self.decoder(encoded, targets, targets_positions=targets_positions, decoder_mask=decoder_mask, encoder_decoder_mask=encoder_decoder_mask, train=train) return logits
def __call__(self, inputs, inputs_positions=None, inputs_segmentation=None): """Applies TransformerLM on the inputs. Args: inputs: target data. inputs_positions: input subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. Returns: logits array from transformer decoder. """ config = self.config # Make padding attention masks. if config.decode: # for fast autoregressive decoding we use no decoder mask decoder_mask = None else: decoder_mask = nn.combine_masks( nn.make_attention_mask(inputs > 0, inputs > 0, dtype=config.dtype), nn.make_causal_mask(inputs, dtype=config.dtype)) # Add segmentation block-diagonal attention masks if using segmented data. if inputs_segmentation is not None: decoder_mask = nn.combine_masks( decoder_mask, nn.make_attention_mask(inputs_segmentation, inputs_segmentation, jnp.equal, dtype=config.dtype)) logits = Decoder(config=config, shared_embedding=None, name='decoder')( inputs, inputs_positions=inputs_positions, inputs_segmentation=inputs_segmentation, decoder_mask=decoder_mask, encoder_decoder_mask=None) return logits.astype(self.config.dtype)
def make_causal_mask(x, length_axis, extra_batch_dims=0, strict=False): idxs = jnp.broadcast_to(jnp.arange(x.shape[length_axis], dtype=jnp.int32), x.shape[:length_axis + 1]) mask = nn.make_attention_mask( idxs, idxs, jnp.greater_equal if not strict else jnp.greater, extra_batch_dims=extra_batch_dims, dtype=jnp.float32) return mask
def encode(self, inputs, inputs_positions=None, inputs_segmentation=None, train=False): # Make padding attention mask. encoder_mask = nn.make_attention_mask(inputs > 0, inputs > 0, dtype=inputs.dtype) # Add segmentation block-diagonal attention mask if using segmented data. if inputs_segmentation is not None: encoder_mask = nn.combine_masks( encoder_mask, nn.make_attention_mask(inputs_segmentation, inputs_segmentation, jnp.equal, dtype=inputs_segmentation.dtype)) encoded = self.encoder(inputs, inputs_positions=inputs_positions, encoder_mask=encoder_mask, train=train) return encoded
def __call__(self, targets, targets_mask=None): """Autoencodes program task. Args: targets: target data `[batch_size, length]` targets_mask: padding mask for targets. Returns: embedding sequence. """ cfg = self.config assert targets.ndim == 2 # (batch, len) if targets_mask is None: targets_mask = jnp.where(targets > 0, 1, 0).astype(jnp.float32) encoder_mask = nn.make_attention_mask(targets_mask, targets_mask, dtype=cfg.dtype) output_embed = nn.Embed( num_embeddings=cfg.output_vocab_size, features=cfg.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0), name='embed_output') # Add num_io dimension to latents and latents_mask. x = targets.astype('int32') x = output_embed(x) x = models.AddPositionEmbs(config=cfg, cache=cfg.decode, name='posembed')(x) x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=cfg.deterministic) for lyr in range(cfg.num_layers): x = models.EncoderBlock( # Attend to inputs. config=cfg, name=f'encoderblock_{lyr}')(x, encoder_mask) y = x * targets_mask[Ellipsis, None] for i in range(self.c): # Strided convolutions to decrease length. y = nn.Conv(features=cfg.emb_dim, kernel_size=(2, ), strides=(2, ), name=f'conv_{i}')(y) return y
def __call__(self, encoding: Array, attention_mask: Array, deterministic: bool) -> Array: """Self attention layer forward. Args: encoding: [bsz, seq_len, model_dim] model state. attention_mask: [bsz, seq_len]. deterministic: if true, do not apply dropout. Returns: Updated encoding. """ attention_mask = nn.make_attention_mask(attention_mask, attention_mask) update = self.attention_layer(inputs_q=encoding, mask=attention_mask, deterministic=deterministic) update = self.dropout(update, deterministic=deterministic) encoding = self.layer_norm(encoding + update) return encoding
def __call__(self, inputs, dummy): """Vanilla Transformer encoder. Args: inputs: input data [batch_size, num_io, length] dummy: unused for SCAN dataset. Returns: Encoded inputs `[batch_size, num_io, length, dim]` """ del dummy # TODO(kshi): possibly use dummy for RobustFill. cfg = self.config # Inputs and outputs shared embeddings. embed = nn.Embed(num_embeddings=cfg.vocab_size, features=cfg.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0), name='embed') x = inputs.astype('int32') encoder_mask = nn.make_attention_mask(x > 0, x > 0, dtype=cfg.dtype) # Embed outputs. x = embed(x) if not cfg.use_relative_attention: pos_emb = AddPositionEmbs(config=cfg, cache=False, name='posembed_io') x = pos_emb(x) x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=cfg.deterministic) for lyr in range(cfg.num_layers): x = EncoderBlock( # Attend to inputs. config=cfg, name=f'encoderblock_{lyr}')(x, encoder_mask) y = nn.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x) return y
def __call__(self, inputs, train, inputs_positions=None, inputs_segmentation=None): """Applies Transformer model on the inputs. Args: inputs: input data train: bool: if model is training. inputs_positions: input subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. Returns: output of a transformer decoder. """ assert inputs.ndim == 2 # (batch, len) dtype = utils.dtype_from_str(self.model_dtype) if self.decode: # for fast autoregressive decoding we use no decoder mask decoder_mask = None else: decoder_mask = nn.combine_masks( nn.make_attention_mask(inputs > 0, inputs > 0, dtype=dtype), nn.make_causal_mask(inputs, dtype=dtype)) if inputs_segmentation is not None: decoder_mask = nn.combine_masks( decoder_mask, nn.make_attention_mask(inputs_segmentation, inputs_segmentation, jnp.equal, dtype=dtype)) y = inputs.astype('int32') if not self.decode: y = shift_inputs(y, segment_ids=inputs_segmentation) # TODO(gdahl,znado): this code appears to be accessing out-of-bounds # indices for dataset_lib:proteins_test. This will break when jnp.take() is # updated to return NaNs for out-of-bounds indices. # Debug why this is the case. y = jnp.clip(y, 0, self.vocab_size - 1) if self.shared_embedding is None: output_embed = nn.Embed( num_embeddings=self.vocab_size, features=self.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0)) else: output_embed = self.shared_embedding y = output_embed(y) y = AddPositionEmbs(max_len=self.max_len, posemb_init=sinusoidal_init(max_len=self.max_len), decode=self.decode, name='posembed_output')( y, inputs_positions=inputs_positions) y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=not train) y = y.astype(dtype) for _ in range(self.num_layers): y = Transformer1DBlock( qkv_dim=self.qkv_dim, mlp_dim=self.mlp_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, attention_dropout_rate=self.attention_dropout_rate, attention_fn=self.attention_fn, normalizer=self.normalizer, dtype=dtype)( inputs=y, train=train, decoder_mask=decoder_mask, encoder_decoder_mask=None, inputs_positions=None, inputs_segmentation=None, ) if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']: maybe_normalize = model_utils.get_normalizer(self.normalizer, train, dtype=dtype) y = maybe_normalize()(y) if self.logits_via_embedding: # Use the transpose of embedding matrix for logit transform. logits = output_embed.attend(y.astype(jnp.float32)) # Correctly normalize pre-softmax logits for this shared case. logits = logits / jnp.sqrt(y.shape[-1]) else: logits = nn.Dense(self.vocab_size, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), dtype=dtype, name='logits_dense')(y) return logits.astype(dtype)
def __call__(self, input_qkv): cfg = self.config cfg.max_len % cfg.max_seg_len == 0 bsize = input_qkv.shape[0] features = self.out_features or input_qkv.shape[-1] query, key, value, head_dim = get_qkv(cfg, input_qkv) num_seg = cfg.max_len // cfg.max_seg_len cur_query = query.reshape( [-1, cfg.max_seg_len, query.shape[-2], query.shape[-1]]) merged_query = jnp.max(cur_query, axis=1, keepdims=True) * jnp.sqrt(head_dim) cur_key = key.reshape( [-1, cfg.max_seg_len, key.shape[-2], key.shape[-1]]) cur_value = value.reshape( [-1, cfg.max_seg_len, value.shape[-2], value.shape[-1]]) dropout_rng = None if not cfg.deterministic and cfg.attention_dropout_rate > 0.: dropout_rng = self.make_rng('dropout') s = dot_product_attention(merged_query, cur_key, cur_value, dropout_rng=dropout_rng, dropout_rate=cfg.attention_dropout_rate, broadcast_dropout=False, deterministic=cfg.deterministic, dtype=cfg.dtype) span_val = jnp.reshape(s, [bsize, -1, s.shape[-2], s.shape[-1]]) span_key = jnp.max(cur_key, axis=1, keepdims=True) # (bsize, n_seg, n_head, dim_per_head) span_key = jnp.reshape( span_key, [bsize, -1, span_key.shape[-2], span_key.shape[-1]]) local_mask = make_causal_mask(cur_query, length_axis=1).transpose([0, 2, 1, 3]) local_bias = lax.select( local_mask > 0, jnp.full(local_mask.shape, 0.).astype(cfg.dtype), jnp.full(local_mask.shape, -1e10).astype(cfg.dtype)) # (bsize * n_seg, seg_len, n_head, seg_len) local_logits = jnp.einsum('...qhd,...khd->...qhk', cur_query, cur_key) + local_bias local_logits = jnp.reshape(local_logits, [bsize, -1, cfg.num_heads, cfg.max_seg_len]) idx = jnp.broadcast_to(jnp.arange(span_key.shape[1], dtype=jnp.int32), span_key.shape[:2]) prev_mask = nn.make_attention_mask(idx, idx, jnp.greater, extra_batch_dims=0, dtype=jnp.float32).transpose( [0, 2, 1, 3]) prev_mask = jnp.repeat(prev_mask, cfg.max_seg_len, axis=-3) prev_bias = lax.select( prev_mask > 0, jnp.full(prev_mask.shape, 0.).astype(cfg.dtype), jnp.full(prev_mask.shape, -1e10).astype(cfg.dtype)) # (bsize, max_len, n_head, num_segs) prev_logits = jnp.einsum('...qhd,...khd->...qhk', query, span_key) + prev_bias joint_logits = jnp.concatenate((local_logits, prev_logits), axis=-1) # (bsize x max_len, n_head, seg_len + num_segs) attn_weights = jax.nn.softmax(joint_logits).astype(cfg.dtype) local_att, prev_att = jnp.split(attn_weights, [cfg.max_seg_len], axis=-1) local_att = local_att.reshape( [bsize * num_seg, cfg.max_seg_len, cfg.num_heads, cfg.max_seg_len]) local_merged = jnp.einsum('...qhk,...khd->...qhd', local_att, cur_value) prev_merged = jnp.einsum('...qhk,...khd->...qhd', prev_att, span_val) joint_merged = jnp.reshape(local_merged, prev_merged.shape) + prev_merged x = DenseGeneral(features=features, axis=(-2, -1), kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, dtype=cfg.dtype)(joint_merged) return x
def decode(self, programs, encoded, encoded_padding_mask): """Applies decoder on programs and encoded specification.""" cfg = self.config.base_config assert programs.ndim == 2, ('Number of program dimensions should be 2,' ' but it is: %d' % programs.ndim) assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,' ' but it is: %d' % encoded.ndim) # Collapse num_io dimension flat_encoded = base_models.flatten_num_io_dim(encoded) flat_encoded_padding_mask = base_models.flatten_num_io_dim( encoded_padding_mask) if cfg.shift: programs = base_models.shift_right(programs, cfg.bos_token) # Make attention masks. decoder_mask = None decoder_relative_position = None # Relative positions. if cfg.decode: # For fast decode with caching, programs shape == [batch_size, 1] and # cfg.shift = False, cfg.decode = True. # TODO(jxihong): Fast decoding currently does not work with new attention. encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(programs), flat_encoded_padding_mask, dtype=cfg.dtype) else: attention_mask_type = self.config.attention_mask_type if attention_mask_type == 'baseline': decoder_mask = nn.combine_masks( nn.make_attention_mask(programs > 0, programs > 0, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) else: if attention_mask_type == 'bos_to_bos': # BOS tokens attend to all previous BOS tokens. decoder_bos_mask = nn.combine_masks( nn.make_attention_mask(programs == cfg.bos_token, programs == cfg.bos_token, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) elif attention_mask_type == 'bos_to_last': # BOS tokens attend to all last partial program tokens. bos_mask = nn.combine_masks( nn.make_attention_mask(programs == cfg.bos_token, programs == cfg.bos_token, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) # Shift bos mask to left to get all previous last partial program # tokens. decoder_bos_mask = shift_left(bos_mask) elif attention_mask_type == 'bos_to_bos_and_last': # BOS tokens attend to all previous BOS + last partial program tokens. bos_mask = nn.combine_masks( nn.make_attention_mask(programs == cfg.bos_token, programs == cfg.bos_token, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) # Shift bos mask to left to get all previous last partial program # tokens. decoder_bos_mask = jnp.logical_or(bos_mask, shift_left(bos_mask)) elif attention_mask_type == 'bos_full_attention': # BOS tokens attend to all previous tokens, including program tokens. decoder_bos_mask = nn.combine_masks( nn.make_attention_mask(programs == cfg.bos_token, programs > 0, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) else: raise ValueError( 'Unhandled attention_mask_type: {}'.format( attention_mask_type)) # Program tokens attend to all previous tokens in partial program. decoder_partial_mask = nn.combine_masks( make_partial_program_mask(programs, bos_token=cfg.bos_token, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) decoder_mask = nn.combine_masks( nn.make_attention_mask(programs > 0, programs > 0, dtype=cfg.dtype), jnp.logical_or(decoder_bos_mask, decoder_partial_mask)) if self.config.bos_special_attention: # Make custom relative positions where BOS are separately indexed. decoder_relative_position = make_relative_position( programs) decoder_partial_relative_position = ( make_partial_program_relative_position( programs, bos_token=cfg.bos_token)) decoder_relative_position = jnp.where( (programs == cfg.bos_token)[Ellipsis, None], decoder_partial_relative_position, decoder_relative_position) else: decoder_relative_position = None encoder_decoder_mask = nn.make_attention_mask( programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype) return self.decoder(programs, flat_encoded, decoder_mask, encoder_decoder_mask, decoder_relative_position)