def _get_rc_model_input( question_ids, question_mask, context_ids, context_mask, vocab, ): """Create RC module input from separate batched components. Args: question_ids: <int32> [batch_size, question_len] question_mask: <int32> [batch_size, question_len] context_ids: <int32> [batch_size, context_len] context_mask: <int32> [batch_size, context_len] vocab: Instance of text_utils.Vocab. Returns: input_ids: <int32> [batch_size, rc_input_len] input_mask: <int32> [batch_size, rc_input_len] segment_ids: <int32> [batch_size, rc_input_len] """ # Get batch size. batch_size = tensor_utils.shape(context_ids, 0) # Get special tokens. cls = vocab.t2i(vocab.CLS) sep = vocab.t2i(vocab.SEP) # Join question, context, and special tokens. cls_batch = tf.fill([batch_size, 1], cls) sep_batch = tf.fill([batch_size, 1], sep) input_ids = tf.concat( [cls_batch, question_ids, sep_batch, context_ids, sep_batch], axis=1) # Create and join segment ids. segment_a_ids = tf.fill( [batch_size, tensor_utils.shape(question_ids, 1) + 2], 0) segment_b_ids = tf.fill( [batch_size, tensor_utils.shape(context_ids, 1) + 1], 1) segment_ids = tf.concat([segment_a_ids, segment_b_ids], axis=1) # Create joined mask, accounting for special tokens gaps. gap_mask = tf.fill([batch_size, 1], 1) input_mask = tf.concat( [gap_mask, question_mask, gap_mask, context_mask, gap_mask], axis=1) bool_mask = tf.cast(input_mask, tf.bool) # Select unmasked items and move all padding to the end. # Right now this looks like this: # [CLS] X X X [PAD] ... [SEP] Y Y Y [PAD] ... [SEP] [PAD] ... # And we want to change it to look like this: # [CLS] X X X [SEP] Y Y Y [SEP] [PAD] ... input_ids = tensor_utils.boolean_mask(input_ids, bool_mask) input_mask = tensor_utils.boolean_mask(input_mask, bool_mask) segment_ids = tensor_utils.boolean_mask(segment_ids, bool_mask) return input_ids, input_mask, segment_ids
def exact_match(answer_ids, prediction_ids, vocab): """Compute exact match score between answer tokens and prediction tokens. Args: answer_ids: <int32> [batch_size, answer_length] prediction_ids: <int32> [batch_size, prediction_length] vocab: Instance of text_utils.Vocab. Returns: score: <float32> [batch_size] tensor of {0.0, 1.0}. """ batch_size = tensor_utils.shape(answer_ids, 0) # Get cleanable words. remove_ids = list(_get_normalized_set(vocab)) remove_ids = tf.reshape(remove_ids, [1, 1, -1]) remove_ids = tf.tile(remove_ids, [batch_size, 1, 1]) # Clean answer: remove tokens that are in the normalized set. should_keep = tf.reduce_all(tf.not_equal(tf.expand_dims(answer_ids, -1), remove_ids), axis=-1) answer_ids = tensor_utils.boolean_mask(answer_ids, should_keep) # Clean context: remove tokens that are in the normalized set. should_keep = tf.reduce_all(tf.not_equal( tf.expand_dims(prediction_ids, -1), remove_ids), axis=-1) prediction_ids = tensor_utils.boolean_mask(prediction_ids, should_keep) # Cleaned lengths. answer_len = tensor_utils.shape(answer_ids, 1) prediction_len = tensor_utils.shape(prediction_ids, 1) # Pad the shorter one to the length of the longer. padding = tf.maximum(0, prediction_len - answer_len) answer_ids = tf.pad(answer_ids, [[0, 0], [0, padding]]) padding = tf.maximum(0, answer_len - prediction_len) prediction_ids = tf.pad(prediction_ids, [[0, 0], [0, padding]]) # Check for equality: Padded A == Padded B? is_equal = tf.reduce_all(tf.equal(answer_ids, prediction_ids), axis=1) score = tf.cast(is_equal, tf.float32) return score
def indicator_score(answer_ids, answer_mask, context_ids, vocab): """Compute indicator score of answer and context. Checks if the answer tokens are a subspan of the context. Args: answer_ids: <int32> [batch_size, answer_length] answer_mask: <int32> [batch_size, answer_length] context_ids: <int32> [batch_size, context_length] vocab: Instance of text_utils.Vocab. Returns: score: <float32> [batch_size] tensor of {0.0, 1.0}. """ batch_size = tensor_utils.shape(answer_ids, 0) # Get cleanable words. remove_ids = list(_get_normalized_set(vocab)) remove_ids = tf.reshape(remove_ids, [1, 1, -1]) remove_ids = tf.tile(remove_ids, [batch_size, 1, 1]) # Clean answer: remove tokens that are in the normalized set. should_keep = tf.reduce_all(tf.not_equal(tf.expand_dims(answer_ids, -1), remove_ids), axis=-1) answer_ids = tensor_utils.boolean_mask(answer_ids, should_keep) answer_mask = tensor_utils.boolean_mask(answer_mask, should_keep) # Clean context: remove tokens that are in the normalized set. should_keep = tf.reduce_all(tf.not_equal(tf.expand_dims(context_ids, -1), remove_ids), axis=-1) context_ids = tensor_utils.boolean_mask(context_ids, should_keep) # Cleaned lengths. answer_len = tensor_utils.shape(answer_ids, 1) context_len = tensor_utils.shape(context_ids, 1) # Pad start of context (to select NULL for over-length indices). context_ids = tf.pad(context_ids, [[0, 0], [1, 0]]) context_len += 1 # Sliding window approach: take the full context of length N and gather # it into a tensor with all windows of length M (a N x M tensor). # [context_len, answer_len] window_idx = tf.range(answer_len) window_idx = tf.tile(tf.expand_dims(window_idx, 0), [context_len, 1]) offsets = tf.expand_dims(tf.range(context_len), 1) window_idx += offsets window_idx *= tf.cast(tf.less(window_idx, context_len), tf.int32) # [batch_size, context_len * answer_len] window_idx = tf.reshape(window_idx, [1, -1]) window_idx = tf.tile(window_idx, [batch_size, 1]) # [batch_size, context_len * answer_len] batch_idx = tf.range(batch_size) batch_idx = tf.expand_dims(batch_idx, 1) batch_idx = tf.tile(batch_idx, [1, context_len * answer_len]) # [batch_size, context_len, answer_len] batch_idx = tf.reshape(batch_idx, [-1]) window_idx = tf.reshape(window_idx, [-1]) coords = tf.stack([batch_idx, window_idx], axis=1) window_ids = tf.gather_nd(context_ids, coords) window_ids = tf.reshape(window_ids, [batch_size, context_len, answer_len]) # [batch_size, context_len, answer_len] answer_mask = tf.expand_dims(answer_mask, 1) window_ids *= answer_mask # Check for equality. The whole window has to match the answer, but only # one window has to count to be a positive indicator value. answer_ids = tf.expand_dims(answer_ids, 1) is_equal = tf.reduce_all(tf.equal(answer_ids, window_ids), axis=-1) score = tf.cast(tf.reduce_any(is_equal, axis=-1), tf.float32) return score
def rc_span( question_ids, question_mask, context_ids, context_mask, rc_model, vocab, max_length=10, no_answer_bias=0, ): """Computes exact match score from QA model run on context. Args: question_ids: <int32> [batch_size, question_len] question_mask: <int32> [batch_size, question_len] context_ids: <int32> [batch_size, context_len] context_mask: <int32> [batch_size, context_len] rc_model: Extractive question answering model. vocab: Instance of text_utils.Vocab. max_length: Max answer length. no_answer_bias: Log-odds ratio for answer span over NULL. Returns: score: <float32> [batch_size] """ # Mask out stop id in context if present. stop_id = vocab.t2i(vocab.SEP) stop_mask = tf.cast(tf.not_equal(context_ids, stop_id), tf.int32) context_mask *= stop_mask # Prepare rc inputs. input_ids, input_mask, segment_ids = _get_rc_model_input( question_ids=question_ids, question_mask=question_mask, context_ids=context_ids, context_mask=context_mask, vocab=vocab) # Get start/end logits from RC model. outputs = rc_model(inputs=dict(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids), signature="extractive_qa", as_dict=True) # Dimensions batch_size = tensor_utils.shape(input_ids, 0) context_len = tensor_utils.shape(input_ids, 1) # Decode span. start_logits = tf.reshape(outputs["start_logits"], [-1, context_len]) end_logits = tf.reshape(outputs["end_logits"], [-1, context_len]) start, end, span_scores = max_scoring_span(start_scores=start_logits, end_scores=end_logits, max_length=max_length, no_answer_bias=no_answer_bias) # Expand shape to be compatible for broadcasting. start = tf.reshape(start, [-1, 1]) end = tf.reshape(end, [-1, 1]) # Create mask where mask[i, j] = True if i >= start and j <= end. # [batch_size, max_rc_input_len] mask = tf.tile(tf.expand_dims(tf.range(context_len), 0), [batch_size, 1]) mask = tf.logical_and(tf.greater_equal(mask, start), tf.less_equal(mask, end)) # Gather padded answer span from context. answer_span = tensor_utils.boolean_mask(input_ids, mask) return answer_span, span_scores