예제 #1
0
    def test_mha_01(self, combine_qkv, dconv_qkv, dconv_kernel_size,
                    use_rotary_position_emb):
        mdl_dim = 16
        hidden_dim = 32
        num_heads = 4
        test_layer_p = attentions.DotProductAttention.Params().Set(
            name='mh',
            input_dim=mdl_dim,
            hidden_dim=hidden_dim,
            num_heads=num_heads,
            dim_per_head=16 if use_rotary_position_emb else None,
            atten_logit_cap=20.0,
            combine_qkv=combine_qkv,
            dconv_qkv=dconv_qkv,
            dconv_kernel_size=dconv_kernel_size,
            use_rotary_position_emb=use_rotary_position_emb)
        layer = test_layer_p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=123)
        prng_key, init_key = jax.random.split(prng_key)
        initial_vars = layer.instantiate_variables(init_key)
        logging.info('initial_vars: %s', initial_vars)
        target_batch_size = 3
        source_max_length = 16
        target_max_length = 16
        initial_states = layer.init_states(target_batch_size,
                                           target_max_length)
        query_vec = np.random.normal(
            size=[target_batch_size, source_max_length, mdl_dim]).astype(
                np.float32)
        key_vec = query_vec
        value_vec = query_vec
        atten_mask = attentions.causal_mask(query_vec)

        prng_key, compute_key = jax.random.split(prng_key)
        global_step = jnp.array(0, dtype=jnp.uint64)

        with base_layer.JaxContext.new_context(
                prng_key=compute_key, global_step=global_step) as jax_context:
            jax_context.bind(layer, layer.vars_to_flax_vars(initial_vars))
            fprop_out, _ = layer.fprop(query_vec, key_vec, value_vec,
                                       atten_mask)

            decoder_output = jnp.zeros(
                shape=[target_max_length, target_batch_size, mdl_dim])
            atten_states = initial_states
            for t in range(target_max_length):
                atten_states, encoded = layer.extend_step(
                    atten_states,
                    query_vec=query_vec[:, t, :],
                    atten_mask=atten_mask[:, :, t, :],
                    time_step=t)
                decoder_output = decoder_output.at[t].set(encoded)

        decoder_out_transposed = jnp.transpose(decoder_output, [1, 0, 2])

        logging.info('fprop_out: %s', fprop_out)
        logging.info('decoder_out: %s', decoder_output)
        self.assertAllClose(fprop_out, decoder_out_transposed)
예제 #2
0
    def test_transformer_relative_bias(self, use_relative_bias):
        p = transformers.Transformer.Params().Set(name='jax_transformer_layer',
                                                  input_dims=32,
                                                  hidden_dims=128,
                                                  num_heads=8,
                                                  mask_self_attention=True,
                                                  packed_input=True,
                                                  cross_attention=False)
        seq_len = np.random.randint(10, 32)
        batch_size = 10
        if use_relative_bias:
            p.tr_atten_tpl.relative_bias_tpl = attentions.RelativeBias.Params(
            ).Set(relative_attention_num_buckets=2,
                  relative_attention_max_distance=8)
        transformer_layer = p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=123)
        initial_vars = transformer_layer.instantiate_variables(prng_key)
        npy_inputs = np.random.normal(
            1.0, 0.5, [batch_size, seq_len, p.input_dims]).astype('float32')
        inputs = jnp.asarray(npy_inputs)
        npy_paddings = np.random.randint(
            0, 1, [batch_size, seq_len]).astype('float32')
        paddings = jnp.asarray(npy_paddings)
        attention_mask = attentions.convert_paddings_to_mask(paddings)
        causal_mask = attentions.causal_mask(inputs)
        attention_mask = jnp.minimum(attention_mask, causal_mask)
        segment_ids = np.random.random_integers(0, 2, [batch_size, seq_len])
        segment_mask = attentions.segment_mask(segment_ids, dtype=np.float32)
        attention_mask = jnp.minimum(attention_mask, segment_mask)

        if use_relative_bias:
            segment_pos = np.random.randint(
                0, seq_len, [batch_size, seq_len]).astype('int32')
            segment_pos = jnp.asarray(segment_pos)
        else:
            segment_pos = None

        with base_layer.JaxContext.new_context(
                prng_key=prng_key,
                global_step=jnp.array(0, dtype=jnp.uint32)) as jax_context:
            jax_context.bind(transformer_layer,
                             transformer_layer.vars_to_flax_vars(initial_vars))
            outputs, _ = transformer_layer.fprop(inputs,
                                                 paddings,
                                                 attention_mask=attention_mask,
                                                 segment_pos=segment_pos)
        logging.info('initial_vars in transformer layer = %s', initial_vars)

        np_outputs = test_utils.to_np(outputs)
        logging.info('np_outputs: %s', np_outputs)
        if use_relative_bias:
            self.assertAlmostEqual(np_outputs[0, 0, 1], 0.79015386, places=5)
            self.assertAlmostEqual(np_outputs[0, 1, 0], 0.48336178, places=5)
        # Plumbing test.
        self.assertAllClose(np_outputs, np_outputs, atol=1e-5)
예제 #3
0
    def test_relative_bias_extend_step(self, num_buckets, max_distance,
                                       attention_extra_logit):
        mdl_dim = 16
        hidden_dim = 32
        num_heads = 4
        test_layer_p = attentions.DotProductAttention.Params().Set(
            name='relative_attn',
            input_dim=mdl_dim,
            hidden_dim=hidden_dim,
            attention_extra_logit=attention_extra_logit,
            num_heads=num_heads)
        test_layer_p.relative_bias_tpl = attentions.RelativeBias.Params().Set(
            relative_attention_num_buckets=num_buckets,
            relative_attention_max_distance=max_distance)
        layer = test_layer_p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=123)
        prng_key, init_key = jax.random.split(prng_key)
        initial_vars = layer.instantiate_variables(init_key)
        target_batch_size = 2
        source_max_length = 8
        target_max_length = 8
        inputs = np.random.normal(
            size=[target_batch_size, source_max_length, mdl_dim]).astype(
                np.float32)
        atten_mask = attentions.causal_mask(inputs)
        initial_states = layer.init_states(target_batch_size,
                                           target_max_length)

        time_step = 2

        _, atten_output = test_utils.apply(layer,
                                           initial_vars,
                                           layer.extend_step,
                                           initial_states,
                                           inputs[:, time_step, :],
                                           atten_mask=atten_mask[:, :,
                                                                 time_step, :],
                                           time_step=time_step)

        self.assertEqual(atten_output.shape, (target_batch_size, mdl_dim))
예제 #4
0
    def test_transformer_layer_norm_policies(self, norm_policy):
        p = transformers.Transformer.Params().Set(name='jax_transformer_layer',
                                                  input_dims=32,
                                                  hidden_dims=128,
                                                  num_heads=8,
                                                  mask_self_attention=True,
                                                  packed_input=True,
                                                  cross_attention=False,
                                                  norm_policy=norm_policy)
        seq_len = np.random.randint(10, 32)
        batch_size = 10
        transformer_layer = p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=123)
        initial_vars = transformer_layer.instantiate_variables(prng_key)
        npy_inputs = np.random.normal(
            1.0, 0.5, [batch_size, seq_len, p.input_dims]).astype('float32')
        inputs = jnp.asarray(npy_inputs)
        npy_paddings = np.random.randint(
            0, 1, [batch_size, seq_len]).astype('float32')
        paddings = jnp.asarray(npy_paddings)
        attention_mask = attentions.convert_paddings_to_mask(paddings)
        causal_mask = attentions.causal_mask(inputs)
        attention_mask = jnp.minimum(attention_mask, causal_mask)
        segment_ids = np.random.random_integers(0, 2, [batch_size, seq_len])
        segment_mask = attentions.segment_mask(segment_ids, dtype=np.float32)
        attention_mask = jnp.minimum(attention_mask, segment_mask)

        with base_layer.JaxContext.new_context(
                prng_key=prng_key,
                global_step=jnp.array(0, dtype=jnp.uint32)) as jax_context:
            jax_context.bind(transformer_layer,
                             transformer_layer.vars_to_flax_vars(initial_vars))
            outputs, _ = transformer_layer.fprop(inputs,
                                                 paddings,
                                                 attention_mask=attention_mask)
        logging.info('initial_vars in transformer layer = %s', initial_vars)

        np_outputs = test_utils.to_np(outputs)
        # Plumbing test.
        self.assertAllClose(np_outputs, np_outputs, atol=1e-5)
예제 #5
0
    def test_relative_bias(self, num_buckets, max_distance):
        mdl_dim = 16
        hidden_dim = 32
        num_heads = 4
        test_layer_p = attentions.DotProductAttention.Params().Set(
            name='relative_attn',
            input_dim=mdl_dim,
            hidden_dim=hidden_dim,
            num_heads=num_heads)
        test_layer_p.relative_bias_tpl = attentions.RelativeBias.Params().Set(
            relative_attention_num_buckets=num_buckets,
            relative_attention_max_distance=max_distance)
        layer = test_layer_p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=123)
        prng_key, init_key = jax.random.split(prng_key)
        initial_vars = layer.instantiate_variables(init_key)
        target_batch_size = 3
        source_max_length = 16
        query_vec = np.random.normal(
            size=[target_batch_size, source_max_length, mdl_dim]).astype(
                np.float32)
        key_vec = query_vec
        value_vec = query_vec
        segment_pos = np.random.randint(
            0, source_max_length,
            [target_batch_size, source_max_length]).astype('int32')
        atten_mask = attentions.causal_mask(query_vec)

        atten_output, _ = test_utils.apply(layer,
                                           initial_vars,
                                           layer.fprop,
                                           query_vec,
                                           key_vec,
                                           value_vec,
                                           atten_mask=atten_mask,
                                           query_segment_pos=segment_pos)

        self.assertEqual(atten_output.shape,
                         (target_batch_size, source_max_length, mdl_dim))
예제 #6
0
    def test_transformer_layer(self, mask_self_attention, packed_input,
                               cross_attention):
        p = transformers.Transformer.Params().Set(
            name='jax_transformer_layer',
            input_dims=32,
            hidden_dims=128,
            num_heads=8,
            mask_self_attention=mask_self_attention,
            packed_input=packed_input,
            cross_attention=cross_attention)
        seq_len = np.random.randint(10, 32)
        batch_size = 10
        transformer_layer = p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=123)
        initial_vars = transformer_layer.instantiate_variables(prng_key)
        npy_inputs = np.random.normal(
            1.0, 0.5, [batch_size, seq_len, p.input_dims]).astype('float32')
        inputs = jnp.asarray(npy_inputs)
        npy_paddings = np.random.randint(
            0, 1, [batch_size, seq_len]).astype('float32')
        paddings = jnp.asarray(npy_paddings)
        causal_mask = None
        segment_mask = None
        tf_segment_mask = None
        attention_mask = attentions.convert_paddings_to_mask(paddings)
        if mask_self_attention:
            causal_mask = attentions.causal_mask(inputs)
            attention_mask = jnp.minimum(attention_mask, causal_mask)
        if packed_input:
            segment_ids = np.random.random_integers(0, 2,
                                                    [batch_size, seq_len])
            segment_mask = attentions.segment_mask(segment_ids,
                                                   dtype=np.float32)
            attention_mask = jnp.minimum(attention_mask, segment_mask)
            if mask_self_attention:
                tf_segment_mask = batch_major_attention.CausalSegmentMask(
                    segment_ids, tf.float32)
            else:
                tf_segment_mask = batch_major_attention.SegmentMask(
                    segment_ids, segment_ids)

        cross_inputs = None
        cross_attention_mask = None
        tf_cross_inputs = None
        tf_cross_paddings = None
        tf_cross_segment_mask = None
        if cross_attention:
            cross_seq_len = np.random.randint(10, 128)
            npy_cross_inputs = np.random.normal(
                1.0, 0.5,
                [batch_size, cross_seq_len, p.input_dims]).astype('float32')
            cross_inputs = jnp.asarray(npy_cross_inputs)
            tf_cross_inputs = tf.constant(npy_cross_inputs, dtype=tf.float32)
            npy_cross_paddings = np.random.randint(
                0, 1, [batch_size, cross_seq_len]).astype('float32')
            cross_paddings = jnp.asarray(npy_cross_paddings)
            cross_attention_mask = attentions.convert_paddings_to_mask(
                cross_paddings)
            tf_cross_paddings = tf.constant(npy_cross_paddings,
                                            dtype=tf.float32)
            if packed_input:
                source_segment_ids = np.random.random_integers(
                    0, 2, [batch_size, cross_seq_len])
                cross_segment_mask = attentions.segment_mask(
                    segment_ids, source_segment_ids, dtype=np.float32)
                cross_attention_mask = jnp.minimum(cross_attention_mask,
                                                   cross_segment_mask)
                tf_cross_segment_mask = batch_major_attention.SegmentMask(
                    segment_ids, source_segment_ids)

        outputs, _ = test_utils.apply(
            transformer_layer,
            initial_vars,
            transformer_layer.fprop,
            inputs,
            paddings,
            context_p=None,
            attention_mask=attention_mask,
            cross_inputs=cross_inputs,
            cross_attention_mask=cross_attention_mask)
        logging.info('initial_vars in transformer layer = %s', initial_vars)

        # Test whether tf Transformer layer returns same output
        # Modify initial_vars to use TF compatible params
        tf_initial_vars = test_utils.replace_jax_attention_vars_to_tf(
            initial_vars, cross_attention)
        tf_initial_vars = test_utils.to_tf_nmap(tf_initial_vars)
        logging.info('tf_initial_vars in transformer layer = %s', initial_vars)
        tf_p = batch_major_attention.TransformerLayer.Params().Set(
            name='tf_transformer_layer',
            input_dim=p.input_dims,
            num_heads=p.num_heads,
            mask_self_atten=mask_self_attention,
            packed_input=packed_input,
            has_aux_atten=cross_attention)
        tf_p.tr_fflayer_tpl.hidden_dim = p.hidden_dims
        tf_p.tr_fflayer_tpl.fflayer_tpl.batch_norm = False
        tf_p.tr_fflayer_tpl.fflayer_tpl.has_bias = True
        tf_transformer_layer = tf_p.Instantiate()
        tf_output, _ = tf_transformer_layer.FProp(
            tf_initial_vars,
            tf.constant(npy_inputs, dtype=tf.float32),
            paddings=test_utils.to_tf_nmap(npy_paddings),
            segment_mask=tf_segment_mask,
            aux_vec=tf_cross_inputs,
            aux_paddings=tf_cross_paddings,
            aux_segment_mask=test_utils.to_tf_nmap(tf_cross_segment_mask))
        np_outputs = test_utils.to_np(outputs)
        tf_np_outputs = test_utils.to_np(tf_output)
        self.assertAllClose(tf_np_outputs, np_outputs, atol=1e-5)
예제 #7
0
 def test_transformer_layer_cross_attention_ln(self, packed_input):
     p = transformers.Transformer.Params().Set(name='jax_transformer_layer',
                                               input_dims=8,
                                               hidden_dims=32,
                                               num_heads=4,
                                               mask_self_attention=True,
                                               packed_input=packed_input,
                                               cross_attention=True)
     seq_len = 5
     batch_size = 4
     transformer_layer = p.Instantiate()
     prng_key = jax.random.PRNGKey(seed=123)
     initial_vars = transformer_layer.instantiate_variables(prng_key)
     # Change the self attention initial vars.
     initial_vars.layer_norm.scale = 0.5
     initial_vars.layer_norm.bias = 5.0
     # Change the cross attention initial vars.
     initial_vars.cross_layer_norm.scale = 15
     initial_vars.cross_layer_norm.bias = 1.5
     npy_inputs = np.random.normal(
         1.0, 0.5, [batch_size, seq_len, p.input_dims]).astype('float32')
     inputs = jnp.asarray(npy_inputs)
     npy_paddings = np.random.randint(
         0, 1, [batch_size, seq_len]).astype('float32')
     paddings = jnp.asarray(npy_paddings)
     attention_mask = attentions.convert_paddings_to_mask(paddings)
     causal_mask = attentions.causal_mask(inputs)
     attention_mask = jnp.minimum(causal_mask, attention_mask)
     if packed_input:
         segment_ids = np.random.random_integers(0, 2,
                                                 [batch_size, seq_len])
         segment_mask = attentions.segment_mask(segment_ids,
                                                dtype=np.float32)
         attention_mask = jnp.minimum(attention_mask, segment_mask)
     with base_layer.JaxContext.new_context(
             prng_key=prng_key,
             global_step=jnp.array(0, dtype=jnp.uint32)) as jax_context:
         jax_context.bind(transformer_layer,
                          transformer_layer.vars_to_flax_vars(initial_vars))
         inputs_normalized = transformer_layer.layer_norm.fprop(inputs)
         # Compute self-attention, key/value vectors are the input itself
         atten_output, _ = transformer_layer.self_attention.fprop(
             inputs_normalized,
             inputs_normalized,
             inputs_normalized,
             atten_mask=attention_mask)
         # Residual dropout and connection.
         atten_output = transformer_layer.residual_dropout.fprop(
             atten_output)
         atten_output += inputs
         # Normalize atten outputs using cross attention.
         atten_output_normalized = transformer_layer.cross_layer_norm.fprop(
             atten_output)
         inputs_normalized = test_utils.to_np(inputs_normalized)
         atten_output_normalized = test_utils.to_np(atten_output_normalized)
     self.assertAllClose(initial_vars.layer_norm.bias,
                         inputs_normalized.mean(),
                         atol=1e-3)
     self.assertAllClose((1.0 + initial_vars.layer_norm.scale)**2,
                         np.var(inputs_normalized),
                         atol=5e-3)
     self.assertAllClose(initial_vars.cross_layer_norm.bias,
                         atten_output_normalized.mean(),
                         atol=1e-3)
     self.assertAllClose((1.0 + initial_vars.cross_layer_norm.scale)**2,
                         np.var(atten_output_normalized),
                         atol=5e-3)
예제 #8
0
    def test_transformer_layer_extendstep(self, packed_input, cross_attention,
                                          dconv_qkv, use_rotary_position_emb):
        p = transformers.Transformer.Params().Set(
            name='jax_transformer_layer',
            input_dims=8,
            hidden_dims=32,
            num_heads=4,
            mask_self_attention=True,
            packed_input=packed_input,
            cross_attention=cross_attention)
        p.tr_atten_tpl.dconv_qkv = dconv_qkv
        p.tr_atten_tpl.use_rotary_position_emb = use_rotary_position_emb
        if cross_attention:
            p.cross_atten_tpl = p.tr_atten_tpl.Copy()
            # Cross attention should not have depth-wise convolution.
            p.cross_atten_tpl.dconv_qkv = False
            # Cross attention should not have rotary position embedding.
            p.cross_atten_tpl.use_rotary_position_emb = False

        p.tr_atten_tpl.dconv_kernel_size = 2
        seq_len = 4
        batch_size = 4
        transformer_layer = p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=123)
        initial_vars = transformer_layer.instantiate_variables(prng_key)
        initial_states = transformer_layer.init_states(batch_size, seq_len)
        npy_inputs = np.random.normal(
            1.0, 0.5, [batch_size, seq_len, p.input_dims]).astype('float32')
        inputs = jnp.asarray(npy_inputs)
        npy_paddings = np.random.randint(
            0, 1, [batch_size, seq_len]).astype('float32')
        # npy_paddings = np.zeros([batch_size, seq_len])
        paddings = jnp.asarray(npy_paddings)
        attention_mask = attentions.convert_paddings_to_mask(paddings)
        segment_mask = None
        causal_mask = attentions.causal_mask(inputs)
        attention_mask = jnp.minimum(causal_mask, attention_mask)
        if packed_input:
            segment_ids = np.random.random_integers(0, 2,
                                                    [batch_size, seq_len])
            segment_mask = attentions.segment_mask(segment_ids,
                                                   dtype=np.float32)
            attention_mask = jnp.minimum(attention_mask, segment_mask)
        cross_inputs = None
        cross_paddings = None
        cross_attention_mask = None
        if cross_attention:
            cross_seq_len = np.random.randint(10, 32)
            npy_cross_inputs = np.random.normal(
                1.0, 0.5,
                [batch_size, cross_seq_len, p.input_dims]).astype('float32')
            cross_inputs = jnp.asarray(npy_cross_inputs)
            npy_cross_paddings = np.random.randint(
                0, 1, [batch_size, cross_seq_len]).astype('float32')
            cross_paddings = jnp.asarray(npy_cross_paddings)
            cross_attention_mask = attentions.convert_paddings_to_mask(
                cross_paddings)
            if packed_input:
                source_segment_ids = np.random.random_integers(
                    0, 2, [batch_size, cross_seq_len])
                cross_segment_mask = attentions.segment_mask(
                    segment_ids, source_segment_ids, dtype=np.float32)
                cross_attention_mask = jnp.minimum(cross_attention_mask,
                                                   cross_segment_mask)

        with base_layer.JaxContext.new_context(
                prng_key=prng_key,
                global_step=jnp.array(0, dtype=jnp.uint32)) as jax_context:
            jax_context.bind(transformer_layer,
                             transformer_layer.vars_to_flax_vars(initial_vars))
            fprop_outputs, _ = transformer_layer.fprop(
                inputs,
                paddings,
                attention_mask=attention_mask,
                cross_inputs=cross_inputs,
                cross_attention_mask=cross_attention_mask)
            decoder_outputs = jnp.zeros(
                shape=[seq_len, batch_size, p.input_dims])
            atten_states = initial_states
            for t in range(seq_len):
                attention_mask_t = attention_mask[:, :, t, :]
                cross_attention_mask_t = cross_attention_mask
                if cross_attention:
                    cross_attention_mask_t = cross_attention_mask[:, :, t, :]
                    cross_attention_mask_t = np.expand_dims(
                        cross_attention_mask_t, axis=2)
                atten_states, encoded = transformer_layer.extend_step(
                    atten_states,
                    inputs=inputs[:, t, :],
                    time_step=t,
                    attention_mask=attention_mask_t,
                    cross_inputs=cross_inputs,
                    cross_attention_mask=cross_attention_mask_t)
                decoder_outputs = decoder_outputs.at[t].set(encoded)

        decoder_out_transposed = jnp.transpose(decoder_outputs, [1, 0, 2])
        logging.info('initial_vars in transformer layer = %s', initial_vars)
        np_fprop_outputs = test_utils.to_np(fprop_outputs)
        np_decoder_outputs = test_utils.to_np(decoder_out_transposed)
        self.assertAllClose(np_fprop_outputs, np_decoder_outputs, atol=1e-5)