def test_history_state_template(self): new_flags = get_updated_default_flags(context_window_size=3, context_title_size=1) with flagsaver.flagsaver(**new_flags): history_state_part = bert_state_lib.history_state_part( documents=[ environment_pb2.Document(content='this is a high term', answer=environment_pb2.Answer( answer='high', mr_score=42.0), title='high title'), environment_pb2.Document( content='this is a low term instead', answer=environment_pb2.Answer(answer='low', mr_score=1.0), title='low title'), ], tokenize_fn=self.tokenizer.tokenize, idf_lookup=collections.defaultdict(float, ( ('high', 10.0), ('low', 5.0), ('term', 7.5), ('title', 8.0), )), context_size=3, max_length=128, max_title_length=1) self.assertEqual(history_state_part, [ nqutils.ObsFragment( text=nqutils.Text( tokens='high [SEP] a high term [SEP] high [SEP]'.split()), type_values={ 'state_part': ['history_answer', '[SEP]'] + ['history_context'] * 3 + ['[SEP]'] + ['history_title'] + ['[SEP]'] }, float_values={ 'mr_score': [42.0] * 8, 'idf_score': [10.0, 0.0, 0.0, 10.0, 7.5, 0.0, 10.0, 0.0], }), nqutils.ObsFragment( text=nqutils.Text( tokens='low [SEP] a low term [SEP] low [SEP]'.split()), type_values={ 'state_part': ['history_answer', '[SEP]'] + ['history_context'] * 3 + ['[SEP]'] + ['history_title'] + ['[SEP]'] }, float_values={ 'mr_score': [1.0] * 8, 'idf_score': [5.0, 0.0, 0.0, 5.0, 7.5, 0.0, 5.0, 0.0], }), ])
def test_original_query_state_template(self): self.assertEqual( bert_state_lib.original_query_state_part( query='who is on the cover of ufc 2', tokenize_fn=self.tokenizer.tokenize, idf_lookup=self.idf_lookup), nqutils.ObsFragment( text=nqutils.Text( tokens='[CLS] who is on the cover of ufc 2 [SEP]'.split()), type_values={ 'state_part': ['[CLS]'] + ['original_query'] * 8 + ['[SEP]'], }, float_values={ 'mr_score': [0.0] * (8 + 2), 'idf_score': [ 0.0, # [CLS] 0.0, # who 0.0, # is 0.0, # on 0.0, # the 5.25, # cover 0.0, # of 8.5, # UFC 3.5, # 2 0.0 # [SEP] ] }))
def history_state_part(documents: Sequence[environment_pb2.Document], tokenize_fn: TokenizeFn, idf_lookup: IdfLookupFn, context_size: int, max_length: int, max_title_length: int) -> List[nqutils.ObsFragment]: """Computes the history part of the BERT state.""" history_fragments = [] for doc in documents: answer_tokens = tokenize_fn(doc.answer.answer if doc.answer else '') answer_tokens_words = nqutils.bert_tokens_to_words(answer_tokens) _, context_window = types.HistoryEntry.get_window_around_substr( doc.content, doc.answer.answer, context_size) context_tokens = tokenize_fn(context_window) context_tokens_words = nqutils.bert_tokens_to_words(context_tokens) title_tokens = tokenize_fn(doc.title if doc.title else '') title_tokens = title_tokens[:max_title_length] title_tokens_words = nqutils.bert_tokens_to_words(title_tokens) state_part = [ p for tokens, part in zip([answer_tokens, context_tokens, title_tokens], ['history_answer', 'history_context', 'history_title']) for p in [part] * len(tokens) + ['[SEP]'] ] length = len(context_tokens) + len(answer_tokens) + len( title_tokens) + 3 history_fragments.append( nqutils.ObsFragment( text=nqutils.Text(tokens=[ *answer_tokens, '[SEP]', *context_tokens, '[SEP]', *title_tokens, '[SEP]', ]), type_values={ 'state_part': state_part, }, float_values={ 'mr_score': [doc.answer.mr_score] * length, 'idf_score': [idf_lookup[word] for word in answer_tokens_words] + [0.] + [idf_lookup[word] for word in context_tokens_words] + [0.] + [idf_lookup[word] for word in title_tokens_words] + [0.] }, )) # stop early if we already have enough history to fill up the whole state if sum([len(fragment.text.tokens) for fragment in history_fragments]) >= max_length: logging.info('BERT state reached max_length: %d', max_length) break return history_fragments
def state_tree_state_part(tree: state_tree.NQStateTree, idf_lookup: IdfLookupFn) -> nqutils.ObsFragment: """Computes the state tree part of the BERT state.""" leave_tokens = tree.root.leaves() return nqutils.ObsFragment( text=nqutils.Text(tokens=[ *leave_tokens, '[SEP]', ]), type_values={ 'state_part': ['tree_leaves'] * len(leave_tokens) + ['[SEP]'], }, float_values={ 'mr_score': [0.0] * (len(leave_tokens) + 1), 'idf_score': [ idf_lookup[token] for token in nqutils.bert_tokens_to_words(leave_tokens) ] + [0.] }, )
def original_query_state_part(query: str, tokenize_fn: TokenizeFn, idf_lookup: IdfLookupFn) -> nqutils.ObsFragment: """Computes the original query part of the BERT state.""" original_tokens = tokenize_fn(query) return nqutils.ObsFragment( text=nqutils.Text(tokens=[ '[CLS]', *original_tokens, '[SEP]', ]), type_values={ 'state_part': ['[CLS]'] + ['original_query'] * len(original_tokens) + ['[SEP]'], }, float_values={ 'mr_score': [0.0] * (len(original_tokens) + 2), 'idf_score': [0.] + [ idf_lookup[token] for token in nqutils.bert_tokens_to_words(original_tokens) ] + [0.] }, )
def test_state_tree_state_part(self): tree = state_tree.NQStateTree(grammar=state_tree.NQCFG(' \n '.join([ "_Q_ -> '[neg]' '[title]' _W_ '[pos]' '[contents]' _W_", '_W_ -> _Vw_', '_W_ -> _Vw_ _Vsh_', "_Vw_ -> 'cov'", "_Vw_ -> 'out'", "_Vsh_ -> '##er'", ]))) tree.grammar.set_start(tree.grammar.productions()[0].lhs()) apply_productions( tree, ["_Q_ -> '[neg]' '[title]' _W_ '[pos]' '[contents]' _W_"]) self.assertEqual( bert_state_lib.state_tree_state_part(tree=tree, idf_lookup=self.idf_lookup), nqutils.ObsFragment(text=nqutils.Text( tokens='[neg] [title] [pos] [contents] [SEP]'.split()), type_values={ 'state_part': ['tree_leaves'] * 4 + ['[SEP]'], }, float_values={ 'mr_score': [0.0] * 5, 'idf_score': [0.0] * 5, })) # Does not change the leaves. apply_productions(tree, ['_W_ -> _Vw_']) self.assertEqual( bert_state_lib.state_tree_state_part(tree=tree, idf_lookup=self.idf_lookup), nqutils.ObsFragment(text=nqutils.Text( tokens='[neg] [title] [pos] [contents] [SEP]'.split()), type_values={ 'state_part': ['tree_leaves'] * 4 + ['[SEP]'], }, float_values={ 'mr_score': [0.0] * 5, 'idf_score': [0.0] * 5, })) apply_productions(tree, ["_Vw_ -> 'out'"]) self.assertEqual( bert_state_lib.state_tree_state_part(tree=tree, idf_lookup=self.idf_lookup), nqutils.ObsFragment(text=nqutils.Text( tokens='[neg] [title] out [pos] [contents] [SEP]'.split()), type_values={ 'state_part': ['tree_leaves'] * 5 + ['[SEP]'], }, float_values={ 'mr_score': [0.0] * 6, 'idf_score': [0.0] * 6, })) # Does not change the leaves. apply_productions(tree, ['_W_ -> _Vw_ _Vsh_']) self.assertEqual( bert_state_lib.state_tree_state_part(tree=tree, idf_lookup=self.idf_lookup), nqutils.ObsFragment(text=nqutils.Text( tokens='[neg] [title] out [pos] [contents] [SEP]'.split()), type_values={ 'state_part': ['tree_leaves'] * 5 + ['[SEP]'], }, float_values={ 'mr_score': [0.0] * 6, 'idf_score': [0.0] * 6, })) apply_productions(tree, ["_Vw_ -> 'cov'"]) self.assertEqual( bert_state_lib.state_tree_state_part(tree=tree, idf_lookup=self.idf_lookup), nqutils.ObsFragment(text=nqutils.Text( tokens='[neg] [title] out [pos] [contents] cov [SEP]'.split()), type_values={ 'state_part': ['tree_leaves'] * 6 + ['[SEP]'], }, float_values={ 'mr_score': [0.0] * 7, 'idf_score': [0.0] * 7, })) apply_productions(tree, ["_Vsh_ -> '##er'"]) self.assertEqual( bert_state_lib.state_tree_state_part(tree=tree, idf_lookup=self.idf_lookup), nqutils.ObsFragment( text=nqutils.Text( tokens='[neg] [title] out [pos] [contents] cov ##er [SEP]'. split()), type_values={ 'state_part': ['tree_leaves'] * 7 + ['[SEP]'], }, float_values={ 'mr_score': [0.0] * 8, # cov ##er 'idf_score': [0.0] * 5 + [5.25, 5.25, 0], }))