def create_env(max_history_entries: int): new_flags = get_updated_default_flags() with flagsaver.flagsaver(**new_flags): nqenv = env.NQEnv( nq_server=None, state=types.EnvState( original_query=environment_pb2.GetQueryResponse()), training=False, stop_after_seeing_new_results=True) history_entries = [ types.HistoryEntry(query='query', original_query='original_query', documents=[ create_document('1', 1), create_document('2', 2), create_document('3', 3), ]), types.HistoryEntry(query='query', original_query='original_query', documents=[ create_document('3', 3), create_document('4', 4), create_document('0', 0), ]), types.HistoryEntry(query='query', original_query='original_query', documents=[ create_document('5', 5), create_document('3', 3), create_document('-1', -1), ]) ] nqenv.state.history = history_entries[:max_history_entries] return nqenv
def process(self, element, *args, **kwargs): # We really want to get results for all episodes. So if an episode fails, # we will retry up to `_NUM_RETRIES` time. # We still log all failed episodes so we could run a smaller follow-up # job and join the data, but hopefully, 20 retries does the trick. element_json = json.loads(element) test_query = environment_pb2.GetQueryResponse( query=element_json['question'], gold_answer=element_json['answer']) for _ in range(_RUNS_PER_QUERY.value): self.started.inc() attempts = 0 initial_inference_cache = [] while True: attempts += 1 # Note: Any interaction with the environment might fail, even resetting # an episode. So move everything inside the try-body. try: episode = self.mzconfig.new_episode( environment=self.environment, index=test_query) run_episode( episode=episode, mzconfig=self.mzconfig, initial_inference_service=self.initial_inference_stub, recurrent_inference_service=self.recurrent_inference_stub, initial_inference_cache=initial_inference_cache, counter=self.cache_hit) metrics_dict = self.environment.special_episode_statistics_learner( return_as_dict=True) self.completed.inc() state_dict = self.environment.state.json_repr() output = { 'query': self.environment.state.original_query.query, 'state': state_dict, 'metrics': metrics_dict, 'gold_answer': element_json['answer'], } yield json.dumps(output) break except (grpc.RpcError, core.RLEnvironmentError) as e: if attempts > _NUM_RETRIES.value: logging.info('Episode permanently failed: %s', e) self.failed.inc() yield beam.pvalue.TaggedOutput(tag='failed', value=element) break else: logging.info('Episode failed: %s, retrying', e) self.retries.inc()
def test_valid_full_words_in_obs(self): new_flags = get_updated_default_flags(num_documents_to_retrieve=3, ) with flagsaver.flagsaver(**new_flags): original_query = environment_pb2.GetQueryResponse() original_query.query = 'who was the first emperor of ancient china' nqenv = env.NQEnv( nq_server=None, state=types.EnvState(original_query=original_query), training=False, stop_after_seeing_new_results=True) # Only the original query is in the history. nqenv.state.history = [ types.HistoryEntry( query=original_query.query, original_query=original_query.query, documents=[ create_document(content='Chuanqi Huangdi', mr_score=5, answer='Huangdi'), create_document(content='Yuan Dynasty Provinces', mr_score=3, answer='Yuan'), create_document( content='Huangdi Yinfujing Yellow Emperor', mr_score=-1, answer='Yinfujing'), ]) ] obs = nqenv._obs() self.assertEqual( set(obs.valid_words.all_valid_words.full_words), { 'dynasty', 'yinfujing', 'yellow', 'yuan', 'chuanqi', 'huangdi', 'provinces', 'emperor' }) self.assertEqual( set(obs.valid_words.question_valid_words.full_words), {'first', 'emperor', 'ancient', 'china'}) self.assertEqual( set(obs.valid_words.answer_valid_words.full_words), {'huangdi', 'yuan', 'yinfujing'}) self.assertEqual( set(obs.valid_words.document_valid_words.full_words), set(obs.valid_words.all_valid_words.full_words)) # query includes terms with +/- contents/titles nqenv.state.history[-1].query = ( 'who was the first emperor of ancient china ' '+(contents:"emperor") -(contents:"provinces") ' '+(title:"huangdi") -(title:"chinese")') obs = nqenv._obs() self.assertEqual( set(obs.valid_words.all_valid_words.full_words), {'dynasty', 'yinfujing', 'yellow', 'yuan', 'chuanqi'}) self.assertEqual( set(obs.valid_words.question_valid_words.full_words), {'first', 'ancient', 'china'}) self.assertEqual( set(obs.valid_words.answer_valid_words.full_words), {'yuan', 'yinfujing'}) self.assertEqual( set(obs.valid_words.document_valid_words.full_words), set(obs.valid_words.all_valid_words.full_words)) # query includes only added terms nqenv.state.history[-1].query = ( 'who was the first emperor of ancient china emperor ' 'huangdi') obs = nqenv._obs() self.assertEqual(set(obs.valid_words.all_valid_words.full_words), { 'dynasty', 'yinfujing', 'yellow', 'yuan', 'chuanqi', 'provinces' }) self.assertEqual( set(obs.valid_words.question_valid_words.full_words), {'first', 'ancient', 'china'}) self.assertEqual( set(obs.valid_words.answer_valid_words.full_words), {'yuan', 'yinfujing'}) self.assertEqual( set(obs.valid_words.document_valid_words.full_words), set(obs.valid_words.all_valid_words.full_words)) # query includes both terms with +/- contents/titles and added terms nqenv.state.history[-1].query = ( 'who was the first emperor of ancient china emperor ' 'huangdi +(contents:"emperor") provinces ' '-(title:"chuanqi") chinese') obs = nqenv._obs() self.assertEqual(set(obs.valid_words.all_valid_words.full_words), {'dynasty', 'yinfujing', 'yellow', 'yuan'}) self.assertEqual( set(obs.valid_words.question_valid_words.full_words), {'first', 'ancient', 'china'}) self.assertEqual( set(obs.valid_words.answer_valid_words.full_words), {'yuan', 'yinfujing'}) self.assertEqual( set(obs.valid_words.document_valid_words.full_words), set(obs.valid_words.all_valid_words.full_words))