Exemplo n.º 1
0
    def call(self, inputs):
        """Implements call() for the layer."""
        unpacked_inputs = tf_utils.unpack_inputs(inputs)
        word_embeddings = unpacked_inputs[0]
        token_type_ids = unpacked_inputs[1]
        input_shape = tf_utils.get_shape_list(word_embeddings, expected_rank=3)
        batch_size = input_shape[0]
        seq_length = input_shape[1]
        width = input_shape[2]

        output = word_embeddings
        if self.use_type_embeddings:
            flat_token_type_ids = tf.reshape(token_type_ids, [-1])
            token_type_embeddings = tf.gather(self.type_embeddings,
                                              flat_token_type_ids)
            token_type_embeddings = tf.reshape(token_type_embeddings,
                                               [batch_size, seq_length, width])
            output += token_type_embeddings

        if self.use_position_embeddings:
            position_embeddings = tf.expand_dims(tf.slice(
                self.position_embeddings, [0, 0], [seq_length, width]),
                                                 axis=0)

            output += position_embeddings

        output = self.output_layer_norm(output)
        output = self.output_dropout(output)

        return output
Exemplo n.º 2
0
    def call(self, inputs):
        """Implements call() for the layer."""
        unpacked_inputs = tf_utils.unpack_inputs(inputs)
        lm_output = unpacked_inputs[0]
        sentence_output = unpacked_inputs[1]
        lm_label_ids = unpacked_inputs[2]
        lm_label_ids = tf.keras.backend.reshape(lm_label_ids, [-1])
        lm_label_ids_one_hot = tf.keras.backend.one_hot(
            lm_label_ids, self.config.vocab_size)
        lm_label_weights = tf.keras.backend.cast(unpacked_inputs[3],
                                                 tf.float32)
        lm_label_weights = tf.keras.backend.reshape(lm_label_weights, [-1])
        lm_per_example_loss = -tf.keras.backend.sum(
            lm_output * lm_label_ids_one_hot, axis=[-1])
        numerator = tf.keras.backend.sum(lm_label_weights *
                                         lm_per_example_loss)
        denominator = tf.keras.backend.sum(lm_label_weights) + 1e-5
        mask_label_loss = numerator / denominator

        sentence_labels = unpacked_inputs[4]
        sentence_labels = tf.keras.backend.reshape(sentence_labels, [-1])
        sentence_label_one_hot = tf.keras.backend.one_hot(sentence_labels, 2)
        per_example_loss_sentence = -tf.keras.backend.sum(
            sentence_label_one_hot * sentence_output, axis=-1)
        sentence_loss = tf.keras.backend.mean(per_example_loss_sentence)
        loss = mask_label_loss + sentence_loss
        # TODO(hongkuny): Avoids the hack and switches add_loss.
        final_loss = tf.fill(tf.keras.backend.shape(per_example_loss_sentence),
                             loss)

        self._add_metrics(lm_output, lm_label_ids, lm_label_weights,
                          lm_per_example_loss, sentence_output,
                          sentence_labels, per_example_loss_sentence)
        return final_loss
Exemplo n.º 3
0
  def call(self, inputs, return_all_layers=False):
    """Implements call() for the layer.

    Args:
      inputs: packed inputs.
      return_all_layers: bool, whether to return outputs of all layers inside
        encoders.

    Returns:
      Output tensor of the last layer or a list of output tensors.
    """
    unpacked_inputs = tf_utils.unpack_inputs(inputs)
    input_tensor = unpacked_inputs[0]
    attention_mask = unpacked_inputs[1]
    output_tensor = input_tensor

    all_layer_outputs = []
    for layer in self.layers:
      output_tensor = layer(output_tensor, attention_mask)
      all_layer_outputs.append(output_tensor)

    if return_all_layers:
      return all_layer_outputs

    return all_layer_outputs[-1]
Exemplo n.º 4
0
 def call(self, inputs):
   """Implements call() for the layer."""
   (input_tensor, attention_mask) = tf_utils.unpack_inputs(inputs)
   attention_output = self.attention_layer(
       from_tensor=input_tensor,
       to_tensor=input_tensor,
       attention_mask=attention_mask)
   attention_output = self.attention_output_dense(attention_output)
   attention_output = self.attention_dropout(attention_output)
   # Use float32 in keras layer norm and the gelu activation in the
   # intermediate dense layer for numeric stability
   attention_output = self.attention_layer_norm(input_tensor +
                                                attention_output)
   if self.float_type == tf.float16:
     attention_output = tf.cast(attention_output, tf.float16)
   intermediate_output = self.intermediate_dense(attention_output)
   if self.float_type == tf.float16:
     intermediate_output = tf.cast(intermediate_output, tf.float16)
   layer_output = self.output_dense(intermediate_output)
   layer_output = self.output_dropout(layer_output)
   # Use float32 in keras layer norm for numeric stability
   layer_output = self.output_layer_norm(layer_output + attention_output)
   if self.float_type == tf.float16:
     layer_output = tf.cast(layer_output, tf.float16)
   return layer_output
 def call(self, inputs):
     """Implements call() for the layer."""
     (input_tensor, attention_mask) = tf_utils.unpack_inputs(inputs)
     attention_output = self.attention_layer(from_tensor=input_tensor,
                                             to_tensor=input_tensor,
                                             attention_mask=attention_mask)
     attention_output = self.attention_output_dense(attention_output)
     attention_output = self.attention_dropout(attention_output)
     # Use float32 in keras layer norm and the gelu activation in the
     # intermediate dense layer for numeric stability
     # TODO(reedwm): These casts are probably unnecessary, as we passed
     # dtype=tf.float32 to the layer norm constructor, so it will cast its inputs
     # to float32 automatically. These manual casts additionally do the "+"
     # operator in float32, but "+" is numerically stable in float16.
     if self.float_type == tf.float16:
         input_tensor = tf.cast(input_tensor, tf.float32)
         attention_output = tf.cast(attention_output, tf.float32)
     attention_output = self.attention_layer_norm(input_tensor +
                                                  attention_output)
     intermediate_output = self.intermediate_dense(attention_output)
     if self.float_type == tf.float16:
         intermediate_output = tf.cast(intermediate_output, tf.float16)
     layer_output = self.output_dense(intermediate_output)
     layer_output = self.output_dropout(layer_output)
     # Use float32 in keras layer norm for numeric stability
     if self.float_type == tf.float16:
         layer_output = tf.cast(layer_output, tf.float32)
     layer_output = self.output_layer_norm(layer_output + attention_output)
     if self.float_type == tf.float16:
         layer_output = tf.cast(layer_output, tf.float16)
     return layer_output
Exemplo n.º 6
0
    def call(self, inputs):
        """Implements call() for the layer."""
        unpacked_inputs = tf_utils.unpack_inputs(inputs)
        lm_output = unpacked_inputs[0]
        sentence_output = unpacked_inputs[1]
        lm_label_ids = unpacked_inputs[2]
        lm_label_weights = tf.keras.backend.cast(unpacked_inputs[3],
                                                 tf.float32)
        sentence_labels = unpacked_inputs[4]

        mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss(
            labels=lm_label_ids,
            predictions=lm_output,
            weights=lm_label_weights)
        sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss(
            labels=sentence_labels, predictions=sentence_output)
        loss = mask_label_loss + sentence_loss
        batch_shape = tf.slice(tf.keras.backend.shape(sentence_labels), [0],
                               [1])
        # TODO(hongkuny): Avoids the hack and switches add_loss.
        final_loss = tf.fill(batch_shape, loss)

        self._add_metrics(lm_output, lm_label_ids, lm_label_weights,
                          mask_label_loss, sentence_output, sentence_labels,
                          sentence_loss)
        return final_loss
Exemplo n.º 7
0
  def call(self, inputs, mode="bert"):
    """Implements call() for the layer.

    Args:
      inputs: packed input tensors.
      mode: string, `bert` or `encoder`.
    Returns:
      Output tensor of the last layer for BERT training (mode=`bert`) which
      is a float Tensor of shape [batch_size, seq_length, hidden_size] or
      a list of output tensors for encoder usage (mode=`encoder`).
    """
    unpacked_inputs = tf_utils.unpack_inputs(inputs)
    input_word_ids = unpacked_inputs[0]
    input_mask = unpacked_inputs[1]
    input_type_ids = unpacked_inputs[2]

    word_embeddings = self.embedding_lookup(input_word_ids)
    embedding_tensor = self.embedding_postprocessor(
        word_embeddings=word_embeddings, token_type_ids=input_type_ids)
    if self.float_type == tf.float16:
      embedding_tensor = tf.cast(embedding_tensor, tf.float16)
    attention_mask = None
    if input_mask is not None:
      attention_mask = create_attention_mask_from_input_mask(
          input_word_ids, input_mask)

    if mode == "encoder":
      return self.encoder(
          embedding_tensor, attention_mask, return_all_layers=True)

    sequence_output = self.encoder(embedding_tensor, attention_mask)
    first_token_tensor = tf.squeeze(sequence_output[:, 0:1, :], axis=1)
    pooled_output = self.pooler_transform(first_token_tensor)

    return (pooled_output, sequence_output)
Exemplo n.º 8
0
  def call(self, inputs):
    """Implements call() for the layer."""
    (hidden, labels) = tf_utils.unpack_inputs(inputs)

    logits = self.proj_layer(hidden)
    one_hot_target = tf.one_hot(labels, self.n_class, dtype=hidden.dtype)  # pytype: disable=attribute-error
    loss = -tf.reduce_sum(tf.nn.log_softmax(logits) * one_hot_target, -1)

    return loss, logits
    def call(self, inputs):
        """Implements call() for the layer."""
        (from_tensor, to_tensor,
         attention_mask) = tf_utils.unpack_inputs(inputs)

        # Scalar dimensions referenced here:
        #   B = batch size (number of sequences)
        #   F = `from_tensor` sequence length
        #   T = `to_tensor` sequence length
        #   N = `num_attention_heads`
        #   H = `size_per_head`
        # `query_tensor` = [B, F, N ,H]
        query_tensor = self.query_dense(from_tensor)

        # `key_tensor` = [B, T, N, H]
        key_tensor = self.key_dense(to_tensor)

        # `value_tensor` = [B, T, N, H]
        value_tensor = self.value_dense(to_tensor)

        # Take the dot product between "query" and "key" to get the raw
        # attention scores.
        attention_scores = tf.einsum("BTNH,BFNH->BNFT", key_tensor,
                                     query_tensor)
        attention_scores = tf.multiply(
            attention_scores, 1.0 / math.sqrt(float(self.size_per_head)))

        if attention_mask is not None:
            # `attention_mask` = [B, 1, F, T]
            attention_mask = tf.expand_dims(attention_mask, axis=[1])

            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
            # masked positions, this operation will create a tensor which is 0.0 for
            # positions we want to attend and -10000.0 for masked positions.
            adder = (1.0 - tf.cast(attention_mask,
                                   attention_scores.dtype)) * -10000.0

            # Since we are adding it to the raw scores before the softmax, this is
            # effectively the same as removing these entirely.
            attention_scores += adder

        # Normalize the attention scores to probabilities.
        # `attention_probs` = [B, N, F, T]
        attention_probs = tf.nn.softmax(attention_scores)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.attention_probs_dropout(attention_probs)

        # `context_layer` = [B, F, N, H]
        context_tensor = tf.einsum("BNFT,BTNH->BFNH", attention_probs,
                                   value_tensor)

        return context_tensor
Exemplo n.º 10
0
  def call(self, inputs):
    """Implements call() for the layer."""
    (h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, attn_mask_h,
     attn_mask_g, mems, target_mapping) = tf_utils.unpack_inputs(inputs)

    if mems is not None and mems.shape.ndims > 1:
      cat = tf.concat([mems, h], 0)
    else:
      cat = h

    # content heads
    q_head_h = tf.einsum('ibh,hnd->ibnd', h, self.qh_projection_layer)
    k_head_h = tf.einsum('ibh,hnd->ibnd', cat, self.kh_projection_layer)
    v_head_h = tf.einsum('ibh,hnd->ibnd', cat, self.vh_projection_layer)

    # positional heads
    k_head_r = tf.einsum('ibh,hnd->ibnd', r, self.kr_projection_layer)

    # core attention ops
    attn_vec_h = self.relative_attention_layer(
        q_head_h, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias,
        r_r_bias, r_s_bias, attn_mask_h)

    # post processing
    output_h = tf.einsum('ibnd,hnd->ibh', attn_vec_h, self.proj_o)
    output_h = self.attention_dropout(output_h)
    output_h = self.output_layer_norm(output_h + h)

    output_g = None
    if g is not None:  # enable two-stream attention
      # g-stream
      q_head_g = tf.einsum('ibh,hnd->ibnd', g, self.qh_projection_layer)
      if target_mapping is not None:
        q_head_g = tf.einsum('mbnd,mlb->lbnd', q_head_g, target_mapping)
        attn_vec_g = self.relative_attention_layer(
            q_head_g, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat,
            r_w_bias, r_r_bias, r_s_bias, attn_mask_g)
        attn_vec_g = tf.einsum('lbnd,mlb->mbnd', attn_vec_g, target_mapping)

      else:
        attn_vec_g = self.relative_attention_layer(
            q_head_g, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat,
            r_w_bias, r_r_bias, r_s_bias, attn_mask_g)

      # post processing
      output_g = tf.einsum('ibnd,hnd->ibh', attn_vec_g, self.proj_o)
      output_g = self.attention_dropout(output_g)
      output_g = self.output_layer_norm(output_g + g)

    return (output_h, output_g)
  def call(self, inputs):
    """Implements call() for the layer."""
    unpacked_inputs = tf_utils.unpack_inputs(inputs)
    pooled_output = unpacked_inputs[0]
    sequence_output = unpacked_inputs[1]
    masked_lm_positions = unpacked_inputs[2]

    mask_lm_input_tensor = gather_indexes(sequence_output, masked_lm_positions)
    lm_output = self.lm_dense(mask_lm_input_tensor)
    lm_output = self.lm_layer_norm(lm_output)
    lm_output = tf.matmul(lm_output, self.embedding_table, transpose_b=True)
    lm_output = tf.nn.bias_add(lm_output, self.output_bias)
    lm_output = tf.nn.log_softmax(lm_output, axis=-1)

    logits = tf.matmul(pooled_output, self.next_seq_weights, transpose_b=True)
    logits = tf.nn.bias_add(logits, self.next_seq_bias)
    sentence_output = tf.nn.log_softmax(logits, axis=-1)
    return (lm_output, sentence_output)
Exemplo n.º 12
0
  def call(self, inputs):
    """Implements call() for the layer."""
    (hidden, target, lookup_table, tgt_mask) = tf_utils.unpack_inputs(inputs)
    if self.use_proj:
      hidden = self.proj_layer_norm(self.proj_layer(hidden))
    if self.tie_weight:
      logits = tf.einsum('ibd,nd->ibn', hidden, lookup_table) + self.softmax_b
    else:
      logits = tf.einsum('ibd,nd->ibn', hidden, self.softmax_w) + self.softmax_b

    if self.use_tpu:
      one_hot_target = tf.one_hot(target, self.n_token, dtype=logits.dtype)
      loss = -tf.reduce_sum(tf.nn.log_softmax(logits) * one_hot_target, -1)
    else:
      loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
          labels=target, logits=logits)

    total_loss = tf.reduce_sum(loss * tgt_mask) / tf.reduce_sum(tgt_mask)

    return total_loss, logits
Exemplo n.º 13
0
  def call(self, inputs):
    """Implements call() for the layer."""
    (q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias,
     r_r_bias, r_s_bias, attn_mask) = tf_utils.unpack_inputs(inputs)

    # content based attention score
    ac = tf.einsum('ibnd,jbnd->ijbn', q_head + r_w_bias, k_head_h)

    # position based attention score
    bd = tf.einsum('ibnd,jbnd->ijbn', q_head + r_r_bias, k_head_r)
    bd = rel_shift(bd, klen=tf.shape(ac)[1])

    # segment-based attention score
    if seg_mat is None:
      ef = 0
    else:
      ef = tf.einsum('ibnd,snd->isbn', q_head + r_s_bias, seg_embed)
      tgt_shape = tf.shape(bd)
      ef = tf.where(
          tf.broadcast_to(tf.expand_dims(seg_mat, 3), tgt_shape),
          tf.broadcast_to(ef[:, 1:, :, :], tgt_shape),
          tf.broadcast_to(ef[:, :1, :, :], tgt_shape))

    # merges attention scores and performs masking
    attn_score = (ac + bd + ef) * self.scale
    if attn_mask is not None:
      attn_score = attn_score - 1e30 * attn_mask

    # attention probability
    attn_prob = tf.nn.softmax(attn_score, 1)
    attn_prob = self.attention_probs_dropout(attn_prob)

    # attention output
    attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, v_head_h)

    return attn_vec