def _testUnstack(self, inputs, **kwargs): params = linears.StackingOverTime.Params().Set( name='stackingOverTime', **kwargs) stacker = params.Instantiate() stacker_vars = None stacked, _ = test_utils.apply(stacker, stacker_vars, stacker.fprop, inputs) unstacked = test_utils.apply(stacker, stacker_vars, stacker.unstack, stacked) print(f'{unstacked}') batch, input_length, depth = inputs.shape stacked_length = stacked.shape[1] stride = stacker.params.stride right_context = stacker.params.right_context self.assertAllClose( unstacked.shape, [batch, (stacked_length - 1) * stride + right_context + 1, depth]) if right_context + 1 >= stride: self.assertGreaterEqual(unstacked.shape[1], input_length) self.assertAllClose(inputs, unstacked[:, :input_length]) else: self.assertLessEqual(unstacked.shape[1], input_length) # The final up to stride - right_context - 1 values are missing. self.assertLessEqual(input_length - unstacked.shape[1], stride - right_context - 1) self.assertAllClose(inputs[:, :unstacked.shape[1]], unstacked)
def test_rotary_position_embedding_layer_no_prefix(self, min_timescale, max_timescale): embedding_dims = 32 p = embedding_softmax.RotaryPositionalEmbedding.Params().Set( name='jax_pos', embedding_dims=embedding_dims, min_timescale=min_timescale, max_timescale=max_timescale) pos_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = pos_layer.instantiate_variables(prng_key) inputs = np.random.normal(1.5, 2.5, (2, 8, 4, embedding_dims)) output = test_utils.apply(pos_layer, initial_vars, pos_layer.fprop, inputs=inputs) # Test whether extend_step returns same output. for i in range(inputs.shape[1]): jax_extend_step_out = test_utils.apply(pos_layer, initial_vars, pos_layer.extend_step, inputs[:, i, :, :], time_step=i) jax_np_extend_step_out = test_utils.to_np(jax_extend_step_out) jax_fprop_slice = output[:, i, :, :] self.assertArraysEqual(jax_fprop_slice, jax_np_extend_step_out)
def test_ngrammer_layer_exact_bigram_2d(self, unigram_vocab_size, ngram_emb_dim, num_heads, dim_per_head, concat_ngrams): batch_size = 2 seq_len = 8 inputs = np.random.randint(unigram_vocab_size, size=[batch_size, seq_len], dtype=np.int32) paddings = np.random.randint(1, size=[batch_size, seq_len]) input_embs = np.random.normal( 1.5, 2.0, (batch_size, seq_len, num_heads * dim_per_head)) prng_key = jax.random.PRNGKey(seed=123) prng_key, init_key = jax.random.split(prng_key) ngrammer_layer_p = ngrammer.Ngrammer.Params().Set( name='jax_ngrammer_layer', unigram_vocab_size=unigram_vocab_size, ngram_vocab_size=num_heads * unigram_vocab_size**2, ngram_emb_dim=ngram_emb_dim, num_heads=num_heads, dim_per_head=dim_per_head, concat_ngrams=concat_ngrams, ) ngrammer_layer = ngrammer_layer_p.Instantiate() initial_vars = ngrammer_layer.instantiate_variables(init_key) ngram_embs = test_utils.apply(ngrammer_layer, initial_vars, ngrammer_layer.fprop, inputs, input_embs, paddings) ngram_embs = np.reshape(ngram_embs, [batch_size, seq_len, num_heads, dim_per_head]) input_embs = np.reshape(input_embs, [batch_size, seq_len, num_heads, dim_per_head]) for i in range(num_heads): input_ids_per_head = inputs ngram_ids_per_head = ngrammer.get_bigram_ids( input_ids_per_head, unigram_vocab_size) ngram_ids_per_head *= (i + 1) ngram_ids_per_head += (i + 1) ngram_embs_expected = test_utils.apply( ngrammer_layer.ngram_table[i], initial_vars.ngram_table[i], ngrammer_layer.ngram_table[i].fprop, np.reshape(ngram_ids_per_head, [-1])) ngram_embs_expected = test_utils.apply( ngrammer_layer.ngram_layer_norm[i], initial_vars.ngram_layer_norm[i], ngrammer_layer.ngram_layer_norm[i].fprop, ngram_embs_expected) ngram_embs_expected = jnp.reshape( ngram_embs_expected, [batch_size, seq_len, ngram_emb_dim]) ngram_embs_expected *= (1 - paddings[:, :, np.newaxis]) if concat_ngrams: ngram_embs_slice = ngram_embs[:, :, i, -ngram_emb_dim:] else: input_embs_ln = test_utils.apply( ngrammer_layer.emb_layer_norm[i], initial_vars.emb_layer_norm[i], ngrammer_layer.emb_layer_norm[i].fprop, input_embs[:, :, i, :]) ngram_embs_slice = ngram_embs[:, :, i, :] - input_embs_ln self.assertAllClose(to_np(ngram_embs_slice), to_np(ngram_embs_expected))
def test_group_norm(self, dim, num_groups, cumulative, input_rank, epsilon, input_shape, input_dtype, paddings, fprop_dtype): p = normalizations.GroupNorm.Params().Set(name='jax_gn', dim=dim, num_groups=num_groups, cumulative=cumulative, input_rank=input_rank, epsilon=epsilon, fprop_dtype=fprop_dtype) group_norm = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123456) prng_key, init_key = jax.random.split(prng_key) initial_vars = group_norm.instantiate_variables(init_key) npy_input = np.random.normal(1.0, 0.5, input_shape).astype(np.float32) inputs = jnp.asarray(npy_input, dtype=input_dtype) if paddings is None: output = test_utils.apply(group_norm, initial_vars, group_norm.fprop, inputs, paddings=None) else: output, output_paddings = test_utils.apply(group_norm, initial_vars, group_norm.fprop, inputs, paddings=jnp.asarray( paddings, dtype=input_dtype)) # Now test whether tf layer norm returns same output. tf_p = bn_layers.GroupNormLayer.Params().Set( name='tf_gn', dim=dim, num_groups=num_groups, cumulative=cumulative, input_rank=input_rank, epsilon=epsilon, fprop_dtype=_JaxToTfDtype(fprop_dtype)) tf_group_norm = tf_p.Instantiate() tf_inputs = tf.constant(inputs, dtype=_JaxToTfDtype(input_dtype)) if paddings is None: tf_output = tf_group_norm.FProp(initial_vars, tf_inputs, paddings=None) else: tf_output, tf_output_paddings = tf_group_norm.FProp( initial_vars, tf_inputs, paddings=tf.convert_to_tensor( paddings, dtype=_JaxToTfDtype(input_dtype))) self.assertAllClose(to_np(tf_output), to_np(output)) if paddings is not None: self.assertAllClose(to_np(tf_output_paddings), to_np(output_paddings))
def test_simple_softmax_layer_class_probs(self, batch_size, num_classes): batch_size = 8 num_classes = 1001 class_probabilities = np.random.normal(1.5, 2.0, [batch_size, num_classes]) # Normalize class probabilities to be a probability distribution. class_probabilities /= np.sum(class_probabilities, axis=-1, keepdims=True) p = embedding_softmax.SingleShardFullSoftmax.Params().Set( name='jax_softmax', num_classes=num_classes, input_dims=40) softmax_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = softmax_layer.instantiate_variables(prng_key) npy_input = np.random.normal(1.5, 2.0, [batch_size, p.input_dims]) inputs = jnp.asarray(npy_input) class_weights = np.random.normal(1.5, 2.0, [batch_size, 1]) logits = test_utils.apply(softmax_layer, initial_vars, softmax_layer.get_logits, inputs) outputs = test_utils.apply(softmax_layer, initial_vars, softmax_layer.fprop, inputs, class_weights, class_ids=None, class_probabilities=class_probabilities) # Test whether tf Softmax layer returns same output. # Modify initial_vars to use TF compatible params. tf_initial_vars = test_utils.replace_jax_simple_full_softmax_vars_to_tf( initial_vars) # Convert all the values to TF tensor. tf_initial_vars = tf.nest.map_structure(tf.convert_to_tensor, tf_initial_vars) tf_p = lingvo_layers.SimpleFullSoftmax.Params().Set( name='tf_softmax', num_classes=p.num_classes, input_dim=p.input_dims) tf_softmax_layer = tf_p.Instantiate() tf_logits = tf_softmax_layer.Logits( tf_initial_vars, tf.constant(inputs, dtype=tf.float32)) tf_output = tf_softmax_layer.FProp( tf_initial_vars, tf.constant(inputs, dtype=tf.float32), class_weights, class_ids=None, class_probabilities=class_probabilities) # Check all entries in the NestedMap and ensure it matches TF. np_get_logits = to_np(logits) tf_np_get_logits = to_np(tf_logits) self.assertAllClose(np_get_logits, tf_np_get_logits) for k in outputs.keys(): self.assertAllClose(to_np(outputs[k]), to_np(tf_output[k]))
def test_combine_qkv_with_attention_combine_dims(self): input_dim = 64 dim_per_head = 8 num_heads = 8 # Reference combine qkv projection layer. ref_proj_p = attentions.CombinedQKVProjectionLayer.Params().Set( name='ref', input_dim=input_dim, dim_per_head=dim_per_head, num_heads=num_heads) proj = ref_proj_p.Instantiate() # Combine attention dim combine qkv projection layer. combine_proj_p = attentions.CombinedQKVProjectionLayer.Params().Set( name='ref', input_dim=input_dim, dim_per_head=dim_per_head, num_heads=num_heads, attention_combine_dims=True) combine_proj = combine_proj_p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) prng_key, init_key = jax.random.split(prng_key) initial_vars = proj.instantiate_variables(init_key) # Set up initial vars for combine attention dim projection. combine_initial_vars = combine_proj.instantiate_variables(init_key) combine_initial_vars.w = np.reshape( initial_vars.w, (3, input_dim, num_heads * dim_per_head)) combine_initial_vars.b = np.reshape(initial_vars.b, (3, num_heads * dim_per_head)) batch_size = 3 inputs = np.random.normal(size=[batch_size, input_dim]).astype( np.float32) prng_key, compute_key = jax.random.split(prng_key) global_step = jnp.array(0, dtype=jnp.uint64) with base_layer.JaxContext.new_context(prng_key=compute_key, global_step=global_step): q_proj_ref, k_proj_ref, v_proj_ref = test_utils.apply( proj, initial_vars, proj.fprop, inputs) q_proj_combine, k_proj_combine, v_proj_combine = test_utils.apply( combine_proj, combine_initial_vars, combine_proj.fprop, inputs) self.assertAllClose(q_proj_ref, q_proj_combine) self.assertAllClose(k_proj_ref, k_proj_combine) self.assertAllClose(v_proj_ref, v_proj_combine)
def test_feedforward_layer_no_bias(self, activation): p = linears.FeedForward.Params().Set( name='jax_ffn', input_dims=3, output_dims=20, has_bias=False, activation=activation) ffn = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = ffn.instantiate_variables(prng_key) npy_input = np.random.normal(1.0, 0.5, [10, 10, p.input_dims]).astype('float32') inputs = jnp.asarray(npy_input) outputs = test_utils.apply(ffn, initial_vars, ffn.fprop, inputs) logging.info('initial_vars in ffn = %s', initial_vars) # Test whether tf projection layer returns same output # Modify initial_vars to use TF compatible params tf_initial_vars = py_utils.NestedMap() tf_initial_vars.w = initial_vars.linear.w tf_initial_vars = to_tf_nmap(tf_initial_vars) tf_p = lingvo_layers.ProjectionLayer.Params().Set( name='tf_ffn', input_dim=p.input_dims, output_dim=p.output_dims, batch_norm=False, has_bias=False, activation=activation) tf_ffn = tf_p.Instantiate() tf_output = tf_ffn.FProp(tf_initial_vars, tf.constant(inputs, dtype=tf.float32)) np_outputs = to_np(outputs) tf_np_outputs = to_np(tf_output) self.assertAllClose(tf_np_outputs, np_outputs, atol=1e-6)
def testStackingOverTimeFProp2(self): params = linears.StackingOverTime.Params() params.name = 'stackingOverTime' params.left_context = 0 params.right_context = 1 params.stride = 2 stacker = linears.StackingOverTime(params) stacker_vars = None self.assertEqual(stacker.window_size, 2) inputs = np.random.normal(size=[2, 21, 16]) # poor man's tf.sequence_mask in np. mask = np.zeros([2, 21]).astype(np.float32) mask[0, :9] = 1. mask[1, :14] = 1. paddings = 1.0 - mask paddings = jnp.expand_dims(paddings, -1) outputs, output_paddings = test_utils.apply(stacker, stacker_vars, stacker.fprop, inputs, paddings) # length self.assertAllClose( np.array([5, 7], dtype=np.float32), np.sum(1.0 - output_paddings, (1, 2))) # input and output sums are equal self.assertAllClose(np.sum(inputs, (1, 2)), np.sum(outputs, (1, 2)))
def test_trainable_positional_embedding_layer(self, lookup_style): p = embedding_softmax.TrainablePositionalEmbedding.Params().Set( name='jax_pos_emb', max_seq_length=10, embedding_dims=40, lookup_style=lookup_style) emb_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = emb_layer.instantiate_variables(prng_key) npy_input = np.random.randint(0, p.max_seq_length, [10, p.max_seq_length]).astype('int32') inputs = jnp.asarray(npy_input) outputs = test_utils.apply(emb_layer, initial_vars, emb_layer.fprop, p.max_seq_length, inputs) # Test whether tf Embedding layer returns same output # Modify initial_vars to use TF compatible params tf_initial_vars = initial_vars tf_p = lingvo_layers.SingleShardEmbeddingLayer.Params().Set( name='tf_pos_emb', vocab_size=p.max_seq_length, embedding_dim=p.embedding_dims) tf_emb_layer = tf_p.Instantiate() tf_output = tf_emb_layer.FProp(tf_initial_vars, tf.constant(inputs, dtype=tf.int32)) np_outputs = to_np(outputs) tf_np_outputs = to_np(tf_output) self.assertAllClose(tf_np_outputs, np_outputs)
def test_vit_transformer_layers(self): batch_size, num_tokens, input_dims, hidden_dims = 3, 8, 12, 48 num_heads, num_layers = 4, 2 residual_dropout_prob, activation_dropout_prob = 0.2, 0.2 atten_dropout_prob = 0.2 atten_logit_cap = 50.0 p_middle = vit.VitTransformerLayers.Params().Set( name='middle', input_dims=input_dims, hidden_dims=hidden_dims, num_heads=num_heads, num_layers=num_layers, atten_logit_cap=atten_logit_cap, residual_dropout_prob=residual_dropout_prob, activation_dropout_prob=activation_dropout_prob, atten_dropout_prob=atten_dropout_prob) middle = p_middle.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = middle.instantiate_variables(prng_key) inputs_np = np.random.normal(size=[batch_size, num_tokens, input_dims]) inputs = jnp.asarray(inputs_np) features = test_utils.apply(middle, initial_vars, middle.fprop, inputs) self.assertEqual(features.shape, (batch_size, num_tokens, input_dims))
def test_causal_depthwise_conv1d(self, shape, kernel_size, axis, hidden_dims): inputs = np.random.normal(1.5, 2.0, shape).astype(np.float32) p = attentions.CausalDepthwiseConv1D.Params().Set( name='causal_dconv', kernel_size=kernel_size, hidden_dims=hidden_dims) causal_dconv_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) prng_key, init_key = jax.random.split(prng_key) initial_vars = causal_dconv_layer.instantiate_variables(init_key) if isinstance(hidden_dims, list): kernel_shape = hidden_dims else: kernel_shape = [hidden_dims] for k in range(kernel_size): initial_vars[f'dconv_{k}'] = np.ones(kernel_shape) jax_dconv_out = test_utils.apply(causal_dconv_layer, initial_vars, causal_dconv_layer.fprop, inputs, axis=axis) jax_np_out = test_utils.to_np(jax_dconv_out) outputs = inputs for _ in range(1, kernel_size): inputs = attentions.shift_1d(inputs, offset=1, axis=axis) outputs += inputs self.assertArraysEqual(jax_np_out, outputs)
def testBase(self): num_classes = 4 latent_dim = 4 b, t = 2, 4 np.random.seed(2021) z = np.random.rand(b, t, latent_dim).astype(np.float32) paddings = np.zeros((b, t)).astype(np.float32) vq_p = self._GetParams(num_classes, latent_dim) vq = vq_p.Instantiate() vq_theta = vq.instantiate_variables(jax.random.PRNGKey(1)) vq_theta.w = jnp.expand_dims(self.w, 1) out = test_utils.apply(vq, vq_theta, vq.fprop, z, paddings) with self.subTest('test_shape'): self.assertEqual((b, t, latent_dim), out.z_q.shape) self.assertEqual((b, t, 1), out.z_codes.shape) self.assertEqual((b, t, 1, num_classes), out.z_onehot.shape) with self.subTest('test_z_q'): self.assertAllClose(15.861525, np.sum(out.z_q)) with self.subTest('test_z_codes'): self.assertEqual(24, np.sum(out.z_codes)) with self.subTest('test_codebook_coverage'): self.assertEqual(0.25, np.sum(out.codebook_coverage)) with self.subTest('test_pplx'): self.assertEqual(1.0, out.pplx) with self.subTest('test_entropy'): self.assertAllClose(0., out.entropy)
def test_rms_norm(self, scale): input_dims = 3 p = normalizations.RmsNorm.Params().Set(name='jax_rmsn', input_dims=input_dims, direct_scale=False) rms_norm = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123456) prng_key, init_key = jax.random.split(prng_key) initial_vars = rms_norm.instantiate_variables(init_key) initial_vars.scale = scale npy_input = np.random.normal( 1.0, 0.5, [10, 10, 10, p.input_dims]).astype('float32') inputs = jnp.asarray(npy_input) outputs = test_utils.apply(rms_norm, initial_vars, rms_norm.fprop, inputs) # Now test whether tf RMS norm returns same output. tf_p = lingvo_layers.LayerNorm.Params().Set(name='tf_rmsn', input_dim=p.input_dims, bias=False, center=False) tf_layer_norm = tf_p.Instantiate() tf_output = tf_layer_norm.FProp(initial_vars, tf.constant(inputs, dtype=tf.float32)) np_outputs = to_np(outputs) tf_np_outputs = to_np(tf_output) np_norms = np.linalg.norm(np_outputs / np.sqrt(float(input_dims)), axis=-1) self.assertAllClose((1.0 + scale) * np.ones_like(np_norms), np_norms, atol=5e-3) self.assertAllClose(tf_np_outputs, np_outputs, atol=6e-5)
def test_vq_layer_equivalence_with_tf(self, num_clusters, num_heads, dim_per_head): inputs = np.random.normal(1.5, 2.0, (2, 32, num_heads, dim_per_head)) prng_key = jax.random.PRNGKey(seed=123) prng_key, init_key = jax.random.split(prng_key) vq_layer_p = ngrammer.VectorQuantization.Params().Set( name='jax_vq_layer', num_clusters=num_clusters, num_heads=num_heads, dim_per_head=dim_per_head, ) vq_layer = vq_layer_p.Instantiate() initial_vars = vq_layer.instantiate_variables(init_key) jax_dists, _ = test_utils.apply(vq_layer, initial_vars, vq_layer.fprop, inputs) # Now run TF based computation. tf_vq_layer_p = attention_util.KMeansClusteringForAtten.Params().Set( name='tf_vq_layer', num_clusters=num_clusters, num_heads=num_heads, dim_per_head=dim_per_head, apply_layer_norm=False) tf_vq_layer = tf_vq_layer_p.Instantiate() tf_dists, _ = tf_vq_layer.FProp(initial_vars, tf.constant(inputs)) self.assertAllClose(to_np(jax_dists), to_np(tf_dists), atol=1e-5)
def test_per_dim_scale(self): test_layer_p = attentions.PerDimScale.Params().Set(name='scale', dim=4) layer = test_layer_p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) prng_key, init_key = jax.random.split(prng_key) initial_vars = layer.instantiate_variables(init_key) initial_vars.per_dim_scale = jnp.array([-0.5, 0.5, 1.0, 0.0], dtype=jnp.float32) logging.info('initial_vars: %s', initial_vars) inputs = np.random.normal(1.5, 2.0, [5, 4]).astype(np.float32) jax_out = test_utils.apply(layer, initial_vars, layer.fprop, inputs) logging.info('jax_output: %s', jax_out) # Now run TF based computation. tf_layer_p = batch_major_attention.PerDimScaleLayer.Params().Set( name='scale', dim=4) tf_layer = tf_layer_p.Instantiate() tf_output1 = tf_layer.FProp(tf_layer.theta, inputs) logging.info('tf_output1: %s', tf_output1) tf_output2 = tf_layer.FProp(initial_vars, inputs) logging.info('tf_output2: %s', tf_output2) self.assertAllClose(test_utils.to_np(jax_out), test_utils.to_np(tf_output2))
def test_rotary_position_embedding_layer_2d(self, position): embedding_dims = 2 min_timescale = 1 max_timescale = 1e4 p = embedding_softmax.RotaryPositionalEmbedding.Params().Set( name='jax_pos', embedding_dims=embedding_dims, min_timescale=min_timescale, max_timescale=max_timescale) pos_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = pos_layer.instantiate_variables(prng_key) inputs = np.random.normal(1.5, 2.5, (1, 4, 1, embedding_dims)) if position is None: position = jnp.arange(4, dtype=jnp.float32) position = jnp.array(position) output = test_utils.apply(pos_layer, initial_vars, pos_layer.fprop, inputs=inputs, position=position[jnp.newaxis, :]) np_output = test_utils.to_np(output) sinusoid_inp = position sin = jnp.sin(sinusoid_inp) cos = jnp.cos(sinusoid_inp) first_part = inputs[0, :, 0, 0] * cos - inputs[0, :, 0, 1] * sin second_part = inputs[0, :, 0, 1] * cos + inputs[0, :, 0, 0] * sin expected_output = np.stack([first_part, second_part], axis=-1) self.assertArraysEqual(np_output[0, :, 0, :], expected_output)
def testStackingOverTimeFPropReduceMaxPadding(self): params = linears.StackingOverTime.Params() params.name = 'stackingOverTime' params.left_context = 2 params.right_context = 0 params.stride = 2 params.padding_reduce_option = 'reduce_max' stacker = linears.StackingOverTime(params) stacker_vars = None self.assertEqual(stacker.window_size, 3) inputs = jnp.array([[[1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6]], [[7, 7], [8, 8], [0, 0], [0, 0], [0, 0], [0, 0]]], dtype=jnp.float32) paddings = jnp.array( [[[0], [0], [0], [0], [0], [0]], [[0], [0], [1], [1], [1], [1]]], dtype=jnp.float32) outputs, output_paddings = test_utils.apply(stacker, stacker_vars, stacker.fprop, inputs, paddings) print(f'{outputs}') expected_outputs = jnp.array([ [[0, 0, 0, 0, 1, 1], [1, 1, 2, 2, 3, 3], [3, 3, 4, 4, 5, 5]], [[0, 0, 0, 0, 7, 7], [7, 7, 8, 8, 0, 0], [0, 0, 0, 0, 0, 0]], ], dtype=jnp.float32) self.assertAllClose(expected_outputs, outputs) expected_output_paddings = jnp.array([[[1], [0], [0]], [[1], [1], [1]]], dtype=jnp.float32) self.assertAllClose(expected_output_paddings, output_paddings)
def test_single_sharded_embedding_layer(self, lookup_style, scale_sqrt_depth): p = embedding_softmax.SingleShardEmbedding.Params().Set( name='jax_emb_lookup', vocab_size=10, embedding_dims=40, lookup_style=lookup_style, scale_sqrt_depth=scale_sqrt_depth) emb_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = emb_layer.instantiate_variables(prng_key) npy_input = np.random.randint(0, p.vocab_size, [10, 20]).astype('int32') inputs = jnp.asarray(npy_input) outputs = test_utils.apply(emb_layer, initial_vars, emb_layer.fprop, inputs) # Test whether tf Embedding layer returns same output # Modify initial_vars to use TF compatible params tf_initial_vars = initial_vars tf_p = lingvo_layers.SingleShardEmbeddingLayer.Params().Set( name='tf_emb_lookup', vocab_size=p.vocab_size, embedding_dim=p.embedding_dims, scale_sqrt_depth=scale_sqrt_depth) tf_emb_layer = tf_p.Instantiate() tf_output = tf_emb_layer.FProp(tf_initial_vars, tf.constant(inputs, dtype=tf.int32)) np_outputs = to_np(outputs) tf_np_outputs = to_np(tf_output) self.assertAllClose(tf_np_outputs, np_outputs, atol=1e-6)
def test_stacked_conformer_layer(self, batch_size, seq_len, num_layers, kernel_size, input_dims, model_dims, atten_num_heads, dropout_prob): p = conformers.StackedConformer.Params().Set(name='conformer', input_dims=input_dims, model_dims=model_dims, num_layers=2) p.conformer_tpl.atten_num_heads = atten_num_heads p.conformer_tpl.kernel_size = kernel_size p.conformer_tpl.dropout_prob = dropout_prob stacked_conformer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = stacked_conformer.instantiate_variables(prng_key) npy_inputs = np.random.normal( 1.0, 0.5, [batch_size, seq_len, input_dims]).astype('float32') inputs = jnp.asarray(npy_inputs) npy_paddings = np.random.randint( 0, 2, [batch_size, seq_len]).astype('float32') paddings = jnp.asarray(npy_paddings) context_p = base_layer.JaxContext.Params().Set(do_eval=True) with cluster_factory.SetEval(True): output = test_utils.apply( stacked_conformer, initial_vars, stacked_conformer.fprop, inputs, paddings, context_p=context_p, ) self.assertEqual(output.shape, (batch_size, seq_len, model_dims))
def test_position_embedding_layer(self, min_timescale, max_timescale): p = embedding_softmax.PositionalEmbedding.Params().Set( name='jax_pos', embedding_dims=50, min_timescale=min_timescale, max_timescale=max_timescale) pos_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = pos_layer.instantiate_variables(prng_key) seq_length = np.random.randint(100, 1000) output = test_utils.apply(pos_layer, initial_vars, pos_layer.fprop, seq_length) output = jnp.squeeze(output, axis=0) # Test whether tf PositionalEmbedding layer returns same output # Modify initial_vars to use TF compatible params tf_initial_vars = initial_vars tf_p = lingvo_layers.PositionalEmbeddingLayer.Params().Set( name='tf_pos', embedding_dim=p.embedding_dims, min_timescale=min_timescale, max_timescale=max_timescale) tf_pos_layer = tf_p.Instantiate() tf_output = tf_pos_layer.FProp(tf_initial_vars, seq_length) np_pos = to_np(output) tf_np_pos = to_np(tf_output) self.assertAllClose(tf_np_pos, np_pos, atol=1e-3)
def test_position_embedding_layer_with_position(self, min_timescale, max_timescale): p = embedding_softmax.PositionalEmbedding.Params().Set( name='jax_pos', embedding_dims=50, min_timescale=min_timescale, max_timescale=max_timescale) pos_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = pos_layer.instantiate_variables(prng_key) position = np.array([[0, 1, 2, 3, 4, 0, 1, 2, 3, 4], [0, 1, 2, 0, 1, 2, 0, 1, 2, 0], [0, 1, 2, 3, 4, 5, 6, 0, 1, 2], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]) output = test_utils.apply(pos_layer, initial_vars, pos_layer.fprop, position=position) # Test whether tf PositionalEmbedding layer returns same output # Modify initial_vars to use TF compatible params tf_initial_vars = initial_vars tf_p = lingvo_layers.PositionalEmbeddingLayer.Params().Set( name='tf_pos', embedding_dim=p.embedding_dims, min_timescale=min_timescale, max_timescale=max_timescale) tf_pos_layer = tf_p.Instantiate() tf_output = tf_pos_layer.FPropWithPosition(tf_initial_vars, position) np_pos = to_np(output) tf_np_pos = to_np(tf_output) self.assertAllClose(tf_np_pos, np_pos, atol=1e-3)
def test_rotary_position_embedding_layer_prefix(self, min_timescale, max_timescale, window_size): embedding_dims = 32 p = embedding_softmax.RotaryPositionalEmbedding.Params().Set( name='jax_pos', embedding_dims=embedding_dims, min_timescale=min_timescale, max_timescale=max_timescale) pos_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = pos_layer.instantiate_variables(prng_key) inputs = np.random.normal(1.5, 2.5, (2, 8, 4, embedding_dims)) output = test_utils.apply(pos_layer, initial_vars, pos_layer.fprop, inputs=inputs) # Test whether extend_step returns same output. for i in range(inputs.shape[1]): start = max(0, i + 1 - window_size) end = i + 1 inputs_prefix = inputs[:, start:end, :, :] pad_width = window_size - end + start paddings = [(0, 0), (pad_width, 0), (0, 0), (0, 0)] inputs_prefix = jnp.pad(inputs_prefix, paddings) jax_extend_step_out = test_utils.apply(pos_layer, initial_vars, pos_layer.extend_step, inputs_prefix, time_step=i) jax_extend_step_out = jax.lax.dynamic_slice_in_dim( jax_extend_step_out, start_index=window_size - 1, slice_size=1, axis=1) jax_np_extend_step_out = test_utils.to_np(jax_extend_step_out) jax_fprop_slice = jax.lax.dynamic_slice_in_dim(output, start_index=i, slice_size=1, axis=1) self.assertArraysEqual(jax_fprop_slice, jax_np_extend_step_out)
def testSpectrumAugmenterWithTimeMask(self): batch_size = 5 inputs = jnp.ones([batch_size, 20, 2], dtype=jnp.float32) paddings = [] for i in range(batch_size): paddings.append( jnp.concatenate([jnp.zeros([1, i + 12]), jnp.ones([1, 8 - i])], axis=1)) paddings = jnp.concatenate(paddings, axis=0) p = spectrum_augmenter.SpectrumAugmenter.Params() p.name = 'specAug_layers' p.freq_mask_max_bins = 0 p.time_mask_max_frames = 5 p.time_mask_count = 2 p.time_mask_max_ratio = 1. specaug_layer = p.Instantiate() expected_output = np.array([[[1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [0., 0.], [0., 0.], [0., 0.], [0., 0.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.], [1., 1.], [0., 0.], [1., 1.], [0., 0.], [0., 0.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [0., 0.], [0., 0.], [0., 0.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [0., 0.], [0., 0.], [1., 1.], [1., 1.], [0., 0.], [0., 0.], [0., 0.], [0., 0.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [0., 0.], [0., 0.], [0., 0.], [0., 0.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.]]]) context_p = base_layer.JaxContext.Params().Set(do_eval=False) prng_key = jax.random.PRNGKey(seed=23456) theta = specaug_layer.instantiate_variables(prng_key) actual_layer_output, _ = test_utils.apply(specaug_layer, theta, specaug_layer.fprop, inputs, paddings, context_p=context_p) self.assertAllClose(actual_layer_output, expected_output)
def test_pooling_layer_with_paddings(self, window_shape, window_stride, padding, pooling_type, input_shape, int_inputs, paddings_all_ones): p = poolings.Pooling.Params().Set(name='jax_pooling', window_shape=window_shape, window_stride=window_stride, pooling_type=pooling_type, padding=padding) pooling_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = pooling_layer.instantiate_variables(prng_key) if int_inputs: npy_inputs = np.random.randint(0, 100, input_shape).astype('int32') else: npy_inputs = np.random.normal(1.0, 0.5, input_shape).astype('float32') inputs = jnp.asarray(npy_inputs) paddings = None tf_paddings = None if paddings_all_ones: npy_paddings = np.ones([input_shape[0], input_shape[1]]).astype(npy_inputs.dtype) else: npy_paddings = np.random.randint( 0, 2, [input_shape[0], input_shape[1]]).astype(npy_inputs.dtype) paddings = jnp.asarray(npy_paddings) tf_paddings = tf.constant(npy_paddings, dtype=tf.float32) output, out_paddings = test_utils.apply(pooling_layer, initial_vars, pooling_layer.fprop, inputs, paddings) # Test whether tf Pooling layer returns the same output. # Modify initial_vars to use TF compatible params. tf_initial_vars = initial_vars tf_p = lingvo_layers.PoolingLayer.Params().Set( name='tf_pooling', window_shape=window_shape, window_stride=window_stride, pooling_type=pooling_type, padding_algorithm=padding) tf_pooling_layer = tf_p.Instantiate() tf_input = tf.constant(npy_inputs, dtype=tf.float32) tf_output = tf_pooling_layer.FProp(tf_initial_vars, tf_input, tf_paddings) # Check the actual output. np_output = to_np(output) tf_np_output = to_np(tf_output[0]) np_paddings = to_np(out_paddings) tf_np_paddings = to_np(tf_output[1]) # Check the paddings. self.assertAllClose(np_paddings, tf_np_paddings) self.assertAllClose(tf_np_output, np_output)
def _run_decode(self, decoder_p, logits, input_batch): p = base_model.LanguageModel.Params() p.name = 'mock_lm' p.decoder = decoder_p.Copy() p.lm = MockLM.Params() p.lm.logits = logits lang_model = p.Instantiate() theta = NestedMap(lm=NestedMap()) # We fix seed to 1027 to get the desired prefix lengths below. _, results = test_utils.apply(lang_model, theta, lang_model.decode, input_batch, seed=1027) return results
def testBase(self, b, t, latent_dim, projection_dim, num_classes): np.random.seed(2022) z = np.random.rand(b, t, latent_dim).astype(np.float32) paddings = np.zeros((b, t)).astype(np.float32) rq = quantizer.RandomVectorQuantizer.Params().Set( name='vq', num_latent_classes=num_classes, latent_dim=latent_dim, projection_dim=projection_dim) rq = rq.Instantiate() rq_theta = rq.instantiate_variables(jax.random.PRNGKey(1)) out = test_utils.apply(rq, rq_theta, rq.fprop, z, paddings) self.assertEqual((b, t, projection_dim), out.z_q.shape) self.assertEqual((b, t), out.z_codes.shape) self.assertEqual((b, t, num_classes), out.z_onehot.shape)
def test_vit(self): batch_size = 3 p_vit = self._vit_params() vit_model = p_vit.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = vit_model.instantiate_variables(prng_key) inputs_np = np.random.normal( size=[batch_size, p_vit.image_size, p_vit.image_size, 3]) inputs = jnp.asarray(inputs_np) features = test_utils.apply(vit_model, initial_vars, vit_model.fprop, inputs) self.assertEqual(features.shape, (batch_size, p_vit.hidden_dim))
def testStackingOverTimePadWithRightFrameFProp(self, pad_with_right_frame): params = linears.StackingOverTime.Params() params.name = 'stackingOverTime' params.left_context = 0 params.right_context = 1 params.stride = 2 params.pad_with_right_frame = pad_with_right_frame stacker = linears.StackingOverTime(params) stacker_vars = None self.assertEqual(stacker.window_size, 2) # input shape [2, 5, 2] inputs = jnp.array([[[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]], [[7, 7], [8, 8], [0, 0], [0, 0], [0, 0]]], dtype=jnp.float32) paddings = jnp.array( [[[0], [0], [0], [0], [0]], [[0], [0], [1], [1], [1]]], dtype=jnp.float32) outputs, output_paddings = test_utils.apply(stacker, stacker_vars, stacker.fprop, inputs, paddings) print(f'{outputs}') if pad_with_right_frame: # output shape [2, 3, 4] # [5, 5] is duplication of the last input frame. expected_outputs = jnp.array([ [[1, 1, 2, 2], [3, 3, 4, 4], [5, 5, 5, 5]], [[7, 7, 8, 8], [0, 0, 0, 0], [0, 0, 0, 0]], ], dtype=jnp.float32) else: expected_outputs = jnp.array([ [[1, 1, 2, 2], [3, 3, 4, 4], [5, 5, 0, 0]], [[7, 7, 8, 8], [0, 0, 0, 0], [0, 0, 0, 0]], ], dtype=jnp.float32) self.assertAllClose(expected_outputs, outputs) expected_output_paddings = jnp.array( [[[0], [0], [0]], [[0], [1], [1]]], dtype=jnp.float32) self.assertAllClose(expected_output_paddings, output_paddings)
def testVitSkipExitLayers(self): batch_size = 3 p_vit = self._vit_params().Set(exit_layers_tpl=None) vit_model = p_vit.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = vit_model.instantiate_variables(prng_key) inputs_np = np.random.normal( size=[batch_size, p_vit.image_size, p_vit.image_size, 3]) inputs = jnp.asarray(inputs_np) features = test_utils.apply(vit_model, initial_vars, vit_model.fprop, inputs) patch_count = p_vit.image_size // p_vit.patch_size self.assertEqual(features.shape, (batch_size, patch_count**2, p_vit.hidden_dim))
def test_transformer_bert(self, trainable_position_emb): seq_len = 512 if trainable_position_emb: position_emb_tpl = embedding_softmax.TrainablePositionalEmbedding.Params( ) position_emb_tpl.max_seq_length = seq_len else: position_emb_tpl = embedding_softmax.PositionalEmbedding.Params() p = transformer_models.TransformerLm.Params().Set( name='bert_lm', model_dims=32, vocab_size=52, position_emb_tpl=position_emb_tpl) stacked_transformer_tpl = p.stacked_transformer_tpl stacked_transformer_tpl.model_dims = 32 stacked_transformer_tpl.hidden_dims = 4 * 32 stacked_transformer_tpl.num_heads = 4 stacked_transformer_tpl.num_layers = 1 p.softmax_tpl.scale_sqrt_depth = True batch_size = 8 bert_lm = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = bert_lm.instantiate_variables(prng_key) input_ids = jax.random.randint(jax.random.PRNGKey(1234), [batch_size, seq_len], 0, 51) input_paddings = jnp.zeros([batch_size, seq_len]) input_weights = jnp.ones([batch_size, seq_len]) input_segment_ids = jnp.ones([batch_size, seq_len]) input_segment_pos = jnp.tile( jnp.arange(0, seq_len)[jnp.newaxis, :], [batch_size, 1]) labels = py_utils.NestedMap() labels.class_ids = input_ids labels.class_weights = input_weights outputs = test_utils.apply(bert_lm, initial_vars, bert_lm.fprop, input_ids, input_paddings, labels=labels, segment_ids=input_segment_ids, segment_pos=input_segment_pos) logging.info('outputs: %s', outputs)