示例#1
0
class TestAIFlowContext(unittest.TestCase):
    def setUp(self):
        if os.path.exists(_SQLITE_DB_FILE):
            os.remove(_SQLITE_DB_FILE)
        self.server = AIFlowServer(store_uri=_SQLITE_DB_URI, port=_PORT,
                                   start_default_notification=False,
                                   start_meta_service=True,
                                   start_metric_service=False,
                                   start_model_center_service=False,
                                   start_scheduler_service=False)
        self.server.run()

    def tearDown(self):
        self.server.stop()
        if os.path.exists(_SQLITE_DB_FILE):
            os.remove(_SQLITE_DB_FILE)

    def test_init_ai_flow_context(self):
        init_ai_flow_context()
        project_config = current_project_config()
        self.assertEqual('test_project', project_config.get_project_name())
        self.assertEqual('a', project_config.get('a'))
        project_context = current_project_context()
        self.assertEqual('test_project', project_context.project_name)
        workflow_config_ = current_workflow_config()
        self.assertEqual('test_ai_flow_context', workflow_config_.workflow_name)
        self.assertEqual(5, len(workflow_config_.job_configs))
示例#2
0
class TestTensorFlowIrisModel(unittest.TestCase):
    def setUp(self) -> None:
        if os.path.exists(_SQLITE_DB_FILE):
            os.remove(_SQLITE_DB_FILE)
        self.server = AIFlowServer(store_uri=_SQLITE_DB_URI,
                                   port=_PORT,
                                   start_scheduler_service=False)
        self.server.run()
        self.client = AIFlowClient(server_uri='localhost:' + _PORT)

    def tearDown(self) -> None:
        self.client.stop_listen_event()
        self.server.stop()
        os.remove(_SQLITE_DB_FILE)

    def test_save_and_load_model(self):
        iris_model = fit_and_save_model()
        tf_graph = tf.Graph()
        registered_model = self.client.create_registered_model(
            model_name='iris_model', model_desc='iris model')
        self.client.create_model_version(
            model_name=registered_model.model_name,
            model_path=iris_model.path,
            model_type=
            '{"meta_graph_tags":["serve"],"signature_def_map_key":"predict"}',
            version_desc='iris model')

        class IrisWatcher(EventWatcher):
            def process(self, notifications):
                for notification in notifications:
                    model_path = json.loads(
                        notification.value).get('_model_path')
                    model_flavor = json.loads(
                        notification.value).get('_model_type')
                    print(json.loads(notification.value).keys())
                    print(model_path)
                    signature_def = load_tensorflow_saved_model(
                        model_uri=model_path,
                        meta_graph_tags=json.loads(model_flavor).get(
                            'meta_graph_tags'),
                        signature_def_map_key=json.loads(model_flavor).get(
                            'signature_def_map_key'),
                        tf_session=tf.Session(graph=tf_graph))
                    for _, input_signature in signature_def.inputs.items():
                        t_input = tf_graph.get_tensor_by_name(
                            input_signature.name)
                        assert t_input is not None
                    for _, output_signature in signature_def.outputs.items():
                        t_output = tf_graph.get_tensor_by_name(
                            output_signature.name)
                        assert t_output is not None

        self.client.start_listen_event(key=registered_model.model_name,
                                       watcher=IrisWatcher())
示例#3
0
class TestSklearnModel(unittest.TestCase):
    def setUp(self) -> None:
        if os.path.exists(_SQLITE_DB_FILE):
            os.remove(_SQLITE_DB_FILE)
        self.server = AIFlowServer(store_uri=_SQLITE_DB_URI,
                                   port=_PORT,
                                   start_scheduler_service=False)
        self.server.run()
        self.client = AIFlowClient(server_uri='localhost:' + _PORT)

    def tearDown(self) -> None:
        self.client.stop_listen_event()
        self.server.stop()
        os.remove(_SQLITE_DB_FILE)

    def test_save_and_load_model(self):
        knn_model = fit_model()
        model_path = tempfile.mkdtemp()
        if not os.path.exists(model_path):
            os.makedirs(model_path)
        model_path = os.path.join(model_path, 'model.pkl')
        save_model(sk_model=knn_model.model,
                   output_path=model_path,
                   serialization_format=SERIALIZATION_FORMAT_PICKLE)
        registered_model = self.client.create_registered_model(
            model_name='knn_model', model_desc='knn model')
        self.client.create_model_version(
            model_name=registered_model.model_name,
            model_path=model_path,
            model_type='sklearn',
            version_desc='knn model')

        class KnnWatcher(EventWatcher):
            def process(self, notifications):
                for notification in notifications:
                    load_path = json.loads(
                        notification.value).get('_model_path')
                    reloaded_knn_model = load_scikit_learn_model(
                        model_uri=load_path)
                    numpy.testing.assert_array_equal(
                        knn_model.model.predict(knn_model.inference_data),
                        reloaded_knn_model.predict(knn_model.inference_data))
                    os.remove(load_path)

        self.client.start_listen_event(key=registered_model.model_name,
                                       watcher=KnnWatcher())
示例#4
0
class TestHighAvailableAIFlowServer(unittest.TestCase):
    @staticmethod
    def start_aiflow_server(host, port):
        port = str(port)
        server_uri = host + ":" + port
        server = AIFlowServer(store_uri=_SQLITE_DB_URI,
                              port=port,
                              enabled_ha=True,
                              start_scheduler_service=False,
                              ha_server_uri=server_uri,
                              notification_uri='localhost:30031',
                              start_default_notification=False)
        server.run()
        return server

    def wait_for_new_members_detected(self, new_member_uri):
        while True:
            living_member = self.client.living_aiflow_members
            if new_member_uri in living_member:
                break
            else:
                time.sleep(1)

    def setUp(self) -> None:
        SqlAlchemyStore(_SQLITE_DB_URI)
        self.notification = NotificationMaster(
            service=NotificationService(storage=MemoryEventStorage()),
            port=30031)
        self.notification.run()
        self.server1 = AIFlowServer(store_uri=_SQLITE_DB_URI,
                                    port=50051,
                                    enabled_ha=True,
                                    start_scheduler_service=False,
                                    ha_server_uri='localhost:50051',
                                    notification_uri='localhost:30031',
                                    start_default_notification=False)
        self.server1.run()
        self.server2 = None
        self.server3 = None
        self.config = ProjectConfig()
        self.config.set_enable_ha(True)
        self.config.set_notification_service_uri('localhost:30031')
        self.client = AIFlowClient(
            server_uri='localhost:50052,localhost:50051',
            project_config=self.config)

    def tearDown(self) -> None:
        self.client.stop_listen_event()
        self.client.disable_high_availability()
        if self.server1 is not None:
            self.server1.stop()
        if self.server2 is not None:
            self.server2.stop()
        if self.server3 is not None:
            self.server3.stop()
        if self.notification is not None:
            self.notification.stop()
        store = SqlAlchemyStore(_SQLITE_DB_URI)
        base.metadata.drop_all(store.db_engine)

    def test_server_change(self) -> None:
        self.client.register_project("test_project")
        projects = self.client.list_project(10, 0)
        self.assertEqual(self.client.current_aiflow_uri, "localhost:50051")
        self.assertEqual(projects[0].name, "test_project")

        self.server2 = self.start_aiflow_server("localhost", 50052)
        self.wait_for_new_members_detected("localhost:50052")
        self.server1.stop()
        projects = self.client.list_project(10, 0)
        self.assertEqual(self.client.current_aiflow_uri, "localhost:50052")
        self.assertEqual(projects[0].name, "test_project")

        self.server3 = self.start_aiflow_server("localhost", 50053)
        self.wait_for_new_members_detected("localhost:50053")
        self.server2.stop()
        projects = self.client.list_project(10, 0)
        self.assertEqual(self.client.current_aiflow_uri, "localhost:50053")
        self.assertEqual(projects[0].name, "test_project")
示例#5
0
class TestSchedulerService(unittest.TestCase):
    def setUp(self):
        config = SchedulerServiceConfig()
        config.set_scheduler_class_name(SCHEDULER_CLASS)
        config.set_scheduler_config({})
        if os.path.exists(_SQLITE_DB_FILE):
            os.remove(_SQLITE_DB_FILE)
        self.server = AIFlowServer(store_uri=_SQLITE_DB_URI,
                                   port=_PORT,
                                   start_default_notification=False,
                                   start_meta_service=False,
                                   start_metric_service=False,
                                   start_model_center_service=False,
                                   start_scheduler_service=True,
                                   scheduler_service_config=config)
        self.server.run()

    def tearDown(self):
        self.server.stop()
        if os.path.exists(_SQLITE_DB_FILE):
            os.remove(_SQLITE_DB_FILE)

    def test_submit_workflow(self):
        with mock.patch(SCHEDULER_CLASS) as mockScheduler:
            instance = mockScheduler.return_value
            instance.submit_workflow.return_value = WorkflowInfo(
                workflow_name='test_workflow')
            client = SchedulerClient("localhost:{}".format(_PORT))
            with self.assertRaises(Exception) as context:
                workflow = client.submit_workflow_to_scheduler(
                    namespace='namespace',
                    workflow_name='test_workflow',
                    workflow_json='')
            self.assertTrue('workflow json is empty' in str(context.exception))

    def test_pause_workflow(self):
        with mock.patch(SCHEDULER_CLASS) as mockScheduler:
            instance = mockScheduler.return_value
            self.server.scheduler_service._scheduler = instance

            instance.pause_workflow_scheduling.return_value = WorkflowInfo(
                workflow_name='test_workflow')
            client = SchedulerClient("localhost:{}".format(_PORT))
            workflow = client.pause_workflow_scheduling(
                namespace='namespace', workflow_name='test_workflow')
            self.assertTrue('test_workflow', workflow.name)

    def test_resume_workflow(self):
        with mock.patch(SCHEDULER_CLASS) as mockScheduler:
            instance = mockScheduler.return_value
            self.server.scheduler_service._scheduler = instance

            instance.resume_workflow_scheduling.return_value = WorkflowInfo(
                workflow_name='test_workflow')
            client = SchedulerClient("localhost:{}".format(_PORT))
            workflow = client.resume_workflow_scheduling(
                namespace='namespace', workflow_name='test_workflow')
            self.assertTrue('test_workflow', workflow.name)

    def test_start_new_workflow_execution(self):
        with mock.patch(SCHEDULER_CLASS) as mockScheduler:
            instance = mockScheduler.return_value
            self.server.scheduler_service._scheduler = instance

            instance.start_new_workflow_execution.return_value \
                = WorkflowExecutionInfo(workflow_execution_id='id', status=Status.INIT)
            client = SchedulerClient("localhost:{}".format(_PORT))
            workflow_execution = client.start_new_workflow_execution(
                namespace='namespace', workflow_name='test_workflow')
            self.assertEqual('id', workflow_execution.execution_id)
            self.assertEqual(StateProto.INIT,
                             workflow_execution.execution_state)

    def test_kill_all_workflow_execution(self):
        with mock.patch(SCHEDULER_CLASS) as mockScheduler:
            instance = mockScheduler.return_value
            self.server.scheduler_service._scheduler = instance

            instance.stop_all_workflow_execution.return_value \
                = [WorkflowExecutionInfo(workflow_execution_id='id_1', status=Status.INIT),
                   WorkflowExecutionInfo(workflow_execution_id='id_2', status=Status.INIT)]
            client = SchedulerClient("localhost:{}".format(_PORT))
            workflow_execution_list = client.kill_all_workflow_executions(
                namespace='namespace', workflow_name='test_workflow')
            self.assertEqual(2, len(workflow_execution_list))

    def test_kill_workflow_execution(self):
        with mock.patch(SCHEDULER_CLASS) as mockScheduler:
            instance = mockScheduler.return_value
            self.server.scheduler_service._scheduler = instance

            instance.stop_workflow_execution.return_value \
                = WorkflowExecutionInfo(workflow_execution_id='id', status=Status.RUNNING)
            client = SchedulerClient("localhost:{}".format(_PORT))
            workflow_execution = client.kill_workflow_execution(
                execution_id='id')
            self.assertEqual('id', workflow_execution.execution_id)
            self.assertEqual(StateProto.RUNNING,
                             workflow_execution.execution_state)

    def test_get_workflow_execution(self):
        with mock.patch(SCHEDULER_CLASS) as mockScheduler:
            instance = mockScheduler.return_value
            self.server.scheduler_service._scheduler = instance

            instance.get_workflow_execution.return_value \
                = WorkflowExecutionInfo(workflow_execution_id='id', status=Status.INIT)
            client = SchedulerClient("localhost:{}".format(_PORT))
            workflow_execution = client.get_workflow_execution(
                execution_id='id')
            self.assertEqual('id', workflow_execution.execution_id)
            self.assertEqual(StateProto.INIT,
                             workflow_execution.execution_state)

    def test_list_workflow_executions(self):
        with mock.patch(SCHEDULER_CLASS) as mockScheduler:
            instance = mockScheduler.return_value
            self.server.scheduler_service._scheduler = instance

            instance.list_workflow_executions.return_value \
                = [WorkflowExecutionInfo(workflow_execution_id='id_1', status=Status.INIT),
                   WorkflowExecutionInfo(workflow_execution_id='id_2', status=Status.INIT)]
            client = SchedulerClient("localhost:{}".format(_PORT))
            workflow_execution_list = client.list_workflow_executions(
                namespace='namespace', workflow_name='test_workflow')
            self.assertEqual(2, len(workflow_execution_list))

    def test_start_job(self):
        with mock.patch(SCHEDULER_CLASS) as mockScheduler:
            instance = mockScheduler.return_value
            self.server.scheduler_service._scheduler = instance

            instance.start_job_execution.return_value \
                = JobExecutionInfo(job_name='job_name',
                                   status=Status.RUNNING,
                                   workflow_execution=WorkflowExecutionInfo(workflow_execution_id='id',
                                                                            status=Status.INIT))
            client = SchedulerClient("localhost:{}".format(_PORT))
            job = client.start_job(job_name='job_name', execution_id='id')
            self.assertEqual('job_name', job.name)
            self.assertEqual(StateProto.RUNNING, job.job_state)
            self.assertEqual('id', job.workflow_execution.execution_id)
            self.assertEqual(StateProto.INIT,
                             job.workflow_execution.execution_state)

    def test_stop_job(self):
        with mock.patch(SCHEDULER_CLASS) as mockScheduler:
            instance = mockScheduler.return_value
            self.server.scheduler_service._scheduler = instance

            instance.stop_job_execution.return_value \
                = JobExecutionInfo(job_name='job_name',
                                   status=Status.RUNNING,
                                   workflow_execution=WorkflowExecutionInfo(workflow_execution_id='id',
                                                                            status=Status.INIT))
            client = SchedulerClient("localhost:{}".format(_PORT))
            job = client.stop_job(job_name='job_name', execution_id='id')
            self.assertEqual('job_name', job.name)
            self.assertEqual(StateProto.RUNNING, job.job_state)
            self.assertEqual('id', job.workflow_execution.execution_id)
            self.assertEqual(StateProto.INIT,
                             job.workflow_execution.execution_state)

    def test_restart_job(self):
        with mock.patch(SCHEDULER_CLASS) as mockScheduler:
            instance = mockScheduler.return_value
            self.server.scheduler_service._scheduler = instance

            instance.restart_job_execution.return_value \
                = JobExecutionInfo(job_name='job_name',
                                   status=Status.RUNNING,
                                   workflow_execution=WorkflowExecutionInfo(workflow_execution_id='id',
                                                                            status=Status.INIT))
            client = SchedulerClient("localhost:{}".format(_PORT))
            job = client.restart_job(job_name='job_name', execution_id='id')
            self.assertEqual('job_name', job.name)
            self.assertEqual(StateProto.RUNNING, job.job_state)
            self.assertEqual('id', job.workflow_execution.execution_id)
            self.assertEqual(StateProto.INIT,
                             job.workflow_execution.execution_state)

    def test_get_job(self):
        with mock.patch(SCHEDULER_CLASS) as mockScheduler:
            instance = mockScheduler.return_value
            self.server.scheduler_service._scheduler = instance

            instance.get_job_executions.return_value \
                = [JobExecutionInfo(job_name='job_name',
                                    status=Status.RUNNING,
                                    workflow_execution=WorkflowExecutionInfo(workflow_execution_id='id',
                                                                             status=Status.INIT))]
            client = SchedulerClient("localhost:{}".format(_PORT))
            job = client.get_job(job_name='job_name', execution_id='id')
            self.assertEqual('job_name', job.name)
            self.assertEqual(StateProto.RUNNING, job.job_state)
            self.assertEqual('id', job.workflow_execution.execution_id)
            self.assertEqual(StateProto.INIT,
                             job.workflow_execution.execution_state)

    def test_list_jobs(self):
        with mock.patch(SCHEDULER_CLASS) as mockScheduler:
            instance = mockScheduler.return_value
            self.server.scheduler_service._scheduler = instance

            instance.list_job_executions.return_value \
                = [JobExecutionInfo(job_name='job_name_1',
                                    status=Status.RUNNING,
                                    workflow_execution=WorkflowExecutionInfo(workflow_execution_id='id',
                                                                             status=Status.INIT)),
                   JobExecutionInfo(job_name='job_name_2',
                                    status=Status.RUNNING,
                                    workflow_execution=WorkflowExecutionInfo(workflow_execution_id='id',
                                                                             status=Status.INIT))]
            client = SchedulerClient("localhost:{}".format(_PORT))
            job_list = client.list_jobs(execution_id='id')
            self.assertEqual(2, len(job_list))
class AIFlowServerRunner(object):
    """
    AI flow server runner. This class is the runner class for the AIFlowServer. It parse the server configuration and
    manage the live cycle of the AIFlowServer.
    """
    def __init__(self,
                 config_file: Text = None,
                 enable_ha=False,
                 server_uri: str = None,
                 ttl_ms=10000) -> None:
        """
        Set the server attribute according to the server config file.

        :param config_file: server configuration file.
        """
        super().__init__()
        self.config_file = config_file
        self.server = None
        self.server_config = AIFlowServerConfig()
        self.enable_ha = enable_ha
        self.server_uri = server_uri
        self.ttl_ms = ttl_ms

    def start(self, is_block=False) -> None:
        """
        Start the AI flow runner.

        :param is_block: AI flow runner will run non-stop if True.
        """
        if self.config_file is not None:
            self.server_config.load_from_file(self.config_file)
        else:
            self.server_config.set_server_port(str(_PORT))
        global GLOBAL_MASTER_CONFIG
        GLOBAL_MASTER_CONFIG = self.server_config
        logging.info("AI Flow Master Config {}".format(GLOBAL_MASTER_CONFIG))
        self.server = AIFlowServer(
            store_uri=self.server_config.get_db_uri(),
            port=str(self.server_config.get_server_port()),
            start_default_notification=self.server_config.
            start_default_notification(),
            notification_uri=self.server_config.get_notification_uri(),
            start_meta_service=self.server_config.start_meta_service(),
            start_model_center_service=self.server_config.
            start_model_center_service(),
            start_metric_service=self.server_config.start_metric_service(),
            start_scheduler_service=self.server_config.start_scheduler_service(
            ),
            scheduler_service_config=self.server_config.get_scheduler_config(),
            enabled_ha=self.server_config.get_enable_ha(),
            ha_server_uri=self.server_config.get_server_ip() + ":" +
            str(self.server_config.get_server_port()),
            ttl_ms=self.server_config.get_ha_ttl_ms())
        self.server.run(is_block=is_block)

    def stop(self, clear_sql_lite_db_file=True) -> None:
        """
        Stop the AI flow runner.

        :param clear_sql_lite_db_file: If True, the sqlite database files will be deleted When the server stops working.
        """
        self.server.stop(clear_sql_lite_db_file)

    def _clear_db(self):
        self.server._clear_db()