Exemplo n.º 1
0
    def __init__(self, opt: Opt, agents=None, shared=None):
        super().__init__(opt, agents, shared)
        self.batchsize = opt["batchsize"]
        self.batch_agents = []
        self.batch_acts = []
        self.batch_goals = []  # for case when num_episodes < batchsize
        self.batch_tod_world_metrics = []
        for i in range(self.batchsize):
            here_agents = []
            for j, agent in enumerate(agents):
                if (
                    j == SYSTEM_UTT_IDX
                ):  # handle separately cause we expect it to be same as API_CALL agent
                    here_agents.append(here_agents[API_CALL_IDX])
                    continue
                share = agent.share()
                batch_opt = copy.deepcopy(share["opt"])
                batch_opt["batchindex"] = i
                here_agents.append(share["class"](batch_opt, share))
            self.batch_agents.append(here_agents)
            self.batch_acts.append([Message.padding_example()] * 4)
            self.batch_tod_world_metrics.append(tod_metrics.TodMetrics())
        self.end_episode = [False] * self.batchsize

        self.max_turns = self.opt.get("max_turns", 30)
        self.turns = 0
        self.need_grounding = True
Exemplo n.º 2
0
 def _observe_and_act(
     self, observe_idx, act_idx, info="for regular parley", override_act_idx=None
 ):
     act_agent_idx = override_act_idx if override_act_idx else act_idx
     act_agent = self.agents[act_agent_idx]
     record_output_idx = act_idx
     if hasattr(act_agent, "batch_act"):
         batch_observations = []
         for i in range(self.batchsize):
             if not self.end_episode[i]:
                 observe = self.batch_acts[i][observe_idx]
                 observe = self.batch_agents[i][act_agent_idx].observe(observe)
                 batch_observations.append(Message(observe))
             else:
                 # We're done with this episode, so just do a pad.
                 # NOTE: This could cause issues with RL down the line
                 batch_observations.append(Message.padding_example())
                 self.batch_acts[i][record_output_idx] = {"text": "", "id": ""}
         batch_actions = act_agent.batch_act(batch_observations)
         for i in range(self.batchsize):
             if self.end_episode[i]:
                 continue
             self.batch_acts[i][record_output_idx] = batch_actions[i]
             self.batch_agents[i][record_output_idx].self_observe(batch_actions[i])
     else:  # Run on agents individually
         for i in range(self.batchsize):
             act_agent = (
                 self.batch_agents[i][override_act_idx]
                 if override_act_idx
                 else self.batch_agents[i][act_idx]
             )
             if hasattr(act_agent, "episode_done") and act_agent.episode_done():
                 self.end_episode[i] = True
             if self.end_episode[i]:
                 # Following line exists because:
                 # 1. Code for writing converseations is not hapy if an "id" does not exists with a sample
                 # 2. Because of the `self.end_episode` code, no agent will see this example anyway.
                 self.batch_acts[i][record_output_idx] = {"text": "", "id": ""}
                 continue
             act_agent.observe(self.batch_acts[i][observe_idx])
             if isinstance(act_agent, LocalHumanAgent):
                 print(
                     f"Getting message for {SPEAKER_TO_NAME[record_output_idx]} for {info} in batch {i}"
                 )
             try:
                 self.batch_acts[i][record_output_idx] = act_agent.act()
             except StopIteration:
                 self.end_episode[i] = True
     for i in range(self.batchsize):
         if self.end_episode[i]:
             continue
         self.batch_tod_world_metrics[i].handle_message(
             self.batch_acts[i][record_output_idx], SPEAKER_TO_NAME[act_agent_idx]
         )
         if tod.STANDARD_DONE in self.batch_acts[i][record_output_idx].get(
             "text", ""
         ):
             # User models trained to output a "DONE" on last turn; same with human agents.
             self.end_episode[i] = True