class Wrapper: """A wrapper class for ``cls``, the class to be decorated. It contains a reference to the ``proxy`` and a ``message handler`` lookup table and defines a launch method as the universal entry point for running a ``cls`` instance in distributed mode. """ def __init__(self, *args, **kwargs): self.local_instance = cls(*args, **kwargs) self.proxy = proxy self._handler_function = {} self._registry_table = RegisterTable(self.proxy.get_peers()) # Use functools.partial to freeze handling function's local_instance and proxy # arguments to self.local_instance and self.proxy. for handler_fn, constraint in handler_dict.items(): self._handler_function[handler_fn] = partial( handler_fn, self.local_instance, self.proxy) self._registry_table.register_event_handler( constraint, handler_fn) def __getattr__(self, name): if name in self.__dict__: return self.__dict__[name] return getattr(self.local_instance, name) def launch(self): """Universal entry point for running a ``cls`` instance in distributed mode.""" for msg in self.proxy.receive(): self._registry_table.push(msg) triggered_event = self._registry_table.get() for handler_fn, msg_lst in triggered_event: self._handler_function[handler_fn](msg_lst)
def __init__(self, *args, **kwargs): self.local_instance = cls(*args, **kwargs) self.proxy = proxy self._handler_function = {} self._registry_table = RegisterTable(self.proxy.get_peers()) # Use functools.partial to freeze handling function's local_instance and proxy # arguments to self.local_instance and self.proxy. for handler_fn, constraint in handler_dict.items(): self._handler_function[handler_fn] = partial( handler_fn, self.local_instance, self.proxy) self._registry_table.register_event_handler( constraint, handler_fn)
def __init__(self, group_name: str, num_actors: int, update_trigger: str = None, proxy_options: dict = None, log_dir: str = getcwd()): self.agent = None peers = {"actor": num_actors} if proxy_options is None: proxy_options = {} self._proxy = Proxy(group_name, "learner", peers, **proxy_options) self._actors = self._proxy.peers_name["actor"] # remote actor ID's self._registry_table = RegisterTable(self._proxy.peers_name) if update_trigger is None: update_trigger = len(self._actors) self._registry_table.register_event_handler( f"actor:{MessageTag.FINISHED.value}:{update_trigger}", self._on_rollout_finish) self.logger = Logger("ACTOR_PROXY", dump_folder=log_dir)
def setUp(self) -> None: print(f"clear register table before each test.") self.register_table = RegisterTable(get_peers)
class TestRegisterTable(unittest.TestCase): def setUp(self) -> None: print(f"clear register table before each test.") self.register_table = RegisterTable(get_peers) @classmethod def setUpClass(cls) -> None: print(f"The register table unit test start!") # Prepare message dict for test cls.message_dict = { "worker_a": defaultdict(list), "worker_b": defaultdict(list) } worker_a_list = [ "worker_a.1", "worker_a.2", "worker_a.3", "worker_a.4", "worker_a.5" ] worker_b_list = [ "worker_b.1", "worker_b.2", "worker_b.3", "worker_b.4", "worker_b.5" ] tag_type = ["tag_a", "tag_b"] for source in worker_a_list: for tag in tag_type: message = SessionMessage(tag=tag, source=source, destination="test") cls.message_dict["worker_a"][tag].append(message) for source in worker_b_list: for tag in tag_type: message = SessionMessage(tag=tag, source=source, destination="test") cls.message_dict["worker_b"][tag].append(message) @classmethod def tearDownClass(cls) -> None: print(f"The register table unit test finished!") def test_unit_conditional_event(self): # Accept a message from worker_a with tag_a. unit_event_1 = "worker_a:tag_a:1" self.register_table.register_event_handler(unit_event_1, handle_function) for msg in TestRegisterTable.message_dict["worker_a"]["tag_a"]: # The message from worker_a with tag_a, it will trigger handler function each time. self.register_table.push(msg) self.assertIsNotNone(self.register_table.get()) for msg in TestRegisterTable.message_dict["worker_b"]["tag_b"]: # The message from worker_b with tag_b, the register table won't be trigger anytime. self.register_table.push(msg) self.assertEqual(self.register_table.get(), []) def test_special_symbol(self): # Accept a message from worker_a with any tags. unit_event_2 = "worker_a:*:1" self.register_table.register_event_handler(unit_event_2, handle_function) for msg in TestRegisterTable.message_dict["worker_a"][ "tag_a"] + TestRegisterTable.message_dict["worker_a"]["tag_b"]: # The message from worker_a with any tags, it will trigger handler function each time. self.register_table.push(msg) self.assertIsNotNone(self.register_table.get()) for msg in TestRegisterTable.message_dict["worker_b"]["tag_a"]: # The message from worker_b with tag_a, it won't trigger handler function. self.register_table.push(msg) self.assertEqual(self.register_table.get(), []) def test_percentage_case(self): # Accept messages from any source with tag_a until the number of message reach 60% of source number. unit_event_2 = "*:tag_a:50%" self.register_table.register_event_handler(unit_event_2, handle_function) for idx, msg in enumerate( TestRegisterTable.message_dict["worker_a"]["tag_a"] + TestRegisterTable.message_dict["worker_b"]["tag_a"]): # The message with tag_a, it will trigger handler function until receiving 5 times. self.register_table.push(msg) if (idx + 1) % 5 == 0: self.assertIsNotNone(self.register_table.get()) else: self.assertEqual(self.register_table.get(), []) def test_conditional_event(self): # Accept the combination of two messages: one from worker_a with tag_a, and one from worker_b with tag_a. and_conditional_event = ("worker_a:tag_a:1", "worker_b:tag_a:1", "AND") self.register_table.register_event_handler(and_conditional_event, handle_function) for idx, msg in enumerate( TestRegisterTable.message_dict["worker_a"]["tag_a"] + TestRegisterTable.message_dict["worker_b"]["tag_a"]): # The messages with tag_a from worker_a and worker_b, it will trigger handler function until the # combination be satisfied. self.register_table.push(msg) if idx >= 5: self.assertIsNotNone(self.register_table.get()) else: self.assertEqual(self.register_table.get(), []) # Accept the message from worker_a with tag_a, or from worker_b with tag_a. or_conditional_event = ("worker_a:tag_a:1", "worker_b:tag_a:1", "OR") self.register_table.register_event_handler(or_conditional_event, handle_function) for idx, msg in enumerate( TestRegisterTable.message_dict["worker_a"]["tag_a"] + TestRegisterTable.message_dict["worker_b"]["tag_a"]): # The messages with tag_a from worker_a and worker_b, it will trigger handler function each time. self.register_table.push(msg) self.assertIsNotNone(self.register_table.get()) def test_complicated_conditional_event(self): # Accept the combination of three messages: one from worker_a with tag_a, one from worker_b with tag_a, # and one from worker_a with tag_b. recurrent_conditional_event = (("worker_a:tag_a:1", "worker_b:tag_a:1", "AND"), "worker_a:tag_b:1", "AND") self.register_table.register_event_handler(recurrent_conditional_event, handle_function) for msg in TestRegisterTable.message_dict["worker_a"]["tag_a"] + \ TestRegisterTable.message_dict["worker_b"]["tag_a"]: self.register_table.push(msg) for msg in TestRegisterTable.message_dict["worker_a"]["tag_b"]: self.register_table.push(msg) self.assertIsNotNone(self.register_table.get()) def test_multiple_trigger(self): # Accept a message from worker_a with tag_a unit_event_1 = "worker_a:tag_a:1" # Accept a message from worker_a with any tag. unit_event_2 = "worker_a:*:1" self.register_table.register_event_handler(unit_event_1, handle_function) self.register_table.register_event_handler(unit_event_2, handle_function) for msg in TestRegisterTable.message_dict["worker_a"]["tag_a"]: # For each message from worker_a with tag_a, it will trigger two handler functions, and both of them will # have the same message. self.register_table.push(msg) res = self.register_table.get() self.assertEqual(len(res), 2) self.assertEqual(res[0][1], res[1][1])
class ActorProxy(object): """Actor proxy that manages a set of remote actors. Args: group_name (str): Identifier of the group to which the actor belongs. It must be the same group name assigned to the actors (and roll-out clients, if any). num_actors (int): Expected number of actors in the group identified by ``group_name``. update_trigger (str): Number or percentage of ``MessageTag.FINISHED`` messages required to trigger learner updates, i.e., model training. proxy_options (dict): Keyword parameters for the internal ``Proxy`` instance. See ``Proxy`` class for details. Defaults to None. """ def __init__( self, group_name: str, num_actors: int, update_trigger: str = None, proxy_options: dict = None ): self.agent = None peers = {"actor": num_actors} if proxy_options is None: proxy_options = {} self._proxy = Proxy(group_name, "learner", peers, **proxy_options) self._actors = self._proxy.peers_name["actor"] # remote actor ID's self._registry_table = RegisterTable(self._proxy.peers_name) if update_trigger is None: update_trigger = len(self._actors) self._registry_table.register_event_handler( f"actor:{MessageTag.FINISHED.value}:{update_trigger}", self._on_rollout_finish ) self.logger = InternalLogger("ACTOR_PROXY") def roll_out(self, index: int, training: bool = True, model_by_agent: dict = None, exploration_params=None): """Collect roll-out data from remote actors. Args: index (int): Index of roll-out requests. training (bool): If true, the roll-out request is for training purposes. model_by_agent (dict): Models to be broadcast to remote actors for inference. Defaults to None. exploration_params: Exploration parameters to be used by the remote roll-out actors. Defaults to None. """ payload = { PayloadKey.ROLLOUT_INDEX: index, PayloadKey.TRAINING: training, PayloadKey.MODEL: model_by_agent, PayloadKey.EXPLORATION_PARAMS: exploration_params } self._proxy.iscatter(MessageTag.ROLLOUT, SessionType.TASK, [(actor, payload) for actor in self._actors]) self.logger.info(f"Sent roll-out requests to {self._actors} for ep-{index}") # Receive roll-out results from remote actors for msg in self._proxy.receive(): if msg.payload[PayloadKey.ROLLOUT_INDEX] != index: self.logger.info( f"Ignore a message of type {msg.tag} with ep {msg.payload[PayloadKey.ROLLOUT_INDEX]} " f"(expected {index} or greater)" ) continue if msg.tag == MessageTag.FINISHED: # If enough update messages have been received, call update() and break out of the loop to start # the next episode. result = self._registry_table.push(msg) if result: env_metrics, details = result[0] break return env_metrics, details def _on_rollout_finish(self, messages: List[Message]): metrics = {msg.source: msg.payload[PayloadKey.METRICS] for msg in messages} details = {msg.source: msg.payload[PayloadKey.DETAILS] for msg in messages} return metrics, details def terminate(self): """Tell the remote actors to exit.""" self._proxy.ibroadcast( component_type="actor", tag=MessageTag.EXIT, session_type=SessionType.NOTIFICATION ) self.logger.info("Exiting...")
def setUp(self) -> None: print(f"clear register table before each test.") proxy = FakedProxy() self.register_table = RegisterTable(proxy.peers_dict)