Exemple #1
0
def scatter_update(sequence, updates, positions):
    '''Scatter-update a sequence.

    Args:
      sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] tensor
      updates: A tensor of size batch_size*seq_len(*depth)
      positions: A [batch_size, n_positions] tensor

    Returns: A tuple of two tensors. First is a [batch_size, seq_len] or
      [batch_size, seq_len, depth] tensor of 'sequence' with elements at
      'positions' replaced by the values at 'updates.' Updates to index 0 are
      ignored. If there are duplicated positions the update is only applied
      once. Second is a [batch_size, seq_len] mask tensor of which inputs were
      updated.
    '''
    shape = util.get_shape_list(sequence, expected_rank=[2, 3])
    depth_dimension = (len(shape) == 3)
    if depth_dimension:
        B, L, D = shape
    else:
        B, L = shape
        D = 1
        sequence = tf.expand_dims(sequence, -1)
    N = util.get_shape_list(positions)[1]

    shift = tf.expand_dims(L * tf.range(B), -1)
    flat_positions = tf.reshape(positions + shift, [-1, 1])
    flat_updates = tf.reshape(updates, [-1, D])
    updates = tf.scatter_nd(flat_positions, flat_updates, [B * L, D])
    updates = tf.reshape(updates, [B, L, D])

    flat_updates_mask = tf.ones([B * N], tf.int32)
    updates_mask = tf.scatter_nd(flat_positions, flat_updates_mask, [B * L])
    updates_mask = tf.reshape(updates_mask, [B, L])
    not_first_token = tf.concat(
        [tf.zeros((B, 1), tf.int32),
         tf.ones((B, L - 1), tf.int32)], -1)
    updates_mask *= not_first_token
    updates_mask_3d = tf.expand_dims(updates_mask, -1)

    # account for duplicate positions
    if sequence.dtype == tf.float32:
        updates_mask_3d = tf.cast(updates_mask_3d, tf.float32)
        updates /= tf.maximum(1.0, updates_mask_3d)
    else:
        assert sequence.dtype == tf.int32
        updates = tf.divide(updates, tf.maximum(1, updates_mask_3d))
        updates = tf.cast(updates, tf.int32)
    updates_mask = tf.minimum(updates_mask, 1)
    updates_mask_3d = tf.minimum(updates_mask_3d, 1)

    updated_sequence = (((1 - updates_mask_3d) * sequence) +
                        (updates_mask_3d * updates))
    if not depth_dimension:
        updated_sequence = tf.squeeze(updated_sequence, -1)

    return updated_sequence, updates_mask
Exemple #2
0
def dot_product_attention(q, k, v, bias, dropout_rate=0.0):
    """Dot-product attention.

  Args:
    q: Tensor with shape [..., length_q, depth_k].
    k: Tensor with shape [..., length_kv, depth_k]. Leading dimensions must
      match with q.
    v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must
      match with q.
    bias: bias Tensor (see attention_bias())
    dropout_rate: a float.

  Returns:
    Tensor with shape [..., length_q, depth_v].
  """
    logits = tf.matmul(q, k, transpose_b=True)  # [..., length_q, length_kv]
    logits = tf.multiply(logits,
                         1.0 / math.sqrt(float(util.get_shape_list(q)[-1])))
    if bias is not None:
        # `attention_mask` = [B, T]
        from_shape = util.get_shape_list(q)
        if len(from_shape) == 4:
            broadcast_ones = tf.ones([from_shape[0], 1, from_shape[2], 1],
                                     tf.float32)
        elif len(from_shape) == 5:
            # from_shape = [B, N, Block_num, block_size, depth]#
            broadcast_ones = tf.ones(
                [from_shape[0], 1, from_shape[2], from_shape[3], 1],
                tf.float32)

        bias = tf.matmul(broadcast_ones,
                         tf.cast(bias, tf.float32),
                         transpose_b=True)

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        adder = (1.0 - bias) * -10000.0

        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        logits += adder
    else:
        adder = 0.0

    attention_probs = tf.nn.softmax(logits, name="attention_probs")
    attention_probs = util.dropout(attention_probs, dropout_rate)
    return tf.matmul(attention_probs, v)
Exemple #3
0
def create_attention_mask_from_input_mask(from_tensor, to_mask):
    '''Create 3D attention mask from a 2D tensor mask.

  Args:
    from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
    to_mask: int32 Tensor of shape [batch_size, to_seq_length].

  Returns:
    float Tensor of shape [batch_size, from_seq_length, to_seq_length].
  '''
    from_shape = util.get_shape_list(from_tensor, expected_rank=[2, 3])
    batch_size = from_shape[0]
    from_seq_length = from_shape[1]

    to_shape = util.get_shape_list(to_mask, expected_rank=2)
    to_seq_length = to_shape[1]

    to_mask = tf.cast(tf.reshape(to_mask, [batch_size, 1, to_seq_length]),
                      tf.float32)

    # We don't assume that `from_tensor` is a mask (although it could be). We
    # don't actually care if we attend *from* padding tokens (only *to* padding)
    # tokens so we create a tensor of all ones.
    #
    # `broadcast_ones` = [batch_size, from_seq_length, 1]
    broadcast_ones = tf.ones(shape=[batch_size, from_seq_length, 1],
                             dtype=tf.float32)

    # Here we broadcast along two dimensions to create the mask.
    mask = broadcast_ones * to_mask

    return mask
Exemple #4
0
 def create_attention_mask_from_input_mask(to_mask,
                                           batch_size,
                                           max_seq_length,
                                           dtype=tf.float32):
     to_mask = tf.cast(tf.reshape(to_mask, [batch_size, 1, max_seq_length]),
                       dtype=dtype)
     broadcast_ones = tf.ones(shape=[batch_size, max_seq_length, 1],
                              dtype=dtype)
     mask = broadcast_ones * to_mask
     return mask
Exemple #5
0
    def _forward(self, is_training, split_placeholders, **kwargs):

        if not is_training:
            return super()._forward(is_training, split_placeholders, **kwargs)

        aug_input_ids = tf.boolean_mask(
            split_placeholders['aug_input_ids'],
            mask=(1.0 - split_placeholders['is_supervised']),
            axis=0)
        aug_input_mask = tf.boolean_mask(
            split_placeholders['aug_input_mask'],
            mask=(1.0 - split_placeholders['is_supervised']),
            axis=0)
        aug_segment_ids = tf.boolean_mask(
            split_placeholders['aug_segment_ids'],
            mask=(1.0 - split_placeholders['is_supervised']),
            axis=0)
        input_ids = tf.concat([split_placeholders['input_ids'], aug_input_ids],
                              axis=0)
        input_mask = tf.concat(
            [split_placeholders['input_mask'], aug_input_mask], axis=0)
        segment_ids = tf.concat(
            [split_placeholders['segment_ids'], aug_segment_ids], axis=0)
        encoder = BERTEncoder(bert_config=self.bert_config,
                              is_training=is_training,
                              input_ids=input_ids,
                              input_mask=input_mask,
                              segment_ids=segment_ids,
                              scope='bert',
                              drop_pooler=self._drop_pooler,
                              **kwargs)
        encoder_output = encoder.get_pooled_output()

        label_ids = split_placeholders['label_ids']
        is_expanded = tf.zeros_like(label_ids, dtype=tf.float32)
        batch_size = util.get_shape_list(aug_input_ids)[0]
        aug_is_expanded = tf.ones((batch_size), dtype=tf.float32)
        is_expanded = tf.concat([is_expanded, aug_is_expanded], axis=0)
        decoder = UDADecoder(
            is_training=is_training,
            input_tensor=encoder_output,
            is_supervised=split_placeholders['is_supervised'],
            is_expanded=is_expanded,
            label_ids=label_ids,
            label_size=self.label_size,
            sample_weight=split_placeholders.get('sample_weight'),
            scope='cls/seq_relationship',
            global_step=self._global_step,
            num_train_steps=self.total_steps,
            uda_softmax_temp=self._uda_softmax_temp,
            uda_confidence_thresh=self._uda_confidence_thresh,
            tsa_schedule=self._tsa_schedule,
            **kwargs)
        (total_loss, losses, probs, preds) = decoder.get_forward_outputs()
        return (total_loss, losses, probs, preds)
Exemple #6
0
def noncausal_denominator(qs, ks):
    '''Computes FAVOR normalizer in noncausal attention.
  Args:
    qs: query_prime tensor of the shape [L,B,H,M].
    ks: key_prime tensor of the shape [L,B,H,M].
  Returns:
    FAVOR normalizer in noncausal attention.
  '''
    all_ones = tf.ones([ks.shape[0]])
    ks_sum = tf.einsum('lbhm,l->bhm', ks, all_ones)
    return tf.einsum('lbhm,bhm->lbh', qs, ks_sum)
Exemple #7
0
def _create_mask(qlen, mlen, dtype=tf.float32, same_length=False):
    '''create causal attention mask.'''
    attn_mask = tf.ones([qlen, qlen], dtype=dtype)
    mask_u = tf.matrix_band_part(attn_mask, 0, -1)
    mask_dia = tf.matrix_band_part(attn_mask, 0, 0)
    attn_mask_pad = tf.zeros([qlen, mlen], dtype=dtype)
    ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1)
    if same_length:
        mask_l = tf.matrix_band_part(attn_mask, -1, 0)
        ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1)

    return ret
Exemple #8
0
def create_projection_matrix(m, d, seed=0, scaling=0, struct_mode=False):
    r'''Constructs the matrix of random projections.
  Constructs a matrix of random orthogonal projections. Each projection vector
  has direction chosen uniformly at random and either deterministic length
  \sqrt{d} or length taken from the \chi(d) distribution (in the latter case
  marginal distributions of the projections are d-dimensional Gaussian vectors
  with associated identity covariance matrix).
  Args:
    m: number of random projections.
    d: dimensionality of each random projection.
    seed: random seed used to construct projections.
    scaling: 1 if all the random projections need to be renormalized to have
      length \sqrt{d}, 0 if the lengths of random projections should follow
      \chi(d) distribution.
    struct_mode: if True then products of Givens rotations will be used to
      construct random orthogonal matrix. This bypasses Gram-Schmidt
      orthogonalization.
  Returns:
    The matrix of random projections of the shape [m, d].
  '''
    nb_full_blocks = int(m / d)
    block_list = []
    current_seed = seed
    for _ in range(nb_full_blocks):
        if struct_mode:
            q = create_products_of_givens_rotations(d, seed)
        else:
            unstructured_block = tf.random_normal((d, d), seed=current_seed)
            q, _ = tf.linalg.qr(unstructured_block)
            q = tf.transpose(q)
        block_list.append(q)
        current_seed += 1
    remaining_rows = m - nb_full_blocks * d
    if remaining_rows > 0:
        if struct_mode:
            q = create_products_of_givens_rotations(d, seed)
        else:
            unstructured_block = tf.random_normal((d, d), seed=current_seed)
            q, _ = tf.linalg.qr(unstructured_block)
            q = tf.transpose(q)
        block_list.append(q[0:remaining_rows])
    final_matrix = tf.concat(block_list, axis=0)
    current_seed += 1

    if scaling == 0:
        multiplier = tf.norm(tf.random_normal((m, d), seed=current_seed),
                             axis=1)
    elif scaling == 1:
        multiplier = 1 / tf.math.rsqrt(float(d)) * tf.ones((m))
    else:
        raise ValueError('Scaling must be one of {0, 1}. Was %s' % scaling)

    return tf.matmul(tf.linalg.diag(multiplier), final_matrix)
Exemple #9
0
    def create_attention_mask_from_input_mask(self,
                                              input_mask,
                                              batch_size,
                                              max_seq_length,
                                              dtype=tf.float32):
        if self._mode == 'bi':
            to_mask = tf.cast(tf.reshape(
                input_mask, [batch_size, 1, max_seq_length]), dtype=dtype)
            broadcast_ones = tf.ones(
                shape=[batch_size, max_seq_length, 1], dtype=dtype)
            mask = broadcast_ones * to_mask

        elif self._mode == 'l2r':
            arange = tf.range(max_seq_length) + 1
            to_mask = tf.cast(tf.sequence_mask(arange, max_seq_length), dtype)
            to_mask = tf.reshape(to_mask, [1, max_seq_length, max_seq_length])
            mask = tf.tile(to_mask, [batch_size, 1, 1])

        elif self._mode == 'r2l':
            to_mask = tf.cast(tf.reshape(
                input_mask, [batch_size, 1, max_seq_length]), dtype=dtype)
            broadcast_ones = tf.ones(
                shape=[batch_size, max_seq_length, 1], dtype=dtype)
            cover_mask = broadcast_ones * to_mask

            arange = tf.range(max_seq_length)
            reverse = tf.cast(tf.sequence_mask(arange, max_seq_length), dtype)
            reverse = tf.reshape(reverse, [1, max_seq_length, max_seq_length])
            reverse_mask = tf.tile(reverse, [batch_size, 1, 1])

            mask = (1 - reverse_mask) * cover_mask

        elif self._mode == 's2s':
            mask = tf.cast(
                tf.sequence_mask(input_mask, max_seq_length), dtype)

        return mask
Exemple #10
0
    def create_attention_mask_from_input_mask(self,
                                              input_mask,
                                              batch_size,
                                              max_seq_length,
                                              dtype=tf.float32):
        to_mask = tf.cast(tf.reshape(input_mask,
                                     [batch_size, 1, max_seq_length]),
                          dtype=dtype)
        broadcast_ones = tf.ones(shape=[batch_size, max_seq_length, 1],
                                 dtype=dtype)
        mask = broadcast_ones * to_mask

        broadcast_eye = tf.tile(
            tf.reshape(tf.eye(max_seq_length),
                       [1, max_seq_length, max_seq_length]),
            [batch_size, 1, 1])
        mask += broadcast_eye
        mask = tf.cast(tf.greater(mask, 0), dtype)
        return mask
Exemple #11
0
def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
    '''
  Sample a permutation of the factorization order, and create an
  attention mask accordingly.

  Args:
    inputs: int64 Tensor in shape [seq_len], input ids.
    targets: int64 Tensor in shape [seq_len], target ids.
    is_masked: bool Tensor in shape [seq_len]. True means being selected
      for partial prediction.
    perm_size: the length of longest permutation. Could be set to be reuse_len.
      Should not be larger than reuse_len or there will be data leaks.
    seq_len: int, sequence length.
  '''
    batch_size = tf.shape(inputs)[0]

    # Generate permutation indices
    index = tf.range(seq_len, dtype=tf.int64)
    index = tf.reshape(index, [-1, perm_size])
    index = tf.transpose(index)
    index = tf.random_shuffle(index)
    index = tf.transpose(index)
    index = tf.reshape(index, [1, -1])
    index = tf.tile(index, [batch_size, 1])

    # `perm_mask` and `target_mask`
    # non-functional tokens
    non_func_tokens = tf.logical_not(
        tf.logical_or(tf.equal(inputs, SEP_ID), tf.equal(inputs, CLS_ID)))

    non_mask_tokens = tf.logical_and(tf.logical_not(is_masked),
                                     non_func_tokens)
    masked_or_func_tokens = tf.logical_not(non_mask_tokens)

    # Set the permutation indices of non-masked (& non-funcional) tokens to the
    # smallest index (-1):
    # (1) they can be seen by all other positions
    # (2) they cannot see masked positions, so there won't be information leak
    smallest_index = -tf.ones([batch_size, seq_len], dtype=tf.int64)
    rev_index = tf.where(non_mask_tokens, smallest_index, index)

    # Create `target_mask`: non-funcional and maksed tokens
    # 1: use mask as input and have loss
    # 0: use token (or [SEP], [CLS]) as input and do not have loss
    target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens)
    target_mask = tf.cast(target_tokens, tf.float32)

    # Create `perm_mask`
    # `target_tokens` cannot see themselves
    self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1)

    # 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens)
    # 0: can attend if i > j or j is non-masked
    perm_mask = tf.logical_and(
        self_rev_index[:, :, None] <= rev_index[:, None, :],
        tf.expand_dims(masked_or_func_tokens, axis=-1))

    # new target: [next token] for LM and [curr token] (self) for PLM
    new_targets = tf.concat([inputs[:, 0:1], targets[:, :-1]], axis=1)

    # construct inputs_k
    inputs_k = inputs

    # construct inputs_q
    inputs_q = target_mask

    return perm_mask, new_targets, target_mask, inputs_k, inputs_q
Exemple #12
0
def _expand_features(module, split_placeholders):

    inputs = split_placeholders['input']
    target = split_placeholders['target']
    is_masked = tf.cast(split_placeholders['is_masked'], tf.bool)
    batch_size = tf.shape(inputs)[0]

    non_reuse_len = module.max_seq_length - module.reuse_seq_length
    assert (module.perm_size <= module.reuse_seq_length
            and module.perm_size <= non_reuse_len)

    (perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0) = \
        _local_perm(
            inputs[:, :module.reuse_seq_length],
            target[:, :module.reuse_seq_length],
            is_masked[:, :module.reuse_seq_length],
            module.perm_size,
            module.reuse_seq_length)

    (perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1) = \
        _local_perm(
            inputs[:, module.reuse_seq_length:],
            target[:, module.reuse_seq_length:],
            is_masked[:, module.reuse_seq_length:],
            module.perm_size,
            non_reuse_len)

    perm_mask_0 = tf.concat([
        tf.cast(perm_mask_0, dtype=tf.float32),
        tf.ones([batch_size, module.reuse_seq_length, non_reuse_len])
    ],
                            axis=2)
    perm_mask_1 = tf.concat([
        tf.zeros([batch_size, non_reuse_len, module.reuse_seq_length]),
        tf.cast(perm_mask_1, dtype=tf.float32)
    ],
                            axis=2)
    perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=1)
    target = tf.concat([target_0, target_1], axis=1)
    target_mask = tf.concat([target_mask_0, target_mask_1], axis=1)
    input_k = tf.concat([input_k_0, input_k_1], axis=1)
    input_q = tf.concat([input_q_0, input_q_1], axis=1)

    if module._num_predict is not None:
        #TODO(geying): convert tensors from 1-D to 2-D

        indices = tf.range(module.max_seq_length, dtype=tf.int64)
        indices = tf.reshape(indices, [-1, module.max_seq_length])
        indices = tf.tile(indices, [batch_size, 1])
        bool_target_mask = tf.cast(target_mask, tf.bool)
        indices = tf.boolean_mask(indices, bool_target_mask)

        ##### extra padding due to CLS/SEP introduced after prepro
        actual_num_predict = tf.shape(indices)[1]
        pad_len = module._num_predict - actual_num_predict

        ##### target_mapping
        target_mapping = tf.one_hot(indices,
                                    module.max_seq_length,
                                    dtype=tf.float32)
        paddings = tf.zeros([pad_len, module.max_seq_length],
                            dtype=target_mapping.dtype)
        target_mapping = tf.concat([target_mapping, paddings], axis=0)
        split_placeholders['target_mapping'] = tf.reshape(
            target_mapping, [-1, module._num_predict, module.max_seq_length])

        ##### target
        target = tf.boolean_mask(target, bool_target_mask)
        paddings = tf.zeros([pad_len], dtype=target.dtype)
        target = tf.concat([target, paddings], axis=0)
        split_placeholders['target'] = tf.reshape(target,
                                                  [-1, module._num_predict])

        ##### target mask
        target_mask = tf.concat([
            tf.ones([batch_size, actual_num_predict], dtype=tf.float32),
            tf.zeros([batch_size, pad_len], dtype=tf.float32)
        ],
                                axis=1)
        split_placeholders['target_mask'] = tf.reshape(
            target_mask, [-1, module._num_predict])
    else:
        split_placeholders['target'] = tf.reshape(target,
                                                  [-1, module.max_seq_length])
        split_placeholders['target_mask'] = tf.reshape(
            target_mask, [-1, module.max_seq_length])

    # reshape back to fixed shape
    split_placeholders['perm_mask'] = tf.reshape(
        perm_mask, [-1, module.max_seq_length, module.max_seq_length])
    split_placeholders['input_k'] = tf.reshape(input_k,
                                               [-1, module.max_seq_length])
    split_placeholders['input_q'] = tf.reshape(input_q,
                                               [-1, module.max_seq_length])

    return split_placeholders
Exemple #13
0
    def __init__(self,
                 bert_config,
                 is_training,
                 input_ids,
                 input_mask=None,
                 token_type_ids=None,
                 use_one_hot_embeddings=True,
                 scope=None,
                 embedding_size=None,
                 input_embeddings=None,
                 input_reprs=None,
                 update_embeddings=True,
                 untied_embeddings=False):
        '''Constructor for BertModel.

        Args:
          bert_config: `BertConfig` instance.
          is_training: bool. true for training model, false for eval model.
            Controls whether dropout will be applied.
          input_ids: int32 Tensor of shape [batch_size, seq_length].
          input_mask: (optional) int32 Tensor of shape [batch_size,
            seq_length].
          token_type_ids: (optional) int32 Tensor of shape [batch_size,
            seq_length].
          use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
            embeddings or tf.embedding_lookup() for the word embeddings. On
            the TPU, it is much faster if this is True, on the CPU or GPU,
            it is faster if this is False.
          scope: (optional) variable scope. Defaults to 'electra'.

        Raises:
          ValueError: The config is invalid or one of the input tensor shapes
            is invalid.
        '''
        bert_config = copy.deepcopy(bert_config)
        if not is_training:
            bert_config.hidden_dropout_prob = 0.0
            bert_config.attention_probs_dropout_prob = 0.0

        input_shape = util.get_shape_list(token_type_ids, expected_rank=2)
        batch_size = input_shape[0]
        seq_length = input_shape[1]

        if input_mask is None:
            input_mask = tf.ones(shape=[batch_size, seq_length],
                                 dtype=tf.int32)

        assert token_type_ids is not None

        if input_reprs is None:
            with tf.variable_scope(
                ((scope if untied_embeddings else 'electra') + '/embeddings'),
                    reuse=tf.AUTO_REUSE):
                # Perform embedding lookup on the word ids
                if embedding_size is None:
                    embedding_size = bert_config.hidden_size
                (token_embeddings, self.embedding_table) = \
                    embedding_lookup(
                        input_ids=input_ids,
                        vocab_size=bert_config.vocab_size,
                        embedding_size=embedding_size,
                        initializer_range=bert_config.initializer_range,
                        word_embedding_name='word_embeddings',
                        use_one_hot_embeddings=use_one_hot_embeddings)

            with tf.variable_scope(
                ((scope if untied_embeddings else 'electra') + '/embeddings'),
                    reuse=tf.AUTO_REUSE):
                # Add positional embeddings and token type embeddings, then
                # layer normalize and perform dropout.
                self.embedding_output = embedding_postprocessor(
                    input_tensor=token_embeddings,
                    use_token_type=True,
                    token_type_ids=token_type_ids,
                    token_type_vocab_size=bert_config.type_vocab_size,
                    token_type_embedding_name='token_type_embeddings',
                    use_position_embeddings=True,
                    position_embedding_name='position_embeddings',
                    initializer_range=bert_config.initializer_range,
                    max_position_embeddings=\
                        bert_config.max_position_embeddings,
                    dropout_prob=bert_config.hidden_dropout_prob)
        else:
            self.embedding_output = input_reprs
        if not update_embeddings:
            self.embedding_output = tf.stop_gradient(self.embedding_output)

        with tf.variable_scope(scope, default_name='electra'):
            if self.embedding_output.shape[-1] != bert_config.hidden_size:
                self.embedding_output = tf.layers.dense(
                    self.embedding_output,
                    bert_config.hidden_size,
                    name='embeddings_project')

            with tf.variable_scope('encoder'):
                # This converts a 2D mask of shape [batch_size, seq_length]
                # to a 3D mask of shape [batch_size, seq_length, seq_length]
                # which is used for the attention scores.
                attention_mask = create_attention_mask_from_input_mask(
                    token_type_ids, input_mask)

                # Run the stacked transformer. Output shapes
                # attn_maps:
                #   [n_layers, batch_size, n_heads, seq_length, seq_length]
                (self.all_layer_outputs, self.attn_maps) = transformer_model(
                    input_tensor=self.embedding_output,
                    attention_mask=attention_mask,
                    hidden_size=bert_config.hidden_size,
                    num_hidden_layers=bert_config.num_hidden_layers,
                    num_attention_heads=bert_config.num_attention_heads,
                    intermediate_size=bert_config.intermediate_size,
                    intermediate_act_fn=util.get_activation(
                        bert_config.hidden_act),
                    hidden_dropout_prob=bert_config.hidden_dropout_prob,
                    attention_probs_dropout_prob=bert_config.
                    attention_probs_dropout_prob,
                    initializer_range=bert_config.initializer_range,
                    do_return_all_layers=True)
                self.sequence_output = self.all_layer_outputs[-1]
                self.pooled_output = self.sequence_output[:, 0]
Exemple #14
0
    def __init__(self,
                 albert_config,
                 is_training,
                 input_ids,
                 input_mask=None,
                 segment_ids=None,
                 scope='bert',
                 drop_pooler=False,
                 trainable=True,
                 **kwargs):
        """Constructor for AlbertModel.

    Args:
      albert_config: `AlbertConfig` instance.
      is_training: bool. true for training model, false for eval model.
        Controls whether dropout will be applied.
      input_ids: int32 Tensor of shape [batch_size, seq_length].
      input_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
      segment_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
      use_einsum: (optional) bool. Whether to use einsum or reshape+matmul for
        dense layers
      scope: (optional) variable scope. Defaults to "bert".

    Raises:
      ValueError: The config is invalid or one of the input tensor shapes
        is invalid.
    """
        albert_config = copy.deepcopy(albert_config)
        if not is_training:
            albert_config.hidden_dropout_prob = 0.0
            albert_config.attention_probs_dropout_prob = 0.0

        input_shape = util.get_shape_list(input_ids, expected_rank=2)
        batch_size = input_shape[0]
        seq_length = input_shape[1]

        if input_mask is None:
            input_mask = tf.ones(shape=[batch_size, seq_length],
                                 dtype=tf.int32)

        if segment_ids is None:
            segment_ids = tf.zeros(shape=[batch_size, seq_length],
                                   dtype=tf.int32)

        # Tilda embeddings for SMART algorithm
        tilda_embeddings = None
        use_tilda_embedding = kwargs.get('use_tilda_embedding')
        if use_tilda_embedding:
            with tf.variable_scope('', reuse=True):
                tilda_embeddings = tf.get_variable('tilda_embeddings')

        with tf.variable_scope(scope):
            with tf.variable_scope("embeddings"):
                # Perform embedding lookup on the word ids.
                (self.word_embedding_output,
                 self.output_embedding_table) = embedding_lookup(
                     input_ids=input_ids,
                     vocab_size=albert_config.vocab_size,
                     embedding_size=albert_config.embedding_size,
                     initializer_range=albert_config.initializer_range,
                     word_embedding_name="word_embeddings",
                     tilda_embeddings=tilda_embeddings,
                     trainable=trainable)

                # Add positional embeddings and token type embeddings, then layer
                # normalize and perform dropout.
                self.embedding_output = embedding_postprocessor(
                    input_tensor=self.word_embedding_output,
                    use_token_type=True,
                    segment_ids=segment_ids,
                    token_type_vocab_size=albert_config.type_vocab_size,
                    token_type_embedding_name="token_type_embeddings",
                    use_position_embeddings=True,
                    position_embedding_name="position_embeddings",
                    initializer_range=albert_config.initializer_range,
                    max_position_embeddings=albert_config.
                    max_position_embeddings,
                    dropout_prob=albert_config.hidden_dropout_prob,
                    trainable=trainable)

            with tf.variable_scope("encoder"):
                # Run the stacked transformer.
                # `sequence_output` shape = [batch_size, seq_length, hidden_size].
                self.all_encoder_layers = transformer_model(
                    input_tensor=self.embedding_output,
                    attention_mask=input_mask,
                    hidden_size=albert_config.hidden_size,
                    num_hidden_layers=albert_config.num_hidden_layers,
                    num_hidden_groups=albert_config.num_hidden_groups,
                    num_attention_heads=albert_config.num_attention_heads,
                    intermediate_size=albert_config.intermediate_size,
                    inner_group_num=albert_config.inner_group_num,
                    intermediate_act_fn=util.get_activation(
                        albert_config.hidden_act),
                    hidden_dropout_prob=albert_config.hidden_dropout_prob,
                    attention_probs_dropout_prob=albert_config.
                    attention_probs_dropout_prob,
                    initializer_range=albert_config.initializer_range,
                    do_return_all_layers=True,
                    use_einsum=False,
                    trainable=trainable)

            self.sequence_output = self.all_encoder_layers[-1]
            # The "pooler" converts the encoded sequence tensor of shape
            # [batch_size, seq_length, hidden_size] to a tensor of shape
            # [batch_size, hidden_size]. This is necessary for segment-level
            # (or segment-pair-level) classification tasks where we need a fixed
            # dimensional representation of the segment.
            with tf.variable_scope("pooler"):
                # We "pool" the model by simply taking the hidden state corresponding
                # to the first token. We assume that this has been pre-trained
                first_token_tensor = tf.squeeze(self.sequence_output[:,
                                                                     0:1, :],
                                                axis=1)

                # trick: ignore the fully connected layer
                if drop_pooler:
                    self.pooled_output = first_token_tensor
                else:
                    self.pooled_output = tf.layers.dense(
                        first_token_tensor,
                        albert_config.hidden_size,
                        activation=tf.tanh,
                        kernel_initializer=util.create_initializer(
                            albert_config.initializer_range),
                        trainable=trainable)
Exemple #15
0
    def __init__(self,
                 vocab_size,
                 is_training,
                 source_ids,
                 target_ids,
                 sos_id,
                 sample_weight=None,
                 hidden_size=768,
                 num_blocks=6,
                 num_attention_heads=12,
                 scope='transformer',
                 use_label_smoothing=False,
                 use_tilda_embedding=False,
                 trainable=True,
                 **kwargs):
        super().__init__()

        dropout_rate = 0.0
        if is_training:
            dropout_rate = 0.1

        source_shape = util.get_shape_list(source_ids, expected_rank=2)
        target_shape = util.get_shape_list(target_ids, expected_rank=2)
        batch_size = source_shape[0]
        source_max_seq_length = source_shape[1]
        target_max_seq_length = target_shape[1]

        # Tilda embeddings for SMART algorithm
        tilda_embeddings = None
        if use_tilda_embedding:
            with tf.variable_scope('', reuse=True):
                tilda_embeddings = tf.get_variable('tilda_embeddings')

        with tf.variable_scope(scope):
            source_mask = tf.math.equal(source_ids, 0)

            # embedding
            with tf.variable_scope('embeddings'):
                (enc, embedding_table) = embedding_lookup(
                    input_ids=source_ids,
                    vocab_size=vocab_size,
                    batch_size=batch_size,
                    max_seq_length=source_max_seq_length,
                    embedding_size=hidden_size,
                    word_embedding_name='word_embeddings',
                    tilda_embeddings=tilda_embeddings)
                enc *= hidden_size ** 0.5  # scale
                enc += positional_encoding(enc, source_max_seq_length)
                enc = util.dropout(enc, dropout_rate)

            with tf.variable_scope('encoder'):

                # stacked multi-attention layers
                for i in range(num_blocks):
                    with tf.variable_scope('block_%s' % i):

                        # self-attention
                        enc = multihead_attention(
                            queries=enc,
                            keys=enc,
                            values=enc,
                            key_masks=source_mask,
                            num_heads=num_attention_heads,
                            dropout_rate=dropout_rate,
                            training=is_training,
                            causality=False,
                            scope='self_attention')

                        # feed forward
                        enc = ff(enc, num_units=[hidden_size * 4, hidden_size])
                memory = enc

            def _forward(target_ids, target_mask, target_max_seq_length):

                with tf.variable_scope('decoder'):

                    # shared embedding
                    dec = tf.nn.embedding_lookup(embedding_table, target_ids)
                    dec *= hidden_size ** 0.5  # scale
                    dec += positional_encoding(dec, target_max_seq_length)
                    dec = util.dropout(dec, dropout_rate)

                    # blocks
                    for i in range(num_blocks):
                        with tf.variable_scope('block_%s' % i):

                            # masked self-attention
                            dec = multihead_attention(
                                queries=dec,
                                keys=dec,
                                values=dec,
                                key_masks=target_mask,
                                num_heads=num_attention_heads,
                                dropout_rate=dropout_rate,
                                training=is_training,
                                causality=True,
                                scope='masked_self_attention')

                            # vanilla attention
                            dec = multihead_attention(
                                queries=dec,
                                keys=memory,
                                values=memory,
                                key_masks=source_mask,
                                num_heads=num_attention_heads,
                                dropout_rate=dropout_rate,
                                training=is_training,
                                causality=False,
                                scope='vanilla_attention')

                            # feed forward
                            dec = ff(
                                dec, num_units=[4 * hidden_size, hidden_size])

                # final linear projection (embedding weights are shared)
                with tf.variable_scope('cls'):
                    output_bias = tf.get_variable(
                        'output_bias', shape=[vocab_size],
                        initializer=tf.zeros_initializer())
                    dec = tf.reshape(dec, [-1, hidden_size])
                    logits = tf.matmul(dec, embedding_table, transpose_b=True)
                    logits = tf.reshape(
                        logits, [-1, target_max_seq_length, vocab_size])
                    logits = tf.nn.bias_add(logits, output_bias)

                return logits

            # convert to labels
            label_ids = tf.concat(
                [target_ids[:, 1:],
                 tf.zeros([batch_size, 1], dtype=tf.int32)], axis=-1)

            # forward once
            if is_training:
                target_mask = tf.math.equal(target_ids, 0)  # (N, T2)
                logits = _forward(
                    target_ids, target_mask, target_max_seq_length)

                self.preds['MT'] = tf.argmax(logits, axis=-1)

            # forward loop
            else:
                target_mask_base = tf.zeros([batch_size, 1], dtype=tf.int32)
                target_ids = tf.ones([batch_size, 1], dtype=tf.int32) * sos_id

                for cur_length in range(1, target_max_seq_length + 1):
                    target_mask = tf.tile(target_mask_base, [1, cur_length])
                    logits = _forward(target_ids, target_mask, cur_length)

                    pred_ids = tf.argmax(
                        logits[:, cur_length-1:cur_length, :],
                        axis=-1)
                    pred_ids = tf.cast(pred_ids, tf.int32)
                    target_ids = tf.concat([target_ids, pred_ids], axis=-1)

                self.preds['MT'] = target_ids[:, 1:]

            # loss
            log_probs = tf.nn.log_softmax(logits, axis=-1)
            one_hot_labels = tf.one_hot(label_ids, depth=vocab_size)
            if use_label_smoothing:
                one_hot_labels = label_smoothing(one_hot_labels)
            per_token_loss = -tf.reduce_sum(
                one_hot_labels * log_probs, axis=-1)
            label_mask = tf.cast(tf.not_equal(label_ids, 0), tf.float32)
            per_example_loss = \
                tf.reduce_sum(per_token_loss * label_mask, axis=-1) / \
                tf.reduce_sum(label_mask, axis=-1)
            if sample_weight is not None:
                per_example_loss *= tf.expand_dims(sample_weight, axis=-1)

            self.total_loss = tf.reduce_mean(per_example_loss)
            self.losses['MT'] = per_example_loss