Пример #1
0
    def test_per_dim_scale(self):
        test_layer_p = attentions.PerDimScale.Params().Set(name='scale', dim=4)
        layer = test_layer_p.Instantiate()

        prng_key = jax.random.PRNGKey(seed=123)
        prng_key, init_key = jax.random.split(prng_key)
        initial_vars = layer.instantiate_variables(init_key)
        initial_vars.per_dim_scale = jnp.array([-0.5, 0.5, 1.0, 0.0],
                                               dtype=jnp.float32)
        logging.info('initial_vars: %s', initial_vars)

        inputs = np.random.normal(1.5, 2.0, [5, 4]).astype(np.float32)

        jax_out = test_utils.apply(layer, initial_vars, layer.fprop, inputs)
        logging.info('jax_output: %s', jax_out)

        # Now run TF based computation.
        tf_layer_p = batch_major_attention.PerDimScaleLayer.Params().Set(
            name='scale', dim=4)
        tf_layer = tf_layer_p.Instantiate()
        tf_output1 = tf_layer.FProp(tf_layer.theta, inputs)
        logging.info('tf_output1: %s', tf_output1)
        tf_output2 = tf_layer.FProp(initial_vars, inputs)
        logging.info('tf_output2: %s', tf_output2)
        self.assertAllClose(test_utils.to_np(jax_out),
                            test_utils.to_np(tf_output2))
Пример #2
0
 def have_similar_stats(x, y):
     mean1, std1 = var_stats(test_utils.to_np(x))
     mean2, std2 = var_stats(test_utils.to_np(y))
     delta_mean = np.abs(mean1 - mean2)
     delta_std = np.abs(std1 - std2)
     logging.info('mean1: %s, mean2: %s', mean1, mean2)
     logging.info('std1: %s, std2: %s', std1, std2)
     test_case.assertLess(delta_mean, 0.0002)
     test_case.assertLess(delta_std, 0.0002)
Пример #3
0
    def test_transformer_feedforward(self, activation_function):
        p = transformers.TransformerFeedForward.Params().Set(
            name='ffwd',
            input_dims=8,
            hidden_dims=32,
            activation=activation_function)
        batch_size = 8
        seq_len = 512
        ffwd = p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=123)
        initial_vars = ffwd.instantiate_variables(prng_key)

        npy_inputs = np.random.normal(
            1.0, 0.5, [batch_size, seq_len, p.input_dims]).astype('float32')
        inputs = jnp.asarray(npy_inputs)
        npy_paddings = np.zeros([batch_size, seq_len], dtype=np.float32)
        input_paddings = jnp.asarray(npy_paddings)

        with base_layer.JaxContext.new_context(
                prng_key=jax.random.PRNGKey(seed=1234),
                global_step=jnp.array(0, dtype=jnp.uint32)) as jax_context:
            jax_context.bind(ffwd, ffwd.vars_to_flax_vars(initial_vars))
            outputs = ffwd.fprop(inputs, input_paddings)
            logging.info('outputs: %s', outputs)

        if activation_function.startswith('GATED_'):
            # Default lingvo layers_with_attention.TransformerFeedForwardLayer does
            # not support gating.
            return

        # Test whether Tensorflow TransformerFeedForwardLayer returns the same
        # output. Modify `initial_vars` to use TF compatible params.
        tf_initial_vars = test_utils.replace_jax_transformer_ffwd_vars_to_tf(
            initial_vars)
        tf_initial_vars = test_utils.to_tf_nmap(tf_initial_vars)
        logging.info('tf_initial_vars in transformer feedforward layer = %s',
                     initial_vars)
        tf_p = layers_with_attention.TransformerFeedForwardLayer.Params().Set(
            name='tf_ffwd',
            input_dim=p.input_dims,
            hidden_dim=p.hidden_dims,
            activation=p.activation)
        tf_ffwd = tf_p.Instantiate()
        tf_output = tf_ffwd.FProp(tf_initial_vars,
                                  tf.constant(npy_inputs, dtype=tf.float32),
                                  paddings=test_utils.to_tf_nmap(npy_paddings))
        np_outputs = test_utils.to_np(outputs)
        tf_np_outputs = test_utils.to_np(tf_output)
        self.assertAllClose(tf_np_outputs, np_outputs, atol=1e-5)
Пример #4
0
    def test_causal_depthwise_conv1d(self, shape, kernel_size, axis,
                                     hidden_dims):
        inputs = np.random.normal(1.5, 2.0, shape).astype(np.float32)
        p = attentions.CausalDepthwiseConv1D.Params().Set(
            name='causal_dconv',
            kernel_size=kernel_size,
            hidden_dims=hidden_dims)
        causal_dconv_layer = p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=123)
        prng_key, init_key = jax.random.split(prng_key)
        initial_vars = causal_dconv_layer.instantiate_variables(init_key)
        if isinstance(hidden_dims, list):
            kernel_shape = hidden_dims
        else:
            kernel_shape = [hidden_dims]
        for k in range(kernel_size):
            initial_vars[f'dconv_{k}'] = np.ones(kernel_shape)

        jax_dconv_out = test_utils.apply(causal_dconv_layer,
                                         initial_vars,
                                         causal_dconv_layer.fprop,
                                         inputs,
                                         axis=axis)
        jax_np_out = test_utils.to_np(jax_dconv_out)
        outputs = inputs
        for _ in range(1, kernel_size):
            inputs = attentions.shift_1d(inputs, offset=1, axis=axis)
            outputs += inputs
        self.assertArraysEqual(jax_np_out, outputs)
Пример #5
0
    def test_limited_context_mask_from_padding(self, batch_size, max_length,
                                               left_context, right_context):
        def get_padding_from_length(length):
            idx = np.tile(np.arange(max_length), [batch_size, 1])
            return (idx >= np.expand_dims(length, -1)).astype('float32')

        length = np.random.randint(max_length // 2, max_length, [
            batch_size,
        ])
        padding = jnp.asarray(get_padding_from_length(length))

        result = attentions.limited_context_mask_from_padding(
            padding, left_context, right_context)
        expect = np.zeros((batch_size, 1, max_length, max_length))
        for b in range(batch_size):
            for t1 in range(max_length):
                if t1 >= length[b]:
                    continue
                start_p, end_p = 0, length[b]
                if left_context is not None:
                    start_p = max(0, t1 - left_context + 1)
                if right_context is not None:
                    end_p = min(length[b], t1 + right_context + 1)
                expect[b, 0, t1, start_p:end_p] = 1.0
        self.assertAllClose(test_utils.to_np(result), (1.0 - expect) *
                            attentions._get_large_negative_number(jnp.float32))
Пример #6
0
 def test_rotary_position_embedding_layer_no_prefix(self, min_timescale,
                                                    max_timescale):
     embedding_dims = 32
     p = embedding_softmax.RotaryPositionalEmbedding.Params().Set(
         name='jax_pos',
         embedding_dims=embedding_dims,
         min_timescale=min_timescale,
         max_timescale=max_timescale)
     pos_layer = p.Instantiate()
     prng_key = jax.random.PRNGKey(seed=123)
     initial_vars = pos_layer.instantiate_variables(prng_key)
     inputs = np.random.normal(1.5, 2.5, (2, 8, 4, embedding_dims))
     output = test_utils.apply(pos_layer,
                               initial_vars,
                               pos_layer.fprop,
                               inputs=inputs)
     # Test whether extend_step returns same output.
     for i in range(inputs.shape[1]):
         jax_extend_step_out = test_utils.apply(pos_layer,
                                                initial_vars,
                                                pos_layer.extend_step,
                                                inputs[:, i, :, :],
                                                time_step=i)
         jax_np_extend_step_out = test_utils.to_np(jax_extend_step_out)
         jax_fprop_slice = output[:, i, :, :]
         self.assertArraysEqual(jax_fprop_slice, jax_np_extend_step_out)
Пример #7
0
 def test_rotary_position_embedding_layer_2d(self, position):
     embedding_dims = 2
     min_timescale = 1
     max_timescale = 1e4
     p = embedding_softmax.RotaryPositionalEmbedding.Params().Set(
         name='jax_pos',
         embedding_dims=embedding_dims,
         min_timescale=min_timescale,
         max_timescale=max_timescale)
     pos_layer = p.Instantiate()
     prng_key = jax.random.PRNGKey(seed=123)
     initial_vars = pos_layer.instantiate_variables(prng_key)
     inputs = np.random.normal(1.5, 2.5, (1, 4, 1, embedding_dims))
     if position is None:
         position = jnp.arange(4, dtype=jnp.float32)
     position = jnp.array(position)
     output = test_utils.apply(pos_layer,
                               initial_vars,
                               pos_layer.fprop,
                               inputs=inputs,
                               position=position[jnp.newaxis, :])
     np_output = test_utils.to_np(output)
     sinusoid_inp = position
     sin = jnp.sin(sinusoid_inp)
     cos = jnp.cos(sinusoid_inp)
     first_part = inputs[0, :, 0, 0] * cos - inputs[0, :, 0, 1] * sin
     second_part = inputs[0, :, 0, 1] * cos + inputs[0, :, 0, 0] * sin
     expected_output = np.stack([first_part, second_part], axis=-1)
     self.assertArraysEqual(np_output[0, :, 0, :], expected_output)
Пример #8
0
    def test_mhd_projection_02(self, use_nhd_shape):
        test_layer_p = attentions.AttentionProjection.Params().Set(
            name='mh',
            input_dim=16,
            num_heads=2,
            dim_per_head=5,
            is_output_projection=True,
            use_nhd_shape=use_nhd_shape,
        )
        layer = test_layer_p.Instantiate()

        prng_key = jax.random.PRNGKey(seed=123)
        prng_key, init_key = jax.random.split(prng_key)
        initial_vars = layer.instantiate_variables(init_key)
        logging.info('initial_vars: %s', initial_vars)

        inputs = np.random.normal(1.5, 2.0, [5, 2, 5]).astype(np.float32)

        jax_out = test_utils.apply(layer, initial_vars, layer.fprop, inputs)
        logging.info('jax_output: %s', jax_out)

        if use_nhd_shape:
            initial_vars.w = np.einsum('ABC->CAB', initial_vars.w)

        # Now run TF based computation.
        tf_layer_p = batch_major_attention.MultiHeadedProjectionLayer.Params(
        ).Set(name='mh',
              input_dim=16,
              num_heads=2,
              dim_per_head=5,
              is_output_projection=True)
        tf_layer = tf_layer_p.Instantiate()
        tf_output1 = tf_layer.FProp(tf_layer.theta, inputs)
        logging.info('tf_output1: %s', tf_output1)
        tf_output2 = tf_layer.FProp(initial_vars, inputs)
        logging.info('tf_output2: %s', tf_output2)
        self.assertGreater(
            np.sum(
                np.abs(
                    test_utils.to_np(tf_output1) -
                    test_utils.to_np(tf_output2))), 0.1)
        self.assertAllClose(test_utils.to_np(jax_out),
                            test_utils.to_np(tf_output2))
Пример #9
0
    def test_transformer_relative_bias(self, use_relative_bias):
        p = transformers.Transformer.Params().Set(name='jax_transformer_layer',
                                                  input_dims=32,
                                                  hidden_dims=128,
                                                  num_heads=8,
                                                  mask_self_attention=True,
                                                  packed_input=True,
                                                  cross_attention=False)
        seq_len = np.random.randint(10, 32)
        batch_size = 10
        if use_relative_bias:
            p.tr_atten_tpl.relative_bias_tpl = attentions.RelativeBias.Params(
            ).Set(relative_attention_num_buckets=2,
                  relative_attention_max_distance=8)
        transformer_layer = p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=123)
        initial_vars = transformer_layer.instantiate_variables(prng_key)
        npy_inputs = np.random.normal(
            1.0, 0.5, [batch_size, seq_len, p.input_dims]).astype('float32')
        inputs = jnp.asarray(npy_inputs)
        npy_paddings = np.random.randint(
            0, 1, [batch_size, seq_len]).astype('float32')
        paddings = jnp.asarray(npy_paddings)
        attention_mask = attentions.convert_paddings_to_mask(paddings)
        causal_mask = attentions.causal_mask(inputs)
        attention_mask = jnp.minimum(attention_mask, causal_mask)
        segment_ids = np.random.random_integers(0, 2, [batch_size, seq_len])
        segment_mask = attentions.segment_mask(segment_ids, dtype=np.float32)
        attention_mask = jnp.minimum(attention_mask, segment_mask)

        if use_relative_bias:
            segment_pos = np.random.randint(
                0, seq_len, [batch_size, seq_len]).astype('int32')
            segment_pos = jnp.asarray(segment_pos)
        else:
            segment_pos = None

        with base_layer.JaxContext.new_context(
                prng_key=prng_key,
                global_step=jnp.array(0, dtype=jnp.uint32)) as jax_context:
            jax_context.bind(transformer_layer,
                             transformer_layer.vars_to_flax_vars(initial_vars))
            outputs, _ = transformer_layer.fprop(inputs,
                                                 paddings,
                                                 attention_mask=attention_mask,
                                                 segment_pos=segment_pos)
        logging.info('initial_vars in transformer layer = %s', initial_vars)

        np_outputs = test_utils.to_np(outputs)
        logging.info('np_outputs: %s', np_outputs)
        if use_relative_bias:
            self.assertAlmostEqual(np_outputs[0, 0, 1], 0.79015386, places=5)
            self.assertAlmostEqual(np_outputs[0, 1, 0], 0.48336178, places=5)
        # Plumbing test.
        self.assertAllClose(np_outputs, np_outputs, atol=1e-5)
Пример #10
0
 def test_causal_depthwise_conv1d_extend_step(self, shape, kernel_size,
                                              axis, hidden_dims):
     inputs = np.random.normal(1.5, 2.0, shape).astype(np.float32)
     p = attentions.CausalDepthwiseConv1D.Params().Set(
         name='causal_dconv',
         kernel_size=kernel_size,
         hidden_dims=hidden_dims)
     causal_dconv_layer = p.Instantiate()
     prng_key = jax.random.PRNGKey(seed=123)
     prng_key, init_key = jax.random.split(prng_key)
     initial_vars = causal_dconv_layer.instantiate_variables(init_key)
     prng_key, compute_key = jax.random.split(prng_key)
     global_step = jnp.array(0, dtype=jnp.uint64)
     with base_layer.JaxContext.new_context(
             prng_key=compute_key, global_step=global_step) as jax_context:
         jax_context.bind(
             causal_dconv_layer,
             causal_dconv_layer.vars_to_flax_vars(initial_vars))
         jax_dconv_out = causal_dconv_layer.fprop(inputs, axis=axis)
         jax_np_out = test_utils.to_np(jax_dconv_out)
         jax_extend_step_out = jnp.zeros_like(jax_dconv_out)
         for i in range(shape[1]):
             jax_extend_step_out = causal_dconv_layer.extend_step(inputs,
                                                                  axis=axis,
                                                                  step=i)
             jax_np_extend_step_out = test_utils.to_np(jax_extend_step_out)
             jax_extend_step_out_tensor = causal_dconv_layer.extend_step(
                 inputs, axis=axis, step=jnp.array(i))
             jax_np_extend_step_out_tensor = test_utils.to_np(
                 jax_extend_step_out_tensor)
             jax_fprop_slice = jax.lax.dynamic_slice_in_dim(jax_np_out,
                                                            start_index=i,
                                                            slice_size=1,
                                                            axis=axis)
             jax_fprop_slice = jnp.squeeze(jax_fprop_slice, axis)
             self.assertArraysEqual(jax_fprop_slice, jax_np_extend_step_out)
             self.assertArraysEqual(jax_fprop_slice,
                                    jax_np_extend_step_out_tensor)
Пример #11
0
 def test_rotary_position_embedding_layer_prefix(self, min_timescale,
                                                 max_timescale,
                                                 window_size):
     embedding_dims = 32
     p = embedding_softmax.RotaryPositionalEmbedding.Params().Set(
         name='jax_pos',
         embedding_dims=embedding_dims,
         min_timescale=min_timescale,
         max_timescale=max_timescale)
     pos_layer = p.Instantiate()
     prng_key = jax.random.PRNGKey(seed=123)
     initial_vars = pos_layer.instantiate_variables(prng_key)
     inputs = np.random.normal(1.5, 2.5, (2, 8, 4, embedding_dims))
     output = test_utils.apply(pos_layer,
                               initial_vars,
                               pos_layer.fprop,
                               inputs=inputs)
     # Test whether extend_step returns same output.
     for i in range(inputs.shape[1]):
         start = max(0, i + 1 - window_size)
         end = i + 1
         inputs_prefix = inputs[:, start:end, :, :]
         pad_width = window_size - end + start
         paddings = [(0, 0), (pad_width, 0), (0, 0), (0, 0)]
         inputs_prefix = jnp.pad(inputs_prefix, paddings)
         jax_extend_step_out = test_utils.apply(pos_layer,
                                                initial_vars,
                                                pos_layer.extend_step,
                                                inputs_prefix,
                                                time_step=i)
         jax_extend_step_out = jax.lax.dynamic_slice_in_dim(
             jax_extend_step_out,
             start_index=window_size - 1,
             slice_size=1,
             axis=1)
         jax_np_extend_step_out = test_utils.to_np(jax_extend_step_out)
         jax_fprop_slice = jax.lax.dynamic_slice_in_dim(output,
                                                        start_index=i,
                                                        slice_size=1,
                                                        axis=1)
         self.assertArraysEqual(jax_fprop_slice, jax_np_extend_step_out)
Пример #12
0
    def test_transformer_layer_norm_policies(self, norm_policy):
        p = transformers.Transformer.Params().Set(name='jax_transformer_layer',
                                                  input_dims=32,
                                                  hidden_dims=128,
                                                  num_heads=8,
                                                  mask_self_attention=True,
                                                  packed_input=True,
                                                  cross_attention=False,
                                                  norm_policy=norm_policy)
        seq_len = np.random.randint(10, 32)
        batch_size = 10
        transformer_layer = p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=123)
        initial_vars = transformer_layer.instantiate_variables(prng_key)
        npy_inputs = np.random.normal(
            1.0, 0.5, [batch_size, seq_len, p.input_dims]).astype('float32')
        inputs = jnp.asarray(npy_inputs)
        npy_paddings = np.random.randint(
            0, 1, [batch_size, seq_len]).astype('float32')
        paddings = jnp.asarray(npy_paddings)
        attention_mask = attentions.convert_paddings_to_mask(paddings)
        causal_mask = attentions.causal_mask(inputs)
        attention_mask = jnp.minimum(attention_mask, causal_mask)
        segment_ids = np.random.random_integers(0, 2, [batch_size, seq_len])
        segment_mask = attentions.segment_mask(segment_ids, dtype=np.float32)
        attention_mask = jnp.minimum(attention_mask, segment_mask)

        with base_layer.JaxContext.new_context(
                prng_key=prng_key,
                global_step=jnp.array(0, dtype=jnp.uint32)) as jax_context:
            jax_context.bind(transformer_layer,
                             transformer_layer.vars_to_flax_vars(initial_vars))
            outputs, _ = transformer_layer.fprop(inputs,
                                                 paddings,
                                                 attention_mask=attention_mask)
        logging.info('initial_vars in transformer layer = %s', initial_vars)

        np_outputs = test_utils.to_np(outputs)
        # Plumbing test.
        self.assertAllClose(np_outputs, np_outputs, atol=1e-5)
Пример #13
0
    def test_glam_unitransformer(self):
        batch = 2
        length = 3
        d_model = 6
        num_heads = 2
        vocab_size = 16
        ff_dim = 8
        c_dim = 3
        e_dim = 2
        num_layers = 4
        # Build jax layer
        jax_p = transformer_models.TransformerLm.GLaMUniTransformerParams(
            name='model',
            vocab_size=vocab_size,
            num_transformer_layers=num_layers,
            moe=True,
            model_dim=d_model,
            ff_dim=ff_dim,
            moe_hidden_dim=ff_dim,
            attention_num_heads=num_heads,
            attention_key_value_dim=d_model // num_heads,
            attention_extra_logit=0.0,
            use_tgt_labels_size_as_loss_denominator=True,
            moe_load_balance_loss_weight=0.01,
            z_loss_weight=1e-4,
            c_dim=c_dim,
            e_dim=e_dim)
        assert jax_p.packed_input
        jax_layer = jax_p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=42)
        jax_vars = jax_layer.instantiate_variables(prng_key)

        builder_p = gshard_builder.DenseBuilder.Params().Set(
            num_groups=1,
            second_expert_policy='all',
            relative_attention_type='bias',
            model_dim=d_model,
            attention_key_value_dim=d_model // num_heads,
            attention_num_heads=num_heads,
            attention_combine_dims=True,
            c_dim=c_dim,
            capacity_factor=None,
            attention_extra_logit=0.0,
            e_dim=e_dim,
            moe_hidden_dim=ff_dim,
            ff_dim=ff_dim)
        tf_layer = gshard_builder.UniTransformer.Params().Set(
            name='model',
            num_transformer_layers=num_layers,
            builder=builder_p,
            vocab_size=vocab_size,
            sequence_length=length,
            label_smoothing=0,
            aux_loss_coef=0.01,
            z_loss=1e-4,
            use_tgt_labels_size_as_loss_denominator=True,
            positional_embedding=False,
            gated_gelu=True,
            moe=True).Instantiate()

        # Build Jax Inputs
        np.random.seed(42)
        npy_ids = np.random.randint(0, vocab_size - 1, [batch, length])
        jax_ids = jnp.asarray(npy_ids)
        npy_paddings = np.array([[0, 0, 1], [0, 0, 1]], dtype=np.float32)

        jax_paddings = jnp.asarray(npy_paddings)
        npy_segment_ids = np.array([[1, 2, 0], [1, 1, 0]], dtype=np.int32)
        npy_segment_pos = np.array([[0, 0, 0], [0, 1, 0]], dtype=np.int32)
        npy_labels = np.roll(npy_ids, -1, axis=1)
        jax_labels = jnp.asarray(npy_labels)
        jax_seg_ids = jnp.asarray(npy_segment_ids)
        jax_seg_pos = jnp.asarray(npy_segment_pos)
        jax_label_weighs = jnp.asarray([[1, 1, 0], [1, 1, 0]])

        # Build TF Inputs
        tf_tgt_inputs = py_utils.NestedMap(
            ids=tf.convert_to_tensor(npy_ids, dtype=tf.int32),
            labels=tf.convert_to_tensor(npy_labels, dtype=tf.int32),
            segment_ids=tf.convert_to_tensor(npy_segment_ids, dtype=tf.int32),
            segment_pos=tf.convert_to_tensor(npy_segment_pos, dtype=tf.int32))
        tf_inputs = py_utils.NestedMap(tgt=tf_tgt_inputs)

        # Compute jax outputs
        jax_outputs = test_utils.apply(jax_layer,
                                       jax_vars,
                                       jax_layer.fprop,
                                       jax_ids,
                                       jax_paddings,
                                       context_p=None,
                                       labels=py_utils.NestedMap(
                                           class_ids=jax_labels,
                                           class_weights=jax_label_weighs,
                                       ),
                                       segment_ids=jax_seg_ids,
                                       segment_pos=jax_seg_pos)

        # Copy jax vars to tf ones.
        tf_theta = tf_layer.theta.DeepCopy()

        # GShardBuilder softmax weight use self.vars rather than theta.
        tf_layer.vars.dec_emb.w.embedding.assign(jax_vars.softmax.embedding.w)
        tf_theta.dec_emb.w.embedding = jax_vars.softmax.embedding.w
        tf_theta.dec.final_layer_norm.w.scale = jax_vars.final_ln.scale
        jax_layer_0_var = tf.nest.map_structure(
            lambda v: jnp.squeeze(jnp.split(v, 2)[0], axis=0),
            jax_vars.transformer.repeat.sub.x_layers[0])
        tf_theta.dec.layer_000.ln.w.scale = jax_layer_0_var.layer_norm.scale
        jax_atten_var = jax_layer_0_var.self_attention
        tf_atten_var = tf_theta.dec.layer_000.dec_self_attention
        tf_atten_var.w.wk = jax_atten_var.key.w
        tf_atten_var.w.wq = jax_atten_var.query.w
        tf_atten_var.w.wv = jax_atten_var.value.w
        tf_atten_var.w.wo = jax_atten_var.post.w
        tf_atten_var.wrb.wrb = jax_atten_var.relative_bias.wrb

        jax_moe_var = jax_layer_0_var.ff_layer
        tf_theta.dec.layer_001.ln.w.scale = jax_moe_var.layer_norm.scale
        tf_theta.dec.layer_001.moe.ffw.top_2_gating.w = jax_moe_var.gate
        tf_theta.dec.layer_001.moe.moe.wi = jax_moe_var.wi_0
        tf_theta.dec.layer_001.moe.moe.wo = jax_moe_var.wo_0

        jax_layer_1_var = tf.nest.map_structure(
            lambda v: jnp.squeeze(jnp.split(v, 2)[0], axis=0),
            jax_vars.transformer.repeat.sub.x_layers[1])
        tf_theta.dec.layer_002.ln.w.scale = jax_layer_1_var.layer_norm.scale
        jax_atten_var = jax_layer_1_var.self_attention
        tf_atten_var = tf_theta.dec.layer_002.dec_self_attention
        tf_atten_var.w.wk = jax_atten_var.key.w
        tf_atten_var.w.wq = jax_atten_var.query.w
        tf_atten_var.w.wv = jax_atten_var.value.w
        tf_atten_var.w.wo = jax_atten_var.post.w
        tf_atten_var.wrb.wrb = jax_atten_var.relative_bias.wrb

        jax_ffn_var = jax_layer_1_var.ff_layer
        tf_ffn_var = tf_theta.dec.layer_003.dense_relu_dense
        tf_ffn_var.w.wi_0 = jax_ffn_var.ffn_layer1_gate.linear.w
        tf_ffn_var.w.wi_1 = jax_ffn_var.ffn_layer1.linear.w
        tf_ffn_var.w.wo = jax_ffn_var.ffn_layer2.linear.w
        tf_theta.dec.layer_003.ln.w.scale = jax_ffn_var.layer_norm.scale

        jax_layer_2_var = tf.nest.map_structure(
            lambda v: jnp.squeeze(jnp.split(v, 2)[1], axis=0),
            jax_vars.transformer.repeat.sub.x_layers[0])
        tf_theta.dec.layer_004.ln.w.scale = jax_layer_2_var.layer_norm.scale
        jax_atten_var = jax_layer_2_var.self_attention
        tf_atten_var = tf_theta.dec.layer_004.dec_self_attention
        tf_atten_var.w.wk = jax_atten_var.key.w
        tf_atten_var.w.wq = jax_atten_var.query.w
        tf_atten_var.w.wv = jax_atten_var.value.w
        tf_atten_var.w.wo = jax_atten_var.post.w
        tf_atten_var.wrb.wrb = jax_atten_var.relative_bias.wrb

        jax_moe_var = jax_layer_2_var.ff_layer
        tf_theta.dec.layer_005.ln.w.scale = jax_moe_var.layer_norm.scale
        tf_theta.dec.layer_005.moe.ffw.top_2_gating.w = jax_moe_var.gate
        tf_theta.dec.layer_005.moe.moe.wi = jax_moe_var.wi_0
        tf_theta.dec.layer_005.moe.moe.wo = jax_moe_var.wo_0

        jax_layer_3_var = tf.nest.map_structure(
            lambda v: jnp.squeeze(jnp.split(v, 2)[1], axis=0),
            jax_vars.transformer.repeat.sub.x_layers[1])
        tf_theta.dec.layer_006.ln.w.scale = jax_layer_3_var.layer_norm.scale
        jax_atten_var = jax_layer_3_var.self_attention
        tf_atten_var = tf_theta.dec.layer_006.dec_self_attention
        tf_atten_var.w.wk = jax_atten_var.key.w
        tf_atten_var.w.wq = jax_atten_var.query.w
        tf_atten_var.w.wv = jax_atten_var.value.w
        tf_atten_var.w.wo = jax_atten_var.post.w
        tf_atten_var.wrb.wrb = jax_atten_var.relative_bias.wrb

        jax_ffn_var = jax_layer_3_var.ff_layer
        tf_ffn_var = tf_theta.dec.layer_007.dense_relu_dense
        tf_ffn_var.w.wi_0 = jax_ffn_var.ffn_layer1_gate.linear.w
        tf_ffn_var.w.wi_1 = jax_ffn_var.ffn_layer1.linear.w
        tf_ffn_var.w.wo = jax_ffn_var.ffn_layer2.linear.w
        tf_theta.dec.layer_007.ln.w.scale = jax_ffn_var.layer_norm.scale

        tf_theta = test_utils.to_tf_nmap(tf_theta)

        # Compute TF outputs
        tf_out, _ = tf_layer.FProp(tf_theta, tf_inputs)
        self.assertAllClose(test_utils.to_np(jax_outputs.total_loss),
                            test_utils.to_np(tf_out['loss'][0]))
Пример #14
0
    def test_transformer_moe_dense_layer(self, mask_self_attention,
                                         packed_input, cross_attention):
        # Comparing scan over blocks of layers and regular loop
        block_p = transformers.StackedTransformer.Params().Set(
            name='transformer_block',
            num_layers=2,
            model_dims=3,
            hidden_dims=6,
            num_heads=1,
            mask_self_attention=mask_self_attention,
            packed_input=packed_input,
            cross_attention=cross_attention,
            num_experts=4,
            num_groups=1,
            moe_layers=[0])

        block_p_repeated = transformers.StackedTransformerRepeated.Params(
        ).Set(name='stacked_transformer_layer_repeated',
              block=block_p.Copy(),
              x_times=1)

        stack_p = transformers.StackedTransformer.Params().Set(
            name='transformer_stack',
            num_layers=2,  # moe + dense
            model_dims=block_p.model_dims,
            hidden_dims=block_p.hidden_dims,
            num_heads=block_p.num_heads,
            mask_self_attention=block_p.mask_self_attention,
            packed_input=block_p.packed_input,
            cross_attention=block_p.cross_attention,
            num_experts=block_p.num_experts,
            num_groups=block_p.num_groups,
            moe_layers=[0])

        moe_p = stack_p.moe_layer_tpl
        moe_p.expert_capacity_dim = 2
        moe_p.expert_capacity_factor = 0

        moe_p = block_p.moe_layer_tpl
        moe_p.expert_capacity_dim = 2
        moe_p.expert_capacity_factor = 0

        transformer_block = block_p_repeated.Instantiate()
        transformer_stack = stack_p.Instantiate()

        seq_len = 4
        batch_size = 3
        prng_key = jax.random.PRNGKey(seed=123)
        block_initial_vars = transformer_block.instantiate_variables(prng_key)
        stack_initial_vars = transformer_stack.instantiate_variables(prng_key)
        npy_inputs = np.random.normal(
            1.0, 0.5,
            [batch_size, seq_len, block_p.model_dims]).astype('float32')
        inputs = jnp.asarray(npy_inputs)
        npy_paddings = np.random.randint(
            0, 1, [batch_size, seq_len]).astype('float32')
        paddings = jnp.asarray(npy_paddings)
        segment_mask = None
        if packed_input:
            segment_ids = np.random.random_integers(0, 2,
                                                    [batch_size, seq_len])
            segment_mask = attentions.segment_mask(segment_ids,
                                                   dtype=np.float32)

        cross_inputs = None
        cross_paddings = None
        cross_segment_mask = None
        if cross_attention:
            cross_seq_len = np.random.randint(10, 64)
            npy_cross_inputs = np.random.normal(
                1.0, 0.5, [batch_size, cross_seq_len, block_p.model_dims
                           ]).astype('float32')
            cross_inputs = jnp.asarray(npy_cross_inputs)
            npy_cross_paddings = np.random.randint(
                0, 1, [batch_size, cross_seq_len]).astype('float32')
            cross_paddings = jnp.asarray(npy_cross_paddings)
            if packed_input:
                source_segment_ids = np.random.random_integers(
                    0, 2, [batch_size, cross_seq_len])
                cross_segment_mask = attentions.segment_mask(
                    segment_ids, source_segment_ids, dtype=np.float32)

        block_outputs = test_utils.apply(transformer_block,
                                         block_initial_vars,
                                         transformer_block.fprop,
                                         inputs,
                                         paddings,
                                         segment_mask=segment_mask,
                                         cross_inputs=cross_inputs,
                                         cross_paddings=cross_paddings,
                                         cross_segment_mask=cross_segment_mask)

        stack_outputs = test_utils.apply(transformer_stack,
                                         stack_initial_vars,
                                         transformer_stack.fprop,
                                         inputs,
                                         paddings,
                                         segment_mask=segment_mask,
                                         cross_inputs=cross_inputs,
                                         cross_paddings=cross_paddings,
                                         cross_segment_mask=cross_segment_mask)
        _ = test_utils.to_np(block_outputs)
        _ = test_utils.to_np(stack_outputs)
Пример #15
0
    def test_transformer_layer(self, mask_self_attention, packed_input,
                               cross_attention):
        p = transformers.Transformer.Params().Set(
            name='jax_transformer_layer',
            input_dims=32,
            hidden_dims=128,
            num_heads=8,
            mask_self_attention=mask_self_attention,
            packed_input=packed_input,
            cross_attention=cross_attention)
        seq_len = np.random.randint(10, 32)
        batch_size = 10
        transformer_layer = p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=123)
        initial_vars = transformer_layer.instantiate_variables(prng_key)
        npy_inputs = np.random.normal(
            1.0, 0.5, [batch_size, seq_len, p.input_dims]).astype('float32')
        inputs = jnp.asarray(npy_inputs)
        npy_paddings = np.random.randint(
            0, 1, [batch_size, seq_len]).astype('float32')
        paddings = jnp.asarray(npy_paddings)
        causal_mask = None
        segment_mask = None
        tf_segment_mask = None
        attention_mask = attentions.convert_paddings_to_mask(paddings)
        if mask_self_attention:
            causal_mask = attentions.causal_mask(inputs)
            attention_mask = jnp.minimum(attention_mask, causal_mask)
        if packed_input:
            segment_ids = np.random.random_integers(0, 2,
                                                    [batch_size, seq_len])
            segment_mask = attentions.segment_mask(segment_ids,
                                                   dtype=np.float32)
            attention_mask = jnp.minimum(attention_mask, segment_mask)
            if mask_self_attention:
                tf_segment_mask = batch_major_attention.CausalSegmentMask(
                    segment_ids, tf.float32)
            else:
                tf_segment_mask = batch_major_attention.SegmentMask(
                    segment_ids, segment_ids)

        cross_inputs = None
        cross_attention_mask = None
        tf_cross_inputs = None
        tf_cross_paddings = None
        tf_cross_segment_mask = None
        if cross_attention:
            cross_seq_len = np.random.randint(10, 128)
            npy_cross_inputs = np.random.normal(
                1.0, 0.5,
                [batch_size, cross_seq_len, p.input_dims]).astype('float32')
            cross_inputs = jnp.asarray(npy_cross_inputs)
            tf_cross_inputs = tf.constant(npy_cross_inputs, dtype=tf.float32)
            npy_cross_paddings = np.random.randint(
                0, 1, [batch_size, cross_seq_len]).astype('float32')
            cross_paddings = jnp.asarray(npy_cross_paddings)
            cross_attention_mask = attentions.convert_paddings_to_mask(
                cross_paddings)
            tf_cross_paddings = tf.constant(npy_cross_paddings,
                                            dtype=tf.float32)
            if packed_input:
                source_segment_ids = np.random.random_integers(
                    0, 2, [batch_size, cross_seq_len])
                cross_segment_mask = attentions.segment_mask(
                    segment_ids, source_segment_ids, dtype=np.float32)
                cross_attention_mask = jnp.minimum(cross_attention_mask,
                                                   cross_segment_mask)
                tf_cross_segment_mask = batch_major_attention.SegmentMask(
                    segment_ids, source_segment_ids)

        outputs, _ = test_utils.apply(
            transformer_layer,
            initial_vars,
            transformer_layer.fprop,
            inputs,
            paddings,
            context_p=None,
            attention_mask=attention_mask,
            cross_inputs=cross_inputs,
            cross_attention_mask=cross_attention_mask)
        logging.info('initial_vars in transformer layer = %s', initial_vars)

        # Test whether tf Transformer layer returns same output
        # Modify initial_vars to use TF compatible params
        tf_initial_vars = test_utils.replace_jax_attention_vars_to_tf(
            initial_vars, cross_attention)
        tf_initial_vars = test_utils.to_tf_nmap(tf_initial_vars)
        logging.info('tf_initial_vars in transformer layer = %s', initial_vars)
        tf_p = batch_major_attention.TransformerLayer.Params().Set(
            name='tf_transformer_layer',
            input_dim=p.input_dims,
            num_heads=p.num_heads,
            mask_self_atten=mask_self_attention,
            packed_input=packed_input,
            has_aux_atten=cross_attention)
        tf_p.tr_fflayer_tpl.hidden_dim = p.hidden_dims
        tf_p.tr_fflayer_tpl.fflayer_tpl.batch_norm = False
        tf_p.tr_fflayer_tpl.fflayer_tpl.has_bias = True
        tf_transformer_layer = tf_p.Instantiate()
        tf_output, _ = tf_transformer_layer.FProp(
            tf_initial_vars,
            tf.constant(npy_inputs, dtype=tf.float32),
            paddings=test_utils.to_tf_nmap(npy_paddings),
            segment_mask=tf_segment_mask,
            aux_vec=tf_cross_inputs,
            aux_paddings=tf_cross_paddings,
            aux_segment_mask=test_utils.to_tf_nmap(tf_cross_segment_mask))
        np_outputs = test_utils.to_np(outputs)
        tf_np_outputs = test_utils.to_np(tf_output)
        self.assertAllClose(tf_np_outputs, np_outputs, atol=1e-5)
Пример #16
0
    def test_stacked_transformer_layer(self, mask_self_attention, packed_input,
                                       cross_attention):
        p = transformers.StackedTransformer.Params().Set(
            name='jax_stacked_transformer_layer',
            model_dims=16,
            hidden_dims=64,
            num_heads=8,
            mask_self_attention=mask_self_attention,
            num_layers=4,
            packed_input=packed_input,
            cross_attention=cross_attention)
        seq_len = np.random.randint(10, 32)
        batch_size = 10
        stacked_transformer_layer = p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=123)
        initial_vars = stacked_transformer_layer.instantiate_variables(
            prng_key)

        # test conversion between vars and flax vars.
        pax_vars = stacked_transformer_layer.vars
        flax_vars = stacked_transformer_layer.flax_vars
        tf.nest.assert_same_structure(
            pax_vars, stacked_transformer_layer.flax_vars_to_vars(flax_vars))
        tf.nest.assert_same_structure(
            flax_vars, stacked_transformer_layer.vars_to_flax_vars(pax_vars))

        npy_inputs = np.random.normal(
            1.0, 0.5, [batch_size, seq_len, p.model_dims]).astype('float32')
        inputs = jnp.asarray(npy_inputs)
        npy_paddings = np.random.randint(
            0, 1, [batch_size, seq_len]).astype('float32')
        paddings = jnp.asarray(npy_paddings)
        segment_mask = None
        tf_segment_mask = None
        if packed_input:
            segment_ids = np.random.random_integers(0, 2,
                                                    [batch_size, seq_len])
            segment_mask = attentions.segment_mask(segment_ids,
                                                   dtype=np.float32)
            if mask_self_attention:
                tf_segment_mask = batch_major_attention.CausalSegmentMask(
                    segment_ids, tf.float32)
            else:
                tf_segment_mask = batch_major_attention.SegmentMask(
                    segment_ids, segment_ids)

        cross_inputs = None
        cross_paddings = None
        cross_segment_mask = None
        tf_cross_inputs = None
        tf_cross_paddings = None
        tf_cross_segment_mask = None
        if cross_attention:
            cross_seq_len = np.random.randint(10, 64)
            npy_cross_inputs = np.random.normal(
                1.0, 0.5,
                [batch_size, cross_seq_len, p.model_dims]).astype('float32')
            cross_inputs = jnp.asarray(npy_cross_inputs)
            tf_cross_inputs = tf.constant(npy_cross_inputs, dtype=tf.float32)
            npy_cross_paddings = np.random.randint(
                0, 1, [batch_size, cross_seq_len]).astype('float32')
            cross_paddings = jnp.asarray(npy_cross_paddings)
            tf_cross_paddings = tf.constant(npy_cross_paddings,
                                            dtype=tf.float32)
            if packed_input:
                source_segment_ids = np.random.random_integers(
                    0, 2, [batch_size, cross_seq_len])
                cross_segment_mask = attentions.segment_mask(
                    segment_ids, source_segment_ids, dtype=np.float32)
                tf_cross_segment_mask = batch_major_attention.SegmentMask(
                    segment_ids, source_segment_ids)

        outputs = test_utils.apply(stacked_transformer_layer,
                                   initial_vars,
                                   stacked_transformer_layer.fprop,
                                   inputs,
                                   paddings,
                                   context_p=None,
                                   segment_mask=segment_mask,
                                   cross_inputs=cross_inputs,
                                   cross_paddings=cross_paddings,
                                   cross_segment_mask=cross_segment_mask)
        logging.info('initial_vars in transformer layer = %s', initial_vars)

        # Test whether tf Transformer layer returns same output
        # Modify initial_vars to use TF compatible params
        tf_initial_vars = py_utils.NestedMap()
        tf_initial_vars.x_layers = []
        for jax_initial_vars in initial_vars.x_layers:
            tf_layer_vars = test_utils.replace_jax_attention_vars_to_tf(
                jax_initial_vars, cross_attention)
            tf_initial_vars.x_layers.append(tf_layer_vars)
        tf_initial_vars = test_utils.to_tf_nmap(tf_initial_vars)
        logging.info('tf_initial_vars in transformer layer = %s', initial_vars)
        tf_p = batch_major_attention.StackedTransformerLayers.Params().Set(
            name='tf_transformer_layer',
            mdl_dim=p.model_dims,
            hidden_dim=p.hidden_dims,
            num_atten_heads=p.num_heads,
            mask_self_atten=mask_self_attention,
            num_layers=p.num_layers,
            packed_input=packed_input,
            has_aux_atten=cross_attention)
        tf_p.transformer_layer_params_tpl.tr_fflayer_tpl.fflayer_tpl.batch_norm = (
            False)
        tf_p.transformer_layer_params_tpl.tr_fflayer_tpl.fflayer_tpl.has_bias = True
        tf_stacked_transformer_layer = tf_p.Instantiate()
        tf_output, _ = tf_stacked_transformer_layer.FProp(
            tf_initial_vars,
            test_utils.to_tf_nmap(npy_inputs),
            paddings=test_utils.to_tf_nmap(npy_paddings),
            segment_mask=test_utils.to_tf_nmap(tf_segment_mask),
            aux_vec=test_utils.to_tf_nmap(tf_cross_inputs),
            aux_paddings=test_utils.to_tf_nmap(tf_cross_paddings),
            aux_segment_mask=test_utils.to_tf_nmap(tf_cross_segment_mask))
        np_outputs = test_utils.to_np(outputs)
        tf_np_outputs = test_utils.to_np(tf_output)
        self.assertAllClose(tf_np_outputs, np_outputs, atol=1e-5)
Пример #17
0
    def test_stacked_transformer_layer_extendstep(self, packed_input,
                                                  cross_attention, combine_qkv,
                                                  dconv_qkv,
                                                  use_rotary_position_emb):
        if cross_attention and combine_qkv:
            self.skipTest(
                'combine_qkv optimization only works for self-attention.')
        layer_params = transformers.StackedTransformer.Params()

        num_layers = 2
        model_dims = 8
        p = layer_params.Set(name='jax_transformer_layer',
                             model_dims=model_dims,
                             hidden_dims=32,
                             num_heads=2,
                             mask_self_attention=True,
                             packed_input=packed_input,
                             cross_attention=cross_attention,
                             num_layers=num_layers)
        p.transformer_layer_params_tpl.tr_atten_tpl.combine_qkv = combine_qkv
        p.transformer_layer_params_tpl.tr_atten_tpl.dconv_qkv = dconv_qkv
        p.transformer_layer_params_tpl.tr_atten_tpl.use_rotary_position_emb = (
            use_rotary_position_emb)
        if cross_attention:
            p.transformer_layer_params_tpl.cross_atten_tpl = (
                p.transformer_layer_params_tpl.tr_atten_tpl.Copy())
            # Cross attention should not have depth-wise convolution.
            p.transformer_layer_params_tpl.cross_atten_tpl.dconv_qkv = False
            # Cross attention should not have rotary position embedding.
            p.transformer_layer_params_tpl.cross_atten_tpl.use_rotary_position_emb = (
                False)

        p_copy = p.Copy()
        p_copy.num_layers = 1
        p = transformers.StackedTransformerRepeated.Params()
        p.name = 'jax_transformer_repeated_layer'
        p.block = p_copy
        p.x_times = num_layers

        seq_len = 4
        batch_size = 4
        stacked_transformer_layer = p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=123)
        initial_vars = stacked_transformer_layer.instantiate_variables(
            prng_key)
        npy_inputs = np.random.normal(
            1.0, 0.5, [batch_size, seq_len, model_dims]).astype('float32')
        inputs = jnp.asarray(npy_inputs)
        npy_paddings = np.random.randint(
            0, 1, [batch_size, seq_len]).astype('float32')
        paddings = jnp.asarray(npy_paddings)
        attention_mask = attentions.convert_paddings_to_mask(paddings)
        segment_mask = None
        if packed_input:
            segment_ids = np.random.random_integers(0, 2,
                                                    [batch_size, seq_len])
            segment_mask = attentions.segment_mask(segment_ids,
                                                   dtype=np.float32)

        cross_inputs = None
        cross_paddings = None
        cross_segment_mask = None
        if cross_attention:
            cross_seq_len = np.random.randint(10, 32)
            npy_cross_inputs = np.random.normal(
                1.0, 0.5,
                [batch_size, cross_seq_len, model_dims]).astype('float32')
            cross_inputs = jnp.asarray(npy_cross_inputs)
            npy_cross_paddings = np.random.randint(
                0, 1, [batch_size, cross_seq_len]).astype('float32')
            cross_paddings = jnp.asarray(npy_cross_paddings)
            if packed_input:
                source_segment_ids = np.random.random_integers(
                    0, 2, [batch_size, cross_seq_len])
                cross_segment_mask = attentions.segment_mask(
                    segment_ids, source_segment_ids, dtype=np.float32)

        prng_key = jax.random.PRNGKey(seed=123)
        global_step = jnp.array(0, dtype=jnp.uint64)
        with base_layer.JaxContext.new_context(
                prng_key=prng_key, global_step=global_step) as jax_context:
            jax_context.bind(
                stacked_transformer_layer,
                stacked_transformer_layer.vars_to_flax_vars(initial_vars))
            fprop_outputs = stacked_transformer_layer.fprop(
                inputs,
                paddings,
                segment_mask=segment_mask,
                cross_inputs=cross_inputs,
                cross_paddings=cross_paddings,
                cross_segment_mask=cross_segment_mask)
            decoder_outputs = jnp.zeros(
                shape=[seq_len, batch_size, model_dims])
            initial_states = stacked_transformer_layer.init_states(
                batch_size, seq_len)
            atten_states = initial_states
            for t in range(seq_len):
                segment_mask_t = attention_mask[:, :, t, :]
                cross_segment_mask_t = cross_segment_mask
                if segment_mask is not None:
                    segment_mask_t = jnp.minimum(segment_mask_t,
                                                 segment_mask[:, :, t, :])
                if cross_segment_mask is not None:
                    cross_segment_mask_t = cross_segment_mask[:, :, t, :]
                atten_states, encoded = stacked_transformer_layer.extend_step(
                    atten_states,
                    inputs=inputs[:, t, :],
                    time_step=t,
                    segment_mask=segment_mask_t,
                    cross_inputs=cross_inputs,
                    cross_paddings=cross_paddings,
                    cross_segment_mask=cross_segment_mask_t)
                decoder_outputs = decoder_outputs.at[t].set(encoded)

        decoder_out_transposed = jnp.transpose(decoder_outputs, [1, 0, 2])
        # TODO(lepikhin): remove noisy test logging
        # logging.info('initial_vars in transformer layer = %s', initial_vars)
        np_fprop_outputs = test_utils.to_np(fprop_outputs)
        np_decoder_outputs = test_utils.to_np(decoder_out_transposed)
        self.assertAllClose(np_fprop_outputs, np_decoder_outputs, atol=1e-5)
Пример #18
0
    def test_mha_02(self):
        mdl_dim = 16
        hidden_dim = 32
        num_heads = 4
        test_layer_p = attentions.DotProductAttention.Params().Set(
            name='mh',
            input_dim=mdl_dim,
            hidden_dim=hidden_dim,
            num_heads=num_heads,
            atten_logit_cap=20.0,
        )
        layer = test_layer_p.Instantiate()

        prng_key = jax.random.PRNGKey(seed=123)
        prng_key, init_key = jax.random.split(prng_key)
        initial_vars = layer.instantiate_variables(init_key)

        target_batch_size = 3
        source_max_length = 8
        target_max_length = 8

        query_vec = np.random.normal(
            size=[target_batch_size, source_max_length, mdl_dim]).astype(
                np.float32)
        key_vec = np.random.normal(
            size=[target_batch_size, source_max_length, mdl_dim]).astype(
                np.float32)
        value_vec = np.random.normal(
            size=[target_batch_size, source_max_length, mdl_dim]).astype(
                np.float32)
        segment_ids = np.random.random_integers(
            0, 1, size=[target_batch_size, target_max_length]).astype(np.int32)
        atten_mask = attentions.causal_segment_mask(segment_ids, np.float32)

        jax_fprop_out, jax_atten_prob = test_utils.apply(
            layer, initial_vars, layer.fprop, query_vec, key_vec, value_vec,
            atten_mask)

        tf_layer_p = batch_major_attention.MultiHeadedAttention.Params().Set(
            name='mh',
            input_dim=mdl_dim,
            hidden_dim=hidden_dim,
            num_heads=num_heads,
            atten_logit_cap=20.0,
            packed_input=True)
        tf_layer = tf_layer_p.Instantiate()
        tf_out, tf_atten_prob = tf_layer.FProp(
            initial_vars,
            query_vec,
            key_vec,
            value_vec,
            paddings=tf.zeros([target_batch_size, source_max_length]),
            segment_mask=atten_mask)

        logging.info('jax_layer_out: %s', jax_fprop_out)
        logging.info('jax_atten_probs: %s', jax_atten_prob)
        logging.info('tf_layer_out: %s', tf_out)
        logging.info('tf_atten_probs: %s', tf_atten_prob)
        self.assertAllClose(test_utils.to_np(jax_fprop_out),
                            test_utils.to_np(tf_out))
        self.assertAllClose(test_utils.to_np(jax_atten_prob),
                            test_utils.to_np(tf_atten_prob))
Пример #19
0
 def test_transformer_layer_cross_attention_ln(self, packed_input):
     p = transformers.Transformer.Params().Set(name='jax_transformer_layer',
                                               input_dims=8,
                                               hidden_dims=32,
                                               num_heads=4,
                                               mask_self_attention=True,
                                               packed_input=packed_input,
                                               cross_attention=True)
     seq_len = 5
     batch_size = 4
     transformer_layer = p.Instantiate()
     prng_key = jax.random.PRNGKey(seed=123)
     initial_vars = transformer_layer.instantiate_variables(prng_key)
     # Change the self attention initial vars.
     initial_vars.layer_norm.scale = 0.5
     initial_vars.layer_norm.bias = 5.0
     # Change the cross attention initial vars.
     initial_vars.cross_layer_norm.scale = 15
     initial_vars.cross_layer_norm.bias = 1.5
     npy_inputs = np.random.normal(
         1.0, 0.5, [batch_size, seq_len, p.input_dims]).astype('float32')
     inputs = jnp.asarray(npy_inputs)
     npy_paddings = np.random.randint(
         0, 1, [batch_size, seq_len]).astype('float32')
     paddings = jnp.asarray(npy_paddings)
     attention_mask = attentions.convert_paddings_to_mask(paddings)
     causal_mask = attentions.causal_mask(inputs)
     attention_mask = jnp.minimum(causal_mask, attention_mask)
     if packed_input:
         segment_ids = np.random.random_integers(0, 2,
                                                 [batch_size, seq_len])
         segment_mask = attentions.segment_mask(segment_ids,
                                                dtype=np.float32)
         attention_mask = jnp.minimum(attention_mask, segment_mask)
     with base_layer.JaxContext.new_context(
             prng_key=prng_key,
             global_step=jnp.array(0, dtype=jnp.uint32)) as jax_context:
         jax_context.bind(transformer_layer,
                          transformer_layer.vars_to_flax_vars(initial_vars))
         inputs_normalized = transformer_layer.layer_norm.fprop(inputs)
         # Compute self-attention, key/value vectors are the input itself
         atten_output, _ = transformer_layer.self_attention.fprop(
             inputs_normalized,
             inputs_normalized,
             inputs_normalized,
             atten_mask=attention_mask)
         # Residual dropout and connection.
         atten_output = transformer_layer.residual_dropout.fprop(
             atten_output)
         atten_output += inputs
         # Normalize atten outputs using cross attention.
         atten_output_normalized = transformer_layer.cross_layer_norm.fprop(
             atten_output)
         inputs_normalized = test_utils.to_np(inputs_normalized)
         atten_output_normalized = test_utils.to_np(atten_output_normalized)
     self.assertAllClose(initial_vars.layer_norm.bias,
                         inputs_normalized.mean(),
                         atol=1e-3)
     self.assertAllClose((1.0 + initial_vars.layer_norm.scale)**2,
                         np.var(inputs_normalized),
                         atol=5e-3)
     self.assertAllClose(initial_vars.cross_layer_norm.bias,
                         atten_output_normalized.mean(),
                         atol=1e-3)
     self.assertAllClose((1.0 + initial_vars.cross_layer_norm.scale)**2,
                         np.var(atten_output_normalized),
                         atol=5e-3)
Пример #20
0
 def test_mask(self):
     a = np.random.random_integers(0, 5, size=[2, 50])
     jax_mask = attentions.causal_segment_mask(a, jnp.float32)
     tf_mask = batch_major_attention.CausalSegmentMask(a, tf.float32)
     self.assertAllClose(test_utils.to_np(jax_mask),
                         test_utils.to_np(tf_mask))
Пример #21
0
    def test_transformer_layer_extendstep(self, packed_input, cross_attention,
                                          dconv_qkv, use_rotary_position_emb):
        p = transformers.Transformer.Params().Set(
            name='jax_transformer_layer',
            input_dims=8,
            hidden_dims=32,
            num_heads=4,
            mask_self_attention=True,
            packed_input=packed_input,
            cross_attention=cross_attention)
        p.tr_atten_tpl.dconv_qkv = dconv_qkv
        p.tr_atten_tpl.use_rotary_position_emb = use_rotary_position_emb
        if cross_attention:
            p.cross_atten_tpl = p.tr_atten_tpl.Copy()
            # Cross attention should not have depth-wise convolution.
            p.cross_atten_tpl.dconv_qkv = False
            # Cross attention should not have rotary position embedding.
            p.cross_atten_tpl.use_rotary_position_emb = False

        p.tr_atten_tpl.dconv_kernel_size = 2
        seq_len = 4
        batch_size = 4
        transformer_layer = p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=123)
        initial_vars = transformer_layer.instantiate_variables(prng_key)
        initial_states = transformer_layer.init_states(batch_size, seq_len)
        npy_inputs = np.random.normal(
            1.0, 0.5, [batch_size, seq_len, p.input_dims]).astype('float32')
        inputs = jnp.asarray(npy_inputs)
        npy_paddings = np.random.randint(
            0, 1, [batch_size, seq_len]).astype('float32')
        # npy_paddings = np.zeros([batch_size, seq_len])
        paddings = jnp.asarray(npy_paddings)
        attention_mask = attentions.convert_paddings_to_mask(paddings)
        segment_mask = None
        causal_mask = attentions.causal_mask(inputs)
        attention_mask = jnp.minimum(causal_mask, attention_mask)
        if packed_input:
            segment_ids = np.random.random_integers(0, 2,
                                                    [batch_size, seq_len])
            segment_mask = attentions.segment_mask(segment_ids,
                                                   dtype=np.float32)
            attention_mask = jnp.minimum(attention_mask, segment_mask)
        cross_inputs = None
        cross_paddings = None
        cross_attention_mask = None
        if cross_attention:
            cross_seq_len = np.random.randint(10, 32)
            npy_cross_inputs = np.random.normal(
                1.0, 0.5,
                [batch_size, cross_seq_len, p.input_dims]).astype('float32')
            cross_inputs = jnp.asarray(npy_cross_inputs)
            npy_cross_paddings = np.random.randint(
                0, 1, [batch_size, cross_seq_len]).astype('float32')
            cross_paddings = jnp.asarray(npy_cross_paddings)
            cross_attention_mask = attentions.convert_paddings_to_mask(
                cross_paddings)
            if packed_input:
                source_segment_ids = np.random.random_integers(
                    0, 2, [batch_size, cross_seq_len])
                cross_segment_mask = attentions.segment_mask(
                    segment_ids, source_segment_ids, dtype=np.float32)
                cross_attention_mask = jnp.minimum(cross_attention_mask,
                                                   cross_segment_mask)

        with base_layer.JaxContext.new_context(
                prng_key=prng_key,
                global_step=jnp.array(0, dtype=jnp.uint32)) as jax_context:
            jax_context.bind(transformer_layer,
                             transformer_layer.vars_to_flax_vars(initial_vars))
            fprop_outputs, _ = transformer_layer.fprop(
                inputs,
                paddings,
                attention_mask=attention_mask,
                cross_inputs=cross_inputs,
                cross_attention_mask=cross_attention_mask)
            decoder_outputs = jnp.zeros(
                shape=[seq_len, batch_size, p.input_dims])
            atten_states = initial_states
            for t in range(seq_len):
                attention_mask_t = attention_mask[:, :, t, :]
                cross_attention_mask_t = cross_attention_mask
                if cross_attention:
                    cross_attention_mask_t = cross_attention_mask[:, :, t, :]
                    cross_attention_mask_t = np.expand_dims(
                        cross_attention_mask_t, axis=2)
                atten_states, encoded = transformer_layer.extend_step(
                    atten_states,
                    inputs=inputs[:, t, :],
                    time_step=t,
                    attention_mask=attention_mask_t,
                    cross_inputs=cross_inputs,
                    cross_attention_mask=cross_attention_mask_t)
                decoder_outputs = decoder_outputs.at[t].set(encoded)

        decoder_out_transposed = jnp.transpose(decoder_outputs, [1, 0, 2])
        logging.info('initial_vars in transformer layer = %s', initial_vars)
        np_fprop_outputs = test_utils.to_np(fprop_outputs)
        np_decoder_outputs = test_utils.to_np(decoder_out_transposed)
        self.assertAllClose(np_fprop_outputs, np_decoder_outputs, atol=1e-5)