Beispiel #1
0
    def _get_response_timeout_loop(
        agent: Agent,
        world: World,
        timeout: int = DEFAULT_TIMEOUT,
        timeout_msg: str = 'You have timed out',
    ) -> Optional[Message]:
        """
        Get a response from the agent.

        :param agent:
            agent who is acting
        :param world:
            world in which agent is acting
        :param timeout:
            timeout in secs
        :param timeout_msg:
            what to say to agent when they timeout

        :return response:
            Response if given, else None
        """
        a = TimeoutUtils.get_timeout_act(agent, timeout)
        if a is None:
            world.episodeDone = True  # type: ignore
            agent.observe({"id": "", "text": timeout_msg})
            return None

        if (a.get("text", "") or "").upper() == "EXIT":
            world.episodeDone = True  # type: ignore
            return None
        return a
Beispiel #2
0
    def _get_time(self, world: World) -> Tuple[float, float, float]:
        """
        Return train, log, and validate timing.

        If relying on the time for validation/logging/max train time purposes,
        we sync and return primary worker's time.

        Otherwise, it's not super relevant what we do here.

        **SIDE EFFECT**: Update _total_epochs trained.

        :param world:
            current running world

        :return (train, log, valid):
            return time for each of train, log, and validation
        """
        if (
            self.max_train_time < float('inf')
            or self.log_every_n_secs < float('inf')
            or self.val_every_n_secs < float('inf')
            or self.val_every_n_epochs < float('inf')
            or self.max_num_epochs < float('inf')
        ):
            self._total_epochs = self._preempted_epochs + sum(
                all_gather_list(world.get_total_epochs())
            )
            train_time, log_time, validate_time, save_time = sync_object(
                (
                    self.train_time.time(),
                    self.log_time.time(),
                    self.validate_time.time(),
                    self.save_time.time(),
                )
            )
        else:
            train_time, log_time, validate_time, save_time = (
                self.train_time.time(),
                self.log_time.time(),
                self.validate_time.time(),
                self.save_time.time(),
            )
            self._total_epochs = self._preempted_epochs + (
                num_workers() * world.get_total_epochs()
            )

        return train_time, log_time, validate_time, save_time