class AIFlowMaster(object): """ AI flow master. """ def __init__(self, config_file: Text = None) -> None: """ Set the master attribute according to the master config file. :param config_file: master configuration file. """ super().__init__() self.config_file = config_file self.server = None self.master_config = MasterConfig() def start(self, is_block=False) -> None: """ Start the AI flow master. :param is_block: AI flow master will run non-stop if True. """ if self.config_file is not None: self.master_config.load_from_file(self.config_file) else: self.master_config.set_master_port(str(_PORT)) global GLOBAL_MASTER_CONFIG GLOBAL_MASTER_CONFIG = self.master_config logging.info("AI Flow Master Config {}".format(GLOBAL_MASTER_CONFIG)) self.server = AIFlowServer( store_uri=self.master_config.get_db_uri(), port=str(self.master_config.get_master_port()), start_default_notification=self.master_config. start_default_notification(), notification_uri=self.master_config.get_notification_uri()) self.server.run(is_block=is_block) def stop(self, clear_sql_lite_db_file=True) -> None: """ Stop the AI flow master. :param clear_sql_lite_db_file: If True, the sqlite database files will be deleted When the server stops working. """ self.server.stop() if self.master_config.get_db_type( ) == DBType.SQLITE and clear_sql_lite_db_file: store = SqlAlchemyStore(self.master_config.get_db_uri()) base.metadata.drop_all(store.db_engine) os.remove(self.master_config.get_sql_lite_db_file()) def _clear_db(self): if self.master_config.get_db_type() == DBType.SQLITE: store = SqlAlchemyStore(self.master_config.get_db_uri()) base.metadata.drop_all(store.db_engine) base.metadata.create_all(store.db_engine)
class TestTensorFlowIrisModel(unittest.TestCase): def setUp(self) -> None: if os.path.exists(_SQLITE_DB_FILE): os.remove(_SQLITE_DB_FILE) self.server = AIFlowServer(store_uri=_SQLITE_DB_URI, port=_PORT) self.server.run() self.client = AIFlowClient(server_uri='localhost:' + _PORT) def tearDown(self) -> None: self.client.stop_listen_event() self.server.stop() os.remove(_SQLITE_DB_FILE) def test_save_and_load_model(self): iris_model = fit_and_save_model() tf_graph = tf.Graph() registered_model = self.client.create_registered_model( model_name='iris_model', model_type=ModelType.SAVED_MODEL, model_desc='iris model') self.client.create_model_version( model_name=registered_model.model_name, model_path=iris_model.path, model_metric='http://metric', model_flavor= '{"meta_graph_tags":["serve"],"signature_def_map_key":"predict"}', version_desc='iris model') class IrisWatcher(EventWatcher): def process(self, notifications): for notification in notifications: model_path = json.loads( notification.value).get('_model_path') model_flavor = json.loads( notification.value).get('_model_flavor') signature_def = tensorflow.load_model( model_uri=model_path, meta_graph_tags=json.loads(model_flavor).get( 'meta_graph_tags'), signature_def_map_key=json.loads(model_flavor).get( 'signature_def_map_key'), tf_session=tf.Session(graph=tf_graph)) for _, input_signature in signature_def.inputs.items(): t_input = tf_graph.get_tensor_by_name( input_signature.name) assert t_input is not None for _, output_signature in signature_def.outputs.items(): t_output = tf_graph.get_tensor_by_name( output_signature.name) assert t_output is not None self.client.start_listen_event(key=registered_model.model_name, watcher=IrisWatcher())
def setUpClass(cls) -> None: if os.path.exists(_SQLITE_DB_FILE): os.remove(_SQLITE_DB_FILE) cls.server = AIFlowServer(store_uri=_SQLITE_DB_URI, port=_PORT) cls.server.run() global client client = AIFlowClient(server_uri='localhost:' + _PORT)
def start(self, is_block=False) -> None: """ Start the AI flow master. :param is_block: AI flow master will run non-stop if True. """ if self.config_file is not None: self.master_config.load_from_file(self.config_file) else: self.master_config.set_master_port(str(_PORT)) global GLOBAL_MASTER_CONFIG GLOBAL_MASTER_CONFIG = self.master_config logging.info("AI Flow Master Config {}".format(GLOBAL_MASTER_CONFIG)) if not self.master_config.get_enable_ha(): self.server = AIFlowServer( store_uri=self.master_config.get_db_uri(), port=str(self.master_config.get_master_port()), start_default_notification=self.master_config.start_default_notification(), notification_uri=self.master_config.get_notification_uri()) else: self.server = HighAvailableAIFlowServer( store_uri=self.master_config.get_db_uri(), port=str(self.master_config.get_master_port()), start_default_notification=self.master_config.start_default_notification(), notification_uri=self.master_config.get_notification_uri(), server_uri=self.master_config.get_master_ip() + ":" + self.master_config.get_master_port(), ttl_ms=self.master_config.get_ha_ttl_ms()) self.server.run(is_block=is_block)
def setUp(self): config = SchedulerConfig() config.set_scheduler_class_name( 'ai_flow.test.scheduler.test_scheduling_service.MockScheduler') if os.path.exists(_SQLITE_DB_FILE): os.remove(_SQLITE_DB_FILE) self.server = AIFlowServer(store_uri=_SQLITE_DB_URI, port=_PORT, start_default_notification=False, start_deploy_service=False, start_meta_service=False, start_metric_service=False, start_model_center_service=False, start_scheduling_service=True, scheduler_config=config) self.server.run()
class TestSklearnModel(unittest.TestCase): def setUp(self) -> None: if os.path.exists(_SQLITE_DB_FILE): os.remove(_SQLITE_DB_FILE) self.server = AIFlowServer(store_uri=_SQLITE_DB_URI, port=_PORT) self.server.run() self.client = AIFlowClient(server_uri='localhost:' + _PORT) def tearDown(self) -> None: self.client.stop_listen_event() self.server.stop() os.remove(_SQLITE_DB_FILE) def test_save_and_load_model(self): knn_model = fit_model() model_path = tempfile.mkdtemp() if not os.path.exists(model_path): os.makedirs(model_path) model_path = os.path.join(model_path, 'model.pkl') save_model(sk_model=knn_model.model, output_path=model_path, serialization_format=SERIALIZATION_FORMAT_PICKLE) registered_model = self.client.create_registered_model( model_name='knn_model', model_type=ModelType.SAVED_MODEL, model_desc='knn model') self.client.create_model_version( model_name=registered_model.model_name, model_path=model_path, model_metric='http://metric', model_flavor='sklearn', version_desc='knn model') class KnnWatcher(EventWatcher): def process(self, notifications): for notification in notifications: load_path = json.loads( notification.value).get('_model_path') reloaded_knn_model = sklearn.load_model( model_uri=load_path) numpy.testing.assert_array_equal( knn_model.model.predict(knn_model.inference_data), reloaded_knn_model.predict(knn_model.inference_data)) os.remove(load_path) self.client.start_listen_event(key=registered_model.model_name, watcher=KnnWatcher())
def setUpClass(cls) -> None: if os.path.exists(_SQLITE_DB_FILE): os.remove(_SQLITE_DB_FILE) cls.server = AIFlowServer(store_uri=_SQLITE_DB_URI, port=_PORT) cls.sc_manager = cls.server.deploy_service.scheduler_manager cls.server.run() global client client = AIFlowClient(server_uri='localhost:' + _PORT) cls.ls_manager = cls.sc_manager.listener_manager
def setUpClass(cls) -> None: print("TestAIFlowClientMySQL setUpClass") db_server_url = get_mysql_server_url() cls.db_name = 'test_aiflow_client' cls.engine = sqlalchemy.create_engine(db_server_url) cls.engine.execute('DROP DATABASE IF EXISTS %s' % cls.db_name) cls.engine.execute('CREATE DATABASE IF NOT EXISTS %s' % cls.db_name) cls.store_uri = '%s/%s' % (db_server_url, cls.db_name) cls.server = AIFlowServer(store_uri=cls.store_uri, port=_PORT) cls.server.run() test_client.client = AIFlowClient(server_uri='localhost:' + _PORT) test_client.client1 = AIFlowClient(server_uri='localhost:' + _PORT) test_client.client2 = AIFlowClient(server_uri='localhost:' + _PORT)
def setUp(self) -> None: if os.path.exists(_SQLITE_DB_FILE): os.remove(_SQLITE_DB_FILE) self.server = AIFlowServer(store_uri=_SQLITE_DB_URI, port=_PORT) self.server.run() self.client = AIFlowClient(server_uri='localhost:' + _PORT)
# or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # import os import sys import tempfile sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../../../.."))) from ai_flow.rest_endpoint.service.server import AIFlowServer if __name__ == '__main__': fd, temp_dbfile = tempfile.mkstemp() os.close(fd) db_uri = '%s%s' % ('sqlite:///', temp_dbfile) server = AIFlowServer(store_uri=db_uri) server.run(is_block=True)
class TestSchedulingService(unittest.TestCase): def setUp(self): config = SchedulerConfig() config.set_scheduler_class_name( 'ai_flow.test.scheduler.test_scheduling_service.MockScheduler') if os.path.exists(_SQLITE_DB_FILE): os.remove(_SQLITE_DB_FILE) self.server = AIFlowServer(store_uri=_SQLITE_DB_URI, port=_PORT, start_default_notification=False, start_deploy_service=False, start_meta_service=False, start_metric_service=False, start_model_center_service=False, start_scheduling_service=True, scheduler_config=config) self.server.run() def tearDown(self): self.server.stop() if os.path.exists(_SQLITE_DB_FILE): os.remove(_SQLITE_DB_FILE) # def test_submit_workflow(self): # with mock.patch('ai_flow.test.scheduler.test_scheduling_service.MockScheduler') as mockScheduler: # instance = mockScheduler.return_value # instance.submit_workflow.return_value = WorkflowInfo(workflow_name='test_workflow') # client = SchedulingClient("localhost:{}".format(_PORT)) # workflow = client.submit_workflow_to_scheduler(namespace='namespace', workflow_name='test_workflow') # print(workflow) def test_delete_none_workflow(self): with mock.patch( 'ai_flow.test.scheduler.test_scheduling_service.MockScheduler' ) as mockScheduler: instance = mockScheduler.return_value self.server.scheduling_service._scheduler = instance instance.delete_workflow.return_value = None client = SchedulingClient("localhost:{}".format(_PORT)) with self.assertRaises(Exception) as context: workflow = client.delete_workflow( namespace='namespace', workflow_name='test_workflow') def test_delete_workflow(self): with mock.patch( 'ai_flow.test.scheduler.test_scheduling_service.MockScheduler' ) as mockScheduler: instance = mockScheduler.return_value self.server.scheduling_service._scheduler = instance instance.delete_workflow.return_value = WorkflowInfo( workflow_name='test_workflow') client = SchedulingClient("localhost:{}".format(_PORT)) workflow = client.delete_workflow(namespace='namespace', workflow_name='test_workflow') self.assertTrue('test_workflow', workflow.name) def test_pause_workflow(self): with mock.patch( 'ai_flow.test.scheduler.test_scheduling_service.MockScheduler' ) as mockScheduler: instance = mockScheduler.return_value self.server.scheduling_service._scheduler = instance instance.pause_workflow_scheduling.return_value = WorkflowInfo( workflow_name='test_workflow') client = SchedulingClient("localhost:{}".format(_PORT)) workflow = client.pause_workflow_scheduling( namespace='namespace', workflow_name='test_workflow') self.assertTrue('test_workflow', workflow.name) def test_resume_workflow(self): with mock.patch( 'ai_flow.test.scheduler.test_scheduling_service.MockScheduler' ) as mockScheduler: instance = mockScheduler.return_value self.server.scheduling_service._scheduler = instance instance.resume_workflow_scheduling.return_value = WorkflowInfo( workflow_name='test_workflow') client = SchedulingClient("localhost:{}".format(_PORT)) workflow = client.resume_workflow_scheduling( namespace='namespace', workflow_name='test_workflow') self.assertTrue('test_workflow', workflow.name) def test_get_workflow(self): with mock.patch( 'ai_flow.test.scheduler.test_scheduling_service.MockScheduler' ) as mockScheduler: instance = mockScheduler.return_value self.server.scheduling_service._scheduler = instance instance.get_workflow.return_value = WorkflowInfo( workflow_name='test_workflow') client = SchedulingClient("localhost:{}".format(_PORT)) workflow = client.get_workflow(namespace='namespace', workflow_name='test_workflow') self.assertTrue('test_workflow', workflow.name) def test_list_workflows(self): with mock.patch( 'ai_flow.test.scheduler.test_scheduling_service.MockScheduler' ) as mockScheduler: instance = mockScheduler.return_value self.server.scheduling_service._scheduler = instance instance.list_workflows.return_value = [ WorkflowInfo(workflow_name='test_workflow_1'), WorkflowInfo(workflow_name='test_workflow_2') ] client = SchedulingClient("localhost:{}".format(_PORT)) workflow_list = client.list_workflows(namespace='namespace') self.assertTrue(2, len(workflow_list)) def test_start_new_workflow_execution(self): with mock.patch( 'ai_flow.test.scheduler.test_scheduling_service.MockScheduler' ) as mockScheduler: instance = mockScheduler.return_value self.server.scheduling_service._scheduler = instance instance.start_new_workflow_execution.return_value \ = WorkflowExecutionInfo(execution_id='id', state=State.INIT) client = SchedulingClient("localhost:{}".format(_PORT)) workflow_execution = client.start_new_workflow_execution( namespace='namespace', workflow_name='test_workflow') self.assertEqual('id', workflow_execution.execution_id) self.assertEqual(StateProto.INIT, workflow_execution.execution_state) def test_kill_all_workflow_execution(self): with mock.patch( 'ai_flow.test.scheduler.test_scheduling_service.MockScheduler' ) as mockScheduler: instance = mockScheduler.return_value self.server.scheduling_service._scheduler = instance instance.kill_all_workflow_execution.return_value \ = [WorkflowExecutionInfo(execution_id='id_1', state=State.INIT), WorkflowExecutionInfo(execution_id='id_2', state=State.INIT)] client = SchedulingClient("localhost:{}".format(_PORT)) workflow_execution_list = client.kill_all_workflow_executions( namespace='namespace', workflow_name='test_workflow') self.assertEqual(2, len(workflow_execution_list)) def test_kill_workflow_execution(self): with mock.patch( 'ai_flow.test.scheduler.test_scheduling_service.MockScheduler' ) as mockScheduler: instance = mockScheduler.return_value self.server.scheduling_service._scheduler = instance instance.kill_workflow_execution.return_value \ = WorkflowExecutionInfo(execution_id='id', state=State.RUNNING) client = SchedulingClient("localhost:{}".format(_PORT)) workflow_execution = client.kill_workflow_execution( execution_id='id') self.assertEqual('id', workflow_execution.execution_id) self.assertEqual(StateProto.RUNNING, workflow_execution.execution_state) def test_get_workflow_execution(self): with mock.patch( 'ai_flow.test.scheduler.test_scheduling_service.MockScheduler' ) as mockScheduler: instance = mockScheduler.return_value self.server.scheduling_service._scheduler = instance instance.get_workflow_execution.return_value \ = WorkflowExecutionInfo(execution_id='id', state=State.INIT) client = SchedulingClient("localhost:{}".format(_PORT)) workflow_execution = client.get_workflow_execution( execution_id='id') self.assertEqual('id', workflow_execution.execution_id) self.assertEqual(StateProto.INIT, workflow_execution.execution_state) def test_list_workflow_executions(self): with mock.patch( 'ai_flow.test.scheduler.test_scheduling_service.MockScheduler' ) as mockScheduler: instance = mockScheduler.return_value self.server.scheduling_service._scheduler = instance instance.list_workflow_executions.return_value \ = [WorkflowExecutionInfo(execution_id='id_1', state=State.INIT), WorkflowExecutionInfo(execution_id='id_2', state=State.INIT)] client = SchedulingClient("localhost:{}".format(_PORT)) workflow_execution_list = client.list_workflow_executions( namespace='namespace', workflow_name='test_workflow') self.assertEqual(2, len(workflow_execution_list)) def test_start_job(self): with mock.patch( 'ai_flow.test.scheduler.test_scheduling_service.MockScheduler' ) as mockScheduler: instance = mockScheduler.return_value self.server.scheduling_service._scheduler = instance instance.start_job.return_value \ = JobInfo(job_name='job_name', state=State.RUNNING, workflow_execution=WorkflowExecutionInfo(execution_id='id', state=State.INIT)) client = SchedulingClient("localhost:{}".format(_PORT)) job = client.start_job(job_name='job_name', execution_id='id') self.assertEqual('job_name', job.name) self.assertEqual(StateProto.RUNNING, job.job_state) self.assertEqual('id', job.workflow_execution.execution_id) self.assertEqual(StateProto.INIT, job.workflow_execution.execution_state) def test_stop_job(self): with mock.patch( 'ai_flow.test.scheduler.test_scheduling_service.MockScheduler' ) as mockScheduler: instance = mockScheduler.return_value self.server.scheduling_service._scheduler = instance instance.stop_job.return_value \ = JobInfo(job_name='job_name', state=State.RUNNING, workflow_execution=WorkflowExecutionInfo(execution_id='id', state=State.INIT)) client = SchedulingClient("localhost:{}".format(_PORT)) job = client.stop_job(job_name='job_name', execution_id='id') self.assertEqual('job_name', job.name) self.assertEqual(StateProto.RUNNING, job.job_state) self.assertEqual('id', job.workflow_execution.execution_id) self.assertEqual(StateProto.INIT, job.workflow_execution.execution_state) def test_restart_job(self): with mock.patch( 'ai_flow.test.scheduler.test_scheduling_service.MockScheduler' ) as mockScheduler: instance = mockScheduler.return_value self.server.scheduling_service._scheduler = instance instance.restart_job.return_value \ = JobInfo(job_name='job_name', state=State.RUNNING, workflow_execution=WorkflowExecutionInfo(execution_id='id', state=State.INIT)) client = SchedulingClient("localhost:{}".format(_PORT)) job = client.restart_job(job_name='job_name', execution_id='id') self.assertEqual('job_name', job.name) self.assertEqual(StateProto.RUNNING, job.job_state) self.assertEqual('id', job.workflow_execution.execution_id) self.assertEqual(StateProto.INIT, job.workflow_execution.execution_state) def test_get_job(self): with mock.patch( 'ai_flow.test.scheduler.test_scheduling_service.MockScheduler' ) as mockScheduler: instance = mockScheduler.return_value self.server.scheduling_service._scheduler = instance instance.get_job.return_value \ = JobInfo(job_name='job_name', state=State.RUNNING, workflow_execution=WorkflowExecutionInfo(execution_id='id', state=State.INIT)) client = SchedulingClient("localhost:{}".format(_PORT)) job = client.get_job(job_name='job_name', execution_id='id') self.assertEqual('job_name', job.name) self.assertEqual(StateProto.RUNNING, job.job_state) self.assertEqual('id', job.workflow_execution.execution_id) self.assertEqual(StateProto.INIT, job.workflow_execution.execution_state) def test_list_jobs(self): with mock.patch( 'ai_flow.test.scheduler.test_scheduling_service.MockScheduler' ) as mockScheduler: instance = mockScheduler.return_value self.server.scheduling_service._scheduler = instance instance.list_jobs.return_value \ = [JobInfo(job_name='job_name_1', state=State.RUNNING, workflow_execution=WorkflowExecutionInfo(execution_id='id', state=State.INIT)), JobInfo(job_name='job_name_2', state=State.RUNNING, workflow_execution=WorkflowExecutionInfo(execution_id='id', state=State.INIT))] client = SchedulingClient("localhost:{}".format(_PORT)) job_list = client.list_jobs(execution_id='id') self.assertEqual(2, len(job_list))