コード例 #1
0
        class Wrapper:
            """A wrapper class for ``cls``, the class to be decorated.

            It contains a reference to the ``proxy`` and a ``message handler`` lookup table and defines a launch method
            as the universal entry point for running a ``cls`` instance in distributed mode.
            """
            def __init__(self, *args, **kwargs):
                self.local_instance = cls(*args, **kwargs)
                self.proxy = proxy
                self._handler_function = {}
                self._registry_table = RegisterTable(self.proxy.get_peers())
                # Use functools.partial to freeze handling function's local_instance and proxy
                # arguments to self.local_instance and self.proxy.
                for handler_fn, constraint in handler_dict.items():
                    self._handler_function[handler_fn] = partial(
                        handler_fn, self.local_instance, self.proxy)
                    self._registry_table.register_event_handler(
                        constraint, handler_fn)

            def __getattr__(self, name):
                if name in self.__dict__:
                    return self.__dict__[name]

                return getattr(self.local_instance, name)

            def launch(self):
                """Universal entry point for running a ``cls`` instance in distributed mode."""
                for msg in self.proxy.receive():
                    self._registry_table.push(msg)
                    triggered_event = self._registry_table.get()

                    for handler_fn, msg_lst in triggered_event:
                        self._handler_function[handler_fn](msg_lst)
コード例 #2
0
 def __init__(self, *args, **kwargs):
     self.local_instance = cls(*args, **kwargs)
     self.proxy = proxy
     self._handler_function = {}
     self._registry_table = RegisterTable(self.proxy.get_peers())
     # Use functools.partial to freeze handling function's local_instance and proxy
     # arguments to self.local_instance and self.proxy.
     for handler_fn, constraint in handler_dict.items():
         self._handler_function[handler_fn] = partial(
             handler_fn, self.local_instance, self.proxy)
         self._registry_table.register_event_handler(
             constraint, handler_fn)
コード例 #3
0
 def __init__(self,
              group_name: str,
              num_actors: int,
              update_trigger: str = None,
              proxy_options: dict = None,
              log_dir: str = getcwd()):
     self.agent = None
     peers = {"actor": num_actors}
     if proxy_options is None:
         proxy_options = {}
     self._proxy = Proxy(group_name, "learner", peers, **proxy_options)
     self._actors = self._proxy.peers_name["actor"]  # remote actor ID's
     self._registry_table = RegisterTable(self._proxy.peers_name)
     if update_trigger is None:
         update_trigger = len(self._actors)
     self._registry_table.register_event_handler(
         f"actor:{MessageTag.FINISHED.value}:{update_trigger}",
         self._on_rollout_finish)
     self.logger = Logger("ACTOR_PROXY", dump_folder=log_dir)
コード例 #4
0
ファイル: test_registry_table.py プロジェクト: xniac/maro
 def setUp(self) -> None:
     print(f"clear register table before each test.")
     self.register_table = RegisterTable(get_peers)
コード例 #5
0
ファイル: test_registry_table.py プロジェクト: xniac/maro
class TestRegisterTable(unittest.TestCase):
    def setUp(self) -> None:
        print(f"clear register table before each test.")
        self.register_table = RegisterTable(get_peers)

    @classmethod
    def setUpClass(cls) -> None:
        print(f"The register table unit test start!")
        # Prepare message dict for test
        cls.message_dict = {
            "worker_a": defaultdict(list),
            "worker_b": defaultdict(list)
        }

        worker_a_list = [
            "worker_a.1", "worker_a.2", "worker_a.3", "worker_a.4",
            "worker_a.5"
        ]
        worker_b_list = [
            "worker_b.1", "worker_b.2", "worker_b.3", "worker_b.4",
            "worker_b.5"
        ]
        tag_type = ["tag_a", "tag_b"]

        for source in worker_a_list:
            for tag in tag_type:
                message = SessionMessage(tag=tag,
                                         source=source,
                                         destination="test")
                cls.message_dict["worker_a"][tag].append(message)

        for source in worker_b_list:
            for tag in tag_type:
                message = SessionMessage(tag=tag,
                                         source=source,
                                         destination="test")
                cls.message_dict["worker_b"][tag].append(message)

    @classmethod
    def tearDownClass(cls) -> None:
        print(f"The register table unit test finished!")

    def test_unit_conditional_event(self):
        # Accept a message from worker_a with tag_a.
        unit_event_1 = "worker_a:tag_a:1"
        self.register_table.register_event_handler(unit_event_1,
                                                   handle_function)
        for msg in TestRegisterTable.message_dict["worker_a"]["tag_a"]:
            # The message from worker_a with tag_a, it will trigger handler function each time.
            self.register_table.push(msg)
            self.assertIsNotNone(self.register_table.get())

        for msg in TestRegisterTable.message_dict["worker_b"]["tag_b"]:
            # The message from worker_b with tag_b, the register table won't be trigger anytime.
            self.register_table.push(msg)
            self.assertEqual(self.register_table.get(), [])

    def test_special_symbol(self):
        # Accept a message from worker_a with any tags.
        unit_event_2 = "worker_a:*:1"
        self.register_table.register_event_handler(unit_event_2,
                                                   handle_function)
        for msg in TestRegisterTable.message_dict["worker_a"][
                "tag_a"] + TestRegisterTable.message_dict["worker_a"]["tag_b"]:
            # The message from worker_a with any tags, it will trigger handler function each time.
            self.register_table.push(msg)
            self.assertIsNotNone(self.register_table.get())

        for msg in TestRegisterTable.message_dict["worker_b"]["tag_a"]:
            # The message from worker_b with tag_a, it won't trigger handler function.
            self.register_table.push(msg)
            self.assertEqual(self.register_table.get(), [])

    def test_percentage_case(self):
        # Accept messages from any source with tag_a until the number of message reach 60% of source number.
        unit_event_2 = "*:tag_a:50%"
        self.register_table.register_event_handler(unit_event_2,
                                                   handle_function)
        for idx, msg in enumerate(
                TestRegisterTable.message_dict["worker_a"]["tag_a"] +
                TestRegisterTable.message_dict["worker_b"]["tag_a"]):
            # The message with tag_a, it will trigger handler function until receiving 5 times.
            self.register_table.push(msg)
            if (idx + 1) % 5 == 0:
                self.assertIsNotNone(self.register_table.get())
            else:
                self.assertEqual(self.register_table.get(), [])

    def test_conditional_event(self):
        # Accept the combination of two messages: one from worker_a with tag_a, and one from worker_b with tag_a.
        and_conditional_event = ("worker_a:tag_a:1", "worker_b:tag_a:1", "AND")
        self.register_table.register_event_handler(and_conditional_event,
                                                   handle_function)
        for idx, msg in enumerate(
                TestRegisterTable.message_dict["worker_a"]["tag_a"] +
                TestRegisterTable.message_dict["worker_b"]["tag_a"]):
            # The messages with tag_a from worker_a and worker_b, it will trigger handler function until the
            # combination be satisfied.
            self.register_table.push(msg)
            if idx >= 5:
                self.assertIsNotNone(self.register_table.get())
            else:
                self.assertEqual(self.register_table.get(), [])

        # Accept the message from worker_a with tag_a, or from worker_b with tag_a.
        or_conditional_event = ("worker_a:tag_a:1", "worker_b:tag_a:1", "OR")
        self.register_table.register_event_handler(or_conditional_event,
                                                   handle_function)
        for idx, msg in enumerate(
                TestRegisterTable.message_dict["worker_a"]["tag_a"] +
                TestRegisterTable.message_dict["worker_b"]["tag_a"]):
            # The messages with tag_a from worker_a and worker_b, it will trigger handler function each time.
            self.register_table.push(msg)
            self.assertIsNotNone(self.register_table.get())

    def test_complicated_conditional_event(self):
        # Accept the combination of three messages: one from worker_a with tag_a, one from worker_b with tag_a,
        # and one from worker_a with tag_b.
        recurrent_conditional_event = (("worker_a:tag_a:1", "worker_b:tag_a:1",
                                        "AND"), "worker_a:tag_b:1", "AND")
        self.register_table.register_event_handler(recurrent_conditional_event,
                                                   handle_function)

        for msg in TestRegisterTable.message_dict["worker_a"]["tag_a"] + \
                   TestRegisterTable.message_dict["worker_b"]["tag_a"]:
            self.register_table.push(msg)

        for msg in TestRegisterTable.message_dict["worker_a"]["tag_b"]:
            self.register_table.push(msg)
            self.assertIsNotNone(self.register_table.get())

    def test_multiple_trigger(self):
        # Accept a message from worker_a with tag_a
        unit_event_1 = "worker_a:tag_a:1"
        # Accept a message from worker_a with any tag.
        unit_event_2 = "worker_a:*:1"
        self.register_table.register_event_handler(unit_event_1,
                                                   handle_function)
        self.register_table.register_event_handler(unit_event_2,
                                                   handle_function)
        for msg in TestRegisterTable.message_dict["worker_a"]["tag_a"]:
            # For each message from worker_a with tag_a, it will trigger two handler functions, and both of them will
            # have the same message.
            self.register_table.push(msg)
            res = self.register_table.get()
            self.assertEqual(len(res), 2)
            self.assertEqual(res[0][1], res[1][1])
コード例 #6
0
class ActorProxy(object):
    """Actor proxy that manages a set of remote actors.

    Args:
        group_name (str): Identifier of the group to which the actor belongs. It must be the same group name
            assigned to the actors (and roll-out clients, if any).
        num_actors (int): Expected number of actors in the group identified by ``group_name``.
        update_trigger (str): Number or percentage of ``MessageTag.FINISHED`` messages required to trigger
            learner updates, i.e., model training.
        proxy_options (dict): Keyword parameters for the internal ``Proxy`` instance. See ``Proxy`` class
            for details. Defaults to None.
    """
    def __init__(
        self,
        group_name: str,
        num_actors: int,
        update_trigger: str = None,
        proxy_options: dict = None
    ):
        self.agent = None
        peers = {"actor": num_actors}
        if proxy_options is None:
            proxy_options = {}
        self._proxy = Proxy(group_name, "learner", peers, **proxy_options)
        self._actors = self._proxy.peers_name["actor"]  # remote actor ID's
        self._registry_table = RegisterTable(self._proxy.peers_name)
        if update_trigger is None:
            update_trigger = len(self._actors)
        self._registry_table.register_event_handler(
            f"actor:{MessageTag.FINISHED.value}:{update_trigger}", self._on_rollout_finish
        )
        self.logger = InternalLogger("ACTOR_PROXY")

    def roll_out(self, index: int, training: bool = True, model_by_agent: dict = None, exploration_params=None):
        """Collect roll-out data from remote actors.

        Args:
            index (int): Index of roll-out requests.
            training (bool): If true, the roll-out request is for training purposes.
            model_by_agent (dict): Models to be broadcast to remote actors for inference. Defaults to None.
            exploration_params: Exploration parameters to be used by the remote roll-out actors. Defaults to None.
        """
        payload = {
            PayloadKey.ROLLOUT_INDEX: index,
            PayloadKey.TRAINING: training,
            PayloadKey.MODEL: model_by_agent,
            PayloadKey.EXPLORATION_PARAMS: exploration_params
        }
        self._proxy.iscatter(MessageTag.ROLLOUT, SessionType.TASK, [(actor, payload) for actor in self._actors])
        self.logger.info(f"Sent roll-out requests to {self._actors} for ep-{index}")

        # Receive roll-out results from remote actors
        for msg in self._proxy.receive():
            if msg.payload[PayloadKey.ROLLOUT_INDEX] != index:
                self.logger.info(
                    f"Ignore a message of type {msg.tag} with ep {msg.payload[PayloadKey.ROLLOUT_INDEX]} "
                    f"(expected {index} or greater)"
                )
                continue
            if msg.tag == MessageTag.FINISHED:
                # If enough update messages have been received, call update() and break out of the loop to start
                # the next episode.
                result = self._registry_table.push(msg)
                if result:
                    env_metrics, details = result[0]
                    break

        return env_metrics, details

    def _on_rollout_finish(self, messages: List[Message]):
        metrics = {msg.source: msg.payload[PayloadKey.METRICS] for msg in messages}
        details = {msg.source: msg.payload[PayloadKey.DETAILS] for msg in messages}
        return metrics, details

    def terminate(self):
        """Tell the remote actors to exit."""
        self._proxy.ibroadcast(
            component_type="actor", tag=MessageTag.EXIT, session_type=SessionType.NOTIFICATION
        )
        self.logger.info("Exiting...")
コード例 #7
0
 def setUp(self) -> None:
     print(f"clear register table before each test.")
     proxy = FakedProxy()
     self.register_table = RegisterTable(proxy.peers_dict)