Exemplo n.º 1
0
  def test_unflatten(self):
    with tf.Graph().as_default():
      tensor = tf.placeholder(tf.float32, [4, 7, 6, 3])
      w = tf.placeholder(tf.float32, [3, 9])

      flat_tensor, unflatten = tensor_utils.flatten(tensor)
      self.assertAllEqual(tensor_utils.shape(flat_tensor), [4 * 7 * 6, 3])

      flat_projected_tensor = tf.matmul(flat_tensor, w)
      projected_tensor = unflatten(flat_projected_tensor)
      self.assertAllEqual(tensor_utils.shape(projected_tensor), [4, 7, 6, 9])
Exemplo n.º 2
0
def span_candidates(masks, max_span_width):
    """Generate span candidates.

  Args:
    masks: <int32> [num_retrievals, max_sequence_len]
    max_span_width: int

  Returns:
    starts: <int32> [num_spans]
    ends: <int32> [num_spans]
    span_masks: <int32> [num_retrievals, num_spans]
  """
    _, max_sequence_len = tensor_utils.shape(masks)

    def _spans_given_width(width):
        current_starts = tf.range(max_sequence_len - width + 1)
        current_ends = tf.range(width - 1, max_sequence_len)
        return current_starts, current_ends

    starts, ends = zip(*(_spans_given_width(w + 1)
                         for w in range(max_span_width)))

    # [num_spans]
    starts = tf.concat(starts, 0)
    ends = tf.concat(ends, 0)

    # [num_retrievals, num_spans]
    start_masks = tf.gather(masks, starts, axis=-1)
    end_masks = tf.gather(masks, ends, axis=-1)
    span_masks = start_masks * end_masks

    return starts, ends, span_masks
Exemplo n.º 3
0
    def test_shape_static(self):
        with tf.Graph().as_default():
            tensor = tf.placeholder(tf.int64, [4, 7])
            d0_single = tensor_utils.shape(tensor, 0)
            d1_single = tensor_utils.shape(tensor, 1)
            d0_full, d1_full = tensor_utils.shape(tensor)

            self.assertIsInstance(d0_single, int)
            self.assertIsInstance(d1_single, int)
            self.assertIsInstance(d0_full, int)
            self.assertIsInstance(d1_full, int)

            self.assertEqual(d0_single, 4)
            self.assertEqual(d1_single, 7)
            self.assertEqual(d0_full, 4)
            self.assertEqual(d1_full, 7)
Exemplo n.º 4
0
def mask_attention(attention, seq_len1, seq_len2):
    """Masks an attention matrix.

  Args:
    attention: <tf.float32>[batch, seq_len1, seq_len2]
    seq_len1: <tf.int32>[batch]
    seq_len2: <tf.int32>[batch]

  Returns:
    the masked scores <tf.float32>[batch, seq_len1, seq_len2]
  """
    dim1 = tensor_utils.shape(attention, 1)
    dim2 = tensor_utils.shape(attention, 2)
    m1 = tf.sequence_mask(seq_len1, dim1)
    m2 = tf.sequence_mask(seq_len2, dim2)
    joint_mask = tf.logical_and(tf.expand_dims(m1, 2), tf.expand_dims(m2, 1))
    return ops.mask_logits(attention, joint_mask)
Exemplo n.º 5
0
    def test_shape_mixed(self):
        """Test for shape() with a mixture of static and dynamic sizes."""
        with tf.Graph().as_default():
            tensor = tf.placeholder(tf.int64, [4, None])
            d0_single = tensor_utils.shape(tensor, 0)
            d1_single = tensor_utils.shape(tensor, 1)
            d0_full, d1_full = tensor_utils.shape(tensor)

            self.assertIsInstance(d0_single, int)
            self.assertIsInstance(d1_single, tf.Tensor)
            self.assertIsInstance(d0_full, int)
            self.assertIsInstance(d1_full, tf.Tensor)

            self.assertEqual(d0_single, 4)
            self.assertEqual(d0_full, 4)

            with tf.Session() as sess:
                feed_dict = {tensor: np.zeros((4, 7))}

                tf_d1_single = sess.run(d1_single, feed_dict=feed_dict)
                self.assertEqual(tf_d1_single, 7)

                tf_d1_full = sess.run(d1_full, feed_dict=feed_dict)
                self.assertEqual(tf_d1_full, 7)
Exemplo n.º 6
0
    def test_shape_dynamic(self):
        with tf.Graph().as_default():
            tensor = tf.placeholder(tf.int64, [None, None])
            d0_single = tensor_utils.shape(tensor, 0)
            d1_single = tensor_utils.shape(tensor, 1)
            d0_full, d1_full = tensor_utils.shape(tensor)

            self.assertIsInstance(d0_single, tf.Tensor)
            self.assertIsInstance(d1_single, tf.Tensor)
            self.assertIsInstance(d0_full, tf.Tensor)
            self.assertIsInstance(d1_full, tf.Tensor)

            with tf.Session() as sess:
                feed_dict = {tensor: np.zeros((4, 7))}

                tf_d0_single, tf_d1_single = sess.run([d0_single, d1_single],
                                                      feed_dict=feed_dict)
                self.assertEqual(tf_d0_single, 4)
                self.assertEqual(tf_d1_single, 7)

                tf_d0_full, tf_d1_full = sess.run([d0_full, d1_full],
                                                  feed_dict=feed_dict)
                self.assertEqual(tf_d0_full, 4)
                self.assertEqual(tf_d1_full, 7)
Exemplo n.º 7
0
def _bilinear_score(context_emb, question_emb):
    """Compute a bilinear score between the context and question embeddings.

  Args:
    context_emb: <float32> [batch_size, max_context_len, hidden_size]
    question_emb: <float32> [batch_size, hidden_size]

  Returns:
    scores: <float32> [batch_size, max_context_len]
  """
    # [batch_size, hidden_size]
    projected_question_emb = tf.layers.dense(
        question_emb, tensor_utils.shape(context_emb, -1))

    # [batch_size, max_context_len, 1]
    scores = tf.matmul(context_emb, tf.expand_dims(projected_question_emb, -1))

    return tf.squeeze(scores, -1)
Exemplo n.º 8
0
def cross_shard_pad(input_tensor):
    """Cross shard pad.

  Assuming `input_tensor` is replicated over different TPU cores across the
  zeroth dimension, this creates a global tensor with unique chunks per replica.
  This function only fills in the local `input_tensor` and pads the non-local
  part of the tensor with zeros. Does not actually do any cross-shard
  communication.

  Args:
    input_tensor: <int32|float32> [local_batch_size, dim1, dim2, ...]

  Returns:
    padded_tensor: <int32|float32>
        [local_batch_size * num_shards, dim1, dim2, ...]
  """
    num_shards = num_tpu_shards()

    # [num_shards]
    local_mask = tf.equal(tf.range(num_shards), shard_id())
    local_mask = tf.cast(local_mask, input_tensor.dtype)

    tensor_shape = tensor_utils.shape(input_tensor)
    local_batch_size = tensor_shape[0]
    global_batch_size = num_shards * local_batch_size

    # [num_shards, 1, 1, ...]
    for _ in tensor_shape:
        local_mask = tf.expand_dims(local_mask, -1)

    # [num_shards, local_batch_size, input_tensor_dim1, ...]
    padded_tensor = local_mask * tf.expand_dims(input_tensor, 0)

    # [global_batch_size, input_tensor_dim1, ...]
    padded_tensor = tf.reshape(padded_tensor,
                               [global_batch_size] + tensor_shape[1:])
    return padded_tensor
Exemplo n.º 9
0
def batch_word_to_char_ids(words, word_length):
    """Batched version of word_to_char_ids.

  This is a deterministic function that should be computed during preprocessing.
  We pin this op to the CPU anyways to be safe, since it is slower on GPUs.

  Args:
    words: <string> [...]
    word_length: Number of bytes to include per word.

  Returns:
    char_ids: <int32> [..., word_length]
  """
    with tf.device("/cpu:0"):
        flat_words = tf.reshape(words, [-1])
        flat_char_ids = tf.map_fn(fn=partial(word_to_char_ids,
                                             word_length=word_length),
                                  elems=flat_words,
                                  dtype=tf.int32,
                                  back_prop=False)

    char_ids = tf.reshape(flat_char_ids,
                          tensor_utils.shape(words) + [word_length])
    return char_ids
Exemplo n.º 10
0
def model_fn(features, labels, mode, params):
    """Model function."""
    del labels

    # [local_batch_size, block_seq_len]
    block_ids = features["block_ids"]
    block_mask = features["block_mask"]
    block_segment_ids = features["block_segment_ids"]

    # [local_batch_size, query_seq_len]
    query_ids = features["query_ids"]
    query_mask = features["query_mask"]

    local_batch_size = tensor_utils.shape(block_ids, 0)
    tf.logging.info("Model batch size: %d", local_batch_size)

    ict_module = create_ict_module(params, mode)

    query_emb = ict_module(inputs=dict(input_ids=query_ids,
                                       input_mask=query_mask,
                                       segment_ids=tf.zeros_like(query_ids)),
                           signature="projected")
    block_emb = ict_module(inputs=dict(input_ids=block_ids,
                                       input_mask=block_mask,
                                       segment_ids=block_segment_ids),
                           signature="projected")

    if params["use_tpu"]:
        # [global_batch_size, hidden_size]
        block_emb = tpu_utils.cross_shard_concat(block_emb)

        # [global_batch_size, local_batch_size]
        labels = tpu_utils.cross_shard_pad(tf.eye(local_batch_size))

        # [local_batch_size]
        labels = tf.argmax(labels, 0)
    else:
        # [local_batch_size]
        labels = tf.range(local_batch_size)

    tf.logging.info("Global batch size: %s", tensor_utils.shape(block_emb, 0))

    # [batch_size, global_batch_size]
    logits = tf.matmul(query_emb, block_emb, transpose_b=True)

    # []
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

    train_op = optimization.create_optimizer(
        loss=loss,
        init_lr=params["learning_rate"],
        num_train_steps=params["num_train_steps"],
        num_warmup_steps=min(10000,
                             max(100, int(params["num_train_steps"] / 10))),
        use_tpu=params["use_tpu"] if "use_tpu" in params else False)

    predictions = tf.argmax(logits, -1)

    metric_args = [
        query_mask, block_mask, labels, predictions, features["mask_query"]
    ]

    def metric_fn(query_mask, block_mask, labels, predictions, mask_query):
        masked_accuracy = tf.metrics.accuracy(labels=labels,
                                              predictions=predictions,
                                              weights=mask_query)
        unmasked_accuracy = tf.metrics.accuracy(
            labels=labels,
            predictions=predictions,
            weights=tf.logical_not(mask_query))
        return dict(query_non_padding=tf.metrics.mean(query_mask),
                    block_non_padding=tf.metrics.mean(block_mask),
                    actual_mask_ratio=tf.metrics.mean(mask_query),
                    masked_accuracy=masked_accuracy,
                    unmasked_accuracy=unmasked_accuracy)

    if params["use_tpu"]:
        return tf.estimator.tpu.TPUEstimatorSpec(mode=mode,
                                                 loss=loss,
                                                 train_op=train_op,
                                                 eval_metrics=(metric_fn,
                                                               metric_args))
    else:
        return tf.estimator.EstimatorSpec(
            mode=mode,
            loss=loss,
            train_op=train_op,
            eval_metric_ops=metric_fn(*metric_args),
            predictions=predictions)
Exemplo n.º 11
0
def variational_dropout(x, dropout_rate, is_train):
    if is_train:
        shape = tensor_utils.shape(x)
        return tf.nn.dropout(x, 1.0 - dropout_rate, [shape[0], 1, shape[2]])
    else:
        return x
def decomposable_attention(emb1,
                           len1,
                           emb2,
                           len2,
                           hidden_size,
                           hidden_layers,
                           dropout_ratio,
                           mode,
                           epsilon=1e-8):
    """See https://arxiv.org/abs/1606.01933.

  Args:
    emb1: A Tensor with shape [batch_size, max_len1, emb_size] representing the
        first input sequence.
    len1: A Tensor with shape [batch_size], indicating the true sequence length
        of `emb1`. This is required due to padding.
    emb2: A Tensor with shape [batch_size, max_len2, emb_size] representing the
        second input sequence.
    len2: A Tensor with shape [batch_size], indicating the true sequence length
        of `emb1`. This is required due to padding.
    hidden_size: An integer indicating the size of each hidden layer in the
        feed-forward neural networks.
    hidden_layers: An integer indicating the number of hidden layers in the
        feed-forward neural networks.
    dropout_ratio: The probability of dropping out each unit in the activation.
        This can be None, and is only applied during training.
    mode: One of the keys from tf.estimator.ModeKeys.
    epsilon: A small positive constant to add to masks for numerical stability.

  Returns:
    final_emb: A Tensor with shape [batch_size, hidden_size].
  """
    # [batch_size, maxlen1]
    mask1 = tf.sequence_mask(len1,
                             tensor_utils.shape(emb1, 1),
                             dtype=tf.float32)

    # [batch_size, maxlen2]
    mask2 = tf.sequence_mask(len2,
                             tensor_utils.shape(emb2, 1),
                             dtype=tf.float32)

    with tf.variable_scope("attend"):
        projected_emb1 = common_layers.ffnn(emb1,
                                            [hidden_size] * hidden_layers,
                                            dropout_ratio, mode)
    with tf.variable_scope("attend", reuse=True):
        projected_emb2 = common_layers.ffnn(emb2,
                                            [hidden_size] * hidden_layers,
                                            dropout_ratio, mode)

    # [batch_size, maxlen1, maxlen2]
    attention_scores = tf.matmul(projected_emb1,
                                 projected_emb2,
                                 transpose_b=True)
    attention_weights1 = tf.nn.softmax(
        attention_scores + tf.log(tf.expand_dims(mask2, 1) + epsilon), 2)
    attention_weights2 = tf.nn.softmax(
        attention_scores + tf.log(tf.expand_dims(mask1, 2) + epsilon), 1)

    # [batch_size, maxlen1, emb_size]
    attended_emb1 = tf.matmul(attention_weights1, emb2)

    # [batch_size, maxlen2, emb_size]
    attended_emb2 = tf.matmul(attention_weights2, emb1, transpose_a=True)

    with tf.variable_scope("compare"):
        compared_emb1 = common_layers.ffnn(
            tf.concat([emb1, attended_emb1], -1),
            [hidden_size] * hidden_layers, dropout_ratio, mode)
    with tf.variable_scope("compare", reuse=True):
        compared_emb2 = common_layers.ffnn(
            tf.concat([emb2, attended_emb2], -1),
            [hidden_size] * hidden_layers, dropout_ratio, mode)

    compared_emb1 *= tf.expand_dims(mask1, -1)
    compared_emb2 *= tf.expand_dims(mask2, -1)

    # [batch_size, hidden_size]
    aggregated_emb1 = tf.reduce_sum(compared_emb1, 1)
    aggregated_emb2 = tf.reduce_sum(compared_emb2, 1)
    with tf.variable_scope("aggregate"):
        final_emb = common_layers.ffnn(
            tf.concat([aggregated_emb1, aggregated_emb2], -1),
            [hidden_size] * hidden_layers, dropout_ratio, mode)
    return final_emb
Exemplo n.º 13
0
def model_function(features, labels, mode, params, embeddings):
    """A model function satisfying the tf.estimator API.

  Args:
    features: Dictionary of feature tensors with keys:
        - question_tok: <string> [batch_size, max_question_len]
        - context_tok: <string> [batch_size, max_num_context, max_context_len]
        - question_tok_len: <int32> [batch_size]
        - num_context: <int32> [batch_size]
        - context_tok_len: <int32> [batch_size]
        - question_tok_wid: <int32> [batch_size, max_question_len]
        - context_tok_wid: <int32> [batch_size, max_num_context,
          max_context_len]
         - long_answer_indices: <int32> [batch_size]
    labels: <int32> [batch_size] for answer index (-1 = NULL).
    mode: One of the keys from tf.estimator.ModeKeys.
    params: Dictionary of hyperparameters.
    embeddings: An embedding_utils.PretrainedWordEmbeddings object.

  Returns:
    estimator_spec: A tf.estimator.EstimatorSpec object.
  """
    del params  # Unused.

    if mode == tf.estimator.ModeKeys.PREDICT:
        # Add a dummy batch dimension if we are exporting the predictor.
        features = {k: tf.expand_dims(v, 0) for k, v in features.items()}

    embedding_weights, embedding_scaffold = embeddings.get_params(
        trainable=False)

    # Features.
    question_tok_len = features["question_tok_len"]
    question_tok_wid = features["question_tok_wid"]
    context_tok_wid = features["context_tok_wid"]
    num_context = features["num_context"]
    context_tok_len = features["context_tok_len"]

    # Truncate the contexts and labels to a certain maximum length.
    context_tok_wid, num_context, context_tok_len = (
        nq_long_utils.truncate_contexts(context_token_ids=context_tok_wid,
                                        num_contexts=num_context,
                                        context_len=context_tok_len,
                                        max_contexts=FLAGS.max_contexts,
                                        max_context_len=FLAGS.max_context_len))

    non_null_context_scores = nq_long_decatt_model.build_model(
        question_tok_wid=question_tok_wid,
        question_lens=question_tok_len,
        context_tok_wid=context_tok_wid,
        context_lens=context_tok_len,
        embedding_weights=embedding_weights,
        mode=mode)

    # Mask out contexts that are padding.
    num_context_mask = tf.log(
        tf.sequence_mask(num_context,
                         tensor_utils.shape(non_null_context_scores, 1),
                         dtype=tf.float32))
    non_null_context_scores += num_context_mask

    # <float> [batch_size, 1]
    null_score = tf.zeros([tf.shape(question_tok_wid)[0], 1])

    # Offset everything by 1 to account for null context.
    # [batch_size, 1 + max_contexts]
    context_scores = tf.concat([null_score, non_null_context_scores], 1)

    if mode != tf.estimator.ModeKeys.PREDICT:
        labels = nq_long_utils.truncate_labels(labels, FLAGS.max_contexts)

        # In the data, NULL is given index -1 but this is not compatible with
        # softmax so shift by 1.
        labels = labels + 1

        # Reweight null examples.
        weights = nq_long_utils.compute_null_weights(labels, FLAGS.null_weight)

        # When computing the loss we take only the first label.
        loss_labels = labels[:, 0]

        # []
        loss = tf.losses.sparse_softmax_cross_entropy(labels=loss_labels,
                                                      logits=context_scores,
                                                      weights=weights)

        optimizer = tf.train.AdagradOptimizer(
            learning_rate=FLAGS.learning_rate)
        train_op = optimizer.minimize(loss=loss,
                                      global_step=tf.train.get_global_step())

        # <int32> [batch_size]
        eval_predictions = tf.to_int32(tf.argmax(context_scores, 1))

        non_null_match, non_null_gold, non_null_predictions = (
            nq_long_utils.compute_match_stats(eval_predictions, labels))

        precision, precision_op = (tf.metrics.mean(
            non_null_match, weights=non_null_predictions))
        recall, recall_op = (tf.metrics.mean(non_null_match,
                                             weights=non_null_gold))

        f1, f1_op = (nq_long_utils.f1_metric(precision=precision,
                                             precision_op=precision_op,
                                             recall=recall,
                                             recall_op=recall_op))

        # Bogus metric until we figure out how to connect Ming Wei's eval code.
        eval_metric_ops = {
            "precision": (precision, precision_op),
            "recall": (recall, recall_op),
            "f1": (f1, f1_op)
        }
    else:
        loss = None
        train_op = None
        eval_metric_ops = {}

    # In the export, we never predict NULL since the eval metric will compute the
    # best possible F1.
    export_long_answer_idx = tf.to_int32(tf.argmax(non_null_context_scores, 1))
    export_long_answer_score = tf.reduce_max(non_null_context_scores, 1)
    predictions = dict(idx=export_long_answer_idx,
                       score=export_long_answer_score)

    if mode == tf.estimator.ModeKeys.PREDICT:
        # Remove the dummy batch dimension if we are exporting the predictor.
        predictions = {k: tf.squeeze(v, 0) for k, v in predictions.items()}

    estimator_spec = tf.estimator.EstimatorSpec(
        mode=mode,
        loss=loss,
        predictions=predictions,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops,
        scaffold=embedding_scaffold)

    return estimator_spec
Exemplo n.º 14
0
def score_endpoints(question_emb,
                    question_len,
                    context_emb,
                    context_len,
                    hidden_size,
                    num_layers,
                    dropout_ratio,
                    mode,
                    use_cudnn=None):
    """Compute two scores over context words based on the input embeddings.

  Args:
    question_emb: <float32> [batch_size, max_question_len, hidden_size]
    question_len: <int32> [batch_size]
    context_emb: <float32>[batch_size, max_context_len, hidden_size]
    context_len: <int32> [batch_size]
    hidden_size: Size of hidden layers.
    num_layers: Number of LSTM layers.
    dropout_ratio: The probability of dropping out hidden units.
    mode: Object of type tf.estimator.ModeKeys.
    use_cudnn: Specify the use of cudnn. `None` denotes automatic selection.

  Returns:
    start_scores: <float32> [batch_size, max_context_words]
    end_scores: <float32> [batch_size, max_context_words]
  """
    # [batch_size, max_question_len]
    question_mask = tf.sequence_mask(question_len,
                                     tensor_utils.shape(question_emb, 1),
                                     dtype=tf.float32)

    # [batch_size, max_context_len, hidden_size]
    attended_emb = _attend_to_question(context_emb=context_emb,
                                       question_emb=question_emb,
                                       question_mask=question_mask,
                                       hidden_size=hidden_size)

    # [batch_size, max_context_len, hidden_size * 2]
    context_emb = tf.concat([context_emb, attended_emb], -1)

    with tf.variable_scope("contextualize_context"):
        # [batch_size, max_context_len, hidden_size]
        contextualized_context_emb = cudnn_layers.stacked_bilstm(
            input_emb=context_emb,
            input_len=context_len,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout_ratio=dropout_ratio,
            mode=mode,
            use_cudnn=use_cudnn)
    with tf.variable_scope("contextualize_question"):
        # [batch_size, max_question_len, hidden_size]
        contextualized_question_emb = cudnn_layers.stacked_bilstm(
            input_emb=question_emb,
            input_len=question_len,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout_ratio=dropout_ratio,
            mode=mode,
            use_cudnn=use_cudnn)
    if mode == tf_estimator.ModeKeys.TRAIN:
        contextualized_context_emb = tf.nn.dropout(contextualized_context_emb,
                                                   1.0 - dropout_ratio)
        contextualized_question_emb = tf.nn.dropout(
            contextualized_question_emb, 1.0 - dropout_ratio)

    # [batch_size, hidden_size]
    pooled_question_emb = _attention_pool(contextualized_question_emb,
                                          question_mask)

    if mode == tf_estimator.ModeKeys.TRAIN:
        pooled_question_emb = tf.nn.dropout(pooled_question_emb,
                                            1.0 - dropout_ratio)

    # [batch_size, max_context_len]
    with tf.variable_scope("start_scores"):
        start_scores = _bilinear_score(contextualized_context_emb,
                                       pooled_question_emb)
    with tf.variable_scope("end_scores"):
        end_scores = _bilinear_score(contextualized_context_emb,
                                     pooled_question_emb)
    context_log_mask = tf.log(
        tf.sequence_mask(context_len,
                         tensor_utils.shape(context_emb, 1),
                         dtype=tf.float32))
    start_scores += context_log_mask
    end_scores += context_log_mask
    return start_scores, end_scores
Exemplo n.º 15
0
def model_fn(features, labels, mode, params):
  """Model function."""
  del labels

  # ==============================
  # Input features
  # ==============================
  # [batch_size, query_seq_len]
  query_inputs = features["query_inputs"]

  # [batch_size, num_candidates, candidate_seq_len]
  candidate_inputs = features["candidate_inputs"]

  # [batch_size, num_candidates, query_seq_len + candidate_seq_len]
  joint_inputs = features["joint_inputs"]

  # [batch_size, num_masks]
  mlm_targets = features["mlm_targets"]
  mlm_positions = features["mlm_positions"]
  mlm_mask = features["mlm_mask"]

  # ==============================
  # Create modules.
  # ==============================
  bert_module = hub.Module(
      spec=params["bert_hub_module_handle"],
      name="bert",
      tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {},
      trainable=True)
  hub.register_module_for_export(bert_module, "bert")

  embedder_module = hub.Module(
      spec=params["embedder_hub_module_handle"],
      name="embedder",
      tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {},
      trainable=True)
  hub.register_module_for_export(embedder_module, "embedder")

  # ==============================
  # Retrieve.
  # ==============================
  # [batch_size, projected_size]
  query_emb = embedder_module(
      inputs=dict(
          input_ids=query_inputs.token_ids,
          input_mask=query_inputs.mask,
          segment_ids=query_inputs.segment_ids),
      signature="projected")

  # [batch_size * num_candidates, candidate_seq_len]
  flat_candidate_inputs, unflatten = flatten_bert_inputs(
      candidate_inputs)

  # [batch_size * num_candidates, projected_size]
  flat_candidate_emb = embedder_module(
      inputs=dict(
          input_ids=flat_candidate_inputs.token_ids,
          input_mask=flat_candidate_inputs.mask,
          segment_ids=flat_candidate_inputs.segment_ids),
      signature="projected")

  # [batch_size, num_candidates, projected_size]
  unflattened_candidate_emb = unflatten(flat_candidate_emb)

  # [batch_size, num_candidates]
  retrieval_score = tf.einsum("BD,BND->BN", query_emb,
                              unflattened_candidate_emb)

  # ==============================
  # Read.
  # ==============================
  # [batch_size * num_candidates, query_seq_len + candidate_seq_len]
  flat_joint_inputs, unflatten = flatten_bert_inputs(joint_inputs)

  # [batch_size * num_candidates, num_masks]
  flat_mlm_positions, _ = tensor_utils.flatten(
      tf.tile(
          tf.expand_dims(mlm_positions, 1), [1, params["num_candidates"], 1]))

  batch_size, num_masks = tensor_utils.shape(mlm_targets)

  # [batch_size * num_candidates, query_seq_len + candidates_seq_len]
  flat_joint_bert_outputs = bert_module(
      inputs=dict(
          input_ids=flat_joint_inputs.token_ids,
          input_mask=flat_joint_inputs.mask,
          segment_ids=flat_joint_inputs.segment_ids,
          mlm_positions=flat_mlm_positions),
      signature="mlm",
      as_dict=True)

  # [batch_size, num_candidates]
  candidate_score = retrieval_score

  # [batch_size, num_candidates]
  candidate_log_probs = tf.math.log_softmax(candidate_score)

  # ==============================
  # Compute marginal log-likelihood.
  # ==============================
  # [batch_size * num_candidates, num_masks]
  flat_mlm_logits = flat_joint_bert_outputs["mlm_logits"]

  # [batch_size, num_candidates, num_masks, vocab_size]
  mlm_logits = tf.reshape(
      flat_mlm_logits, [batch_size, params["num_candidates"], num_masks, -1])
  mlm_log_probs = tf.math.log_softmax(mlm_logits)

  # [batch_size, num_candidates, num_masks]
  tiled_mlm_targets = tf.tile(
      tf.expand_dims(mlm_targets, 1), [1, params["num_candidates"], 1])

  # [batch_size, num_candidates, num_masks, 1]
  tiled_mlm_targets = tf.expand_dims(tiled_mlm_targets, -1)

  # [batch_size, num_candidates, num_masks, 1]
  gold_log_probs = tf.batch_gather(mlm_log_probs, tiled_mlm_targets)

  # [batch_size, num_candidates, num_masks]
  gold_log_probs = tf.squeeze(gold_log_probs, -1)

  # [batch_size, num_candidates, num_masks]
  joint_gold_log_probs = (
      tf.expand_dims(candidate_log_probs, -1) + gold_log_probs)

  # [batch_size, num_masks]
  marginal_gold_log_probs = tf.reduce_logsumexp(joint_gold_log_probs, 1)

  # [batch_size, num_masks]
  float_mlm_mask = tf.cast(mlm_mask, tf.float32)

  # []
  loss = -tf.div_no_nan(
      tf.reduce_sum(marginal_gold_log_probs * float_mlm_mask),
      tf.reduce_sum(float_mlm_mask))

  # ==============================
  # Optimization
  # ==============================
  num_warmup_steps = min(10000, max(100, int(params["num_train_steps"] / 10)))
  train_op = optimization.create_optimizer(
      loss=loss,
      init_lr=params["learning_rate"],
      num_train_steps=params["num_train_steps"],
      num_warmup_steps=num_warmup_steps,
      use_tpu=params["use_tpu"])

  # ==============================
  # Evaluation
  # ==============================
  eval_metric_ops = None if params["use_tpu"] else dict()
  if mode != tf.estimator.ModeKeys.PREDICT:
    # [batch_size, num_masks]
    retrieval_utility = marginal_gold_log_probs - gold_log_probs[:, 0]
    retrieval_utility *= tf.cast(features["mlm_mask"], tf.float32)

    # []
    retrieval_utility = tf.div_no_nan(
        tf.reduce_sum(retrieval_utility), tf.reduce_sum(float_mlm_mask))
    add_mean_metric("retrieval_utility", retrieval_utility, eval_metric_ops)

    has_timestamp = tf.cast(
        tf.greater(features["export_timestamp"], 0), tf.float64)
    off_policy_delay_secs = (
        tf.timestamp() - tf.cast(features["export_timestamp"], tf.float64))
    off_policy_delay_mins = off_policy_delay_secs / 60.0
    off_policy_delay_mins *= tf.cast(has_timestamp, tf.float64)

    add_mean_metric("off_policy_delay_mins", off_policy_delay_mins,
                    eval_metric_ops)

  # Create empty predictions to avoid errors when running in prediction mode.
  predictions = dict()

  if params["use_tpu"]:
    return tf.estimator.tpu.TPUEstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        predictions=predictions)
  else:
    if eval_metric_ops is not None:
      # Make sure the eval metrics are updated during training so that we get
      # quick feedback from tensorboard summaries when debugging locally.
      with tf.control_dependencies([u for _, u in eval_metric_ops.values()]):
        loss = tf.identity(loss)
    return tf.estimator.EstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops,
        predictions=predictions)
Exemplo n.º 16
0
def model_function(features, labels, mode, params, embeddings):
  """A model function satisfying the tf.estimator API.

  Args:
    features: Dictionary of feature tensors with keys:
        - question: <string> [batch_size, max_question_len]
        - question_len: <int32> [batch_size]
        - question_cid: <int32> [batch_size, max_question_len, max_chars]
        - question_wid: <int32> [batch_size, max_question_len]
        - context: <string> [batch_size, max_context_len]
        - context_len: <int32> [batch_size]
        - context_cid: <int32> [batch_size, max_context_len, max_chars]
        - context_wid: <int32> [batch_size, max_context_len]
        - answer_start: <int32> [batch_size]
        - answer_end: <int32> [batch_size]
    labels: Pair of tensors containing the answer start and answer end.
    mode: One of the keys from tf.estimator.ModeKeys.
    params: Unused parameter dictionary.
    embeddings: An embedding_utils.PretrainedWordEmbeddings object.

  Returns:
    estimator_spec: A tf.estimator.EstimatorSpec object.
  """
  del params

  if mode == tf.estimator.ModeKeys.PREDICT:
    # Add a dummy batch dimension if we are exporting the predictor.
    features = {k: tf.expand_dims(v, 0) for k, v in features.items()}

  embedding_weights, embedding_scaffold = embeddings.get_params(trainable=False)

  def _embed(prefix):
    """Embed the input text based and word and character IDs."""
    word_emb = tf.nn.embedding_lookup(embedding_weights,
                                      features[prefix + "_wid"])
    char_emb = common_layers.character_cnn(
        char_ids=features[prefix + "_cid"],
        emb_size=FLAGS.char_emb_size,
        kernel_width=FLAGS.char_kernel_width,
        num_filters=FLAGS.num_char_filters)
    concat_emb = tf.concat([word_emb, char_emb], -1)

    if mode == tf.estimator.ModeKeys.TRAIN:
      concat_emb = tf.nn.dropout(concat_emb, 1.0 - FLAGS.dropout_ratio)
    return concat_emb

  with tf.variable_scope("embed"):
    # [batch_size, max_question_len, hidden_size]
    question_emb = _embed("question")

  with tf.variable_scope("embed", reuse=True):
    # [batch_size, max_context_len, hidden_size]
    context_emb = _embed("context")

  # [batch_size, max_context_len]
  start_logits, end_logits = document_reader.score_endpoints(
      question_emb=question_emb,
      question_len=features["question_len"],
      context_emb=context_emb,
      context_len=features["context_len"],
      hidden_size=FLAGS.hidden_size,
      num_layers=FLAGS.num_layers,
      dropout_ratio=FLAGS.dropout_ratio,
      mode=mode,
      use_cudnn=False if mode == tf.estimator.ModeKeys.PREDICT else None)

  if mode != tf.estimator.ModeKeys.PREDICT:
    # [batch_size]
    start_labels, end_labels = labels

    # Since we truncate long contexts, some of the labels will not be
    # recoverable. In that case, we mask these invalid labels.
    valid_start_labels = tf.less(start_labels, features["context_len"])
    valid_end_labels = tf.less(end_labels, features["context_len"])
    tf.summary.histogram("valid_start_labels", tf.to_float(valid_start_labels))
    tf.summary.histogram("valid_end_labels", tf.to_float(valid_end_labels))

    dummy_labels = tf.zeros_like(start_labels)

    # []
    start_loss = tf.losses.sparse_softmax_cross_entropy(
        labels=tf.where(valid_start_labels, start_labels, dummy_labels),
        logits=start_logits,
        weights=tf.to_float(valid_start_labels),
        reduction=tf.losses.Reduction.MEAN)
    end_loss = tf.losses.sparse_softmax_cross_entropy(
        labels=tf.where(valid_end_labels, end_labels, dummy_labels),
        logits=end_logits,
        weights=tf.to_float(valid_end_labels),
        reduction=tf.losses.Reduction.MEAN)
    loss = start_loss + end_loss
  else:
    loss = None

  if mode == tf.estimator.ModeKeys.TRAIN:
    optimizer = tf.train.AdamOptimizer()
    gradients, variables = zip(*optimizer.compute_gradients(loss))
    gradients, _ = tf.clip_by_global_norm(gradients, 5.0)
    train_op = optimizer.apply_gradients(
        grads_and_vars=zip(gradients, variables),
        global_step=tf.train.get_global_step())
  else:
    # Don't build the train_op unnecessarily, since the ADAM variables can cause
    # problems with loading checkpoints on CPUs.
    train_op = None

  batch_size, max_context_len = tensor_utils.shape(features["context_wid"])
  tf.summary.histogram("batch_size", batch_size)
  tf.summary.histogram("non_padding", features["context_len"] / max_context_len)

  # [batch_size], [batch_size]
  start_predictions, end_predictions, predicted_score = (
      span_utils.max_scoring_span(start_logits, end_logits))

  # [batch_size, 2]
  predictions = dict(
      start_idx=start_predictions,
      end_idx=(end_predictions + 1),
      score=predicted_score)

  if mode == tf.estimator.ModeKeys.PREDICT:
    # Remove the dummy batch dimension if we are exporting the predictor.
    predictions = {k: tf.squeeze(v, 0) for k, v in predictions.items()}

  if mode == tf.estimator.ModeKeys.EVAL:
    text_summary = get_text_summary(
        question=features["question"],
        context=features["context"],
        start_predictions=start_predictions,
        end_predictions=end_predictions)

    # TODO(kentonl): Replace this with @mingweichang's official eval script.
    exact_match = tf.logical_and(
        tf.equal(start_predictions, start_labels),
        tf.equal(end_predictions, end_labels))

    eval_metric_ops = dict(
        exact_match=tf.metrics.mean(exact_match),
        text_summary=(text_summary, tf.no_op()))
  else:
    eval_metric_ops = None

  estimator_spec = tf.estimator.EstimatorSpec(
      mode=mode,
      loss=loss,
      predictions=predictions,
      train_op=train_op,
      eval_metric_ops=eval_metric_ops,
      scaffold=embedding_scaffold)

  return estimator_spec
def create_de_model(bert_config, is_training, input_ids_1, input_mask_1,
                    segment_ids_1, input_ids_2, input_masks_2, segment_ids_2,
                    num_candidates, labels, use_one_hot_embeddings):
    """Creates a ranking model using cosine and dual encoder representations."""

    sequence_length_query = FLAGS.max_seq_length_query
    sequence_length_passage = FLAGS.max_seq_length - FLAGS.max_seq_length_query

    input_ids_1 = tf.reshape(input_ids_1, [-1, sequence_length_query])
    segment_ids_1 = tf.reshape(segment_ids_1, [-1, sequence_length_query])
    input_masks_1 = tf.reshape(input_mask_1, [-1, sequence_length_query])
    batch_size = tf.shape(input_masks_1)[0]

    input_ids_2 = tf.reshape(input_ids_2, [-1, sequence_length_passage])
    segment_ids_2 = tf.reshape(segment_ids_2, [-1, sequence_length_passage])
    input_masks_2 = tf.reshape(input_masks_2, [-1, sequence_length_passage])

    # [batch_size, num_candidates]
    labels = tf.dtypes.cast(labels, tf.float32)

    # [batch_size, num_vec_query, hidden_size], [batch_size, num_vec_query]
    output_layer_1, mask_1 = encode_block(bert_config, input_ids_1,
                                          input_masks_1, segment_ids_1,
                                          use_one_hot_embeddings,
                                          FLAGS.num_vec_query, is_training)

    output_layer_2, mask_2 = encode_block(bert_config, input_ids_2,
                                          input_masks_2, segment_ids_2,
                                          use_one_hot_embeddings,
                                          FLAGS.num_vec_passage, is_training)

    label_mask = tf.expand_dims(tf.eye(batch_size), axis=2)
    label_mask = tf.tile(label_mask, [1, 1, num_candidates])
    label_mask = tf.reshape(label_mask, [batch_size, -1])
    label_mask = tf.cast(label_mask, tf.float32)

    labels = tf.tile(labels, [1, batch_size])
    labels = tf.multiply(labels, label_mask)
    output_layer_2_logits = tf.reshape(
        output_layer_2,
        [batch_size, num_candidates, FLAGS.num_vec_passage, -1])
    mask_2_logits = tf.reshape(
        mask_2, [batch_size, num_candidates, FLAGS.num_vec_passage])
    mask_logits = tf.einsum("BQ,BCP->BCQP", tf.cast(mask_1, tf.float32),
                            tf.cast(mask_2_logits, tf.float32))

    logits = tf.einsum("BQH,BCPH->BCQP", output_layer_1, output_layer_2_logits)
    logits = tf.multiply(logits, mask_logits)
    logits = tf.reduce_max(logits, axis=-1)
    logits = tf.reduce_sum(logits, axis=-1)

    if FLAGS.use_tpu and is_training:
        num_shards = tpu_utils.num_tpu_shards()
        output_layer_2 = tpu_utils.cross_shard_concat(output_layer_2)
        mask_2 = tpu_utils.cross_shard_concat(tf.cast(mask_2, tf.float32))
        mask_2 = tf.cast(mask_2, tf.bool)
        labels = tpu_utils.cross_shard_pad(labels)
        tf.logging.info("Global batch size: %s", tensor_utils.shape(labels, 0))
        tf.logging.info("Num shards: %s", num_shards)
        tf.logging.info("Number of candidates in batch: %s",
                        tensor_utils.shape(output_layer_2, 0))
        labels = tf.reshape(labels, [num_shards, batch_size, -1])
        labels = tf.transpose(labels, perm=[1, 0, 2])
        labels = tf.reshape(labels, [batch_size, -1])

    with tf.variable_scope("loss"):
        if is_training:
            output_layer_1 = tf.nn.dropout(output_layer_1,
                                           keep_prob=FLAGS.dropout)
            output_layer_2 = tf.nn.dropout(output_layer_2,
                                           keep_prob=FLAGS.dropout)
        cosine_similarity = tf.einsum("AQH,BPH->ABQP", output_layer_1,
                                      output_layer_2)
        mask = tf.cast(
            tf.logical_and(tf.expand_dims(tf.expand_dims(mask_1, 2), 1),
                           tf.expand_dims(tf.expand_dims(mask_2, 1), 0)),
            tf.float32)
        cosine_similarity = tf.multiply(cosine_similarity, mask)
        cosine_similarity = tf.reduce_max(cosine_similarity, axis=-1)
        cosine_similarity = tf.reduce_sum(cosine_similarity, axis=-1)
        per_example_loss = tf.losses.softmax_cross_entropy(
            labels, cosine_similarity)

        return (per_example_loss, logits)
Exemplo n.º 18
0
def build_model(question_tok_wid, question_lens, context_tok_wid, context_lens,
                embedding_weights, mode):
    """Wrapper around for Decomposable Attention model for NQ long answer scoring.

  Args:
    question_tok_wid: <int32> [batch_size, question_len]
    question_lens: <int32> [batch_size]
    context_tok_wid: <int32> [batch_size, num_context, context_len]
    context_lens: <int32> [batch_size, num_context]
    embedding_weights: <float> [vocab_size, embed_dim]
    mode: One of the keys from tf.estimator.ModeKeys.

  Returns:
    context_scores: <float> [batch_size, num_context]
  """
    # <float> [batch_size, question_len, embed_dim]
    question_emb = tf.nn.embedding_lookup(embedding_weights, question_tok_wid)
    # <float> [batch_size, num_context, context_len, embed_dim]
    context_emb = tf.nn.embedding_lookup(embedding_weights, context_tok_wid)

    question_emb = tf.layers.dense(inputs=question_emb,
                                   units=FLAGS.hidden_size,
                                   activation=None,
                                   name="reduce_emb",
                                   reuse=False)

    context_emb = tf.layers.dense(inputs=context_emb,
                                  units=FLAGS.hidden_size,
                                  activation=None,
                                  name="reduce_emb",
                                  reuse=True)

    batch_size, num_contexts, max_context_len, embed_dim = (
        tensor_utils.shape(context_emb))
    _, max_question_len, _ = tensor_utils.shape(question_emb)

    # <float> [batch_size * num_context, context_len, embed_dim]
    flat_context_emb = tf.reshape(context_emb,
                                  [-1, max_context_len, embed_dim])

    # <int32> [batch_size * num_context]
    flat_context_lens = tf.reshape(context_lens, [-1])

    # <float> [batch_size * num_context, question_len, embed_dim]
    question_emb_tiled = tf.tile(tf.expand_dims(question_emb, 1),
                                 [1, num_contexts, 1, 1])
    flat_question_emb_tiled = tf.reshape(question_emb_tiled,
                                         [-1, max_question_len, embed_dim])

    # <int32> [batch_size * num_context]
    question_lens_tiled = tf.tile(tf.expand_dims(question_lens, 1),
                                  [1, num_contexts])
    flat_question_lens_tiled = tf.reshape(question_lens_tiled, [-1])

    # <float> [batch_size * num_context, hidden_size]
    flat_decatt_emb = decatt.decomposable_attention(
        emb1=flat_question_emb_tiled,
        len1=flat_question_lens_tiled,
        emb2=flat_context_emb,
        len2=flat_context_lens,
        hidden_size=FLAGS.hidden_size,
        hidden_layers=FLAGS.hidden_layers,
        dropout_ratio=FLAGS.dropout_ratio,
        mode=mode)

    # <float> [batch_size, num_context, hidden_size]
    decatt_emb = tf.reshape(flat_decatt_emb,
                            [batch_size, num_contexts, FLAGS.hidden_size])

    weighted_num_overlap, unweighted_num_overlap, pos_embs = (
        _get_non_neural_features(question_tok_wid=question_tok_wid,
                                 question_lens=question_lens,
                                 context_tok_wid=context_tok_wid,
                                 context_lens=context_lens))

    final_emb = tf.concat(
        [decatt_emb, weighted_num_overlap, unweighted_num_overlap, pos_embs],
        -1)

    # Final linear layer to get score.
    # <float> [batch_size, num_context]
    context_scores = tf.layers.dense(inputs=final_emb,
                                     units=1,
                                     activation=None)
    context_scores = tf.squeeze(context_scores, -1)

    return context_scores
def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
                 labels, use_one_hot_embeddings, use_tpu):
    """Creates a classification model."""
    tpu_split = FLAGS.tpu_split if use_tpu else 1
    model = modeling.BertModel(config=bert_config,
                               is_training=is_training,
                               input_ids=input_ids,
                               input_mask=input_mask,
                               token_type_ids=segment_ids,
                               use_one_hot_embeddings=use_one_hot_embeddings)

    output_final_layer = model.get_sequence_output()

    # shape: bze, max_seq_len, hidden
    if FLAGS.emb_rep == "cls":
        embedding = tf.squeeze(output_final_layer[:, 0:1, :], axis=1)
    elif FLAGS.emb_rep == "mean":
        embedding = tf.reduce_mean(output_final_layer, axis=1)

    tf.logging.info("per tpu slice")
    tf.logging.info("emebdding size: %s", embedding.shape)
    tf.logging.info("label size: %s", labels.shape)
    tf.logging.info("=======" * 10)

    if use_tpu:
        # for tpu usage: combine embeddings after splitting 8 ways
        # [global_batch_size]
        labels = tpu_utils.cross_shard_concat(labels)
        tf.logging.info("label size: %s", labels.shape)
        tf.logging.info("=======" * 10)

        # [global_batch_size, hidden_size]
        embedding = tpu_utils.cross_shard_concat(embedding)

    tf.logging.info("Global batch size: %s", tensor_utils.shape(embedding, 0))

    tf.logging.info("emebdding size: %s", embedding.shape)
    tf.logging.info("label size: %s", labels.shape)
    tf.logging.info("num tpu shards: %s", tpu_utils.num_tpu_shards())
    tf.logging.info("=======" * 10)

    num_known_classes = FLAGS.num_domains * FLAGS.num_labels_per_domain
    num_unknown_classes = NUM_CLASSES - num_known_classes
    if FLAGS.continual_learning == "pretrain":
        num_classes = num_known_classes
        n_examples = FLAGS.known_num_shots
    elif FLAGS.continual_learning == "few_shot":
        num_classes = num_unknown_classes
        n_examples = FLAGS.few_shot

    if FLAGS.few_shot_known_neg:
        num_classes = NUM_CLASSES
        real_num_classes = num_unknown_classes

    # remove padding in each batch
    if use_tpu:
        real_shift = math.ceil(
            num_classes / FLAGS.batch_size) * FLAGS.batch_size
        # if use TPU, then embedding.shape[0] will be (num_classes + pad_num) * 8
        real_indices = tf.range(num_classes)
        for i in range(1, tpu_split):
            real_indices = tf.concat(
                [real_indices,
                 tf.range(num_classes) + real_shift * i], axis=0)

        embedding = tf.gather(embedding, real_indices)
        labels = tf.gather(labels, real_indices)

        tf.logging.info("emebdding size after removing padding in batch: %s",
                        embedding.shape)
        tf.logging.info("label size after removing padding in batch: %s",
                        labels.shape)

        # remove padded batch
        if n_examples < tpu_split:
            real_batch_total = n_examples * num_classes
            embedding = embedding[:real_batch_total]
            labels = labels[:real_batch_total]
            real_num = n_examples
        else:
            real_num = tpu_split
    else:
        # not use TPUs
        if n_examples < tpu_split:
            real_num = n_examples
        else:
            real_num = tpu_split
        real_batch_total = real_num * num_classes
        embedding = embedding[:real_batch_total]
        labels = labels[:real_batch_total]

    tf.logging.info("real emebdding size: %s", embedding.shape)
    tf.logging.info("real label size: %s", labels.shape)

    n = embedding.shape[0].value

    assert n == real_num * num_classes, "n: %d; real_num: %d: num_classes: %d" % (
        n, real_num, num_classes)

    with tf.variable_scope("loss", reuse=tf.AUTO_REUSE):
        if is_training:
            # I.e., 0.1 dropout
            embedding = tf.nn.dropout(embedding, keep_prob=1 - DROPOUT_PROB)

        logits = tf.matmul(embedding, embedding, transpose_b=True)

        diagonal_matrix = tf.eye(n, n)
        logits = logits - diagonal_matrix * logits

        logits_reshape = tf.reshape(logits, [n, real_num, num_classes])

        if FLAGS.reduce_method == "mean":
            all_logits_sum = tf.reduce_sum(logits_reshape, 1)
            num_counts = tf.ones([n, num_classes]) * real_num
            label_diagonal = tf.eye(num_classes, num_classes)
            label_diagonal = tf.tile(label_diagonal, tf.constant([real_num,
                                                                  1]))
            num_counts = num_counts - label_diagonal
            mean_logits = tf.divide(all_logits_sum, num_counts)
            if FLAGS.few_shot_known_neg:
                real_logits_indices = tf.range(real_num_classes)
                for i in range(1, n_examples):
                    real_logits_indices = tf.concat([
                        real_logits_indices,
                        tf.range(real_num_classes) + num_classes * i
                    ],
                                                    axis=0)
                mean_logits = tf.gather(mean_logits, real_logits_indices)

                label_diagonal = tf.eye(real_num_classes, num_classes)
                label_diagonal = tf.tile(label_diagonal,
                                         tf.constant([real_num, 1]))

            probabilities = tf.nn.softmax(mean_logits, axis=-1)
            log_probs = tf.nn.log_softmax(mean_logits, axis=-1)
            return_logits = mean_logits

        elif FLAGS.reduce_method == "max":
            max_logits = tf.reduce_max(logits_reshape, 1)

            if FLAGS.min_max:
                # Because the diagnoal is 0, we need to assign a large number to get the
                # true min.
                large_number = 50000
                added_logits = logits + diagonal_matrix * large_number
                added_reshape_logits = tf.reshape(added_logits,
                                                  [n, real_num, num_classes])
                min_logits = tf.reduce_min(added_reshape_logits,
                                           1)  # n * num_classes
                masks = tf.tile(tf.eye(num_classes, num_classes),
                                tf.constant([real_num, 1]))
                max_logits = masks * min_logits + (1 - masks) * max_logits

            label_diagonal = tf.eye(num_classes, num_classes)

            if FLAGS.few_shot_known_neg:
                real_logits_indices = tf.range(real_num_classes)
                # WARNING: current implementation may not be correct for few_shot > 8 on
                # tpus in the following for loop, it should be for i in
                # range(1, real_num) instead of in range(1, n_examples).
                assert n_examples < 8, (
                    "current implementation may not be correct for "
                    "few_shot > 8 on tpus. Need to check")
                # Note: n_examples here is 2 or 5, which is less than tpu_slit.
                for i in range(1, n_examples):
                    real_logits_indices = tf.concat([
                        real_logits_indices,
                        tf.range(real_num_classes) + num_classes * i
                    ],
                                                    axis=0)
                max_logits = tf.gather(max_logits, real_logits_indices)
                label_diagonal = label_diagonal[:real_num_classes]

            label_diagonal = tf.tile(label_diagonal, tf.constant([real_num,
                                                                  1]))

            probabilities = tf.nn.softmax(max_logits, axis=-1)
            log_probs = tf.nn.log_softmax(max_logits, axis=-1)
            return_logits = max_logits

        elif FLAGS.reduce_method == "random":
            indice_0 = tf.expand_dims(tf.range(n), axis=1)  # n x 1
            indice_1 = tf.random.uniform([n, 1],
                                         minval=0,
                                         maxval=real_num,
                                         dtype=tf.dtypes.int32)
            random_indices = tf.concat([indice_0, indice_1], axis=1)
            random_logits = tf.gather_nd(logits_reshape, random_indices)

            label_diagonal = tf.eye(num_classes, num_classes)

            if FLAGS.few_shot_known_neg:
                real_logits_indices = tf.range(real_num_classes)
                for i in range(1, n_examples):
                    real_logits_indices = tf.concat([
                        real_logits_indices,
                        tf.range(real_num_classes) + num_classes * i
                    ],
                                                    axis=0)
                random_logits = tf.gather(random_logits, real_logits_indices)
                label_diagonal = label_diagonal[:real_num_classes]

            label_diagonal = tf.tile(label_diagonal, tf.constant([real_num,
                                                                  1]))

            probabilities = tf.nn.softmax(random_logits, axis=-1)
            log_probs = tf.nn.log_softmax(random_logits, axis=-1)
            return_logits = random_logits

        per_example_loss = -tf.reduce_sum(label_diagonal * log_probs, axis=-1)
        loss = tf.reduce_mean(per_example_loss)

        return (loss, per_example_loss, return_logits, probabilities)