コード例 #1
0
ファイル: t5_agent_lib.py プロジェクト: greck2908/language
def make_query(base: str, addition_terms: Sequence[Term],
               subtraction_terms: Sequence[Term],
               or_terms: Sequence[Term]) -> str:
    """Modifies a `base`-query with addition and subtraction terms.

  Args:
    base:  Base query, as is, i.e. without manual escaping nor operators.
    addition_terms:  Sequence of terms which will be added to `base` with the
      "+" operator.
    subtraction_terms:  Sequence of terms which will be added to `base` with the
      "-" operator.
    or_terms: Sequence of terms which will be added to `base` directly without
      any operator.

  Returns:
    The properly modified base-query which can be issued to lucene.
  """
    escaped_base = utils.escape_for_lucene(base)

    modifications = []
    for subtraction_term, addition_term, or_term in zip(
            subtraction_terms, addition_terms, or_terms):
        modifications.append(" ".join(
            list(
                filter(None, (_term_to_str(subtraction_term, operator="-"),
                              _term_to_str(addition_term, operator="+"),
                              _term_to_str(or_term, operator=""))))))

    full_query = f'{escaped_base} {" ".join(modifications)}'
    return full_query.strip()
コード例 #2
0
def to_lucene_query(base_query: str,
                    adjustments: Optional[List[QueryAdjustment]] = None,
                    escape_query: bool = True) -> str:
    """Builds a valid lucene query from a base query, query adjustments, and
  vocabulary mappings.

  Args:
    base_query: str, The base query which is prepended to the new query.
    adjustments: List[QueryAdjustment], List of all the adjustmenst
      (=reformulations) applied to the base query. The terms in the adjustments
      can be templates or words directly.
    escape_query:  If True, escapes `base_query`.  Needed for Lucene, not needed
      for running against WebSearch.

  Returns:
    A valid and escaped lucene query.
  """
    if not adjustments:
        adjustments = []

    escape_fn = {
        True: utils.escape_for_lucene,
        False: lambda x: x,
    }[escape_query]

    lucene_query = escape_fn(base_query)

    for adjustment in adjustments:
        term = adjustment.term
        # `term` needs to be properly escaped if it contains any special characters.
        term = utils.escape_for_lucene(term)

        field = adjustment.field.value.strip('[]')
        if adjustment.operator == Operator.MINUS:
            lucene_query += f' -({field}:"{term}")'
        elif adjustment.operator == Operator.PLUS:
            lucene_query += f' +({field}:"{term}")'
        elif adjustment.operator == Operator.APPEND:
            if adjustment.field == Field.ALL:
                lucene_query += f' {term}'
            else:
                lucene_query += f' ({field}:"{term}")'

    return lucene_query
コード例 #3
0
 def test_escape_for_lucene(self):
     self.assertEqual(utils.escape_for_lucene("foo:bar-baz"),
                      "foo\\:bar\\-baz")
コード例 #4
0
    def reset(
        self,
        index: Optional[Union[int, environment_pb2.GetQueryResponse]] = None,
        documents: Optional[Sequence[environment_pb2.Document]] = None
    ) -> Tuple[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], Mapping[
            str, Any]]:
        """Resets the environment to play a query.

    Args:
      index:  If set as an integer, requests this specific query from the
        environment.  If set as a `GetQueryResponse`, resets for this query.  If
        `None`, a random query will be requested from the environment.
      documents:  If set, uses these results instead of actually querying the
        environment.  Useful during pre-training.

    Returns:
      This function changes the state of the environment, as the name suggests.
      It also returns a tuple of (Observation, InfoDict), where `Observation`
      is the observation for the agent at the beginning of the episode (after
      the original query has been issued).
    """

        if index is None or isinstance(index, int):
            query = self._get_query(index)
        else:
            query = index
        self._logging_info('Original Query [%d/%d]: %s | gold answers: %s.',
                           query.index, query.total, query.query,
                           query.gold_answer)
        self.state = types.EnvState(original_query=query,
                                    k=common_flags.K.value)
        self.state.add_history_entry(
            self._get_env_output(query=utils.escape_for_lucene(query.query),
                                 original_query=query,
                                 documents=documents))

        if (common_flags.RELEVANCE_FEEDBACK_RESTRICT.value == 1
                and self.training and query.gold_answer):
            target_query = (
                f'{utils.escape_for_lucene(query.query)} '
                f'+(contents:"{utils.escape_for_lucene(query.gold_answer[0])}")'
            )
            self.state.target_documents = self._get_env_output(
                query=target_query, original_query=query).documents

            # Check that the target documents lead to a higher reward otherwise we
            # drop them.
            if self._target_reward() < self._initial_reward():
                self.state.target_documents = None

        # Signal we have not yet built the trie for these results.
        self.known_word_tries = None
        self.valid_word_actions = None
        self.use_rf_restrict = False

        documents = self.state.history[-1].documents
        # If we do not have documents at this point, we have to skip the episode.
        if not documents:
            raise mzcore.SkipEpisode(
                f'No documents for original query {query.query}')

        self.state.tree = state_tree.NQStateTree(grammar=self.grammar)

        self._logging_info(
            'Initial result score: %s',
            self._compute_reward(self.state.history[0].documents))

        self.action_history = []
        self.n_episode += 1
        obs = self._obs()
        info = self._info_dict(obs=obs)
        return obs.observation, info