def test_per_dim_scale(self): test_layer_p = attentions.PerDimScale.Params().Set(name='scale', dim=4) layer = test_layer_p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) prng_key, init_key = jax.random.split(prng_key) initial_vars = layer.instantiate_variables(init_key) initial_vars.per_dim_scale = jnp.array([-0.5, 0.5, 1.0, 0.0], dtype=jnp.float32) logging.info('initial_vars: %s', initial_vars) inputs = np.random.normal(1.5, 2.0, [5, 4]).astype(np.float32) jax_out = test_utils.apply(layer, initial_vars, layer.fprop, inputs) logging.info('jax_output: %s', jax_out) # Now run TF based computation. tf_layer_p = batch_major_attention.PerDimScaleLayer.Params().Set( name='scale', dim=4) tf_layer = tf_layer_p.Instantiate() tf_output1 = tf_layer.FProp(tf_layer.theta, inputs) logging.info('tf_output1: %s', tf_output1) tf_output2 = tf_layer.FProp(initial_vars, inputs) logging.info('tf_output2: %s', tf_output2) self.assertAllClose(test_utils.to_np(jax_out), test_utils.to_np(tf_output2))
def have_similar_stats(x, y): mean1, std1 = var_stats(test_utils.to_np(x)) mean2, std2 = var_stats(test_utils.to_np(y)) delta_mean = np.abs(mean1 - mean2) delta_std = np.abs(std1 - std2) logging.info('mean1: %s, mean2: %s', mean1, mean2) logging.info('std1: %s, std2: %s', std1, std2) test_case.assertLess(delta_mean, 0.0002) test_case.assertLess(delta_std, 0.0002)
def test_transformer_feedforward(self, activation_function): p = transformers.TransformerFeedForward.Params().Set( name='ffwd', input_dims=8, hidden_dims=32, activation=activation_function) batch_size = 8 seq_len = 512 ffwd = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = ffwd.instantiate_variables(prng_key) npy_inputs = np.random.normal( 1.0, 0.5, [batch_size, seq_len, p.input_dims]).astype('float32') inputs = jnp.asarray(npy_inputs) npy_paddings = np.zeros([batch_size, seq_len], dtype=np.float32) input_paddings = jnp.asarray(npy_paddings) with base_layer.JaxContext.new_context( prng_key=jax.random.PRNGKey(seed=1234), global_step=jnp.array(0, dtype=jnp.uint32)) as jax_context: jax_context.bind(ffwd, ffwd.vars_to_flax_vars(initial_vars)) outputs = ffwd.fprop(inputs, input_paddings) logging.info('outputs: %s', outputs) if activation_function.startswith('GATED_'): # Default lingvo layers_with_attention.TransformerFeedForwardLayer does # not support gating. return # Test whether Tensorflow TransformerFeedForwardLayer returns the same # output. Modify `initial_vars` to use TF compatible params. tf_initial_vars = test_utils.replace_jax_transformer_ffwd_vars_to_tf( initial_vars) tf_initial_vars = test_utils.to_tf_nmap(tf_initial_vars) logging.info('tf_initial_vars in transformer feedforward layer = %s', initial_vars) tf_p = layers_with_attention.TransformerFeedForwardLayer.Params().Set( name='tf_ffwd', input_dim=p.input_dims, hidden_dim=p.hidden_dims, activation=p.activation) tf_ffwd = tf_p.Instantiate() tf_output = tf_ffwd.FProp(tf_initial_vars, tf.constant(npy_inputs, dtype=tf.float32), paddings=test_utils.to_tf_nmap(npy_paddings)) np_outputs = test_utils.to_np(outputs) tf_np_outputs = test_utils.to_np(tf_output) self.assertAllClose(tf_np_outputs, np_outputs, atol=1e-5)
def test_causal_depthwise_conv1d(self, shape, kernel_size, axis, hidden_dims): inputs = np.random.normal(1.5, 2.0, shape).astype(np.float32) p = attentions.CausalDepthwiseConv1D.Params().Set( name='causal_dconv', kernel_size=kernel_size, hidden_dims=hidden_dims) causal_dconv_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) prng_key, init_key = jax.random.split(prng_key) initial_vars = causal_dconv_layer.instantiate_variables(init_key) if isinstance(hidden_dims, list): kernel_shape = hidden_dims else: kernel_shape = [hidden_dims] for k in range(kernel_size): initial_vars[f'dconv_{k}'] = np.ones(kernel_shape) jax_dconv_out = test_utils.apply(causal_dconv_layer, initial_vars, causal_dconv_layer.fprop, inputs, axis=axis) jax_np_out = test_utils.to_np(jax_dconv_out) outputs = inputs for _ in range(1, kernel_size): inputs = attentions.shift_1d(inputs, offset=1, axis=axis) outputs += inputs self.assertArraysEqual(jax_np_out, outputs)
def test_limited_context_mask_from_padding(self, batch_size, max_length, left_context, right_context): def get_padding_from_length(length): idx = np.tile(np.arange(max_length), [batch_size, 1]) return (idx >= np.expand_dims(length, -1)).astype('float32') length = np.random.randint(max_length // 2, max_length, [ batch_size, ]) padding = jnp.asarray(get_padding_from_length(length)) result = attentions.limited_context_mask_from_padding( padding, left_context, right_context) expect = np.zeros((batch_size, 1, max_length, max_length)) for b in range(batch_size): for t1 in range(max_length): if t1 >= length[b]: continue start_p, end_p = 0, length[b] if left_context is not None: start_p = max(0, t1 - left_context + 1) if right_context is not None: end_p = min(length[b], t1 + right_context + 1) expect[b, 0, t1, start_p:end_p] = 1.0 self.assertAllClose(test_utils.to_np(result), (1.0 - expect) * attentions._get_large_negative_number(jnp.float32))
def test_rotary_position_embedding_layer_no_prefix(self, min_timescale, max_timescale): embedding_dims = 32 p = embedding_softmax.RotaryPositionalEmbedding.Params().Set( name='jax_pos', embedding_dims=embedding_dims, min_timescale=min_timescale, max_timescale=max_timescale) pos_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = pos_layer.instantiate_variables(prng_key) inputs = np.random.normal(1.5, 2.5, (2, 8, 4, embedding_dims)) output = test_utils.apply(pos_layer, initial_vars, pos_layer.fprop, inputs=inputs) # Test whether extend_step returns same output. for i in range(inputs.shape[1]): jax_extend_step_out = test_utils.apply(pos_layer, initial_vars, pos_layer.extend_step, inputs[:, i, :, :], time_step=i) jax_np_extend_step_out = test_utils.to_np(jax_extend_step_out) jax_fprop_slice = output[:, i, :, :] self.assertArraysEqual(jax_fprop_slice, jax_np_extend_step_out)
def test_rotary_position_embedding_layer_2d(self, position): embedding_dims = 2 min_timescale = 1 max_timescale = 1e4 p = embedding_softmax.RotaryPositionalEmbedding.Params().Set( name='jax_pos', embedding_dims=embedding_dims, min_timescale=min_timescale, max_timescale=max_timescale) pos_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = pos_layer.instantiate_variables(prng_key) inputs = np.random.normal(1.5, 2.5, (1, 4, 1, embedding_dims)) if position is None: position = jnp.arange(4, dtype=jnp.float32) position = jnp.array(position) output = test_utils.apply(pos_layer, initial_vars, pos_layer.fprop, inputs=inputs, position=position[jnp.newaxis, :]) np_output = test_utils.to_np(output) sinusoid_inp = position sin = jnp.sin(sinusoid_inp) cos = jnp.cos(sinusoid_inp) first_part = inputs[0, :, 0, 0] * cos - inputs[0, :, 0, 1] * sin second_part = inputs[0, :, 0, 1] * cos + inputs[0, :, 0, 0] * sin expected_output = np.stack([first_part, second_part], axis=-1) self.assertArraysEqual(np_output[0, :, 0, :], expected_output)
def test_mhd_projection_02(self, use_nhd_shape): test_layer_p = attentions.AttentionProjection.Params().Set( name='mh', input_dim=16, num_heads=2, dim_per_head=5, is_output_projection=True, use_nhd_shape=use_nhd_shape, ) layer = test_layer_p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) prng_key, init_key = jax.random.split(prng_key) initial_vars = layer.instantiate_variables(init_key) logging.info('initial_vars: %s', initial_vars) inputs = np.random.normal(1.5, 2.0, [5, 2, 5]).astype(np.float32) jax_out = test_utils.apply(layer, initial_vars, layer.fprop, inputs) logging.info('jax_output: %s', jax_out) if use_nhd_shape: initial_vars.w = np.einsum('ABC->CAB', initial_vars.w) # Now run TF based computation. tf_layer_p = batch_major_attention.MultiHeadedProjectionLayer.Params( ).Set(name='mh', input_dim=16, num_heads=2, dim_per_head=5, is_output_projection=True) tf_layer = tf_layer_p.Instantiate() tf_output1 = tf_layer.FProp(tf_layer.theta, inputs) logging.info('tf_output1: %s', tf_output1) tf_output2 = tf_layer.FProp(initial_vars, inputs) logging.info('tf_output2: %s', tf_output2) self.assertGreater( np.sum( np.abs( test_utils.to_np(tf_output1) - test_utils.to_np(tf_output2))), 0.1) self.assertAllClose(test_utils.to_np(jax_out), test_utils.to_np(tf_output2))
def test_transformer_relative_bias(self, use_relative_bias): p = transformers.Transformer.Params().Set(name='jax_transformer_layer', input_dims=32, hidden_dims=128, num_heads=8, mask_self_attention=True, packed_input=True, cross_attention=False) seq_len = np.random.randint(10, 32) batch_size = 10 if use_relative_bias: p.tr_atten_tpl.relative_bias_tpl = attentions.RelativeBias.Params( ).Set(relative_attention_num_buckets=2, relative_attention_max_distance=8) transformer_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = transformer_layer.instantiate_variables(prng_key) npy_inputs = np.random.normal( 1.0, 0.5, [batch_size, seq_len, p.input_dims]).astype('float32') inputs = jnp.asarray(npy_inputs) npy_paddings = np.random.randint( 0, 1, [batch_size, seq_len]).astype('float32') paddings = jnp.asarray(npy_paddings) attention_mask = attentions.convert_paddings_to_mask(paddings) causal_mask = attentions.causal_mask(inputs) attention_mask = jnp.minimum(attention_mask, causal_mask) segment_ids = np.random.random_integers(0, 2, [batch_size, seq_len]) segment_mask = attentions.segment_mask(segment_ids, dtype=np.float32) attention_mask = jnp.minimum(attention_mask, segment_mask) if use_relative_bias: segment_pos = np.random.randint( 0, seq_len, [batch_size, seq_len]).astype('int32') segment_pos = jnp.asarray(segment_pos) else: segment_pos = None with base_layer.JaxContext.new_context( prng_key=prng_key, global_step=jnp.array(0, dtype=jnp.uint32)) as jax_context: jax_context.bind(transformer_layer, transformer_layer.vars_to_flax_vars(initial_vars)) outputs, _ = transformer_layer.fprop(inputs, paddings, attention_mask=attention_mask, segment_pos=segment_pos) logging.info('initial_vars in transformer layer = %s', initial_vars) np_outputs = test_utils.to_np(outputs) logging.info('np_outputs: %s', np_outputs) if use_relative_bias: self.assertAlmostEqual(np_outputs[0, 0, 1], 0.79015386, places=5) self.assertAlmostEqual(np_outputs[0, 1, 0], 0.48336178, places=5) # Plumbing test. self.assertAllClose(np_outputs, np_outputs, atol=1e-5)
def test_causal_depthwise_conv1d_extend_step(self, shape, kernel_size, axis, hidden_dims): inputs = np.random.normal(1.5, 2.0, shape).astype(np.float32) p = attentions.CausalDepthwiseConv1D.Params().Set( name='causal_dconv', kernel_size=kernel_size, hidden_dims=hidden_dims) causal_dconv_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) prng_key, init_key = jax.random.split(prng_key) initial_vars = causal_dconv_layer.instantiate_variables(init_key) prng_key, compute_key = jax.random.split(prng_key) global_step = jnp.array(0, dtype=jnp.uint64) with base_layer.JaxContext.new_context( prng_key=compute_key, global_step=global_step) as jax_context: jax_context.bind( causal_dconv_layer, causal_dconv_layer.vars_to_flax_vars(initial_vars)) jax_dconv_out = causal_dconv_layer.fprop(inputs, axis=axis) jax_np_out = test_utils.to_np(jax_dconv_out) jax_extend_step_out = jnp.zeros_like(jax_dconv_out) for i in range(shape[1]): jax_extend_step_out = causal_dconv_layer.extend_step(inputs, axis=axis, step=i) jax_np_extend_step_out = test_utils.to_np(jax_extend_step_out) jax_extend_step_out_tensor = causal_dconv_layer.extend_step( inputs, axis=axis, step=jnp.array(i)) jax_np_extend_step_out_tensor = test_utils.to_np( jax_extend_step_out_tensor) jax_fprop_slice = jax.lax.dynamic_slice_in_dim(jax_np_out, start_index=i, slice_size=1, axis=axis) jax_fprop_slice = jnp.squeeze(jax_fprop_slice, axis) self.assertArraysEqual(jax_fprop_slice, jax_np_extend_step_out) self.assertArraysEqual(jax_fprop_slice, jax_np_extend_step_out_tensor)
def test_rotary_position_embedding_layer_prefix(self, min_timescale, max_timescale, window_size): embedding_dims = 32 p = embedding_softmax.RotaryPositionalEmbedding.Params().Set( name='jax_pos', embedding_dims=embedding_dims, min_timescale=min_timescale, max_timescale=max_timescale) pos_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = pos_layer.instantiate_variables(prng_key) inputs = np.random.normal(1.5, 2.5, (2, 8, 4, embedding_dims)) output = test_utils.apply(pos_layer, initial_vars, pos_layer.fprop, inputs=inputs) # Test whether extend_step returns same output. for i in range(inputs.shape[1]): start = max(0, i + 1 - window_size) end = i + 1 inputs_prefix = inputs[:, start:end, :, :] pad_width = window_size - end + start paddings = [(0, 0), (pad_width, 0), (0, 0), (0, 0)] inputs_prefix = jnp.pad(inputs_prefix, paddings) jax_extend_step_out = test_utils.apply(pos_layer, initial_vars, pos_layer.extend_step, inputs_prefix, time_step=i) jax_extend_step_out = jax.lax.dynamic_slice_in_dim( jax_extend_step_out, start_index=window_size - 1, slice_size=1, axis=1) jax_np_extend_step_out = test_utils.to_np(jax_extend_step_out) jax_fprop_slice = jax.lax.dynamic_slice_in_dim(output, start_index=i, slice_size=1, axis=1) self.assertArraysEqual(jax_fprop_slice, jax_np_extend_step_out)
def test_transformer_layer_norm_policies(self, norm_policy): p = transformers.Transformer.Params().Set(name='jax_transformer_layer', input_dims=32, hidden_dims=128, num_heads=8, mask_self_attention=True, packed_input=True, cross_attention=False, norm_policy=norm_policy) seq_len = np.random.randint(10, 32) batch_size = 10 transformer_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = transformer_layer.instantiate_variables(prng_key) npy_inputs = np.random.normal( 1.0, 0.5, [batch_size, seq_len, p.input_dims]).astype('float32') inputs = jnp.asarray(npy_inputs) npy_paddings = np.random.randint( 0, 1, [batch_size, seq_len]).astype('float32') paddings = jnp.asarray(npy_paddings) attention_mask = attentions.convert_paddings_to_mask(paddings) causal_mask = attentions.causal_mask(inputs) attention_mask = jnp.minimum(attention_mask, causal_mask) segment_ids = np.random.random_integers(0, 2, [batch_size, seq_len]) segment_mask = attentions.segment_mask(segment_ids, dtype=np.float32) attention_mask = jnp.minimum(attention_mask, segment_mask) with base_layer.JaxContext.new_context( prng_key=prng_key, global_step=jnp.array(0, dtype=jnp.uint32)) as jax_context: jax_context.bind(transformer_layer, transformer_layer.vars_to_flax_vars(initial_vars)) outputs, _ = transformer_layer.fprop(inputs, paddings, attention_mask=attention_mask) logging.info('initial_vars in transformer layer = %s', initial_vars) np_outputs = test_utils.to_np(outputs) # Plumbing test. self.assertAllClose(np_outputs, np_outputs, atol=1e-5)
def test_glam_unitransformer(self): batch = 2 length = 3 d_model = 6 num_heads = 2 vocab_size = 16 ff_dim = 8 c_dim = 3 e_dim = 2 num_layers = 4 # Build jax layer jax_p = transformer_models.TransformerLm.GLaMUniTransformerParams( name='model', vocab_size=vocab_size, num_transformer_layers=num_layers, moe=True, model_dim=d_model, ff_dim=ff_dim, moe_hidden_dim=ff_dim, attention_num_heads=num_heads, attention_key_value_dim=d_model // num_heads, attention_extra_logit=0.0, use_tgt_labels_size_as_loss_denominator=True, moe_load_balance_loss_weight=0.01, z_loss_weight=1e-4, c_dim=c_dim, e_dim=e_dim) assert jax_p.packed_input jax_layer = jax_p.Instantiate() prng_key = jax.random.PRNGKey(seed=42) jax_vars = jax_layer.instantiate_variables(prng_key) builder_p = gshard_builder.DenseBuilder.Params().Set( num_groups=1, second_expert_policy='all', relative_attention_type='bias', model_dim=d_model, attention_key_value_dim=d_model // num_heads, attention_num_heads=num_heads, attention_combine_dims=True, c_dim=c_dim, capacity_factor=None, attention_extra_logit=0.0, e_dim=e_dim, moe_hidden_dim=ff_dim, ff_dim=ff_dim) tf_layer = gshard_builder.UniTransformer.Params().Set( name='model', num_transformer_layers=num_layers, builder=builder_p, vocab_size=vocab_size, sequence_length=length, label_smoothing=0, aux_loss_coef=0.01, z_loss=1e-4, use_tgt_labels_size_as_loss_denominator=True, positional_embedding=False, gated_gelu=True, moe=True).Instantiate() # Build Jax Inputs np.random.seed(42) npy_ids = np.random.randint(0, vocab_size - 1, [batch, length]) jax_ids = jnp.asarray(npy_ids) npy_paddings = np.array([[0, 0, 1], [0, 0, 1]], dtype=np.float32) jax_paddings = jnp.asarray(npy_paddings) npy_segment_ids = np.array([[1, 2, 0], [1, 1, 0]], dtype=np.int32) npy_segment_pos = np.array([[0, 0, 0], [0, 1, 0]], dtype=np.int32) npy_labels = np.roll(npy_ids, -1, axis=1) jax_labels = jnp.asarray(npy_labels) jax_seg_ids = jnp.asarray(npy_segment_ids) jax_seg_pos = jnp.asarray(npy_segment_pos) jax_label_weighs = jnp.asarray([[1, 1, 0], [1, 1, 0]]) # Build TF Inputs tf_tgt_inputs = py_utils.NestedMap( ids=tf.convert_to_tensor(npy_ids, dtype=tf.int32), labels=tf.convert_to_tensor(npy_labels, dtype=tf.int32), segment_ids=tf.convert_to_tensor(npy_segment_ids, dtype=tf.int32), segment_pos=tf.convert_to_tensor(npy_segment_pos, dtype=tf.int32)) tf_inputs = py_utils.NestedMap(tgt=tf_tgt_inputs) # Compute jax outputs jax_outputs = test_utils.apply(jax_layer, jax_vars, jax_layer.fprop, jax_ids, jax_paddings, context_p=None, labels=py_utils.NestedMap( class_ids=jax_labels, class_weights=jax_label_weighs, ), segment_ids=jax_seg_ids, segment_pos=jax_seg_pos) # Copy jax vars to tf ones. tf_theta = tf_layer.theta.DeepCopy() # GShardBuilder softmax weight use self.vars rather than theta. tf_layer.vars.dec_emb.w.embedding.assign(jax_vars.softmax.embedding.w) tf_theta.dec_emb.w.embedding = jax_vars.softmax.embedding.w tf_theta.dec.final_layer_norm.w.scale = jax_vars.final_ln.scale jax_layer_0_var = tf.nest.map_structure( lambda v: jnp.squeeze(jnp.split(v, 2)[0], axis=0), jax_vars.transformer.repeat.sub.x_layers[0]) tf_theta.dec.layer_000.ln.w.scale = jax_layer_0_var.layer_norm.scale jax_atten_var = jax_layer_0_var.self_attention tf_atten_var = tf_theta.dec.layer_000.dec_self_attention tf_atten_var.w.wk = jax_atten_var.key.w tf_atten_var.w.wq = jax_atten_var.query.w tf_atten_var.w.wv = jax_atten_var.value.w tf_atten_var.w.wo = jax_atten_var.post.w tf_atten_var.wrb.wrb = jax_atten_var.relative_bias.wrb jax_moe_var = jax_layer_0_var.ff_layer tf_theta.dec.layer_001.ln.w.scale = jax_moe_var.layer_norm.scale tf_theta.dec.layer_001.moe.ffw.top_2_gating.w = jax_moe_var.gate tf_theta.dec.layer_001.moe.moe.wi = jax_moe_var.wi_0 tf_theta.dec.layer_001.moe.moe.wo = jax_moe_var.wo_0 jax_layer_1_var = tf.nest.map_structure( lambda v: jnp.squeeze(jnp.split(v, 2)[0], axis=0), jax_vars.transformer.repeat.sub.x_layers[1]) tf_theta.dec.layer_002.ln.w.scale = jax_layer_1_var.layer_norm.scale jax_atten_var = jax_layer_1_var.self_attention tf_atten_var = tf_theta.dec.layer_002.dec_self_attention tf_atten_var.w.wk = jax_atten_var.key.w tf_atten_var.w.wq = jax_atten_var.query.w tf_atten_var.w.wv = jax_atten_var.value.w tf_atten_var.w.wo = jax_atten_var.post.w tf_atten_var.wrb.wrb = jax_atten_var.relative_bias.wrb jax_ffn_var = jax_layer_1_var.ff_layer tf_ffn_var = tf_theta.dec.layer_003.dense_relu_dense tf_ffn_var.w.wi_0 = jax_ffn_var.ffn_layer1_gate.linear.w tf_ffn_var.w.wi_1 = jax_ffn_var.ffn_layer1.linear.w tf_ffn_var.w.wo = jax_ffn_var.ffn_layer2.linear.w tf_theta.dec.layer_003.ln.w.scale = jax_ffn_var.layer_norm.scale jax_layer_2_var = tf.nest.map_structure( lambda v: jnp.squeeze(jnp.split(v, 2)[1], axis=0), jax_vars.transformer.repeat.sub.x_layers[0]) tf_theta.dec.layer_004.ln.w.scale = jax_layer_2_var.layer_norm.scale jax_atten_var = jax_layer_2_var.self_attention tf_atten_var = tf_theta.dec.layer_004.dec_self_attention tf_atten_var.w.wk = jax_atten_var.key.w tf_atten_var.w.wq = jax_atten_var.query.w tf_atten_var.w.wv = jax_atten_var.value.w tf_atten_var.w.wo = jax_atten_var.post.w tf_atten_var.wrb.wrb = jax_atten_var.relative_bias.wrb jax_moe_var = jax_layer_2_var.ff_layer tf_theta.dec.layer_005.ln.w.scale = jax_moe_var.layer_norm.scale tf_theta.dec.layer_005.moe.ffw.top_2_gating.w = jax_moe_var.gate tf_theta.dec.layer_005.moe.moe.wi = jax_moe_var.wi_0 tf_theta.dec.layer_005.moe.moe.wo = jax_moe_var.wo_0 jax_layer_3_var = tf.nest.map_structure( lambda v: jnp.squeeze(jnp.split(v, 2)[1], axis=0), jax_vars.transformer.repeat.sub.x_layers[1]) tf_theta.dec.layer_006.ln.w.scale = jax_layer_3_var.layer_norm.scale jax_atten_var = jax_layer_3_var.self_attention tf_atten_var = tf_theta.dec.layer_006.dec_self_attention tf_atten_var.w.wk = jax_atten_var.key.w tf_atten_var.w.wq = jax_atten_var.query.w tf_atten_var.w.wv = jax_atten_var.value.w tf_atten_var.w.wo = jax_atten_var.post.w tf_atten_var.wrb.wrb = jax_atten_var.relative_bias.wrb jax_ffn_var = jax_layer_3_var.ff_layer tf_ffn_var = tf_theta.dec.layer_007.dense_relu_dense tf_ffn_var.w.wi_0 = jax_ffn_var.ffn_layer1_gate.linear.w tf_ffn_var.w.wi_1 = jax_ffn_var.ffn_layer1.linear.w tf_ffn_var.w.wo = jax_ffn_var.ffn_layer2.linear.w tf_theta.dec.layer_007.ln.w.scale = jax_ffn_var.layer_norm.scale tf_theta = test_utils.to_tf_nmap(tf_theta) # Compute TF outputs tf_out, _ = tf_layer.FProp(tf_theta, tf_inputs) self.assertAllClose(test_utils.to_np(jax_outputs.total_loss), test_utils.to_np(tf_out['loss'][0]))
def test_transformer_moe_dense_layer(self, mask_self_attention, packed_input, cross_attention): # Comparing scan over blocks of layers and regular loop block_p = transformers.StackedTransformer.Params().Set( name='transformer_block', num_layers=2, model_dims=3, hidden_dims=6, num_heads=1, mask_self_attention=mask_self_attention, packed_input=packed_input, cross_attention=cross_attention, num_experts=4, num_groups=1, moe_layers=[0]) block_p_repeated = transformers.StackedTransformerRepeated.Params( ).Set(name='stacked_transformer_layer_repeated', block=block_p.Copy(), x_times=1) stack_p = transformers.StackedTransformer.Params().Set( name='transformer_stack', num_layers=2, # moe + dense model_dims=block_p.model_dims, hidden_dims=block_p.hidden_dims, num_heads=block_p.num_heads, mask_self_attention=block_p.mask_self_attention, packed_input=block_p.packed_input, cross_attention=block_p.cross_attention, num_experts=block_p.num_experts, num_groups=block_p.num_groups, moe_layers=[0]) moe_p = stack_p.moe_layer_tpl moe_p.expert_capacity_dim = 2 moe_p.expert_capacity_factor = 0 moe_p = block_p.moe_layer_tpl moe_p.expert_capacity_dim = 2 moe_p.expert_capacity_factor = 0 transformer_block = block_p_repeated.Instantiate() transformer_stack = stack_p.Instantiate() seq_len = 4 batch_size = 3 prng_key = jax.random.PRNGKey(seed=123) block_initial_vars = transformer_block.instantiate_variables(prng_key) stack_initial_vars = transformer_stack.instantiate_variables(prng_key) npy_inputs = np.random.normal( 1.0, 0.5, [batch_size, seq_len, block_p.model_dims]).astype('float32') inputs = jnp.asarray(npy_inputs) npy_paddings = np.random.randint( 0, 1, [batch_size, seq_len]).astype('float32') paddings = jnp.asarray(npy_paddings) segment_mask = None if packed_input: segment_ids = np.random.random_integers(0, 2, [batch_size, seq_len]) segment_mask = attentions.segment_mask(segment_ids, dtype=np.float32) cross_inputs = None cross_paddings = None cross_segment_mask = None if cross_attention: cross_seq_len = np.random.randint(10, 64) npy_cross_inputs = np.random.normal( 1.0, 0.5, [batch_size, cross_seq_len, block_p.model_dims ]).astype('float32') cross_inputs = jnp.asarray(npy_cross_inputs) npy_cross_paddings = np.random.randint( 0, 1, [batch_size, cross_seq_len]).astype('float32') cross_paddings = jnp.asarray(npy_cross_paddings) if packed_input: source_segment_ids = np.random.random_integers( 0, 2, [batch_size, cross_seq_len]) cross_segment_mask = attentions.segment_mask( segment_ids, source_segment_ids, dtype=np.float32) block_outputs = test_utils.apply(transformer_block, block_initial_vars, transformer_block.fprop, inputs, paddings, segment_mask=segment_mask, cross_inputs=cross_inputs, cross_paddings=cross_paddings, cross_segment_mask=cross_segment_mask) stack_outputs = test_utils.apply(transformer_stack, stack_initial_vars, transformer_stack.fprop, inputs, paddings, segment_mask=segment_mask, cross_inputs=cross_inputs, cross_paddings=cross_paddings, cross_segment_mask=cross_segment_mask) _ = test_utils.to_np(block_outputs) _ = test_utils.to_np(stack_outputs)
def test_transformer_layer(self, mask_self_attention, packed_input, cross_attention): p = transformers.Transformer.Params().Set( name='jax_transformer_layer', input_dims=32, hidden_dims=128, num_heads=8, mask_self_attention=mask_self_attention, packed_input=packed_input, cross_attention=cross_attention) seq_len = np.random.randint(10, 32) batch_size = 10 transformer_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = transformer_layer.instantiate_variables(prng_key) npy_inputs = np.random.normal( 1.0, 0.5, [batch_size, seq_len, p.input_dims]).astype('float32') inputs = jnp.asarray(npy_inputs) npy_paddings = np.random.randint( 0, 1, [batch_size, seq_len]).astype('float32') paddings = jnp.asarray(npy_paddings) causal_mask = None segment_mask = None tf_segment_mask = None attention_mask = attentions.convert_paddings_to_mask(paddings) if mask_self_attention: causal_mask = attentions.causal_mask(inputs) attention_mask = jnp.minimum(attention_mask, causal_mask) if packed_input: segment_ids = np.random.random_integers(0, 2, [batch_size, seq_len]) segment_mask = attentions.segment_mask(segment_ids, dtype=np.float32) attention_mask = jnp.minimum(attention_mask, segment_mask) if mask_self_attention: tf_segment_mask = batch_major_attention.CausalSegmentMask( segment_ids, tf.float32) else: tf_segment_mask = batch_major_attention.SegmentMask( segment_ids, segment_ids) cross_inputs = None cross_attention_mask = None tf_cross_inputs = None tf_cross_paddings = None tf_cross_segment_mask = None if cross_attention: cross_seq_len = np.random.randint(10, 128) npy_cross_inputs = np.random.normal( 1.0, 0.5, [batch_size, cross_seq_len, p.input_dims]).astype('float32') cross_inputs = jnp.asarray(npy_cross_inputs) tf_cross_inputs = tf.constant(npy_cross_inputs, dtype=tf.float32) npy_cross_paddings = np.random.randint( 0, 1, [batch_size, cross_seq_len]).astype('float32') cross_paddings = jnp.asarray(npy_cross_paddings) cross_attention_mask = attentions.convert_paddings_to_mask( cross_paddings) tf_cross_paddings = tf.constant(npy_cross_paddings, dtype=tf.float32) if packed_input: source_segment_ids = np.random.random_integers( 0, 2, [batch_size, cross_seq_len]) cross_segment_mask = attentions.segment_mask( segment_ids, source_segment_ids, dtype=np.float32) cross_attention_mask = jnp.minimum(cross_attention_mask, cross_segment_mask) tf_cross_segment_mask = batch_major_attention.SegmentMask( segment_ids, source_segment_ids) outputs, _ = test_utils.apply( transformer_layer, initial_vars, transformer_layer.fprop, inputs, paddings, context_p=None, attention_mask=attention_mask, cross_inputs=cross_inputs, cross_attention_mask=cross_attention_mask) logging.info('initial_vars in transformer layer = %s', initial_vars) # Test whether tf Transformer layer returns same output # Modify initial_vars to use TF compatible params tf_initial_vars = test_utils.replace_jax_attention_vars_to_tf( initial_vars, cross_attention) tf_initial_vars = test_utils.to_tf_nmap(tf_initial_vars) logging.info('tf_initial_vars in transformer layer = %s', initial_vars) tf_p = batch_major_attention.TransformerLayer.Params().Set( name='tf_transformer_layer', input_dim=p.input_dims, num_heads=p.num_heads, mask_self_atten=mask_self_attention, packed_input=packed_input, has_aux_atten=cross_attention) tf_p.tr_fflayer_tpl.hidden_dim = p.hidden_dims tf_p.tr_fflayer_tpl.fflayer_tpl.batch_norm = False tf_p.tr_fflayer_tpl.fflayer_tpl.has_bias = True tf_transformer_layer = tf_p.Instantiate() tf_output, _ = tf_transformer_layer.FProp( tf_initial_vars, tf.constant(npy_inputs, dtype=tf.float32), paddings=test_utils.to_tf_nmap(npy_paddings), segment_mask=tf_segment_mask, aux_vec=tf_cross_inputs, aux_paddings=tf_cross_paddings, aux_segment_mask=test_utils.to_tf_nmap(tf_cross_segment_mask)) np_outputs = test_utils.to_np(outputs) tf_np_outputs = test_utils.to_np(tf_output) self.assertAllClose(tf_np_outputs, np_outputs, atol=1e-5)
def test_stacked_transformer_layer(self, mask_self_attention, packed_input, cross_attention): p = transformers.StackedTransformer.Params().Set( name='jax_stacked_transformer_layer', model_dims=16, hidden_dims=64, num_heads=8, mask_self_attention=mask_self_attention, num_layers=4, packed_input=packed_input, cross_attention=cross_attention) seq_len = np.random.randint(10, 32) batch_size = 10 stacked_transformer_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = stacked_transformer_layer.instantiate_variables( prng_key) # test conversion between vars and flax vars. pax_vars = stacked_transformer_layer.vars flax_vars = stacked_transformer_layer.flax_vars tf.nest.assert_same_structure( pax_vars, stacked_transformer_layer.flax_vars_to_vars(flax_vars)) tf.nest.assert_same_structure( flax_vars, stacked_transformer_layer.vars_to_flax_vars(pax_vars)) npy_inputs = np.random.normal( 1.0, 0.5, [batch_size, seq_len, p.model_dims]).astype('float32') inputs = jnp.asarray(npy_inputs) npy_paddings = np.random.randint( 0, 1, [batch_size, seq_len]).astype('float32') paddings = jnp.asarray(npy_paddings) segment_mask = None tf_segment_mask = None if packed_input: segment_ids = np.random.random_integers(0, 2, [batch_size, seq_len]) segment_mask = attentions.segment_mask(segment_ids, dtype=np.float32) if mask_self_attention: tf_segment_mask = batch_major_attention.CausalSegmentMask( segment_ids, tf.float32) else: tf_segment_mask = batch_major_attention.SegmentMask( segment_ids, segment_ids) cross_inputs = None cross_paddings = None cross_segment_mask = None tf_cross_inputs = None tf_cross_paddings = None tf_cross_segment_mask = None if cross_attention: cross_seq_len = np.random.randint(10, 64) npy_cross_inputs = np.random.normal( 1.0, 0.5, [batch_size, cross_seq_len, p.model_dims]).astype('float32') cross_inputs = jnp.asarray(npy_cross_inputs) tf_cross_inputs = tf.constant(npy_cross_inputs, dtype=tf.float32) npy_cross_paddings = np.random.randint( 0, 1, [batch_size, cross_seq_len]).astype('float32') cross_paddings = jnp.asarray(npy_cross_paddings) tf_cross_paddings = tf.constant(npy_cross_paddings, dtype=tf.float32) if packed_input: source_segment_ids = np.random.random_integers( 0, 2, [batch_size, cross_seq_len]) cross_segment_mask = attentions.segment_mask( segment_ids, source_segment_ids, dtype=np.float32) tf_cross_segment_mask = batch_major_attention.SegmentMask( segment_ids, source_segment_ids) outputs = test_utils.apply(stacked_transformer_layer, initial_vars, stacked_transformer_layer.fprop, inputs, paddings, context_p=None, segment_mask=segment_mask, cross_inputs=cross_inputs, cross_paddings=cross_paddings, cross_segment_mask=cross_segment_mask) logging.info('initial_vars in transformer layer = %s', initial_vars) # Test whether tf Transformer layer returns same output # Modify initial_vars to use TF compatible params tf_initial_vars = py_utils.NestedMap() tf_initial_vars.x_layers = [] for jax_initial_vars in initial_vars.x_layers: tf_layer_vars = test_utils.replace_jax_attention_vars_to_tf( jax_initial_vars, cross_attention) tf_initial_vars.x_layers.append(tf_layer_vars) tf_initial_vars = test_utils.to_tf_nmap(tf_initial_vars) logging.info('tf_initial_vars in transformer layer = %s', initial_vars) tf_p = batch_major_attention.StackedTransformerLayers.Params().Set( name='tf_transformer_layer', mdl_dim=p.model_dims, hidden_dim=p.hidden_dims, num_atten_heads=p.num_heads, mask_self_atten=mask_self_attention, num_layers=p.num_layers, packed_input=packed_input, has_aux_atten=cross_attention) tf_p.transformer_layer_params_tpl.tr_fflayer_tpl.fflayer_tpl.batch_norm = ( False) tf_p.transformer_layer_params_tpl.tr_fflayer_tpl.fflayer_tpl.has_bias = True tf_stacked_transformer_layer = tf_p.Instantiate() tf_output, _ = tf_stacked_transformer_layer.FProp( tf_initial_vars, test_utils.to_tf_nmap(npy_inputs), paddings=test_utils.to_tf_nmap(npy_paddings), segment_mask=test_utils.to_tf_nmap(tf_segment_mask), aux_vec=test_utils.to_tf_nmap(tf_cross_inputs), aux_paddings=test_utils.to_tf_nmap(tf_cross_paddings), aux_segment_mask=test_utils.to_tf_nmap(tf_cross_segment_mask)) np_outputs = test_utils.to_np(outputs) tf_np_outputs = test_utils.to_np(tf_output) self.assertAllClose(tf_np_outputs, np_outputs, atol=1e-5)
def test_stacked_transformer_layer_extendstep(self, packed_input, cross_attention, combine_qkv, dconv_qkv, use_rotary_position_emb): if cross_attention and combine_qkv: self.skipTest( 'combine_qkv optimization only works for self-attention.') layer_params = transformers.StackedTransformer.Params() num_layers = 2 model_dims = 8 p = layer_params.Set(name='jax_transformer_layer', model_dims=model_dims, hidden_dims=32, num_heads=2, mask_self_attention=True, packed_input=packed_input, cross_attention=cross_attention, num_layers=num_layers) p.transformer_layer_params_tpl.tr_atten_tpl.combine_qkv = combine_qkv p.transformer_layer_params_tpl.tr_atten_tpl.dconv_qkv = dconv_qkv p.transformer_layer_params_tpl.tr_atten_tpl.use_rotary_position_emb = ( use_rotary_position_emb) if cross_attention: p.transformer_layer_params_tpl.cross_atten_tpl = ( p.transformer_layer_params_tpl.tr_atten_tpl.Copy()) # Cross attention should not have depth-wise convolution. p.transformer_layer_params_tpl.cross_atten_tpl.dconv_qkv = False # Cross attention should not have rotary position embedding. p.transformer_layer_params_tpl.cross_atten_tpl.use_rotary_position_emb = ( False) p_copy = p.Copy() p_copy.num_layers = 1 p = transformers.StackedTransformerRepeated.Params() p.name = 'jax_transformer_repeated_layer' p.block = p_copy p.x_times = num_layers seq_len = 4 batch_size = 4 stacked_transformer_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = stacked_transformer_layer.instantiate_variables( prng_key) npy_inputs = np.random.normal( 1.0, 0.5, [batch_size, seq_len, model_dims]).astype('float32') inputs = jnp.asarray(npy_inputs) npy_paddings = np.random.randint( 0, 1, [batch_size, seq_len]).astype('float32') paddings = jnp.asarray(npy_paddings) attention_mask = attentions.convert_paddings_to_mask(paddings) segment_mask = None if packed_input: segment_ids = np.random.random_integers(0, 2, [batch_size, seq_len]) segment_mask = attentions.segment_mask(segment_ids, dtype=np.float32) cross_inputs = None cross_paddings = None cross_segment_mask = None if cross_attention: cross_seq_len = np.random.randint(10, 32) npy_cross_inputs = np.random.normal( 1.0, 0.5, [batch_size, cross_seq_len, model_dims]).astype('float32') cross_inputs = jnp.asarray(npy_cross_inputs) npy_cross_paddings = np.random.randint( 0, 1, [batch_size, cross_seq_len]).astype('float32') cross_paddings = jnp.asarray(npy_cross_paddings) if packed_input: source_segment_ids = np.random.random_integers( 0, 2, [batch_size, cross_seq_len]) cross_segment_mask = attentions.segment_mask( segment_ids, source_segment_ids, dtype=np.float32) prng_key = jax.random.PRNGKey(seed=123) global_step = jnp.array(0, dtype=jnp.uint64) with base_layer.JaxContext.new_context( prng_key=prng_key, global_step=global_step) as jax_context: jax_context.bind( stacked_transformer_layer, stacked_transformer_layer.vars_to_flax_vars(initial_vars)) fprop_outputs = stacked_transformer_layer.fprop( inputs, paddings, segment_mask=segment_mask, cross_inputs=cross_inputs, cross_paddings=cross_paddings, cross_segment_mask=cross_segment_mask) decoder_outputs = jnp.zeros( shape=[seq_len, batch_size, model_dims]) initial_states = stacked_transformer_layer.init_states( batch_size, seq_len) atten_states = initial_states for t in range(seq_len): segment_mask_t = attention_mask[:, :, t, :] cross_segment_mask_t = cross_segment_mask if segment_mask is not None: segment_mask_t = jnp.minimum(segment_mask_t, segment_mask[:, :, t, :]) if cross_segment_mask is not None: cross_segment_mask_t = cross_segment_mask[:, :, t, :] atten_states, encoded = stacked_transformer_layer.extend_step( atten_states, inputs=inputs[:, t, :], time_step=t, segment_mask=segment_mask_t, cross_inputs=cross_inputs, cross_paddings=cross_paddings, cross_segment_mask=cross_segment_mask_t) decoder_outputs = decoder_outputs.at[t].set(encoded) decoder_out_transposed = jnp.transpose(decoder_outputs, [1, 0, 2]) # TODO(lepikhin): remove noisy test logging # logging.info('initial_vars in transformer layer = %s', initial_vars) np_fprop_outputs = test_utils.to_np(fprop_outputs) np_decoder_outputs = test_utils.to_np(decoder_out_transposed) self.assertAllClose(np_fprop_outputs, np_decoder_outputs, atol=1e-5)
def test_mha_02(self): mdl_dim = 16 hidden_dim = 32 num_heads = 4 test_layer_p = attentions.DotProductAttention.Params().Set( name='mh', input_dim=mdl_dim, hidden_dim=hidden_dim, num_heads=num_heads, atten_logit_cap=20.0, ) layer = test_layer_p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) prng_key, init_key = jax.random.split(prng_key) initial_vars = layer.instantiate_variables(init_key) target_batch_size = 3 source_max_length = 8 target_max_length = 8 query_vec = np.random.normal( size=[target_batch_size, source_max_length, mdl_dim]).astype( np.float32) key_vec = np.random.normal( size=[target_batch_size, source_max_length, mdl_dim]).astype( np.float32) value_vec = np.random.normal( size=[target_batch_size, source_max_length, mdl_dim]).astype( np.float32) segment_ids = np.random.random_integers( 0, 1, size=[target_batch_size, target_max_length]).astype(np.int32) atten_mask = attentions.causal_segment_mask(segment_ids, np.float32) jax_fprop_out, jax_atten_prob = test_utils.apply( layer, initial_vars, layer.fprop, query_vec, key_vec, value_vec, atten_mask) tf_layer_p = batch_major_attention.MultiHeadedAttention.Params().Set( name='mh', input_dim=mdl_dim, hidden_dim=hidden_dim, num_heads=num_heads, atten_logit_cap=20.0, packed_input=True) tf_layer = tf_layer_p.Instantiate() tf_out, tf_atten_prob = tf_layer.FProp( initial_vars, query_vec, key_vec, value_vec, paddings=tf.zeros([target_batch_size, source_max_length]), segment_mask=atten_mask) logging.info('jax_layer_out: %s', jax_fprop_out) logging.info('jax_atten_probs: %s', jax_atten_prob) logging.info('tf_layer_out: %s', tf_out) logging.info('tf_atten_probs: %s', tf_atten_prob) self.assertAllClose(test_utils.to_np(jax_fprop_out), test_utils.to_np(tf_out)) self.assertAllClose(test_utils.to_np(jax_atten_prob), test_utils.to_np(tf_atten_prob))
def test_transformer_layer_cross_attention_ln(self, packed_input): p = transformers.Transformer.Params().Set(name='jax_transformer_layer', input_dims=8, hidden_dims=32, num_heads=4, mask_self_attention=True, packed_input=packed_input, cross_attention=True) seq_len = 5 batch_size = 4 transformer_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = transformer_layer.instantiate_variables(prng_key) # Change the self attention initial vars. initial_vars.layer_norm.scale = 0.5 initial_vars.layer_norm.bias = 5.0 # Change the cross attention initial vars. initial_vars.cross_layer_norm.scale = 15 initial_vars.cross_layer_norm.bias = 1.5 npy_inputs = np.random.normal( 1.0, 0.5, [batch_size, seq_len, p.input_dims]).astype('float32') inputs = jnp.asarray(npy_inputs) npy_paddings = np.random.randint( 0, 1, [batch_size, seq_len]).astype('float32') paddings = jnp.asarray(npy_paddings) attention_mask = attentions.convert_paddings_to_mask(paddings) causal_mask = attentions.causal_mask(inputs) attention_mask = jnp.minimum(causal_mask, attention_mask) if packed_input: segment_ids = np.random.random_integers(0, 2, [batch_size, seq_len]) segment_mask = attentions.segment_mask(segment_ids, dtype=np.float32) attention_mask = jnp.minimum(attention_mask, segment_mask) with base_layer.JaxContext.new_context( prng_key=prng_key, global_step=jnp.array(0, dtype=jnp.uint32)) as jax_context: jax_context.bind(transformer_layer, transformer_layer.vars_to_flax_vars(initial_vars)) inputs_normalized = transformer_layer.layer_norm.fprop(inputs) # Compute self-attention, key/value vectors are the input itself atten_output, _ = transformer_layer.self_attention.fprop( inputs_normalized, inputs_normalized, inputs_normalized, atten_mask=attention_mask) # Residual dropout and connection. atten_output = transformer_layer.residual_dropout.fprop( atten_output) atten_output += inputs # Normalize atten outputs using cross attention. atten_output_normalized = transformer_layer.cross_layer_norm.fprop( atten_output) inputs_normalized = test_utils.to_np(inputs_normalized) atten_output_normalized = test_utils.to_np(atten_output_normalized) self.assertAllClose(initial_vars.layer_norm.bias, inputs_normalized.mean(), atol=1e-3) self.assertAllClose((1.0 + initial_vars.layer_norm.scale)**2, np.var(inputs_normalized), atol=5e-3) self.assertAllClose(initial_vars.cross_layer_norm.bias, atten_output_normalized.mean(), atol=1e-3) self.assertAllClose((1.0 + initial_vars.cross_layer_norm.scale)**2, np.var(atten_output_normalized), atol=5e-3)
def test_mask(self): a = np.random.random_integers(0, 5, size=[2, 50]) jax_mask = attentions.causal_segment_mask(a, jnp.float32) tf_mask = batch_major_attention.CausalSegmentMask(a, tf.float32) self.assertAllClose(test_utils.to_np(jax_mask), test_utils.to_np(tf_mask))
def test_transformer_layer_extendstep(self, packed_input, cross_attention, dconv_qkv, use_rotary_position_emb): p = transformers.Transformer.Params().Set( name='jax_transformer_layer', input_dims=8, hidden_dims=32, num_heads=4, mask_self_attention=True, packed_input=packed_input, cross_attention=cross_attention) p.tr_atten_tpl.dconv_qkv = dconv_qkv p.tr_atten_tpl.use_rotary_position_emb = use_rotary_position_emb if cross_attention: p.cross_atten_tpl = p.tr_atten_tpl.Copy() # Cross attention should not have depth-wise convolution. p.cross_atten_tpl.dconv_qkv = False # Cross attention should not have rotary position embedding. p.cross_atten_tpl.use_rotary_position_emb = False p.tr_atten_tpl.dconv_kernel_size = 2 seq_len = 4 batch_size = 4 transformer_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = transformer_layer.instantiate_variables(prng_key) initial_states = transformer_layer.init_states(batch_size, seq_len) npy_inputs = np.random.normal( 1.0, 0.5, [batch_size, seq_len, p.input_dims]).astype('float32') inputs = jnp.asarray(npy_inputs) npy_paddings = np.random.randint( 0, 1, [batch_size, seq_len]).astype('float32') # npy_paddings = np.zeros([batch_size, seq_len]) paddings = jnp.asarray(npy_paddings) attention_mask = attentions.convert_paddings_to_mask(paddings) segment_mask = None causal_mask = attentions.causal_mask(inputs) attention_mask = jnp.minimum(causal_mask, attention_mask) if packed_input: segment_ids = np.random.random_integers(0, 2, [batch_size, seq_len]) segment_mask = attentions.segment_mask(segment_ids, dtype=np.float32) attention_mask = jnp.minimum(attention_mask, segment_mask) cross_inputs = None cross_paddings = None cross_attention_mask = None if cross_attention: cross_seq_len = np.random.randint(10, 32) npy_cross_inputs = np.random.normal( 1.0, 0.5, [batch_size, cross_seq_len, p.input_dims]).astype('float32') cross_inputs = jnp.asarray(npy_cross_inputs) npy_cross_paddings = np.random.randint( 0, 1, [batch_size, cross_seq_len]).astype('float32') cross_paddings = jnp.asarray(npy_cross_paddings) cross_attention_mask = attentions.convert_paddings_to_mask( cross_paddings) if packed_input: source_segment_ids = np.random.random_integers( 0, 2, [batch_size, cross_seq_len]) cross_segment_mask = attentions.segment_mask( segment_ids, source_segment_ids, dtype=np.float32) cross_attention_mask = jnp.minimum(cross_attention_mask, cross_segment_mask) with base_layer.JaxContext.new_context( prng_key=prng_key, global_step=jnp.array(0, dtype=jnp.uint32)) as jax_context: jax_context.bind(transformer_layer, transformer_layer.vars_to_flax_vars(initial_vars)) fprop_outputs, _ = transformer_layer.fprop( inputs, paddings, attention_mask=attention_mask, cross_inputs=cross_inputs, cross_attention_mask=cross_attention_mask) decoder_outputs = jnp.zeros( shape=[seq_len, batch_size, p.input_dims]) atten_states = initial_states for t in range(seq_len): attention_mask_t = attention_mask[:, :, t, :] cross_attention_mask_t = cross_attention_mask if cross_attention: cross_attention_mask_t = cross_attention_mask[:, :, t, :] cross_attention_mask_t = np.expand_dims( cross_attention_mask_t, axis=2) atten_states, encoded = transformer_layer.extend_step( atten_states, inputs=inputs[:, t, :], time_step=t, attention_mask=attention_mask_t, cross_inputs=cross_inputs, cross_attention_mask=cross_attention_mask_t) decoder_outputs = decoder_outputs.at[t].set(encoded) decoder_out_transposed = jnp.transpose(decoder_outputs, [1, 0, 2]) logging.info('initial_vars in transformer layer = %s', initial_vars) np_fprop_outputs = test_utils.to_np(fprop_outputs) np_decoder_outputs = test_utils.to_np(decoder_out_transposed) self.assertAllClose(np_fprop_outputs, np_decoder_outputs, atol=1e-5)