コード例 #1
0
ファイル: test_rejoin.py プロジェクト: ysqyang/maro
def actor_init(queue, redis_port):
    proxy = Proxy(component_type="actor",
                  expected_peers={"master": 1},
                  redis_address=("localhost", redis_port),
                  **PROXY_PARAMETER)

    # Continuously receive messages from proxy.
    for msg in proxy.receive(is_continuous=True):
        print(f"receive message from master. {msg.tag}")
        if msg.tag == "cont":
            proxy.reply(message=msg, tag="recv", payload="successful receive!")
        elif msg.tag == "stop":
            proxy.reply(message=msg,
                        tag="recv",
                        payload=f"{proxy.name} exited!")
            queue.put(proxy.name)
            break
        elif msg.tag == "finish":
            proxy.reply(message=msg,
                        tag="recv",
                        payload=f"{proxy.name} finish!")
            sys.exit(0)

    proxy.close()
    sys.exit(1)
コード例 #2
0
ファイル: actor.py プロジェクト: yumiaoGitHub/maro
    def as_worker(self, group: str, proxy_options=None, log_dir: str = getcwd()):
        """Executes an event loop where roll-outs are performed on demand from a remote learner.

        Args:
            group (str): Identifier of the group to which the actor belongs. It must be the same group name
                assigned to the learner (and decision clients, if any).
            proxy_options (dict): Keyword parameters for the internal ``Proxy`` instance. See ``Proxy`` class
                for details. Defaults to None.
        """
        if proxy_options is None:
            proxy_options = {}
        proxy = Proxy(group, "actor", {"learner": 1}, **proxy_options)
        logger = Logger(proxy.name, dump_folder=log_dir)
        for msg in proxy.receive():
            if msg.tag == MessageTag.EXIT:
                logger.info("Exiting...")
                proxy.close()
                sys.exit(0)
            elif msg.tag == MessageTag.ROLLOUT:
                ep = msg.payload[PayloadKey.ROLLOUT_INDEX]
                logger.info(f"Rolling out ({ep})...")
                metrics, rollout_data = self.roll_out(
                    ep,
                    training=msg.payload[PayloadKey.TRAINING],
                    model_by_agent=msg.payload[PayloadKey.MODEL],
                    exploration_params=msg.payload[PayloadKey.EXPLORATION_PARAMS]
                )
                if rollout_data is None:
                    logger.info(f"Roll-out {ep} aborted")
                else:
                    logger.info(f"Roll-out {ep} finished")
                    rollout_finish_msg = Message(
                        MessageTag.FINISHED,
                        proxy.name,
                        proxy.peers_name["learner"][0],
                        payload={
                            PayloadKey.ROLLOUT_INDEX: ep,
                            PayloadKey.METRICS: metrics,
                            PayloadKey.DETAILS: rollout_data
                        }
                    )
                    proxy.isend(rollout_finish_msg)
                self.env.reset()
コード例 #3
0
class ActorProxy(object):
    """Actor proxy that manages a set of remote actors.

    Args:
        group_name (str): Identifier of the group to which the actor belongs. It must be the same group name
            assigned to the actors (and roll-out clients, if any).
        num_actors (int): Expected number of actors in the group identified by ``group_name``.
        update_trigger (str): Number or percentage of ``MessageTag.FINISHED`` messages required to trigger
            learner updates, i.e., model training.
        proxy_options (dict): Keyword parameters for the internal ``Proxy`` instance. See ``Proxy`` class
            for details. Defaults to None.
    """
    def __init__(self,
                 group_name: str,
                 num_actors: int,
                 update_trigger: str = None,
                 proxy_options: dict = None,
                 log_dir: str = getcwd()):
        self.agent = None
        peers = {"actor": num_actors}
        if proxy_options is None:
            proxy_options = {}
        self._proxy = Proxy(group_name, "learner", peers, **proxy_options)
        self._actors = self._proxy.peers_name["actor"]  # remote actor ID's
        self._registry_table = RegisterTable(self._proxy.peers_name)
        if update_trigger is None:
            update_trigger = len(self._actors)
        self._registry_table.register_event_handler(
            f"actor:{MessageTag.FINISHED.value}:{update_trigger}",
            self._on_rollout_finish)
        self.logger = Logger("ACTOR_PROXY", dump_folder=log_dir)

    def roll_out(self,
                 index: int,
                 training: bool = True,
                 model_by_agent: dict = None,
                 exploration_params=None):
        """Collect roll-out data from remote actors.

        Args:
            index (int): Index of roll-out requests.
            training (bool): If true, the roll-out request is for training purposes.
            model_by_agent (dict): Models to be broadcast to remote actors for inference. Defaults to None.
            exploration_params: Exploration parameters to be used by the remote roll-out actors. Defaults to None.
        """
        payload = {
            PayloadKey.ROLLOUT_INDEX: index,
            PayloadKey.TRAINING: training,
            PayloadKey.MODEL: model_by_agent,
            PayloadKey.EXPLORATION_PARAMS: exploration_params
        }
        self._proxy.iscatter(MessageTag.ROLLOUT, SessionType.TASK,
                             [(actor, payload) for actor in self._actors])
        self.logger.info(
            f"Sent roll-out requests to {self._actors} for ep-{index}")

        # Receive roll-out results from remote actors
        for msg in self._proxy.receive():
            if msg.payload[PayloadKey.ROLLOUT_INDEX] != index:
                self.logger.info(
                    f"Ignore a message of type {msg.tag} with ep {msg.payload[PayloadKey.ROLLOUT_INDEX]} "
                    f"(expected {index} or greater)")
                continue
            if msg.tag == MessageTag.FINISHED:
                # If enough update messages have been received, call update() and break out of the loop to start
                # the next episode.
                result = self._registry_table.push(msg)
                if result:
                    env_metrics, details = result[0]
                    break

        return env_metrics, details

    def _on_rollout_finish(self, messages: List[Message]):
        metrics = {
            msg.source: msg.payload[PayloadKey.METRICS]
            for msg in messages
        }
        details = {
            msg.source: msg.payload[PayloadKey.DETAILS]
            for msg in messages
        }
        return metrics, details

    def terminate(self):
        """Tell the remote actors to exit."""
        self._proxy.ibroadcast(component_type="actor",
                               tag=MessageTag.EXIT,
                               session_type=SessionType.NOTIFICATION)
        self.logger.info("Exiting...")
        self._proxy.close()