Пример #1
0
  def __init__(self,
               nq_server: server.NQServer,
               state: Optional[types.EnvState] = None,
               random_state: Optional[np.random.RandomState] = None,
               training: bool = True,
               stop_after_seeing_new_results: bool = False):
    super().__init__()
    self.nq_server = nq_server
    self.training = training

    self.first_time = True  # Used for initial debug logging

    self.stop_after_seeing_new_results = stop_after_seeing_new_results

    self.descriptor = get_descriptor()
    self.grammar = self.descriptor.extras['grammar']
    self.tokenizer = self.descriptor.extras['tokenizer']
    self.action_space = len(self.grammar.productions())

    self.idf_lookup = utils.IDFLookup.get_instance(
        path=common_flags.IDF_LOOKUP_PATH.value)

    trie_start_time = tf.timestamp()
    if common_flags.GLOBAL_TRIE_PATH.value is None:
      self.global_trie = pygtrie.Trie.fromkeys((x for x in map(
          functools.partial(
              to_action_tuple, grammar=self.grammar, tokenizer=self.tokenizer),
          self.idf_lookup.lookup) if x))
      self._logging_info('Built trie of size %s in %s s', len(self.global_trie),
                         tf.timestamp() - trie_start_time)
    else:
      with tf.io.gfile.GFile(common_flags.GLOBAL_TRIE_PATH.value,
                             'rb') as trie_f:
        self.global_trie = pickle.load(trie_f)
      self._logging_info('Restored trie of size %s in %s s',
                         len(self.global_trie),
                         tf.timestamp() - trie_start_time)

    # The value of the global steps in the learner is updated in step()
    self.training_steps = 0

    # Trie for the current results.  We only build this the first time after
    # a new set of results is obtained.  A value of `None` indicates that for
    # the current set of results, it has not been built yet.
    self.known_word_tries = None  # type: Optional[state_tree.KnownWordTries]
    self.valid_word_actions = None  # type: Optional[state_tree.ValidWordActions]
    self.use_rf_restrict = False

    self.state = state
    if state and state.tree is None:
      self.state.tree = state_tree.NQStateTree(grammar=self.grammar)

    self.bert_config: configs.BertConfig = self.descriptor.extras['bert_config']
    self.sequence_length: int = self.descriptor.extras['sequence_length']
    self.action_history = []
    self.n_episode = 0

    self._rand = np.random.RandomState()
    if random_state:
      self._rand.set_state(random_state)
Пример #2
0
def compute_bert_state(query: str,
                       documents: Sequence[environment_pb2.Document],
                       idf_lookup: IdfLookupFn, context_size: int,
                       title_size: int, seq_length: int,
                       tokenize_fn: TokenizeFn,
                       tokens_to_ids_fn: TokensToIdsFn) -> Dict[str, Any]:
    """Computes the BERT state as in the agent."""

    obs_fragments = make_bert_state_impl(
        query=query,
        tree=state_tree.NQStateTree(nltk.CFG.fromstring("Q -> 'unused'")),
        documents=documents,
        idf_lookup=idf_lookup,
        tokenize_fn=tokenize_fn,
        context_size=context_size,
        max_length=seq_length,
        max_title_length=title_size)

    token_ids, type_ids, float_values = nqutils.ObsFragment.combine_and_expand(
        fragments=obs_fragments,
        length=seq_length,
        type_vocabs=TYPE_VOCABS,
        float_names=FLOAT_NAMES,
        tokens_to_id_fn=tokens_to_ids_fn,
    )
    token_ids = np.array(token_ids, np.int32)
    type_ids = np.array(type_ids, np.int32).T
    float_values = np.array(float_values, np.float32).T

    return {
        'obs_fragments': obs_fragments,
        'token_ids': token_ids,
        'type_ids': type_ids,
        'float_values': float_values,
    }
Пример #3
0
    def reset(
        self,
        index: Optional[Union[int, environment_pb2.GetQueryResponse]] = None,
        documents: Optional[Sequence[environment_pb2.Document]] = None
    ) -> Tuple[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], Mapping[
            str, Any]]:
        """Resets the environment to play a query.

    Args:
      index:  If set as an integer, requests this specific query from the
        environment.  If set as a `GetQueryResponse`, resets for this query.  If
        `None`, a random query will be requested from the environment.
      documents:  If set, uses these results instead of actually querying the
        environment.  Useful during pre-training.

    Returns:
      This function changes the state of the environment, as the name suggests.
      It also returns a tuple of (Observation, InfoDict), where `Observation`
      is the observation for the agent at the beginning of the episode (after
      the original query has been issued).
    """

        if index is None or isinstance(index, int):
            query = self._get_query(index)
        else:
            query = index
        self._logging_info('Original Query [%d/%d]: %s | gold answers: %s.',
                           query.index, query.total, query.query,
                           query.gold_answer)
        self.state = types.EnvState(original_query=query,
                                    k=common_flags.K.value)
        self.state.add_history_entry(
            self._get_env_output(query=utils.escape_for_lucene(query.query),
                                 original_query=query,
                                 documents=documents))

        if (common_flags.RELEVANCE_FEEDBACK_RESTRICT.value == 1
                and self.training and query.gold_answer):
            target_query = (
                f'{utils.escape_for_lucene(query.query)} '
                f'+(contents:"{utils.escape_for_lucene(query.gold_answer[0])}")'
            )
            self.state.target_documents = self._get_env_output(
                query=target_query, original_query=query).documents

            # Check that the target documents lead to a higher reward otherwise we
            # drop them.
            if self._target_reward() < self._initial_reward():
                self.state.target_documents = None

        # Signal we have not yet built the trie for these results.
        self.known_word_tries = None
        self.valid_word_actions = None
        self.use_rf_restrict = False

        documents = self.state.history[-1].documents
        # If we do not have documents at this point, we have to skip the episode.
        if not documents:
            raise mzcore.SkipEpisode(
                f'No documents for original query {query.query}')

        self.state.tree = state_tree.NQStateTree(grammar=self.grammar)

        self._logging_info(
            'Initial result score: %s',
            self._compute_reward(self.state.history[0].documents))

        self.action_history = []
        self.n_episode += 1
        obs = self._obs()
        info = self._info_dict(obs=obs)
        return obs.observation, info
Пример #4
0
    def test_state_tree_state_part(self):
        tree = state_tree.NQStateTree(grammar=state_tree.NQCFG(' \n '.join([
            "_Q_ -> '[neg]' '[title]' _W_ '[pos]' '[contents]' _W_",
            '_W_ -> _Vw_',
            '_W_ -> _Vw_ _Vsh_',
            "_Vw_ -> 'cov'",
            "_Vw_ -> 'out'",
            "_Vsh_ -> '##er'",
        ])))
        tree.grammar.set_start(tree.grammar.productions()[0].lhs())

        apply_productions(
            tree, ["_Q_ -> '[neg]' '[title]' _W_ '[pos]' '[contents]' _W_"])

        self.assertEqual(
            bert_state_lib.state_tree_state_part(tree=tree,
                                                 idf_lookup=self.idf_lookup),
            nqutils.ObsFragment(text=nqutils.Text(
                tokens='[neg] [title] [pos] [contents] [SEP]'.split()),
                                type_values={
                                    'state_part':
                                    ['tree_leaves'] * 4 + ['[SEP]'],
                                },
                                float_values={
                                    'mr_score': [0.0] * 5,
                                    'idf_score': [0.0] * 5,
                                }))

        # Does not change the leaves.
        apply_productions(tree, ['_W_ -> _Vw_'])
        self.assertEqual(
            bert_state_lib.state_tree_state_part(tree=tree,
                                                 idf_lookup=self.idf_lookup),
            nqutils.ObsFragment(text=nqutils.Text(
                tokens='[neg] [title] [pos] [contents] [SEP]'.split()),
                                type_values={
                                    'state_part':
                                    ['tree_leaves'] * 4 + ['[SEP]'],
                                },
                                float_values={
                                    'mr_score': [0.0] * 5,
                                    'idf_score': [0.0] * 5,
                                }))

        apply_productions(tree, ["_Vw_ -> 'out'"])
        self.assertEqual(
            bert_state_lib.state_tree_state_part(tree=tree,
                                                 idf_lookup=self.idf_lookup),
            nqutils.ObsFragment(text=nqutils.Text(
                tokens='[neg] [title] out [pos] [contents] [SEP]'.split()),
                                type_values={
                                    'state_part':
                                    ['tree_leaves'] * 5 + ['[SEP]'],
                                },
                                float_values={
                                    'mr_score': [0.0] * 6,
                                    'idf_score': [0.0] * 6,
                                }))

        # Does not change the leaves.
        apply_productions(tree, ['_W_ -> _Vw_ _Vsh_'])
        self.assertEqual(
            bert_state_lib.state_tree_state_part(tree=tree,
                                                 idf_lookup=self.idf_lookup),
            nqutils.ObsFragment(text=nqutils.Text(
                tokens='[neg] [title] out [pos] [contents] [SEP]'.split()),
                                type_values={
                                    'state_part':
                                    ['tree_leaves'] * 5 + ['[SEP]'],
                                },
                                float_values={
                                    'mr_score': [0.0] * 6,
                                    'idf_score': [0.0] * 6,
                                }))

        apply_productions(tree, ["_Vw_ -> 'cov'"])
        self.assertEqual(
            bert_state_lib.state_tree_state_part(tree=tree,
                                                 idf_lookup=self.idf_lookup),
            nqutils.ObsFragment(text=nqutils.Text(
                tokens='[neg] [title] out [pos] [contents] cov [SEP]'.split()),
                                type_values={
                                    'state_part':
                                    ['tree_leaves'] * 6 + ['[SEP]'],
                                },
                                float_values={
                                    'mr_score': [0.0] * 7,
                                    'idf_score': [0.0] * 7,
                                }))

        apply_productions(tree, ["_Vsh_ -> '##er'"])
        self.assertEqual(
            bert_state_lib.state_tree_state_part(tree=tree,
                                                 idf_lookup=self.idf_lookup),
            nqutils.ObsFragment(
                text=nqutils.Text(
                    tokens='[neg] [title] out [pos] [contents] cov ##er [SEP]'.
                    split()),
                type_values={
                    'state_part': ['tree_leaves'] * 7 + ['[SEP]'],
                },
                float_values={
                    'mr_score': [0.0] * 8,
                    #                         cov ##er
                    'idf_score': [0.0] * 5 + [5.25, 5.25, 0],
                }))