예제 #1
0
    def setUpClass(cls) -> None:
        print(f"The register table unit test start!")
        # Prepare message dict for test
        cls.message_dict = {
            "worker_a": defaultdict(list),
            "worker_b": defaultdict(list)
        }

        worker_a_list = [
            "worker_a.1", "worker_a.2", "worker_a.3", "worker_a.4",
            "worker_a.5"
        ]
        worker_b_list = [
            "worker_b.1", "worker_b.2", "worker_b.3", "worker_b.4",
            "worker_b.5"
        ]
        tag_type = ["tag_a", "tag_b"]

        for source in worker_a_list:
            for tag in tag_type:
                message = SessionMessage(tag=tag,
                                         source=source,
                                         destination="test")
                cls.message_dict["worker_a"][tag].append(message)

        for source in worker_b_list:
            for tag in tag_type:
                message = SessionMessage(tag=tag,
                                         source=source,
                                         destination="test")
                cls.message_dict["worker_b"][tag].append(message)
예제 #2
0
    def reply(self,
              received_message: SessionMessage,
              tag: Union[str, Enum] = None,
              payload=None,
              ack_reply: bool = False) -> List[str]:
        """Reply a received message.

        Args:
            received_message (Message): The message need to reply.
            tag (str|Enum): New message tag, if None, keeps the original message's tag. Defaults to None.
            payload (object): New message payload, if None, keeps the original message's payload. Defaults to None.
            ack_reply (bool): If True, it is acknowledge reply. Defaults to False.

        Returns:
            List[str]: Message belonged session id.
        """
        if received_message.session_type == SessionType.TASK:
            session_stage = TaskSessionStage.RECEIVE if ack_reply else TaskSessionStage.COMPLETE
        else:
            session_stage = NotificationSessionStage.RECEIVE

        replied_message = SessionMessage(
            tag=tag if tag else received_message.tag,
            source=self._name,
            destination=received_message.source,
            session_id=received_message.session_id,
            payload=payload,
            session_stage=session_stage)
        return self.isend(replied_message)
예제 #3
0
    def _scatter(self,
                 tag: Union[str, Enum],
                 session_type: SessionType,
                 destination_payload_list: list,
                 session_id: str = None) -> List[str]:
        """Scatters a list of data to peers, and return list of session id."""
        session_id_list = []

        for destination, payload in destination_payload_list:
            message = SessionMessage(tag=tag,
                                     source=self._name,
                                     destination=destination,
                                     session_id=session_id,
                                     payload=payload,
                                     session_type=session_type)
            sending_status = self._driver.send(message)

            if not sending_status:
                session_id_list.append(message.session_id)
            elif sending_status and self._is_enable_fault_tolerant:
                self._logger.warn(
                    f"{self._name} failure to send message to {message.destination}, as {str(sending_status)}"
                )
            else:
                raise sending_status

        return session_id_list
예제 #4
0
파일: send.py 프로젝트: yumiaoGitHub/maro
def master(group_name: str, is_immediate: bool = False):
    """
    The main master logic includes initialize proxy and allocate jobs to workers.

    Args:
        group_name (str): Identifier for the group of all communication components,
        is_immediate (bool): If True, it will be an async mode; otherwise, it will be an sync mode.
            Async Mode: The proxy only returns the session id for sending messages. Based on the local task priority,
                        you can do something with high priority before receiving replied messages from peers.
            Sync Mode: It will block until the proxy returns all the replied messages.
    """
    proxy = Proxy(group_name=group_name,
                  component_type="master",
                  expected_peers={"worker": 1})

    random_integer_list = np.random.randint(0, 100, 5)
    print(f"generate random integer list: {random_integer_list}.")

    for peer in proxy.peers_name["worker"]:
        message = SessionMessage(tag="sum",
                                 source=proxy.name,
                                 destination=peer,
                                 payload=random_integer_list,
                                 session_type=SessionType.TASK)
        if is_immediate:
            session_id = proxy.isend(message)
            # Do some tasks with higher priority here.
            replied_msgs = proxy.receive_by_id(session_id, timeout=-1)
        else:
            replied_msgs = proxy.send(message, timeout=-1)

        for msg in replied_msgs:
            print(
                f"{proxy.name} receive {msg.source}, replied payload is {msg.payload}."
            )
예제 #5
0
    def forward(self,
                received_message: SessionMessage,
                destination: str,
                tag: Union[str, Enum] = None,
                payload=None) -> List[str]:
        """
        forward a received message.

        Args:
            received_message (Message): The message need to forward,
            destination (str): The receiver of message,
            tag (str|Enum): Message tag, which is customized by the user, for specific application logic,
            payload (object): Message payload, such as model parameters, experiences, etc.

        Returns:
            session_id List[str]: Message belonged session id.
        """
        forward_message = SessionMessage(
            tag=tag if tag else received_message.tag,
            source=self._name,
            destination=destination,
            session_id=received_message.session_id,
            payload=payload if payload else received_message.payload,
            session_stage=received_message.session_stage)
        return self.isend(forward_message)
예제 #6
0
    def reply(self,
              received_message: SessionMessage,
              tag: Union[str, Enum] = None,
              payload=None,
              ack_reply: bool = False) -> List[str]:
        """
        Reply a received message.

        Args:
            received_message (Message): The message need to reply,
            tag (str|Enum): Message tag, which is customized by the user, for specific application logic,
            payload (object): Message payload, such as model parameters, experiences, etc,
            ack_reply (bool): If True, it is acknowledge reply.

        Returns:
            session_id List[str]: Message belonged session id.
        """
        if received_message.session_type == SessionType.TASK:
            session_stage = TaskSessionStage.RECEIVE if ack_reply else TaskSessionStage.COMPLETE
        else:
            session_stage = NotificationSessionStage.RECEIVE

        replied_message = SessionMessage(
            tag=tag if tag else received_message.tag,
            source=self._name,
            destination=received_message.source,
            session_id=received_message.session_id,
            payload=payload,
            session_stage=session_stage)
        return self.isend(replied_message)
예제 #7
0
    def test_send(self):
        for worker_proxy in TestProxy.worker_proxies:
            send_msg = SessionMessage(tag="unit_test",
                                      source=TestProxy.master_proxy.component_name,
                                      destination=worker_proxy.component_name,
                                      payload="hello_world!")
            TestProxy.master_proxy.isend(send_msg)

            for receive_message in worker_proxy.receive(is_continuous=False):
                self.assertEqual(send_msg.payload, receive_message.payload)
예제 #8
0
    def test_decorator(self):
        message = SessionMessage(
            tag="unittest",
            source=TestDecorator.sender_proxy.component_name,
            destination=TestDecorator.sender_proxy.peers["receiver"][0],
            payload={"counter": 0})
        replied_message = TestDecorator.sender_proxy.send(message)

        self.assertEqual(message.payload["counter"] + 1,
                         replied_message[0].payload["counter"])
예제 #9
0
    def test_send(self):
        for peer in self.peer_list:
            message = SessionMessage(tag="unit_test",
                                     source="sender",
                                     destination=peer,
                                     payload="hello_world")
            self.sender.send(message)

            for received_message in self.receivers[peer].receive(
                    is_continuous=False):
                self.assertEqual(received_message.payload, message.payload)
예제 #10
0
    def test_reply(self):
        for worker_proxy in TestProxy.worker_proxies:
            send_msg = SessionMessage(tag="unit_test",
                                      source=TestProxy.master_proxy.component_name,
                                      destination=worker_proxy.component_name,
                                      payload="hello ")
            session_id_list = TestProxy.master_proxy.isend(send_msg)

            for receive_message in worker_proxy.receive(is_continuous=False):
                worker_proxy.reply(received_message=receive_message, tag="unit_test", payload="world!")

            replied_msg_list = TestProxy.master_proxy.receive_by_id(session_id_list)
            self.assertEqual(send_msg.payload + replied_msg_list[0].payload, "hello world!")
예제 #11
0
    def test_broadcast(self):
        executor = ThreadPoolExecutor(max_workers=len(TestDriver.peer_list))
        all_task = [executor.submit(message_receive, (TestDriver.receivers[peer])) for peer in TestDriver.peer_list]

        message = SessionMessage(
            tag="unit_test",
            source="sender",
            destination="*",
            payload="hello_world"
        )
        TestDriver.sender.broadcast(topic="receiver", message=message)

        for task in as_completed(all_task):
            res = task.result()
            self.assertEqual(res, message.payload)
예제 #12
0
    def test_rejoin(self):
        # Check all connected.
        destination_payload_list = []
        for peer in TestRejoin.peers:
            destination_payload_list.append((peer, "continuous"))

        # Connection check.
        replied = TestRejoin.master_proxy.scatter(
            tag="cont",
            session_type=SessionType.NOTIFICATION,
            destination_payload_list=destination_payload_list
        )
        self.assertEqual(len(replied), TestRejoin.peers_number)

        # Disconnect one peer.
        disconnect_message = SessionMessage(
            tag="stop",
            source=TestRejoin.master_proxy.component_name,
            destination=TestRejoin.peers[1],
            payload=None,
            session_type=SessionType.TASK
        )
        TestRejoin.master_proxy.isend(disconnect_message)

        # Now, 1 peer exited, only have 2 peers.
        time.sleep(2)
        replied = TestRejoin.master_proxy.scatter(
            tag="cont", session_type=SessionType.NOTIFICATION,
            destination_payload_list=destination_payload_list
        )
        self.assertEqual(len(replied), TestRejoin.peers_number-1)

        # Wait for rejoin.
        time.sleep(5)
        # Now, all peers rejoin.
        replied = TestRejoin.master_proxy.scatter(
            tag="cont", session_type=SessionType.NOTIFICATION,
            destination_payload_list=destination_payload_list
        )
        self.assertEqual(len(replied), TestRejoin.peers_number+1)
예제 #13
0
    def _broadcast(self,
                   tag: Union[str, Enum],
                   session_type: SessionType,
                   session_id: str = None,
                   payload=None) -> List[str]:
        """Broadcast message to all peers, and return list of session id."""
        message = SessionMessage(tag=tag,
                                 source=self._name,
                                 destination="*",
                                 payload=payload,
                                 session_id=session_id,
                                 session_type=session_type)

        broadcast_status = self._driver.broadcast(message)

        if not broadcast_status:
            return [message.session_id] * len(self.get_peers())
        elif broadcast_status and self._is_enable_fault_tolerant:
            self._logger.warn(
                f"{self._name} failure to broadcast message to any peers, as {str(broadcast_status)}"
            )
        else:
            raise broadcast_status
예제 #14
0
    def forward(self,
                received_message: SessionMessage,
                destination: str,
                tag: Union[str, Enum] = None,
                payload=None) -> List[str]:
        """Forward a received message.

        Args:
            received_message (Message): The message need to forward.
            destination (str): The receiver of message.
            tag (str|Enum): New message tag, if None, keeps the original message's tag. Defaults to None.
            payload (object): Message payload, if None, keeps the original message's payload. Defaults to None.

        Returns:
            List[str]: Message belonged session id.
        """
        forward_message = SessionMessage(
            tag=tag if tag else received_message.tag,
            source=self._name,
            destination=destination,
            session_id=received_message.session_id,
            payload=payload if payload else received_message.payload,
            session_stage=received_message.session_stage)
        return self.isend(forward_message)