def __init__(self, opt: Opt, agents=None, shared=None): super().__init__(opt, agents, shared) self.batchsize = opt["batchsize"] self.batch_agents = [] self.batch_acts = [] self.batch_goals = [] # for case when num_episodes < batchsize self.batch_tod_world_metrics = [] for i in range(self.batchsize): here_agents = [] for j, agent in enumerate(agents): if ( j == SYSTEM_UTT_IDX ): # handle separately cause we expect it to be same as API_CALL agent here_agents.append(here_agents[API_CALL_IDX]) continue share = agent.share() batch_opt = copy.deepcopy(share["opt"]) batch_opt["batchindex"] = i here_agents.append(share["class"](batch_opt, share)) self.batch_agents.append(here_agents) self.batch_acts.append([Message.padding_example()] * 4) self.batch_tod_world_metrics.append(tod_metrics.TodMetrics()) self.end_episode = [False] * self.batchsize self.max_turns = self.opt.get("max_turns", 30) self.turns = 0 self.need_grounding = True
def _observe_and_act( self, observe_idx, act_idx, info="for regular parley", override_act_idx=None ): act_agent_idx = override_act_idx if override_act_idx else act_idx act_agent = self.agents[act_agent_idx] record_output_idx = act_idx if hasattr(act_agent, "batch_act"): batch_observations = [] for i in range(self.batchsize): if not self.end_episode[i]: observe = self.batch_acts[i][observe_idx] observe = self.batch_agents[i][act_agent_idx].observe(observe) batch_observations.append(Message(observe)) else: # We're done with this episode, so just do a pad. # NOTE: This could cause issues with RL down the line batch_observations.append(Message.padding_example()) self.batch_acts[i][record_output_idx] = {"text": "", "id": ""} batch_actions = act_agent.batch_act(batch_observations) for i in range(self.batchsize): if self.end_episode[i]: continue self.batch_acts[i][record_output_idx] = batch_actions[i] self.batch_agents[i][record_output_idx].self_observe(batch_actions[i]) else: # Run on agents individually for i in range(self.batchsize): act_agent = ( self.batch_agents[i][override_act_idx] if override_act_idx else self.batch_agents[i][act_idx] ) if hasattr(act_agent, "episode_done") and act_agent.episode_done(): self.end_episode[i] = True if self.end_episode[i]: # Following line exists because: # 1. Code for writing converseations is not hapy if an "id" does not exists with a sample # 2. Because of the `self.end_episode` code, no agent will see this example anyway. self.batch_acts[i][record_output_idx] = {"text": "", "id": ""} continue act_agent.observe(self.batch_acts[i][observe_idx]) if isinstance(act_agent, LocalHumanAgent): print( f"Getting message for {SPEAKER_TO_NAME[record_output_idx]} for {info} in batch {i}" ) try: self.batch_acts[i][record_output_idx] = act_agent.act() except StopIteration: self.end_episode[i] = True for i in range(self.batchsize): if self.end_episode[i]: continue self.batch_tod_world_metrics[i].handle_message( self.batch_acts[i][record_output_idx], SPEAKER_TO_NAME[act_agent_idx] ) if tod.STANDARD_DONE in self.batch_acts[i][record_output_idx].get( "text", "" ): # User models trained to output a "DONE" on last turn; same with human agents. self.end_episode[i] = True