Beispiel #1
0
        def _call(input_embeddings, input_shape, padding_mask, past,
                  attention_mask, training):

            if attention_mask is not None:
                attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
                attention_mask = tf.cast(attention_mask, tf.float32)
                attention_mask = (1.0 - attention_mask) * -10000.0

            if attention_mask is None:
                attention_mask = padding_mask
            else:
                attention_mask = attention_mask * padding_mask

            hidden_states = input_embeddings
            hidden_states = self.drop(hidden_states, training=training)

            output_shape = input_shape + [shape_list(hidden_states)[-1]]

            presents = ()
            all_attentions = []
            all_hidden_states = ()
            for i, (block, layer_past) in enumerate(zip(self.h, past)):
                if self.output_hidden_states:
                    all_hidden_states = all_hidden_states + (tf.reshape(
                        hidden_states, output_shape), )

                outputs = block([hidden_states, layer_past, attention_mask],
                                training=training)

                hidden_states, present = outputs[:2]
                presents = presents + (present, )

                if self.output_attentions:
                    all_attentions.append(outputs[2])

            hidden_states = self.ln_f(hidden_states)
            hidden_states = tf.reshape(hidden_states, output_shape)

            # Add last hidden state
            if self.output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states, )

            outputs = (hidden_states, presents)

            if self.output_hidden_states:
                outputs = outputs + (all_hidden_states, )
            if self.output_attentions:
                # let the number of heads free (-1) so we can extract attention even after head pruning
                attention_output_shape = input_shape[:-1] + [-1] + shape_list(
                    all_attentions[0])[-2:]
                all_attentions = tuple(
                    tf.reshape(t, attention_output_shape)
                    for t in all_attentions)
                outputs = outputs + (all_attentions, )
            if self.output_embeddings:
                outputs = outputs
            return outputs  # last hidden state, presents, (all hidden_states), (attentions), input_embedding
Beispiel #2
0
    def _attn(self, inputs, training=False):
        q, k, v, attention_mask = inputs
        # q, k, v have shape [batch, heads, sequence, features]
        w = tf.matmul(q, k, transpose_b=True)
        if self.scale:
            dk = tf.cast(tf.shape(k)[-1], tf.float32)  # scale attention_scores
            w = w / tf.math.sqrt(dk)

        # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
        _, _, nd, ns = shape_list(w)
        if self.casual_masking:
            b = self.causal_attention_mask(nd, ns, dtype=w.dtype)
            b = tf.reshape(b, [1, 1, nd, ns])
            w = w * b - 1e4 * (1 - b)

        if attention_mask is not None:
            # Apply the attention mask
            w = w + attention_mask

        w = tf.nn.softmax(w, axis=-1)
        w = self.attn_dropout(w, training=training)

        outputs = [tf.matmul(w, v)]
        if self.output_attentions:
            outputs.append(w)
        return outputs
Beispiel #3
0
    def call(self, x, **kwargs):
        bz, sl = shape_list(x)[:2]

        x = tf.reshape(x, [-1, self.nx])
        x = tf.matmul(x, self.weight) + self.bias

        x = tf.reshape(x, [bz, sl, self.nf])

        return x
Beispiel #4
0
    def get_input_embeddings(self,
                             inputs,
                             past=None,
                             attention_mask=None,
                             token_type_ids=None,
                             position_ids=None,
                             training=False):
        if past is None:
            past_length = 0
            past = [None] * len(self.h)
        else:
            past_length = shape_list(past[0][0])[-2]
        if position_ids is None:
            position_ids = tf.range(past_length,
                                    shape_list(inputs)[-1] + past_length,
                                    dtype=tf.int32)[tf.newaxis, :]

        input_shape = shape_list(inputs)

        input_ids = tf.reshape(inputs, [-1, input_shape[-1]])
        position_ids = tf.reshape(position_ids,
                                  [-1, shape_list(position_ids)[-1]])

        inputs_embeds = self.wte(input_ids, mode='embedding')
        position_embeds = self.wpe(position_ids)

        if token_type_ids is not None:
            token_type_ids = tf.reshape(
                token_type_ids, [-1, shape_list(token_type_ids)[-1]])
            token_type_embeds = self.wte(token_type_ids, mode='embedding')
        else:
            token_type_embeds = 0

        input_embeddings = inputs_embeds + position_embeds + token_type_embeds

        padding_mask = tf.cast(tf.not_equal(
            inputs, tf.zeros_like(inputs))[:, tf.newaxis, :, tf.newaxis],
                               dtype=tf.float32)

        return input_embeddings, input_shape, padding_mask, past
Beispiel #5
0
 def split_heads(self, x):
     x_shape = shape_list(x)
     new_x_shape = x_shape[:-1] + [self.n_head, x_shape[-1] // self.n_head]
     x = tf.reshape(x, new_x_shape)
     return tf.transpose(
         x, (0, 2, 1, 3))  # (batch, head, seq_length, head_features)
Beispiel #6
0
 def merge_heads(self, x):
     x = tf.transpose(x, [0, 2, 1, 3])
     x_shape = shape_list(x)
     new_x_shape = x_shape[:-2] + [x_shape[-2] * x_shape[-1]]
     return tf.reshape(x, new_x_shape)
Beispiel #7
0
        def _call(inputs, past, attention_mask, token_type_ids, position_ids,
                  training):

            if past is None:
                past_length = 0
                past = [None] * len(self.h)
            else:
                past_length = shape_list(past[0][0])[-2]
            if position_ids is None:
                position_ids = tf.range(past_length,
                                        shape_list(inputs)[-1] + past_length,
                                        dtype=tf.int32)[tf.newaxis, :]

            if attention_mask is not None:
                attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
                attention_mask = tf.cast(attention_mask, tf.float32)
                attention_mask = (1.0 - attention_mask) * -10000.0

            padding_mask = tf.cast(tf.not_equal(
                inputs, tf.zeros_like(inputs))[:, tf.newaxis, :, tf.newaxis],
                                   dtype=tf.float32)

            if attention_mask is None:
                attention_mask = padding_mask
            else:
                attention_mask = attention_mask * padding_mask

            input_shape = shape_list(inputs)
            input_ids = tf.reshape(inputs, [-1, input_shape[-1]])
            position_ids = tf.reshape(position_ids,
                                      [-1, shape_list(position_ids)[-1]])

            inputs_embeds = self.wte(input_ids, mode='embedding')
            position_embeds = self.wpe(position_ids)

            if token_type_ids is not None:
                token_type_ids = tf.reshape(
                    token_type_ids, [-1, shape_list(token_type_ids)[-1]])
                token_type_embeds = self.wte(token_type_ids, mode='embedding')
            else:
                token_type_embeds = 0

            hidden_states = inputs_embeds + position_embeds + token_type_embeds
            hidden_states = self.drop(hidden_states, training=training)

            output_shape = input_shape + [shape_list(hidden_states)[-1]]

            presents = ()
            all_attentions = []
            all_hidden_states = ()
            for i, (block, layer_past) in enumerate(zip(self.h, past)):
                if self.output_hidden_states:
                    all_hidden_states = all_hidden_states + (tf.reshape(
                        hidden_states, output_shape), )

                outputs = block([hidden_states, layer_past, attention_mask],
                                training=training)

                hidden_states, present = outputs[:2]
                presents = presents + (present, )

                if self.output_attentions:
                    all_attentions.append(outputs[2])

            hidden_states = self.ln_f(hidden_states)
            hidden_states = tf.reshape(hidden_states, output_shape)

            # Add last hidden state
            if self.output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states, )

            outputs = (hidden_states, presents)

            if self.output_hidden_states:
                outputs = outputs + (all_hidden_states, )
            if self.output_attentions:
                # let the number of heads free (-1) so we can extract attention even after head pruning
                attention_output_shape = input_shape[:-1] + [-1] + shape_list(
                    all_attentions[0])[-2:]
                all_attentions = tuple(
                    tf.reshape(t, attention_output_shape)
                    for t in all_attentions)
                outputs = outputs + (all_attentions, )
            if self.output_embeddings:
                outputs = outputs + (inputs_embeds, )
            return outputs  # last hidden state, presents, (all hidden_states), (attentions)