def test_mha_01(self, combine_qkv, dconv_qkv, dconv_kernel_size, use_rotary_position_emb): mdl_dim = 16 hidden_dim = 32 num_heads = 4 test_layer_p = attentions.DotProductAttention.Params().Set( name='mh', input_dim=mdl_dim, hidden_dim=hidden_dim, num_heads=num_heads, dim_per_head=16 if use_rotary_position_emb else None, atten_logit_cap=20.0, combine_qkv=combine_qkv, dconv_qkv=dconv_qkv, dconv_kernel_size=dconv_kernel_size, use_rotary_position_emb=use_rotary_position_emb) layer = test_layer_p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) prng_key, init_key = jax.random.split(prng_key) initial_vars = layer.instantiate_variables(init_key) logging.info('initial_vars: %s', initial_vars) target_batch_size = 3 source_max_length = 16 target_max_length = 16 initial_states = layer.init_states(target_batch_size, target_max_length) query_vec = np.random.normal( size=[target_batch_size, source_max_length, mdl_dim]).astype( np.float32) key_vec = query_vec value_vec = query_vec atten_mask = attentions.causal_mask(query_vec) prng_key, compute_key = jax.random.split(prng_key) global_step = jnp.array(0, dtype=jnp.uint64) with base_layer.JaxContext.new_context( prng_key=compute_key, global_step=global_step) as jax_context: jax_context.bind(layer, layer.vars_to_flax_vars(initial_vars)) fprop_out, _ = layer.fprop(query_vec, key_vec, value_vec, atten_mask) decoder_output = jnp.zeros( shape=[target_max_length, target_batch_size, mdl_dim]) atten_states = initial_states for t in range(target_max_length): atten_states, encoded = layer.extend_step( atten_states, query_vec=query_vec[:, t, :], atten_mask=atten_mask[:, :, t, :], time_step=t) decoder_output = decoder_output.at[t].set(encoded) decoder_out_transposed = jnp.transpose(decoder_output, [1, 0, 2]) logging.info('fprop_out: %s', fprop_out) logging.info('decoder_out: %s', decoder_output) self.assertAllClose(fprop_out, decoder_out_transposed)
def test_transformer_relative_bias(self, use_relative_bias): p = transformers.Transformer.Params().Set(name='jax_transformer_layer', input_dims=32, hidden_dims=128, num_heads=8, mask_self_attention=True, packed_input=True, cross_attention=False) seq_len = np.random.randint(10, 32) batch_size = 10 if use_relative_bias: p.tr_atten_tpl.relative_bias_tpl = attentions.RelativeBias.Params( ).Set(relative_attention_num_buckets=2, relative_attention_max_distance=8) transformer_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = transformer_layer.instantiate_variables(prng_key) npy_inputs = np.random.normal( 1.0, 0.5, [batch_size, seq_len, p.input_dims]).astype('float32') inputs = jnp.asarray(npy_inputs) npy_paddings = np.random.randint( 0, 1, [batch_size, seq_len]).astype('float32') paddings = jnp.asarray(npy_paddings) attention_mask = attentions.convert_paddings_to_mask(paddings) causal_mask = attentions.causal_mask(inputs) attention_mask = jnp.minimum(attention_mask, causal_mask) segment_ids = np.random.random_integers(0, 2, [batch_size, seq_len]) segment_mask = attentions.segment_mask(segment_ids, dtype=np.float32) attention_mask = jnp.minimum(attention_mask, segment_mask) if use_relative_bias: segment_pos = np.random.randint( 0, seq_len, [batch_size, seq_len]).astype('int32') segment_pos = jnp.asarray(segment_pos) else: segment_pos = None with base_layer.JaxContext.new_context( prng_key=prng_key, global_step=jnp.array(0, dtype=jnp.uint32)) as jax_context: jax_context.bind(transformer_layer, transformer_layer.vars_to_flax_vars(initial_vars)) outputs, _ = transformer_layer.fprop(inputs, paddings, attention_mask=attention_mask, segment_pos=segment_pos) logging.info('initial_vars in transformer layer = %s', initial_vars) np_outputs = test_utils.to_np(outputs) logging.info('np_outputs: %s', np_outputs) if use_relative_bias: self.assertAlmostEqual(np_outputs[0, 0, 1], 0.79015386, places=5) self.assertAlmostEqual(np_outputs[0, 1, 0], 0.48336178, places=5) # Plumbing test. self.assertAllClose(np_outputs, np_outputs, atol=1e-5)
def test_relative_bias_extend_step(self, num_buckets, max_distance, attention_extra_logit): mdl_dim = 16 hidden_dim = 32 num_heads = 4 test_layer_p = attentions.DotProductAttention.Params().Set( name='relative_attn', input_dim=mdl_dim, hidden_dim=hidden_dim, attention_extra_logit=attention_extra_logit, num_heads=num_heads) test_layer_p.relative_bias_tpl = attentions.RelativeBias.Params().Set( relative_attention_num_buckets=num_buckets, relative_attention_max_distance=max_distance) layer = test_layer_p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) prng_key, init_key = jax.random.split(prng_key) initial_vars = layer.instantiate_variables(init_key) target_batch_size = 2 source_max_length = 8 target_max_length = 8 inputs = np.random.normal( size=[target_batch_size, source_max_length, mdl_dim]).astype( np.float32) atten_mask = attentions.causal_mask(inputs) initial_states = layer.init_states(target_batch_size, target_max_length) time_step = 2 _, atten_output = test_utils.apply(layer, initial_vars, layer.extend_step, initial_states, inputs[:, time_step, :], atten_mask=atten_mask[:, :, time_step, :], time_step=time_step) self.assertEqual(atten_output.shape, (target_batch_size, mdl_dim))
def test_transformer_layer_norm_policies(self, norm_policy): p = transformers.Transformer.Params().Set(name='jax_transformer_layer', input_dims=32, hidden_dims=128, num_heads=8, mask_self_attention=True, packed_input=True, cross_attention=False, norm_policy=norm_policy) seq_len = np.random.randint(10, 32) batch_size = 10 transformer_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = transformer_layer.instantiate_variables(prng_key) npy_inputs = np.random.normal( 1.0, 0.5, [batch_size, seq_len, p.input_dims]).astype('float32') inputs = jnp.asarray(npy_inputs) npy_paddings = np.random.randint( 0, 1, [batch_size, seq_len]).astype('float32') paddings = jnp.asarray(npy_paddings) attention_mask = attentions.convert_paddings_to_mask(paddings) causal_mask = attentions.causal_mask(inputs) attention_mask = jnp.minimum(attention_mask, causal_mask) segment_ids = np.random.random_integers(0, 2, [batch_size, seq_len]) segment_mask = attentions.segment_mask(segment_ids, dtype=np.float32) attention_mask = jnp.minimum(attention_mask, segment_mask) with base_layer.JaxContext.new_context( prng_key=prng_key, global_step=jnp.array(0, dtype=jnp.uint32)) as jax_context: jax_context.bind(transformer_layer, transformer_layer.vars_to_flax_vars(initial_vars)) outputs, _ = transformer_layer.fprop(inputs, paddings, attention_mask=attention_mask) logging.info('initial_vars in transformer layer = %s', initial_vars) np_outputs = test_utils.to_np(outputs) # Plumbing test. self.assertAllClose(np_outputs, np_outputs, atol=1e-5)
def test_relative_bias(self, num_buckets, max_distance): mdl_dim = 16 hidden_dim = 32 num_heads = 4 test_layer_p = attentions.DotProductAttention.Params().Set( name='relative_attn', input_dim=mdl_dim, hidden_dim=hidden_dim, num_heads=num_heads) test_layer_p.relative_bias_tpl = attentions.RelativeBias.Params().Set( relative_attention_num_buckets=num_buckets, relative_attention_max_distance=max_distance) layer = test_layer_p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) prng_key, init_key = jax.random.split(prng_key) initial_vars = layer.instantiate_variables(init_key) target_batch_size = 3 source_max_length = 16 query_vec = np.random.normal( size=[target_batch_size, source_max_length, mdl_dim]).astype( np.float32) key_vec = query_vec value_vec = query_vec segment_pos = np.random.randint( 0, source_max_length, [target_batch_size, source_max_length]).astype('int32') atten_mask = attentions.causal_mask(query_vec) atten_output, _ = test_utils.apply(layer, initial_vars, layer.fprop, query_vec, key_vec, value_vec, atten_mask=atten_mask, query_segment_pos=segment_pos) self.assertEqual(atten_output.shape, (target_batch_size, source_max_length, mdl_dim))
def test_transformer_layer(self, mask_self_attention, packed_input, cross_attention): p = transformers.Transformer.Params().Set( name='jax_transformer_layer', input_dims=32, hidden_dims=128, num_heads=8, mask_self_attention=mask_self_attention, packed_input=packed_input, cross_attention=cross_attention) seq_len = np.random.randint(10, 32) batch_size = 10 transformer_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = transformer_layer.instantiate_variables(prng_key) npy_inputs = np.random.normal( 1.0, 0.5, [batch_size, seq_len, p.input_dims]).astype('float32') inputs = jnp.asarray(npy_inputs) npy_paddings = np.random.randint( 0, 1, [batch_size, seq_len]).astype('float32') paddings = jnp.asarray(npy_paddings) causal_mask = None segment_mask = None tf_segment_mask = None attention_mask = attentions.convert_paddings_to_mask(paddings) if mask_self_attention: causal_mask = attentions.causal_mask(inputs) attention_mask = jnp.minimum(attention_mask, causal_mask) if packed_input: segment_ids = np.random.random_integers(0, 2, [batch_size, seq_len]) segment_mask = attentions.segment_mask(segment_ids, dtype=np.float32) attention_mask = jnp.minimum(attention_mask, segment_mask) if mask_self_attention: tf_segment_mask = batch_major_attention.CausalSegmentMask( segment_ids, tf.float32) else: tf_segment_mask = batch_major_attention.SegmentMask( segment_ids, segment_ids) cross_inputs = None cross_attention_mask = None tf_cross_inputs = None tf_cross_paddings = None tf_cross_segment_mask = None if cross_attention: cross_seq_len = np.random.randint(10, 128) npy_cross_inputs = np.random.normal( 1.0, 0.5, [batch_size, cross_seq_len, p.input_dims]).astype('float32') cross_inputs = jnp.asarray(npy_cross_inputs) tf_cross_inputs = tf.constant(npy_cross_inputs, dtype=tf.float32) npy_cross_paddings = np.random.randint( 0, 1, [batch_size, cross_seq_len]).astype('float32') cross_paddings = jnp.asarray(npy_cross_paddings) cross_attention_mask = attentions.convert_paddings_to_mask( cross_paddings) tf_cross_paddings = tf.constant(npy_cross_paddings, dtype=tf.float32) if packed_input: source_segment_ids = np.random.random_integers( 0, 2, [batch_size, cross_seq_len]) cross_segment_mask = attentions.segment_mask( segment_ids, source_segment_ids, dtype=np.float32) cross_attention_mask = jnp.minimum(cross_attention_mask, cross_segment_mask) tf_cross_segment_mask = batch_major_attention.SegmentMask( segment_ids, source_segment_ids) outputs, _ = test_utils.apply( transformer_layer, initial_vars, transformer_layer.fprop, inputs, paddings, context_p=None, attention_mask=attention_mask, cross_inputs=cross_inputs, cross_attention_mask=cross_attention_mask) logging.info('initial_vars in transformer layer = %s', initial_vars) # Test whether tf Transformer layer returns same output # Modify initial_vars to use TF compatible params tf_initial_vars = test_utils.replace_jax_attention_vars_to_tf( initial_vars, cross_attention) tf_initial_vars = test_utils.to_tf_nmap(tf_initial_vars) logging.info('tf_initial_vars in transformer layer = %s', initial_vars) tf_p = batch_major_attention.TransformerLayer.Params().Set( name='tf_transformer_layer', input_dim=p.input_dims, num_heads=p.num_heads, mask_self_atten=mask_self_attention, packed_input=packed_input, has_aux_atten=cross_attention) tf_p.tr_fflayer_tpl.hidden_dim = p.hidden_dims tf_p.tr_fflayer_tpl.fflayer_tpl.batch_norm = False tf_p.tr_fflayer_tpl.fflayer_tpl.has_bias = True tf_transformer_layer = tf_p.Instantiate() tf_output, _ = tf_transformer_layer.FProp( tf_initial_vars, tf.constant(npy_inputs, dtype=tf.float32), paddings=test_utils.to_tf_nmap(npy_paddings), segment_mask=tf_segment_mask, aux_vec=tf_cross_inputs, aux_paddings=tf_cross_paddings, aux_segment_mask=test_utils.to_tf_nmap(tf_cross_segment_mask)) np_outputs = test_utils.to_np(outputs) tf_np_outputs = test_utils.to_np(tf_output) self.assertAllClose(tf_np_outputs, np_outputs, atol=1e-5)
def test_transformer_layer_cross_attention_ln(self, packed_input): p = transformers.Transformer.Params().Set(name='jax_transformer_layer', input_dims=8, hidden_dims=32, num_heads=4, mask_self_attention=True, packed_input=packed_input, cross_attention=True) seq_len = 5 batch_size = 4 transformer_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = transformer_layer.instantiate_variables(prng_key) # Change the self attention initial vars. initial_vars.layer_norm.scale = 0.5 initial_vars.layer_norm.bias = 5.0 # Change the cross attention initial vars. initial_vars.cross_layer_norm.scale = 15 initial_vars.cross_layer_norm.bias = 1.5 npy_inputs = np.random.normal( 1.0, 0.5, [batch_size, seq_len, p.input_dims]).astype('float32') inputs = jnp.asarray(npy_inputs) npy_paddings = np.random.randint( 0, 1, [batch_size, seq_len]).astype('float32') paddings = jnp.asarray(npy_paddings) attention_mask = attentions.convert_paddings_to_mask(paddings) causal_mask = attentions.causal_mask(inputs) attention_mask = jnp.minimum(causal_mask, attention_mask) if packed_input: segment_ids = np.random.random_integers(0, 2, [batch_size, seq_len]) segment_mask = attentions.segment_mask(segment_ids, dtype=np.float32) attention_mask = jnp.minimum(attention_mask, segment_mask) with base_layer.JaxContext.new_context( prng_key=prng_key, global_step=jnp.array(0, dtype=jnp.uint32)) as jax_context: jax_context.bind(transformer_layer, transformer_layer.vars_to_flax_vars(initial_vars)) inputs_normalized = transformer_layer.layer_norm.fprop(inputs) # Compute self-attention, key/value vectors are the input itself atten_output, _ = transformer_layer.self_attention.fprop( inputs_normalized, inputs_normalized, inputs_normalized, atten_mask=attention_mask) # Residual dropout and connection. atten_output = transformer_layer.residual_dropout.fprop( atten_output) atten_output += inputs # Normalize atten outputs using cross attention. atten_output_normalized = transformer_layer.cross_layer_norm.fprop( atten_output) inputs_normalized = test_utils.to_np(inputs_normalized) atten_output_normalized = test_utils.to_np(atten_output_normalized) self.assertAllClose(initial_vars.layer_norm.bias, inputs_normalized.mean(), atol=1e-3) self.assertAllClose((1.0 + initial_vars.layer_norm.scale)**2, np.var(inputs_normalized), atol=5e-3) self.assertAllClose(initial_vars.cross_layer_norm.bias, atten_output_normalized.mean(), atol=1e-3) self.assertAllClose((1.0 + initial_vars.cross_layer_norm.scale)**2, np.var(atten_output_normalized), atol=5e-3)
def test_transformer_layer_extendstep(self, packed_input, cross_attention, dconv_qkv, use_rotary_position_emb): p = transformers.Transformer.Params().Set( name='jax_transformer_layer', input_dims=8, hidden_dims=32, num_heads=4, mask_self_attention=True, packed_input=packed_input, cross_attention=cross_attention) p.tr_atten_tpl.dconv_qkv = dconv_qkv p.tr_atten_tpl.use_rotary_position_emb = use_rotary_position_emb if cross_attention: p.cross_atten_tpl = p.tr_atten_tpl.Copy() # Cross attention should not have depth-wise convolution. p.cross_atten_tpl.dconv_qkv = False # Cross attention should not have rotary position embedding. p.cross_atten_tpl.use_rotary_position_emb = False p.tr_atten_tpl.dconv_kernel_size = 2 seq_len = 4 batch_size = 4 transformer_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = transformer_layer.instantiate_variables(prng_key) initial_states = transformer_layer.init_states(batch_size, seq_len) npy_inputs = np.random.normal( 1.0, 0.5, [batch_size, seq_len, p.input_dims]).astype('float32') inputs = jnp.asarray(npy_inputs) npy_paddings = np.random.randint( 0, 1, [batch_size, seq_len]).astype('float32') # npy_paddings = np.zeros([batch_size, seq_len]) paddings = jnp.asarray(npy_paddings) attention_mask = attentions.convert_paddings_to_mask(paddings) segment_mask = None causal_mask = attentions.causal_mask(inputs) attention_mask = jnp.minimum(causal_mask, attention_mask) if packed_input: segment_ids = np.random.random_integers(0, 2, [batch_size, seq_len]) segment_mask = attentions.segment_mask(segment_ids, dtype=np.float32) attention_mask = jnp.minimum(attention_mask, segment_mask) cross_inputs = None cross_paddings = None cross_attention_mask = None if cross_attention: cross_seq_len = np.random.randint(10, 32) npy_cross_inputs = np.random.normal( 1.0, 0.5, [batch_size, cross_seq_len, p.input_dims]).astype('float32') cross_inputs = jnp.asarray(npy_cross_inputs) npy_cross_paddings = np.random.randint( 0, 1, [batch_size, cross_seq_len]).astype('float32') cross_paddings = jnp.asarray(npy_cross_paddings) cross_attention_mask = attentions.convert_paddings_to_mask( cross_paddings) if packed_input: source_segment_ids = np.random.random_integers( 0, 2, [batch_size, cross_seq_len]) cross_segment_mask = attentions.segment_mask( segment_ids, source_segment_ids, dtype=np.float32) cross_attention_mask = jnp.minimum(cross_attention_mask, cross_segment_mask) with base_layer.JaxContext.new_context( prng_key=prng_key, global_step=jnp.array(0, dtype=jnp.uint32)) as jax_context: jax_context.bind(transformer_layer, transformer_layer.vars_to_flax_vars(initial_vars)) fprop_outputs, _ = transformer_layer.fprop( inputs, paddings, attention_mask=attention_mask, cross_inputs=cross_inputs, cross_attention_mask=cross_attention_mask) decoder_outputs = jnp.zeros( shape=[seq_len, batch_size, p.input_dims]) atten_states = initial_states for t in range(seq_len): attention_mask_t = attention_mask[:, :, t, :] cross_attention_mask_t = cross_attention_mask if cross_attention: cross_attention_mask_t = cross_attention_mask[:, :, t, :] cross_attention_mask_t = np.expand_dims( cross_attention_mask_t, axis=2) atten_states, encoded = transformer_layer.extend_step( atten_states, inputs=inputs[:, t, :], time_step=t, attention_mask=attention_mask_t, cross_inputs=cross_inputs, cross_attention_mask=cross_attention_mask_t) decoder_outputs = decoder_outputs.at[t].set(encoded) decoder_out_transposed = jnp.transpose(decoder_outputs, [1, 0, 2]) logging.info('initial_vars in transformer layer = %s', initial_vars) np_fprop_outputs = test_utils.to_np(fprop_outputs) np_decoder_outputs = test_utils.to_np(decoder_out_transposed) self.assertAllClose(np_fprop_outputs, np_decoder_outputs, atol=1e-5)