Beispiel #1
0
  def regression_loss(self, hidden, labels, is_training, scope,
                      reuse=tf.AUTO_REUSE, return_logits=False):
    """Get regression loss."""
    net_config = self.net_config
    initializer = self.get_initializer()

    with tf.variable_scope(scope, reuse=reuse):
      hidden = ops.dropout_op(hidden, net_config.dropout, training=is_training)
      logits = ops.dense(
          hidden,
          1,
          initializer=initializer,
          scope="logit")

      # Always cast to float32 for loss
      logits = tf.squeeze(logits, axis=-1)
      if logits.dtype != tf.float32:
        logits = tf.cast(logits, tf.float32)

      loss = tf.square(logits - tf.cast(labels, logits.dtype))

      if return_logits:
        return loss, logits

      return loss
Beispiel #2
0
  def classification_loss(self, hidden, labels, n_class, is_training, scope,
                          reuse=tf.AUTO_REUSE, return_logits=False):
    """Get classification loss."""
    net_config = self.net_config
    initializer = self.get_initializer()

    with tf.variable_scope(scope, reuse=reuse):
      hidden = ops.dropout_op(hidden, net_config.dropout, training=is_training)
      logits = ops.dense(
          hidden,
          n_class,
          initializer=initializer,
          scope="logit")

      # Always cast to float32 for softmax & loss
      if logits.dtype != tf.float32:
        logits = tf.cast(logits, tf.float32)

      one_hot_target = tf.one_hot(labels, n_class, dtype=logits.dtype)
      loss = -tf.reduce_sum(tf.nn.log_softmax(logits) * one_hot_target, -1)

      if return_logits:
        return loss, logits

      return loss
Beispiel #3
0
  def get_race_loss(self, labels, inputs, is_training, seg_id=None,
                    input_mask=None, use_tpu=False, use_bfloat16=False):
    """RACE loss."""
    net_config = self.net_config
    initializer = self.get_initializer()

    bsz_per_core = tf.shape(inputs)[0]
    inputs = tf.reshape(inputs, [bsz_per_core * 4, -1])
    labels = tf.reshape(labels, [bsz_per_core])

    if seg_id is not None:
      seg_id = tf.reshape(seg_id, [bsz_per_core * 4, -1])
    if input_mask is not None:
      input_mask = tf.reshape(input_mask, [bsz_per_core * 4, -1])

    summary, _ = self.get_pooled_output(inputs,
                                        is_training,
                                        seg_id=seg_id,
                                        input_mask=input_mask,
                                        use_tpu=use_tpu,
                                        use_bfloat16=use_bfloat16)

    with tf.variable_scope("race"):
      summary = ops.dropout_op(summary, net_config.dropout,
                               training=is_training)
      logits = ops.dense(
          summary,
          1,
          initializer=initializer,
          scope="logits")

      logits = tf.reshape(logits, [bsz_per_core, 4])
      logits = tf.cast(logits, tf.float32)
      one_hot_target = tf.one_hot(labels, 4, dtype=logits.dtype)
      per_example_loss = -tf.reduce_sum(
          tf.nn.log_softmax(logits) * one_hot_target, -1)

    return per_example_loss, logits
Beispiel #4
0
  def input_embedding(self, inputs, is_training, seg_id=None, pos_id=None,
                      word_embed_table=None, use_tpu=False, scope="input",
                      reuse=tf.AUTO_REUSE, dtype=tf.float32):
    """Turn input ids to input embedding."""

    net_config = self.net_config
    initializer = self.get_initializer()
    ret_dict = {}

    ##### Embedding
    def embed_func(x, pos_id, seg_id):
      """Word embed + Position embed + Segment embed (if provided)."""
      # Word embedding
      embed, word_embed_table = ops.embedding_lookup(
          x=x,
          n_embed=net_config.vocab_size,
          d_embed=net_config.d_embed,
          initializer=initializer,
          use_tpu=use_tpu,
          dtype=dtype,
          scope="word_embedding")

      if net_config.rel_attn_type == "null":
        # Position embedding
        if pos_id is None:
          pos_id = tf.cast(tf.range(tf.shape(x)[-1]), x.dtype)
        pos_emb, _ = ops.embedding_lookup(
            x=pos_id,
            n_embed=512,
            d_embed=net_config.d_embed,
            initializer=initializer,
            use_tpu=use_tpu,
            dtype=dtype,
            scope="position_embedding")
        embed += pos_emb

        # Segment embedding
        if seg_id is not None:
          seg_emb, _ = ops.embedding_lookup(
              x=seg_id % 2,
              n_embed=2,
              d_embed=net_config.d_embed,
              initializer=initializer,
              use_tpu=use_tpu,
              dtype=dtype,
              scope="segment_embedding")
          embed += seg_emb

      return embed, word_embed_table

    with tf.variable_scope(scope, reuse=reuse):
      ##### Input embedding layer normalization and dropout
      word_emb, word_embed_table = embed_func(x=inputs,
                                              pos_id=pos_id,
                                              seg_id=seg_id)
      word_emb = ops.layer_norm_op(word_emb, norm_shape=[net_config.d_embed])

      output = ops.dropout_op(word_emb,
                              net_config.dropout,
                              training=is_training)

    return output, word_embed_table, ret_dict
Beispiel #5
0
  def get_squad_loss(self, inputs, cls_index, para_mask, is_training,
                     seg_id=None, input_mask=None, start_positions=None,
                     use_tpu=False, use_bfloat16=False):
    """SQuAD loss."""
    net_config = self.net_config
    initializer = self.get_initializer()

    seq_len = tf.shape(inputs)[1]
    output, _, _ = self.extract_hiddens(
        inputs,
        is_training,
        seg_id=seg_id,
        input_mask=input_mask,
        use_decoder=True,
        use_tpu=use_tpu,
        use_bfloat16=use_bfloat16)

    with tf.variable_scope("start_logits"):
      # [B x L x D] -> [B x L x 1]
      start_logits = ops.dense(
          output,
          1,
          initializer=initializer)
      # [B x L x 1] -> [B x L]
      start_logits = tf.squeeze(start_logits, -1)
      start_logits_masked = start_logits * (1 - para_mask) - 1e30 * para_mask
      # [B x L]
      start_log_probs = tf.nn.log_softmax(
          tf.cast(start_logits_masked, tf.float32), -1)

    with tf.variable_scope("end_logits"):
      if FLAGS.conditional_end:
        if is_training:
          assert start_positions is not None
          start_index = tf.one_hot(start_positions, depth=seq_len, axis=-1,
                                   dtype=output.dtype)
          start_features = tf.einsum("blh,bl->bh", output, start_index)
          start_features = tf.tile(start_features[:, None], [1, seq_len, 1])
          end_logits = ops.dense(
              tf.concat([output, start_features], axis=-1),
              net_config.d_model,
              initializer=initializer,
              activation=tf.tanh,
              scope="dense_0")
          end_logits = ops.layer_norm_op(end_logits, begin_norm_axis=-1)

          end_logits = ops.dense(
              end_logits, 1,
              initializer=initializer,
              scope="dense_1")
          end_logits = tf.squeeze(end_logits, -1)
          end_logits_masked = end_logits * (1 - para_mask) - 1e30 * para_mask
          # [B x L]
          end_log_probs = tf.nn.log_softmax(
              tf.cast(end_logits_masked, tf.float32), -1)
        else:
          start_top_log_probs, start_top_index = tf.nn.top_k(
              start_log_probs, k=FLAGS.start_n_top)
          start_index = tf.one_hot(start_top_index,
                                   depth=seq_len, axis=-1, dtype=output.dtype)
          # [B x L x D] + [B x K x L] -> [B x K x D]
          start_features = tf.einsum("blh,bkl->bkh", output, start_index)
          # [B x L x D] -> [B x 1 x L x D] -> [B x K x L x D]
          end_input = tf.tile(output[:, None],
                              [1, FLAGS.start_n_top, 1, 1])
          # [B x K x D] -> [B x K x 1 x D] -> [B x K x L x D]
          start_features = tf.tile(start_features[:, :, None],
                                   [1, 1, seq_len, 1])
          # [B x K x L x 2D]
          end_input = tf.concat([end_input, start_features], axis=-1)
          end_logits = ops.dense(
              end_input,
              net_config.d_model,
              initializer=initializer,
              activation=tf.tanh,
              scope="dense_0")
          end_logits = ops.layer_norm_op(end_logits, begin_norm_axis=-1)
          # [B x K x L x 1]
          end_logits = ops.dense(
              end_logits,
              1,
              initializer=initializer,
              scope="dense_1")

          # [B x K x L]
          end_logits = tf.squeeze(end_logits, -1)
          if FLAGS.use_masked_loss:
            end_logits_masked = end_logits * (
                1 - para_mask[:, None]) - 1e30 * para_mask[:, None]
          else:
            end_logits_masked = end_logits
          # [B x K x L]
          end_log_probs = tf.nn.log_softmax(
              tf.cast(end_logits_masked, tf.float32), -1)
          # [B x K x K']
          end_top_log_probs, end_top_index = tf.nn.top_k(
              end_log_probs, k=FLAGS.end_n_top)
          # [B x K*K']
          end_top_log_probs = tf.reshape(
              end_top_log_probs,
              [-1, FLAGS.start_n_top * FLAGS.end_n_top])
          end_top_index = tf.reshape(
              end_top_index,
              [-1, FLAGS.start_n_top * FLAGS.end_n_top])
      else:
        end_logits = ops.dense(
            output,
            1,
            initializer=initializer)
        end_logits = tf.squeeze(end_logits, -1)
        end_logits_masked = end_logits * (1 - para_mask) - 1e30 * para_mask
        end_log_probs = tf.nn.log_softmax(
            tf.cast(end_logits_masked, tf.float32), -1)
        if not is_training:
          start_top_log_probs, start_top_index = tf.nn.top_k(
              start_log_probs, k=FLAGS.start_n_top)
          end_top_log_probs, end_top_index = tf.nn.top_k(
              end_log_probs, k=FLAGS.end_n_top)

    return_dict = {}
    if is_training:
      return_dict["start_log_probs"] = start_log_probs
      return_dict["end_log_probs"] = end_log_probs
    else:
      return_dict["start_top_log_probs"] = start_top_log_probs
      return_dict["start_top_index"] = start_top_index
      return_dict["end_top_log_probs"] = end_top_log_probs
      return_dict["end_top_index"] = end_top_index

    if FLAGS.use_answer_class:
      with tf.variable_scope("answer_class"):
        cls_index = tf.one_hot(cls_index, seq_len, axis=-1, dtype=output.dtype)
        cls_feature = tf.einsum("blh,bl->bh", output, cls_index)

        start_p = tf.nn.softmax(start_logits_masked, axis=-1,
                                name="softmax_start")
        start_feature = tf.einsum("blh,bl->bh", output, start_p)

        ans_feature = tf.concat([start_feature, cls_feature], -1)
        ans_feature = ops.dense(
            ans_feature,
            FLAGS.d_model,
            activation=tf.tanh,
            initializer=initializer,
            scope="dense_0")
        ans_feature = ops.dropout_op(ans_feature, net_config.dropout,
                                     training=is_training)
        cls_logits = ops.dense(
            ans_feature,
            1,
            initializer=initializer,
            scope="dense_1",
            use_bias=False)
        cls_logits = tf.squeeze(cls_logits, -1)

        return_dict["cls_logits"] = tf.cast(cls_logits, tf.float32)
    else:
      cls_index = tf.one_hot(cls_index, seq_len, axis=-1, dtype=tf.float32)
      cls_logits = tf.einsum("bl,bl->b", start_log_probs, cls_index)

      return_dict["cls_logits"] = tf.cast(cls_logits, tf.float32)

    return return_dict