示例#1
0
  def _testUnstack(self, inputs, **kwargs):
    params = linears.StackingOverTime.Params().Set(
        name='stackingOverTime', **kwargs)

    stacker = params.Instantiate()
    stacker_vars = None
    stacked, _ = test_utils.apply(stacker, stacker_vars, stacker.fprop, inputs)
    unstacked = test_utils.apply(stacker, stacker_vars, stacker.unstack,
                                 stacked)
    print(f'{unstacked}')

    batch, input_length, depth = inputs.shape
    stacked_length = stacked.shape[1]
    stride = stacker.params.stride
    right_context = stacker.params.right_context

    self.assertAllClose(
        unstacked.shape,
        [batch, (stacked_length - 1) * stride + right_context + 1, depth])
    if right_context + 1 >= stride:
      self.assertGreaterEqual(unstacked.shape[1], input_length)
      self.assertAllClose(inputs, unstacked[:, :input_length])
    else:
      self.assertLessEqual(unstacked.shape[1], input_length)
      # The final up to stride - right_context - 1 values are missing.
      self.assertLessEqual(input_length - unstacked.shape[1],
                           stride - right_context - 1)
      self.assertAllClose(inputs[:, :unstacked.shape[1]], unstacked)
示例#2
0
 def test_rotary_position_embedding_layer_no_prefix(self, min_timescale,
                                                    max_timescale):
     embedding_dims = 32
     p = embedding_softmax.RotaryPositionalEmbedding.Params().Set(
         name='jax_pos',
         embedding_dims=embedding_dims,
         min_timescale=min_timescale,
         max_timescale=max_timescale)
     pos_layer = p.Instantiate()
     prng_key = jax.random.PRNGKey(seed=123)
     initial_vars = pos_layer.instantiate_variables(prng_key)
     inputs = np.random.normal(1.5, 2.5, (2, 8, 4, embedding_dims))
     output = test_utils.apply(pos_layer,
                               initial_vars,
                               pos_layer.fprop,
                               inputs=inputs)
     # Test whether extend_step returns same output.
     for i in range(inputs.shape[1]):
         jax_extend_step_out = test_utils.apply(pos_layer,
                                                initial_vars,
                                                pos_layer.extend_step,
                                                inputs[:, i, :, :],
                                                time_step=i)
         jax_np_extend_step_out = test_utils.to_np(jax_extend_step_out)
         jax_fprop_slice = output[:, i, :, :]
         self.assertArraysEqual(jax_fprop_slice, jax_np_extend_step_out)
示例#3
0
    def test_ngrammer_layer_exact_bigram_2d(self, unigram_vocab_size,
                                            ngram_emb_dim, num_heads,
                                            dim_per_head, concat_ngrams):
        batch_size = 2
        seq_len = 8
        inputs = np.random.randint(unigram_vocab_size,
                                   size=[batch_size, seq_len],
                                   dtype=np.int32)
        paddings = np.random.randint(1, size=[batch_size, seq_len])
        input_embs = np.random.normal(
            1.5, 2.0, (batch_size, seq_len, num_heads * dim_per_head))
        prng_key = jax.random.PRNGKey(seed=123)
        prng_key, init_key = jax.random.split(prng_key)
        ngrammer_layer_p = ngrammer.Ngrammer.Params().Set(
            name='jax_ngrammer_layer',
            unigram_vocab_size=unigram_vocab_size,
            ngram_vocab_size=num_heads * unigram_vocab_size**2,
            ngram_emb_dim=ngram_emb_dim,
            num_heads=num_heads,
            dim_per_head=dim_per_head,
            concat_ngrams=concat_ngrams,
        )
        ngrammer_layer = ngrammer_layer_p.Instantiate()
        initial_vars = ngrammer_layer.instantiate_variables(init_key)

        ngram_embs = test_utils.apply(ngrammer_layer, initial_vars,
                                      ngrammer_layer.fprop, inputs, input_embs,
                                      paddings)
        ngram_embs = np.reshape(ngram_embs,
                                [batch_size, seq_len, num_heads, dim_per_head])
        input_embs = np.reshape(input_embs,
                                [batch_size, seq_len, num_heads, dim_per_head])
        for i in range(num_heads):
            input_ids_per_head = inputs
            ngram_ids_per_head = ngrammer.get_bigram_ids(
                input_ids_per_head, unigram_vocab_size)
            ngram_ids_per_head *= (i + 1)
            ngram_ids_per_head += (i + 1)
            ngram_embs_expected = test_utils.apply(
                ngrammer_layer.ngram_table[i], initial_vars.ngram_table[i],
                ngrammer_layer.ngram_table[i].fprop,
                np.reshape(ngram_ids_per_head, [-1]))
            ngram_embs_expected = test_utils.apply(
                ngrammer_layer.ngram_layer_norm[i],
                initial_vars.ngram_layer_norm[i],
                ngrammer_layer.ngram_layer_norm[i].fprop, ngram_embs_expected)
            ngram_embs_expected = jnp.reshape(
                ngram_embs_expected, [batch_size, seq_len, ngram_emb_dim])
            ngram_embs_expected *= (1 - paddings[:, :, np.newaxis])
            if concat_ngrams:
                ngram_embs_slice = ngram_embs[:, :, i, -ngram_emb_dim:]
            else:
                input_embs_ln = test_utils.apply(
                    ngrammer_layer.emb_layer_norm[i],
                    initial_vars.emb_layer_norm[i],
                    ngrammer_layer.emb_layer_norm[i].fprop, input_embs[:, :,
                                                                       i, :])
                ngram_embs_slice = ngram_embs[:, :, i, :] - input_embs_ln
            self.assertAllClose(to_np(ngram_embs_slice),
                                to_np(ngram_embs_expected))
示例#4
0
    def test_group_norm(self, dim, num_groups, cumulative, input_rank, epsilon,
                        input_shape, input_dtype, paddings, fprop_dtype):
        p = normalizations.GroupNorm.Params().Set(name='jax_gn',
                                                  dim=dim,
                                                  num_groups=num_groups,
                                                  cumulative=cumulative,
                                                  input_rank=input_rank,
                                                  epsilon=epsilon,
                                                  fprop_dtype=fprop_dtype)
        group_norm = p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=123456)
        prng_key, init_key = jax.random.split(prng_key)
        initial_vars = group_norm.instantiate_variables(init_key)
        npy_input = np.random.normal(1.0, 0.5, input_shape).astype(np.float32)
        inputs = jnp.asarray(npy_input, dtype=input_dtype)
        if paddings is None:
            output = test_utils.apply(group_norm,
                                      initial_vars,
                                      group_norm.fprop,
                                      inputs,
                                      paddings=None)
        else:
            output, output_paddings = test_utils.apply(group_norm,
                                                       initial_vars,
                                                       group_norm.fprop,
                                                       inputs,
                                                       paddings=jnp.asarray(
                                                           paddings,
                                                           dtype=input_dtype))

        # Now test whether tf layer norm returns same output.
        tf_p = bn_layers.GroupNormLayer.Params().Set(
            name='tf_gn',
            dim=dim,
            num_groups=num_groups,
            cumulative=cumulative,
            input_rank=input_rank,
            epsilon=epsilon,
            fprop_dtype=_JaxToTfDtype(fprop_dtype))
        tf_group_norm = tf_p.Instantiate()
        tf_inputs = tf.constant(inputs, dtype=_JaxToTfDtype(input_dtype))
        if paddings is None:
            tf_output = tf_group_norm.FProp(initial_vars,
                                            tf_inputs,
                                            paddings=None)
        else:
            tf_output, tf_output_paddings = tf_group_norm.FProp(
                initial_vars,
                tf_inputs,
                paddings=tf.convert_to_tensor(
                    paddings, dtype=_JaxToTfDtype(input_dtype)))

        self.assertAllClose(to_np(tf_output), to_np(output))
        if paddings is not None:
            self.assertAllClose(to_np(tf_output_paddings),
                                to_np(output_paddings))
示例#5
0
    def test_simple_softmax_layer_class_probs(self, batch_size, num_classes):
        batch_size = 8
        num_classes = 1001
        class_probabilities = np.random.normal(1.5, 2.0,
                                               [batch_size, num_classes])
        # Normalize class probabilities to be a probability distribution.
        class_probabilities /= np.sum(class_probabilities,
                                      axis=-1,
                                      keepdims=True)
        p = embedding_softmax.SingleShardFullSoftmax.Params().Set(
            name='jax_softmax', num_classes=num_classes, input_dims=40)
        softmax_layer = p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=123)
        initial_vars = softmax_layer.instantiate_variables(prng_key)
        npy_input = np.random.normal(1.5, 2.0, [batch_size, p.input_dims])
        inputs = jnp.asarray(npy_input)
        class_weights = np.random.normal(1.5, 2.0, [batch_size, 1])
        logits = test_utils.apply(softmax_layer, initial_vars,
                                  softmax_layer.get_logits, inputs)
        outputs = test_utils.apply(softmax_layer,
                                   initial_vars,
                                   softmax_layer.fprop,
                                   inputs,
                                   class_weights,
                                   class_ids=None,
                                   class_probabilities=class_probabilities)
        # Test whether tf Softmax layer returns same output.
        # Modify initial_vars to use TF compatible params.
        tf_initial_vars = test_utils.replace_jax_simple_full_softmax_vars_to_tf(
            initial_vars)
        # Convert all the values to TF tensor.
        tf_initial_vars = tf.nest.map_structure(tf.convert_to_tensor,
                                                tf_initial_vars)

        tf_p = lingvo_layers.SimpleFullSoftmax.Params().Set(
            name='tf_softmax',
            num_classes=p.num_classes,
            input_dim=p.input_dims)
        tf_softmax_layer = tf_p.Instantiate()
        tf_logits = tf_softmax_layer.Logits(
            tf_initial_vars, tf.constant(inputs, dtype=tf.float32))
        tf_output = tf_softmax_layer.FProp(
            tf_initial_vars,
            tf.constant(inputs, dtype=tf.float32),
            class_weights,
            class_ids=None,
            class_probabilities=class_probabilities)
        # Check all entries in the NestedMap and ensure it matches TF.
        np_get_logits = to_np(logits)
        tf_np_get_logits = to_np(tf_logits)
        self.assertAllClose(np_get_logits, tf_np_get_logits)
        for k in outputs.keys():
            self.assertAllClose(to_np(outputs[k]), to_np(tf_output[k]))
示例#6
0
    def test_combine_qkv_with_attention_combine_dims(self):
        input_dim = 64
        dim_per_head = 8
        num_heads = 8
        # Reference combine qkv projection layer.
        ref_proj_p = attentions.CombinedQKVProjectionLayer.Params().Set(
            name='ref',
            input_dim=input_dim,
            dim_per_head=dim_per_head,
            num_heads=num_heads)
        proj = ref_proj_p.Instantiate()

        # Combine attention dim combine qkv projection layer.
        combine_proj_p = attentions.CombinedQKVProjectionLayer.Params().Set(
            name='ref',
            input_dim=input_dim,
            dim_per_head=dim_per_head,
            num_heads=num_heads,
            attention_combine_dims=True)
        combine_proj = combine_proj_p.Instantiate()

        prng_key = jax.random.PRNGKey(seed=123)
        prng_key, init_key = jax.random.split(prng_key)
        initial_vars = proj.instantiate_variables(init_key)

        # Set up initial vars for combine attention dim projection.
        combine_initial_vars = combine_proj.instantiate_variables(init_key)
        combine_initial_vars.w = np.reshape(
            initial_vars.w, (3, input_dim, num_heads * dim_per_head))
        combine_initial_vars.b = np.reshape(initial_vars.b,
                                            (3, num_heads * dim_per_head))

        batch_size = 3
        inputs = np.random.normal(size=[batch_size, input_dim]).astype(
            np.float32)

        prng_key, compute_key = jax.random.split(prng_key)
        global_step = jnp.array(0, dtype=jnp.uint64)

        with base_layer.JaxContext.new_context(prng_key=compute_key,
                                               global_step=global_step):
            q_proj_ref, k_proj_ref, v_proj_ref = test_utils.apply(
                proj, initial_vars, proj.fprop, inputs)
            q_proj_combine, k_proj_combine, v_proj_combine = test_utils.apply(
                combine_proj, combine_initial_vars, combine_proj.fprop, inputs)

        self.assertAllClose(q_proj_ref, q_proj_combine)
        self.assertAllClose(k_proj_ref, k_proj_combine)
        self.assertAllClose(v_proj_ref, v_proj_combine)
示例#7
0
 def test_feedforward_layer_no_bias(self, activation):
   p = linears.FeedForward.Params().Set(
       name='jax_ffn',
       input_dims=3,
       output_dims=20,
       has_bias=False,
       activation=activation)
   ffn = p.Instantiate()
   prng_key = jax.random.PRNGKey(seed=123)
   initial_vars = ffn.instantiate_variables(prng_key)
   npy_input = np.random.normal(1.0, 0.5,
                                [10, 10, p.input_dims]).astype('float32')
   inputs = jnp.asarray(npy_input)
   outputs = test_utils.apply(ffn, initial_vars, ffn.fprop, inputs)
   logging.info('initial_vars in ffn = %s', initial_vars)
   # Test whether tf projection layer returns same output
   # Modify initial_vars to use TF compatible params
   tf_initial_vars = py_utils.NestedMap()
   tf_initial_vars.w = initial_vars.linear.w
   tf_initial_vars = to_tf_nmap(tf_initial_vars)
   tf_p = lingvo_layers.ProjectionLayer.Params().Set(
       name='tf_ffn',
       input_dim=p.input_dims,
       output_dim=p.output_dims,
       batch_norm=False,
       has_bias=False,
       activation=activation)
   tf_ffn = tf_p.Instantiate()
   tf_output = tf_ffn.FProp(tf_initial_vars,
                            tf.constant(inputs, dtype=tf.float32))
   np_outputs = to_np(outputs)
   tf_np_outputs = to_np(tf_output)
   self.assertAllClose(tf_np_outputs, np_outputs, atol=1e-6)
示例#8
0
  def testStackingOverTimeFProp2(self):
    params = linears.StackingOverTime.Params()
    params.name = 'stackingOverTime'
    params.left_context = 0
    params.right_context = 1
    params.stride = 2

    stacker = linears.StackingOverTime(params)
    stacker_vars = None
    self.assertEqual(stacker.window_size, 2)

    inputs = np.random.normal(size=[2, 21, 16])
    # poor man's tf.sequence_mask in np.
    mask = np.zeros([2, 21]).astype(np.float32)
    mask[0, :9] = 1.
    mask[1, :14] = 1.

    paddings = 1.0 - mask
    paddings = jnp.expand_dims(paddings, -1)
    outputs, output_paddings = test_utils.apply(stacker, stacker_vars,
                                                stacker.fprop, inputs, paddings)

    # length
    self.assertAllClose(
        np.array([5, 7], dtype=np.float32), np.sum(1.0 - output_paddings,
                                                   (1, 2)))
    # input and output sums are equal
    self.assertAllClose(np.sum(inputs, (1, 2)), np.sum(outputs, (1, 2)))
示例#9
0
 def test_trainable_positional_embedding_layer(self, lookup_style):
     p = embedding_softmax.TrainablePositionalEmbedding.Params().Set(
         name='jax_pos_emb',
         max_seq_length=10,
         embedding_dims=40,
         lookup_style=lookup_style)
     emb_layer = p.Instantiate()
     prng_key = jax.random.PRNGKey(seed=123)
     initial_vars = emb_layer.instantiate_variables(prng_key)
     npy_input = np.random.randint(0, p.max_seq_length,
                                   [10, p.max_seq_length]).astype('int32')
     inputs = jnp.asarray(npy_input)
     outputs = test_utils.apply(emb_layer, initial_vars, emb_layer.fprop,
                                p.max_seq_length, inputs)
     # Test whether tf Embedding layer returns same output
     # Modify initial_vars to use TF compatible params
     tf_initial_vars = initial_vars
     tf_p = lingvo_layers.SingleShardEmbeddingLayer.Params().Set(
         name='tf_pos_emb',
         vocab_size=p.max_seq_length,
         embedding_dim=p.embedding_dims)
     tf_emb_layer = tf_p.Instantiate()
     tf_output = tf_emb_layer.FProp(tf_initial_vars,
                                    tf.constant(inputs, dtype=tf.int32))
     np_outputs = to_np(outputs)
     tf_np_outputs = to_np(tf_output)
     self.assertAllClose(tf_np_outputs, np_outputs)
示例#10
0
    def test_vit_transformer_layers(self):
        batch_size, num_tokens, input_dims, hidden_dims = 3, 8, 12, 48
        num_heads, num_layers = 4, 2
        residual_dropout_prob, activation_dropout_prob = 0.2, 0.2
        atten_dropout_prob = 0.2
        atten_logit_cap = 50.0

        p_middle = vit.VitTransformerLayers.Params().Set(
            name='middle',
            input_dims=input_dims,
            hidden_dims=hidden_dims,
            num_heads=num_heads,
            num_layers=num_layers,
            atten_logit_cap=atten_logit_cap,
            residual_dropout_prob=residual_dropout_prob,
            activation_dropout_prob=activation_dropout_prob,
            atten_dropout_prob=atten_dropout_prob)

        middle = p_middle.Instantiate()

        prng_key = jax.random.PRNGKey(seed=123)
        initial_vars = middle.instantiate_variables(prng_key)

        inputs_np = np.random.normal(size=[batch_size, num_tokens, input_dims])
        inputs = jnp.asarray(inputs_np)

        features = test_utils.apply(middle, initial_vars, middle.fprop, inputs)

        self.assertEqual(features.shape, (batch_size, num_tokens, input_dims))
示例#11
0
    def test_causal_depthwise_conv1d(self, shape, kernel_size, axis,
                                     hidden_dims):
        inputs = np.random.normal(1.5, 2.0, shape).astype(np.float32)
        p = attentions.CausalDepthwiseConv1D.Params().Set(
            name='causal_dconv',
            kernel_size=kernel_size,
            hidden_dims=hidden_dims)
        causal_dconv_layer = p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=123)
        prng_key, init_key = jax.random.split(prng_key)
        initial_vars = causal_dconv_layer.instantiate_variables(init_key)
        if isinstance(hidden_dims, list):
            kernel_shape = hidden_dims
        else:
            kernel_shape = [hidden_dims]
        for k in range(kernel_size):
            initial_vars[f'dconv_{k}'] = np.ones(kernel_shape)

        jax_dconv_out = test_utils.apply(causal_dconv_layer,
                                         initial_vars,
                                         causal_dconv_layer.fprop,
                                         inputs,
                                         axis=axis)
        jax_np_out = test_utils.to_np(jax_dconv_out)
        outputs = inputs
        for _ in range(1, kernel_size):
            inputs = attentions.shift_1d(inputs, offset=1, axis=axis)
            outputs += inputs
        self.assertArraysEqual(jax_np_out, outputs)
示例#12
0
    def testBase(self):
        num_classes = 4
        latent_dim = 4

        b, t = 2, 4
        np.random.seed(2021)
        z = np.random.rand(b, t, latent_dim).astype(np.float32)
        paddings = np.zeros((b, t)).astype(np.float32)

        vq_p = self._GetParams(num_classes, latent_dim)
        vq = vq_p.Instantiate()
        vq_theta = vq.instantiate_variables(jax.random.PRNGKey(1))
        vq_theta.w = jnp.expand_dims(self.w, 1)
        out = test_utils.apply(vq, vq_theta, vq.fprop, z, paddings)

        with self.subTest('test_shape'):
            self.assertEqual((b, t, latent_dim), out.z_q.shape)
            self.assertEqual((b, t, 1), out.z_codes.shape)
            self.assertEqual((b, t, 1, num_classes), out.z_onehot.shape)
        with self.subTest('test_z_q'):
            self.assertAllClose(15.861525, np.sum(out.z_q))
        with self.subTest('test_z_codes'):
            self.assertEqual(24, np.sum(out.z_codes))
        with self.subTest('test_codebook_coverage'):
            self.assertEqual(0.25, np.sum(out.codebook_coverage))
        with self.subTest('test_pplx'):
            self.assertEqual(1.0, out.pplx)
        with self.subTest('test_entropy'):
            self.assertAllClose(0., out.entropy)
示例#13
0
 def test_rms_norm(self, scale):
     input_dims = 3
     p = normalizations.RmsNorm.Params().Set(name='jax_rmsn',
                                             input_dims=input_dims,
                                             direct_scale=False)
     rms_norm = p.Instantiate()
     prng_key = jax.random.PRNGKey(seed=123456)
     prng_key, init_key = jax.random.split(prng_key)
     initial_vars = rms_norm.instantiate_variables(init_key)
     initial_vars.scale = scale
     npy_input = np.random.normal(
         1.0, 0.5, [10, 10, 10, p.input_dims]).astype('float32')
     inputs = jnp.asarray(npy_input)
     outputs = test_utils.apply(rms_norm, initial_vars, rms_norm.fprop,
                                inputs)
     # Now test whether tf RMS norm returns same output.
     tf_p = lingvo_layers.LayerNorm.Params().Set(name='tf_rmsn',
                                                 input_dim=p.input_dims,
                                                 bias=False,
                                                 center=False)
     tf_layer_norm = tf_p.Instantiate()
     tf_output = tf_layer_norm.FProp(initial_vars,
                                     tf.constant(inputs, dtype=tf.float32))
     np_outputs = to_np(outputs)
     tf_np_outputs = to_np(tf_output)
     np_norms = np.linalg.norm(np_outputs / np.sqrt(float(input_dims)),
                               axis=-1)
     self.assertAllClose((1.0 + scale) * np.ones_like(np_norms),
                         np_norms,
                         atol=5e-3)
     self.assertAllClose(tf_np_outputs, np_outputs, atol=6e-5)
示例#14
0
  def test_vq_layer_equivalence_with_tf(self, num_clusters, num_heads,
                                        dim_per_head):
    inputs = np.random.normal(1.5, 2.0, (2, 32, num_heads, dim_per_head))
    prng_key = jax.random.PRNGKey(seed=123)
    prng_key, init_key = jax.random.split(prng_key)
    vq_layer_p = ngrammer.VectorQuantization.Params().Set(
        name='jax_vq_layer',
        num_clusters=num_clusters,
        num_heads=num_heads,
        dim_per_head=dim_per_head,
    )
    vq_layer = vq_layer_p.Instantiate()
    initial_vars = vq_layer.instantiate_variables(init_key)

    jax_dists, _ = test_utils.apply(vq_layer, initial_vars, vq_layer.fprop,
                                    inputs)

    # Now run TF based computation.
    tf_vq_layer_p = attention_util.KMeansClusteringForAtten.Params().Set(
        name='tf_vq_layer',
        num_clusters=num_clusters,
        num_heads=num_heads,
        dim_per_head=dim_per_head,
        apply_layer_norm=False)
    tf_vq_layer = tf_vq_layer_p.Instantiate()
    tf_dists, _ = tf_vq_layer.FProp(initial_vars, tf.constant(inputs))
    self.assertAllClose(to_np(jax_dists), to_np(tf_dists), atol=1e-5)
示例#15
0
    def test_per_dim_scale(self):
        test_layer_p = attentions.PerDimScale.Params().Set(name='scale', dim=4)
        layer = test_layer_p.Instantiate()

        prng_key = jax.random.PRNGKey(seed=123)
        prng_key, init_key = jax.random.split(prng_key)
        initial_vars = layer.instantiate_variables(init_key)
        initial_vars.per_dim_scale = jnp.array([-0.5, 0.5, 1.0, 0.0],
                                               dtype=jnp.float32)
        logging.info('initial_vars: %s', initial_vars)

        inputs = np.random.normal(1.5, 2.0, [5, 4]).astype(np.float32)

        jax_out = test_utils.apply(layer, initial_vars, layer.fprop, inputs)
        logging.info('jax_output: %s', jax_out)

        # Now run TF based computation.
        tf_layer_p = batch_major_attention.PerDimScaleLayer.Params().Set(
            name='scale', dim=4)
        tf_layer = tf_layer_p.Instantiate()
        tf_output1 = tf_layer.FProp(tf_layer.theta, inputs)
        logging.info('tf_output1: %s', tf_output1)
        tf_output2 = tf_layer.FProp(initial_vars, inputs)
        logging.info('tf_output2: %s', tf_output2)
        self.assertAllClose(test_utils.to_np(jax_out),
                            test_utils.to_np(tf_output2))
示例#16
0
 def test_rotary_position_embedding_layer_2d(self, position):
     embedding_dims = 2
     min_timescale = 1
     max_timescale = 1e4
     p = embedding_softmax.RotaryPositionalEmbedding.Params().Set(
         name='jax_pos',
         embedding_dims=embedding_dims,
         min_timescale=min_timescale,
         max_timescale=max_timescale)
     pos_layer = p.Instantiate()
     prng_key = jax.random.PRNGKey(seed=123)
     initial_vars = pos_layer.instantiate_variables(prng_key)
     inputs = np.random.normal(1.5, 2.5, (1, 4, 1, embedding_dims))
     if position is None:
         position = jnp.arange(4, dtype=jnp.float32)
     position = jnp.array(position)
     output = test_utils.apply(pos_layer,
                               initial_vars,
                               pos_layer.fprop,
                               inputs=inputs,
                               position=position[jnp.newaxis, :])
     np_output = test_utils.to_np(output)
     sinusoid_inp = position
     sin = jnp.sin(sinusoid_inp)
     cos = jnp.cos(sinusoid_inp)
     first_part = inputs[0, :, 0, 0] * cos - inputs[0, :, 0, 1] * sin
     second_part = inputs[0, :, 0, 1] * cos + inputs[0, :, 0, 0] * sin
     expected_output = np.stack([first_part, second_part], axis=-1)
     self.assertArraysEqual(np_output[0, :, 0, :], expected_output)
示例#17
0
  def testStackingOverTimeFPropReduceMaxPadding(self):
    params = linears.StackingOverTime.Params()
    params.name = 'stackingOverTime'
    params.left_context = 2
    params.right_context = 0
    params.stride = 2
    params.padding_reduce_option = 'reduce_max'

    stacker = linears.StackingOverTime(params)
    stacker_vars = None
    self.assertEqual(stacker.window_size, 3)

    inputs = jnp.array([[[1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6]],
                        [[7, 7], [8, 8], [0, 0], [0, 0], [0, 0], [0, 0]]],
                       dtype=jnp.float32)
    paddings = jnp.array(
        [[[0], [0], [0], [0], [0], [0]], [[0], [0], [1], [1], [1], [1]]],
        dtype=jnp.float32)

    outputs, output_paddings = test_utils.apply(stacker, stacker_vars,
                                                stacker.fprop, inputs, paddings)
    print(f'{outputs}')
    expected_outputs = jnp.array([
        [[0, 0, 0, 0, 1, 1], [1, 1, 2, 2, 3, 3], [3, 3, 4, 4, 5, 5]],
        [[0, 0, 0, 0, 7, 7], [7, 7, 8, 8, 0, 0], [0, 0, 0, 0, 0, 0]],
    ],
                                 dtype=jnp.float32)

    self.assertAllClose(expected_outputs, outputs)

    expected_output_paddings = jnp.array([[[1], [0], [0]], [[1], [1], [1]]],
                                         dtype=jnp.float32)
    self.assertAllClose(expected_output_paddings, output_paddings)
示例#18
0
 def test_single_sharded_embedding_layer(self, lookup_style,
                                         scale_sqrt_depth):
     p = embedding_softmax.SingleShardEmbedding.Params().Set(
         name='jax_emb_lookup',
         vocab_size=10,
         embedding_dims=40,
         lookup_style=lookup_style,
         scale_sqrt_depth=scale_sqrt_depth)
     emb_layer = p.Instantiate()
     prng_key = jax.random.PRNGKey(seed=123)
     initial_vars = emb_layer.instantiate_variables(prng_key)
     npy_input = np.random.randint(0, p.vocab_size,
                                   [10, 20]).astype('int32')
     inputs = jnp.asarray(npy_input)
     outputs = test_utils.apply(emb_layer, initial_vars, emb_layer.fprop,
                                inputs)
     # Test whether tf Embedding layer returns same output
     # Modify initial_vars to use TF compatible params
     tf_initial_vars = initial_vars
     tf_p = lingvo_layers.SingleShardEmbeddingLayer.Params().Set(
         name='tf_emb_lookup',
         vocab_size=p.vocab_size,
         embedding_dim=p.embedding_dims,
         scale_sqrt_depth=scale_sqrt_depth)
     tf_emb_layer = tf_p.Instantiate()
     tf_output = tf_emb_layer.FProp(tf_initial_vars,
                                    tf.constant(inputs, dtype=tf.int32))
     np_outputs = to_np(outputs)
     tf_np_outputs = to_np(tf_output)
     self.assertAllClose(tf_np_outputs, np_outputs, atol=1e-6)
示例#19
0
    def test_stacked_conformer_layer(self, batch_size, seq_len, num_layers,
                                     kernel_size, input_dims, model_dims,
                                     atten_num_heads, dropout_prob):
        p = conformers.StackedConformer.Params().Set(name='conformer',
                                                     input_dims=input_dims,
                                                     model_dims=model_dims,
                                                     num_layers=2)
        p.conformer_tpl.atten_num_heads = atten_num_heads
        p.conformer_tpl.kernel_size = kernel_size
        p.conformer_tpl.dropout_prob = dropout_prob

        stacked_conformer = p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=123)
        initial_vars = stacked_conformer.instantiate_variables(prng_key)
        npy_inputs = np.random.normal(
            1.0, 0.5, [batch_size, seq_len, input_dims]).astype('float32')
        inputs = jnp.asarray(npy_inputs)
        npy_paddings = np.random.randint(
            0, 2, [batch_size, seq_len]).astype('float32')
        paddings = jnp.asarray(npy_paddings)

        context_p = base_layer.JaxContext.Params().Set(do_eval=True)

        with cluster_factory.SetEval(True):
            output = test_utils.apply(
                stacked_conformer,
                initial_vars,
                stacked_conformer.fprop,
                inputs,
                paddings,
                context_p=context_p,
            )

        self.assertEqual(output.shape, (batch_size, seq_len, model_dims))
示例#20
0
 def test_position_embedding_layer(self, min_timescale, max_timescale):
     p = embedding_softmax.PositionalEmbedding.Params().Set(
         name='jax_pos',
         embedding_dims=50,
         min_timescale=min_timescale,
         max_timescale=max_timescale)
     pos_layer = p.Instantiate()
     prng_key = jax.random.PRNGKey(seed=123)
     initial_vars = pos_layer.instantiate_variables(prng_key)
     seq_length = np.random.randint(100, 1000)
     output = test_utils.apply(pos_layer, initial_vars, pos_layer.fprop,
                               seq_length)
     output = jnp.squeeze(output, axis=0)
     # Test whether tf PositionalEmbedding layer returns same output
     # Modify initial_vars to use TF compatible params
     tf_initial_vars = initial_vars
     tf_p = lingvo_layers.PositionalEmbeddingLayer.Params().Set(
         name='tf_pos',
         embedding_dim=p.embedding_dims,
         min_timescale=min_timescale,
         max_timescale=max_timescale)
     tf_pos_layer = tf_p.Instantiate()
     tf_output = tf_pos_layer.FProp(tf_initial_vars, seq_length)
     np_pos = to_np(output)
     tf_np_pos = to_np(tf_output)
     self.assertAllClose(tf_np_pos, np_pos, atol=1e-3)
示例#21
0
 def test_position_embedding_layer_with_position(self, min_timescale,
                                                 max_timescale):
     p = embedding_softmax.PositionalEmbedding.Params().Set(
         name='jax_pos',
         embedding_dims=50,
         min_timescale=min_timescale,
         max_timescale=max_timescale)
     pos_layer = p.Instantiate()
     prng_key = jax.random.PRNGKey(seed=123)
     initial_vars = pos_layer.instantiate_variables(prng_key)
     position = np.array([[0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
                          [0, 1, 2, 0, 1, 2, 0, 1, 2, 0],
                          [0, 1, 2, 3, 4, 5, 6, 0, 1, 2],
                          [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
     output = test_utils.apply(pos_layer,
                               initial_vars,
                               pos_layer.fprop,
                               position=position)
     # Test whether tf PositionalEmbedding layer returns same output
     # Modify initial_vars to use TF compatible params
     tf_initial_vars = initial_vars
     tf_p = lingvo_layers.PositionalEmbeddingLayer.Params().Set(
         name='tf_pos',
         embedding_dim=p.embedding_dims,
         min_timescale=min_timescale,
         max_timescale=max_timescale)
     tf_pos_layer = tf_p.Instantiate()
     tf_output = tf_pos_layer.FPropWithPosition(tf_initial_vars, position)
     np_pos = to_np(output)
     tf_np_pos = to_np(tf_output)
     self.assertAllClose(tf_np_pos, np_pos, atol=1e-3)
示例#22
0
 def test_rotary_position_embedding_layer_prefix(self, min_timescale,
                                                 max_timescale,
                                                 window_size):
     embedding_dims = 32
     p = embedding_softmax.RotaryPositionalEmbedding.Params().Set(
         name='jax_pos',
         embedding_dims=embedding_dims,
         min_timescale=min_timescale,
         max_timescale=max_timescale)
     pos_layer = p.Instantiate()
     prng_key = jax.random.PRNGKey(seed=123)
     initial_vars = pos_layer.instantiate_variables(prng_key)
     inputs = np.random.normal(1.5, 2.5, (2, 8, 4, embedding_dims))
     output = test_utils.apply(pos_layer,
                               initial_vars,
                               pos_layer.fprop,
                               inputs=inputs)
     # Test whether extend_step returns same output.
     for i in range(inputs.shape[1]):
         start = max(0, i + 1 - window_size)
         end = i + 1
         inputs_prefix = inputs[:, start:end, :, :]
         pad_width = window_size - end + start
         paddings = [(0, 0), (pad_width, 0), (0, 0), (0, 0)]
         inputs_prefix = jnp.pad(inputs_prefix, paddings)
         jax_extend_step_out = test_utils.apply(pos_layer,
                                                initial_vars,
                                                pos_layer.extend_step,
                                                inputs_prefix,
                                                time_step=i)
         jax_extend_step_out = jax.lax.dynamic_slice_in_dim(
             jax_extend_step_out,
             start_index=window_size - 1,
             slice_size=1,
             axis=1)
         jax_np_extend_step_out = test_utils.to_np(jax_extend_step_out)
         jax_fprop_slice = jax.lax.dynamic_slice_in_dim(output,
                                                        start_index=i,
                                                        slice_size=1,
                                                        axis=1)
         self.assertArraysEqual(jax_fprop_slice, jax_np_extend_step_out)
示例#23
0
    def testSpectrumAugmenterWithTimeMask(self):
        batch_size = 5
        inputs = jnp.ones([batch_size, 20, 2], dtype=jnp.float32)
        paddings = []
        for i in range(batch_size):
            paddings.append(
                jnp.concatenate([jnp.zeros([1, i + 12]),
                                 jnp.ones([1, 8 - i])],
                                axis=1))
        paddings = jnp.concatenate(paddings, axis=0)

        p = spectrum_augmenter.SpectrumAugmenter.Params()
        p.name = 'specAug_layers'
        p.freq_mask_max_bins = 0
        p.time_mask_max_frames = 5
        p.time_mask_count = 2
        p.time_mask_max_ratio = 1.
        specaug_layer = p.Instantiate()
        expected_output = np.array([[[1., 1.], [1., 1.], [1., 1.], [1., 1.],
                                     [1., 1.], [1., 1.], [1., 1.], [1., 1.],
                                     [0., 0.], [0., 0.], [0., 0.], [0., 0.],
                                     [1., 1.], [1., 1.], [1., 1.], [1., 1.],
                                     [1., 1.], [1., 1.], [1., 1.], [1., 1.]],
                                    [[1., 1.], [1., 1.], [1., 1.], [1., 1.],
                                     [0., 0.], [1., 1.], [0., 0.], [0., 0.],
                                     [1., 1.], [1., 1.], [1., 1.], [1., 1.],
                                     [1., 1.], [1., 1.], [1., 1.], [1., 1.],
                                     [1., 1.], [1., 1.], [1., 1.], [1., 1.]],
                                    [[1., 1.], [1., 1.], [1., 1.], [1., 1.],
                                     [1., 1.], [1., 1.], [1., 1.], [0., 0.],
                                     [0., 0.], [0., 0.], [1., 1.], [1., 1.],
                                     [1., 1.], [1., 1.], [1., 1.], [1., 1.],
                                     [1., 1.], [1., 1.], [1., 1.], [1., 1.]],
                                    [[1., 1.], [1., 1.], [1., 1.], [1., 1.],
                                     [1., 1.], [1., 1.], [0., 0.], [0., 0.],
                                     [1., 1.], [1., 1.], [0., 0.], [0., 0.],
                                     [0., 0.], [0., 0.], [1., 1.], [1., 1.],
                                     [1., 1.], [1., 1.], [1., 1.], [1., 1.]],
                                    [[1., 1.], [1., 1.], [1., 1.], [1., 1.],
                                     [1., 1.], [1., 1.], [1., 1.], [1., 1.],
                                     [1., 1.], [1., 1.], [1., 1.], [0., 0.],
                                     [0., 0.], [0., 0.], [0., 0.], [1., 1.],
                                     [1., 1.], [1., 1.], [1., 1.], [1., 1.]]])
        context_p = base_layer.JaxContext.Params().Set(do_eval=False)
        prng_key = jax.random.PRNGKey(seed=23456)
        theta = specaug_layer.instantiate_variables(prng_key)
        actual_layer_output, _ = test_utils.apply(specaug_layer,
                                                  theta,
                                                  specaug_layer.fprop,
                                                  inputs,
                                                  paddings,
                                                  context_p=context_p)
        self.assertAllClose(actual_layer_output, expected_output)
示例#24
0
 def test_pooling_layer_with_paddings(self, window_shape, window_stride,
                                      padding, pooling_type, input_shape,
                                      int_inputs, paddings_all_ones):
     p = poolings.Pooling.Params().Set(name='jax_pooling',
                                       window_shape=window_shape,
                                       window_stride=window_stride,
                                       pooling_type=pooling_type,
                                       padding=padding)
     pooling_layer = p.Instantiate()
     prng_key = jax.random.PRNGKey(seed=123)
     initial_vars = pooling_layer.instantiate_variables(prng_key)
     if int_inputs:
         npy_inputs = np.random.randint(0, 100, input_shape).astype('int32')
     else:
         npy_inputs = np.random.normal(1.0, 0.5,
                                       input_shape).astype('float32')
     inputs = jnp.asarray(npy_inputs)
     paddings = None
     tf_paddings = None
     if paddings_all_ones:
         npy_paddings = np.ones([input_shape[0],
                                 input_shape[1]]).astype(npy_inputs.dtype)
     else:
         npy_paddings = np.random.randint(
             0, 2,
             [input_shape[0], input_shape[1]]).astype(npy_inputs.dtype)
     paddings = jnp.asarray(npy_paddings)
     tf_paddings = tf.constant(npy_paddings, dtype=tf.float32)
     output, out_paddings = test_utils.apply(pooling_layer, initial_vars,
                                             pooling_layer.fprop, inputs,
                                             paddings)
     # Test whether tf Pooling layer returns the same output.
     # Modify initial_vars to use TF compatible params.
     tf_initial_vars = initial_vars
     tf_p = lingvo_layers.PoolingLayer.Params().Set(
         name='tf_pooling',
         window_shape=window_shape,
         window_stride=window_stride,
         pooling_type=pooling_type,
         padding_algorithm=padding)
     tf_pooling_layer = tf_p.Instantiate()
     tf_input = tf.constant(npy_inputs, dtype=tf.float32)
     tf_output = tf_pooling_layer.FProp(tf_initial_vars, tf_input,
                                        tf_paddings)
     # Check the actual output.
     np_output = to_np(output)
     tf_np_output = to_np(tf_output[0])
     np_paddings = to_np(out_paddings)
     tf_np_paddings = to_np(tf_output[1])
     # Check the paddings.
     self.assertAllClose(np_paddings, tf_np_paddings)
     self.assertAllClose(tf_np_output, np_output)
示例#25
0
 def _run_decode(self, decoder_p, logits, input_batch):
     p = base_model.LanguageModel.Params()
     p.name = 'mock_lm'
     p.decoder = decoder_p.Copy()
     p.lm = MockLM.Params()
     p.lm.logits = logits
     lang_model = p.Instantiate()
     theta = NestedMap(lm=NestedMap())
     # We fix seed to 1027 to get the desired prefix lengths below.
     _, results = test_utils.apply(lang_model,
                                   theta,
                                   lang_model.decode,
                                   input_batch,
                                   seed=1027)
     return results
示例#26
0
    def testBase(self, b, t, latent_dim, projection_dim, num_classes):
        np.random.seed(2022)
        z = np.random.rand(b, t, latent_dim).astype(np.float32)
        paddings = np.zeros((b, t)).astype(np.float32)

        rq = quantizer.RandomVectorQuantizer.Params().Set(
            name='vq',
            num_latent_classes=num_classes,
            latent_dim=latent_dim,
            projection_dim=projection_dim)
        rq = rq.Instantiate()
        rq_theta = rq.instantiate_variables(jax.random.PRNGKey(1))
        out = test_utils.apply(rq, rq_theta, rq.fprop, z, paddings)
        self.assertEqual((b, t, projection_dim), out.z_q.shape)
        self.assertEqual((b, t), out.z_codes.shape)
        self.assertEqual((b, t, num_classes), out.z_onehot.shape)
示例#27
0
    def test_vit(self):
        batch_size = 3

        p_vit = self._vit_params()
        vit_model = p_vit.Instantiate()

        prng_key = jax.random.PRNGKey(seed=123)
        initial_vars = vit_model.instantiate_variables(prng_key)

        inputs_np = np.random.normal(
            size=[batch_size, p_vit.image_size, p_vit.image_size, 3])
        inputs = jnp.asarray(inputs_np)

        features = test_utils.apply(vit_model, initial_vars, vit_model.fprop,
                                    inputs)

        self.assertEqual(features.shape, (batch_size, p_vit.hidden_dim))
示例#28
0
    def testStackingOverTimePadWithRightFrameFProp(self, pad_with_right_frame):
        params = linears.StackingOverTime.Params()
        params.name = 'stackingOverTime'
        params.left_context = 0
        params.right_context = 1
        params.stride = 2
        params.pad_with_right_frame = pad_with_right_frame

        stacker = linears.StackingOverTime(params)
        stacker_vars = None
        self.assertEqual(stacker.window_size, 2)

        # input shape [2, 5, 2]
        inputs = jnp.array([[[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]],
                            [[7, 7], [8, 8], [0, 0], [0, 0], [0, 0]]],
                           dtype=jnp.float32)
        paddings = jnp.array(
            [[[0], [0], [0], [0], [0]], [[0], [0], [1], [1], [1]]],
            dtype=jnp.float32)
        outputs, output_paddings = test_utils.apply(stacker, stacker_vars,
                                                    stacker.fprop, inputs,
                                                    paddings)
        print(f'{outputs}')

        if pad_with_right_frame:
            # output shape [2, 3, 4]
            # [5, 5] is duplication of the last input frame.
            expected_outputs = jnp.array([
                [[1, 1, 2, 2], [3, 3, 4, 4], [5, 5, 5, 5]],
                [[7, 7, 8, 8], [0, 0, 0, 0], [0, 0, 0, 0]],
            ],
                                         dtype=jnp.float32)
        else:
            expected_outputs = jnp.array([
                [[1, 1, 2, 2], [3, 3, 4, 4], [5, 5, 0, 0]],
                [[7, 7, 8, 8], [0, 0, 0, 0], [0, 0, 0, 0]],
            ],
                                         dtype=jnp.float32)

        self.assertAllClose(expected_outputs, outputs)

        expected_output_paddings = jnp.array(
            [[[0], [0], [0]], [[0], [1], [1]]], dtype=jnp.float32)
        self.assertAllClose(expected_output_paddings, output_paddings)
示例#29
0
    def testVitSkipExitLayers(self):
        batch_size = 3

        p_vit = self._vit_params().Set(exit_layers_tpl=None)
        vit_model = p_vit.Instantiate()

        prng_key = jax.random.PRNGKey(seed=123)
        initial_vars = vit_model.instantiate_variables(prng_key)

        inputs_np = np.random.normal(
            size=[batch_size, p_vit.image_size, p_vit.image_size, 3])
        inputs = jnp.asarray(inputs_np)

        features = test_utils.apply(vit_model, initial_vars, vit_model.fprop,
                                    inputs)

        patch_count = p_vit.image_size // p_vit.patch_size
        self.assertEqual(features.shape,
                         (batch_size, patch_count**2, p_vit.hidden_dim))
示例#30
0
    def test_transformer_bert(self, trainable_position_emb):
        seq_len = 512
        if trainable_position_emb:
            position_emb_tpl = embedding_softmax.TrainablePositionalEmbedding.Params(
            )
            position_emb_tpl.max_seq_length = seq_len
        else:
            position_emb_tpl = embedding_softmax.PositionalEmbedding.Params()
        p = transformer_models.TransformerLm.Params().Set(
            name='bert_lm',
            model_dims=32,
            vocab_size=52,
            position_emb_tpl=position_emb_tpl)
        stacked_transformer_tpl = p.stacked_transformer_tpl
        stacked_transformer_tpl.model_dims = 32
        stacked_transformer_tpl.hidden_dims = 4 * 32
        stacked_transformer_tpl.num_heads = 4
        stacked_transformer_tpl.num_layers = 1
        p.softmax_tpl.scale_sqrt_depth = True
        batch_size = 8
        bert_lm = p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=123)
        initial_vars = bert_lm.instantiate_variables(prng_key)
        input_ids = jax.random.randint(jax.random.PRNGKey(1234),
                                       [batch_size, seq_len], 0, 51)
        input_paddings = jnp.zeros([batch_size, seq_len])
        input_weights = jnp.ones([batch_size, seq_len])
        input_segment_ids = jnp.ones([batch_size, seq_len])
        input_segment_pos = jnp.tile(
            jnp.arange(0, seq_len)[jnp.newaxis, :], [batch_size, 1])

        labels = py_utils.NestedMap()
        labels.class_ids = input_ids
        labels.class_weights = input_weights
        outputs = test_utils.apply(bert_lm,
                                   initial_vars,
                                   bert_lm.fprop,
                                   input_ids,
                                   input_paddings,
                                   labels=labels,
                                   segment_ids=input_segment_ids,
                                   segment_pos=input_segment_pos)
        logging.info('outputs: %s', outputs)