def test_reader_inputs(self): concat_inputs = orqa_ops.reader_inputs(question_token_ids=[0, 1], block_token_ids=[[2, 3, 4], [5, 6, 0]], block_lengths=[3, 2], block_token_map=[[1, 2, 5], [1, 3, 4]], answer_token_ids=[[3, 4], [7, 0]], answer_lengths=[2, 1], cls_token_id=10, sep_token_id=11, max_sequence_len=10) self.assertAllEqual(concat_inputs.token_ids.numpy(), [[10, 0, 1, 11, 2, 3, 4, 11, 0, 0], [10, 0, 1, 11, 5, 6, 11, 0, 0, 0]]) self.assertAllEqual( concat_inputs.mask.numpy(), [[1, 1, 1, 1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0]]) self.assertAllEqual( concat_inputs.segment_ids.numpy(), [[0, 0, 0, 0, 1, 1, 1, 1, 0, 0], [0, 0, 0, 0, 1, 1, 1, 0, 0, 0]]) self.assertAllEqual( concat_inputs.block_mask.numpy(), [[0, 0, 0, 0, 1, 1, 1, 0, 0, 0], [0, 0, 0, 0, 1, 1, 0, 0, 0, 0]]) self.assertAllEqual(concat_inputs.token_map.numpy(), [[-1, -1, -1, -1, 1, 2, 5, -1, -1, -1], [-1, -1, -1, -1, 1, 3, -1, -1, -1, -1]]) self.assertAllEqual(concat_inputs.gold_starts.numpy(), [[5], [-1]]) self.assertAllEqual(concat_inputs.gold_ends.numpy(), [[6], [-1]])
def read(features, retriever_logits, blocks, mode, params, labels): """Do reading.""" tokenizer, vocab_lookup_table = bert_utils.get_tf_tokenizer( params["reader_module_path"]) orig_blocks = blocks (orig_tokens, block_token_map, block_token_ids, blocks) = (bert_utils.tokenize_with_original_mapping(blocks, tokenizer)) question_token_ids = tokenizer.tokenize( tf.expand_dims(features["question"], 0)) question_token_ids = tf.cast(question_token_ids.flat_values, tf.int32) orig_tokens = orig_tokens.to_tensor(default_value="") block_lengths = tf.cast(block_token_ids.row_lengths(), tf.int32) block_token_ids = tf.cast(block_token_ids.to_tensor(), tf.int32) block_token_map = tf.cast(block_token_map.to_tensor(), tf.int32) answer_token_ids = tokenizer.tokenize(labels).merge_dims(1, 2) answer_lengths = tf.cast(answer_token_ids.row_lengths(), tf.int32) answer_token_ids = tf.cast(answer_token_ids.to_tensor(), tf.int32) cls_token_id = vocab_lookup_table.lookup(tf.constant("[CLS]")) sep_token_id = vocab_lookup_table.lookup(tf.constant("[SEP]")) concat_inputs = orqa_ops.reader_inputs( question_token_ids=question_token_ids, block_token_ids=block_token_ids, block_lengths=block_lengths, block_token_map=block_token_map, answer_token_ids=answer_token_ids, answer_lengths=answer_lengths, cls_token_id=tf.cast(cls_token_id, tf.int32), sep_token_id=tf.cast(sep_token_id, tf.int32), max_sequence_len=params["reader_seq_len"]) tf.summary.scalar("reader_nonpad_ratio", tf.reduce_mean(tf.cast(concat_inputs.mask, tf.float32))) reader_module = hub.Module( params["reader_module_path"], tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {}, trainable=True) concat_outputs = reader_module(dict(input_ids=concat_inputs.token_ids, input_mask=concat_inputs.mask, segment_ids=concat_inputs.segment_ids), signature="tokens", as_dict=True) concat_token_emb = concat_outputs["sequence_output"] # [num_spans], [num_spans], [reader_beam_size, num_spans] candidate_starts, candidate_ends, candidate_mask = span_candidates( concat_inputs.block_mask, params["max_span_width"]) # Score with an MLP to enable start/end interaction: # score(s, e) = w·σ(w_s·h_s + w_e·h_e) kernel_initializer = tf.truncated_normal_initializer(stddev=0.02) # [reader_beam_size, max_sequence_len, span_hidden_size * 2] projection = tf.layers.dense(concat_token_emb, params["span_hidden_size"] * 2, kernel_initializer=kernel_initializer) # [reader_beam_size, max_sequence_len, span_hidden_size] start_projection, end_projection = tf.split(projection, 2, -1) # [reader_beam_size, num_candidates, span_hidden_size] candidate_start_projections = tf.gather(start_projection, candidate_starts, axis=1) candidate_end_projection = tf.gather(end_projection, candidate_ends, axis=1) candidate_hidden = candidate_start_projections + candidate_end_projection candidate_hidden = tf.nn.relu(candidate_hidden) candidate_hidden = tf.keras.layers.LayerNormalization( axis=-1)(candidate_hidden) # [reader_beam_size, num_candidates, 1] reader_logits = tf.layers.dense(candidate_hidden, 1, kernel_initializer=kernel_initializer) # [reader_beam_size, num_candidates] reader_logits = tf.squeeze(reader_logits) reader_logits += mask_to_score(candidate_mask) reader_logits += tf.expand_dims(retriever_logits, -1) # [reader_beam_size, num_candidates] candidate_orig_starts = tf.gather(params=concat_inputs.token_map, indices=candidate_starts, axis=-1) candidate_orig_ends = tf.gather(params=concat_inputs.token_map, indices=candidate_ends, axis=-1) return ReaderOutputs(logits=reader_logits, candidate_starts=candidate_starts, candidate_ends=candidate_ends, candidate_orig_starts=candidate_orig_starts, candidate_orig_ends=candidate_orig_ends, blocks=blocks, orig_blocks=orig_blocks, orig_tokens=orig_tokens, token_ids=concat_inputs.token_ids, gold_starts=concat_inputs.gold_starts, gold_ends=concat_inputs.gold_ends)
def read(features, retriever_logits, blocks, mode, params): """Do reading.""" tokenizer, vocab_lookup_table = bert_utils.get_tf_tokenizer( params["reader_module_path"]) (orig_tokens, block_token_map, block_token_ids, blocks) = (bert_utils.tokenize_with_original_mapping(blocks, tokenizer)) # NOTE: we assume that the batch size is 1. question_token_ids = tokenizer.tokenize( tf.expand_dims(features["question"], 0)) question_token_ids = tf.cast(question_token_ids.flat_values, tf.int32) orig_tokens = orig_tokens.to_tensor(default_value="") block_lengths = tf.cast(block_token_ids.row_lengths(), tf.int32) block_token_ids = tf.cast(block_token_ids.to_tensor(), tf.int32) block_token_map = tf.cast(block_token_map.to_tensor(), tf.int32) fake_answer = tf.constant([""]) answer_token_ids = tokenizer.tokenize(fake_answer).merge_dims(1, 2) answer_lengths = tf.cast(answer_token_ids.row_lengths(), tf.int32) answer_token_ids = tf.cast(answer_token_ids.to_tensor(), tf.int32) cls_token_id = vocab_lookup_table.lookup(tf.constant("[CLS]")) sep_token_id = vocab_lookup_table.lookup(tf.constant("[SEP]")) concat_inputs = orqa_ops.reader_inputs( question_token_ids=question_token_ids, block_token_ids=block_token_ids, block_lengths=block_lengths, block_token_map=block_token_map, answer_token_ids=answer_token_ids, answer_lengths=answer_lengths, cls_token_id=tf.cast(cls_token_id, tf.int32), sep_token_id=tf.cast(sep_token_id, tf.int32), max_sequence_len=params["reader_seq_len"]) tf.summary.scalar("reader_nonpad_ratio", tf.reduce_mean(tf.cast(concat_inputs.mask, tf.float32))) reader_module = hub.Module( params["reader_module_path"], tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {}, trainable=True) concat_outputs = reader_module(dict(input_ids=concat_inputs.token_ids, input_mask=concat_inputs.mask, segment_ids=concat_inputs.segment_ids), signature="tokens", as_dict=True) with tf.variable_scope("pooler"): # [reader_beam_size, hidden_size] first_token_tensor = concat_outputs["pooled_output"] kernel_initializer = tf.truncated_normal_initializer(stddev=0.02) # [reader_beam_size, num_classes] reader_logits = tf.layers.dense(first_token_tensor, params["num_classes"], kernel_initializer=kernel_initializer) final_logits = reader_logits + tf.expand_dims(retriever_logits, -1) return final_logits