예제 #1
0
def master(group_name: str, worker_num: int, is_immediate: bool = False):
    """
    The main master logic includes initialize proxy and allocate jobs to workers.

    Args:
        group_name (str): Identifier for the group of all communication components,
        worker_num (int): The number of workers,
        is_immediate (bool): If True, it will be an async mode; otherwise, it will be an sync mode.
            Async Mode: The proxy only returns the session id for sending messages. Based on the local task priority,
                        you can do something with high priority before receiving replied messages from peers.
            Sync Mode: It will block until the proxy returns all the replied messages.
    """
    proxy = Proxy(group_name=group_name,
                  component_type="master",
                  expected_peers={"worker": worker_num})

    if is_immediate:
        session_ids = proxy.ibroadcast(tag="INC",
                                       session_type=SessionType.NOTIFICATION)
        # do some tasks with higher priority here.
        replied_msgs = proxy.receive_by_id(session_ids)
    else:
        replied_msgs = proxy.broadcast(tag="INC",
                                       session_type=SessionType.NOTIFICATION)

    for msg in replied_msgs:
        print(
            f"{proxy.component_name} get receive notification from {msg.source} with message session stage "
            + f"{msg.session_stage}.")
class ActorProxy(object):
    def __init__(self, proxy_params):
        self._proxy = Proxy(component_type="actor_proxy", **proxy_params)

    def roll_out(self,
                 mode: RolloutMode,
                 models: dict = None,
                 epsilon_dict: dict = None,
                 seed: int = None):
        if mode == RolloutMode.EXIT:
            # TODO: session type: notification
            self._proxy.broadcast(tag=MessageType.ROLLOUT,
                                  session_type=SessionType.TASK,
                                  payload={PayloadKey.RolloutMode: mode})
            return None, None
        else:
            performance, exp_by_agent = {}, {}
            payloads = [(peer, {
                PayloadKey.MODEL:
                models,
                PayloadKey.RolloutMode:
                mode,
                PayloadKey.EPSILON:
                epsilon_dict,
                PayloadKey.SEED: (seed + i) if seed is not None else None
            }) for i, peer in enumerate(self._proxy.get_peers("actor_worker"))]
            # TODO: double check when ack enable
            replies = self._proxy.scatter(tag=MessageType.ROLLOUT,
                                          session_type=SessionType.TASK,
                                          destination_payload_list=payloads)
            for msg in replies:
                performance[msg.source] = msg.payload[PayloadKey.PERFORMANCE]
                if msg.payload[PayloadKey.EXPERIENCE] is not None:
                    for agent_id, exp_set in msg.payload[
                            PayloadKey.EXPERIENCE].items():
                        if agent_id not in exp_by_agent:
                            exp_by_agent[agent_id] = defaultdict(list)
                        for k, v in exp_set.items():
                            exp_by_agent[agent_id][k].extend(v)

            return performance, exp_by_agent