Пример #1
0
 def test_history_state_template(self):
     new_flags = get_updated_default_flags(context_window_size=3,
                                           context_title_size=1)
     with flagsaver.flagsaver(**new_flags):
         history_state_part = bert_state_lib.history_state_part(
             documents=[
                 environment_pb2.Document(content='this is a high term',
                                          answer=environment_pb2.Answer(
                                              answer='high', mr_score=42.0),
                                          title='high title'),
                 environment_pb2.Document(
                     content='this is a low term instead',
                     answer=environment_pb2.Answer(answer='low',
                                                   mr_score=1.0),
                     title='low title'),
             ],
             tokenize_fn=self.tokenizer.tokenize,
             idf_lookup=collections.defaultdict(float, (
                 ('high', 10.0),
                 ('low', 5.0),
                 ('term', 7.5),
                 ('title', 8.0),
             )),
             context_size=3,
             max_length=128,
             max_title_length=1)
     self.assertEqual(history_state_part, [
         nqutils.ObsFragment(
             text=nqutils.Text(
                 tokens='high [SEP] a high term [SEP] high [SEP]'.split()),
             type_values={
                 'state_part':
                 ['history_answer', '[SEP]'] + ['history_context'] * 3 +
                 ['[SEP]'] + ['history_title'] + ['[SEP]']
             },
             float_values={
                 'mr_score': [42.0] * 8,
                 'idf_score': [10.0, 0.0, 0.0, 10.0, 7.5, 0.0, 10.0, 0.0],
             }),
         nqutils.ObsFragment(
             text=nqutils.Text(
                 tokens='low [SEP] a low term [SEP] low [SEP]'.split()),
             type_values={
                 'state_part':
                 ['history_answer', '[SEP]'] + ['history_context'] * 3 +
                 ['[SEP]'] + ['history_title'] + ['[SEP]']
             },
             float_values={
                 'mr_score': [1.0] * 8,
                 'idf_score': [5.0, 0.0, 0.0, 5.0, 7.5, 0.0, 5.0, 0.0],
             }),
     ])
Пример #2
0
 def test_original_query_state_template(self):
     self.assertEqual(
         bert_state_lib.original_query_state_part(
             query='who is on the cover of ufc 2',
             tokenize_fn=self.tokenizer.tokenize,
             idf_lookup=self.idf_lookup),
         nqutils.ObsFragment(
             text=nqutils.Text(
                 tokens='[CLS] who is on the cover of ufc 2 [SEP]'.split()),
             type_values={
                 'state_part':
                 ['[CLS]'] + ['original_query'] * 8 + ['[SEP]'],
             },
             float_values={
                 'mr_score': [0.0] * (8 + 2),
                 'idf_score': [
                     0.0,  # [CLS]
                     0.0,  # who
                     0.0,  # is
                     0.0,  # on
                     0.0,  # the
                     5.25,  # cover
                     0.0,  # of
                     8.5,  # UFC
                     3.5,  # 2
                     0.0  # [SEP]
                 ]
             }))
Пример #3
0
def history_state_part(documents: Sequence[environment_pb2.Document],
                       tokenize_fn: TokenizeFn, idf_lookup: IdfLookupFn,
                       context_size: int, max_length: int,
                       max_title_length: int) -> List[nqutils.ObsFragment]:
    """Computes the history part of the BERT state."""
    history_fragments = []

    for doc in documents:
        answer_tokens = tokenize_fn(doc.answer.answer if doc.answer else '')
        answer_tokens_words = nqutils.bert_tokens_to_words(answer_tokens)
        _, context_window = types.HistoryEntry.get_window_around_substr(
            doc.content, doc.answer.answer, context_size)
        context_tokens = tokenize_fn(context_window)
        context_tokens_words = nqutils.bert_tokens_to_words(context_tokens)
        title_tokens = tokenize_fn(doc.title if doc.title else '')
        title_tokens = title_tokens[:max_title_length]
        title_tokens_words = nqutils.bert_tokens_to_words(title_tokens)
        state_part = [
            p for tokens, part in
            zip([answer_tokens, context_tokens, title_tokens],
                ['history_answer', 'history_context', 'history_title'])
            for p in [part] * len(tokens) + ['[SEP]']
        ]

        length = len(context_tokens) + len(answer_tokens) + len(
            title_tokens) + 3
        history_fragments.append(
            nqutils.ObsFragment(
                text=nqutils.Text(tokens=[
                    *answer_tokens,
                    '[SEP]',
                    *context_tokens,
                    '[SEP]',
                    *title_tokens,
                    '[SEP]',
                ]),
                type_values={
                    'state_part': state_part,
                },
                float_values={
                    'mr_score': [doc.answer.mr_score] * length,
                    'idf_score':
                    [idf_lookup[word] for word in answer_tokens_words] + [0.] +
                    [idf_lookup[word] for word in context_tokens_words] +
                    [0.] + [idf_lookup[word]
                            for word in title_tokens_words] + [0.]
                },
            ))
        # stop early if we already have enough history to fill up the whole state
        if sum([len(fragment.text.tokens)
                for fragment in history_fragments]) >= max_length:
            logging.info('BERT state reached max_length: %d', max_length)
            break

    return history_fragments
Пример #4
0
def state_tree_state_part(tree: state_tree.NQStateTree,
                          idf_lookup: IdfLookupFn) -> nqutils.ObsFragment:
    """Computes the state tree part of the BERT state."""
    leave_tokens = tree.root.leaves()

    return nqutils.ObsFragment(
        text=nqutils.Text(tokens=[
            *leave_tokens,
            '[SEP]',
        ]),
        type_values={
            'state_part': ['tree_leaves'] * len(leave_tokens) + ['[SEP]'],
        },
        float_values={
            'mr_score': [0.0] * (len(leave_tokens) + 1),
            'idf_score': [
                idf_lookup[token]
                for token in nqutils.bert_tokens_to_words(leave_tokens)
            ] + [0.]
        },
    )
Пример #5
0
def original_query_state_part(query: str, tokenize_fn: TokenizeFn,
                              idf_lookup: IdfLookupFn) -> nqutils.ObsFragment:
    """Computes the original query part of the BERT state."""
    original_tokens = tokenize_fn(query)

    return nqutils.ObsFragment(
        text=nqutils.Text(tokens=[
            '[CLS]',
            *original_tokens,
            '[SEP]',
        ]),
        type_values={
            'state_part':
            ['[CLS]'] + ['original_query'] * len(original_tokens) + ['[SEP]'],
        },
        float_values={
            'mr_score': [0.0] * (len(original_tokens) + 2),
            'idf_score': [0.] + [
                idf_lookup[token]
                for token in nqutils.bert_tokens_to_words(original_tokens)
            ] + [0.]
        },
    )
Пример #6
0
    def test_state_tree_state_part(self):
        tree = state_tree.NQStateTree(grammar=state_tree.NQCFG(' \n '.join([
            "_Q_ -> '[neg]' '[title]' _W_ '[pos]' '[contents]' _W_",
            '_W_ -> _Vw_',
            '_W_ -> _Vw_ _Vsh_',
            "_Vw_ -> 'cov'",
            "_Vw_ -> 'out'",
            "_Vsh_ -> '##er'",
        ])))
        tree.grammar.set_start(tree.grammar.productions()[0].lhs())

        apply_productions(
            tree, ["_Q_ -> '[neg]' '[title]' _W_ '[pos]' '[contents]' _W_"])

        self.assertEqual(
            bert_state_lib.state_tree_state_part(tree=tree,
                                                 idf_lookup=self.idf_lookup),
            nqutils.ObsFragment(text=nqutils.Text(
                tokens='[neg] [title] [pos] [contents] [SEP]'.split()),
                                type_values={
                                    'state_part':
                                    ['tree_leaves'] * 4 + ['[SEP]'],
                                },
                                float_values={
                                    'mr_score': [0.0] * 5,
                                    'idf_score': [0.0] * 5,
                                }))

        # Does not change the leaves.
        apply_productions(tree, ['_W_ -> _Vw_'])
        self.assertEqual(
            bert_state_lib.state_tree_state_part(tree=tree,
                                                 idf_lookup=self.idf_lookup),
            nqutils.ObsFragment(text=nqutils.Text(
                tokens='[neg] [title] [pos] [contents] [SEP]'.split()),
                                type_values={
                                    'state_part':
                                    ['tree_leaves'] * 4 + ['[SEP]'],
                                },
                                float_values={
                                    'mr_score': [0.0] * 5,
                                    'idf_score': [0.0] * 5,
                                }))

        apply_productions(tree, ["_Vw_ -> 'out'"])
        self.assertEqual(
            bert_state_lib.state_tree_state_part(tree=tree,
                                                 idf_lookup=self.idf_lookup),
            nqutils.ObsFragment(text=nqutils.Text(
                tokens='[neg] [title] out [pos] [contents] [SEP]'.split()),
                                type_values={
                                    'state_part':
                                    ['tree_leaves'] * 5 + ['[SEP]'],
                                },
                                float_values={
                                    'mr_score': [0.0] * 6,
                                    'idf_score': [0.0] * 6,
                                }))

        # Does not change the leaves.
        apply_productions(tree, ['_W_ -> _Vw_ _Vsh_'])
        self.assertEqual(
            bert_state_lib.state_tree_state_part(tree=tree,
                                                 idf_lookup=self.idf_lookup),
            nqutils.ObsFragment(text=nqutils.Text(
                tokens='[neg] [title] out [pos] [contents] [SEP]'.split()),
                                type_values={
                                    'state_part':
                                    ['tree_leaves'] * 5 + ['[SEP]'],
                                },
                                float_values={
                                    'mr_score': [0.0] * 6,
                                    'idf_score': [0.0] * 6,
                                }))

        apply_productions(tree, ["_Vw_ -> 'cov'"])
        self.assertEqual(
            bert_state_lib.state_tree_state_part(tree=tree,
                                                 idf_lookup=self.idf_lookup),
            nqutils.ObsFragment(text=nqutils.Text(
                tokens='[neg] [title] out [pos] [contents] cov [SEP]'.split()),
                                type_values={
                                    'state_part':
                                    ['tree_leaves'] * 6 + ['[SEP]'],
                                },
                                float_values={
                                    'mr_score': [0.0] * 7,
                                    'idf_score': [0.0] * 7,
                                }))

        apply_productions(tree, ["_Vsh_ -> '##er'"])
        self.assertEqual(
            bert_state_lib.state_tree_state_part(tree=tree,
                                                 idf_lookup=self.idf_lookup),
            nqutils.ObsFragment(
                text=nqutils.Text(
                    tokens='[neg] [title] out [pos] [contents] cov ##er [SEP]'.
                    split()),
                type_values={
                    'state_part': ['tree_leaves'] * 7 + ['[SEP]'],
                },
                float_values={
                    'mr_score': [0.0] * 8,
                    #                         cov ##er
                    'idf_score': [0.0] * 5 + [5.25, 5.25, 0],
                }))