Exemple #1
0
    def test_stacked_transformer_layer(self, mask_self_attention, packed_input,
                                       cross_attention):
        p = transformers.StackedTransformer.Params().Set(
            name='jax_stacked_transformer_layer',
            model_dims=16,
            hidden_dims=64,
            num_heads=8,
            mask_self_attention=mask_self_attention,
            num_layers=4,
            packed_input=packed_input,
            cross_attention=cross_attention)
        seq_len = np.random.randint(10, 32)
        batch_size = 10
        stacked_transformer_layer = p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=123)
        initial_vars = stacked_transformer_layer.instantiate_variables(
            prng_key)

        # test conversion between vars and flax vars.
        pax_vars = stacked_transformer_layer.vars
        flax_vars = stacked_transformer_layer.flax_vars
        tf.nest.assert_same_structure(
            pax_vars, stacked_transformer_layer.flax_vars_to_vars(flax_vars))
        tf.nest.assert_same_structure(
            flax_vars, stacked_transformer_layer.vars_to_flax_vars(pax_vars))

        npy_inputs = np.random.normal(
            1.0, 0.5, [batch_size, seq_len, p.model_dims]).astype('float32')
        inputs = jnp.asarray(npy_inputs)
        npy_paddings = np.random.randint(
            0, 1, [batch_size, seq_len]).astype('float32')
        paddings = jnp.asarray(npy_paddings)
        segment_mask = None
        tf_segment_mask = None
        if packed_input:
            segment_ids = np.random.random_integers(0, 2,
                                                    [batch_size, seq_len])
            segment_mask = attentions.segment_mask(segment_ids,
                                                   dtype=np.float32)
            if mask_self_attention:
                tf_segment_mask = batch_major_attention.CausalSegmentMask(
                    segment_ids, tf.float32)
            else:
                tf_segment_mask = batch_major_attention.SegmentMask(
                    segment_ids, segment_ids)

        cross_inputs = None
        cross_paddings = None
        cross_segment_mask = None
        tf_cross_inputs = None
        tf_cross_paddings = None
        tf_cross_segment_mask = None
        if cross_attention:
            cross_seq_len = np.random.randint(10, 64)
            npy_cross_inputs = np.random.normal(
                1.0, 0.5,
                [batch_size, cross_seq_len, p.model_dims]).astype('float32')
            cross_inputs = jnp.asarray(npy_cross_inputs)
            tf_cross_inputs = tf.constant(npy_cross_inputs, dtype=tf.float32)
            npy_cross_paddings = np.random.randint(
                0, 1, [batch_size, cross_seq_len]).astype('float32')
            cross_paddings = jnp.asarray(npy_cross_paddings)
            tf_cross_paddings = tf.constant(npy_cross_paddings,
                                            dtype=tf.float32)
            if packed_input:
                source_segment_ids = np.random.random_integers(
                    0, 2, [batch_size, cross_seq_len])
                cross_segment_mask = attentions.segment_mask(
                    segment_ids, source_segment_ids, dtype=np.float32)
                tf_cross_segment_mask = batch_major_attention.SegmentMask(
                    segment_ids, source_segment_ids)

        outputs = test_utils.apply(stacked_transformer_layer,
                                   initial_vars,
                                   stacked_transformer_layer.fprop,
                                   inputs,
                                   paddings,
                                   context_p=None,
                                   segment_mask=segment_mask,
                                   cross_inputs=cross_inputs,
                                   cross_paddings=cross_paddings,
                                   cross_segment_mask=cross_segment_mask)
        logging.info('initial_vars in transformer layer = %s', initial_vars)

        # Test whether tf Transformer layer returns same output
        # Modify initial_vars to use TF compatible params
        tf_initial_vars = py_utils.NestedMap()
        tf_initial_vars.x_layers = []
        for jax_initial_vars in initial_vars.x_layers:
            tf_layer_vars = test_utils.replace_jax_attention_vars_to_tf(
                jax_initial_vars, cross_attention)
            tf_initial_vars.x_layers.append(tf_layer_vars)
        tf_initial_vars = test_utils.to_tf_nmap(tf_initial_vars)
        logging.info('tf_initial_vars in transformer layer = %s', initial_vars)
        tf_p = batch_major_attention.StackedTransformerLayers.Params().Set(
            name='tf_transformer_layer',
            mdl_dim=p.model_dims,
            hidden_dim=p.hidden_dims,
            num_atten_heads=p.num_heads,
            mask_self_atten=mask_self_attention,
            num_layers=p.num_layers,
            packed_input=packed_input,
            has_aux_atten=cross_attention)
        tf_p.transformer_layer_params_tpl.tr_fflayer_tpl.fflayer_tpl.batch_norm = (
            False)
        tf_p.transformer_layer_params_tpl.tr_fflayer_tpl.fflayer_tpl.has_bias = True
        tf_stacked_transformer_layer = tf_p.Instantiate()
        tf_output, _ = tf_stacked_transformer_layer.FProp(
            tf_initial_vars,
            test_utils.to_tf_nmap(npy_inputs),
            paddings=test_utils.to_tf_nmap(npy_paddings),
            segment_mask=test_utils.to_tf_nmap(tf_segment_mask),
            aux_vec=test_utils.to_tf_nmap(tf_cross_inputs),
            aux_paddings=test_utils.to_tf_nmap(tf_cross_paddings),
            aux_segment_mask=test_utils.to_tf_nmap(tf_cross_segment_mask))
        np_outputs = test_utils.to_np(outputs)
        tf_np_outputs = test_utils.to_np(tf_output)
        self.assertAllClose(tf_np_outputs, np_outputs, atol=1e-5)
Exemple #2
0
    def FProp(self, theta, input_batch):
        """Embeds source ids and transforms with TransformerStack.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      input_batch: A `.NestedMap` object containing: ids - The inputs tensor of
        shape [batch, time]. paddings - The ids' paddings of shape [batch,
        time].

    Returns:
      A '.NestedMap' object containing:
        encoded - The encoded features of shape [time, batch, dim] or [batch,
          time, dim], depending p.output_data_format.
        padding - The encoded features' padding of shape [time, batch] or
          [batch, time].
        segment_id - The segmentation of packed inputs of shape [time, batch] or
          [batch, time] if it is supported by the model, or None otherwise.
        embedded_inputs - The embedded inputs tokens without positional
          encodings of shape [time, batch, dim] or [batch, time, dim].
    """

        p = self.params
        with tf.name_scope(p.name):
            # [batch, time]
            input_ids = input_batch.ids
            # [batch, time]
            paddings = input_batch.paddings

            # [batch, time]
            segment_ids = input_batch.segment_ids if p.packed_input else None

            batch = py_utils.GetShape(input_ids)[0]
            time = py_utils.GetShape(input_ids)[1]

            # Embedding layer.
            # [batch, time, dim]
            if not p.shared_emb:
                input_embs = self.token_emb.EmbLookup(theta.token_emb,
                                                      input_ids)
            else:
                input_embs = self.softmax.EmbLookup(theta.softmax, input_ids)
            orig_input_embs = input_embs

            # [1, time, dim]
            if p.packed_input:
                positions = input_batch.segment_pos
                position_embs = tf.expand_dims(
                    self.position_emb.FPropWithPosition(
                        theta.position_emb, positions), 0)
            else:
                position_embs = tf.expand_dims(
                    self.position_emb.FProp(theta.position_emb, time), 0)

            # [batch, time, dim]
            input_embs += position_embs

            if p.input_dropout_tpl.fprop_dtype:
                input_embs = tf.cast(input_embs,
                                     p.input_dropout_tpl.fprop_dtype)
                paddings = tf.cast(paddings, p.input_dropout_tpl.fprop_dtype)

            input_embs = self.input_dropout.FProp(theta.input_dropout,
                                                  input_embs)
            # [batch, time, dim]
            transformer_input = input_embs
            # Explicitly set the input shape of Transformer layers, to avoid
            # unknown shape error occurred to tf.einsum on nonTPU devices.
            transformer_input = tf.reshape(transformer_input,
                                           [batch, time, p.model_dim])

            # Compute self-attention segment mask once.
            if p.packed_input:
                segment_mask = batch_major_attention.SegmentMask(
                    segment_ids, segment_ids, dtype=transformer_input.dtype)
            else:
                segment_mask = tf.zeros([batch, 1, time, time])

            encoded, padding = self.transformer_stack.FProp(
                theta.transformer_stack, transformer_input, paddings,
                segment_mask)

            if p.final_layer_norm:
                encoded = self.final_ln.FProp(theta.final_ln, encoded)

            seq_lengths = tf.cast(tf.reduce_sum(1. - padding, axis=1),
                                  tf.int32)

            if p.output_data_format == 'TBC':
                encoded = tf.transpose(encoded,
                                       [1, 0, 2])  # [time, batch, dim]
                padding = tf.transpose(padding)  # [time, batch]
                segment_ids = tf.transpose(
                    segment_ids) if p.packed_input else None
                orig_input_embs = tf.transpose(orig_input_embs, [1, 0, 2])

            return py_utils.NestedMap(
                encoded=encoded,
                padding=padding,
                seq_lengths=seq_lengths,  # used by beam_search_helper.
                segment_id=segment_ids,
                embedded_inputs=orig_input_embs)
Exemple #3
0
    def test_transformer_layer(self, mask_self_attention, packed_input,
                               cross_attention):
        p = transformers.Transformer.Params().Set(
            name='jax_transformer_layer',
            input_dims=32,
            hidden_dims=128,
            num_heads=8,
            mask_self_attention=mask_self_attention,
            packed_input=packed_input,
            cross_attention=cross_attention)
        seq_len = np.random.randint(10, 32)
        batch_size = 10
        transformer_layer = p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=123)
        initial_vars = transformer_layer.instantiate_variables(prng_key)
        npy_inputs = np.random.normal(
            1.0, 0.5, [batch_size, seq_len, p.input_dims]).astype('float32')
        inputs = jnp.asarray(npy_inputs)
        npy_paddings = np.random.randint(
            0, 1, [batch_size, seq_len]).astype('float32')
        paddings = jnp.asarray(npy_paddings)
        causal_mask = None
        segment_mask = None
        tf_segment_mask = None
        attention_mask = attentions.convert_paddings_to_mask(paddings)
        if mask_self_attention:
            causal_mask = attentions.causal_mask(inputs)
            attention_mask = jnp.minimum(attention_mask, causal_mask)
        if packed_input:
            segment_ids = np.random.random_integers(0, 2,
                                                    [batch_size, seq_len])
            segment_mask = attentions.segment_mask(segment_ids,
                                                   dtype=np.float32)
            attention_mask = jnp.minimum(attention_mask, segment_mask)
            if mask_self_attention:
                tf_segment_mask = batch_major_attention.CausalSegmentMask(
                    segment_ids, tf.float32)
            else:
                tf_segment_mask = batch_major_attention.SegmentMask(
                    segment_ids, segment_ids)

        cross_inputs = None
        cross_attention_mask = None
        tf_cross_inputs = None
        tf_cross_paddings = None
        tf_cross_segment_mask = None
        if cross_attention:
            cross_seq_len = np.random.randint(10, 128)
            npy_cross_inputs = np.random.normal(
                1.0, 0.5,
                [batch_size, cross_seq_len, p.input_dims]).astype('float32')
            cross_inputs = jnp.asarray(npy_cross_inputs)
            tf_cross_inputs = tf.constant(npy_cross_inputs, dtype=tf.float32)
            npy_cross_paddings = np.random.randint(
                0, 1, [batch_size, cross_seq_len]).astype('float32')
            cross_paddings = jnp.asarray(npy_cross_paddings)
            cross_attention_mask = attentions.convert_paddings_to_mask(
                cross_paddings)
            tf_cross_paddings = tf.constant(npy_cross_paddings,
                                            dtype=tf.float32)
            if packed_input:
                source_segment_ids = np.random.random_integers(
                    0, 2, [batch_size, cross_seq_len])
                cross_segment_mask = attentions.segment_mask(
                    segment_ids, source_segment_ids, dtype=np.float32)
                cross_attention_mask = jnp.minimum(cross_attention_mask,
                                                   cross_segment_mask)
                tf_cross_segment_mask = batch_major_attention.SegmentMask(
                    segment_ids, source_segment_ids)

        outputs, _ = test_utils.apply(
            transformer_layer,
            initial_vars,
            transformer_layer.fprop,
            inputs,
            paddings,
            context_p=None,
            attention_mask=attention_mask,
            cross_inputs=cross_inputs,
            cross_attention_mask=cross_attention_mask)
        logging.info('initial_vars in transformer layer = %s', initial_vars)

        # Test whether tf Transformer layer returns same output
        # Modify initial_vars to use TF compatible params
        tf_initial_vars = test_utils.replace_jax_attention_vars_to_tf(
            initial_vars, cross_attention)
        tf_initial_vars = test_utils.to_tf_nmap(tf_initial_vars)
        logging.info('tf_initial_vars in transformer layer = %s', initial_vars)
        tf_p = batch_major_attention.TransformerLayer.Params().Set(
            name='tf_transformer_layer',
            input_dim=p.input_dims,
            num_heads=p.num_heads,
            mask_self_atten=mask_self_attention,
            packed_input=packed_input,
            has_aux_atten=cross_attention)
        tf_p.tr_fflayer_tpl.hidden_dim = p.hidden_dims
        tf_p.tr_fflayer_tpl.fflayer_tpl.batch_norm = False
        tf_p.tr_fflayer_tpl.fflayer_tpl.has_bias = True
        tf_transformer_layer = tf_p.Instantiate()
        tf_output, _ = tf_transformer_layer.FProp(
            tf_initial_vars,
            tf.constant(npy_inputs, dtype=tf.float32),
            paddings=test_utils.to_tf_nmap(npy_paddings),
            segment_mask=tf_segment_mask,
            aux_vec=tf_cross_inputs,
            aux_paddings=tf_cross_paddings,
            aux_segment_mask=test_utils.to_tf_nmap(tf_cross_segment_mask))
        np_outputs = test_utils.to_np(outputs)
        tf_np_outputs = test_utils.to_np(tf_output)
        self.assertAllClose(tf_np_outputs, np_outputs, atol=1e-5)