def main(): af.init_ai_flow_context() with af.job_config('task_1'): af.user_define_operation(BashProcessor("sleep 30")) with af.job_config('task_2'): af.user_define_operation(BashProcessor("sleep 60")) with af.job_config('task_3'): af.user_define_operation(BashProcessor("echo hello")) af.action_on_job_status('task_2', 'task_1', upstream_job_status=Status.RUNNING, action=JobAction.START) af.action_on_job_status('task_2', 'task_1', upstream_job_status=Status.FINISHED, action=JobAction.STOP) af.action_on_job_status('task_3', 'task_1', upstream_job_status=Status.RUNNING, action=JobAction.START) af.action_on_job_status('task_3', 'task_2', upstream_job_status=Status.KILLED, action=JobAction.RESTART) workflow_name = af.current_workflow_config().workflow_name stop_workflow_executions(workflow_name) af.workflow_operation.submit_workflow(workflow_name) af.workflow_operation.start_new_workflow_execution(workflow_name)
def test_three_task(self): with af.job_config('task_1'): af.user_define_operation(processor=None) with af.job_config('task_2'): af.user_define_operation(processor=None) with af.job_config('task_3'): af.user_define_operation(processor=None) af.action_on_event(job_name='task_3', event_key='a', event_type='a', event_value='a', sender='task_1') af.action_on_job_status(job_name='task_3', upstream_job_name='task_2', upstream_job_status=Status.FINISHED, action=JobAction.START) w = af.workflow_operation.submit_workflow( workflow_name='test_dag_generator') code = w.properties.get('code') self.assertTrue( ".subscribe_event('a', 'a', 'default', 'task_1')" in code) # Now do not support the event_type equals JOB_STATUS_CHANGED event. # self.assertTrue(".subscribe_event('test_dag_generator', 'JOB_STATUS_CHANGED', 'test_project', 'task_2')" in code) self.assertTrue( ".set_events_handler(AIFlowHandler(configs_op_" in code)
def run_workflow(client: NotificationClient): with af.job_config('task_3'): af.user_define_operation(processor=PyProcessor3()) with af.job_config('task_4'): af.user_define_operation(processor=PyProcessor4()) with af.job_config('task_5'): af.user_define_operation(processor=PyProcessor5()) event_condition = af.MeetAllEventCondition() event_condition.add_event(event_key='k_1', event_value='v_1', namespace='*', sender='*') event_condition.add_event(event_key='k_2', event_value='v_2', namespace='*', sender='*') af.action_on_events(job_name='task_5', event_condition=event_condition, action=af.JobAction.START) workflow_info = af.workflow_operation.submit_workflow( workflow_name=af.current_workflow_config().workflow_name) workflow_execution = af.workflow_operation.start_new_workflow_execution( workflow_name=af.current_workflow_config().workflow_name) while True: with create_session() as session: dag_run = session.query(DagRun)\ .filter(DagRun.dag_id == 'test_project.{}'.format(af.current_workflow_config().workflow_name))\ .first() if dag_run is not None: ti = session.query(TaskInstance).filter(TaskInstance.task_id == 'task_5').first() if ti.state == State.SUCCESS: break else: time.sleep(1)
def main(): af.init_ai_flow_context() with af.job_config('task_1'): af.user_define_operation(BashProcessor("echo hello")) with af.job_config('task_2'): af.user_define_operation(BashProcessor("echo hello")) af.action_on_job_status('task_2', 'task_1') workflow_name = af.current_workflow_config().workflow_name stop_workflow_executions(workflow_name) af.workflow_operation.submit_workflow(workflow_name) af.workflow_operation.start_new_workflow_execution(workflow_name)
def test_cluster_flink_java_task(self): flink_home = os.environ.get('FLINK_HOME') word_count_jar = os.path.join(flink_home, 'examples', 'batch', 'WordCount.jar') output_file = os.path.join(flink_home, 'log', 'output') if os.path.exists(output_file): os.remove(output_file) jar_dir = os.path.join(project_path, 'dependencies', 'jar') if not os.path.exists(jar_dir): os.makedirs(jar_dir) shutil.copy(word_count_jar, jar_dir) args = [ '--input', os.path.join(flink_home, 'conf', 'flink-conf.yaml'), '--output', output_file ] with af.job_config('task_2'): af.user_define_operation(processor=flink.FlinkJavaProcessor( entry_class=None, main_jar_file='WordCount.jar', args=args)) w = af.workflow_operation.submit_workflow( workflow_name=af.current_workflow_config().workflow_name) je = af.workflow_operation.start_job_execution(job_name='task_2', execution_id='1') je = af.workflow_operation.get_job_execution(job_name='task_2', execution_id='1') self.assertEqual(Status.FINISHED, je.status) dep_dir = os.path.join(project_path, 'dependencies') if os.path.exists(dep_dir): shutil.rmtree(dep_dir)
def test_one_task(self): with af.job_config('task_1'): af.user_define_operation(processor=None) w = af.workflow_operation.submit_workflow( workflow_name='test_dag_generator') code = w.properties.get('code') self.assertTrue('op_0 = AIFlowOperator' in code)
def test_two_task(self): with af.job_config('task_1'): af.user_define_operation(processor=None) with af.job_config('task_2'): af.user_define_operation(processor=None) af.action_on_event(job_name='task_2', event_key='a', event_type='a', event_value='a', sender='task_1') w = af.workflow_operation.submit_workflow( workflow_name='test_dag_generator') code = w.properties.get('code') self.assertTrue( "op_1.subscribe_event('a', 'a', 'default', 'task_1')" in code) self.assertTrue( "op_1.set_events_handler(AIFlowHandler(configs_op_1))" in code)
def test_action_on_job_status_two_status(self): with af.job_config('task_1'): af.user_define_operation(processor=None) with af.job_config('task_2'): af.user_define_operation(processor=None) af.action_on_job_status(job_name='task_2', upstream_job_name='task_1', upstream_job_status=Status.RUNNING, action=JobAction.START) af.action_on_job_status(job_name='task_2', upstream_job_name='task_1', upstream_job_status=Status.FINISHED, action=JobAction.STOP) w = af.workflow_operation.submit_workflow( workflow_name='test_dag_generator') code = w.properties.get('code') self.assertTrue('"event_value": "RUNNING"' in code) self.assertTrue('"event_value": "FINISHED"' in code)
def test_bash_task(self): with af.job_config('task_1'): af.user_define_operation(processor=bash.BashProcessor( bash_command='echo "Xiao ming hello world!"')) w = af.workflow_operation.submit_workflow(workflow_name='test_bash') je = af.workflow_operation.start_job_execution(job_name='task_1', execution_id='1') je = af.workflow_operation.get_job_execution(job_name='task_1', execution_id='1') self.assertEqual(Status.FINISHED, je.status)
def test_python_task(self): with af.job_config('task_1'): af.user_define_operation(processor=PyProcessor1()) w = af.workflow_operation.submit_workflow( workflow_name=af.current_workflow_config().workflow_name) je = af.workflow_operation.start_job_execution(job_name='task_1', execution_id='1') je = af.workflow_operation.get_job_execution(job_name='task_1', execution_id='1') self.assertEqual(Status.FINISHED, je.status)
def test_action_on_job_status(self): with af.job_config('task_1'): af.user_define_operation(processor=None) with af.job_config('task_2'): af.user_define_operation(processor=None) with af.job_config('task_3'): af.user_define_operation(processor=None) af.action_on_job_status(job_name='task_2', upstream_job_name='task_1') af.action_on_job_status(job_name='task_3', upstream_job_name='task_2', upstream_job_status=Status.RUNNING, action=JobAction.START) w = af.workflow_operation.submit_workflow( workflow_name='test_dag_generator') code = w.properties.get('code') self.assertTrue( "op_1.subscribe_event('test_dag_generator.task_1', 'TASK_STATUS_CHANGED', 'test_project', 'task_1')" in code) self.assertTrue( "op_2.subscribe_event('test_dag_generator.task_2', 'TASK_STATUS_CHANGED', 'test_project', 'task_2')" in code)
def run_workflow(client: NotificationClient): with af.job_config('task_1'): af.user_define_operation(processor=bash.BashProcessor(bash_command='echo "Xiao ming hello world!"')) with af.job_config('task_2'): af.user_define_operation(processor=bash.BashProcessor(bash_command='echo "Xiao li hello world!"')) af.action_on_job_status('task_2', 'task_1', Status.FINISHED, JobAction.START) workflow_info = af.workflow_operation.submit_workflow( workflow_name=af.current_workflow_config().workflow_name) workflow_execution = af.workflow_operation.start_new_workflow_execution( workflow_name=af.current_workflow_config().workflow_name) while True: with create_session() as session: ti = session.query(TaskInstance)\ .filter(TaskInstance.dag_id == 'test_project.{}'.format(af.current_workflow_config().workflow_name), TaskInstance.task_id == 'task_2')\ .first() if ti is not None and ti.state == State.SUCCESS: break else: time.sleep(1)
def test_cluster_flink_task(self): with af.job_config('task_2'): input_example = af.user_define_operation(processor=Source()) processed = af.transform(input=[input_example], transform_processor=Transformer()) af.user_define_operation(input=[processed], processor=Sink()) w = af.workflow_operation.submit_workflow( workflow_name=af.current_workflow_config().workflow_name) je = af.workflow_operation.start_job_execution(job_name='task_2', execution_id='1') je = af.workflow_operation.get_job_execution(job_name='task_2', execution_id='1') self.assertEqual(Status.FINISHED, je.status)
def test_stop_python_task(self): time.sleep(1) with af.job_config('task_1'): af.user_define_operation(processor=PyProcessor2()) w = af.workflow_operation.submit_workflow(workflow_name='test_python') je = af.workflow_operation.start_job_execution(job_name='task_1', execution_id='1') time.sleep(2) af.workflow_operation.stop_job_execution(job_name='task_1', execution_id='1') je = af.workflow_operation.get_job_execution(job_name='task_1', execution_id='1') self.assertEqual(Status.FAILED, je.status) self.assertTrue('err' in je.properties)
def test_stop_bash_task(self): time.sleep(1) with af.job_config('task_1'): af.user_define_operation(processor=bash.BashProcessor( bash_command='sleep 10')) w = af.workflow_operation.submit_workflow(workflow_name='test_bash') je = af.workflow_operation.start_job_execution(job_name='task_1', execution_id='1') af.workflow_operation.stop_job_execution(job_name='task_1', execution_id='1') je = af.workflow_operation.get_job_execution(job_name='task_1', execution_id='1') self.assertEqual(Status.FAILED, je.status) self.assertTrue('err' in je.properties)
def test_periodic_interval_workflow(self): workflow_config_ = af.current_workflow_config() workflow_config_.periodic_config = PeriodicConfig(trigger_config={ 'start_date': "2020,1,1,,,,Asia/Chongqing", 'interval': "1,1,1," }) with af.job_config('task_1'): af.user_define_operation(processor=None) w = af.workflow_operation.submit_workflow( workflow_name='test_dag_generator') code = w.properties.get('code') self.assertTrue('op_0 = AIFlowOperator' in code) self.assertTrue('datetime' in code) self.assertTrue('schedule_interval' in code) self.assertTrue('timedelta' in code)
def test_stop_local_flink_task(self): with af.job_config('task_1'): input_example = af.user_define_operation(processor=Source()) processed = af.transform(input=[input_example], transform_processor=Transformer2()) af.user_define_operation(input=[processed], processor=Sink()) w = af.workflow_operation.submit_workflow(workflow_name='test_python') je = af.workflow_operation.start_job_execution(job_name='task_1', execution_id='1') time.sleep(2) af.workflow_operation.stop_job_execution(job_name='task_1', execution_id='1') je = af.workflow_operation.get_job_execution(job_name='task_1', execution_id='1') self.assertEqual(Status.FAILED, je.status) self.assertTrue('err' in je.properties)
def run_workflow(client: NotificationClient): with af.job_config('task_1'): af.user_define_operation(processor=PyProcessor1()) w = af.workflow_operation.submit_workflow( workflow_name=af.current_workflow_config().workflow_name) wei = af.workflow_operation.start_new_workflow_execution( workflow_name=af.current_workflow_config().workflow_name) set_workflow_execution_info(wei) while True: with create_session() as session: dag_run = session.query(DagRun) \ .filter(DagRun.dag_id == 'test_project.{}' .format(af.current_workflow_config().workflow_name)).first() if dag_run is not None and dag_run.state == State.SUCCESS: break else: time.sleep(1)
def run_workflow(client: NotificationClient): with af.job_config(task_name): af.user_define_operation(processor=bash.BashProcessor(bash_command='echo "Xiao ming hello world!"')) workflow_info = af.workflow_operation.submit_workflow( workflow_name=af.current_workflow_config().workflow_name) workflow_execution = af.workflow_operation.start_new_workflow_execution( workflow_name=af.current_workflow_config().workflow_name) while True: with create_session() as session: tes = session.query(TaskExecution)\ .filter(TaskExecution.dag_id == 'test_project.{}' .format(af.current_workflow_config().workflow_name), TaskExecution.task_id == task_name).all() if len(tes) == 2: break else: time.sleep(1)
def run_workflow(): # Init project af.init_ai_flow_context() artifact_prefix = af.current_project_config().get_project_name() + "." # Training of model with af.job_config('train'): # Register metadata of training data(dataset) and read dataset(i.e. training dataset) train_dataset = af.register_dataset(name=artifact_prefix + 'train_dataset', uri=DATASET_URI.format('train')) train_read_dataset = af.read_dataset(dataset_info=train_dataset, read_dataset_processor=DatasetReader()) # Register model metadata and train model train_model = af.register_model(model_name=artifact_prefix + 'KNN', model_desc='KNN model') train_channel = af.train(input=[train_read_dataset], training_processor=ModelTrainer(), model_info=train_model) # Validation of model with af.job_config('validate'): # Read validation dataset validate_dataset = af.register_dataset(name=artifact_prefix + 'validate_dataset', uri=DATASET_URI.format('test')) # Validate model before it is used to predict validate_read_dataset = af.read_dataset(dataset_info=validate_dataset, read_dataset_processor=ValidateDatasetReader()) validate_artifact_name = artifact_prefix + 'validate_artifact' validate_artifact = af.register_artifact(name=validate_artifact_name, uri=get_file_dir(__file__) + '/validate_result') validate_channel = af.model_validate(input=[validate_read_dataset], model_info=train_model, model_validation_processor=ModelValidator(validate_artifact_name)) # Prediction(Inference) using flink with af.job_config('predict'): # Read test data and do prediction predict_dataset = af.register_dataset(name=artifact_prefix + 'predict_dataset', uri=DATASET_URI.format('test')) predict_read_dataset = af.read_dataset(dataset_info=predict_dataset, read_dataset_processor=Source()) predict_channel = af.predict(input=[predict_read_dataset], model_info=train_model, prediction_processor=Predictor()) # Save prediction result write_dataset = af.register_dataset(name=artifact_prefix + 'write_dataset', uri=get_file_dir(__file__) + '/predict_result.csv') af.write_dataset(input=predict_channel, dataset_info=write_dataset, write_dataset_processor=Sink()) # Define relation graph connected by control edge: train -> validate -> predict af.action_on_model_version_event(job_name='validate', model_version_event_type=ModelVersionEventType.MODEL_GENERATED, model_name=train_model.name) af.action_on_model_version_event(job_name='predict', model_version_event_type=ModelVersionEventType.MODEL_VALIDATED, model_name=train_model.name) # Submit workflow af.workflow_operation.submit_workflow(af.current_workflow_config().workflow_name) # Run workflow af.workflow_operation.start_new_workflow_execution(af.current_workflow_config().workflow_name)
def run_workflow(): af.init_ai_flow_context() artifact_prefix = af.current_project_config().get_project_name() + "." with af.job_config('train'): # Register metadata raw training data(dataset) and read dataset(i.e. training dataset) train_dataset = af.register_dataset(name=artifact_prefix + 'train_dataset', uri=DATASET_URI.format('train')) train_read_dataset = af.read_dataset( dataset_info=train_dataset, read_dataset_processor=TrainDatasetReader()) train_transform = af.transform( input=[train_read_dataset], transform_processor=TrainDatasetTransformer()) train_model = af.register_model(model_name=artifact_prefix + 'logistic-regression', model_desc='logistic regression model') train_channel = af.train(input=[train_transform], training_processor=ModelTrainer(), model_info=train_model) with af.job_config('validate'): validate_dataset = af.register_dataset( name=artifact_prefix + 'validate_dataset', uri=DATASET_URI.format('evaluate')) validate_read_dataset = af.read_dataset( dataset_info=validate_dataset, read_dataset_processor=ValidateDatasetReader()) validate_transform = af.transform( input=[validate_read_dataset], transform_processor=ValidateTransformer()) validate_artifact_name = artifact_prefix + 'validate_artifact' validate_artifact = af.register_artifact(name=validate_artifact_name, uri=get_file_dir(__file__) + '/validate_result') validate_channel = af.model_validate( input=[validate_transform], model_info=train_model, model_validation_processor=ModelValidator(validate_artifact_name)) with af.job_config('push'): # Push model to serving # Register metadata of pushed model push_model_artifact_name = artifact_prefix + 'push_model_artifact' push_model_artifact = af.register_artifact( name=push_model_artifact_name, uri=get_file_dir(__file__) + '/pushed_model') af.push_model( model_info=train_model, pushing_model_processor=ModelPusher(push_model_artifact_name)) with af.job_config('predict'): predict_dataset = af.register_dataset( name=artifact_prefix + 'predict_dataset', uri=DATASET_URI.format('predict')) predict_read_dataset = af.read_dataset( dataset_info=predict_dataset, read_dataset_processor=PredictDatasetReader()) predict_transform = af.transform( input=[predict_read_dataset], transform_processor=PredictTransformer()) predict_channel = af.predict(input=[predict_transform], model_info=train_model, prediction_processor=ModelPredictor()) write_dataset = af.register_dataset( name=artifact_prefix + 'export_dataset', uri=get_file_dir(__file__) + '/predict_result') af.write_dataset(input=predict_channel, dataset_info=write_dataset, write_dataset_processor=DatasetWriter()) af.action_on_model_version_event( job_name='validate', model_version_event_type=ModelVersionEventType.MODEL_GENERATED, model_name=train_model.name) af.action_on_model_version_event( job_name='push', model_version_event_type=ModelVersionEventType.MODEL_VALIDATED, model_name=train_model.name) # Run workflow af.workflow_operation.submit_workflow( af.current_workflow_config().workflow_name) af.workflow_operation.start_new_workflow_execution( af.current_workflow_config().workflow_name)
def run_workflow(): af.init_ai_flow_context() artifact_prefix = af.current_project_config().get_project_name() + "." with af.job_config('train'): # Training of model # Register metadata raw training data(dataset) and read dataset(i.e. training dataset) train_dataset = af.register_dataset(name=artifact_prefix + 'train_dataset', uri=DATASET_URI.format('train')) train_read_dataset = af.read_dataset( dataset_info=train_dataset, read_dataset_processor=DatasetReader()) # Transform(preprocessing) dataset train_transform = af.transform( input=[train_read_dataset], transform_processor=DatasetTransformer()) # Register model metadata and train model train_model = af.register_model(model_name=artifact_prefix + 'logistic-regression', model_desc='logistic regression model') train_channel = af.train(input=[train_transform], training_processor=ModelTrainer(), model_info=train_model) with af.job_config('evaluate'): # Evaluation of model evaluate_dataset = af.register_dataset( name=artifact_prefix + 'evaluate_dataset', uri=DATASET_URI.format('evaluate')) evaluate_read_dataset = af.read_dataset( dataset_info=evaluate_dataset, read_dataset_processor=EvaluateDatasetReader()) evaluate_transform = af.transform( input=[evaluate_read_dataset], transform_processor=EvaluateTransformer()) # Register disk path used to save evaluate result evaluate_artifact_name = artifact_prefix + 'evaluate_artifact' evaluate_artifact = af.register_artifact(name=evaluate_artifact_name, uri=get_file_dir(__file__) + '/evaluate_result') # Evaluate model evaluate_channel = af.evaluate( input=[evaluate_transform], model_info=train_model, evaluation_processor=ModelEvaluator(evaluate_artifact_name)) with af.job_config('validate'): # Validation of model # Read validation dataset and validate model before it is used to predict validate_dataset = af.register_dataset( name=artifact_prefix + 'validate_dataset', uri=DATASET_URI.format('evaluate')) validate_read_dataset = af.read_dataset( dataset_info=validate_dataset, read_dataset_processor=ValidateDatasetReader()) validate_transform = af.transform( input=[validate_read_dataset], transform_processor=ValidateTransformer()) validate_artifact_name = artifact_prefix + 'validate_artifact' validate_artifact = af.register_artifact(name=validate_artifact_name, uri=get_file_dir(__file__) + '/validate_result') validate_channel = af.model_validate( input=[validate_transform], model_info=train_model, model_validation_processor=ModelValidator(validate_artifact_name)) with af.job_config('push'): # Push model to serving # Register metadata of pushed model push_model_artifact_name = artifact_prefix + 'push_model_artifact' push_model_artifact = af.register_artifact( name=push_model_artifact_name, uri=get_file_dir(__file__) + '/pushed_model') af.push_model( model_info=train_model, pushing_model_processor=ModelPusher(push_model_artifact_name)) with af.job_config('predict'): # Prediction(Inference) predict_dataset = af.register_dataset( name=artifact_prefix + 'predict_dataset', uri=DATASET_URI.format('predict')) predict_read_dataset = af.read_dataset( dataset_info=predict_dataset, read_dataset_processor=PredictDatasetReader()) predict_transform = af.transform( input=[predict_read_dataset], transform_processor=PredictTransformer()) predict_channel = af.predict(input=[predict_transform], model_info=train_model, prediction_processor=ModelPredictor()) # Save prediction result write_dataset = af.register_dataset( name=artifact_prefix + 'write_dataset', uri=get_file_dir(__file__) + '/predict_result') af.write_dataset(input=predict_channel, dataset_info=write_dataset, write_dataset_processor=DatasetWriter()) # Define relation graph connected by control edge: train -> evaluate -> validate -> push -> predict af.action_on_job_status('evaluate', 'train') af.action_on_job_status('validate', 'evaluate') af.action_on_job_status('push', 'validate') af.action_on_job_status('predict', 'push') # Run workflow af.workflow_operation.submit_workflow( af.current_workflow_config().workflow_name) af.workflow_operation.start_new_workflow_execution( af.current_workflow_config().workflow_name)
def run_workflow(): af.init_ai_flow_context() artifact_prefix = af.current_project_config().get_project_name() + "." # the config of train job is a periodic job which means it will # run every `interval`(defined in workflow_config.yaml) seconds with af.job_config('train'): # Register metadata raw training data(dataset) and read dataset(i.e. training dataset) train_dataset = af.register_dataset(name=artifact_prefix + 'train_dataset', uri=DATASET_URI.format('train')) train_read_dataset = af.read_dataset( dataset_info=train_dataset, read_dataset_processor=DatasetReader()) # Transform(preprocessing) dataset train_transform = af.transform( input=[train_read_dataset], transform_processor=DatasetTransformer()) # Register model metadata and train model train_model = af.register_model(model_name=artifact_prefix + 'logistic-regression', model_desc='logistic regression model') train_channel = af.train(input=[train_transform], training_processor=ModelTrainer(), model_info=train_model) with af.job_config('validate'): # Validation of model # Read validation dataset and validate model before it is used to predict validate_dataset = af.register_dataset( name=artifact_prefix + 'validate_dataset', uri=DATASET_URI.format('evaluate')) validate_read_dataset = af.read_dataset( dataset_info=validate_dataset, read_dataset_processor=ValidateDatasetReader()) validate_transform = af.transform( input=[validate_read_dataset], transform_processor=ValidateTransformer()) validate_artifact_name = artifact_prefix + 'validate_artifact' validate_artifact = af.register_artifact(name=validate_artifact_name, uri=get_file_dir(__file__) + '/validate_result') validate_channel = af.model_validate( input=[validate_transform], model_info=train_model, model_validation_processor=ModelValidator(validate_artifact_name)) with af.job_config('push'): # Push model to serving # Register metadata of pushed model push_model_artifact_name = artifact_prefix + 'push_model_artifact' push_model_artifact = af.register_artifact( name=push_model_artifact_name, uri=get_file_dir(__file__) + '/pushed_model') af.push_model( model_info=train_model, pushing_model_processor=ModelPusher(push_model_artifact_name)) with af.job_config('predict'): # Prediction(Inference) predict_dataset = af.register_dataset( name=artifact_prefix + 'predict_dataset', uri=DATASET_URI.format('predict')) predict_read_dataset = af.read_dataset( dataset_info=predict_dataset, read_dataset_processor=PredictDatasetReader()) predict_transform = af.transform( input=[predict_read_dataset], transform_processor=PredictTransformer()) predict_channel = af.predict(input=[predict_transform], model_info=train_model, prediction_processor=ModelPredictor()) # Save prediction result write_dataset = af.register_dataset( name=artifact_prefix + 'write_dataset', uri=get_file_dir(__file__) + '/predict_result') af.write_dataset(input=predict_channel, dataset_info=write_dataset, write_dataset_processor=DatasetWriter()) # Define relation graph connected by control edge: # Once a round of training is done, validator will be launched and # pusher will be launched if the new model is better. # Prediction will start once the first round of training is done and # when pusher pushes(deploys) a new model, the predictor will use the latest deployed model as well. af.action_on_model_version_event( job_name='validate', model_version_event_type=ModelVersionEventType.MODEL_GENERATED, model_name=train_model.name) af.action_on_model_version_event( job_name='push', model_version_event_type=ModelVersionEventType.MODEL_VALIDATED, model_name=train_model.name) # Run workflow af.workflow_operation.submit_workflow( af.current_workflow_config().workflow_name) af.workflow_operation.start_new_workflow_execution( af.current_workflow_config().workflow_name)