def _get_rc_model_input(
    question_ids,
    question_mask,
    context_ids,
    context_mask,
    vocab,
):
    """Create RC module input from separate batched components.

  Args:
    question_ids: <int32> [batch_size, question_len]
    question_mask: <int32> [batch_size, question_len]
    context_ids: <int32> [batch_size, context_len]
    context_mask: <int32> [batch_size, context_len]
    vocab: Instance of text_utils.Vocab.

  Returns:
    input_ids: <int32> [batch_size, rc_input_len]
    input_mask: <int32> [batch_size, rc_input_len]
    segment_ids: <int32> [batch_size, rc_input_len]
  """
    # Get batch size.
    batch_size = tensor_utils.shape(context_ids, 0)

    # Get special tokens.
    cls = vocab.t2i(vocab.CLS)
    sep = vocab.t2i(vocab.SEP)

    # Join question, context, and special tokens.
    cls_batch = tf.fill([batch_size, 1], cls)
    sep_batch = tf.fill([batch_size, 1], sep)
    input_ids = tf.concat(
        [cls_batch, question_ids, sep_batch, context_ids, sep_batch], axis=1)

    # Create and join segment ids.
    segment_a_ids = tf.fill(
        [batch_size, tensor_utils.shape(question_ids, 1) + 2], 0)
    segment_b_ids = tf.fill(
        [batch_size, tensor_utils.shape(context_ids, 1) + 1], 1)
    segment_ids = tf.concat([segment_a_ids, segment_b_ids], axis=1)

    # Create joined mask, accounting for special tokens gaps.
    gap_mask = tf.fill([batch_size, 1], 1)
    input_mask = tf.concat(
        [gap_mask, question_mask, gap_mask, context_mask, gap_mask], axis=1)
    bool_mask = tf.cast(input_mask, tf.bool)

    # Select unmasked items and move all padding to the end.
    # Right now this looks like this:
    #   [CLS] X X X [PAD] ... [SEP] Y Y Y [PAD] ... [SEP] [PAD] ...
    # And we want to change it to look like this:
    #   [CLS] X X X [SEP] Y Y Y [SEP] [PAD] ...
    input_ids = tensor_utils.boolean_mask(input_ids, bool_mask)
    input_mask = tensor_utils.boolean_mask(input_mask, bool_mask)
    segment_ids = tensor_utils.boolean_mask(segment_ids, bool_mask)

    return input_ids, input_mask, segment_ids
def exact_match(answer_ids, prediction_ids, vocab):
    """Compute exact match score between answer tokens and prediction tokens.

  Args:
    answer_ids: <int32> [batch_size, answer_length]
    prediction_ids: <int32> [batch_size, prediction_length]
    vocab: Instance of text_utils.Vocab.

  Returns:
    score: <float32> [batch_size] tensor of {0.0, 1.0}.
  """
    batch_size = tensor_utils.shape(answer_ids, 0)

    # Get cleanable words.
    remove_ids = list(_get_normalized_set(vocab))
    remove_ids = tf.reshape(remove_ids, [1, 1, -1])
    remove_ids = tf.tile(remove_ids, [batch_size, 1, 1])

    # Clean answer: remove tokens that are in the normalized set.
    should_keep = tf.reduce_all(tf.not_equal(tf.expand_dims(answer_ids, -1),
                                             remove_ids),
                                axis=-1)
    answer_ids = tensor_utils.boolean_mask(answer_ids, should_keep)

    # Clean context: remove tokens that are in the normalized set.
    should_keep = tf.reduce_all(tf.not_equal(
        tf.expand_dims(prediction_ids, -1), remove_ids),
                                axis=-1)
    prediction_ids = tensor_utils.boolean_mask(prediction_ids, should_keep)

    # Cleaned lengths.
    answer_len = tensor_utils.shape(answer_ids, 1)
    prediction_len = tensor_utils.shape(prediction_ids, 1)

    # Pad the shorter one to the length of the longer.
    padding = tf.maximum(0, prediction_len - answer_len)
    answer_ids = tf.pad(answer_ids, [[0, 0], [0, padding]])
    padding = tf.maximum(0, answer_len - prediction_len)
    prediction_ids = tf.pad(prediction_ids, [[0, 0], [0, padding]])

    # Check for equality: Padded A == Padded B?
    is_equal = tf.reduce_all(tf.equal(answer_ids, prediction_ids), axis=1)
    score = tf.cast(is_equal, tf.float32)

    return score
def indicator_score(answer_ids, answer_mask, context_ids, vocab):
    """Compute indicator score of answer and context.

  Checks if the answer tokens are a subspan of the context.

  Args:
    answer_ids: <int32> [batch_size, answer_length]
    answer_mask: <int32> [batch_size, answer_length]
    context_ids: <int32> [batch_size, context_length]
    vocab: Instance of text_utils.Vocab.

  Returns:
    score: <float32> [batch_size] tensor of {0.0, 1.0}.
  """
    batch_size = tensor_utils.shape(answer_ids, 0)

    # Get cleanable words.
    remove_ids = list(_get_normalized_set(vocab))
    remove_ids = tf.reshape(remove_ids, [1, 1, -1])
    remove_ids = tf.tile(remove_ids, [batch_size, 1, 1])

    # Clean answer: remove tokens that are in the normalized set.
    should_keep = tf.reduce_all(tf.not_equal(tf.expand_dims(answer_ids, -1),
                                             remove_ids),
                                axis=-1)
    answer_ids = tensor_utils.boolean_mask(answer_ids, should_keep)
    answer_mask = tensor_utils.boolean_mask(answer_mask, should_keep)

    # Clean context: remove tokens that are in the normalized set.
    should_keep = tf.reduce_all(tf.not_equal(tf.expand_dims(context_ids, -1),
                                             remove_ids),
                                axis=-1)
    context_ids = tensor_utils.boolean_mask(context_ids, should_keep)

    # Cleaned lengths.
    answer_len = tensor_utils.shape(answer_ids, 1)
    context_len = tensor_utils.shape(context_ids, 1)

    # Pad start of context (to select NULL for over-length indices).
    context_ids = tf.pad(context_ids, [[0, 0], [1, 0]])
    context_len += 1

    # Sliding window approach: take the full context of length N and gather
    # it into a tensor with all windows of length M (a N x M tensor).
    # [context_len, answer_len]
    window_idx = tf.range(answer_len)
    window_idx = tf.tile(tf.expand_dims(window_idx, 0), [context_len, 1])
    offsets = tf.expand_dims(tf.range(context_len), 1)
    window_idx += offsets
    window_idx *= tf.cast(tf.less(window_idx, context_len), tf.int32)

    # [batch_size, context_len * answer_len]
    window_idx = tf.reshape(window_idx, [1, -1])
    window_idx = tf.tile(window_idx, [batch_size, 1])

    # [batch_size, context_len * answer_len]
    batch_idx = tf.range(batch_size)
    batch_idx = tf.expand_dims(batch_idx, 1)
    batch_idx = tf.tile(batch_idx, [1, context_len * answer_len])

    # [batch_size, context_len, answer_len]
    batch_idx = tf.reshape(batch_idx, [-1])
    window_idx = tf.reshape(window_idx, [-1])
    coords = tf.stack([batch_idx, window_idx], axis=1)
    window_ids = tf.gather_nd(context_ids, coords)
    window_ids = tf.reshape(window_ids, [batch_size, context_len, answer_len])

    # [batch_size, context_len, answer_len]
    answer_mask = tf.expand_dims(answer_mask, 1)
    window_ids *= answer_mask

    # Check for equality. The whole window has to match the answer, but only
    # one window has to count to be a positive indicator value.
    answer_ids = tf.expand_dims(answer_ids, 1)
    is_equal = tf.reduce_all(tf.equal(answer_ids, window_ids), axis=-1)
    score = tf.cast(tf.reduce_any(is_equal, axis=-1), tf.float32)

    return score
def rc_span(
    question_ids,
    question_mask,
    context_ids,
    context_mask,
    rc_model,
    vocab,
    max_length=10,
    no_answer_bias=0,
):
    """Computes exact match score from QA model run on context.

  Args:
    question_ids: <int32> [batch_size, question_len]
    question_mask: <int32> [batch_size, question_len]
    context_ids: <int32> [batch_size, context_len]
    context_mask: <int32> [batch_size, context_len]
    rc_model: Extractive question answering model.
    vocab: Instance of text_utils.Vocab.
    max_length: Max answer length.
    no_answer_bias: Log-odds ratio for answer span over NULL.

  Returns:
    score: <float32> [batch_size]
  """
    # Mask out stop id in context if present.
    stop_id = vocab.t2i(vocab.SEP)
    stop_mask = tf.cast(tf.not_equal(context_ids, stop_id), tf.int32)
    context_mask *= stop_mask

    # Prepare rc inputs.
    input_ids, input_mask, segment_ids = _get_rc_model_input(
        question_ids=question_ids,
        question_mask=question_mask,
        context_ids=context_ids,
        context_mask=context_mask,
        vocab=vocab)

    # Get start/end logits from RC model.
    outputs = rc_model(inputs=dict(input_ids=input_ids,
                                   input_mask=input_mask,
                                   segment_ids=segment_ids),
                       signature="extractive_qa",
                       as_dict=True)

    # Dimensions
    batch_size = tensor_utils.shape(input_ids, 0)
    context_len = tensor_utils.shape(input_ids, 1)

    # Decode span.
    start_logits = tf.reshape(outputs["start_logits"], [-1, context_len])
    end_logits = tf.reshape(outputs["end_logits"], [-1, context_len])
    start, end, span_scores = max_scoring_span(start_scores=start_logits,
                                               end_scores=end_logits,
                                               max_length=max_length,
                                               no_answer_bias=no_answer_bias)

    # Expand shape to be compatible for broadcasting.
    start = tf.reshape(start, [-1, 1])
    end = tf.reshape(end, [-1, 1])

    # Create mask where mask[i, j] = True if i >= start and j <= end.
    # [batch_size, max_rc_input_len]
    mask = tf.tile(tf.expand_dims(tf.range(context_len), 0), [batch_size, 1])
    mask = tf.logical_and(tf.greater_equal(mask, start),
                          tf.less_equal(mask, end))

    # Gather padded answer span from context.
    answer_span = tensor_utils.boolean_mask(input_ids, mask)

    return answer_span, span_scores