コード例 #1
0
def create_env(max_history_entries: int):
    new_flags = get_updated_default_flags()
    with flagsaver.flagsaver(**new_flags):
        nqenv = env.NQEnv(
            nq_server=None,
            state=types.EnvState(
                original_query=environment_pb2.GetQueryResponse()),
            training=False,
            stop_after_seeing_new_results=True)
        history_entries = [
            types.HistoryEntry(query='query',
                               original_query='original_query',
                               documents=[
                                   create_document('1', 1),
                                   create_document('2', 2),
                                   create_document('3', 3),
                               ]),
            types.HistoryEntry(query='query',
                               original_query='original_query',
                               documents=[
                                   create_document('3', 3),
                                   create_document('4', 4),
                                   create_document('0', 0),
                               ]),
            types.HistoryEntry(query='query',
                               original_query='original_query',
                               documents=[
                                   create_document('5', 5),
                                   create_document('3', 3),
                                   create_document('-1', -1),
                               ])
        ]
        nqenv.state.history = history_entries[:max_history_entries]
        return nqenv
コード例 #2
0
  def process(self, element, *args, **kwargs):
    # We really want to get results for all episodes.  So if an episode fails,
    # we will retry up to `_NUM_RETRIES` time.
    # We still log all failed episodes so we could run a smaller follow-up
    # job and join the data, but hopefully, 20 retries does the trick.

    element_json = json.loads(element)
    test_query = environment_pb2.GetQueryResponse(
        query=element_json['question'], gold_answer=element_json['answer'])

    for _ in range(_RUNS_PER_QUERY.value):
      self.started.inc()

      attempts = 0
      initial_inference_cache = []
      while True:
        attempts += 1
        # Note:  Any interaction with the environment might fail, even resetting
        #        an episode.  So move everything inside the try-body.
        try:
          episode = self.mzconfig.new_episode(
              environment=self.environment, index=test_query)
          run_episode(
              episode=episode,
              mzconfig=self.mzconfig,
              initial_inference_service=self.initial_inference_stub,
              recurrent_inference_service=self.recurrent_inference_stub,
              initial_inference_cache=initial_inference_cache,
              counter=self.cache_hit)
          metrics_dict = self.environment.special_episode_statistics_learner(
              return_as_dict=True)
          self.completed.inc()

          state_dict = self.environment.state.json_repr()
          output = {
              'query': self.environment.state.original_query.query,
              'state': state_dict,
              'metrics': metrics_dict,
              'gold_answer': element_json['answer'],
          }
          yield json.dumps(output)
          break
        except (grpc.RpcError, core.RLEnvironmentError) as e:
          if attempts > _NUM_RETRIES.value:
            logging.info('Episode permanently failed: %s', e)
            self.failed.inc()
            yield beam.pvalue.TaggedOutput(tag='failed', value=element)
            break
          else:
            logging.info('Episode failed: %s, retrying', e)
            self.retries.inc()
コード例 #3
0
    def test_valid_full_words_in_obs(self):
        new_flags = get_updated_default_flags(num_documents_to_retrieve=3, )
        with flagsaver.flagsaver(**new_flags):
            original_query = environment_pb2.GetQueryResponse()
            original_query.query = 'who was the first emperor of ancient china'
            nqenv = env.NQEnv(
                nq_server=None,
                state=types.EnvState(original_query=original_query),
                training=False,
                stop_after_seeing_new_results=True)
            # Only the original query is in the history.
            nqenv.state.history = [
                types.HistoryEntry(
                    query=original_query.query,
                    original_query=original_query.query,
                    documents=[
                        create_document(content='Chuanqi Huangdi',
                                        mr_score=5,
                                        answer='Huangdi'),
                        create_document(content='Yuan Dynasty Provinces',
                                        mr_score=3,
                                        answer='Yuan'),
                        create_document(
                            content='Huangdi Yinfujing Yellow Emperor',
                            mr_score=-1,
                            answer='Yinfujing'),
                    ])
            ]
            obs = nqenv._obs()
            self.assertEqual(
                set(obs.valid_words.all_valid_words.full_words), {
                    'dynasty', 'yinfujing', 'yellow', 'yuan', 'chuanqi',
                    'huangdi', 'provinces', 'emperor'
                })
            self.assertEqual(
                set(obs.valid_words.question_valid_words.full_words),
                {'first', 'emperor', 'ancient', 'china'})
            self.assertEqual(
                set(obs.valid_words.answer_valid_words.full_words),
                {'huangdi', 'yuan', 'yinfujing'})
            self.assertEqual(
                set(obs.valid_words.document_valid_words.full_words),
                set(obs.valid_words.all_valid_words.full_words))

            # query includes terms with +/- contents/titles
            nqenv.state.history[-1].query = (
                'who was the first emperor of ancient china '
                '+(contents:"emperor") -(contents:"provinces") '
                '+(title:"huangdi") -(title:"chinese")')
            obs = nqenv._obs()
            self.assertEqual(
                set(obs.valid_words.all_valid_words.full_words),
                {'dynasty', 'yinfujing', 'yellow', 'yuan', 'chuanqi'})
            self.assertEqual(
                set(obs.valid_words.question_valid_words.full_words),
                {'first', 'ancient', 'china'})
            self.assertEqual(
                set(obs.valid_words.answer_valid_words.full_words),
                {'yuan', 'yinfujing'})
            self.assertEqual(
                set(obs.valid_words.document_valid_words.full_words),
                set(obs.valid_words.all_valid_words.full_words))

            # query includes only added terms
            nqenv.state.history[-1].query = (
                'who was the first emperor of ancient china emperor '
                'huangdi')
            obs = nqenv._obs()
            self.assertEqual(set(obs.valid_words.all_valid_words.full_words), {
                'dynasty', 'yinfujing', 'yellow', 'yuan', 'chuanqi',
                'provinces'
            })
            self.assertEqual(
                set(obs.valid_words.question_valid_words.full_words),
                {'first', 'ancient', 'china'})
            self.assertEqual(
                set(obs.valid_words.answer_valid_words.full_words),
                {'yuan', 'yinfujing'})
            self.assertEqual(
                set(obs.valid_words.document_valid_words.full_words),
                set(obs.valid_words.all_valid_words.full_words))

            # query includes both terms with +/- contents/titles and added terms
            nqenv.state.history[-1].query = (
                'who was the first emperor of ancient china emperor '
                'huangdi +(contents:"emperor") provinces '
                '-(title:"chuanqi") chinese')
            obs = nqenv._obs()
            self.assertEqual(set(obs.valid_words.all_valid_words.full_words),
                             {'dynasty', 'yinfujing', 'yellow', 'yuan'})
            self.assertEqual(
                set(obs.valid_words.question_valid_words.full_words),
                {'first', 'ancient', 'china'})
            self.assertEqual(
                set(obs.valid_words.answer_valid_words.full_words),
                {'yuan', 'yinfujing'})
            self.assertEqual(
                set(obs.valid_words.document_valid_words.full_words),
                set(obs.valid_words.all_valid_words.full_words))