def __init__(self, db_uri, server_uri): db_engine = extract_db_engine_from_uri(db_uri) if DBType.value_of(db_engine) == DBType.MONGODB: username, password, host, port, db = parse_mongo_uri(db_uri) self.store = MongoStore(host=host, port=int(port), username=username, password=password, db=db) else: self.store = SqlAlchemyStore(db_uri) self.model_center_client = ModelCenterClient(server_uri)
def __init__(self, db_uri): self.db_uri = db_uri db_engine = extract_db_engine_from_uri(self.db_uri) if DBType.value_of(db_engine) == DBType.MONGODB: username, password, host, port, db = parse_mongo_uri(self.db_uri) self.store = MongoStore(host=host, port=int(port), username=username, password=password, db=db) else: self.store = SqlAlchemyStore(self.db_uri)
def __init__(self, store_uri, notification_uri=None): db_engine = extract_db_engine_from_uri(store_uri) if DBType.value_of(db_engine) == DBType.MONGODB: username, password, host, port, db = parse_mongo_uri(store_uri) self.model_repo_store = MongoStore(host=host, port=int(port), username=username, password=password, db=db) else: self.model_repo_store = SqlAlchemyStore(store_uri) self.notification_client = None if notification_uri is not None: self.notification_client = NotificationClient(notification_uri, default_namespace=DEFAULT_NAMESPACE)
def _clear_db(self): if self.db_type == DBType.SQLITE: store = SqlAlchemyStore(self.store_uri) base.metadata.drop_all(store.db_engine) base.metadata.create_all(store.db_engine) elif self.db_type == DBType.MONGODB: MongoStoreConnManager().drop_all()
def _clear_db(self): if self.master_config.get_db_type() == DBType.SQLITE: store = SqlAlchemyStore(self.master_config.get_db_uri()) base.metadata.drop_all(store.db_engine) base.metadata.create_all(store.db_engine) elif self.master_config.get_db_type() == DBType.MONGODB: MongoStoreConnManager().drop_all()
def __init__(self, store_uri=None, port=_PORT, start_default_notification: bool = True, notification_uri=None, ha_manager=None, server_uri=None, ha_storage=None, ttl_ms: int = 10000): super(HighAvailableAIFlowServer, self).__init__(store_uri, port, start_default_notification, notification_uri) if ha_manager is None: ha_manager = SimpleAIFlowServerHaManager() if server_uri is None: raise ValueError("server_uri is required!") if ha_storage is None: db_engine = extract_db_engine_from_uri(store_uri) if DBType.value_of(db_engine) == DBType.MONGODB: username, password, host, port, db = parse_mongo_uri(store_uri) ha_storage = MongoStore(host=host, port=int(port), username=username, password=password, db=db) else: ha_storage = SqlAlchemyStore(store_uri) self.ha_service = HighAvailableService(ha_manager, server_uri, ha_storage, ttl_ms) add_HighAvailabilityManagerServicer_to_server(self.ha_service, self.server)
def tearDown(self) -> None: print('tearDown') store = SqlAlchemyStore(_SQLITE_DB_URI) base.metadata.drop_all(store.db_engine) base.metadata.create_all(store.db_engine) af.default_graph().clear_graph() res = client.list_job(page_size=10, offset=0) self.assertIsNone(res)
def tearDown(self) -> None: self.client.stop_listen_event() self.client.disable_high_availability() if self.server1 is not None: self.server1.stop() if self.server2 is not None: self.server2.stop() if self.server3 is not None: self.server3.stop() store = SqlAlchemyStore(_SQLITE_DB_URI) base.metadata.drop_all(store.db_engine)
def setUp(self) -> None: SqlAlchemyStore(_SQLITE_DB_URI) self.server1 = HighAvailableAIFlowServer( store_uri=_SQLITE_DB_URI, port=50051, server_uri='localhost:50051') self.server1.run() self.server2 = None self.server3 = None self.config = ProjectConfig() self.config.set_enable_ha(True) self.client = AIFlowClient(server_uri='localhost:50052,localhost:50051', project_config=self.config)
def stop(self, clear_sql_lite_db_file=True) -> None: """ Stop the AI flow master. :param clear_sql_lite_db_file: If True, the sqlite database files will be deleted When the server stops working. """ self.server.stop() if self.master_config.get_db_type() == DBType.SQLITE and clear_sql_lite_db_file: store = SqlAlchemyStore(self.master_config.get_db_uri()) base.metadata.drop_all(store.db_engine) os.remove(self.master_config.get_sql_lite_db_file())
def stop(self, clear_sql_lite_db_file=False): self.executor.shutdown() self.server.stop(0) if self.enabled_ha: self.ha_service.stop() if self.db_type == DBType.SQLITE and clear_sql_lite_db_file: store = SqlAlchemyStore(self.store_uri) base.metadata.drop_all(store.db_engine) os.remove(self.store_uri[10:]) elif self.db_type == DBType.MONGODB: MongoStoreConnManager().disconnect_all() logging.info('AIFlow server stopped.')
def _add_ha_service(self, ha_manager, ha_server_uri, ha_storage, store_uri, ttl_ms): if ha_manager is None: ha_manager = SimpleAIFlowServerHaManager() if ha_server_uri is None: raise ValueError("ha_server_uri is required with ha enabled!") if ha_storage is None: db_engine = extract_db_engine_from_uri(store_uri) if DBType.value_of(db_engine) == DBType.MONGODB: username, password, host, port, db = parse_mongo_uri(store_uri) ha_storage = MongoStore(host=host, port=int(port), username=username, password=password, db=db) else: ha_storage = SqlAlchemyStore(store_uri) self.ha_service = HighAvailableService(ha_manager, ha_server_uri, ha_storage, ttl_ms) add_HighAvailabilityManagerServicer_to_server(self.ha_service, self.server)
def setUp(self) -> None: SqlAlchemyStore(_SQLITE_DB_URI) self.notification = NotificationMaster( service=NotificationService(storage=MemoryEventStorage()), port=30031) self.notification.run() self.server1 = AIFlowServer(store_uri=_SQLITE_DB_URI, port=50051, enabled_ha=True, start_scheduler_service=False, ha_server_uri='localhost:50051', notification_uri='localhost:30031', start_default_notification=False) self.server1.run() self.server2 = None self.server3 = None self.config = ProjectConfig() self.config.set_enable_ha(True) self.config.set_notification_service_uri('localhost:30031') self.client = AIFlowClient( server_uri='localhost:50052,localhost:50051', project_config=self.config)
def __init__(self, store_uri=None, port=_PORT, start_default_notification: bool = True, notification_uri=None, ha_manager=None, server_uri=None, ha_storage=None, ttl_ms: int = 10000): super(HighAvailableAIFlowServer, self).__init__(store_uri, port, start_default_notification, notification_uri) if ha_manager is None: ha_manager = SimpleAIFlowServerHaManager() if server_uri is None: raise ValueError("server_uri is required!") if ha_storage is None: ha_storage = SqlAlchemyStore(store_uri) self.ha_service = HighAvailableService(ha_manager, server_uri, ha_storage, ttl_ms) add_HighAvailabilityManagerServicer_to_server(self.ha_service, self.server)
class ModelCenterService(model_center_service_pb2_grpc.ModelCenterServiceServicer): def __init__(self, store_uri, notification_uri=None): db_engine = extract_db_engine_from_uri(store_uri) if DBType.value_of(db_engine) == DBType.MONGODB: username, password, host, port, db = parse_mongo_uri(store_uri) self.model_repo_store = MongoStore(host=host, port=int(port), username=username, password=password, db=db) else: self.model_repo_store = SqlAlchemyStore(store_uri) self.notification_client = None if notification_uri is not None: self.notification_client = NotificationClient(notification_uri, default_namespace=DEFAULT_NAMESPACE) @catch_exception def createRegisteredModel(self, request, context): registered_model_param = RegisteredModelParam.from_proto(request) registered_model_meta = self.model_repo_store.create_registered_model(registered_model_param.model_name, registered_model_param.model_desc) return _wrap_response(registered_model_meta.to_meta_proto()) @catch_exception def updateRegisteredModel(self, request, context): model_meta_param = RegisteredModel.from_proto(request) registered_model_param = RegisteredModelParam.from_proto(request) registered_model_meta = self.model_repo_store.update_registered_model( RegisteredModel(model_meta_param.model_name), registered_model_param.model_name, registered_model_param.model_desc) return _wrap_response(None if registered_model_meta is None else registered_model_meta.to_meta_proto()) @catch_exception def deleteRegisteredModel(self, request, context): model_meta_param = RegisteredModel.from_proto(request) self.model_repo_store.delete_registered_model(RegisteredModel(model_name=model_meta_param.model_name)) return _wrap_response(request.model_meta) @catch_exception def listRegisteredModels(self, request, context): registered_models = self.model_repo_store.list_registered_models() return _wrap_response(RegisteredModelMetas(registered_models=[registered_model.to_meta_proto() for registered_model in registered_models])) @catch_exception def getRegisteredModelDetail(self, request, context): model_meta_param = ModelVersion.from_proto(request) registered_model_detail = self.model_repo_store.get_registered_model_detail( RegisteredModel(model_name=model_meta_param.model_name)) return _wrap_response(None if registered_model_detail is None else registered_model_detail.to_detail_proto()) @catch_exception def createModelVersion(self, request, context): model_meta_param = ModelVersion.from_proto(request) model_version_param = ModelVersionParam.from_proto(request) model_version_meta = self.model_repo_store.create_model_version(model_meta_param.model_name, model_version_param.model_path, model_version_param.model_type, model_version_param.version_desc, model_version_param.current_stage) event_type = MODEL_VERSION_TO_EVENT_TYPE.get(ModelVersionStage.from_string(model_version_param.current_stage)) if self.notification_client is not None: self.notification_client.send_event(BaseEvent(model_version_meta.model_name, json.dumps(model_version_meta.__dict__), event_type)) return _wrap_response(model_version_meta.to_meta_proto()) @catch_exception def updateModelVersion(self, request, context): model_meta_param = ModelVersion.from_proto(request) model_version_param = ModelVersionParam.from_proto(request) model_version_meta = self.model_repo_store.update_model_version(model_meta_param, model_version_param.model_path, model_version_param.model_type, model_version_param.version_desc, model_version_param.current_stage) if model_version_param.current_stage is not None: event_type = MODEL_VERSION_TO_EVENT_TYPE.get( ModelVersionStage.from_string(model_version_param.current_stage)) if self.notification_client is not None: self.notification_client.send_event(BaseEvent(model_version_meta.model_name, json.dumps(model_version_meta.__dict__), event_type)) return _wrap_response(None if model_version_meta is None else model_version_meta.to_meta_proto()) @catch_exception def deleteModelVersion(self, request, context): model_meta_param = ModelVersion.from_proto(request) self.model_repo_store.delete_model_version(model_meta_param) if self.notification_client is not None: self.notification_client.send_event(BaseEvent(model_meta_param.model_name, json.dumps(model_meta_param.__dict__), ModelVersionEventType.MODEL_DELETED)) return _wrap_response(request.model_meta) @catch_exception def getModelVersionDetail(self, request, context): model_meta_param = ModelVersion.from_proto(request) model_version_meta = self.model_repo_store.get_model_version_detail(model_meta_param) return _wrap_response(None if model_version_meta is None else model_version_meta.to_meta_proto())
def _clear_db(self): if self.master_config.get_db_type() == DBType.SQLITE: store = SqlAlchemyStore(self.master_config.get_db_uri()) base.metadata.drop_all(store.db_engine) base.metadata.create_all(store.db_engine)
def tearDownClass(cls) -> None: client.stop_listen_event() store = SqlAlchemyStore(_SQLITE_DB_URI) base.metadata.drop_all(store.db_engine) os.remove(_SQLITE_DB_FILE)
class MetricService(MetricServiceServicer): def __init__(self, db_uri): self.db_uri = db_uri self.store = SqlAlchemyStore(db_uri) @catch_exception def registerMetricMeta(self, request, context): metric_meta_proto = request.metric_meta metric_meta = proto_to_metric_meta(metric_meta_proto) res_metric_meta = self.store.register_metric_meta( metric_meta.name, metric_meta.dataset_id, metric_meta.model_name, metric_meta.model_version, metric_meta.job_id, metric_meta.start_time, metric_meta.end_time, metric_meta.metric_type, metric_meta.uri, metric_meta.tags, metric_meta.metric_description, metric_meta.properties) return _warp_metric_meta_response(res_metric_meta) def registerMetricSummary(self, request, context): metric_summary_proto = request.metric_summary metric_summary = proto_to_metric_summary(metric_summary_proto) res_metric_summary = self.store.register_metric_summary( metric_id=metric_summary.metric_id, metric_key=metric_summary.metric_key, metric_value=metric_summary.metric_value) return _warp_metric_summary_response(res_metric_summary) def updateMetricMeta(self, request, context): metric_meta_proto = request.metric_meta if metric_meta_proto.metric_type == MetricTypeProto.DATASET: metric_type = MetricType.DATASET else: metric_type = MetricType.MODEL res_metric_meta = self.store.update_metric_meta( metric_meta_proto.uuid, metric_meta_proto.name.value if metric_meta_proto.HasField('name') else None, metric_meta_proto.dataset_id.value if metric_meta_proto.HasField('dataset_id') else None, metric_meta_proto.model_name.value if metric_meta_proto.HasField('model_name') else None, metric_meta_proto.model_version.value if metric_meta_proto.HasField('model_version') else None, metric_meta_proto.job_id.value if metric_meta_proto.HasField('job_id') else None, metric_meta_proto.start_time.value if metric_meta_proto.HasField('start_time') else None, metric_meta_proto.end_time.value if metric_meta_proto.HasField('end_time') else None, metric_type, metric_meta_proto.uri.value if metric_meta_proto.HasField('uri') else None, metric_meta_proto.tags.value if metric_meta_proto.HasField('tags') else None, metric_meta_proto.metric_description.value if metric_meta_proto.HasField('metric_description') else None, metric_meta_proto.properties if {} == metric_meta_proto.properties else None) return _warp_metric_meta_response(res_metric_meta) def updateMetricSummary(self, request, context): metric_summary_proto = request.metric_summary res_metric_summary = self.store.update_metric_summary( uuid=metric_summary_proto.uuid, metric_id=metric_summary_proto.metric_id.value if metric_summary_proto.HasField('metric_id') else None, metric_key=metric_summary_proto.metric_key.value if metric_summary_proto.HasField('metric_key') else None, metric_value=metric_summary_proto.metric_value.value if metric_summary_proto.HasField('metric_value') else None) return _warp_metric_summary_response(res_metric_summary) def getMetricMeta(self, request, context): metric_name = request.metric_name res_metric_meta = self.store.get_metric_meta(metric_name) return _warp_metric_meta_response(res_metric_meta) def getDatasetMetricMeta(self, request, context): dataset_id = request.dataset_id metric_meta = self.store.get_dataset_metric_meta(dataset_id=dataset_id) return _warp_list_metric_meta_response(metric_meta) def getModelMetricMeta(self, request, context): model_name = request.model_name model_version = request.model_version metric_meta = self.store.get_model_metric_meta( model_name=model_name, model_version=model_version) return _warp_list_metric_meta_response(metric_meta) def getMetricSummary(self, request, context): metric_id = request.metric_id metric_summary = self.store.get_metric_summary(metric_id=metric_id) return _warp_list_metric_summary_response(metric_summary) def deleteMetricMeta(self, request, context): uuid = request.uuid try: self.store.delete_metric_meta(uuid) return Response(return_code=str(ReturnCode.SUCCESS), return_msg='', data='') except Exception as e: return Response(return_code=str(ReturnCode.INTERNAL_ERROR), return_msg=str(e), data='') def deleteMetricSummary(self, request, context): uuid = request.uuid try: self.store.delete_metric_summary(uuid) return Response(return_code=str(ReturnCode.SUCCESS), return_msg='', data='') except Exception as e: return Response(return_code=str(ReturnCode.INTERNAL_ERROR), return_msg=str(e), data='')
class MetadataService(metadata_service_pb2_grpc.MetadataServiceServicer): def __init__(self, db_uri, server_uri): db_engine = extract_db_engine_from_uri(db_uri) if DBType.value_of(db_engine) == DBType.MONGODB: username, password, host, port, db = parse_mongo_uri(db_uri) self.store = MongoStore(host=host, port=int(port), username=username, password=password, db=db) else: self.store = SqlAlchemyStore(db_uri) self.model_center_client = ModelCenterClient(server_uri) '''dataset api''' @catch_exception def getDatasetById(self, request, context): dataset = self.store.get_dataset_by_id(request.id) return _wrap_meta_response(MetaToProto.dataset_meta_to_proto(dataset)) @catch_exception def getDatasetByName(self, request, context): dataset = self.store.get_dataset_by_name(request.name) return _wrap_meta_response(MetaToProto.dataset_meta_to_proto(dataset)) @catch_exception def registerDataset(self, request, context): dataset = transform_dataset_meta(request.dataset) dataset_meta = self.store.register_dataset( name=dataset.name, data_format=dataset.data_format, description=dataset.description, uri=dataset.uri, properties=dataset.properties, name_list=dataset.schema.name_list, type_list=dataset.schema.type_list) return _wrap_meta_response( MetaToProto.dataset_meta_to_proto(dataset_meta)) @catch_exception def registerDatasetWithCatalog(self, request, context): dataset = transform_dataset_meta(request.dataset) dataset_meta = self.store.register_dataset_with_catalog( name=dataset.name, catalog_name=dataset.catalog_name, catalog_type=dataset.catalog_type, catalog_database=dataset.catalog_database, catalog_connection_uri=dataset.catalog_connection_uri, catalog_table=dataset.catalog_table) return _wrap_meta_response( MetaToProto.dataset_meta_to_proto(dataset_meta)) @catch_exception def registerDatasets(self, request, context): _datasets = ProtoToMeta.proto_to_dataset_meta_list(request.datasets) response = self.store.register_datasets(_datasets) return _warp_dataset_list_response(response) @catch_exception def updateDataset(self, request, context): properties = None if request.properties == {} else request.properties name_list = request.name_list type_list = request.type_list if not name_list: name_list = None if not type_list: data_type_list = None else: data_type_list = [] for data_type in type_list: data_type_list.append(DataType(DataTypeProto.Name(data_type))) dataset_meta = self.store.update_dataset(dataset_name=request.name, data_format=request.data_format.value if request.HasField( 'data_format') else None, description=request.description.value if request.HasField( 'description') else None, uri=request.uri.value if request.HasField('uri') else None, properties=properties, name_list=name_list, type_list=data_type_list, catalog_name=request.catalog_name.value if request.HasField( 'catalog_name') else None, catalog_type=request.catalog_type.value if request.HasField( 'catalog_type') else None, catalog_database=request.catalog_database.value if request.HasField( 'catalog_database') else None, catalog_connection_uri=request.catalog_connection_uri.value \ if request.HasField('catalog_connection_uri') else None, catalog_table=request.catalog_table.value if request.HasField( 'catalog_table') else None) return _wrap_meta_response( MetaToProto.dataset_meta_to_proto(dataset_meta)) @catch_exception def listDatasets(self, request, context): dataset_meta_list = self.store.list_datasets(request.page_size, request.offset) return _warp_dataset_list_response(dataset_meta_list) @catch_exception def deleteDatasetById(self, request, context): status = self.store.delete_dataset_by_id(request.id) return _wrap_delete_response(status) @catch_exception def deleteDatasetByName(self, request, context): status = self.store.delete_dataset_by_name(request.name) return _wrap_delete_response(status) '''model relation api''' @catch_exception def getModelRelationById(self, request, context): model_meta = self.store.get_model_relation_by_id(request.id) return _wrap_meta_response( MetaToProto.model_relation_meta_to_proto(model_meta)) @catch_exception def getModelRelationByName(self, request, context): model_meta = self.store.get_model_relation_by_name(request.name) return _wrap_meta_response( MetaToProto.model_relation_meta_to_proto(model_meta)) @catch_exception def registerModelRelation(self, request, context): model = transform_model_relation_meta(request.model_relation) response = self.store.register_model_relation( name=model.name, project_id=model.project_id) return _wrap_meta_response( MetaToProto.model_relation_meta_to_proto(response)) @catch_exception def listModelRelation(self, request, context): model_list = self.store.list_model_relation(request.page_size, request.offset) return _warp_model_relation_list_response(model_list) @catch_exception def deleteModelRelationById(self, request, context): status = self.store.delete_model_relation_by_id(request.id) return _wrap_delete_response(status) @catch_exception def deleteModelRelationByName(self, request, context): status = self.store.delete_model_relation_by_name(request.name) return _wrap_delete_response(status) '''model api''' @catch_exception def getModelById(self, request, context): model_relation = self.store.get_model_relation_by_id(request.id) if model_relation is None: model_detail = None else: model_detail = self.model_center_client.get_registered_model_detail( model_relation.name) return _wrap_meta_response( MetaToProto.model_meta_to_proto(model_relation, model_detail)) @catch_exception def getModelByName(self, request, context): model_relation = self.store.get_model_relation_by_name(request.name) model_detail = self.model_center_client.get_registered_model_detail( request.name) return _wrap_meta_response( MetaToProto.model_meta_to_proto(model_relation, model_detail)) @catch_exception def registerModel(self, request, context): model = transform_model_meta(request.model) model_detail = self.model_center_client.create_registered_model( model.name, model.model_desc) model_relation = self.store.register_model_relation( name=model.name, project_id=model.project_id) return _wrap_meta_response( MetaToProto.model_meta_to_proto(model_relation, model_detail)) @catch_exception def deleteModelById(self, request, context): model_relation = self.store.get_model_relation_by_id(request.id) if model_relation is None: return _wrap_delete_response(Status.ERROR) else: model_relation_status = self.store.delete_model_relation_by_id( request.id) self.model_center_client.delete_registered_model( model_relation.name) return _wrap_delete_response(model_relation_status) @catch_exception def deleteModelByName(self, request, context): model_relation_status = self.store.delete_model_relation_by_name( request.name) self.model_center_client.delete_registered_model(request.name) return _wrap_delete_response(model_relation_status) '''model version relation api''' @catch_exception def getModelVersionRelationByVersion(self, request, context): model_version_meta = self.store.get_model_version_relation_by_version( request.name, request.model_id) return _wrap_meta_response( MetaToProto.model_version_relation_meta_to_proto( model_version_meta)) @catch_exception def listModelVersionRelation(self, request, context): model_version_meta_list = self.store.list_model_version_relation( request.model_id, request.page_size, request.offset) return _warp_model_version_relation_list_response( model_version_meta_list) @catch_exception def registerModelVersionRelation(self, request, context): model_version = transform_model_version_relation_meta( request.model_version_relation) response = self.store.register_model_version_relation( version=model_version.version, model_id=model_version.model_id, project_snapshot_id=model_version.project_snapshot_id) return _wrap_meta_response( MetaToProto.model_version_relation_meta_to_proto(response)) @catch_exception def deleteModelVersionRelationByVersion(self, request, context): status = self.store.delete_model_version_relation_by_version( request.name, request.model_id) return _wrap_delete_response(status) '''model version api''' @catch_exception def getModelVersionByVersion(self, request, context): model_version_relation = self.store.get_model_version_relation_by_version( request.name, request.model_id) if model_version_relation is None: model_version_detail = None else: model_relation = self.store.get_model_relation_by_id( model_version_relation.model_id) model_version_detail = self.model_center_client.get_model_version_detail( model_relation.name, request.name) return _wrap_meta_response( MetaToProto.model_version_meta_to_proto(model_version_relation, model_version_detail)) @catch_exception def registerModelVersion(self, request, context): model_version = transform_model_version_meta(request.model_version) model_relation = self.store.get_model_relation_by_id( model_version.model_id) model_version_detail = self.model_center_client.create_model_version( model_relation.name, model_version.model_path, model_version.model_type, model_version.version_desc, request.model_version.current_stage) model_version_relation = self.store.register_model_version_relation( version=model_version_detail.model_version, model_id=model_version.model_id, project_snapshot_id=model_version.project_snapshot_id) return _wrap_meta_response( MetaToProto.model_version_meta_to_proto(model_version_relation, model_version_detail)) @catch_exception def deleteModelVersionByVersion(self, request, context): model_version_relation = self.store.get_model_version_relation_by_version( request.name, request.model_id) if model_version_relation is None: return _wrap_delete_response(Status.ERROR) else: model_version__status = self.store.delete_model_version_relation_by_version( request.name, request.model_id) model_relation = self.store.get_model_relation_by_id( model_version_relation.model_id) if model_relation is not None: self.model_center_client.delete_model_version( model_relation.name, request.name) return _wrap_delete_response(model_version__status) @catch_exception def getDeployedModelVersion(self, request, context): model_version_detail = self.store.get_deployed_model_version( request.name) if model_version_detail is None: model_version_relation = None else: model_relation = self.store.get_model_relation_by_name( request.name) model_version_relation = self.store.get_model_version_relation_by_version( model_version_detail.model_version, model_relation.uuid) return _wrap_meta_response( MetaToProto.model_version_store_to_proto(model_version_relation, model_version_detail)) @catch_exception def getLatestValidatedModelVersion(self, request, context): model_version_detail = self.store.get_latest_validated_model_version( request.name) if model_version_detail is None: model_version_relation = None else: model_relation = self.store.get_model_relation_by_name( request.name) model_version_relation = self.store.get_model_version_relation_by_version( model_version_detail.model_version, model_relation.uuid) return _wrap_meta_response( MetaToProto.model_version_store_to_proto(model_version_relation, model_version_detail)) @catch_exception def getLatestGeneratedModelVersion(self, request, context): model_version_detail = self.store.get_latest_generated_model_version( request.name) if model_version_detail is None: model_version_relation = None else: model_relation = self.store.get_model_relation_by_name( request.name) model_version_relation = self.store.get_model_version_relation_by_version( model_version_detail.model_version, model_relation.uuid) return _wrap_meta_response( MetaToProto.model_version_store_to_proto(model_version_relation, model_version_detail)) '''project api''' @catch_exception def getProjectById(self, request, context): project_meta = self.store.get_project_by_id(request.id) return _wrap_meta_response( MetaToProto.project_meta_to_proto(project_meta)) @catch_exception def getProjectByName(self, request, context): project_meta = self.store.get_project_by_name(request.name) return _wrap_meta_response( MetaToProto.project_meta_to_proto(project_meta)) @catch_exception def listProject(self, request, context): project_meta_list = self.store.list_project(request.page_size, request.offset) return _warp_project_list_response(project_meta_list) @catch_exception def registerProject(self, request, context): project = transform_project_meta(request.project) response = self.store.register_project(name=project.name, uri=project.uri, properties=project.properties) return _wrap_meta_response(MetaToProto.project_meta_to_proto(response)) @catch_exception def updateProject(self, request, context): properties = None if request.properties == {} else request.properties project = self.store.update_project( project_name=request.name, uri=request.uri.value if request.HasField('uri') else None, properties=properties) return _wrap_meta_response(MetaToProto.project_meta_to_proto(project)) @catch_exception def deleteProjectById(self, request, context): status = self.store.delete_project_by_id(request.id) return _wrap_delete_response(status) @catch_exception def deleteProjectByName(self, request, context): status = self.store.delete_project_by_name(request.name) return _wrap_delete_response(status) '''artifact api''' @catch_exception def getArtifactById(self, request, context): artifact_meta = self.store.get_artifact_by_id(request.id) return _wrap_meta_response( MetaToProto.artifact_meta_to_proto(artifact_meta)) @catch_exception def getArtifactByName(self, request, context): artifact_meta = self.store.get_artifact_by_name(request.name) return _wrap_meta_response( MetaToProto.artifact_meta_to_proto(artifact_meta)) @catch_exception def registerArtifact(self, request, context): artifact = transform_artifact_meta(request.artifact) response = self.store.register_artifact( name=artifact.name, artifact_type=artifact.artifact_type, description=artifact.description, uri=artifact.uri, properties=artifact.properties) return _wrap_meta_response( MetaToProto.artifact_meta_to_proto(response)) @catch_exception def updateArtifact(self, request, context): properties = None if request.properties == {} else request.properties artifact = self.store.update_artifact( name=request.name, artifact_type=request.artifact_type.value if request.HasField('artifact_type') else None, properties=properties, description=request.description.value if request.HasField('description') else None, uri=request.uri.value if request.HasField('uri') else None, ) return _wrap_meta_response( MetaToProto.artifact_meta_to_proto(artifact)) @catch_exception def listArtifact(self, request, context): artifact_meta_list = self.store.list_artifact(request.page_size, request.offset) return _warp_artifact_list_response(artifact_meta_list) @catch_exception def deleteArtifactById(self, request, context): status = self.store.delete_artifact_by_id(request.id) return _wrap_delete_response(status) @catch_exception def deleteArtifactByName(self, request, context): status = self.store.delete_artifact_by_name(request.name) return _wrap_delete_response(status) @catch_exception def registerWorkflow(self, request, context): workflow = transform_workflow_meta(request.workflow) response = self.store.register_workflow(name=workflow.name, project_id=workflow.project_id, properties=workflow.properties) return _wrap_meta_response( MetaToProto.workflow_meta_to_proto(response)) @catch_exception def updateWorkflow(self, request, context): properties = None if request.properties == {} else request.properties workflow = self.store.update_workflow( workflow_name=request.workflow_name, project_name=request.project_name, properties=properties) return _wrap_meta_response( MetaToProto.workflow_meta_to_proto(workflow)) def getWorkflowById(self, request, context): workflow = self.store.get_workflow_by_id(workflow_id=request.id) return _wrap_meta_response( MetaToProto.workflow_meta_to_proto(workflow)) def getWorkflowByName(self, request, context): workflow = self.store.get_workflow_by_name( project_name=request.project_name, workflow_name=request.workflow_name) return _wrap_meta_response( MetaToProto.workflow_meta_to_proto(workflow)) def deleteWorkflowById(self, request, context): status = self.store.delete_workflow_by_id(workflow_id=request.id) return _wrap_delete_response(status) def deleteWorkflowByName(self, request, context): status = self.store.delete_workflow_by_name( project_name=request.project_name, workflow_name=request.workflow_name) return _wrap_delete_response(status) def listWorkflows(self, request, context): workflow_meta_list = self.store.list_workflows( project_name=request.project_name, page_size=request.page_size, offset=request.offset) return _wrap_workflow_list_response(workflow_meta_list)
def __init__(self, db_uri): self.db_uri = db_uri self.store = SqlAlchemyStore(db_uri)
class MetadataService(metadata_service_pb2_grpc.MetadataServiceServicer): def __init__(self, db_uri, server_uri): db_uri = db_uri self.store = SqlAlchemyStore(db_uri) self.model_center_client = ModelCenterClient(server_uri) '''example api''' @catch_exception def getExampleById(self, request, context): example = self.store.get_example_by_id(request.id) return _wrap_meta_response(MetaToProto.example_meta_to_proto(example)) @catch_exception def getExampleByName(self, request, context): example = self.store.get_example_by_name(request.name) return _wrap_meta_response(MetaToProto.example_meta_to_proto(example)) @catch_exception def registerExample(self, request, context): example = transform_example_meta(request.example) example_meta = self.store.register_example(name=example.name, support_type=example.support_type, data_type=example.data_type, data_format=example.data_format, description=example.description, batch_uri=example.batch_uri, stream_uri=example.stream_uri, create_time=example.create_time, update_time=example.update_time, properties=example.properties, name_list=example.schema.name_list, type_list=example.schema.type_list) return _wrap_meta_response(MetaToProto.example_meta_to_proto(example_meta)) @catch_exception def registerExampleWithCatalog(self, request, context): example = transform_example_meta(request.example) example_meta = self.store.register_example_with_catalog(name=example.name, support_type=example.support_type, catalog_name=example.catalog_name, catalog_type=example.catalog_type, catalog_database=example.catalog_database, catalog_connection_uri=example.catalog_connection_uri, catalog_version=example.catalog_version, catalog_table=example.catalog_table) return _wrap_meta_response(MetaToProto.example_meta_to_proto(example_meta)) @catch_exception def registerExamples(self, request, context): _examples = ProtoToMeta.proto_to_example_meta_list(request.examples) response = self.store.register_examples(_examples) return _warp_example_list_response(response) @catch_exception def updateExample(self, request, context): support_type = None if request.support_type == 0 else ExampleSupportType( ExampleSupportTypeProto.Name(request.support_type)) properties = None if request.properties == {} else request.properties name_list = request.name_list type_list = request.type_list if not name_list: name_list = None if not type_list: data_type_list = None else: data_type_list = [] for data_type in type_list: data_type_list.append(DataType(DataTypeProto.Name(data_type))) example_meta = self.store.update_example(example_name=request.name, support_type=support_type, data_type=request.data_type.value if request.HasField( 'data_type') else None, data_format=request.data_format.value if request.HasField( 'data_format') else None, description=request.description.value if request.HasField( 'description') else None, batch_uri=request.batch_uri.value if request.HasField( 'batch_uri') else None, stream_uri=request.stream_uri.value if request.HasField( 'stream_uri') else None, update_time=request.update_time.value if request.HasField( 'update_time') else None, properties=properties, name_list=name_list, type_list=data_type_list, catalog_name=request.catalog_name.value if request.HasField( 'catalog_name') else None, catalog_type=request.catalog_type.value if request.HasField( 'catalog_type') else None, catalog_database=request.catalog_database.value if request.HasField( 'catalog_database') else None, catalog_version=request.catalog_version.value if request.HasField( 'catalog_version') else None, catalog_connection_uri=request.catalog_connection_uri.value \ if request.HasField('catalog_connection_uri') else None, catalog_table=request.catalog_table.value if request.HasField( 'catalog_table') else None) return _wrap_meta_response(MetaToProto.example_meta_to_proto(example_meta)) @catch_exception def listExample(self, request, context): example_meta_list = self.store.list_example(request.page_size, request.offset) return _warp_example_list_response(example_meta_list) @catch_exception def deleteExampleById(self, request, context): status = self.store.delete_example_by_id(request.id) return _wrap_delete_response(status) @catch_exception def deleteExampleByName(self, request, context): status = self.store.delete_example_by_name(request.name) return _wrap_delete_response(status) '''model relation api''' @catch_exception def getModelRelationById(self, request, context): model_meta = self.store.get_model_relation_by_id(request.id) return _wrap_meta_response(MetaToProto.model_relation_meta_to_proto(model_meta)) @catch_exception def getModelRelationByName(self, request, context): model_meta = self.store.get_model_relation_by_name(request.name) return _wrap_meta_response(MetaToProto.model_relation_meta_to_proto(model_meta)) @catch_exception def registerModelRelation(self, request, context): model = transform_model_relation_meta(request.model_relation) response = self.store.register_model_relation(name=model.name, project_id=model.project_id) return _wrap_meta_response(MetaToProto.model_relation_meta_to_proto(response)) @catch_exception def listModelRelation(self, request, context): model_list = self.store.list_model_relation(request.page_size, request.offset) return _warp_model_relation_list_response(model_list) @catch_exception def deleteModelRelationById(self, request, context): status = self.store.delete_model_relation_by_id(request.id) return _wrap_delete_response(status) @catch_exception def deleteModelRelationByName(self, request, context): status = self.store.delete_model_relation_by_name(request.name) return _wrap_delete_response(status) '''model api''' @catch_exception def getModelById(self, request, context): model_relation = self.store.get_model_relation_by_id(request.id) if model_relation is None: model_detail = None else: model_detail = self.model_center_client.get_registered_model_detail(model_relation.name) return _wrap_meta_response(MetaToProto.model_meta_to_proto(model_relation, model_detail)) @catch_exception def getModelByName(self, request, context): model_relation = self.store.get_model_relation_by_name(request.name) model_detail = self.model_center_client.get_registered_model_detail(request.name) return _wrap_meta_response(MetaToProto.model_meta_to_proto(model_relation, model_detail)) @catch_exception def registerModel(self, request, context): model = transform_model_meta(request.model) model_detail = self.model_center_client.create_registered_model(model.name, ModelType.Name(model.model_type), model.model_desc) model_relation = self.store.register_model_relation(name=model.name, project_id=model.project_id) return _wrap_meta_response(MetaToProto.model_meta_to_proto(model_relation, model_detail)) @catch_exception def deleteModelById(self, request, context): model_relation = self.store.get_model_relation_by_id(request.id) if model_relation is None: return _wrap_delete_response(Status.ERROR) else: model_relation_status = self.store.delete_model_relation_by_id(request.id) self.model_center_client.delete_registered_model(model_relation.name) return _wrap_delete_response(model_relation_status) @catch_exception def deleteModelByName(self, request, context): model_relation_status = self.store.delete_model_relation_by_name(request.name) self.model_center_client.delete_registered_model(request.name) return _wrap_delete_response(model_relation_status) '''model version relation api''' @catch_exception def getModelVersionRelationByVersion(self, request, context): model_version_meta = self.store.get_model_version_relation_by_version(request.name, request.model_id) return _wrap_meta_response(MetaToProto.model_version_relation_meta_to_proto(model_version_meta)) @catch_exception def listModelVersionRelation(self, request, context): model_version_meta_list = self.store.list_model_version_relation(request.model_id, request.page_size, request.offset) return _warp_model_version_relation_list_response(model_version_meta_list) @catch_exception def registerModelVersionRelation(self, request, context): model_version = transform_model_version_relation_meta(request.model_version_relation) response = self.store.register_model_version_relation(version=model_version.version, model_id=model_version.model_id, workflow_execution_id=model_version.workflow_execution_id) return _wrap_meta_response(MetaToProto.model_version_relation_meta_to_proto(response)) @catch_exception def deleteModelVersionRelationByVersion(self, request, context): status = self.store.delete_model_version_relation_by_version(request.name, request.model_id) return _wrap_delete_response(status) '''model version api''' @catch_exception def getModelVersionByVersion(self, request, context): model_version_relation = self.store.get_model_version_relation_by_version(request.name, request.model_id) if model_version_relation is None: model_version_detail = None else: model_relation = self.store.get_model_relation_by_id(model_version_relation.model_id) model_version_detail = self.model_center_client.get_model_version_detail(model_relation.name, request.name) return _wrap_meta_response( MetaToProto.model_version_meta_to_proto(model_version_relation, model_version_detail)) @catch_exception def registerModelVersion(self, request, context): model_version = transform_model_version_meta(request.model_version) model_relation = self.store.get_model_relation_by_id(model_version.model_id) model_version_detail = self.model_center_client.create_model_version(model_relation.name, model_version.model_path, model_version.model_metric, model_version.model_flavor, model_version.version_desc, request.model_version.current_stage) model_version_relation = self.store.register_model_version_relation(version=model_version_detail.model_version, model_id=model_version.model_id, workflow_execution_id= model_version.workflow_execution_id) return _wrap_meta_response( MetaToProto.model_version_meta_to_proto(model_version_relation, model_version_detail)) @catch_exception def deleteModelVersionByVersion(self, request, context): model_version_relation = self.store.get_model_version_relation_by_version(request.name, request.model_id) if model_version_relation is None: return _wrap_delete_response(Status.ERROR) else: model_version__status = self.store.delete_model_version_relation_by_version(request.name, request.model_id) model_relation = self.store.get_model_relation_by_id(model_version_relation.model_id) if model_relation is not None: self.model_center_client.delete_model_version(model_relation.name, request.name) return _wrap_delete_response(model_version__status) @catch_exception def getDeployedModelVersion(self, request, context): model_version_detail = self.store.get_deployed_model_version(request.name) if model_version_detail is None: model_version_relation = None else: model_relation = self.store.get_model_relation_by_name(request.name) model_version_relation = self.store.get_model_version_relation_by_version( model_version_detail.model_version, model_relation.uuid) return _wrap_meta_response( MetaToProto.model_version_store_to_proto(model_version_relation, model_version_detail)) @catch_exception def getLatestValidatedModelVersion(self, request, context): model_version_detail = self.store.get_latest_validated_model_version(request.name) if model_version_detail is None: model_version_relation = None else: model_relation = self.store.get_model_relation_by_name(request.name) model_version_relation = self.store.get_model_version_relation_by_version( model_version_detail.model_version, model_relation.uuid) return _wrap_meta_response( MetaToProto.model_version_store_to_proto(model_version_relation, model_version_detail)) @catch_exception def getLatestGeneratedModelVersion(self, request, context): model_version_detail = self.store.get_latest_generated_model_version(request.name) if model_version_detail is None: model_version_relation = None else: model_relation = self.store.get_model_relation_by_name(request.name) model_version_relation = self.store.get_model_version_relation_by_version( model_version_detail.model_version, model_relation.uuid) return _wrap_meta_response( MetaToProto.model_version_store_to_proto(model_version_relation, model_version_detail)) '''job api''' @catch_exception def getJobById(self, request, context): job_meta = self.store.get_job_by_id(request.id) return _wrap_meta_response(MetaToProto.job_meta_to_proto(job_meta)) @catch_exception def getJobByName(self, request, context): job_meta = self.store.get_job_by_name(request.name) return _wrap_meta_response(MetaToProto.job_meta_to_proto(job_meta)) @catch_exception def updateJob(self, request, context): properties = None if request.properties == {} else request.properties job_state = None if request.job_state == 0 else State(StateProto.Name(request.job_state)) job = self.store.update_job(job_name=request.name, job_state=job_state, properties=properties, job_id=request.job_id.value if request.HasField('job_id') else None, workflow_execution_id=request.workflow_execution_id.value if request.HasField('workflow_execution_id') else None, end_time=request.end_time.value if request.HasField('end_time') else None, log_uri=request.log_uri.value if request.HasField('log_uri') else None, signature=request.signature.value if request.HasField('signature') else None) return _wrap_meta_response(MetaToProto.job_meta_to_proto(job)) @catch_exception def listJob(self, request, context): job_meta_list = self.store.list_job(request.page_size, request.offset) return _warp_job_list_response(job_meta_list) @catch_exception def registerJob(self, request, context): job = transform_job_meta(request.job) response = self.store.register_job(name=job.name, workflow_execution_id=job.workflow_execution_id, job_state=job.job_state, properties=job.properties, job_id=job.job_id, start_time=job.start_time, end_time=job.end_time, log_uri=job.log_uri, signature=job.signature) return _wrap_meta_response(MetaToProto.job_meta_to_proto(response)) @catch_exception def updateJobState(self, request, context): _state = ProtoToMeta.proto_to_state(request.state) uuid = self.store.update_job_state(_state, request.name) return _wrap_update_response(uuid) @catch_exception def updateJobEndTime(self, request, context): uuid = self.store.update_job_end_time(request.end_time, request.name) return _wrap_update_response(uuid) @catch_exception def deleteJobById(self, request, context): status = self.store.delete_job_by_id(request.id) return _wrap_delete_response(status) @catch_exception def deleteJobByName(self, request, context): status = self.store.delete_job_by_name(request.name) return _wrap_delete_response(status) '''workflow execution api''' @catch_exception def getWorkFlowExecutionById(self, request, context): execution_meta = self.store.get_workflow_execution_by_id(request.id) return _wrap_meta_response(MetaToProto.workflow_execution_meta_to_proto(execution_meta)) @catch_exception def getWorkFlowExecutionByName(self, request, context): execution_meta = self.store.get_workflow_execution_by_name(request.name) return _wrap_meta_response(MetaToProto.workflow_execution_meta_to_proto(execution_meta)) @catch_exception def updateWorkflowExecution(self, request, context): properties = None if request.properties == {} else request.properties execution_state = None if request.execution_state == 0 else State(StateProto.Name(request.execution_state)) workflow_execution = self.store.update_workflow_execution(execution_name=request.name, execution_state=execution_state, project_id=request.project_id.value if request.HasField( 'project_id') else None, properties=properties, end_time=request.end_time.value if request.HasField( 'end_time') else None, log_uri=request.log_uri_value if request.HasField( 'log_uri') else None, workflow_json=request.workjson.value if request.HasField( 'workflow_json') else None, signature=request.signature.value if request.HasField( 'signature') else None) return _wrap_meta_response(MetaToProto.workflow_execution_meta_to_proto(workflow_execution)) @catch_exception def listWorkFlowExecution(self, request, context): workflow_execution_meta_list = self.store.list_workflow_execution(request.page_size, request.offset) return _warp_workflow_execution_list_response(workflow_execution_meta_list) @catch_exception def registerWorkFlowExecution(self, request, context): workflow_execution = transform_workflow_execution_meta(request.workflow_execution) response = self.store.register_workflow_execution(name=workflow_execution.name, project_id=workflow_execution.project_id, execution_state=workflow_execution.execution_state, properties=workflow_execution.properties, start_time=workflow_execution.start_time, end_time=workflow_execution.end_time, log_uri=workflow_execution.log_uri, workflow_json=workflow_execution.workflow_json, signature=workflow_execution.signature ) return _wrap_meta_response(MetaToProto.workflow_execution_meta_to_proto(response)) @catch_exception def updateWorkflowExecutionEndTime(self, request, context): uuid = self.store.update_workflow_execution_end_time(request.end_time, request.name) return _wrap_update_response(uuid) @catch_exception def updateWorkflowExecutionState(self, request, context): _state = ProtoToMeta.proto_to_state(request.state) uuid = self.store.update_workflow_execution_state(_state, request.name) return _wrap_update_response(uuid) @catch_exception def deleteWorkflowExecutionById(self, request, context): status = self.store.delete_workflow_execution_by_id(request.id) return _wrap_delete_response(status) @catch_exception def deleteWorkflowExecutionByName(self, request, context): status = self.store.delete_workflow_execution_by_name(request.name) return _wrap_delete_response(status) '''project api''' @catch_exception def getProjectById(self, request, context): project_meta = self.store.get_project_by_id(request.id) return _wrap_meta_response(MetaToProto.project_meta_to_proto(project_meta)) @catch_exception def getProjectByName(self, request, context): project_meta = self.store.get_project_by_name(request.name) return _wrap_meta_response(MetaToProto.project_meta_to_proto(project_meta)) @catch_exception def listProject(self, request, context): project_meta_list = self.store.list_project(request.page_size, request.offset) return _warp_project_list_response(project_meta_list) @catch_exception def registerProject(self, request, context): project = transform_project_meta(request.project) response = self.store.register_project(name=project.name, uri=project.uri, properties=project.properties, user=project.user, password=project.password, project_type=project.project_type) return _wrap_meta_response(MetaToProto.project_meta_to_proto(response)) @catch_exception def updateProject(self, request, context): properties = None if request.properties == {} else request.properties project = self.store.update_project(project_name=request.name, uri=request.uri.value if request.HasField('uri') else None, properties=properties, user=request.user.value if request.HasField('user') else None, password=request.password.value if request.HasField('password') else None, project_type=request.project_type.value if request.HasField( 'project_type') else None ) return _wrap_meta_response(MetaToProto.project_meta_to_proto(project)) @catch_exception def deleteProjectById(self, request, context): status = self.store.delete_project_by_id(request.id) return _wrap_delete_response(status) @catch_exception def deleteProjectByName(self, request, context): status = self.store.delete_project_by_name(request.name) return _wrap_delete_response(status) '''artifact api''' @catch_exception def getArtifactById(self, request, context): artifact_meta = self.store.get_artifact_by_id(request.id) return _wrap_meta_response(MetaToProto.artifact_meta_to_proto(artifact_meta)) @catch_exception def getArtifactByName(self, request, context): artifact_meta = self.store.get_artifact_by_name(request.name) return _wrap_meta_response(MetaToProto.artifact_meta_to_proto(artifact_meta)) @catch_exception def registerArtifact(self, request, context): artifact = transform_artifact_meta(request.artifact) response = self.store.register_artifact(name=artifact.name, data_format=artifact.data_format, description=artifact.description, batch_uri=artifact.batch_uri, stream_uri=artifact.stream_uri, create_time=artifact.create_time, update_time=artifact.update_time, properties=artifact.properties) return _wrap_meta_response(MetaToProto.artifact_meta_to_proto(response)) @catch_exception def updateArtifact(self, request, context): properties = None if request.properties == {} else request.properties artifact = self.store.update_artifact(artifact_name=request.name, data_format=request.data_format.value if request.HasField( 'data_format') else None, properties=properties, description=request.description.value if request.HasField( 'description') else None, batch_uri=request.batch_uri.value if request.HasField( 'batch_uri') else None, stream_uri=request.stream_uri.value if request.HasField( 'stream_uri') else None, update_time=request.update_time.value if request.HasField( 'update_time') else None ) return _wrap_meta_response(MetaToProto.artifact_meta_to_proto(artifact)) @catch_exception def listArtifact(self, request, context): artifact_meta_list = self.store.list_artifact(request.page_size, request.offset) return _warp_artifact_list_response(artifact_meta_list) @catch_exception def deleteArtifactById(self, request, context): status = self.store.delete_artifact_by_id(request.id) return _wrap_delete_response(status) @catch_exception def deleteArtifactByName(self, request, context): status = self.store.delete_artifact_by_name(request.name) return _wrap_delete_response(status)
def __init__(self, store_uri, server_uri, notification_uri=None): self.model_repo_store = SqlAlchemyStore(store_uri) if notification_uri is None: self.notification_client = NotificationClient(server_uri) else: self.notification_client = NotificationClient(notification_uri)
class ModelCenterService(model_center_service_pb2_grpc.ModelCenterServiceServicer): def __init__(self, store_uri, server_uri, notification_uri=None): self.model_repo_store = SqlAlchemyStore(store_uri) if notification_uri is None: self.notification_client = NotificationClient(server_uri) else: self.notification_client = NotificationClient(notification_uri) @catch_exception def createRegisteredModel(self, request, context): registered_model_param = RegisteredModelParam.from_proto(request) registered_model_meta = self.model_repo_store.create_registered_model(registered_model_param.model_name, ModelType.Name( registered_model_param.model_type), registered_model_param.model_desc) return _wrap_response(registered_model_meta.to_meta_proto()) @catch_exception def updateRegisteredModel(self, request, context): model_meta_param = RegisteredModel.from_proto(request) registered_model_param = RegisteredModelParam.from_proto(request) registered_model_meta = self.model_repo_store.update_registered_model( RegisteredModel(model_meta_param.model_name), registered_model_param.model_name, ModelType.Name( registered_model_param.model_type), registered_model_param.model_desc) return _wrap_response(None if registered_model_meta is None else registered_model_meta.to_meta_proto()) @catch_exception def deleteRegisteredModel(self, request, context): model_meta_param = RegisteredModel.from_proto(request) self.model_repo_store.delete_registered_model(RegisteredModel(model_name=model_meta_param.model_name)) return _wrap_response(request.model_meta) @catch_exception def listRegisteredModels(self, request, context): registered_models = self.model_repo_store.list_registered_models() return _wrap_response(RegisteredModelMetas(registered_models=[registered_model.to_meta_proto() for registered_model in registered_models])) @catch_exception def getRegisteredModelDetail(self, request, context): model_meta_param = ModelVersion.from_proto(request) registered_model_detail = self.model_repo_store.get_registered_model_detail( RegisteredModel(model_name=model_meta_param.model_name)) return _wrap_response(None if registered_model_detail is None else registered_model_detail.to_detail_proto()) @catch_exception def createModelVersion(self, request, context): model_meta_param = ModelVersion.from_proto(request) model_version_param = ModelVersionParam.from_proto(request) model_version_meta = self.model_repo_store.create_model_version(model_meta_param.model_name, model_version_param.model_path, model_version_param.model_metric, model_version_param.model_flavor, model_version_param.version_desc, model_version_param.current_stage) model_type = MODEL_VERSION_TO_EVENT_TYPE.get(ModelVersionStage.from_string(model_version_param.current_stage)) self.notification_client.send_event(BaseEvent(model_version_meta.model_name, json.dumps(model_version_meta.__dict__), model_type)) return _wrap_response(model_version_meta.to_meta_proto()) @catch_exception def updateModelVersion(self, request, context): model_meta_param = ModelVersion.from_proto(request) model_version_param = ModelVersionParam.from_proto(request) model_version_meta = self.model_repo_store.update_model_version(model_meta_param, model_version_param.model_path, model_version_param.model_metric, model_version_param.model_flavor, model_version_param.version_desc, model_version_param.current_stage) if model_version_param.current_stage is not None: model_type = MODEL_VERSION_TO_EVENT_TYPE.get(ModelVersionStage.from_string(model_version_param.current_stage)) self.notification_client.send_event(BaseEvent(model_version_meta.model_name, json.dumps(model_version_meta.__dict__), model_type)) return _wrap_response(None if model_version_meta is None else model_version_meta.to_meta_proto()) @catch_exception def deleteModelVersion(self, request, context): model_meta_param = ModelVersion.from_proto(request) self.model_repo_store.delete_model_version(model_meta_param) self.notification_client.send_event(BaseEvent(model_meta_param.model_name, json.dumps(model_meta_param.__dict__), ModelVersionEventType.MODEL_DELETED)) return _wrap_response(request.model_meta) @catch_exception def getModelVersionDetail(self, request, context): model_meta_param = ModelVersion.from_proto(request) model_version_meta = self.model_repo_store.get_model_version_detail(model_meta_param) return _wrap_response(None if model_version_meta is None else model_version_meta.to_meta_proto())
def __init__(self, db_uri, server_uri): db_uri = db_uri self.store = SqlAlchemyStore(db_uri) self.model_center_client = ModelCenterClient(server_uri)
def _get_store(db_uri=''): return SqlAlchemyStore(db_uri)
class MetricService(MetricServiceServicer): def __init__(self, db_uri): self.db_uri = db_uri self.store = SqlAlchemyStore(self.db_uri) @catch_exception def registerMetricMeta(self, request, context): metric_meta_proto = request.metric_meta metric_meta = proto_to_metric_meta(metric_meta_proto) res_metric_meta = self.store.register_metric_meta( metric_meta.metric_name, metric_meta.metric_type, metric_meta.project_name, metric_meta.metric_desc, metric_meta.dataset_name, metric_meta.model_name, metric_meta.job_name, metric_meta.start_time, metric_meta.end_time, metric_meta.uri, metric_meta.tags, metric_meta.properties) return _warp_metric_meta_response(res_metric_meta) @catch_exception def updateMetricMeta(self, request, context): metric_meta_proto = request.metric_meta metric_meta = proto_to_metric_meta(metric_meta_proto) res_metric_meta = self.store.update_metric_meta( metric_meta.metric_name, metric_meta.metric_desc, metric_meta.project_name, metric_meta.dataset_name, metric_meta.model_name, metric_meta.job_name, metric_meta.start_time, metric_meta.end_time, metric_meta.uri, metric_meta.tags, metric_meta.properties) return _warp_metric_meta_response(res_metric_meta) @catch_exception def deleteMetricMeta(self, request, context): metric_name = request.metric_name try: self.store.delete_metric_meta(metric_name) return Response(return_code=str(ReturnCode.SUCCESS), return_msg='', data='') except Exception as e: return Response(return_code=str(ReturnCode.INTERNAL_ERROR), return_msg=str(e), data='') @catch_exception def getMetricMeta(self, request, context): metric_name = request.metric_name res_metric_meta = self.store.get_metric_meta(metric_name) return _warp_metric_meta_response(res_metric_meta) @catch_exception def listDatasetMetricMetas(self, request, context): dataset_name = request.dataset_name project_name = request.project_name.value if request.HasField( 'project_name') else None metric_metas = self.store.list_dataset_metric_metas( dataset_name=dataset_name, project_name=project_name) return _warp_list_metric_metas_response(metric_metas) @catch_exception def listModelMetricMetas(self, request, context): model_name = request.model_name project_name = request.project_name.value if request.HasField( 'project_name') else None metric_metas = self.store.list_model_metric_metas( model_name=model_name, project_name=project_name) return _warp_list_metric_metas_response(metric_metas) @catch_exception def registerMetricSummary(self, request, context): metric_summary_proto = request.metric_summary metric_summary = proto_to_metric_summary(metric_summary_proto) res_metric_summary = self.store.register_metric_summary( metric_summary.metric_name, metric_summary.metric_key, metric_summary.metric_value, metric_summary.metric_timestamp, metric_summary.model_version, metric_summary.job_execution_id) return _warp_metric_summary_response(res_metric_summary) @catch_exception def updateMetricSummary(self, request, context): metric_summary_proto = request.metric_summary metric_summary = proto_to_metric_summary(metric_summary_proto) res_metric_summary = self.store.update_metric_summary( metric_summary.uuid, metric_summary.metric_name, metric_summary.metric_key, metric_summary.metric_value, metric_summary.metric_timestamp, metric_summary.model_version, metric_summary.job_execution_id) return _warp_metric_summary_response(res_metric_summary) @catch_exception def deleteMetricSummary(self, request, context): uuid = request.uuid try: self.store.delete_metric_summary(uuid) return Response(return_code=str(ReturnCode.SUCCESS), return_msg='', data='') except Exception as e: return Response(return_code=str(ReturnCode.INTERNAL_ERROR), return_msg=str(e), data='') @catch_exception def getMetricSummary(self, request, context): uuid = request.uuid metric_summary = self.store.get_metric_summary(uuid) return _warp_metric_summary_response(metric_summary) @catch_exception def listMetricSummaries(self, request, context): metric_name = request.metric_name.value if request.HasField( 'metric_name') else None metric_key = request.metric_key.value if request.HasField( 'metric_key') else None model_version = request.model_version.value if request.HasField( 'model_version') else None start_time = request.start_time.value if request.HasField( 'start_time') else None end_time = request.end_time.value if request.HasField( 'end_time') else None metric_summaries = self.store.list_metric_summaries( metric_name, metric_key, model_version, start_time, end_time) return _warp_list_metric_summaries_response(metric_summaries)