Example #1
0
    def __init__(self,
                 is_training,
                 input_tensor,
                 input_mask,
                 label_ids,
                 label_size=2,
                 sample_weight=None,
                 scope='cls/sequence',
                 name='',
                 hidden_dropout_prob=0.1,
                 initializer_range=0.02,
                 trainable=True,
                 **kwargs):
        super().__init__(**kwargs)

        batch_size = tf.shape(input_tensor)[0]
        seq_length = input_tensor.shape.as_list()[-2]
        hidden_size = input_tensor.shape.as_list()[-1]
        with tf.variable_scope(scope):
            output_weights = tf.get_variable(
                'output_weights',
                shape=[label_size, hidden_size],
                initializer=util.create_initializer(initializer_range),
                trainable=trainable)
            output_bias = tf.get_variable('output_bias',
                                          shape=[label_size],
                                          initializer=tf.zeros_initializer(),
                                          trainable=trainable)

            output_layer = util.dropout(
                input_tensor, hidden_dropout_prob if is_training else 0.0)

            output_layer = tf.reshape(output_layer, [-1, hidden_size])
            logits = tf.matmul(output_layer, output_weights, transpose_b=True)
            logits = tf.nn.bias_add(logits, output_bias)
            logits = tf.reshape(logits, [-1, seq_length, label_size])

            self.preds[name] = tf.argmax(logits, axis=-1)
            self.probs[name] = tf.nn.softmax(logits, axis=-1, name='probs')

            log_probs = tf.nn.log_softmax(logits, axis=-1)
            one_hot_labels = tf.one_hot(label_ids,
                                        depth=label_size,
                                        dtype=tf.float32)
            per_token_losses = -tf.reduce_mean(one_hot_labels * log_probs,
                                               axis=-1)
            input_mask = tf.concat([
                tf.zeros((batch_size, 1), dtype=tf.float32),
                tf.cast(input_mask[:, 2:], dtype=tf.float32),
                tf.zeros((batch_size, 1), dtype=tf.float32)
            ],
                                   axis=-1)
            per_token_losses *= input_mask
            per_example_loss = tf.reduce_mean(per_token_losses, axis=-1)
            if sample_weight is not None:
                per_example_loss *= tf.cast(sample_weight, dtype=tf.float32)

            self.losses[name] = per_example_loss
            self.total_loss = tf.reduce_mean(per_example_loss)
Example #2
0
def embedding_lookup(input_ids,
                     vocab_size,
                     batch_size,
                     max_seq_length,
                     embedding_size=128,
                     initializer_range=0.02,
                     word_embedding_name='word_embeddings',
                     zero_pad=True,
                     dtype=tf.float32,
                     trainable=True,
                     tilda_embeddings=None):
    if input_ids.shape.ndims == 2:
        input_ids = tf.expand_dims(input_ids, axis=[-1])

    if tilda_embeddings is not None:
        embedding_table = tilda_embeddings
    else:
        embedding_table = tf.get_variable(
            name=word_embedding_name,
            shape=[vocab_size, embedding_size],
            initializer=util.create_initializer(initializer_range),
            dtype=dtype,
            trainable=trainable)

    embedding_table = tf.concat(
        (tf.zeros(shape=[1, embedding_size]),
         embedding_table[1:, :]), axis=0)

    flat_input_ids = tf.reshape(input_ids, [-1])
    output = tf.gather(
        embedding_table, flat_input_ids, name='embedding_look_up')
    output = tf.reshape(
        output, [batch_size, max_seq_length, embedding_size])

    return (output, embedding_table)
Example #3
0
def scatter_update(sequence, updates, positions):
    '''Scatter-update a sequence.

    Args:
      sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] tensor
      updates: A tensor of size batch_size*seq_len(*depth)
      positions: A [batch_size, n_positions] tensor

    Returns: A tuple of two tensors. First is a [batch_size, seq_len] or
      [batch_size, seq_len, depth] tensor of 'sequence' with elements at
      'positions' replaced by the values at 'updates.' Updates to index 0 are
      ignored. If there are duplicated positions the update is only applied
      once. Second is a [batch_size, seq_len] mask tensor of which inputs were
      updated.
    '''
    shape = util.get_shape_list(sequence, expected_rank=[2, 3])
    depth_dimension = (len(shape) == 3)
    if depth_dimension:
        B, L, D = shape
    else:
        B, L = shape
        D = 1
        sequence = tf.expand_dims(sequence, -1)
    N = util.get_shape_list(positions)[1]

    shift = tf.expand_dims(L * tf.range(B), -1)
    flat_positions = tf.reshape(positions + shift, [-1, 1])
    flat_updates = tf.reshape(updates, [-1, D])
    updates = tf.scatter_nd(flat_positions, flat_updates, [B * L, D])
    updates = tf.reshape(updates, [B, L, D])

    flat_updates_mask = tf.ones([B * N], tf.int32)
    updates_mask = tf.scatter_nd(flat_positions, flat_updates_mask, [B * L])
    updates_mask = tf.reshape(updates_mask, [B, L])
    not_first_token = tf.concat(
        [tf.zeros((B, 1), tf.int32),
         tf.ones((B, L - 1), tf.int32)], -1)
    updates_mask *= not_first_token
    updates_mask_3d = tf.expand_dims(updates_mask, -1)

    # account for duplicate positions
    if sequence.dtype == tf.float32:
        updates_mask_3d = tf.cast(updates_mask_3d, tf.float32)
        updates /= tf.maximum(1.0, updates_mask_3d)
    else:
        assert sequence.dtype == tf.int32
        updates = tf.divide(updates, tf.maximum(1, updates_mask_3d))
        updates = tf.cast(updates, tf.int32)
    updates_mask = tf.minimum(updates_mask, 1)
    updates_mask_3d = tf.minimum(updates_mask_3d, 1)

    updated_sequence = (((1 - updates_mask_3d) * sequence) +
                        (updates_mask_3d * updates))
    if not depth_dimension:
        updated_sequence = tf.squeeze(updated_sequence, -1)

    return updated_sequence, updates_mask
Example #4
0
def _create_mask(qlen, mlen, dtype=tf.float32, same_length=False):
    '''create causal attention mask.'''
    attn_mask = tf.ones([qlen, qlen], dtype=dtype)
    mask_u = tf.matrix_band_part(attn_mask, 0, -1)
    mask_dia = tf.matrix_band_part(attn_mask, 0, 0)
    attn_mask_pad = tf.zeros([qlen, mlen], dtype=dtype)
    ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1)
    if same_length:
        mask_l = tf.matrix_band_part(attn_mask, -1, 0)
        ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1)

    return ret
Example #5
0
def get_token_embeddings(vocab_size, num_units, zero_pad=True):
    '''Constructs token embedding matrix.
    Note that the column of index 0's are set to zeros.
    vocab_size: scalar. V.
    num_units: embedding dimensionalty. E.
    zero_pad: Boolean. If True, all the values of the first row (id = 0) should be constant zero
    To apply query/key masks easily, zero pad is turned on.

    Returns
    weight variable: (V, E)
    '''
    with tf.variable_scope('shared_weight_matrix'):
        embeddings = tf.get_variable('weight_mat',
                                   dtype=tf.float32,
                                   shape=(vocab_size, num_units),
                                   initializer=xavier_initializer())
        if zero_pad:
            embeddings = tf.concat((tf.zeros(shape=[1, num_units]),
                                    embeddings[1:, :]), 0)
    return embeddings
Example #6
0
    def __init__(self,
                 hparams,
                 is_training,
                 input_ids,
                 sample_weight=None,
                 scope='model',
                 given=1,
                 use_tilda_embedding=False,
                 **kwargs):
        super().__init__()

        batch_size = util.get_shape_list(input_ids, expected_rank=2)[0]
        max_seq_length = hparams.n_predict

        # Tilda embeddings for SMART algorithm
        tilda_embeddings = None
        if use_tilda_embedding:
            with tf.variable_scope('', reuse=True):
                tilda_embeddings = tf.get_variable('tilda_embeddings')

        with tf.variable_scope(scope):

            def _forward(input_ids, past=None):
                batch, sequence = shape_list(input_ids)

                if tilda_embeddings is None:
                    wte = tf.get_variable(
                        'word_embeddings', [hparams.n_vocab, hparams.n_embed],
                        initializer=tf.random_normal_initializer(stddev=0.02))
                else:
                    wte = tilda_embeddings
                wpe = tf.get_variable(
                    'wpe', [hparams.n_ctx, hparams.n_embed],
                    initializer=tf.random_normal_initializer(stddev=0.01))
                past_length = 0 if past is None else tf.shape(past)[-2]
                h = (tf.gather(wte, input_ids) +
                     tf.gather(wpe, positions_for(input_ids, past_length)))

                # stacked transformer layers
                presents = []
                pasts = tf.unstack(past, axis=1) if past is not None else \
                    [None] * hparams.n_layer
                assert len(pasts) == hparams.n_layer
                for layer, past in enumerate(pasts):
                    h, present = block(h,
                                       'h%d' % layer,
                                       past=past,
                                       hparams=hparams)
                    presents.append(present)
                present = tf.stack(presents, axis=1)
                h = norm(h, 'ln_f')

                # Language model loss.  Do tokens <n predict token n?
                h_flat = tf.reshape(h, [batch * sequence, hparams.n_embed])
                logits = tf.matmul(h_flat, wte, transpose_b=True)
                logits = tf.reshape(logits, [batch, sequence, hparams.n_vocab])

                return logits, present

            # convert to labels
            label_ids = tf.concat(
                [input_ids[:, 1:],
                 tf.zeros([batch_size, 1], dtype=tf.int32)],
                axis=-1)

            # forward once
            if is_training:
                (logits, _) = _forward(input_ids)

                self.preds['LM'] = tf.argmax(logits, axis=-1)

            # forward loop
            else:
                input_ids = input_ids[:, 0:given]

                for cur_length in range(given, max_seq_length + 1):
                    (logits, _) = _forward(input_ids)

                    pred_ids = tf.argmax(logits[:,
                                                cur_length - 1:cur_length, :],
                                         axis=-1)
                    pred_ids = tf.cast(pred_ids, tf.int32)
                    input_ids = tf.concat([input_ids, pred_ids], axis=-1)

                self.preds['LM'] = input_ids

            # loss
            log_probs = tf.nn.log_softmax(logits, axis=-1)
            one_hot_labels = tf.one_hot(label_ids, depth=hparams.n_vocab)
            per_token_loss = -tf.reduce_sum(one_hot_labels * log_probs,
                                            axis=-1)
            label_mask = tf.cast(tf.not_equal(label_ids, 0), tf.float32)
            per_example_loss = \
                tf.reduce_sum(per_token_loss * label_mask, axis=-1) / \
                tf.reduce_sum(label_mask, axis=-1)
            if sample_weight is not None:
                per_example_loss *= tf.expand_dims(sample_weight, axis=-1)

            self.total_loss = tf.reduce_mean(per_example_loss)
            self.losses['LM'] = per_example_loss
Example #7
0
def _expand_features(module, split_placeholders):

    inputs = split_placeholders['input']
    target = split_placeholders['target']
    is_masked = tf.cast(split_placeholders['is_masked'], tf.bool)
    batch_size = tf.shape(inputs)[0]

    non_reuse_len = module.max_seq_length - module.reuse_seq_length
    assert (module.perm_size <= module.reuse_seq_length
            and module.perm_size <= non_reuse_len)

    (perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0) = \
        _local_perm(
            inputs[:, :module.reuse_seq_length],
            target[:, :module.reuse_seq_length],
            is_masked[:, :module.reuse_seq_length],
            module.perm_size,
            module.reuse_seq_length)

    (perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1) = \
        _local_perm(
            inputs[:, module.reuse_seq_length:],
            target[:, module.reuse_seq_length:],
            is_masked[:, module.reuse_seq_length:],
            module.perm_size,
            non_reuse_len)

    perm_mask_0 = tf.concat([
        tf.cast(perm_mask_0, dtype=tf.float32),
        tf.ones([batch_size, module.reuse_seq_length, non_reuse_len])
    ],
                            axis=2)
    perm_mask_1 = tf.concat([
        tf.zeros([batch_size, non_reuse_len, module.reuse_seq_length]),
        tf.cast(perm_mask_1, dtype=tf.float32)
    ],
                            axis=2)
    perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=1)
    target = tf.concat([target_0, target_1], axis=1)
    target_mask = tf.concat([target_mask_0, target_mask_1], axis=1)
    input_k = tf.concat([input_k_0, input_k_1], axis=1)
    input_q = tf.concat([input_q_0, input_q_1], axis=1)

    if module._num_predict is not None:
        #TODO(geying): convert tensors from 1-D to 2-D

        indices = tf.range(module.max_seq_length, dtype=tf.int64)
        indices = tf.reshape(indices, [-1, module.max_seq_length])
        indices = tf.tile(indices, [batch_size, 1])
        bool_target_mask = tf.cast(target_mask, tf.bool)
        indices = tf.boolean_mask(indices, bool_target_mask)

        ##### extra padding due to CLS/SEP introduced after prepro
        actual_num_predict = tf.shape(indices)[1]
        pad_len = module._num_predict - actual_num_predict

        ##### target_mapping
        target_mapping = tf.one_hot(indices,
                                    module.max_seq_length,
                                    dtype=tf.float32)
        paddings = tf.zeros([pad_len, module.max_seq_length],
                            dtype=target_mapping.dtype)
        target_mapping = tf.concat([target_mapping, paddings], axis=0)
        split_placeholders['target_mapping'] = tf.reshape(
            target_mapping, [-1, module._num_predict, module.max_seq_length])

        ##### target
        target = tf.boolean_mask(target, bool_target_mask)
        paddings = tf.zeros([pad_len], dtype=target.dtype)
        target = tf.concat([target, paddings], axis=0)
        split_placeholders['target'] = tf.reshape(target,
                                                  [-1, module._num_predict])

        ##### target mask
        target_mask = tf.concat([
            tf.ones([batch_size, actual_num_predict], dtype=tf.float32),
            tf.zeros([batch_size, pad_len], dtype=tf.float32)
        ],
                                axis=1)
        split_placeholders['target_mask'] = tf.reshape(
            target_mask, [-1, module._num_predict])
    else:
        split_placeholders['target'] = tf.reshape(target,
                                                  [-1, module.max_seq_length])
        split_placeholders['target_mask'] = tf.reshape(
            target_mask, [-1, module.max_seq_length])

    # reshape back to fixed shape
    split_placeholders['perm_mask'] = tf.reshape(
        perm_mask, [-1, module.max_seq_length, module.max_seq_length])
    split_placeholders['input_k'] = tf.reshape(input_k,
                                               [-1, module.max_seq_length])
    split_placeholders['input_q'] = tf.reshape(input_q,
                                               [-1, module.max_seq_length])

    return split_placeholders
Example #8
0
    def __init__(self,
                 albert_config,
                 is_training,
                 input_ids,
                 input_mask=None,
                 segment_ids=None,
                 scope='bert',
                 drop_pooler=False,
                 trainable=True,
                 **kwargs):
        """Constructor for AlbertModel.

    Args:
      albert_config: `AlbertConfig` instance.
      is_training: bool. true for training model, false for eval model.
        Controls whether dropout will be applied.
      input_ids: int32 Tensor of shape [batch_size, seq_length].
      input_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
      segment_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
      use_einsum: (optional) bool. Whether to use einsum or reshape+matmul for
        dense layers
      scope: (optional) variable scope. Defaults to "bert".

    Raises:
      ValueError: The config is invalid or one of the input tensor shapes
        is invalid.
    """
        albert_config = copy.deepcopy(albert_config)
        if not is_training:
            albert_config.hidden_dropout_prob = 0.0
            albert_config.attention_probs_dropout_prob = 0.0

        input_shape = util.get_shape_list(input_ids, expected_rank=2)
        batch_size = input_shape[0]
        seq_length = input_shape[1]

        if input_mask is None:
            input_mask = tf.ones(shape=[batch_size, seq_length],
                                 dtype=tf.int32)

        if segment_ids is None:
            segment_ids = tf.zeros(shape=[batch_size, seq_length],
                                   dtype=tf.int32)

        # Tilda embeddings for SMART algorithm
        tilda_embeddings = None
        use_tilda_embedding = kwargs.get('use_tilda_embedding')
        if use_tilda_embedding:
            with tf.variable_scope('', reuse=True):
                tilda_embeddings = tf.get_variable('tilda_embeddings')

        with tf.variable_scope(scope):
            with tf.variable_scope("embeddings"):
                # Perform embedding lookup on the word ids.
                (self.word_embedding_output,
                 self.output_embedding_table) = embedding_lookup(
                     input_ids=input_ids,
                     vocab_size=albert_config.vocab_size,
                     embedding_size=albert_config.embedding_size,
                     initializer_range=albert_config.initializer_range,
                     word_embedding_name="word_embeddings",
                     tilda_embeddings=tilda_embeddings,
                     trainable=trainable)

                # Add positional embeddings and token type embeddings, then layer
                # normalize and perform dropout.
                self.embedding_output = embedding_postprocessor(
                    input_tensor=self.word_embedding_output,
                    use_token_type=True,
                    segment_ids=segment_ids,
                    token_type_vocab_size=albert_config.type_vocab_size,
                    token_type_embedding_name="token_type_embeddings",
                    use_position_embeddings=True,
                    position_embedding_name="position_embeddings",
                    initializer_range=albert_config.initializer_range,
                    max_position_embeddings=albert_config.
                    max_position_embeddings,
                    dropout_prob=albert_config.hidden_dropout_prob,
                    trainable=trainable)

            with tf.variable_scope("encoder"):
                # Run the stacked transformer.
                # `sequence_output` shape = [batch_size, seq_length, hidden_size].
                self.all_encoder_layers = transformer_model(
                    input_tensor=self.embedding_output,
                    attention_mask=input_mask,
                    hidden_size=albert_config.hidden_size,
                    num_hidden_layers=albert_config.num_hidden_layers,
                    num_hidden_groups=albert_config.num_hidden_groups,
                    num_attention_heads=albert_config.num_attention_heads,
                    intermediate_size=albert_config.intermediate_size,
                    inner_group_num=albert_config.inner_group_num,
                    intermediate_act_fn=util.get_activation(
                        albert_config.hidden_act),
                    hidden_dropout_prob=albert_config.hidden_dropout_prob,
                    attention_probs_dropout_prob=albert_config.
                    attention_probs_dropout_prob,
                    initializer_range=albert_config.initializer_range,
                    do_return_all_layers=True,
                    use_einsum=False,
                    trainable=trainable)

            self.sequence_output = self.all_encoder_layers[-1]
            # The "pooler" converts the encoded sequence tensor of shape
            # [batch_size, seq_length, hidden_size] to a tensor of shape
            # [batch_size, hidden_size]. This is necessary for segment-level
            # (or segment-pair-level) classification tasks where we need a fixed
            # dimensional representation of the segment.
            with tf.variable_scope("pooler"):
                # We "pool" the model by simply taking the hidden state corresponding
                # to the first token. We assume that this has been pre-trained
                first_token_tensor = tf.squeeze(self.sequence_output[:,
                                                                     0:1, :],
                                                axis=1)

                # trick: ignore the fully connected layer
                if drop_pooler:
                    self.pooled_output = first_token_tensor
                else:
                    self.pooled_output = tf.layers.dense(
                        first_token_tensor,
                        albert_config.hidden_size,
                        activation=tf.tanh,
                        kernel_initializer=util.create_initializer(
                            albert_config.initializer_range),
                        trainable=trainable)
Example #9
0
def transformer_xl(inp_k,
                   n_token,
                   n_layer,
                   d_model,
                   n_head,
                   d_head,
                   d_inner,
                   dropout,
                   dropatt,
                   attn_type,
                   bi_data,
                   initializer,
                   is_training,
                   mem_len=None,
                   inp_q=None,
                   mems=None,
                   same_length=False,
                   clamp_len=-1,
                   untie_r=False,
                   use_tpu=True,
                   input_mask=None,
                   perm_mask=None,
                   seg_id=None,
                   reuse_len=None,
                   ff_activation='relu',
                   target_mapping=None,
                   use_bfloat16=False,
                   scope='transformer',
                   tilda_embeddings=None,
                   **kwargs):
    '''
    Defines a Transformer-XL computation graph with additional
    support for XLNet.

      Args:

      inp_k: int32 Tensor in shape [len, bsz], the input token IDs.
      seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
      input_mask: float32 Tensor in shape [len, bsz], the input mask.
          0 for real tokens and 1 for padding.
      mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
          from previous batches. The length of the list equals n_layer.
          If None, no memory is used.
      perm_mask: float32 Tensor in shape [len, len, bsz].
          If perm_mask[i, j, k] = 0, i attend to j in batch k;
          if perm_mask[i, j, k] = 1, i does not attend to j in batch k.
          If None, each position attends to all the others.
      target_mapping: float32 Tensor in shape [num_predict, len, bsz].
          If target_mapping[i, j, k] = 1, the i-th predict in batch k is
          on the j-th token.
          Only used during pretraining for partial prediction.
          Set to None during finetuning.
      inp_q: float32 Tensor in shape [len, bsz].
          1 for tokens with losses and 0 for tokens without losses.
          Only used during pretraining for two-stream attention.
          Set to None during finetuning.

      n_layer: int, the number of layers.
      d_model: int, the hidden size.
      n_head: int, the number of attention heads.
      d_head: int, the dimension size of each attention head.
      d_inner: int, the hidden size in feed-forward layers.
      ff_activation: str, 'relu' or 'gelu'.
      untie_r: bool, whether to untie the biases in attention.
      n_token: int, the vocab size.

      is_training: bool, whether in training mode.
      use_tpu: bool, whether TPUs are used.
      use_bfloat16: bool, use bfloat16 instead of float32.
      dropout: float, dropout rate.
      dropatt: float, dropout rate on attention probabilities.
      init: str, the initialization scheme, either 'normal' or 'uniform'.
      init_range: float, initialize the parameters with a uniform distribution
          in [-init_range, init_range]. Only effective when init='uniform'.
      init_std: float, initialize the parameters with a normal distribution
          with mean 0 and stddev init_std. Only effective when init='normal'.
      mem_len: int, the number of tokens to cache.
      reuse_len: int, the number of tokens in the currect batch to be cached
          and reused in the future.
      bi_data: bool, whether to use bidirectional input pipeline.
          Usually set to True during pretraining and False during finetuning.
      clamp_len: int, clamp all relative distances larger than clamp_len.
          -1 means no clamping.
      same_length: bool, whether to use the same attention length for each token.
      summary_type: str, 'last', 'first', 'mean', or 'attn'. The method
          to pool the input to get a vector representation.
      initializer: A tf initializer.
      scope: scope name for the computation graph.
    '''
    tf_float = tf.bfloat16 if use_bfloat16 else tf.float32

    new_mems = []
    with tf.variable_scope(scope):
        if untie_r:
            r_w_bias = tf.get_variable('r_w_bias', [n_layer, n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)
            r_r_bias = tf.get_variable('r_r_bias', [n_layer, n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)
        else:
            r_w_bias = tf.get_variable('r_w_bias', [n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)
            r_r_bias = tf.get_variable('r_r_bias', [n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)

        bsz = tf.shape(inp_k)[1]
        qlen = tf.shape(inp_k)[0]
        mlen = tf.shape(mems[0])[0] if mems is not None else 0
        klen = mlen + qlen

        ##### Attention mask
        # causal attention mask
        if attn_type == 'uni':
            attn_mask = _create_mask(qlen, mlen, tf_float, same_length)
            attn_mask = attn_mask[:, :, None, None]
        elif attn_type == 'bi':
            attn_mask = None
        else:
            raise ValueError('Unsupported attention type: %s' % attn_type)

        # data mask: input mask & perm mask
        if input_mask is not None and perm_mask is not None:
            data_mask = input_mask[None] + perm_mask
        elif input_mask is not None and perm_mask is None:
            data_mask = input_mask[None]
        elif input_mask is None and perm_mask is not None:
            data_mask = perm_mask
        else:
            data_mask = None

        if data_mask is not None:
            # all mems can be attended to
            mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, bsz],
                                 dtype=tf_float)
            data_mask = tf.cast(data_mask, dtype=tf.float32)
            data_mask = tf.concat([mems_mask, data_mask], 1)
            if attn_mask is None:
                attn_mask = data_mask[:, :, :, None]
            else:
                attn_mask += data_mask[:, :, :, None]

        if attn_mask is not None:
            attn_mask = tf.cast(attn_mask > 0, dtype=tf_float)

        if attn_mask is not None:
            non_tgt_mask = -tf.eye(qlen, dtype=tf_float)
            non_tgt_mask = tf.concat(
                [tf.zeros([qlen, mlen], dtype=tf_float), non_tgt_mask],
                axis=-1)
            non_tgt_mask = tf.cast(
                (attn_mask + non_tgt_mask[:, :, None, None]) > 0,
                dtype=tf_float)
        else:
            non_tgt_mask = None

        ##### Word embedding
        word_emb_k, lookup_table = embedding_lookup(
            x=inp_k,
            n_token=n_token,
            d_embed=d_model,
            initializer=initializer,
            use_tpu=use_tpu,
            dtype=tf_float,
            scope='word_embedding',
            tilda_embeddings=tilda_embeddings)

        if inp_q is not None:
            with tf.variable_scope('mask_emb'):
                mask_emb = tf.get_variable('mask_emb', [1, 1, d_model],
                                           dtype=tf_float)
                if target_mapping is not None:
                    word_emb_q = tf.tile(mask_emb,
                                         [tf.shape(target_mapping)[0], bsz, 1])
                else:
                    inp_q_ext = inp_q[:, :, None]
                    word_emb_q = \
                        inp_q_ext * mask_emb + (1 - inp_q_ext) * word_emb_k
        output_h = tf.layers.dropout(word_emb_k, dropout, training=is_training)
        if inp_q is not None:
            output_g = tf.layers.dropout(word_emb_q,
                                         dropout,
                                         training=is_training)

        ##### Segment embedding
        if seg_id is not None:
            if untie_r:
                r_s_bias = tf.get_variable('r_s_bias',
                                           [n_layer, n_head, d_head],
                                           dtype=tf_float,
                                           initializer=initializer)
            else:
                # default case (tie)
                r_s_bias = tf.get_variable('r_s_bias', [n_head, d_head],
                                           dtype=tf_float,
                                           initializer=initializer)

            seg_embed = tf.get_variable('seg_embed',
                                        [n_layer, 2, n_head, d_head],
                                        dtype=tf_float,
                                        initializer=initializer)

            # Convert `seg_id` to one-hot `seg_mat`
            mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
            cat_ids = tf.concat([mem_pad, seg_id], 0)

            # `1` indicates not in the same segment [qlen x klen x bsz]
            seg_mat = tf.cast(
                tf.logical_not(tf.equal(seg_id[:, None], cat_ids[None, :])),
                tf.int32)
            seg_mat = tf.one_hot(seg_mat, 2, dtype=tf_float)
        else:
            seg_mat = None

        ##### Positional encoding
        pos_emb = relative_positional_encoding(qlen,
                                               klen,
                                               d_model,
                                               clamp_len,
                                               attn_type,
                                               bi_data,
                                               bsz=bsz,
                                               dtype=tf_float)
        pos_emb = tf.layers.dropout(pos_emb, dropout, training=is_training)

        ##### Attention layers
        if mems is None:
            mems = [None] * n_layer

        for i in range(n_layer):
            # cache new mems
            new_mems.append(_cache_mem(output_h, mems[i], mem_len, reuse_len))

            # segment bias
            if seg_id is None:
                r_s_bias_i = None
                seg_embed_i = None
            else:
                r_s_bias_i = r_s_bias if not untie_r else r_s_bias[i]
                seg_embed_i = seg_embed[i]

            with tf.variable_scope('layer_{}'.format(i)):
                if inp_q is not None:
                    output_h, output_g = two_stream_rel_attn(
                        h=output_h,
                        g=output_g,
                        r=pos_emb,
                        r_w_bias=r_w_bias if not untie_r else r_w_bias[i],
                        r_r_bias=r_r_bias if not untie_r else r_r_bias[i],
                        seg_mat=seg_mat,
                        r_s_bias=r_s_bias_i,
                        seg_embed=seg_embed_i,
                        attn_mask_h=non_tgt_mask,
                        attn_mask_g=attn_mask,
                        mems=mems[i],
                        target_mapping=target_mapping,
                        d_model=d_model,
                        n_head=n_head,
                        d_head=d_head,
                        dropout=dropout,
                        dropatt=dropatt,
                        is_training=is_training,
                        kernel_initializer=initializer)
                    reuse = True
                else:
                    reuse = False

                    output_h = rel_multihead_attn(
                        h=output_h,
                        r=pos_emb,
                        r_w_bias=r_w_bias if not untie_r else r_w_bias[i],
                        r_r_bias=r_r_bias if not untie_r else r_r_bias[i],
                        seg_mat=seg_mat,
                        r_s_bias=r_s_bias_i,
                        seg_embed=seg_embed_i,
                        attn_mask=non_tgt_mask,
                        mems=mems[i],
                        d_model=d_model,
                        n_head=n_head,
                        d_head=d_head,
                        dropout=dropout,
                        dropatt=dropatt,
                        is_training=is_training,
                        kernel_initializer=initializer,
                        reuse=reuse)

                if inp_q is not None:
                    output_g = positionwise_ffn(inp=output_g,
                                                d_model=d_model,
                                                d_inner=d_inner,
                                                dropout=dropout,
                                                kernel_initializer=initializer,
                                                activation_type=ff_activation,
                                                is_training=is_training)

                output_h = positionwise_ffn(inp=output_h,
                                            d_model=d_model,
                                            d_inner=d_inner,
                                            dropout=dropout,
                                            kernel_initializer=initializer,
                                            activation_type=ff_activation,
                                            is_training=is_training,
                                            reuse=reuse)

        if inp_q is not None:
            output = tf.layers.dropout(output_g, dropout, training=is_training)
        else:
            output = tf.layers.dropout(output_h, dropout, training=is_training)

        return output, new_mems, lookup_table
Example #10
0
    def __init__(self,
                 vocab_size,
                 is_training,
                 source_ids,
                 target_ids,
                 sos_id,
                 sample_weight=None,
                 hidden_size=768,
                 num_blocks=6,
                 num_attention_heads=12,
                 scope='transformer',
                 use_label_smoothing=False,
                 use_tilda_embedding=False,
                 trainable=True,
                 **kwargs):
        super().__init__()

        dropout_rate = 0.0
        if is_training:
            dropout_rate = 0.1

        source_shape = util.get_shape_list(source_ids, expected_rank=2)
        target_shape = util.get_shape_list(target_ids, expected_rank=2)
        batch_size = source_shape[0]
        source_max_seq_length = source_shape[1]
        target_max_seq_length = target_shape[1]

        # Tilda embeddings for SMART algorithm
        tilda_embeddings = None
        if use_tilda_embedding:
            with tf.variable_scope('', reuse=True):
                tilda_embeddings = tf.get_variable('tilda_embeddings')

        with tf.variable_scope(scope):
            source_mask = tf.math.equal(source_ids, 0)

            # embedding
            with tf.variable_scope('embeddings'):
                (enc, embedding_table) = embedding_lookup(
                    input_ids=source_ids,
                    vocab_size=vocab_size,
                    batch_size=batch_size,
                    max_seq_length=source_max_seq_length,
                    embedding_size=hidden_size,
                    word_embedding_name='word_embeddings',
                    tilda_embeddings=tilda_embeddings)
                enc *= hidden_size ** 0.5  # scale
                enc += positional_encoding(enc, source_max_seq_length)
                enc = util.dropout(enc, dropout_rate)

            with tf.variable_scope('encoder'):

                # stacked multi-attention layers
                for i in range(num_blocks):
                    with tf.variable_scope('block_%s' % i):

                        # self-attention
                        enc = multihead_attention(
                            queries=enc,
                            keys=enc,
                            values=enc,
                            key_masks=source_mask,
                            num_heads=num_attention_heads,
                            dropout_rate=dropout_rate,
                            training=is_training,
                            causality=False,
                            scope='self_attention')

                        # feed forward
                        enc = ff(enc, num_units=[hidden_size * 4, hidden_size])
                memory = enc

            def _forward(target_ids, target_mask, target_max_seq_length):

                with tf.variable_scope('decoder'):

                    # shared embedding
                    dec = tf.nn.embedding_lookup(embedding_table, target_ids)
                    dec *= hidden_size ** 0.5  # scale
                    dec += positional_encoding(dec, target_max_seq_length)
                    dec = util.dropout(dec, dropout_rate)

                    # blocks
                    for i in range(num_blocks):
                        with tf.variable_scope('block_%s' % i):

                            # masked self-attention
                            dec = multihead_attention(
                                queries=dec,
                                keys=dec,
                                values=dec,
                                key_masks=target_mask,
                                num_heads=num_attention_heads,
                                dropout_rate=dropout_rate,
                                training=is_training,
                                causality=True,
                                scope='masked_self_attention')

                            # vanilla attention
                            dec = multihead_attention(
                                queries=dec,
                                keys=memory,
                                values=memory,
                                key_masks=source_mask,
                                num_heads=num_attention_heads,
                                dropout_rate=dropout_rate,
                                training=is_training,
                                causality=False,
                                scope='vanilla_attention')

                            # feed forward
                            dec = ff(
                                dec, num_units=[4 * hidden_size, hidden_size])

                # final linear projection (embedding weights are shared)
                with tf.variable_scope('cls'):
                    output_bias = tf.get_variable(
                        'output_bias', shape=[vocab_size],
                        initializer=tf.zeros_initializer())
                    dec = tf.reshape(dec, [-1, hidden_size])
                    logits = tf.matmul(dec, embedding_table, transpose_b=True)
                    logits = tf.reshape(
                        logits, [-1, target_max_seq_length, vocab_size])
                    logits = tf.nn.bias_add(logits, output_bias)

                return logits

            # convert to labels
            label_ids = tf.concat(
                [target_ids[:, 1:],
                 tf.zeros([batch_size, 1], dtype=tf.int32)], axis=-1)

            # forward once
            if is_training:
                target_mask = tf.math.equal(target_ids, 0)  # (N, T2)
                logits = _forward(
                    target_ids, target_mask, target_max_seq_length)

                self.preds['MT'] = tf.argmax(logits, axis=-1)

            # forward loop
            else:
                target_mask_base = tf.zeros([batch_size, 1], dtype=tf.int32)
                target_ids = tf.ones([batch_size, 1], dtype=tf.int32) * sos_id

                for cur_length in range(1, target_max_seq_length + 1):
                    target_mask = tf.tile(target_mask_base, [1, cur_length])
                    logits = _forward(target_ids, target_mask, cur_length)

                    pred_ids = tf.argmax(
                        logits[:, cur_length-1:cur_length, :],
                        axis=-1)
                    pred_ids = tf.cast(pred_ids, tf.int32)
                    target_ids = tf.concat([target_ids, pred_ids], axis=-1)

                self.preds['MT'] = target_ids[:, 1:]

            # loss
            log_probs = tf.nn.log_softmax(logits, axis=-1)
            one_hot_labels = tf.one_hot(label_ids, depth=vocab_size)
            if use_label_smoothing:
                one_hot_labels = label_smoothing(one_hot_labels)
            per_token_loss = -tf.reduce_sum(
                one_hot_labels * log_probs, axis=-1)
            label_mask = tf.cast(tf.not_equal(label_ids, 0), tf.float32)
            per_example_loss = \
                tf.reduce_sum(per_token_loss * label_mask, axis=-1) / \
                tf.reduce_sum(label_mask, axis=-1)
            if sample_weight is not None:
                per_example_loss *= tf.expand_dims(sample_weight, axis=-1)

            self.total_loss = tf.reduce_mean(per_example_loss)
            self.losses['MT'] = per_example_loss