def _send_request_and_receive_response(self, server_uri, file_path):
     key = '{}_{}'.format(file_path, time.time_ns())
     client = NotificationClient(server_uri=server_uri,
                                 default_namespace=SCHEDULER_NAMESPACE)
     event = BaseEvent(key=key,
                       event_type=SchedulerInnerEventType.PARSE_DAG_REQUEST.value,
                       value=file_path)
     client.send_event(event)
     watcher: ResponseWatcher = ResponseWatcher()
     client.start_listen_event(key=key,
                               event_type=SchedulerInnerEventType.PARSE_DAG_RESPONSE.value,
                               watcher=watcher)
     res: BaseEvent = watcher.get_result()
     self.assertEquals(event.key, res.key)
     self.assertEquals(event.value, file_path)
class EventSchedulerClient(object):
    def __init__(self, server_uri=None, namespace=None, ns_client=None):
        if ns_client is None:
            self.ns_client = NotificationClient(server_uri, namespace)
        else:
            self.ns_client = ns_client

    @staticmethod
    def generate_id(id):
        return '{}_{}'.format(id, time.time_ns())

    def trigger_parse_dag(self) -> bool:
        id = self.generate_id('')
        watcher: ResponseWatcher = ResponseWatcher()
        handler: ThreadEventWatcherHandle \
            = self.ns_client.start_listen_event(key=id,
                                                event_type=SchedulerInnerEventType.PARSE_DAG_RESPONSE.value,
                                                namespace=SCHEDULER_NAMESPACE, watcher=watcher)

        self.ns_client.send_event(
            BaseEvent(
                key=id,
                event_type=SchedulerInnerEventType.PARSE_DAG_REQUEST.value,
                value=''))
        result = watcher.get_result()
        handler.stop()
        return True

    def schedule_dag(self, dag_id) -> ExecutionContext:
        id = self.generate_id(dag_id)
        watcher: ResponseWatcher = ResponseWatcher()
        handler: ThreadEventWatcherHandle \
            = self.ns_client.start_listen_event(key=id,
                                                event_type=SchedulerInnerEventType.RESPONSE.value,
                                                namespace=SCHEDULER_NAMESPACE, watcher=watcher)
        self.ns_client.send_event(
            RequestEvent(request_id=id,
                         body=RunDagMessage(dag_id).to_json()).to_event())
        result: ResponseEvent = ResponseEvent.from_base_event(
            watcher.get_result())
        handler.stop()
        return ExecutionContext(dagrun_id=result.body)

    def stop_dag_run(self, dag_id,
                     context: ExecutionContext) -> ExecutionContext:
        id = self.generate_id(str(dag_id) + str(context.dagrun_id))
        watcher: ResponseWatcher = ResponseWatcher()
        handler: ThreadEventWatcherHandle \
            = self.ns_client.start_listen_event(key=id,
                                                event_type=SchedulerInnerEventType.RESPONSE.value,
                                                namespace=SCHEDULER_NAMESPACE, watcher=watcher)
        self.ns_client.send_event(
            RequestEvent(
                request_id=id,
                body=StopDagRunMessage(
                    dag_id=dag_id,
                    dagrun_id=context.dagrun_id).to_json()).to_event())
        result: ResponseEvent = ResponseEvent.from_base_event(
            watcher.get_result())
        handler.stop()
        return ExecutionContext(dagrun_id=result.body)

    def schedule_task(self, dag_id: str, task_id: str,
                      action: SchedulingAction,
                      context: ExecutionContext) -> ExecutionContext:
        id = self.generate_id(context.dagrun_id)
        watcher: ResponseWatcher = ResponseWatcher()
        handler: ThreadEventWatcherHandle \
            = self.ns_client.start_listen_event(key=id,
                                                event_type=SchedulerInnerEventType.RESPONSE.value,
                                                namespace=SCHEDULER_NAMESPACE, watcher=watcher)
        self.ns_client.send_event(
            RequestEvent(request_id=id,
                         body=ExecuteTaskMessage(
                             dag_id=dag_id,
                             task_id=task_id,
                             dagrun_id=context.dagrun_id,
                             action=action.value).to_json()).to_event())
        result: ResponseEvent = ResponseEvent.from_base_event(
            watcher.get_result())
        handler.stop()
        return ExecutionContext(dagrun_id=result.body)
Example #3
0
class NotificationTest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.storage = EventModelStorage()
        cls.master = NotificationMaster(NotificationService(cls.storage))
        cls.master.run()

    @classmethod
    def tearDownClass(cls):
        cls.master.stop()

    def setUp(self):
        self.storage.clean_up()
        self.client = NotificationClient(server_uri="localhost:50051")

    def tearDown(self):
        self.client.stop_listen_events()
        self.client.stop_listen_event()

    def test_send_event(self):
        event = self.client.send_event(Event(key="key", value="value1"))
        self.assertTrue(event.version > 0)

    def test_list_events(self):
        event1 = self.client.send_event(Event(key="key", value="value1"))
        event2 = self.client.send_event(Event(key="key", value="value2"))
        event3 = self.client.send_event(Event(key="key", value="value3"))
        events = self.client.list_events("key", version=event1.version)
        self.assertEqual(2, len(events))

    def test_listen_events(self):
        event_list = []

        class TestWatch(EventWatcher):
            def __init__(self, event_list) -> None:
                super().__init__()
                self.event_list = event_list

            def process(self, events: List[Event]):
                self.event_list.extend(events)

        event1 = self.client.send_event(Event(key="key", value="value1"))
        self.client.start_listen_event(key="key",
                                       watcher=TestWatch(event_list),
                                       version=event1.version)
        event = self.client.send_event(Event(key="key", value="value2"))
        event = self.client.send_event(Event(key="key", value="value3"))
        self.client.stop_listen_event("key")
        events = self.client.list_events("key", version=event1.version)
        self.assertEqual(2, len(events))
        self.assertEqual(2, len(event_list))

    def test_all_listen_events(self):
        event = self.client.send_event(Event(key="key", value="value1"))
        event = self.client.send_event(Event(key="key", value="value2"))
        start_time = event.create_time
        event = self.client.send_event(Event(key="key", value="value3"))
        events = self.client.list_all_events(start_time)
        self.assertEqual(2, len(events))

    def test_listen_all_events(self):
        event_list = []

        class TestWatch(EventWatcher):
            def __init__(self, event_list) -> None:
                super().__init__()
                self.event_list = event_list

            def process(self, events: List[Event]):
                self.event_list.extend(events)

        try:
            self.client.start_listen_events(watcher=TestWatch(event_list))
            event = self.client.send_event(Event(key="key1", value="value1"))
            event = self.client.send_event(Event(key="key2", value="value2"))
            event = self.client.send_event(Event(key="key3", value="value3"))
        finally:
            self.client.stop_listen_events()
        self.assertEqual(3, len(event_list))