Example #1
0
    def grad(res_grad):

        grads = tf.zeros_like(tf.einsum('ijk,ijl->ijkl', ks[0], vs[0]))

        gr_sums = sums

        q_grads = []
        k_grads = []
        v_grads = []

        for index in range(qs.shape[0] - 1, -1, -1):

            q_grads.append(
                tf.einsum('ijkl,ijl->ijk', gr_sums, res_grad[index])[None,
                                                                     Ellipsis])
            grads = grads + tf.einsum('ijk,ijl->ijkl', qs[index],
                                      res_grad[index])
            k_grads.append(
                tf.einsum('ijkl,ijl->ijk', grads, vs[index])[None, Ellipsis])
            v_grads.append(
                tf.einsum('ijkl,ijk->ijl', grads, ks[index])[None, Ellipsis])
            gr_sums = gr_sums - tf.einsum('ijk,ijl->ijkl', ks[index],
                                          vs[index])

        q_grads = tf.concat(q_grads[::-1], axis=0)
        k_grads = tf.concat(k_grads[::-1], axis=0)
        v_grads = tf.concat(v_grads[::-1], axis=0)

        return q_grads, k_grads, v_grads
Example #2
0
                def _forward(dilated_ids, dilated_mask):

                    logits = self._bert_forward(
                        bert_config,
                        dilated_ids,
                        dilated_mask,
                        batch_size,
                        dilated_seq_length,
                        tilda_embeddings=tilda_embeddings)
                    output_ids = tf.argmax(logits, axis=-1)
                    output_ids = tf.cast(output_ids, dtype=tf.int32)

                    # special padding (using `spad` token)
                    equal_zero = tf.cast(tf.equal(output_ids, 0), tf.int32)
                    equal_zero = tf.reduce_sum(equal_zero, axis=-1)
                    right_pad = spad_id * tf.sequence_mask(
                        equal_zero, dilated_seq_length, dtype=tf.int32)
                    paded = tf.concat([output_ids, right_pad], axis=-1)

                    # extract ids of length `max_seq_length`
                    flattened_padded = tf.reshape(paded, [-1])
                    is_valid = tf.cast(tf.greater(flattened_padded, 0),
                                       dtype=tf.int32)
                    flattened_valid = tf.boolean_mask(flattened_padded,
                                                      is_valid)
                    valid = tf.reshape(flattened_valid,
                                       [batch_size, dilated_seq_length])
                    cutted_valid = valid[:, :max_seq_length]

                    # replace `spad` token with `pad`
                    non_spad_mask = tf.cast(tf.not_equal(
                        cutted_valid, spad_id),
                                            dtype=tf.int32)
                    output_ids = cutted_valid * non_spad_mask
                    output_length = tf.reduce_sum(non_spad_mask, axis=-1)

                    # dilate
                    reshaped_ids = tf.reshape(output_ids,
                                              [batch_size, max_seq_length, 1])
                    reshaped_mask = tf.reshape(
                        tf.sequence_mask(output_length,
                                         max_seq_length,
                                         dtype=tf.int32),
                        [batch_size, max_seq_length, 1])
                    concat_ids = tf.concat(
                        [reshaped_ids,
                         tf.zeros_like(reshaped_ids)], axis=-1)
                    concat_mask = tf.concat([
                        reshaped_mask,
                        tf.zeros_like(reshaped_mask, dtype=tf.int32)
                    ],
                                            axis=-1)
                    dilated_ids = tf.reshape(concat_ids,
                                             [batch_size, max_seq_length * 2])
                    dilated_mask = tf.reshape(concat_mask,
                                              [batch_size, max_seq_length * 2])

                    return dilated_ids, dilated_mask
Example #3
0
    def _forward(self, is_training, split_placeholders, **kwargs):

        if not is_training:
            return super()._forward(is_training, split_placeholders, **kwargs)

        aug_input_ids = tf.boolean_mask(
            split_placeholders['aug_input_ids'],
            mask=(1.0 - split_placeholders['is_supervised']),
            axis=0)
        aug_input_mask = tf.boolean_mask(
            split_placeholders['aug_input_mask'],
            mask=(1.0 - split_placeholders['is_supervised']),
            axis=0)
        aug_segment_ids = tf.boolean_mask(
            split_placeholders['aug_segment_ids'],
            mask=(1.0 - split_placeholders['is_supervised']),
            axis=0)
        input_ids = tf.concat([split_placeholders['input_ids'], aug_input_ids],
                              axis=0)
        input_mask = tf.concat(
            [split_placeholders['input_mask'], aug_input_mask], axis=0)
        segment_ids = tf.concat(
            [split_placeholders['segment_ids'], aug_segment_ids], axis=0)
        encoder = BERTEncoder(bert_config=self.bert_config,
                              is_training=is_training,
                              input_ids=input_ids,
                              input_mask=input_mask,
                              segment_ids=segment_ids,
                              scope='bert',
                              drop_pooler=self._drop_pooler,
                              **kwargs)
        encoder_output = encoder.get_pooled_output()

        label_ids = split_placeholders['label_ids']
        is_expanded = tf.zeros_like(label_ids, dtype=tf.float32)
        batch_size = util.get_shape_list(aug_input_ids)[0]
        aug_is_expanded = tf.ones((batch_size), dtype=tf.float32)
        is_expanded = tf.concat([is_expanded, aug_is_expanded], axis=0)
        decoder = UDADecoder(
            is_training=is_training,
            input_tensor=encoder_output,
            is_supervised=split_placeholders['is_supervised'],
            is_expanded=is_expanded,
            label_ids=label_ids,
            label_size=self.label_size,
            sample_weight=split_placeholders.get('sample_weight'),
            scope='cls/seq_relationship',
            global_step=self._global_step,
            num_train_steps=self.total_steps,
            uda_softmax_temp=self._uda_softmax_temp,
            uda_confidence_thresh=self._uda_confidence_thresh,
            tsa_schedule=self._tsa_schedule,
            **kwargs)
        (total_loss, losses, probs, preds) = decoder.get_forward_outputs()
        return (total_loss, losses, probs, preds)
Example #4
0
def _create_mask(qlen, mlen, dtype=tf.float32, same_length=False):
    '''create causal attention mask.'''
    attn_mask = tf.ones([qlen, qlen], dtype=dtype)
    mask_u = tf.matrix_band_part(attn_mask, 0, -1)
    mask_dia = tf.matrix_band_part(attn_mask, 0, 0)
    attn_mask_pad = tf.zeros([qlen, mlen], dtype=dtype)
    ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1)
    if same_length:
        mask_l = tf.matrix_band_part(attn_mask, -1, 0)
        ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1)

    return ret
Example #5
0
def multihead_attention(queries, keys, values, key_masks,
                        num_heads=8,
                        dropout_rate=0,
                        training=True,
                        causality=False,
                        scope='multihead_attention'):
    '''Applies multihead attention. See 3.2.2
    queries: A 3d tensor with shape of [N, T_q, d_model].
    keys: A 3d tensor with shape of [N, T_k, d_model].
    values: A 3d tensor with shape of [N, T_k, d_model].
    key_masks: A 2d tensor with shape of [N, key_seqlen]
    num_heads: An int. Number of heads.
    dropout_rate: A floating point number.
    training: Boolean. Controller of mechanism for dropout.
    causality: Boolean. If true, units that reference the future are masked.
    scope: Optional scope for `variable_scope`.

    Returns
      A 3d tensor with shape of (N, T_q, C)
    '''
    d_model = queries.get_shape().as_list()[-1]
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        # Linear projections
        Q = tf.layers.dense(
            queries, d_model, use_bias=True) # (N, T_q, d_model)
        K = tf.layers.dense(
            keys, d_model, use_bias=True) # (N, T_k, d_model)
        V = tf.layers.dense(
            values, d_model, use_bias=True) # (N, T_k, d_model)

        # Split and concat
        Q_ = tf.concat(
            tf.split(Q, num_heads, axis=2), axis=0) # (h*N, T_q, d_model/h)
        K_ = tf.concat(
            tf.split(K, num_heads, axis=2), axis=0) # (h*N, T_k, d_model/h)
        V_ = tf.concat(
            tf.split(V, num_heads, axis=2), axis=0) # (h*N, T_k, d_model/h)

        # Attention
        outputs = scaled_dot_product_attention(
            Q_, K_, V_, key_masks, causality, dropout_rate, training)

        # Restore shape
        outputs = tf.concat(
            tf.split(outputs, num_heads, axis=0), axis=2 ) # (N, T_q, d_model)

        # Residual connection
        outputs += queries

        # Normalize
        outputs = ln(outputs)

    return outputs
Example #6
0
    def __init__(self,
                 is_training,
                 input_tensor,
                 label_ids,
                 sample_weight=None,
                 scope='mrc',
                 name='',
                 hidden_dropout_prob=0.1,
                 initializer_range=0.02,
                 trainable=True,
                 **kwargs):
        super().__init__(**kwargs)

        seq_length = input_tensor.shape.as_list()[-2]
        hidden_size = input_tensor.shape.as_list()[-1]
        with tf.variable_scope(scope):
            output_weights = tf.get_variable(
                'output_weights',
                shape=[2, hidden_size],
                initializer=util.create_initializer(initializer_range),
                trainable=trainable)
            output_bias = tf.get_variable('output_bias',
                                          shape=[2],
                                          initializer=tf.zeros_initializer(),
                                          trainable=trainable)

            output_layer = util.dropout(
                input_tensor, hidden_dropout_prob if is_training else 0.0)

            output_layer = tf.reshape(output_layer, [-1, hidden_size])
            logits = tf.matmul(output_layer, output_weights, transpose_b=True)
            logits = tf.nn.bias_add(logits, output_bias)
            logits = tf.reshape(logits, [-1, seq_length, 2])
            logits = tf.transpose(logits, [0, 2, 1])
            probs = tf.nn.softmax(logits, axis=-1, name='probs')
            self.probs[name] = probs

            start_one_hot_labels = tf.one_hot(label_ids[:, 0],
                                              depth=seq_length,
                                              dtype=tf.float32)
            end_one_hot_labels = tf.one_hot(label_ids[:, 1],
                                            depth=seq_length,
                                            dtype=tf.float32)
            start_log_probs = tf.nn.log_softmax(logits[:, 0, :], axis=-1)
            end_log_probs = tf.nn.log_softmax(logits[:, 1, :], axis=-1)
            per_example_loss = (
                -0.5 * tf.reduce_sum(start_one_hot_labels * start_log_probs,
                                     axis=-1) - 0.5 *
                tf.reduce_sum(end_one_hot_labels * end_log_probs, axis=-1))
            if sample_weight is not None:
                per_example_loss *= sample_weight

            self.total_loss = tf.reduce_mean(per_example_loss)
            self.losses[name] = per_example_loss

            start_preds = tf.expand_dims(tf.argmax(logits[:, 0, :], axis=-1),
                                         axis=-1)
            end_preds = tf.expand_dims(tf.argmax(logits[:, 1, :], axis=-1),
                                       axis=-1)
            self.preds[name] = tf.concat([start_preds, end_preds], axis=-1)
Example #7
0
def embedding_lookup(input_ids,
                     vocab_size,
                     batch_size,
                     max_seq_length,
                     embedding_size=128,
                     initializer_range=0.02,
                     word_embedding_name='word_embeddings',
                     zero_pad=True,
                     dtype=tf.float32,
                     trainable=True,
                     tilda_embeddings=None):
    if input_ids.shape.ndims == 2:
        input_ids = tf.expand_dims(input_ids, axis=[-1])

    if tilda_embeddings is not None:
        embedding_table = tilda_embeddings
    else:
        embedding_table = tf.get_variable(
            name=word_embedding_name,
            shape=[vocab_size, embedding_size],
            initializer=util.create_initializer(initializer_range),
            dtype=dtype,
            trainable=trainable)

    embedding_table = tf.concat(
        (tf.zeros(shape=[1, embedding_size]),
         embedding_table[1:, :]), axis=0)

    flat_input_ids = tf.reshape(input_ids, [-1])
    output = tf.gather(
        embedding_table, flat_input_ids, name='embedding_look_up')
    output = tf.reshape(
        output, [batch_size, max_seq_length, embedding_size])

    return (output, embedding_table)
Example #8
0
def get_timing_signal_1d_given_position(channels,
                                        position,
                                        min_timescale=1.0,
                                        max_timescale=1.0e4):
    """Get sinusoids of diff frequencies, with timing position given.

  Adapted from add_timing_signal_1d_given_position in
  //third_party/py/tensor2tensor/layers/common_attention.py

  Args:
    channels: scalar, size of timing embeddings to create. The number of
        different timescales is equal to channels / 2.
    position: a Tensor with shape [batch, seq_len]
    min_timescale: a float
    max_timescale: a float

  Returns:
    a Tensor of timing signals [batch, seq_len, channels]
  """
    num_timescales = channels // 2
    log_timescale_increment = (
        math.log(float(max_timescale) / float(min_timescale)) /
        (tf.to_float(num_timescales) - 1))
    inv_timescales = min_timescale * tf.exp(
        tf.to_float(tf.range(num_timescales)) * -log_timescale_increment)
    scaled_time = (tf.expand_dims(tf.to_float(position), 2) *
                   tf.expand_dims(tf.expand_dims(inv_timescales, 0), 0))
    signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=2)
    signal = tf.pad(signal, [[0, 0], [0, 0], [0, tf.mod(channels, 2)]])
    return signal
Example #9
0
    def embedding_preprocessor(self,
                               input_values,
                               batch_size=None,
                               embedding_size=128,
                               initializer_range=0.02,
                               name='cls_embedding',
                               dtype=tf.float32,
                               trainable=True):

        with tf.variable_scope(name):
            input_values = util.layer_norm(input_values, trainable=trainable)
            linear_output = tf.layers.dense(
                input_values,
                embedding_size,
                activation=None,
                name='dense',
                kernel_initializer=util.create_initializer(initializer_range),
                trainable=trainable)

            cls_embedding = tf.get_variable(
                name='cls',
                shape=[1, 1, embedding_size],
                initializer=util.create_initializer(initializer_range),
                dtype=dtype,
                trainable=trainable)
            cls_output = tf.tile(cls_embedding, [batch_size, 1, 1])

        output = tf.concat([cls_output, linear_output], axis=1)
        return output
Example #10
0
def attn(x, scope, n_state, *, past, hparams):
    assert x.shape.ndims == 3  # Should be [batch, sequence, features]
    assert n_state % hparams.n_head == 0
    if past is not None:
        assert past.shape.ndims == 5  # Should be [batch, 2, heads, sequence,
        # features], where 2 is [k, v]

    def split_heads(x):
        # From [batch, sequence, features] to [batch, heads,
        # sequence, features]
        return tf.transpose(split_states(x, hparams.n_head), [0, 2, 1, 3])

    def merge_heads(x):
        # Reverse of split_heads
        return merge_states(tf.transpose(x, [0, 2, 1, 3]))

    def mask_attn_weights(w):
        # w has shape [batch, heads, dst_sequence, src_sequence], where
        # information flows from src to dst.
        _, _, nd, ns = shape_list(w)
        b = attention_mask(nd, ns, dtype=w.dtype)
        b = tf.reshape(b, [1, 1, nd, ns])
        w = w * b - tf.cast(1e10, w.dtype) * (1 - b)
        return w

    def multihead_attn(q, k, v):
        # q, k, v have shape [batch, heads, sequence, features]
        w = tf.matmul(q, k, transpose_b=True)
        w = w * tf.rsqrt(tf.cast(v.shape[-1].value, w.dtype))

        w = mask_attn_weights(w)
        w = softmax(w)
        a = tf.matmul(w, v)
        return a

    with tf.variable_scope(scope):
        c = conv1d(x, 'c_attn', n_state * 3)
        q, k, v = map(split_heads, tf.split(c, 3, axis=2))
        present = tf.stack([k, v], axis=1)
        if past is not None:
            pk, pv = tf.unstack(past, axis=1)
            k = tf.concat([pk, k], axis=-2)
            v = tf.concat([pv, v], axis=-2)
        a = multihead_attn(q, k, v)
        a = merge_heads(a)
        a = conv1d(a, 'c_proj', n_state)
        return a, present
Example #11
0
    def __init__(self,
                 is_training,
                 input_tensor,
                 input_mask,
                 label_ids,
                 label_size=2,
                 sample_weight=None,
                 scope='cls/sequence',
                 name='',
                 hidden_dropout_prob=0.1,
                 initializer_range=0.02,
                 trainable=True,
                 **kwargs):
        super().__init__(**kwargs)

        batch_size = tf.shape(input_tensor)[0]
        seq_length = input_tensor.shape.as_list()[-2]
        hidden_size = input_tensor.shape.as_list()[-1]
        with tf.variable_scope(scope):
            output_weights = tf.get_variable(
                'output_weights',
                shape=[label_size, hidden_size],
                initializer=util.create_initializer(initializer_range),
                trainable=trainable)
            output_bias = tf.get_variable('output_bias',
                                          shape=[label_size],
                                          initializer=tf.zeros_initializer(),
                                          trainable=trainable)

            output_layer = util.dropout(
                input_tensor, hidden_dropout_prob if is_training else 0.0)

            output_layer = tf.reshape(output_layer, [-1, hidden_size])
            logits = tf.matmul(output_layer, output_weights, transpose_b=True)
            logits = tf.nn.bias_add(logits, output_bias)
            logits = tf.reshape(logits, [-1, seq_length, label_size])

            self.preds[name] = tf.argmax(logits, axis=-1)
            self.probs[name] = tf.nn.softmax(logits, axis=-1, name='probs')

            log_probs = tf.nn.log_softmax(logits, axis=-1)
            one_hot_labels = tf.one_hot(label_ids,
                                        depth=label_size,
                                        dtype=tf.float32)
            per_token_losses = -tf.reduce_mean(one_hot_labels * log_probs,
                                               axis=-1)
            input_mask = tf.concat([
                tf.zeros((batch_size, 1), dtype=tf.float32),
                tf.cast(input_mask[:, 2:], dtype=tf.float32),
                tf.zeros((batch_size, 1), dtype=tf.float32)
            ],
                                   axis=-1)
            per_token_losses *= input_mask
            per_example_loss = tf.reduce_mean(per_token_losses, axis=-1)
            if sample_weight is not None:
                per_example_loss *= tf.cast(sample_weight, dtype=tf.float32)

            self.losses[name] = per_example_loss
            self.total_loss = tf.reduce_mean(per_example_loss)
Example #12
0
def scatter_update(sequence, updates, positions):
    '''Scatter-update a sequence.

    Args:
      sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] tensor
      updates: A tensor of size batch_size*seq_len(*depth)
      positions: A [batch_size, n_positions] tensor

    Returns: A tuple of two tensors. First is a [batch_size, seq_len] or
      [batch_size, seq_len, depth] tensor of 'sequence' with elements at
      'positions' replaced by the values at 'updates.' Updates to index 0 are
      ignored. If there are duplicated positions the update is only applied
      once. Second is a [batch_size, seq_len] mask tensor of which inputs were
      updated.
    '''
    shape = util.get_shape_list(sequence, expected_rank=[2, 3])
    depth_dimension = (len(shape) == 3)
    if depth_dimension:
        B, L, D = shape
    else:
        B, L = shape
        D = 1
        sequence = tf.expand_dims(sequence, -1)
    N = util.get_shape_list(positions)[1]

    shift = tf.expand_dims(L * tf.range(B), -1)
    flat_positions = tf.reshape(positions + shift, [-1, 1])
    flat_updates = tf.reshape(updates, [-1, D])
    updates = tf.scatter_nd(flat_positions, flat_updates, [B * L, D])
    updates = tf.reshape(updates, [B, L, D])

    flat_updates_mask = tf.ones([B * N], tf.int32)
    updates_mask = tf.scatter_nd(flat_positions, flat_updates_mask, [B * L])
    updates_mask = tf.reshape(updates_mask, [B, L])
    not_first_token = tf.concat(
        [tf.zeros((B, 1), tf.int32),
         tf.ones((B, L - 1), tf.int32)], -1)
    updates_mask *= not_first_token
    updates_mask_3d = tf.expand_dims(updates_mask, -1)

    # account for duplicate positions
    if sequence.dtype == tf.float32:
        updates_mask_3d = tf.cast(updates_mask_3d, tf.float32)
        updates /= tf.maximum(1.0, updates_mask_3d)
    else:
        assert sequence.dtype == tf.int32
        updates = tf.divide(updates, tf.maximum(1, updates_mask_3d))
        updates = tf.cast(updates, tf.int32)
    updates_mask = tf.minimum(updates_mask, 1)
    updates_mask_3d = tf.minimum(updates_mask_3d, 1)

    updated_sequence = (((1 - updates_mask_3d) * sequence) +
                        (updates_mask_3d * updates))
    if not depth_dimension:
        updated_sequence = tf.squeeze(updated_sequence, -1)

    return updated_sequence, updates_mask
Example #13
0
def positional_embedding(pos_seq, inv_freq, bsz=None):
    sinusoid_inp = tf.einsum('i,d->id', pos_seq, inv_freq)
    pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1)
    pos_emb = pos_emb[:, None, :]

    if bsz is not None:
        pos_emb = tf.tile(pos_emb, [1, bsz, 1])

    return pos_emb
Example #14
0
                def _forward(dilated_ids, dilated_mask):

                    logits = self._bert_forward(
                        bert_config,
                        dilated_ids,
                        dilated_mask,
                        batch_size,
                        dilated_seq_length,
                        tilda_embeddings=tilda_embeddings)
                    output_ids = tf.argmax(logits, axis=-1)
                    output_ids = tf.cast(output_ids, dtype=tf.int32)

                    equal_zero = tf.cast(tf.equal(output_ids, 0), tf.int32)
                    equal_zero = tf.reduce_sum(equal_zero, axis=-1)
                    right_pad = spad_id * tf.sequence_mask(
                        equal_zero, dilated_seq_length, dtype=tf.int32)

                    paded = tf.concat([output_ids, right_pad], axis=-1)
                    flattened_padded = tf.reshape(paded, [-1])
                    is_valid = tf.cast(tf.greater(flattened_padded, 0),
                                       dtype=tf.int32)
                    flattened_valid = tf.boolean_mask(flattened_padded,
                                                      is_valid)
                    valid = tf.reshape(flattened_valid,
                                       [batch_size, dilated_seq_length])
                    cutted_valid = valid[:, :max_seq_length]

                    nonpad_mask = tf.cast(tf.not_equal(cutted_valid, spad_id),
                                          dtype=tf.int32)
                    output_ids = cutted_valid * nonpad_mask

                    reshaped = tf.reshape(output_ids,
                                          [batch_size, max_seq_length, 1])
                    concatenated = tf.concat(
                        [reshaped, tf.zeros_like(reshaped)], axis=-1)
                    dilated_ids = tf.reshape(concatenated,
                                             [batch_size, max_seq_length * 2])

                    input_mask = tf.reduce_sum(nonpad_mask, axis=-1)
                    dilated_mask = tf.sequence_mask(input_mask,
                                                    dilated_seq_length,
                                                    dtype=tf.int32)

                    return dilated_ids, dilated_mask
Example #15
0
    def _cls_self_attention(self,
                            prev_output,
                            batch_size,
                            max_seq_length,
                            label_size,
                            attention_mask=None,
                            cls_hidden_size=128,
                            cls_num_attention_heads=2,
                            attention_probs_dropout_prob=0.1,
                            initializer_range=0.02,
                            dtype=tf.float32,
                            trainable=True):
        if cls_hidden_size % cls_num_attention_heads != 0:
            raise ValueError(
                '`cls_hidden_size` (%d) is not a multiple of the number of '
                '`cls_num_attention_heads` (%d)' %
                (cls_hidden_size, cls_num_attention_heads))
        cls_attention_head_size = int(cls_hidden_size /
                                      cls_num_attention_heads)

        with tf.variable_scope('attention'):
            attention_heads = []
            with tf.variable_scope('self'):
                attention_head, _ = self.attention_layer(
                    from_tensor=prev_output,
                    to_tensor=prev_output,
                    attention_mask=attention_mask,
                    num_attention_heads=cls_num_attention_heads,
                    size_per_head=cls_attention_head_size,
                    attention_probs_dropout_prob=attention_probs_dropout_prob,
                    initializer_range=initializer_range,
                    do_return_2d_tensor=False,
                    batch_size=batch_size,
                    from_max_seq_length=max_seq_length,
                    to_max_seq_length=max_seq_length,
                    dtype=dtype,
                    trainable=trainable)
                attention_heads.append(attention_head)

            attention_output = None
            if len(attention_heads) == 1:
                attention_output = attention_heads[0]
            else:
                attention_output = tf.concat(attention_heads, axis=-1)
            attention_output = util.layer_norm(attention_output[:, 0, :],
                                               trainable=trainable)

        with tf.variable_scope('output'):
            cls_output = tf.layers.dense(
                attention_output,
                label_size,
                kernel_initializer=util.create_initializer(initializer_range),
                trainable=trainable)

        return cls_output
Example #16
0
 def _single_seq_fn():
     batch_size = tf.shape(inputs, out_type=tag_indices.dtype)[0]
     example_inds = tf.reshape(
         tf.range(batch_size, dtype=tag_indices.dtype), [-1, 1])
     sequence_scores = tf.gather_nd(
         tf.squeeze(inputs, [1]),
         tf.concat([example_inds, tag_indices], axis=1))
     sequence_scores = tf.where(tf.less_equal(sequence_lengths, 0),
                                tf.zeros_like(sequence_scores),
                                sequence_scores)
     return sequence_scores
Example #17
0
def relative_positional_encoding(qlen,
                                 klen,
                                 d_model,
                                 clamp_len,
                                 attn_type,
                                 bi_data,
                                 bsz=None,
                                 dtype=None):
    '''create relative positional encoding.'''
    freq_seq = tf.range(0, d_model, 2.0)
    if dtype is not None and dtype != tf.float32:
        freq_seq = tf.cast(freq_seq, dtype=dtype)
    inv_freq = 1 / (10000**(freq_seq / d_model))

    if attn_type == 'bi':
        # beg, end = klen - 1, -qlen
        beg, end = klen, -qlen
    elif attn_type == 'uni':
        # beg, end = klen - 1, -1
        beg, end = klen, -1
    else:
        raise ValueError('Unknown `attn_type` {}.'.format(attn_type))

    if bi_data:
        fwd_pos_seq = tf.range(beg, end, -1.0)
        bwd_pos_seq = tf.range(-beg, -end, 1.0)

        if dtype is not None and dtype != tf.float32:
            fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
            bwd_pos_seq = tf.cast(bwd_pos_seq, dtype=dtype)

        if clamp_len > 0:
            fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -clamp_len, clamp_len)
            bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -clamp_len, clamp_len)

        if bsz is not None:
            # With bi_data, the batch size should be divisible by 2.
            assert bsz % 2 == 0
            fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz // 2)
            bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq, bsz // 2)
        else:
            fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq)
            bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq)

        pos_emb = tf.concat([fwd_pos_emb, bwd_pos_emb], axis=1)
    else:
        fwd_pos_seq = tf.range(beg, end, -1.0)
        if dtype is not None and dtype != tf.float32:
            fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
        if clamp_len > 0:
            fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -clamp_len, clamp_len)
        pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz)

    return pos_emb
Example #18
0
def create_projection_matrix(m, d, seed=0, scaling=0, struct_mode=False):
    r'''Constructs the matrix of random projections.
  Constructs a matrix of random orthogonal projections. Each projection vector
  has direction chosen uniformly at random and either deterministic length
  \sqrt{d} or length taken from the \chi(d) distribution (in the latter case
  marginal distributions of the projections are d-dimensional Gaussian vectors
  with associated identity covariance matrix).
  Args:
    m: number of random projections.
    d: dimensionality of each random projection.
    seed: random seed used to construct projections.
    scaling: 1 if all the random projections need to be renormalized to have
      length \sqrt{d}, 0 if the lengths of random projections should follow
      \chi(d) distribution.
    struct_mode: if True then products of Givens rotations will be used to
      construct random orthogonal matrix. This bypasses Gram-Schmidt
      orthogonalization.
  Returns:
    The matrix of random projections of the shape [m, d].
  '''
    nb_full_blocks = int(m / d)
    block_list = []
    current_seed = seed
    for _ in range(nb_full_blocks):
        if struct_mode:
            q = create_products_of_givens_rotations(d, seed)
        else:
            unstructured_block = tf.random_normal((d, d), seed=current_seed)
            q, _ = tf.linalg.qr(unstructured_block)
            q = tf.transpose(q)
        block_list.append(q)
        current_seed += 1
    remaining_rows = m - nb_full_blocks * d
    if remaining_rows > 0:
        if struct_mode:
            q = create_products_of_givens_rotations(d, seed)
        else:
            unstructured_block = tf.random_normal((d, d), seed=current_seed)
            q, _ = tf.linalg.qr(unstructured_block)
            q = tf.transpose(q)
        block_list.append(q[0:remaining_rows])
    final_matrix = tf.concat(block_list, axis=0)
    current_seed += 1

    if scaling == 0:
        multiplier = tf.norm(tf.random_normal((m, d), seed=current_seed),
                             axis=1)
    elif scaling == 1:
        multiplier = 1 / tf.math.rsqrt(float(d)) * tf.ones((m))
    else:
        raise ValueError('Scaling must be one of {0, 1}. Was %s' % scaling)

    return tf.matmul(tf.linalg.diag(multiplier), final_matrix)
Example #19
0
def causal_numerator(qs, ks, vs):
    '''Computes not-normalized FAVOR causal attention A_{masked}V.
  Args:
    qs: query_prime tensor of the shape [L,B,H,M].
    ks: key_prime tensor of the shape [L,B,H,M].
    vs: value tensor of the shape [L,B,H,D].
  Returns:
    Not-normalized FAVOR causal attention A_{masked}V.
  '''

    result = []
    sums = tf.zeros_like(tf.einsum('ijk,ijl->ijkl', ks[0], vs[0]))

    for index in range(qs.shape[0]):
        sums = sums + tf.einsum('ijk,ijl->ijkl', ks[index], vs[index])
        result.append(
            tf.einsum('ijkl,ijk->ijl', sums, qs[index])[None, Ellipsis])

    result = tf.concat(result, axis=0)

    def grad(res_grad):

        grads = tf.zeros_like(tf.einsum('ijk,ijl->ijkl', ks[0], vs[0]))

        gr_sums = sums

        q_grads = []
        k_grads = []
        v_grads = []

        for index in range(qs.shape[0] - 1, -1, -1):

            q_grads.append(
                tf.einsum('ijkl,ijl->ijk', gr_sums, res_grad[index])[None,
                                                                     Ellipsis])
            grads = grads + tf.einsum('ijk,ijl->ijkl', qs[index],
                                      res_grad[index])
            k_grads.append(
                tf.einsum('ijkl,ijl->ijk', grads, vs[index])[None, Ellipsis])
            v_grads.append(
                tf.einsum('ijkl,ijk->ijl', grads, ks[index])[None, Ellipsis])
            gr_sums = gr_sums - tf.einsum('ijk,ijl->ijkl', ks[index],
                                          vs[index])

        q_grads = tf.concat(q_grads[::-1], axis=0)
        k_grads = tf.concat(k_grads[::-1], axis=0)
        v_grads = tf.concat(v_grads[::-1], axis=0)

        return q_grads, k_grads, v_grads

    return result, grad
Example #20
0
    def grad(res_grad):

        k_grad = tf.zeros_like(ks[0])

        gr_sums = sums

        q_grads = []
        k_grads = []

        for index in range(qs.shape[0] - 1, -1, -1):

            q_grads.append(
                tf.einsum('ijk,ij->ijk', gr_sums, res_grad[index])[None,
                                                                   Ellipsis])
            k_grad = k_grad + tf.einsum('ijk,ij->ijk', qs[index],
                                        res_grad[index])
            k_grads.append(k_grad[None, Ellipsis])
            gr_sums = gr_sums - ks[index]

        q_grads = tf.concat(q_grads[::-1], axis=0)
        k_grads = tf.concat(k_grads[::-1], axis=0)

        return q_grads, k_grads
Example #21
0
def rel_multihead_attn(h,
                       r,
                       r_w_bias,
                       r_r_bias,
                       seg_mat,
                       r_s_bias,
                       seg_embed,
                       attn_mask,
                       mems,
                       d_model,
                       n_head,
                       d_head,
                       dropout,
                       dropatt,
                       is_training,
                       kernel_initializer,
                       scope='rel_attn',
                       reuse=None):
    '''Multi-head attention with relative positional encoding.'''

    scale = 1 / (d_head**0.5)
    with tf.variable_scope(scope, reuse=reuse):
        if mems is not None and mems.shape.ndims > 1:
            cat = tf.concat([mems, h], 0)
        else:
            cat = h

        # content heads
        q_head_h = head_projection(h, d_model, n_head, d_head,
                                   kernel_initializer, 'q')
        k_head_h = head_projection(cat, d_model, n_head, d_head,
                                   kernel_initializer, 'k')
        v_head_h = head_projection(cat, d_model, n_head, d_head,
                                   kernel_initializer, 'v')

        # positional heads
        k_head_r = head_projection(r, d_model, n_head, d_head,
                                   kernel_initializer, 'r')

        # core attention ops
        attn_vec = rel_attn_core(q_head_h, k_head_h, v_head_h, k_head_r,
                                 seg_embed, seg_mat, r_w_bias, r_r_bias,
                                 r_s_bias, attn_mask, dropatt, is_training,
                                 scale)

        # post processing
        output = post_attention(h, attn_vec, d_model, n_head, d_head, dropout,
                                is_training, kernel_initializer)

    return output
Example #22
0
def _cache_mem(curr_out, prev_mem, mem_len, reuse_len=None):
    '''cache hidden states into memory.'''
    if mem_len is None or mem_len == 0:
        return None
    else:
        if reuse_len is not None and reuse_len > 0:
            curr_out = curr_out[:reuse_len]

        if prev_mem is None:
            new_mem = curr_out[-mem_len:]
        else:
            new_mem = tf.concat([prev_mem, curr_out], 0)[-mem_len:]

    return tf.stop_gradient(new_mem)
Example #23
0
def causal_denominator(qs, ks):
    '''Computes FAVOR normalizer in causal attention.
  Args:
    qs: query_prime tensor of the shape [L,B,H,M].
    ks: key_prime tensor of the shape [L,B,H,M].
  Returns:
    FAVOR normalizer in causal attention.
  '''

    result = []
    sums = tf.zeros_like(ks[0])

    for index in range(qs.shape[0]):
        sums = sums + ks[index]
        result.append(tf.reduce_sum(qs[index] * sums, axis=2)[None, Ellipsis])

    result = tf.concat(result, axis=0)

    def grad(res_grad):

        k_grad = tf.zeros_like(ks[0])

        gr_sums = sums

        q_grads = []
        k_grads = []

        for index in range(qs.shape[0] - 1, -1, -1):

            q_grads.append(
                tf.einsum('ijk,ij->ijk', gr_sums, res_grad[index])[None,
                                                                   Ellipsis])
            k_grad = k_grad + tf.einsum('ijk,ij->ijk', qs[index],
                                        res_grad[index])
            k_grads.append(k_grad[None, Ellipsis])
            gr_sums = gr_sums - ks[index]

        q_grads = tf.concat(q_grads[::-1], axis=0)
        k_grads = tf.concat(k_grads[::-1], axis=0)

        return q_grads, k_grads

    return result, grad
Example #24
0
def get_token_embeddings(vocab_size, num_units, zero_pad=True):
    '''Constructs token embedding matrix.
    Note that the column of index 0's are set to zeros.
    vocab_size: scalar. V.
    num_units: embedding dimensionalty. E.
    zero_pad: Boolean. If True, all the values of the first row (id = 0) should be constant zero
    To apply query/key masks easily, zero pad is turned on.

    Returns
    weight variable: (V, E)
    '''
    with tf.variable_scope('shared_weight_matrix'):
        embeddings = tf.get_variable('weight_mat',
                                   dtype=tf.float32,
                                   shape=(vocab_size, num_units),
                                   initializer=xavier_initializer())
        if zero_pad:
            embeddings = tf.concat((tf.zeros(shape=[1, num_units]),
                                    embeddings[1:, :]), 0)
    return embeddings
Example #25
0
    def __init__(self,
                 hparams,
                 is_training,
                 input_ids,
                 sample_weight=None,
                 scope='model',
                 given=1,
                 use_tilda_embedding=False,
                 **kwargs):
        super().__init__()

        batch_size = util.get_shape_list(input_ids, expected_rank=2)[0]
        max_seq_length = hparams.n_predict

        # Tilda embeddings for SMART algorithm
        tilda_embeddings = None
        if use_tilda_embedding:
            with tf.variable_scope('', reuse=True):
                tilda_embeddings = tf.get_variable('tilda_embeddings')

        with tf.variable_scope(scope):

            def _forward(input_ids, past=None):
                batch, sequence = shape_list(input_ids)

                if tilda_embeddings is None:
                    wte = tf.get_variable(
                        'word_embeddings', [hparams.n_vocab, hparams.n_embed],
                        initializer=tf.random_normal_initializer(stddev=0.02))
                else:
                    wte = tilda_embeddings
                wpe = tf.get_variable(
                    'wpe', [hparams.n_ctx, hparams.n_embed],
                    initializer=tf.random_normal_initializer(stddev=0.01))
                past_length = 0 if past is None else tf.shape(past)[-2]
                h = (tf.gather(wte, input_ids) +
                     tf.gather(wpe, positions_for(input_ids, past_length)))

                # stacked transformer layers
                presents = []
                pasts = tf.unstack(past, axis=1) if past is not None else \
                    [None] * hparams.n_layer
                assert len(pasts) == hparams.n_layer
                for layer, past in enumerate(pasts):
                    h, present = block(h,
                                       'h%d' % layer,
                                       past=past,
                                       hparams=hparams)
                    presents.append(present)
                present = tf.stack(presents, axis=1)
                h = norm(h, 'ln_f')

                # Language model loss.  Do tokens <n predict token n?
                h_flat = tf.reshape(h, [batch * sequence, hparams.n_embed])
                logits = tf.matmul(h_flat, wte, transpose_b=True)
                logits = tf.reshape(logits, [batch, sequence, hparams.n_vocab])

                return logits, present

            # convert to labels
            label_ids = tf.concat(
                [input_ids[:, 1:],
                 tf.zeros([batch_size, 1], dtype=tf.int32)],
                axis=-1)

            # forward once
            if is_training:
                (logits, _) = _forward(input_ids)

                self.preds['LM'] = tf.argmax(logits, axis=-1)

            # forward loop
            else:
                input_ids = input_ids[:, 0:given]

                for cur_length in range(given, max_seq_length + 1):
                    (logits, _) = _forward(input_ids)

                    pred_ids = tf.argmax(logits[:,
                                                cur_length - 1:cur_length, :],
                                         axis=-1)
                    pred_ids = tf.cast(pred_ids, tf.int32)
                    input_ids = tf.concat([input_ids, pred_ids], axis=-1)

                self.preds['LM'] = input_ids

            # loss
            log_probs = tf.nn.log_softmax(logits, axis=-1)
            one_hot_labels = tf.one_hot(label_ids, depth=hparams.n_vocab)
            per_token_loss = -tf.reduce_sum(one_hot_labels * log_probs,
                                            axis=-1)
            label_mask = tf.cast(tf.not_equal(label_ids, 0), tf.float32)
            per_example_loss = \
                tf.reduce_sum(per_token_loss * label_mask, axis=-1) / \
                tf.reduce_sum(label_mask, axis=-1)
            if sample_weight is not None:
                per_example_loss *= tf.expand_dims(sample_weight, axis=-1)

            self.total_loss = tf.reduce_mean(per_example_loss)
            self.losses['LM'] = per_example_loss
Example #26
0
def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
    '''
  Sample a permutation of the factorization order, and create an
  attention mask accordingly.

  Args:
    inputs: int64 Tensor in shape [seq_len], input ids.
    targets: int64 Tensor in shape [seq_len], target ids.
    is_masked: bool Tensor in shape [seq_len]. True means being selected
      for partial prediction.
    perm_size: the length of longest permutation. Could be set to be reuse_len.
      Should not be larger than reuse_len or there will be data leaks.
    seq_len: int, sequence length.
  '''
    batch_size = tf.shape(inputs)[0]

    # Generate permutation indices
    index = tf.range(seq_len, dtype=tf.int64)
    index = tf.reshape(index, [-1, perm_size])
    index = tf.transpose(index)
    index = tf.random_shuffle(index)
    index = tf.transpose(index)
    index = tf.reshape(index, [1, -1])
    index = tf.tile(index, [batch_size, 1])

    # `perm_mask` and `target_mask`
    # non-functional tokens
    non_func_tokens = tf.logical_not(
        tf.logical_or(tf.equal(inputs, SEP_ID), tf.equal(inputs, CLS_ID)))

    non_mask_tokens = tf.logical_and(tf.logical_not(is_masked),
                                     non_func_tokens)
    masked_or_func_tokens = tf.logical_not(non_mask_tokens)

    # Set the permutation indices of non-masked (& non-funcional) tokens to the
    # smallest index (-1):
    # (1) they can be seen by all other positions
    # (2) they cannot see masked positions, so there won't be information leak
    smallest_index = -tf.ones([batch_size, seq_len], dtype=tf.int64)
    rev_index = tf.where(non_mask_tokens, smallest_index, index)

    # Create `target_mask`: non-funcional and maksed tokens
    # 1: use mask as input and have loss
    # 0: use token (or [SEP], [CLS]) as input and do not have loss
    target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens)
    target_mask = tf.cast(target_tokens, tf.float32)

    # Create `perm_mask`
    # `target_tokens` cannot see themselves
    self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1)

    # 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens)
    # 0: can attend if i > j or j is non-masked
    perm_mask = tf.logical_and(
        self_rev_index[:, :, None] <= rev_index[:, None, :],
        tf.expand_dims(masked_or_func_tokens, axis=-1))

    # new target: [next token] for LM and [curr token] (self) for PLM
    new_targets = tf.concat([inputs[:, 0:1], targets[:, :-1]], axis=1)

    # construct inputs_k
    inputs_k = inputs

    # construct inputs_q
    inputs_q = target_mask

    return perm_mask, new_targets, target_mask, inputs_k, inputs_q
Example #27
0
def _expand_features(module, split_placeholders):

    inputs = split_placeholders['input']
    target = split_placeholders['target']
    is_masked = tf.cast(split_placeholders['is_masked'], tf.bool)
    batch_size = tf.shape(inputs)[0]

    non_reuse_len = module.max_seq_length - module.reuse_seq_length
    assert (module.perm_size <= module.reuse_seq_length
            and module.perm_size <= non_reuse_len)

    (perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0) = \
        _local_perm(
            inputs[:, :module.reuse_seq_length],
            target[:, :module.reuse_seq_length],
            is_masked[:, :module.reuse_seq_length],
            module.perm_size,
            module.reuse_seq_length)

    (perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1) = \
        _local_perm(
            inputs[:, module.reuse_seq_length:],
            target[:, module.reuse_seq_length:],
            is_masked[:, module.reuse_seq_length:],
            module.perm_size,
            non_reuse_len)

    perm_mask_0 = tf.concat([
        tf.cast(perm_mask_0, dtype=tf.float32),
        tf.ones([batch_size, module.reuse_seq_length, non_reuse_len])
    ],
                            axis=2)
    perm_mask_1 = tf.concat([
        tf.zeros([batch_size, non_reuse_len, module.reuse_seq_length]),
        tf.cast(perm_mask_1, dtype=tf.float32)
    ],
                            axis=2)
    perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=1)
    target = tf.concat([target_0, target_1], axis=1)
    target_mask = tf.concat([target_mask_0, target_mask_1], axis=1)
    input_k = tf.concat([input_k_0, input_k_1], axis=1)
    input_q = tf.concat([input_q_0, input_q_1], axis=1)

    if module._num_predict is not None:
        #TODO(geying): convert tensors from 1-D to 2-D

        indices = tf.range(module.max_seq_length, dtype=tf.int64)
        indices = tf.reshape(indices, [-1, module.max_seq_length])
        indices = tf.tile(indices, [batch_size, 1])
        bool_target_mask = tf.cast(target_mask, tf.bool)
        indices = tf.boolean_mask(indices, bool_target_mask)

        ##### extra padding due to CLS/SEP introduced after prepro
        actual_num_predict = tf.shape(indices)[1]
        pad_len = module._num_predict - actual_num_predict

        ##### target_mapping
        target_mapping = tf.one_hot(indices,
                                    module.max_seq_length,
                                    dtype=tf.float32)
        paddings = tf.zeros([pad_len, module.max_seq_length],
                            dtype=target_mapping.dtype)
        target_mapping = tf.concat([target_mapping, paddings], axis=0)
        split_placeholders['target_mapping'] = tf.reshape(
            target_mapping, [-1, module._num_predict, module.max_seq_length])

        ##### target
        target = tf.boolean_mask(target, bool_target_mask)
        paddings = tf.zeros([pad_len], dtype=target.dtype)
        target = tf.concat([target, paddings], axis=0)
        split_placeholders['target'] = tf.reshape(target,
                                                  [-1, module._num_predict])

        ##### target mask
        target_mask = tf.concat([
            tf.ones([batch_size, actual_num_predict], dtype=tf.float32),
            tf.zeros([batch_size, pad_len], dtype=tf.float32)
        ],
                                axis=1)
        split_placeholders['target_mask'] = tf.reshape(
            target_mask, [-1, module._num_predict])
    else:
        split_placeholders['target'] = tf.reshape(target,
                                                  [-1, module.max_seq_length])
        split_placeholders['target_mask'] = tf.reshape(
            target_mask, [-1, module.max_seq_length])

    # reshape back to fixed shape
    split_placeholders['perm_mask'] = tf.reshape(
        perm_mask, [-1, module.max_seq_length, module.max_seq_length])
    split_placeholders['input_k'] = tf.reshape(input_k,
                                               [-1, module.max_seq_length])
    split_placeholders['input_q'] = tf.reshape(input_q,
                                               [-1, module.max_seq_length])

    return split_placeholders
Example #28
0
                def _build_forward(layer_input):
                    with tf.variable_scope('attention'):
                        attention_heads = []
                        with tf.variable_scope('self'):
                            (attention_head, attention_scores) = \
                                self.attention_layer(
                                    from_tensor=layer_input,
                                    to_tensor=layer_input,
                                    attention_mask=attention_mask,
                                    num_attention_heads=num_attention_heads,
                                    size_per_head=attention_head_size,
                                    attention_probs_dropout_prob=\
                                        attention_probs_dropout_prob,
                                    initializer_range=initializer_range,
                                    do_return_2d_tensor=True,
                                    batch_size=batch_size,
                                    from_max_seq_length=max_seq_length,
                                    to_max_seq_length=max_seq_length,
                                    dtype=dtype,
                                    trainable=trainable)
                            attention_heads.append(attention_head)
                            self.attention_scores.append(attention_scores)

                        attention_output = None
                        if len(attention_heads) == 1:
                            attention_output = attention_heads[0]
                        else:
                            attention_output = tf.concat(attention_heads,
                                                         axis=-1)

                        with tf.variable_scope('output'):
                            attention_output = tf.layers.dense(
                                attention_output,
                                hidden_size,
                                kernel_initializer=util.create_initializer(
                                    initializer_range),
                                trainable=trainable)
                            attention_output = util.dropout(
                                attention_output, hidden_dropout_prob)
                            attention_output = util.layer_norm(
                                attention_output + layer_input,
                                trainable=trainable)

                    # The activation is only applied to the `intermediate`
                    # hidden layer.
                    with tf.variable_scope('intermediate'):
                        intermediate_output = tf.layers.dense(
                            attention_output,
                            intermediate_size,
                            activation=intermediate_act_fn,
                            kernel_initializer=util.create_initializer(
                                initializer_range),
                            trainable=trainable)

                    # Down-project back to hidden_size then add the residual.
                    with tf.variable_scope('output'):
                        layer_output = tf.layers.dense(
                            intermediate_output,
                            hidden_size,
                            kernel_initializer=util.create_initializer(
                                initializer_range),
                            trainable=trainable)
                        layer_output = util.dropout(layer_output,
                                                    hidden_dropout_prob)
                        layer_output = util.layer_norm(layer_output +
                                                       attention_output,
                                                       trainable=trainable)

                    return layer_output
Example #29
0
def transformer_model(input_tensor,
                      attention_mask=None,
                      hidden_size=768,
                      num_hidden_layers=12,
                      num_attention_heads=12,
                      intermediate_size=3072,
                      intermediate_act_fn=util.gelu,
                      hidden_dropout_prob=0.1,
                      attention_probs_dropout_prob=0.1,
                      initializer_range=0.02,
                      do_return_all_layers=False):
    '''Multi-headed, multi-layer Transformer from 'Attention is All You Need'.

  This is almost an exact implementation of the original Transformer encoder.

  See the original paper:
  https://arxiv.org/abs/1706.03762

  Also see:
  https://github.com/tensorflow/tensor2tensor/blob/master/
    tensor2tensor/models/transformer.py

  Args:
    input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
    attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
      seq_length], with 1 for positions that can be attended to and 0 in
      positions that should not be.
    hidden_size: int. Hidden size of the Transformer.
    num_hidden_layers: int. Number of layers (blocks) in the Transformer.
    num_attention_heads: int. Number of attention heads in the Transformer.
    intermediate_size: int. The size of the 'intermediate' (a.k.a., feed
      forward) layer.
    intermediate_act_fn: function. The non-linear activation function to apply
      to the output of the intermediate/feed-forward layer.
    hidden_dropout_prob: float. Dropout probability for the hidden layers.
    attention_probs_dropout_prob: float. Dropout probability of the attention
      probabilities.
    initializer_range: float. Range of the initializer (stddev of truncated
      normal).
    do_return_all_layers: Whether to also return all layers or just the final
      layer.

  Returns:
    float Tensor of shape [batch_size, seq_length, hidden_size], the final
    hidden layer of the Transformer.

  Raises:
    ValueError: A Tensor shape or parameter is invalid.
  '''
    if hidden_size % num_attention_heads != 0:
        raise ValueError(
            'The hidden size (%d) is not a multiple of the number of attention '
            'heads (%d)' % (hidden_size, num_attention_heads))

    attention_head_size = int(hidden_size / num_attention_heads)
    input_shape = util.get_shape_list(input_tensor, expected_rank=3)
    batch_size = input_shape[0]
    seq_length = input_shape[1]
    input_width = input_shape[2]

    # The Transformer performs sum residuals on all layers so the input needs
    # to be the same as the hidden size.
    if input_width != hidden_size:
        raise ValueError(
            'The width of the input tensor (%d) != hidden size (%d)' %
            (input_width, hidden_size))

    # We keep the representation as a 2D tensor to avoid re-shaping it back and
    # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on
    # the GPU/CPU but may not be free on the TPU, so we want to minimize them to
    # help the optimizer.
    prev_output = util.reshape_to_matrix(input_tensor)

    attn_maps = []
    all_layer_outputs = []
    for layer_idx in range(num_hidden_layers):
        with tf.variable_scope('layer_%d' % layer_idx):
            with tf.variable_scope('attention'):
                attention_heads = []
                with tf.variable_scope('self'):
                    attention_head, probs = attention_layer(
                        from_tensor=prev_output,
                        to_tensor=prev_output,
                        attention_mask=attention_mask,
                        num_attention_heads=num_attention_heads,
                        size_per_head=attention_head_size,
                        attention_probs_dropout_prob=
                        attention_probs_dropout_prob,
                        initializer_range=initializer_range,
                        do_return_2d_tensor=True,
                        batch_size=batch_size,
                        from_seq_length=seq_length,
                        to_seq_length=seq_length)
                    attention_heads.append(attention_head)
                    attn_maps.append(probs)

                attention_output = None
                if len(attention_heads) == 1:
                    attention_output = attention_heads[0]
                else:
                    # In the case where we have other sequences, we just concatenate
                    # them to the self-attention head before the projection.
                    attention_output = tf.concat(attention_heads, axis=-1)

                # Run a linear projection of `hidden_size` then add a residual
                # with `layer_input`.
                with tf.variable_scope('output'):
                    attention_output = tf.layers.dense(
                        attention_output,
                        hidden_size,
                        kernel_initializer=util.create_initializer(
                            initializer_range))
                    attention_output = util.dropout(attention_output,
                                                    hidden_dropout_prob)
                    attention_output = util.layer_norm(attention_output +
                                                       prev_output)

            # The activation is only applied to the 'intermediate' hidden layer.
            with tf.variable_scope('intermediate'):
                intermediate_output = tf.layers.dense(
                    attention_output,
                    intermediate_size,
                    activation=intermediate_act_fn,
                    kernel_initializer=util.create_initializer(
                        initializer_range))

            # Down-project back to `hidden_size` then add the residual.
            with tf.variable_scope('output'):
                prev_output = tf.layers.dense(
                    intermediate_output,
                    hidden_size,
                    kernel_initializer=util.create_initializer(
                        initializer_range))
                prev_output = util.dropout(prev_output, hidden_dropout_prob)
                prev_output = util.layer_norm(prev_output + attention_output)
                all_layer_outputs.append(prev_output)

    attn_maps = tf.stack(attn_maps, 0)
    if do_return_all_layers:
        return tf.stack([
            util.reshape_from_matrix(layer, input_shape)
            for layer in all_layer_outputs
        ], 0), attn_maps
    else:
        return util.reshape_from_matrix(prev_output, input_shape), attn_maps
Example #30
0
    def __init__(self,
                 bert_config,
                 is_training,
                 input_ids,
                 input_mask,
                 segment_ids,
                 sample_weight=None,
                 scope='bert',
                 dtype=tf.float32,
                 drop_pooler=False,
                 cls_model='self-attention',
                 label_size=2,
                 speed=0.1,
                 ignore_cls='0',
                 **kwargs):
        super(FastBERTCLSDistillor, self).__init__()

        if not ignore_cls:
            ignore_cls = []
        if isinstance(ignore_cls, str):
            ignore_cls = ignore_cls.replace(' ', '').split(',')
            ignore_cls = list(map(int, ignore_cls))
        elif isinstance(ignore_cls, list):
            ignore_cls = list(map(int, ignore_cls))
        else:
            raise ValueError(
                '`ignore_cls` should be a list of child-classifier ids or '
                'a string seperated with commas.')

        if not speed:
            raise ValueError(
                '`speed` should be a float number between `0` and `1`.')

        bert_config = copy.deepcopy(bert_config)
        bert_config.hidden_dropout_prob = 0.0
        bert_config.attention_probs_dropout_prob = 0.0

        input_shape = util.get_shape_list(input_ids, expected_rank=2)
        batch_size = input_shape[0]
        max_seq_length = input_shape[1]

        with tf.variable_scope(scope):
            with tf.variable_scope('embeddings'):

                (self.embedding_output, self.embedding_table) = \
                    self.embedding_lookup(
                        input_ids=input_ids,
                        vocab_size=bert_config.vocab_size,
                        batch_size=batch_size,
                        max_seq_length=max_seq_length,
                        embedding_size=bert_config.hidden_size,
                        initializer_range=bert_config.initializer_range,
                        word_embedding_name='word_embeddings',
                        dtype=dtype,
                        trainable=False,
                        tilda_embeddings=None)

                # Add positional embeddings and token type embeddings
                # layer normalize and perform dropout.
                self.embedding_output = self.embedding_postprocessor(
                    input_tensor=self.embedding_output,
                    batch_size=batch_size,
                    max_seq_length=max_seq_length,
                    hidden_size=bert_config.hidden_size,
                    use_token_type=True,
                    segment_ids=segment_ids,
                    token_type_vocab_size=bert_config.type_vocab_size,
                    token_type_embedding_name='token_type_embeddings',
                    use_position_embeddings=True,
                    position_embedding_name='position_embeddings',
                    initializer_range=bert_config.initializer_range,
                    max_position_embeddings=\
                        bert_config.max_position_embeddings,
                    dropout_prob=bert_config.hidden_dropout_prob,
                    dtype=dtype,
                    trainable=False)

            with tf.variable_scope('encoder'):
                attention_mask = self.create_attention_mask_from_input_mask(
                    input_mask, batch_size, max_seq_length, dtype=dtype)

                # stacked transformers
                (self.all_encoder_layers, self.all_cls_layers) = \
                    self.dynamic_transformer_model(
                        is_training,
                        input_tensor=self.embedding_output,
                        input_mask=input_mask,
                        batch_size=batch_size,
                        max_seq_length=max_seq_length,
                        label_size=label_size,
                        attention_mask=attention_mask,
                        hidden_size=bert_config.hidden_size,
                        num_hidden_layers=bert_config.num_hidden_layers,
                        num_attention_heads=bert_config.num_attention_heads,
                        intermediate_size=bert_config.intermediate_size,
                        intermediate_act_fn=util.get_activation(
                            bert_config.hidden_act),
                        hidden_dropout_prob=bert_config.hidden_dropout_prob,
                        attention_probs_dropout_prob=\
                            bert_config.attention_probs_dropout_prob,
                        initializer_range=bert_config.initializer_range,
                        dtype=dtype,
                        cls_model=cls_model,
                        speed=speed,
                        ignore_cls=ignore_cls)

            self.sequence_output = self.all_encoder_layers[-1]
            with tf.variable_scope('pooler'):
                first_token_tensor = self.sequence_output[:, 0, :]

                # trick: ignore the fully connected layer
                if drop_pooler:
                    self.pooled_output = first_token_tensor
                else:
                    self.pooled_output = tf.layers.dense(
                        first_token_tensor,
                        bert_config.hidden_size,
                        activation=tf.tanh,
                        kernel_initializer=util.create_initializer(
                            bert_config.initializer_range),
                        trainable=False)

        # teacher classifier
        if bert_config.num_hidden_layers not in ignore_cls:
            with tf.variable_scope('cls/seq_relationship'):
                output_weights = tf.get_variable(
                    'output_weights',
                    shape=[label_size, bert_config.hidden_size],
                    initializer=util.create_initializer(
                        bert_config.initializer_range),
                    trainable=False)
                output_bias = tf.get_variable(
                    'output_bias',
                    shape=[label_size],
                    initializer=tf.zeros_initializer(),
                    trainable=False)

                logits = tf.matmul(self.pooled_output,
                                   output_weights,
                                   transpose_b=True)
                logits = tf.nn.bias_add(logits, output_bias)
                probs = tf.nn.softmax(logits, axis=-1)

        # distillation
        if is_training:
            losses = []
            for cls_probs in self.all_cls_layers.values():

                # KL-Divergence
                per_example_loss = tf.reduce_sum(
                    cls_probs * (tf.log(cls_probs) - tf.log(probs)), axis=-1)
                if sample_weight is not None:
                    per_example_loss *= tf.cast(sample_weight,
                                                dtype=tf.float32)
                loss = tf.reduce_mean(per_example_loss)
                losses.append(loss)

            distill_loss = tf.add_n(losses)
            self.total_loss = distill_loss
            self.losses['losses'] = distill_loss

        else:
            if bert_config.num_hidden_layers not in ignore_cls:
                self.all_cls_layers[bert_config.num_hidden_layers] = probs
            self.probs['probs'] = tf.concat(list(self.all_cls_layers.values()),
                                            axis=0,
                                            name='probs')