class TestAIFlowContext(unittest.TestCase): def setUp(self): if os.path.exists(_SQLITE_DB_FILE): os.remove(_SQLITE_DB_FILE) self.server = AIFlowServer(store_uri=_SQLITE_DB_URI, port=_PORT, start_default_notification=False, start_meta_service=True, start_metric_service=False, start_model_center_service=False, start_scheduler_service=False) self.server.run() def tearDown(self): self.server.stop() if os.path.exists(_SQLITE_DB_FILE): os.remove(_SQLITE_DB_FILE) def test_init_ai_flow_context(self): init_ai_flow_context() project_config = current_project_config() self.assertEqual('test_project', project_config.get_project_name()) self.assertEqual('a', project_config.get('a')) project_context = current_project_context() self.assertEqual('test_project', project_context.project_name) workflow_config_ = current_workflow_config() self.assertEqual('test_ai_flow_context', workflow_config_.workflow_name) self.assertEqual(5, len(workflow_config_.job_configs))
class TestTensorFlowIrisModel(unittest.TestCase): def setUp(self) -> None: if os.path.exists(_SQLITE_DB_FILE): os.remove(_SQLITE_DB_FILE) self.server = AIFlowServer(store_uri=_SQLITE_DB_URI, port=_PORT, start_scheduler_service=False) self.server.run() self.client = AIFlowClient(server_uri='localhost:' + _PORT) def tearDown(self) -> None: self.client.stop_listen_event() self.server.stop() os.remove(_SQLITE_DB_FILE) def test_save_and_load_model(self): iris_model = fit_and_save_model() tf_graph = tf.Graph() registered_model = self.client.create_registered_model( model_name='iris_model', model_desc='iris model') self.client.create_model_version( model_name=registered_model.model_name, model_path=iris_model.path, model_type= '{"meta_graph_tags":["serve"],"signature_def_map_key":"predict"}', version_desc='iris model') class IrisWatcher(EventWatcher): def process(self, notifications): for notification in notifications: model_path = json.loads( notification.value).get('_model_path') model_flavor = json.loads( notification.value).get('_model_type') print(json.loads(notification.value).keys()) print(model_path) signature_def = load_tensorflow_saved_model( model_uri=model_path, meta_graph_tags=json.loads(model_flavor).get( 'meta_graph_tags'), signature_def_map_key=json.loads(model_flavor).get( 'signature_def_map_key'), tf_session=tf.Session(graph=tf_graph)) for _, input_signature in signature_def.inputs.items(): t_input = tf_graph.get_tensor_by_name( input_signature.name) assert t_input is not None for _, output_signature in signature_def.outputs.items(): t_output = tf_graph.get_tensor_by_name( output_signature.name) assert t_output is not None self.client.start_listen_event(key=registered_model.model_name, watcher=IrisWatcher())
def start_aiflow_server(host, port): port = str(port) server_uri = host + ":" + port server = AIFlowServer(store_uri=_SQLITE_DB_URI, port=port, enabled_ha=True, start_scheduler_service=False, ha_server_uri=server_uri, notification_uri='localhost:30031', start_default_notification=False) server.run() return server
class TestSklearnModel(unittest.TestCase): def setUp(self) -> None: if os.path.exists(_SQLITE_DB_FILE): os.remove(_SQLITE_DB_FILE) self.server = AIFlowServer(store_uri=_SQLITE_DB_URI, port=_PORT, start_scheduler_service=False) self.server.run() self.client = AIFlowClient(server_uri='localhost:' + _PORT) def tearDown(self) -> None: self.client.stop_listen_event() self.server.stop() os.remove(_SQLITE_DB_FILE) def test_save_and_load_model(self): knn_model = fit_model() model_path = tempfile.mkdtemp() if not os.path.exists(model_path): os.makedirs(model_path) model_path = os.path.join(model_path, 'model.pkl') save_model(sk_model=knn_model.model, output_path=model_path, serialization_format=SERIALIZATION_FORMAT_PICKLE) registered_model = self.client.create_registered_model( model_name='knn_model', model_desc='knn model') self.client.create_model_version( model_name=registered_model.model_name, model_path=model_path, model_type='sklearn', version_desc='knn model') class KnnWatcher(EventWatcher): def process(self, notifications): for notification in notifications: load_path = json.loads( notification.value).get('_model_path') reloaded_knn_model = load_scikit_learn_model( model_uri=load_path) numpy.testing.assert_array_equal( knn_model.model.predict(knn_model.inference_data), reloaded_knn_model.predict(knn_model.inference_data)) os.remove(load_path) self.client.start_listen_event(key=registered_model.model_name, watcher=KnnWatcher())
# or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # import os import sys import tempfile sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../../../.."))) from ai_flow.endpoint.server.server import AIFlowServer if __name__ == '__main__': fd, temp_dbfile = tempfile.mkstemp() os.close(fd) db_uri = '%s%s' % ('sqlite:///', temp_dbfile) server = AIFlowServer(store_uri=db_uri) server.run(is_block=True)
class TestHighAvailableAIFlowServer(unittest.TestCase): @staticmethod def start_aiflow_server(host, port): port = str(port) server_uri = host + ":" + port server = AIFlowServer(store_uri=_SQLITE_DB_URI, port=port, enabled_ha=True, start_scheduler_service=False, ha_server_uri=server_uri, notification_uri='localhost:30031', start_default_notification=False) server.run() return server def wait_for_new_members_detected(self, new_member_uri): while True: living_member = self.client.living_aiflow_members if new_member_uri in living_member: break else: time.sleep(1) def setUp(self) -> None: SqlAlchemyStore(_SQLITE_DB_URI) self.notification = NotificationMaster( service=NotificationService(storage=MemoryEventStorage()), port=30031) self.notification.run() self.server1 = AIFlowServer(store_uri=_SQLITE_DB_URI, port=50051, enabled_ha=True, start_scheduler_service=False, ha_server_uri='localhost:50051', notification_uri='localhost:30031', start_default_notification=False) self.server1.run() self.server2 = None self.server3 = None self.config = ProjectConfig() self.config.set_enable_ha(True) self.config.set_notification_service_uri('localhost:30031') self.client = AIFlowClient( server_uri='localhost:50052,localhost:50051', project_config=self.config) def tearDown(self) -> None: self.client.stop_listen_event() self.client.disable_high_availability() if self.server1 is not None: self.server1.stop() if self.server2 is not None: self.server2.stop() if self.server3 is not None: self.server3.stop() if self.notification is not None: self.notification.stop() store = SqlAlchemyStore(_SQLITE_DB_URI) base.metadata.drop_all(store.db_engine) def test_server_change(self) -> None: self.client.register_project("test_project") projects = self.client.list_project(10, 0) self.assertEqual(self.client.current_aiflow_uri, "localhost:50051") self.assertEqual(projects[0].name, "test_project") self.server2 = self.start_aiflow_server("localhost", 50052) self.wait_for_new_members_detected("localhost:50052") self.server1.stop() projects = self.client.list_project(10, 0) self.assertEqual(self.client.current_aiflow_uri, "localhost:50052") self.assertEqual(projects[0].name, "test_project") self.server3 = self.start_aiflow_server("localhost", 50053) self.wait_for_new_members_detected("localhost:50053") self.server2.stop() projects = self.client.list_project(10, 0) self.assertEqual(self.client.current_aiflow_uri, "localhost:50053") self.assertEqual(projects[0].name, "test_project")
class TestSchedulerService(unittest.TestCase): def setUp(self): config = SchedulerServiceConfig() config.set_scheduler_class_name(SCHEDULER_CLASS) config.set_scheduler_config({}) if os.path.exists(_SQLITE_DB_FILE): os.remove(_SQLITE_DB_FILE) self.server = AIFlowServer(store_uri=_SQLITE_DB_URI, port=_PORT, start_default_notification=False, start_meta_service=False, start_metric_service=False, start_model_center_service=False, start_scheduler_service=True, scheduler_service_config=config) self.server.run() def tearDown(self): self.server.stop() if os.path.exists(_SQLITE_DB_FILE): os.remove(_SQLITE_DB_FILE) def test_submit_workflow(self): with mock.patch(SCHEDULER_CLASS) as mockScheduler: instance = mockScheduler.return_value instance.submit_workflow.return_value = WorkflowInfo( workflow_name='test_workflow') client = SchedulerClient("localhost:{}".format(_PORT)) with self.assertRaises(Exception) as context: workflow = client.submit_workflow_to_scheduler( namespace='namespace', workflow_name='test_workflow', workflow_json='') self.assertTrue('workflow json is empty' in str(context.exception)) def test_pause_workflow(self): with mock.patch(SCHEDULER_CLASS) as mockScheduler: instance = mockScheduler.return_value self.server.scheduler_service._scheduler = instance instance.pause_workflow_scheduling.return_value = WorkflowInfo( workflow_name='test_workflow') client = SchedulerClient("localhost:{}".format(_PORT)) workflow = client.pause_workflow_scheduling( namespace='namespace', workflow_name='test_workflow') self.assertTrue('test_workflow', workflow.name) def test_resume_workflow(self): with mock.patch(SCHEDULER_CLASS) as mockScheduler: instance = mockScheduler.return_value self.server.scheduler_service._scheduler = instance instance.resume_workflow_scheduling.return_value = WorkflowInfo( workflow_name='test_workflow') client = SchedulerClient("localhost:{}".format(_PORT)) workflow = client.resume_workflow_scheduling( namespace='namespace', workflow_name='test_workflow') self.assertTrue('test_workflow', workflow.name) def test_start_new_workflow_execution(self): with mock.patch(SCHEDULER_CLASS) as mockScheduler: instance = mockScheduler.return_value self.server.scheduler_service._scheduler = instance instance.start_new_workflow_execution.return_value \ = WorkflowExecutionInfo(workflow_execution_id='id', status=Status.INIT) client = SchedulerClient("localhost:{}".format(_PORT)) workflow_execution = client.start_new_workflow_execution( namespace='namespace', workflow_name='test_workflow') self.assertEqual('id', workflow_execution.execution_id) self.assertEqual(StateProto.INIT, workflow_execution.execution_state) def test_kill_all_workflow_execution(self): with mock.patch(SCHEDULER_CLASS) as mockScheduler: instance = mockScheduler.return_value self.server.scheduler_service._scheduler = instance instance.stop_all_workflow_execution.return_value \ = [WorkflowExecutionInfo(workflow_execution_id='id_1', status=Status.INIT), WorkflowExecutionInfo(workflow_execution_id='id_2', status=Status.INIT)] client = SchedulerClient("localhost:{}".format(_PORT)) workflow_execution_list = client.kill_all_workflow_executions( namespace='namespace', workflow_name='test_workflow') self.assertEqual(2, len(workflow_execution_list)) def test_kill_workflow_execution(self): with mock.patch(SCHEDULER_CLASS) as mockScheduler: instance = mockScheduler.return_value self.server.scheduler_service._scheduler = instance instance.stop_workflow_execution.return_value \ = WorkflowExecutionInfo(workflow_execution_id='id', status=Status.RUNNING) client = SchedulerClient("localhost:{}".format(_PORT)) workflow_execution = client.kill_workflow_execution( execution_id='id') self.assertEqual('id', workflow_execution.execution_id) self.assertEqual(StateProto.RUNNING, workflow_execution.execution_state) def test_get_workflow_execution(self): with mock.patch(SCHEDULER_CLASS) as mockScheduler: instance = mockScheduler.return_value self.server.scheduler_service._scheduler = instance instance.get_workflow_execution.return_value \ = WorkflowExecutionInfo(workflow_execution_id='id', status=Status.INIT) client = SchedulerClient("localhost:{}".format(_PORT)) workflow_execution = client.get_workflow_execution( execution_id='id') self.assertEqual('id', workflow_execution.execution_id) self.assertEqual(StateProto.INIT, workflow_execution.execution_state) def test_list_workflow_executions(self): with mock.patch(SCHEDULER_CLASS) as mockScheduler: instance = mockScheduler.return_value self.server.scheduler_service._scheduler = instance instance.list_workflow_executions.return_value \ = [WorkflowExecutionInfo(workflow_execution_id='id_1', status=Status.INIT), WorkflowExecutionInfo(workflow_execution_id='id_2', status=Status.INIT)] client = SchedulerClient("localhost:{}".format(_PORT)) workflow_execution_list = client.list_workflow_executions( namespace='namespace', workflow_name='test_workflow') self.assertEqual(2, len(workflow_execution_list)) def test_start_job(self): with mock.patch(SCHEDULER_CLASS) as mockScheduler: instance = mockScheduler.return_value self.server.scheduler_service._scheduler = instance instance.start_job_execution.return_value \ = JobExecutionInfo(job_name='job_name', status=Status.RUNNING, workflow_execution=WorkflowExecutionInfo(workflow_execution_id='id', status=Status.INIT)) client = SchedulerClient("localhost:{}".format(_PORT)) job = client.start_job(job_name='job_name', execution_id='id') self.assertEqual('job_name', job.name) self.assertEqual(StateProto.RUNNING, job.job_state) self.assertEqual('id', job.workflow_execution.execution_id) self.assertEqual(StateProto.INIT, job.workflow_execution.execution_state) def test_stop_job(self): with mock.patch(SCHEDULER_CLASS) as mockScheduler: instance = mockScheduler.return_value self.server.scheduler_service._scheduler = instance instance.stop_job_execution.return_value \ = JobExecutionInfo(job_name='job_name', status=Status.RUNNING, workflow_execution=WorkflowExecutionInfo(workflow_execution_id='id', status=Status.INIT)) client = SchedulerClient("localhost:{}".format(_PORT)) job = client.stop_job(job_name='job_name', execution_id='id') self.assertEqual('job_name', job.name) self.assertEqual(StateProto.RUNNING, job.job_state) self.assertEqual('id', job.workflow_execution.execution_id) self.assertEqual(StateProto.INIT, job.workflow_execution.execution_state) def test_restart_job(self): with mock.patch(SCHEDULER_CLASS) as mockScheduler: instance = mockScheduler.return_value self.server.scheduler_service._scheduler = instance instance.restart_job_execution.return_value \ = JobExecutionInfo(job_name='job_name', status=Status.RUNNING, workflow_execution=WorkflowExecutionInfo(workflow_execution_id='id', status=Status.INIT)) client = SchedulerClient("localhost:{}".format(_PORT)) job = client.restart_job(job_name='job_name', execution_id='id') self.assertEqual('job_name', job.name) self.assertEqual(StateProto.RUNNING, job.job_state) self.assertEqual('id', job.workflow_execution.execution_id) self.assertEqual(StateProto.INIT, job.workflow_execution.execution_state) def test_get_job(self): with mock.patch(SCHEDULER_CLASS) as mockScheduler: instance = mockScheduler.return_value self.server.scheduler_service._scheduler = instance instance.get_job_executions.return_value \ = [JobExecutionInfo(job_name='job_name', status=Status.RUNNING, workflow_execution=WorkflowExecutionInfo(workflow_execution_id='id', status=Status.INIT))] client = SchedulerClient("localhost:{}".format(_PORT)) job = client.get_job(job_name='job_name', execution_id='id') self.assertEqual('job_name', job.name) self.assertEqual(StateProto.RUNNING, job.job_state) self.assertEqual('id', job.workflow_execution.execution_id) self.assertEqual(StateProto.INIT, job.workflow_execution.execution_state) def test_list_jobs(self): with mock.patch(SCHEDULER_CLASS) as mockScheduler: instance = mockScheduler.return_value self.server.scheduler_service._scheduler = instance instance.list_job_executions.return_value \ = [JobExecutionInfo(job_name='job_name_1', status=Status.RUNNING, workflow_execution=WorkflowExecutionInfo(workflow_execution_id='id', status=Status.INIT)), JobExecutionInfo(job_name='job_name_2', status=Status.RUNNING, workflow_execution=WorkflowExecutionInfo(workflow_execution_id='id', status=Status.INIT))] client = SchedulerClient("localhost:{}".format(_PORT)) job_list = client.list_jobs(execution_id='id') self.assertEqual(2, len(job_list))
class AIFlowServerRunner(object): """ AI flow server runner. This class is the runner class for the AIFlowServer. It parse the server configuration and manage the live cycle of the AIFlowServer. """ def __init__(self, config_file: Text = None, enable_ha=False, server_uri: str = None, ttl_ms=10000) -> None: """ Set the server attribute according to the server config file. :param config_file: server configuration file. """ super().__init__() self.config_file = config_file self.server = None self.server_config = AIFlowServerConfig() self.enable_ha = enable_ha self.server_uri = server_uri self.ttl_ms = ttl_ms def start(self, is_block=False) -> None: """ Start the AI flow runner. :param is_block: AI flow runner will run non-stop if True. """ if self.config_file is not None: self.server_config.load_from_file(self.config_file) else: self.server_config.set_server_port(str(_PORT)) global GLOBAL_MASTER_CONFIG GLOBAL_MASTER_CONFIG = self.server_config logging.info("AI Flow Master Config {}".format(GLOBAL_MASTER_CONFIG)) self.server = AIFlowServer( store_uri=self.server_config.get_db_uri(), port=str(self.server_config.get_server_port()), start_default_notification=self.server_config. start_default_notification(), notification_uri=self.server_config.get_notification_uri(), start_meta_service=self.server_config.start_meta_service(), start_model_center_service=self.server_config. start_model_center_service(), start_metric_service=self.server_config.start_metric_service(), start_scheduler_service=self.server_config.start_scheduler_service( ), scheduler_service_config=self.server_config.get_scheduler_config(), enabled_ha=self.server_config.get_enable_ha(), ha_server_uri=self.server_config.get_server_ip() + ":" + str(self.server_config.get_server_port()), ttl_ms=self.server_config.get_ha_ttl_ms()) self.server.run(is_block=is_block) def stop(self, clear_sql_lite_db_file=True) -> None: """ Stop the AI flow runner. :param clear_sql_lite_db_file: If True, the sqlite database files will be deleted When the server stops working. """ self.server.stop(clear_sql_lite_db_file) def _clear_db(self): self.server._clear_db()