def test_stacked_transformer_layer(self, mask_self_attention, packed_input, cross_attention): p = transformers.StackedTransformer.Params().Set( name='jax_stacked_transformer_layer', model_dims=16, hidden_dims=64, num_heads=8, mask_self_attention=mask_self_attention, num_layers=4, packed_input=packed_input, cross_attention=cross_attention) seq_len = np.random.randint(10, 32) batch_size = 10 stacked_transformer_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = stacked_transformer_layer.instantiate_variables( prng_key) # test conversion between vars and flax vars. pax_vars = stacked_transformer_layer.vars flax_vars = stacked_transformer_layer.flax_vars tf.nest.assert_same_structure( pax_vars, stacked_transformer_layer.flax_vars_to_vars(flax_vars)) tf.nest.assert_same_structure( flax_vars, stacked_transformer_layer.vars_to_flax_vars(pax_vars)) npy_inputs = np.random.normal( 1.0, 0.5, [batch_size, seq_len, p.model_dims]).astype('float32') inputs = jnp.asarray(npy_inputs) npy_paddings = np.random.randint( 0, 1, [batch_size, seq_len]).astype('float32') paddings = jnp.asarray(npy_paddings) segment_mask = None tf_segment_mask = None if packed_input: segment_ids = np.random.random_integers(0, 2, [batch_size, seq_len]) segment_mask = attentions.segment_mask(segment_ids, dtype=np.float32) if mask_self_attention: tf_segment_mask = batch_major_attention.CausalSegmentMask( segment_ids, tf.float32) else: tf_segment_mask = batch_major_attention.SegmentMask( segment_ids, segment_ids) cross_inputs = None cross_paddings = None cross_segment_mask = None tf_cross_inputs = None tf_cross_paddings = None tf_cross_segment_mask = None if cross_attention: cross_seq_len = np.random.randint(10, 64) npy_cross_inputs = np.random.normal( 1.0, 0.5, [batch_size, cross_seq_len, p.model_dims]).astype('float32') cross_inputs = jnp.asarray(npy_cross_inputs) tf_cross_inputs = tf.constant(npy_cross_inputs, dtype=tf.float32) npy_cross_paddings = np.random.randint( 0, 1, [batch_size, cross_seq_len]).astype('float32') cross_paddings = jnp.asarray(npy_cross_paddings) tf_cross_paddings = tf.constant(npy_cross_paddings, dtype=tf.float32) if packed_input: source_segment_ids = np.random.random_integers( 0, 2, [batch_size, cross_seq_len]) cross_segment_mask = attentions.segment_mask( segment_ids, source_segment_ids, dtype=np.float32) tf_cross_segment_mask = batch_major_attention.SegmentMask( segment_ids, source_segment_ids) outputs = test_utils.apply(stacked_transformer_layer, initial_vars, stacked_transformer_layer.fprop, inputs, paddings, context_p=None, segment_mask=segment_mask, cross_inputs=cross_inputs, cross_paddings=cross_paddings, cross_segment_mask=cross_segment_mask) logging.info('initial_vars in transformer layer = %s', initial_vars) # Test whether tf Transformer layer returns same output # Modify initial_vars to use TF compatible params tf_initial_vars = py_utils.NestedMap() tf_initial_vars.x_layers = [] for jax_initial_vars in initial_vars.x_layers: tf_layer_vars = test_utils.replace_jax_attention_vars_to_tf( jax_initial_vars, cross_attention) tf_initial_vars.x_layers.append(tf_layer_vars) tf_initial_vars = test_utils.to_tf_nmap(tf_initial_vars) logging.info('tf_initial_vars in transformer layer = %s', initial_vars) tf_p = batch_major_attention.StackedTransformerLayers.Params().Set( name='tf_transformer_layer', mdl_dim=p.model_dims, hidden_dim=p.hidden_dims, num_atten_heads=p.num_heads, mask_self_atten=mask_self_attention, num_layers=p.num_layers, packed_input=packed_input, has_aux_atten=cross_attention) tf_p.transformer_layer_params_tpl.tr_fflayer_tpl.fflayer_tpl.batch_norm = ( False) tf_p.transformer_layer_params_tpl.tr_fflayer_tpl.fflayer_tpl.has_bias = True tf_stacked_transformer_layer = tf_p.Instantiate() tf_output, _ = tf_stacked_transformer_layer.FProp( tf_initial_vars, test_utils.to_tf_nmap(npy_inputs), paddings=test_utils.to_tf_nmap(npy_paddings), segment_mask=test_utils.to_tf_nmap(tf_segment_mask), aux_vec=test_utils.to_tf_nmap(tf_cross_inputs), aux_paddings=test_utils.to_tf_nmap(tf_cross_paddings), aux_segment_mask=test_utils.to_tf_nmap(tf_cross_segment_mask)) np_outputs = test_utils.to_np(outputs) tf_np_outputs = test_utils.to_np(tf_output) self.assertAllClose(tf_np_outputs, np_outputs, atol=1e-5)
def FProp(self, theta, input_batch): """Embeds source ids and transforms with TransformerStack. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` object containing: ids - The inputs tensor of shape [batch, time]. paddings - The ids' paddings of shape [batch, time]. Returns: A '.NestedMap' object containing: encoded - The encoded features of shape [time, batch, dim] or [batch, time, dim], depending p.output_data_format. padding - The encoded features' padding of shape [time, batch] or [batch, time]. segment_id - The segmentation of packed inputs of shape [time, batch] or [batch, time] if it is supported by the model, or None otherwise. embedded_inputs - The embedded inputs tokens without positional encodings of shape [time, batch, dim] or [batch, time, dim]. """ p = self.params with tf.name_scope(p.name): # [batch, time] input_ids = input_batch.ids # [batch, time] paddings = input_batch.paddings # [batch, time] segment_ids = input_batch.segment_ids if p.packed_input else None batch = py_utils.GetShape(input_ids)[0] time = py_utils.GetShape(input_ids)[1] # Embedding layer. # [batch, time, dim] if not p.shared_emb: input_embs = self.token_emb.EmbLookup(theta.token_emb, input_ids) else: input_embs = self.softmax.EmbLookup(theta.softmax, input_ids) orig_input_embs = input_embs # [1, time, dim] if p.packed_input: positions = input_batch.segment_pos position_embs = tf.expand_dims( self.position_emb.FPropWithPosition( theta.position_emb, positions), 0) else: position_embs = tf.expand_dims( self.position_emb.FProp(theta.position_emb, time), 0) # [batch, time, dim] input_embs += position_embs if p.input_dropout_tpl.fprop_dtype: input_embs = tf.cast(input_embs, p.input_dropout_tpl.fprop_dtype) paddings = tf.cast(paddings, p.input_dropout_tpl.fprop_dtype) input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs) # [batch, time, dim] transformer_input = input_embs # Explicitly set the input shape of Transformer layers, to avoid # unknown shape error occurred to tf.einsum on nonTPU devices. transformer_input = tf.reshape(transformer_input, [batch, time, p.model_dim]) # Compute self-attention segment mask once. if p.packed_input: segment_mask = batch_major_attention.SegmentMask( segment_ids, segment_ids, dtype=transformer_input.dtype) else: segment_mask = tf.zeros([batch, 1, time, time]) encoded, padding = self.transformer_stack.FProp( theta.transformer_stack, transformer_input, paddings, segment_mask) if p.final_layer_norm: encoded = self.final_ln.FProp(theta.final_ln, encoded) seq_lengths = tf.cast(tf.reduce_sum(1. - padding, axis=1), tf.int32) if p.output_data_format == 'TBC': encoded = tf.transpose(encoded, [1, 0, 2]) # [time, batch, dim] padding = tf.transpose(padding) # [time, batch] segment_ids = tf.transpose( segment_ids) if p.packed_input else None orig_input_embs = tf.transpose(orig_input_embs, [1, 0, 2]) return py_utils.NestedMap( encoded=encoded, padding=padding, seq_lengths=seq_lengths, # used by beam_search_helper. segment_id=segment_ids, embedded_inputs=orig_input_embs)
def test_transformer_layer(self, mask_self_attention, packed_input, cross_attention): p = transformers.Transformer.Params().Set( name='jax_transformer_layer', input_dims=32, hidden_dims=128, num_heads=8, mask_self_attention=mask_self_attention, packed_input=packed_input, cross_attention=cross_attention) seq_len = np.random.randint(10, 32) batch_size = 10 transformer_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = transformer_layer.instantiate_variables(prng_key) npy_inputs = np.random.normal( 1.0, 0.5, [batch_size, seq_len, p.input_dims]).astype('float32') inputs = jnp.asarray(npy_inputs) npy_paddings = np.random.randint( 0, 1, [batch_size, seq_len]).astype('float32') paddings = jnp.asarray(npy_paddings) causal_mask = None segment_mask = None tf_segment_mask = None attention_mask = attentions.convert_paddings_to_mask(paddings) if mask_self_attention: causal_mask = attentions.causal_mask(inputs) attention_mask = jnp.minimum(attention_mask, causal_mask) if packed_input: segment_ids = np.random.random_integers(0, 2, [batch_size, seq_len]) segment_mask = attentions.segment_mask(segment_ids, dtype=np.float32) attention_mask = jnp.minimum(attention_mask, segment_mask) if mask_self_attention: tf_segment_mask = batch_major_attention.CausalSegmentMask( segment_ids, tf.float32) else: tf_segment_mask = batch_major_attention.SegmentMask( segment_ids, segment_ids) cross_inputs = None cross_attention_mask = None tf_cross_inputs = None tf_cross_paddings = None tf_cross_segment_mask = None if cross_attention: cross_seq_len = np.random.randint(10, 128) npy_cross_inputs = np.random.normal( 1.0, 0.5, [batch_size, cross_seq_len, p.input_dims]).astype('float32') cross_inputs = jnp.asarray(npy_cross_inputs) tf_cross_inputs = tf.constant(npy_cross_inputs, dtype=tf.float32) npy_cross_paddings = np.random.randint( 0, 1, [batch_size, cross_seq_len]).astype('float32') cross_paddings = jnp.asarray(npy_cross_paddings) cross_attention_mask = attentions.convert_paddings_to_mask( cross_paddings) tf_cross_paddings = tf.constant(npy_cross_paddings, dtype=tf.float32) if packed_input: source_segment_ids = np.random.random_integers( 0, 2, [batch_size, cross_seq_len]) cross_segment_mask = attentions.segment_mask( segment_ids, source_segment_ids, dtype=np.float32) cross_attention_mask = jnp.minimum(cross_attention_mask, cross_segment_mask) tf_cross_segment_mask = batch_major_attention.SegmentMask( segment_ids, source_segment_ids) outputs, _ = test_utils.apply( transformer_layer, initial_vars, transformer_layer.fprop, inputs, paddings, context_p=None, attention_mask=attention_mask, cross_inputs=cross_inputs, cross_attention_mask=cross_attention_mask) logging.info('initial_vars in transformer layer = %s', initial_vars) # Test whether tf Transformer layer returns same output # Modify initial_vars to use TF compatible params tf_initial_vars = test_utils.replace_jax_attention_vars_to_tf( initial_vars, cross_attention) tf_initial_vars = test_utils.to_tf_nmap(tf_initial_vars) logging.info('tf_initial_vars in transformer layer = %s', initial_vars) tf_p = batch_major_attention.TransformerLayer.Params().Set( name='tf_transformer_layer', input_dim=p.input_dims, num_heads=p.num_heads, mask_self_atten=mask_self_attention, packed_input=packed_input, has_aux_atten=cross_attention) tf_p.tr_fflayer_tpl.hidden_dim = p.hidden_dims tf_p.tr_fflayer_tpl.fflayer_tpl.batch_norm = False tf_p.tr_fflayer_tpl.fflayer_tpl.has_bias = True tf_transformer_layer = tf_p.Instantiate() tf_output, _ = tf_transformer_layer.FProp( tf_initial_vars, tf.constant(npy_inputs, dtype=tf.float32), paddings=test_utils.to_tf_nmap(npy_paddings), segment_mask=tf_segment_mask, aux_vec=tf_cross_inputs, aux_paddings=tf_cross_paddings, aux_segment_mask=test_utils.to_tf_nmap(tf_cross_segment_mask)) np_outputs = test_utils.to_np(outputs) tf_np_outputs = test_utils.to_np(tf_output) self.assertAllClose(tf_np_outputs, np_outputs, atol=1e-5)