def _world_function():
            world_generator = utils.get_world_fn_attr(
                self._world_module, overworld_name, "generate_world"
            )
            overworld = world_generator(self.opt, [overworld_agent])
            while not overworld.episode_done() and not self.system_done:
                world_type = overworld.parley()
                if world_type is None:
                    time.sleep(0.5)
                    continue

                if world_type == self.manager.EXIT_STR:
                    self.manager._remove_agent(overworld_agent.id)
                    return world_type
                else:
                    continue

                # perform onboarding
                if onboard_map is not None:
                    onboard_type = onboard_map.get(world_type)
                    if onboard_type:
                        onboard_id = 'onboard-{}-{}'.format(overworld_agent.id, time.time())
                        agent = self.manager._create_agent(onboard_id, overworld_agent.id)
                        agent.data = overworld_agent.data
                        agent_state.set_active_agent(agent)
                        agent_state.assign_agent_to_task(agent, onboard_id)
                        _, onboard_data = self._run_world(task, onboard_type, [agent])
                        agent_state.onboard_data = onboard_data
                        agent_state.data = agent.data
                self.manager.add_agent_to_pool(agent_state, world_type)
                log_utils.print_and_log(logging.INFO, 'onboarding/overworld complete')

            return world_type
Beispiel #2
0
 def on_disconnect(*args):
     """
     Disconnect event is a no-op for us, as the server reconnects automatically
     on a retry.
     """
     log_utils.print_and_log(
         logging.INFO, 'World server disconnected: {}'.format(args))
     self.alive = False
     self._ensure_closed()
 def _done_callback(fut):
     e = fut.exception()
     if e is not None:
         log_utils.print_and_log(
             logging.ERROR,
             'World {} had error {}'.format(task_id, repr(e)),
             should_print=True,
         )
         traceback.print_exc(file=sys.stdout)
         if self.debug:
             raise e
Beispiel #4
0
 def on_error(ws, error):
     try:
         if error.errno == errno.ECONNREFUSED:
             self._ensure_closed()
             self.use_socket = False
             raise Exception("Socket refused connection, cancelling")
         else:
             log_utils.print_and_log(
                 logging.WARN,
                 'Socket logged error: {}'.format(repr(error)))
     except BaseException:
         if type(error) is websocket.WebSocketConnectionClosedException:
             return  # Connection closed is noop
         log_utils.print_and_log(
             logging.WARN,
             'Socket logged error: {} Restarting'.format(repr(error)),
         )
         self._ensure_closed()
Beispiel #5
0
 def on_message(*args):
     """
     Incoming message handler for messages from the FB user.
     """
     packet_dict = json.loads(args[1])
     if packet_dict['type'] == 'conn_success':
         self.alive = True
         return  # No action for successful connection
     if packet_dict['type'] == 'pong':
         self.last_pong = time.time()
         return  # No further action for pongs
     message_data = packet_dict['content']
     log_utils.print_and_log(
         logging.DEBUG,
         'Message data received: {}'.format(message_data))
     for message_packet in message_data['entry']:
         for message in message_packet['messaging']:
             self.message_callback(message)
    def shutdown(self):
        """
        Handle any client shutdown cleanup.
        """
        try:
            self.is_running = False
            self.world_runner.shutdown()
            if not self.bypass_server_setup:
                self.socket.keep_running = False
            self._expire_all_conversations()
        except BaseException as e:
            log_utils.print_and_log(logging.ERROR,
                                    f'world ended in error: {e}')

        finally:
            if not self.bypass_server_setup:
                server_utils.delete_server(self.server_task_name,
                                           self.opt['local'])
Beispiel #7
0
 def run_socket(*args):
     url_base_name = self.server_url.split('https://')[1]
     while self.keep_running:
         try:
             sock_addr = "wss://{}/".format(url_base_name)
             self.ws = websocket.WebSocketApp(
                 sock_addr,
                 on_message=on_message,
                 on_error=on_error,
                 on_close=on_disconnect,
             )
             self.ws.on_open = on_socket_open
             self.ws.run_forever(ping_interval=1, ping_timeout=0.9)
         except Exception as e:
             log_utils.print_and_log(
                 logging.WARN,
                 'Socket error {}, attempting restart'.format(repr(e)),
             )
         time.sleep(0.2)
        def _done_callback(fut):
            """
            Log and raise exception of task world, if there is one.

            Additionally, set active agent to overworld agent.
            """
            e = fut.exception()
            if e is not None:
                log_utils.print_and_log(
                    logging.ERROR,
                    'World {} had error {}'.format(world_type, repr(e)),
                    should_print=True,
                )
                traceback.print_exc(file=sys.stdout)
                for agent in agents:
                    self.observe_message(
                        agent.id,
                        'Sorry, this world closed. Returning to overworld.')
            else:
                log_utils.print_and_log(
                    logging.INFO,
                    'World {} had no error'.format(world_type),
                    should_print=True,
                )
            self.active_worlds[task_id] = None
            for agent in agents:
                self.after_agent_removed(agent.id)
                agent_state = self.get_agent_state(agent.id)
                agent_state.data = agent.data
                next_task = agent.data.get("next_task")
                log_utils.print_and_log(logging.INFO,
                                        "Next task: {}".format(next_task))
                if next_task is None:
                    self._launch_overworld(agent.id)
                    overworld_agent = agent_state.get_overworld_agent()
                    overworld_agent.data = agent_state.data
                    agent_state.set_active_agent(overworld_agent)
                elif next_task == self.EXIT_STR:
                    self._remove_agent(agent.id)
                else:
                    self.add_agent_to_pool(agent_state, next_task)
Beispiel #9
0
 def on_socket_open(*args):
     log_utils.print_and_log(logging.DEBUG,
                             'Socket open: {}'.format(args))
     self._send_world_alive()
    def start_task(self):
        """
        Begin handling task.

        Periodically check to see when enough agents are in the agent pool to start an
        instance of the task. Continue doing this until the desired number of
        conversations is had.
        """

        self.running = True
        print("Starting task")
        while self.running:
            # Loop forever until the server is shut down
            with self.agent_pool_change_condition:
                valid_pools = self._get_unique_pool()
                for world_type, agent_pool in valid_pools.items():
                    # check if agent has exceeded max time in pool
                    world_config = self.task_configs[world_type]
                    if world_config.max_time_in_pool is not None:
                        self.check_timeout_in_pool(
                            world_type,
                            agent_pool,
                            world_config.max_time_in_pool,
                            world_config.backup_task,
                        )

                    needed_agents = self.max_agents_for[world_type]
                    if len(agent_pool) >= needed_agents:
                        log_utils.print_and_log(logging.INFO,
                                                'starting pool',
                                                should_print=True)
                        # enough agents in pool to start new conversation
                        self.conversation_index += 1
                        task_id = 't_{}'.format(self.conversation_index)

                        # Add the required number of valid agents to the conv
                        agent_states = [w for w in agent_pool[:needed_agents]]
                        agents = []
                        for state in agent_states:
                            agent = self._create_agent(task_id, state.get_id())
                            agent.onboard_data = state.onboard_data
                            agent.data = state.data
                            state.assign_agent_to_task(agent, task_id)
                            state.set_active_agent(agent)
                            agents.append(agent)
                            # reset wait message state
                            state.stored_data['seen_wait_message'] = False
                        assign_role_function = utils.get_assign_roles_fn(
                            self.world_module, self.taskworld_map[world_type])
                        if assign_role_function is None:
                            assign_role_function = utils.default_assign_roles_fn
                        assign_role_function(agents)
                        # Allow task creator to filter out workers and run
                        # versions of the task that require fewer agents
                        agents = [a for a in agents if a.disp_id is not None]
                        for a in agents:
                            # Remove selected workers from the agent pool
                            self.remove_agent_from_pool(
                                self.get_agent_state(a.id),
                                world_type=world_type,
                                mark_removed=False,
                            )
                        for a in agents:
                            partner_list = agents.copy()
                            partner_list.remove(a)
                            a.message_partners = partner_list

                        done_callback = self._get_done_callback_for_agents(
                            task_id, world_type, agents)

                        # launch task world.
                        future = self.world_runner.launch_task_world(
                            task_id, self.taskworld_map[world_type], agents)
                        future.add_done_callback(done_callback)
                        self.active_worlds[task_id] = future

            time.sleep(utils.THREAD_MEDIUM_SLEEP)
 def _log_debug(self, text: str):
     time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
     log_utils.print_and_log(logging.DEBUG,
                             f'{time}: {text}',
                             should_print=True)
    def _manager_loop_fn(self):
        """
        An iteration of the manager's main loop to launch worlds.
        """
        with self.agent_pool_change_condition:
            valid_pools = self._get_unique_pool()
            for world_type, agent_pool in valid_pools.items():
                # check if agent has exceeded max time in pool
                world_config = self.task_configs[world_type]
                if world_config.max_time_in_pool is not None:
                    self.check_timeout_in_pool(
                        world_type,
                        agent_pool,
                        world_config.max_time_in_pool,
                        world_config.backup_task,
                    )

                needed_agents = self.max_agents_for[world_type]
                if len(agent_pool) >= needed_agents:
                    log_utils.print_and_log(
                        logging.INFO, 'starting pool', should_print=True
                    )
                    # enough agents in pool to start new conversation
                    self.conversation_index += 1
                    task_id = 't_{}'.format(self.conversation_index)

                    # Add the required number of valid agents to the conv
                    agent_states = [w for w in agent_pool[:needed_agents]]
                    agents = []
                    for state in agent_states:
                        agent = self._create_agent(task_id, state.get_id())
                        agent.onboard_data = state.onboard_data
                        agent.data = state.data
                        state.assign_agent_to_task(agent, task_id)
                        state.set_active_agent(agent)
                        agents.append(agent)
                        # reset wait message state
                        state.stored_data['seen_wait_message'] = False
                    assign_role_function = utils.get_assign_roles_fn(
                        self.world_module, self.taskworld_map[world_type]
                    )
                    if assign_role_function is None:
                        assign_role_function = utils.default_assign_roles_fn
                    assign_role_function(agents)
                    # Allow task creator to filter out workers and run
                    # versions of the task that require fewer agents
                    for a in agents:
                        # Remove selected workers from the agent pool
                        self.remove_agent_from_pool(
                            self.get_agent_state(a.id),
                            world_type=world_type,
                            mark_removed=False,
                        )
                    for a in agents:
                        partner_list = agents.copy()
                        partner_list.remove(a)
                        a.message_partners = partner_list

                    done_callback = self._get_done_callback_for_agents(
                        task_id, world_type, agents
                    )

                    # launch task world.
                    future = self.world_runner.launch_task_world(
                        task_id, self.taskworld_map[world_type], agents
                    )
                    future.add_done_callback(done_callback)
                    self.active_worlds[task_id] = future
 def _world_fn():
     log_utils.print_and_log(
         logging.INFO, 'Starting task {}...'.format(task_name)
     )
     return self._run_world(task, world_name, agents)