예제 #1
0
 def __init__(self, db_uri, server_uri):
     db_engine = extract_db_engine_from_uri(db_uri)
     if DBType.value_of(db_engine) == DBType.MONGODB:
         username, password, host, port, db = parse_mongo_uri(db_uri)
         self.store = MongoStore(host=host,
                                 port=int(port),
                                 username=username,
                                 password=password,
                                 db=db)
     else:
         self.store = SqlAlchemyStore(db_uri)
     self.model_center_client = ModelCenterClient(server_uri)
 def __init__(self, db_uri):
     self.db_uri = db_uri
     db_engine = extract_db_engine_from_uri(self.db_uri)
     if DBType.value_of(db_engine) == DBType.MONGODB:
         username, password, host, port, db = parse_mongo_uri(self.db_uri)
         self.store = MongoStore(host=host,
                                 port=int(port),
                                 username=username,
                                 password=password,
                                 db=db)
     else:
         self.store = SqlAlchemyStore(self.db_uri)
예제 #3
0
 def __init__(self, store_uri, notification_uri=None):
     db_engine = extract_db_engine_from_uri(store_uri)
     if DBType.value_of(db_engine) == DBType.MONGODB:
         username, password, host, port, db = parse_mongo_uri(store_uri)
         self.model_repo_store = MongoStore(host=host,
                                            port=int(port),
                                            username=username,
                                            password=password,
                                            db=db)
     else:
         self.model_repo_store = SqlAlchemyStore(store_uri)
     self.notification_client = None
     if notification_uri is not None:
         self.notification_client = NotificationClient(notification_uri, default_namespace=DEFAULT_NAMESPACE)
예제 #4
0
 def _clear_db(self):
     if self.db_type == DBType.SQLITE:
         store = SqlAlchemyStore(self.store_uri)
         base.metadata.drop_all(store.db_engine)
         base.metadata.create_all(store.db_engine)
     elif self.db_type == DBType.MONGODB:
         MongoStoreConnManager().drop_all()
예제 #5
0
 def _clear_db(self):
     if self.master_config.get_db_type() == DBType.SQLITE:
         store = SqlAlchemyStore(self.master_config.get_db_uri())
         base.metadata.drop_all(store.db_engine)
         base.metadata.create_all(store.db_engine)
     elif self.master_config.get_db_type() == DBType.MONGODB:
         MongoStoreConnManager().drop_all()
예제 #6
0
 def __init__(self,
              store_uri=None,
              port=_PORT,
              start_default_notification: bool = True,
              notification_uri=None,
              ha_manager=None,
              server_uri=None,
              ha_storage=None,
              ttl_ms: int = 10000):
     super(HighAvailableAIFlowServer,
           self).__init__(store_uri, port, start_default_notification,
                          notification_uri)
     if ha_manager is None:
         ha_manager = SimpleAIFlowServerHaManager()
     if server_uri is None:
         raise ValueError("server_uri is required!")
     if ha_storage is None:
         db_engine = extract_db_engine_from_uri(store_uri)
         if DBType.value_of(db_engine) == DBType.MONGODB:
             username, password, host, port, db = parse_mongo_uri(store_uri)
             ha_storage = MongoStore(host=host,
                                     port=int(port),
                                     username=username,
                                     password=password,
                                     db=db)
         else:
             ha_storage = SqlAlchemyStore(store_uri)
     self.ha_service = HighAvailableService(ha_manager, server_uri,
                                            ha_storage, ttl_ms)
     add_HighAvailabilityManagerServicer_to_server(self.ha_service,
                                                   self.server)
 def tearDown(self) -> None:
     print('tearDown')
     store = SqlAlchemyStore(_SQLITE_DB_URI)
     base.metadata.drop_all(store.db_engine)
     base.metadata.create_all(store.db_engine)
     af.default_graph().clear_graph()
     res = client.list_job(page_size=10, offset=0)
     self.assertIsNone(res)
예제 #8
0
 def tearDown(self) -> None:
     self.client.stop_listen_event()
     self.client.disable_high_availability()
     if self.server1 is not None:
         self.server1.stop()
     if self.server2 is not None:
         self.server2.stop()
     if self.server3 is not None:
         self.server3.stop()
     store = SqlAlchemyStore(_SQLITE_DB_URI)
     base.metadata.drop_all(store.db_engine)
예제 #9
0
 def setUp(self) -> None:
     SqlAlchemyStore(_SQLITE_DB_URI)
     self.server1 = HighAvailableAIFlowServer(
         store_uri=_SQLITE_DB_URI, port=50051,
         server_uri='localhost:50051')
     self.server1.run()
     self.server2 = None
     self.server3 = None
     self.config = ProjectConfig()
     self.config.set_enable_ha(True)
     self.client = AIFlowClient(server_uri='localhost:50052,localhost:50051', project_config=self.config)
예제 #10
0
    def stop(self, clear_sql_lite_db_file=True) -> None:
        """
        Stop the AI flow master.

        :param clear_sql_lite_db_file: If True, the sqlite database files will be deleted When the server stops working.
        """
        self.server.stop()
        if self.master_config.get_db_type() == DBType.SQLITE and clear_sql_lite_db_file:
            store = SqlAlchemyStore(self.master_config.get_db_uri())
            base.metadata.drop_all(store.db_engine)
            os.remove(self.master_config.get_sql_lite_db_file())
예제 #11
0
    def stop(self, clear_sql_lite_db_file=False):
        self.executor.shutdown()
        self.server.stop(0)
        if self.enabled_ha:
            self.ha_service.stop()

        if self.db_type == DBType.SQLITE and clear_sql_lite_db_file:
            store = SqlAlchemyStore(self.store_uri)
            base.metadata.drop_all(store.db_engine)
            os.remove(self.store_uri[10:])
        elif self.db_type == DBType.MONGODB:
            MongoStoreConnManager().disconnect_all()

        logging.info('AIFlow server stopped.')
예제 #12
0
 def _add_ha_service(self, ha_manager, ha_server_uri, ha_storage, store_uri,
                     ttl_ms):
     if ha_manager is None:
         ha_manager = SimpleAIFlowServerHaManager()
     if ha_server_uri is None:
         raise ValueError("ha_server_uri is required with ha enabled!")
     if ha_storage is None:
         db_engine = extract_db_engine_from_uri(store_uri)
         if DBType.value_of(db_engine) == DBType.MONGODB:
             username, password, host, port, db = parse_mongo_uri(store_uri)
             ha_storage = MongoStore(host=host,
                                     port=int(port),
                                     username=username,
                                     password=password,
                                     db=db)
         else:
             ha_storage = SqlAlchemyStore(store_uri)
     self.ha_service = HighAvailableService(ha_manager, ha_server_uri,
                                            ha_storage, ttl_ms)
     add_HighAvailabilityManagerServicer_to_server(self.ha_service,
                                                   self.server)
예제 #13
0
 def setUp(self) -> None:
     SqlAlchemyStore(_SQLITE_DB_URI)
     self.notification = NotificationMaster(
         service=NotificationService(storage=MemoryEventStorage()),
         port=30031)
     self.notification.run()
     self.server1 = AIFlowServer(store_uri=_SQLITE_DB_URI,
                                 port=50051,
                                 enabled_ha=True,
                                 start_scheduler_service=False,
                                 ha_server_uri='localhost:50051',
                                 notification_uri='localhost:30031',
                                 start_default_notification=False)
     self.server1.run()
     self.server2 = None
     self.server3 = None
     self.config = ProjectConfig()
     self.config.set_enable_ha(True)
     self.config.set_notification_service_uri('localhost:30031')
     self.client = AIFlowClient(
         server_uri='localhost:50052,localhost:50051',
         project_config=self.config)
예제 #14
0
 def __init__(self,
              store_uri=None,
              port=_PORT,
              start_default_notification: bool = True,
              notification_uri=None,
              ha_manager=None,
              server_uri=None,
              ha_storage=None,
              ttl_ms: int = 10000):
     super(HighAvailableAIFlowServer,
           self).__init__(store_uri, port, start_default_notification,
                          notification_uri)
     if ha_manager is None:
         ha_manager = SimpleAIFlowServerHaManager()
     if server_uri is None:
         raise ValueError("server_uri is required!")
     if ha_storage is None:
         ha_storage = SqlAlchemyStore(store_uri)
     self.ha_service = HighAvailableService(ha_manager, server_uri,
                                            ha_storage, ttl_ms)
     add_HighAvailabilityManagerServicer_to_server(self.ha_service,
                                                   self.server)
예제 #15
0
class ModelCenterService(model_center_service_pb2_grpc.ModelCenterServiceServicer):

    def __init__(self, store_uri, notification_uri=None):
        db_engine = extract_db_engine_from_uri(store_uri)
        if DBType.value_of(db_engine) == DBType.MONGODB:
            username, password, host, port, db = parse_mongo_uri(store_uri)
            self.model_repo_store = MongoStore(host=host,
                                               port=int(port),
                                               username=username,
                                               password=password,
                                               db=db)
        else:
            self.model_repo_store = SqlAlchemyStore(store_uri)
        self.notification_client = None
        if notification_uri is not None:
            self.notification_client = NotificationClient(notification_uri, default_namespace=DEFAULT_NAMESPACE)

    @catch_exception
    def createRegisteredModel(self, request, context):
        registered_model_param = RegisteredModelParam.from_proto(request)
        registered_model_meta = self.model_repo_store.create_registered_model(registered_model_param.model_name,
                                                                              registered_model_param.model_desc)
        return _wrap_response(registered_model_meta.to_meta_proto())

    @catch_exception
    def updateRegisteredModel(self, request, context):
        model_meta_param = RegisteredModel.from_proto(request)
        registered_model_param = RegisteredModelParam.from_proto(request)
        registered_model_meta = self.model_repo_store.update_registered_model(
            RegisteredModel(model_meta_param.model_name),
            registered_model_param.model_name,
            registered_model_param.model_desc)
        return _wrap_response(None if registered_model_meta is None else registered_model_meta.to_meta_proto())

    @catch_exception
    def deleteRegisteredModel(self, request, context):
        model_meta_param = RegisteredModel.from_proto(request)
        self.model_repo_store.delete_registered_model(RegisteredModel(model_name=model_meta_param.model_name))
        return _wrap_response(request.model_meta)

    @catch_exception
    def listRegisteredModels(self, request, context):
        registered_models = self.model_repo_store.list_registered_models()
        return _wrap_response(RegisteredModelMetas(registered_models=[registered_model.to_meta_proto()
                                                                      for registered_model in registered_models]))

    @catch_exception
    def getRegisteredModelDetail(self, request, context):
        model_meta_param = ModelVersion.from_proto(request)
        registered_model_detail = self.model_repo_store.get_registered_model_detail(
            RegisteredModel(model_name=model_meta_param.model_name))
        return _wrap_response(None if registered_model_detail is None else registered_model_detail.to_detail_proto())

    @catch_exception
    def createModelVersion(self, request, context):
        model_meta_param = ModelVersion.from_proto(request)
        model_version_param = ModelVersionParam.from_proto(request)
        model_version_meta = self.model_repo_store.create_model_version(model_meta_param.model_name,
                                                                        model_version_param.model_path,
                                                                        model_version_param.model_type,
                                                                        model_version_param.version_desc,
                                                                        model_version_param.current_stage)
        event_type = MODEL_VERSION_TO_EVENT_TYPE.get(ModelVersionStage.from_string(model_version_param.current_stage))
        if self.notification_client is not None:
            self.notification_client.send_event(BaseEvent(model_version_meta.model_name,
                                                          json.dumps(model_version_meta.__dict__),
                                                          event_type))
        return _wrap_response(model_version_meta.to_meta_proto())

    @catch_exception
    def updateModelVersion(self, request, context):
        model_meta_param = ModelVersion.from_proto(request)
        model_version_param = ModelVersionParam.from_proto(request)
        model_version_meta = self.model_repo_store.update_model_version(model_meta_param,
                                                                        model_version_param.model_path,
                                                                        model_version_param.model_type,
                                                                        model_version_param.version_desc,
                                                                        model_version_param.current_stage)
        if model_version_param.current_stage is not None:
            event_type = MODEL_VERSION_TO_EVENT_TYPE.get(
                ModelVersionStage.from_string(model_version_param.current_stage))
            if self.notification_client is not None:
                self.notification_client.send_event(BaseEvent(model_version_meta.model_name,
                                                              json.dumps(model_version_meta.__dict__),
                                                              event_type))
        return _wrap_response(None if model_version_meta is None else model_version_meta.to_meta_proto())

    @catch_exception
    def deleteModelVersion(self, request, context):
        model_meta_param = ModelVersion.from_proto(request)
        self.model_repo_store.delete_model_version(model_meta_param)
        if self.notification_client is not None:
            self.notification_client.send_event(BaseEvent(model_meta_param.model_name,
                                                          json.dumps(model_meta_param.__dict__),
                                                          ModelVersionEventType.MODEL_DELETED))
        return _wrap_response(request.model_meta)

    @catch_exception
    def getModelVersionDetail(self, request, context):
        model_meta_param = ModelVersion.from_proto(request)
        model_version_meta = self.model_repo_store.get_model_version_detail(model_meta_param)
        return _wrap_response(None if model_version_meta is None else model_version_meta.to_meta_proto())
예제 #16
0
 def _clear_db(self):
     if self.master_config.get_db_type() == DBType.SQLITE:
         store = SqlAlchemyStore(self.master_config.get_db_uri())
         base.metadata.drop_all(store.db_engine)
         base.metadata.create_all(store.db_engine)
예제 #17
0
 def tearDownClass(cls) -> None:
     client.stop_listen_event()
     store = SqlAlchemyStore(_SQLITE_DB_URI)
     base.metadata.drop_all(store.db_engine)
     os.remove(_SQLITE_DB_FILE)
예제 #18
0
class MetricService(MetricServiceServicer):
    def __init__(self, db_uri):
        self.db_uri = db_uri
        self.store = SqlAlchemyStore(db_uri)

    @catch_exception
    def registerMetricMeta(self, request, context):
        metric_meta_proto = request.metric_meta

        metric_meta = proto_to_metric_meta(metric_meta_proto)

        res_metric_meta = self.store.register_metric_meta(
            metric_meta.name, metric_meta.dataset_id, metric_meta.model_name,
            metric_meta.model_version, metric_meta.job_id,
            metric_meta.start_time, metric_meta.end_time,
            metric_meta.metric_type, metric_meta.uri, metric_meta.tags,
            metric_meta.metric_description, metric_meta.properties)

        return _warp_metric_meta_response(res_metric_meta)

    def registerMetricSummary(self, request, context):
        metric_summary_proto = request.metric_summary

        metric_summary = proto_to_metric_summary(metric_summary_proto)

        res_metric_summary = self.store.register_metric_summary(
            metric_id=metric_summary.metric_id,
            metric_key=metric_summary.metric_key,
            metric_value=metric_summary.metric_value)
        return _warp_metric_summary_response(res_metric_summary)

    def updateMetricMeta(self, request, context):
        metric_meta_proto = request.metric_meta
        if metric_meta_proto.metric_type == MetricTypeProto.DATASET:
            metric_type = MetricType.DATASET
        else:
            metric_type = MetricType.MODEL
        res_metric_meta = self.store.update_metric_meta(
            metric_meta_proto.uuid, metric_meta_proto.name.value
            if metric_meta_proto.HasField('name') else None,
            metric_meta_proto.dataset_id.value
            if metric_meta_proto.HasField('dataset_id') else None,
            metric_meta_proto.model_name.value
            if metric_meta_proto.HasField('model_name') else None,
            metric_meta_proto.model_version.value
            if metric_meta_proto.HasField('model_version') else None,
            metric_meta_proto.job_id.value
            if metric_meta_proto.HasField('job_id') else None,
            metric_meta_proto.start_time.value
            if metric_meta_proto.HasField('start_time') else None,
            metric_meta_proto.end_time.value
            if metric_meta_proto.HasField('end_time') else None, metric_type,
            metric_meta_proto.uri.value if metric_meta_proto.HasField('uri')
            else None, metric_meta_proto.tags.value
            if metric_meta_proto.HasField('tags') else None,
            metric_meta_proto.metric_description.value
            if metric_meta_proto.HasField('metric_description') else None,
            metric_meta_proto.properties
            if {} == metric_meta_proto.properties else None)

        return _warp_metric_meta_response(res_metric_meta)

    def updateMetricSummary(self, request, context):
        metric_summary_proto = request.metric_summary
        res_metric_summary = self.store.update_metric_summary(
            uuid=metric_summary_proto.uuid,
            metric_id=metric_summary_proto.metric_id.value
            if metric_summary_proto.HasField('metric_id') else None,
            metric_key=metric_summary_proto.metric_key.value
            if metric_summary_proto.HasField('metric_key') else None,
            metric_value=metric_summary_proto.metric_value.value
            if metric_summary_proto.HasField('metric_value') else None)
        return _warp_metric_summary_response(res_metric_summary)

    def getMetricMeta(self, request, context):
        metric_name = request.metric_name
        res_metric_meta = self.store.get_metric_meta(metric_name)
        return _warp_metric_meta_response(res_metric_meta)

    def getDatasetMetricMeta(self, request, context):
        dataset_id = request.dataset_id
        metric_meta = self.store.get_dataset_metric_meta(dataset_id=dataset_id)
        return _warp_list_metric_meta_response(metric_meta)

    def getModelMetricMeta(self, request, context):
        model_name = request.model_name
        model_version = request.model_version
        metric_meta = self.store.get_model_metric_meta(
            model_name=model_name, model_version=model_version)
        return _warp_list_metric_meta_response(metric_meta)

    def getMetricSummary(self, request, context):
        metric_id = request.metric_id
        metric_summary = self.store.get_metric_summary(metric_id=metric_id)
        return _warp_list_metric_summary_response(metric_summary)

    def deleteMetricMeta(self, request, context):
        uuid = request.uuid
        try:
            self.store.delete_metric_meta(uuid)
            return Response(return_code=str(ReturnCode.SUCCESS),
                            return_msg='',
                            data='')
        except Exception as e:
            return Response(return_code=str(ReturnCode.INTERNAL_ERROR),
                            return_msg=str(e),
                            data='')

    def deleteMetricSummary(self, request, context):
        uuid = request.uuid
        try:
            self.store.delete_metric_summary(uuid)
            return Response(return_code=str(ReturnCode.SUCCESS),
                            return_msg='',
                            data='')
        except Exception as e:
            return Response(return_code=str(ReturnCode.INTERNAL_ERROR),
                            return_msg=str(e),
                            data='')
예제 #19
0
class MetadataService(metadata_service_pb2_grpc.MetadataServiceServicer):
    def __init__(self, db_uri, server_uri):
        db_engine = extract_db_engine_from_uri(db_uri)
        if DBType.value_of(db_engine) == DBType.MONGODB:
            username, password, host, port, db = parse_mongo_uri(db_uri)
            self.store = MongoStore(host=host,
                                    port=int(port),
                                    username=username,
                                    password=password,
                                    db=db)
        else:
            self.store = SqlAlchemyStore(db_uri)
        self.model_center_client = ModelCenterClient(server_uri)

    '''dataset api'''

    @catch_exception
    def getDatasetById(self, request, context):
        dataset = self.store.get_dataset_by_id(request.id)
        return _wrap_meta_response(MetaToProto.dataset_meta_to_proto(dataset))

    @catch_exception
    def getDatasetByName(self, request, context):
        dataset = self.store.get_dataset_by_name(request.name)
        return _wrap_meta_response(MetaToProto.dataset_meta_to_proto(dataset))

    @catch_exception
    def registerDataset(self, request, context):
        dataset = transform_dataset_meta(request.dataset)
        dataset_meta = self.store.register_dataset(
            name=dataset.name,
            data_format=dataset.data_format,
            description=dataset.description,
            uri=dataset.uri,
            properties=dataset.properties,
            name_list=dataset.schema.name_list,
            type_list=dataset.schema.type_list)
        return _wrap_meta_response(
            MetaToProto.dataset_meta_to_proto(dataset_meta))

    @catch_exception
    def registerDatasetWithCatalog(self, request, context):
        dataset = transform_dataset_meta(request.dataset)
        dataset_meta = self.store.register_dataset_with_catalog(
            name=dataset.name,
            catalog_name=dataset.catalog_name,
            catalog_type=dataset.catalog_type,
            catalog_database=dataset.catalog_database,
            catalog_connection_uri=dataset.catalog_connection_uri,
            catalog_table=dataset.catalog_table)
        return _wrap_meta_response(
            MetaToProto.dataset_meta_to_proto(dataset_meta))

    @catch_exception
    def registerDatasets(self, request, context):
        _datasets = ProtoToMeta.proto_to_dataset_meta_list(request.datasets)
        response = self.store.register_datasets(_datasets)
        return _warp_dataset_list_response(response)

    @catch_exception
    def updateDataset(self, request, context):
        properties = None if request.properties == {} else request.properties
        name_list = request.name_list
        type_list = request.type_list
        if not name_list:
            name_list = None
        if not type_list:
            data_type_list = None
        else:
            data_type_list = []
            for data_type in type_list:
                data_type_list.append(DataType(DataTypeProto.Name(data_type)))
        dataset_meta = self.store.update_dataset(dataset_name=request.name,
                                                 data_format=request.data_format.value if request.HasField(
                                                     'data_format') else None,
                                                 description=request.description.value if request.HasField(
                                                     'description') else None,
                                                 uri=request.uri.value if request.HasField('uri') else None,
                                                 properties=properties,
                                                 name_list=name_list,
                                                 type_list=data_type_list,
                                                 catalog_name=request.catalog_name.value if request.HasField(
                                                     'catalog_name') else None,
                                                 catalog_type=request.catalog_type.value if request.HasField(
                                                     'catalog_type') else None,
                                                 catalog_database=request.catalog_database.value if request.HasField(
                                                     'catalog_database') else None,
                                                 catalog_connection_uri=request.catalog_connection_uri.value \
                                                     if request.HasField('catalog_connection_uri') else None,
                                                 catalog_table=request.catalog_table.value if request.HasField(
                                                     'catalog_table') else None)
        return _wrap_meta_response(
            MetaToProto.dataset_meta_to_proto(dataset_meta))

    @catch_exception
    def listDatasets(self, request, context):
        dataset_meta_list = self.store.list_datasets(request.page_size,
                                                     request.offset)
        return _warp_dataset_list_response(dataset_meta_list)

    @catch_exception
    def deleteDatasetById(self, request, context):
        status = self.store.delete_dataset_by_id(request.id)
        return _wrap_delete_response(status)

    @catch_exception
    def deleteDatasetByName(self, request, context):
        status = self.store.delete_dataset_by_name(request.name)
        return _wrap_delete_response(status)

    '''model relation api'''

    @catch_exception
    def getModelRelationById(self, request, context):
        model_meta = self.store.get_model_relation_by_id(request.id)
        return _wrap_meta_response(
            MetaToProto.model_relation_meta_to_proto(model_meta))

    @catch_exception
    def getModelRelationByName(self, request, context):
        model_meta = self.store.get_model_relation_by_name(request.name)
        return _wrap_meta_response(
            MetaToProto.model_relation_meta_to_proto(model_meta))

    @catch_exception
    def registerModelRelation(self, request, context):
        model = transform_model_relation_meta(request.model_relation)
        response = self.store.register_model_relation(
            name=model.name, project_id=model.project_id)
        return _wrap_meta_response(
            MetaToProto.model_relation_meta_to_proto(response))

    @catch_exception
    def listModelRelation(self, request, context):
        model_list = self.store.list_model_relation(request.page_size,
                                                    request.offset)
        return _warp_model_relation_list_response(model_list)

    @catch_exception
    def deleteModelRelationById(self, request, context):
        status = self.store.delete_model_relation_by_id(request.id)
        return _wrap_delete_response(status)

    @catch_exception
    def deleteModelRelationByName(self, request, context):
        status = self.store.delete_model_relation_by_name(request.name)
        return _wrap_delete_response(status)

    '''model api'''

    @catch_exception
    def getModelById(self, request, context):
        model_relation = self.store.get_model_relation_by_id(request.id)
        if model_relation is None:
            model_detail = None
        else:
            model_detail = self.model_center_client.get_registered_model_detail(
                model_relation.name)
        return _wrap_meta_response(
            MetaToProto.model_meta_to_proto(model_relation, model_detail))

    @catch_exception
    def getModelByName(self, request, context):
        model_relation = self.store.get_model_relation_by_name(request.name)
        model_detail = self.model_center_client.get_registered_model_detail(
            request.name)
        return _wrap_meta_response(
            MetaToProto.model_meta_to_proto(model_relation, model_detail))

    @catch_exception
    def registerModel(self, request, context):
        model = transform_model_meta(request.model)
        model_detail = self.model_center_client.create_registered_model(
            model.name, model.model_desc)
        model_relation = self.store.register_model_relation(
            name=model.name, project_id=model.project_id)
        return _wrap_meta_response(
            MetaToProto.model_meta_to_proto(model_relation, model_detail))

    @catch_exception
    def deleteModelById(self, request, context):
        model_relation = self.store.get_model_relation_by_id(request.id)
        if model_relation is None:
            return _wrap_delete_response(Status.ERROR)
        else:
            model_relation_status = self.store.delete_model_relation_by_id(
                request.id)
            self.model_center_client.delete_registered_model(
                model_relation.name)
            return _wrap_delete_response(model_relation_status)

    @catch_exception
    def deleteModelByName(self, request, context):
        model_relation_status = self.store.delete_model_relation_by_name(
            request.name)
        self.model_center_client.delete_registered_model(request.name)
        return _wrap_delete_response(model_relation_status)

    '''model version relation api'''

    @catch_exception
    def getModelVersionRelationByVersion(self, request, context):
        model_version_meta = self.store.get_model_version_relation_by_version(
            request.name, request.model_id)
        return _wrap_meta_response(
            MetaToProto.model_version_relation_meta_to_proto(
                model_version_meta))

    @catch_exception
    def listModelVersionRelation(self, request, context):
        model_version_meta_list = self.store.list_model_version_relation(
            request.model_id, request.page_size, request.offset)
        return _warp_model_version_relation_list_response(
            model_version_meta_list)

    @catch_exception
    def registerModelVersionRelation(self, request, context):
        model_version = transform_model_version_relation_meta(
            request.model_version_relation)
        response = self.store.register_model_version_relation(
            version=model_version.version,
            model_id=model_version.model_id,
            project_snapshot_id=model_version.project_snapshot_id)
        return _wrap_meta_response(
            MetaToProto.model_version_relation_meta_to_proto(response))

    @catch_exception
    def deleteModelVersionRelationByVersion(self, request, context):
        status = self.store.delete_model_version_relation_by_version(
            request.name, request.model_id)
        return _wrap_delete_response(status)

    '''model version api'''

    @catch_exception
    def getModelVersionByVersion(self, request, context):
        model_version_relation = self.store.get_model_version_relation_by_version(
            request.name, request.model_id)
        if model_version_relation is None:
            model_version_detail = None
        else:
            model_relation = self.store.get_model_relation_by_id(
                model_version_relation.model_id)
            model_version_detail = self.model_center_client.get_model_version_detail(
                model_relation.name, request.name)
        return _wrap_meta_response(
            MetaToProto.model_version_meta_to_proto(model_version_relation,
                                                    model_version_detail))

    @catch_exception
    def registerModelVersion(self, request, context):
        model_version = transform_model_version_meta(request.model_version)
        model_relation = self.store.get_model_relation_by_id(
            model_version.model_id)
        model_version_detail = self.model_center_client.create_model_version(
            model_relation.name, model_version.model_path,
            model_version.model_type, model_version.version_desc,
            request.model_version.current_stage)
        model_version_relation = self.store.register_model_version_relation(
            version=model_version_detail.model_version,
            model_id=model_version.model_id,
            project_snapshot_id=model_version.project_snapshot_id)
        return _wrap_meta_response(
            MetaToProto.model_version_meta_to_proto(model_version_relation,
                                                    model_version_detail))

    @catch_exception
    def deleteModelVersionByVersion(self, request, context):
        model_version_relation = self.store.get_model_version_relation_by_version(
            request.name, request.model_id)
        if model_version_relation is None:
            return _wrap_delete_response(Status.ERROR)
        else:
            model_version__status = self.store.delete_model_version_relation_by_version(
                request.name, request.model_id)
            model_relation = self.store.get_model_relation_by_id(
                model_version_relation.model_id)
            if model_relation is not None:
                self.model_center_client.delete_model_version(
                    model_relation.name, request.name)
            return _wrap_delete_response(model_version__status)

    @catch_exception
    def getDeployedModelVersion(self, request, context):
        model_version_detail = self.store.get_deployed_model_version(
            request.name)
        if model_version_detail is None:
            model_version_relation = None
        else:
            model_relation = self.store.get_model_relation_by_name(
                request.name)
            model_version_relation = self.store.get_model_version_relation_by_version(
                model_version_detail.model_version, model_relation.uuid)
        return _wrap_meta_response(
            MetaToProto.model_version_store_to_proto(model_version_relation,
                                                     model_version_detail))

    @catch_exception
    def getLatestValidatedModelVersion(self, request, context):
        model_version_detail = self.store.get_latest_validated_model_version(
            request.name)
        if model_version_detail is None:
            model_version_relation = None
        else:
            model_relation = self.store.get_model_relation_by_name(
                request.name)
            model_version_relation = self.store.get_model_version_relation_by_version(
                model_version_detail.model_version, model_relation.uuid)
        return _wrap_meta_response(
            MetaToProto.model_version_store_to_proto(model_version_relation,
                                                     model_version_detail))

    @catch_exception
    def getLatestGeneratedModelVersion(self, request, context):
        model_version_detail = self.store.get_latest_generated_model_version(
            request.name)
        if model_version_detail is None:
            model_version_relation = None
        else:
            model_relation = self.store.get_model_relation_by_name(
                request.name)
            model_version_relation = self.store.get_model_version_relation_by_version(
                model_version_detail.model_version, model_relation.uuid)
        return _wrap_meta_response(
            MetaToProto.model_version_store_to_proto(model_version_relation,
                                                     model_version_detail))

    '''project api'''

    @catch_exception
    def getProjectById(self, request, context):
        project_meta = self.store.get_project_by_id(request.id)
        return _wrap_meta_response(
            MetaToProto.project_meta_to_proto(project_meta))

    @catch_exception
    def getProjectByName(self, request, context):
        project_meta = self.store.get_project_by_name(request.name)
        return _wrap_meta_response(
            MetaToProto.project_meta_to_proto(project_meta))

    @catch_exception
    def listProject(self, request, context):
        project_meta_list = self.store.list_project(request.page_size,
                                                    request.offset)
        return _warp_project_list_response(project_meta_list)

    @catch_exception
    def registerProject(self, request, context):
        project = transform_project_meta(request.project)
        response = self.store.register_project(name=project.name,
                                               uri=project.uri,
                                               properties=project.properties)
        return _wrap_meta_response(MetaToProto.project_meta_to_proto(response))

    @catch_exception
    def updateProject(self, request, context):
        properties = None if request.properties == {} else request.properties
        project = self.store.update_project(
            project_name=request.name,
            uri=request.uri.value if request.HasField('uri') else None,
            properties=properties)
        return _wrap_meta_response(MetaToProto.project_meta_to_proto(project))

    @catch_exception
    def deleteProjectById(self, request, context):
        status = self.store.delete_project_by_id(request.id)
        return _wrap_delete_response(status)

    @catch_exception
    def deleteProjectByName(self, request, context):
        status = self.store.delete_project_by_name(request.name)
        return _wrap_delete_response(status)

    '''artifact api'''

    @catch_exception
    def getArtifactById(self, request, context):
        artifact_meta = self.store.get_artifact_by_id(request.id)
        return _wrap_meta_response(
            MetaToProto.artifact_meta_to_proto(artifact_meta))

    @catch_exception
    def getArtifactByName(self, request, context):
        artifact_meta = self.store.get_artifact_by_name(request.name)
        return _wrap_meta_response(
            MetaToProto.artifact_meta_to_proto(artifact_meta))

    @catch_exception
    def registerArtifact(self, request, context):
        artifact = transform_artifact_meta(request.artifact)
        response = self.store.register_artifact(
            name=artifact.name,
            artifact_type=artifact.artifact_type,
            description=artifact.description,
            uri=artifact.uri,
            properties=artifact.properties)
        return _wrap_meta_response(
            MetaToProto.artifact_meta_to_proto(response))

    @catch_exception
    def updateArtifact(self, request, context):
        properties = None if request.properties == {} else request.properties
        artifact = self.store.update_artifact(
            name=request.name,
            artifact_type=request.artifact_type.value
            if request.HasField('artifact_type') else None,
            properties=properties,
            description=request.description.value
            if request.HasField('description') else None,
            uri=request.uri.value if request.HasField('uri') else None,
        )
        return _wrap_meta_response(
            MetaToProto.artifact_meta_to_proto(artifact))

    @catch_exception
    def listArtifact(self, request, context):
        artifact_meta_list = self.store.list_artifact(request.page_size,
                                                      request.offset)
        return _warp_artifact_list_response(artifact_meta_list)

    @catch_exception
    def deleteArtifactById(self, request, context):
        status = self.store.delete_artifact_by_id(request.id)
        return _wrap_delete_response(status)

    @catch_exception
    def deleteArtifactByName(self, request, context):
        status = self.store.delete_artifact_by_name(request.name)
        return _wrap_delete_response(status)

    @catch_exception
    def registerWorkflow(self, request, context):
        workflow = transform_workflow_meta(request.workflow)
        response = self.store.register_workflow(name=workflow.name,
                                                project_id=workflow.project_id,
                                                properties=workflow.properties)
        return _wrap_meta_response(
            MetaToProto.workflow_meta_to_proto(response))

    @catch_exception
    def updateWorkflow(self, request, context):
        properties = None if request.properties == {} else request.properties
        workflow = self.store.update_workflow(
            workflow_name=request.workflow_name,
            project_name=request.project_name,
            properties=properties)
        return _wrap_meta_response(
            MetaToProto.workflow_meta_to_proto(workflow))

    def getWorkflowById(self, request, context):
        workflow = self.store.get_workflow_by_id(workflow_id=request.id)
        return _wrap_meta_response(
            MetaToProto.workflow_meta_to_proto(workflow))

    def getWorkflowByName(self, request, context):
        workflow = self.store.get_workflow_by_name(
            project_name=request.project_name,
            workflow_name=request.workflow_name)
        return _wrap_meta_response(
            MetaToProto.workflow_meta_to_proto(workflow))

    def deleteWorkflowById(self, request, context):
        status = self.store.delete_workflow_by_id(workflow_id=request.id)
        return _wrap_delete_response(status)

    def deleteWorkflowByName(self, request, context):
        status = self.store.delete_workflow_by_name(
            project_name=request.project_name,
            workflow_name=request.workflow_name)
        return _wrap_delete_response(status)

    def listWorkflows(self, request, context):
        workflow_meta_list = self.store.list_workflows(
            project_name=request.project_name,
            page_size=request.page_size,
            offset=request.offset)
        return _wrap_workflow_list_response(workflow_meta_list)
예제 #20
0
 def __init__(self, db_uri):
     self.db_uri = db_uri
     self.store = SqlAlchemyStore(db_uri)
예제 #21
0
class MetadataService(metadata_service_pb2_grpc.MetadataServiceServicer):
    def __init__(self, db_uri, server_uri):
        db_uri = db_uri
        self.store = SqlAlchemyStore(db_uri)
        self.model_center_client = ModelCenterClient(server_uri)

    '''example api'''

    @catch_exception
    def getExampleById(self, request, context):
        example = self.store.get_example_by_id(request.id)
        return _wrap_meta_response(MetaToProto.example_meta_to_proto(example))

    @catch_exception
    def getExampleByName(self, request, context):
        example = self.store.get_example_by_name(request.name)
        return _wrap_meta_response(MetaToProto.example_meta_to_proto(example))

    @catch_exception
    def registerExample(self, request, context):
        example = transform_example_meta(request.example)
        example_meta = self.store.register_example(name=example.name,
                                                   support_type=example.support_type,
                                                   data_type=example.data_type,
                                                   data_format=example.data_format,
                                                   description=example.description,
                                                   batch_uri=example.batch_uri,
                                                   stream_uri=example.stream_uri,
                                                   create_time=example.create_time,
                                                   update_time=example.update_time,
                                                   properties=example.properties,
                                                   name_list=example.schema.name_list,
                                                   type_list=example.schema.type_list)
        return _wrap_meta_response(MetaToProto.example_meta_to_proto(example_meta))

    @catch_exception
    def registerExampleWithCatalog(self, request, context):
        example = transform_example_meta(request.example)
        example_meta = self.store.register_example_with_catalog(name=example.name,
                                                                support_type=example.support_type,
                                                                catalog_name=example.catalog_name,
                                                                catalog_type=example.catalog_type,
                                                                catalog_database=example.catalog_database,
                                                                catalog_connection_uri=example.catalog_connection_uri,
                                                                catalog_version=example.catalog_version,
                                                                catalog_table=example.catalog_table)
        return _wrap_meta_response(MetaToProto.example_meta_to_proto(example_meta))

    @catch_exception
    def registerExamples(self, request, context):
        _examples = ProtoToMeta.proto_to_example_meta_list(request.examples)
        response = self.store.register_examples(_examples)
        return _warp_example_list_response(response)

    @catch_exception
    def updateExample(self, request, context):
        support_type = None if request.support_type == 0 else ExampleSupportType(
            ExampleSupportTypeProto.Name(request.support_type))
        properties = None if request.properties == {} else request.properties
        name_list = request.name_list
        type_list = request.type_list
        if not name_list:
            name_list = None
        if not type_list:
            data_type_list = None
        else:
            data_type_list = []
            for data_type in type_list:
                data_type_list.append(DataType(DataTypeProto.Name(data_type)))
        example_meta = self.store.update_example(example_name=request.name,
                                                 support_type=support_type,
                                                 data_type=request.data_type.value if request.HasField(
                                                     'data_type') else None,
                                                 data_format=request.data_format.value if request.HasField(
                                                     'data_format') else None,
                                                 description=request.description.value if request.HasField(
                                                     'description') else None,
                                                 batch_uri=request.batch_uri.value if request.HasField(
                                                     'batch_uri') else None,
                                                 stream_uri=request.stream_uri.value if request.HasField(
                                                     'stream_uri') else None,
                                                 update_time=request.update_time.value if request.HasField(
                                                     'update_time') else None,
                                                 properties=properties,
                                                 name_list=name_list,
                                                 type_list=data_type_list,
                                                 catalog_name=request.catalog_name.value if request.HasField(
                                                     'catalog_name') else None,
                                                 catalog_type=request.catalog_type.value if request.HasField(
                                                     'catalog_type') else None,
                                                 catalog_database=request.catalog_database.value if request.HasField(
                                                     'catalog_database') else None,
                                                 catalog_version=request.catalog_version.value if request.HasField(
                                                     'catalog_version') else None,
                                                 catalog_connection_uri=request.catalog_connection_uri.value \
                                                     if request.HasField('catalog_connection_uri') else None,
                                                 catalog_table=request.catalog_table.value if request.HasField(
                                                     'catalog_table') else None)
        return _wrap_meta_response(MetaToProto.example_meta_to_proto(example_meta))

    @catch_exception
    def listExample(self, request, context):
        example_meta_list = self.store.list_example(request.page_size, request.offset)
        return _warp_example_list_response(example_meta_list)

    @catch_exception
    def deleteExampleById(self, request, context):
        status = self.store.delete_example_by_id(request.id)
        return _wrap_delete_response(status)

    @catch_exception
    def deleteExampleByName(self, request, context):
        status = self.store.delete_example_by_name(request.name)
        return _wrap_delete_response(status)

    '''model relation api'''

    @catch_exception
    def getModelRelationById(self, request, context):
        model_meta = self.store.get_model_relation_by_id(request.id)
        return _wrap_meta_response(MetaToProto.model_relation_meta_to_proto(model_meta))

    @catch_exception
    def getModelRelationByName(self, request, context):
        model_meta = self.store.get_model_relation_by_name(request.name)
        return _wrap_meta_response(MetaToProto.model_relation_meta_to_proto(model_meta))

    @catch_exception
    def registerModelRelation(self, request, context):
        model = transform_model_relation_meta(request.model_relation)
        response = self.store.register_model_relation(name=model.name, project_id=model.project_id)
        return _wrap_meta_response(MetaToProto.model_relation_meta_to_proto(response))

    @catch_exception
    def listModelRelation(self, request, context):
        model_list = self.store.list_model_relation(request.page_size, request.offset)
        return _warp_model_relation_list_response(model_list)

    @catch_exception
    def deleteModelRelationById(self, request, context):
        status = self.store.delete_model_relation_by_id(request.id)
        return _wrap_delete_response(status)

    @catch_exception
    def deleteModelRelationByName(self, request, context):
        status = self.store.delete_model_relation_by_name(request.name)
        return _wrap_delete_response(status)

    '''model api'''

    @catch_exception
    def getModelById(self, request, context):
        model_relation = self.store.get_model_relation_by_id(request.id)
        if model_relation is None:
            model_detail = None
        else:
            model_detail = self.model_center_client.get_registered_model_detail(model_relation.name)
        return _wrap_meta_response(MetaToProto.model_meta_to_proto(model_relation, model_detail))

    @catch_exception
    def getModelByName(self, request, context):
        model_relation = self.store.get_model_relation_by_name(request.name)
        model_detail = self.model_center_client.get_registered_model_detail(request.name)
        return _wrap_meta_response(MetaToProto.model_meta_to_proto(model_relation, model_detail))

    @catch_exception
    def registerModel(self, request, context):
        model = transform_model_meta(request.model)
        model_detail = self.model_center_client.create_registered_model(model.name, ModelType.Name(model.model_type),
                                                                        model.model_desc)
        model_relation = self.store.register_model_relation(name=model.name, project_id=model.project_id)
        return _wrap_meta_response(MetaToProto.model_meta_to_proto(model_relation, model_detail))

    @catch_exception
    def deleteModelById(self, request, context):
        model_relation = self.store.get_model_relation_by_id(request.id)
        if model_relation is None:
            return _wrap_delete_response(Status.ERROR)
        else:
            model_relation_status = self.store.delete_model_relation_by_id(request.id)
            self.model_center_client.delete_registered_model(model_relation.name)
            return _wrap_delete_response(model_relation_status)

    @catch_exception
    def deleteModelByName(self, request, context):
        model_relation_status = self.store.delete_model_relation_by_name(request.name)
        self.model_center_client.delete_registered_model(request.name)
        return _wrap_delete_response(model_relation_status)

    '''model version relation api'''

    @catch_exception
    def getModelVersionRelationByVersion(self, request, context):
        model_version_meta = self.store.get_model_version_relation_by_version(request.name, request.model_id)
        return _wrap_meta_response(MetaToProto.model_version_relation_meta_to_proto(model_version_meta))

    @catch_exception
    def listModelVersionRelation(self, request, context):
        model_version_meta_list = self.store.list_model_version_relation(request.model_id, request.page_size,
                                                                         request.offset)
        return _warp_model_version_relation_list_response(model_version_meta_list)

    @catch_exception
    def registerModelVersionRelation(self, request, context):
        model_version = transform_model_version_relation_meta(request.model_version_relation)
        response = self.store.register_model_version_relation(version=model_version.version,
                                                              model_id=model_version.model_id,
                                                              workflow_execution_id=model_version.workflow_execution_id)
        return _wrap_meta_response(MetaToProto.model_version_relation_meta_to_proto(response))

    @catch_exception
    def deleteModelVersionRelationByVersion(self, request, context):
        status = self.store.delete_model_version_relation_by_version(request.name, request.model_id)
        return _wrap_delete_response(status)

    '''model version api'''

    @catch_exception
    def getModelVersionByVersion(self, request, context):
        model_version_relation = self.store.get_model_version_relation_by_version(request.name, request.model_id)
        if model_version_relation is None:
            model_version_detail = None
        else:
            model_relation = self.store.get_model_relation_by_id(model_version_relation.model_id)
            model_version_detail = self.model_center_client.get_model_version_detail(model_relation.name, request.name)
        return _wrap_meta_response(
            MetaToProto.model_version_meta_to_proto(model_version_relation, model_version_detail))

    @catch_exception
    def registerModelVersion(self, request, context):
        model_version = transform_model_version_meta(request.model_version)
        model_relation = self.store.get_model_relation_by_id(model_version.model_id)
        model_version_detail = self.model_center_client.create_model_version(model_relation.name,
                                                                             model_version.model_path,
                                                                             model_version.model_metric,
                                                                             model_version.model_flavor,
                                                                             model_version.version_desc,
                                                                             request.model_version.current_stage)
        model_version_relation = self.store.register_model_version_relation(version=model_version_detail.model_version,
                                                                            model_id=model_version.model_id,
                                                                            workflow_execution_id=
                                                                            model_version.workflow_execution_id)
        return _wrap_meta_response(
            MetaToProto.model_version_meta_to_proto(model_version_relation, model_version_detail))

    @catch_exception
    def deleteModelVersionByVersion(self, request, context):
        model_version_relation = self.store.get_model_version_relation_by_version(request.name, request.model_id)
        if model_version_relation is None:
            return _wrap_delete_response(Status.ERROR)
        else:
            model_version__status = self.store.delete_model_version_relation_by_version(request.name, request.model_id)
            model_relation = self.store.get_model_relation_by_id(model_version_relation.model_id)
            if model_relation is not None:
                self.model_center_client.delete_model_version(model_relation.name, request.name)
            return _wrap_delete_response(model_version__status)

    @catch_exception
    def getDeployedModelVersion(self, request, context):
        model_version_detail = self.store.get_deployed_model_version(request.name)
        if model_version_detail is None:
            model_version_relation = None
        else:
            model_relation = self.store.get_model_relation_by_name(request.name)
            model_version_relation = self.store.get_model_version_relation_by_version(
                model_version_detail.model_version,
                model_relation.uuid)
        return _wrap_meta_response(
            MetaToProto.model_version_store_to_proto(model_version_relation, model_version_detail))

    @catch_exception
    def getLatestValidatedModelVersion(self, request, context):
        model_version_detail = self.store.get_latest_validated_model_version(request.name)
        if model_version_detail is None:
            model_version_relation = None
        else:
            model_relation = self.store.get_model_relation_by_name(request.name)
            model_version_relation = self.store.get_model_version_relation_by_version(
                model_version_detail.model_version,
                model_relation.uuid)
        return _wrap_meta_response(
            MetaToProto.model_version_store_to_proto(model_version_relation, model_version_detail))

    @catch_exception
    def getLatestGeneratedModelVersion(self, request, context):
        model_version_detail = self.store.get_latest_generated_model_version(request.name)
        if model_version_detail is None:
            model_version_relation = None
        else:
            model_relation = self.store.get_model_relation_by_name(request.name)
            model_version_relation = self.store.get_model_version_relation_by_version(
                model_version_detail.model_version,
                model_relation.uuid)
        return _wrap_meta_response(
            MetaToProto.model_version_store_to_proto(model_version_relation, model_version_detail))

    '''job api'''

    @catch_exception
    def getJobById(self, request, context):
        job_meta = self.store.get_job_by_id(request.id)
        return _wrap_meta_response(MetaToProto.job_meta_to_proto(job_meta))

    @catch_exception
    def getJobByName(self, request, context):
        job_meta = self.store.get_job_by_name(request.name)
        return _wrap_meta_response(MetaToProto.job_meta_to_proto(job_meta))

    @catch_exception
    def updateJob(self, request, context):
        properties = None if request.properties == {} else request.properties
        job_state = None if request.job_state == 0 else State(StateProto.Name(request.job_state))
        job = self.store.update_job(job_name=request.name, job_state=job_state,
                                    properties=properties,
                                    job_id=request.job_id.value if request.HasField('job_id') else None,
                                    workflow_execution_id=request.workflow_execution_id.value
                                    if request.HasField('workflow_execution_id') else None,
                                    end_time=request.end_time.value if request.HasField('end_time') else None,
                                    log_uri=request.log_uri.value if request.HasField('log_uri') else None,
                                    signature=request.signature.value if request.HasField('signature') else None)
        return _wrap_meta_response(MetaToProto.job_meta_to_proto(job))

    @catch_exception
    def listJob(self, request, context):
        job_meta_list = self.store.list_job(request.page_size, request.offset)
        return _warp_job_list_response(job_meta_list)

    @catch_exception
    def registerJob(self, request, context):
        job = transform_job_meta(request.job)
        response = self.store.register_job(name=job.name, workflow_execution_id=job.workflow_execution_id,
                                           job_state=job.job_state, properties=job.properties, job_id=job.job_id,
                                           start_time=job.start_time, end_time=job.end_time, log_uri=job.log_uri,
                                           signature=job.signature)
        return _wrap_meta_response(MetaToProto.job_meta_to_proto(response))

    @catch_exception
    def updateJobState(self, request, context):
        _state = ProtoToMeta.proto_to_state(request.state)
        uuid = self.store.update_job_state(_state, request.name)
        return _wrap_update_response(uuid)

    @catch_exception
    def updateJobEndTime(self, request, context):
        uuid = self.store.update_job_end_time(request.end_time, request.name)
        return _wrap_update_response(uuid)

    @catch_exception
    def deleteJobById(self, request, context):
        status = self.store.delete_job_by_id(request.id)
        return _wrap_delete_response(status)

    @catch_exception
    def deleteJobByName(self, request, context):
        status = self.store.delete_job_by_name(request.name)
        return _wrap_delete_response(status)

    '''workflow execution api'''

    @catch_exception
    def getWorkFlowExecutionById(self, request, context):
        execution_meta = self.store.get_workflow_execution_by_id(request.id)
        return _wrap_meta_response(MetaToProto.workflow_execution_meta_to_proto(execution_meta))

    @catch_exception
    def getWorkFlowExecutionByName(self, request, context):
        execution_meta = self.store.get_workflow_execution_by_name(request.name)
        return _wrap_meta_response(MetaToProto.workflow_execution_meta_to_proto(execution_meta))

    @catch_exception
    def updateWorkflowExecution(self, request, context):
        properties = None if request.properties == {} else request.properties
        execution_state = None if request.execution_state == 0 else State(StateProto.Name(request.execution_state))
        workflow_execution = self.store.update_workflow_execution(execution_name=request.name,
                                                                  execution_state=execution_state,
                                                                  project_id=request.project_id.value if request.HasField(
                                                                      'project_id') else None,
                                                                  properties=properties,
                                                                  end_time=request.end_time.value if request.HasField(
                                                                      'end_time') else None,
                                                                  log_uri=request.log_uri_value if request.HasField(
                                                                      'log_uri') else None,
                                                                  workflow_json=request.workjson.value if request.HasField(
                                                                      'workflow_json') else None,
                                                                  signature=request.signature.value if request.HasField(
                                                                      'signature') else None)
        return _wrap_meta_response(MetaToProto.workflow_execution_meta_to_proto(workflow_execution))

    @catch_exception
    def listWorkFlowExecution(self, request, context):
        workflow_execution_meta_list = self.store.list_workflow_execution(request.page_size, request.offset)
        return _warp_workflow_execution_list_response(workflow_execution_meta_list)

    @catch_exception
    def registerWorkFlowExecution(self, request, context):
        workflow_execution = transform_workflow_execution_meta(request.workflow_execution)
        response = self.store.register_workflow_execution(name=workflow_execution.name,
                                                          project_id=workflow_execution.project_id,
                                                          execution_state=workflow_execution.execution_state,
                                                          properties=workflow_execution.properties,
                                                          start_time=workflow_execution.start_time,
                                                          end_time=workflow_execution.end_time,
                                                          log_uri=workflow_execution.log_uri,
                                                          workflow_json=workflow_execution.workflow_json,
                                                          signature=workflow_execution.signature
                                                          )
        return _wrap_meta_response(MetaToProto.workflow_execution_meta_to_proto(response))

    @catch_exception
    def updateWorkflowExecutionEndTime(self, request, context):
        uuid = self.store.update_workflow_execution_end_time(request.end_time, request.name)
        return _wrap_update_response(uuid)

    @catch_exception
    def updateWorkflowExecutionState(self, request, context):
        _state = ProtoToMeta.proto_to_state(request.state)
        uuid = self.store.update_workflow_execution_state(_state, request.name)
        return _wrap_update_response(uuid)

    @catch_exception
    def deleteWorkflowExecutionById(self, request, context):
        status = self.store.delete_workflow_execution_by_id(request.id)
        return _wrap_delete_response(status)

    @catch_exception
    def deleteWorkflowExecutionByName(self, request, context):
        status = self.store.delete_workflow_execution_by_name(request.name)
        return _wrap_delete_response(status)

    '''project api'''

    @catch_exception
    def getProjectById(self, request, context):
        project_meta = self.store.get_project_by_id(request.id)
        return _wrap_meta_response(MetaToProto.project_meta_to_proto(project_meta))

    @catch_exception
    def getProjectByName(self, request, context):
        project_meta = self.store.get_project_by_name(request.name)
        return _wrap_meta_response(MetaToProto.project_meta_to_proto(project_meta))

    @catch_exception
    def listProject(self, request, context):
        project_meta_list = self.store.list_project(request.page_size, request.offset)
        return _warp_project_list_response(project_meta_list)

    @catch_exception
    def registerProject(self, request, context):
        project = transform_project_meta(request.project)
        response = self.store.register_project(name=project.name, uri=project.uri, properties=project.properties,
                                               user=project.user, password=project.password,
                                               project_type=project.project_type)
        return _wrap_meta_response(MetaToProto.project_meta_to_proto(response))

    @catch_exception
    def updateProject(self, request, context):
        properties = None if request.properties == {} else request.properties
        project = self.store.update_project(project_name=request.name,
                                            uri=request.uri.value if request.HasField('uri') else None,
                                            properties=properties,
                                            user=request.user.value if request.HasField('user') else None,
                                            password=request.password.value if request.HasField('password') else None,
                                            project_type=request.project_type.value if request.HasField(
                                                'project_type') else None
                                            )
        return _wrap_meta_response(MetaToProto.project_meta_to_proto(project))

    @catch_exception
    def deleteProjectById(self, request, context):
        status = self.store.delete_project_by_id(request.id)
        return _wrap_delete_response(status)

    @catch_exception
    def deleteProjectByName(self, request, context):
        status = self.store.delete_project_by_name(request.name)
        return _wrap_delete_response(status)

    '''artifact api'''

    @catch_exception
    def getArtifactById(self, request, context):
        artifact_meta = self.store.get_artifact_by_id(request.id)
        return _wrap_meta_response(MetaToProto.artifact_meta_to_proto(artifact_meta))

    @catch_exception
    def getArtifactByName(self, request, context):
        artifact_meta = self.store.get_artifact_by_name(request.name)
        return _wrap_meta_response(MetaToProto.artifact_meta_to_proto(artifact_meta))

    @catch_exception
    def registerArtifact(self, request, context):
        artifact = transform_artifact_meta(request.artifact)
        response = self.store.register_artifact(name=artifact.name, data_format=artifact.data_format,
                                                description=artifact.description,
                                                batch_uri=artifact.batch_uri, stream_uri=artifact.stream_uri,
                                                create_time=artifact.create_time,
                                                update_time=artifact.update_time, properties=artifact.properties)
        return _wrap_meta_response(MetaToProto.artifact_meta_to_proto(response))

    @catch_exception
    def updateArtifact(self, request, context):
        properties = None if request.properties == {} else request.properties
        artifact = self.store.update_artifact(artifact_name=request.name,
                                              data_format=request.data_format.value if request.HasField(
                                                  'data_format') else None,
                                              properties=properties,
                                              description=request.description.value if request.HasField(
                                                  'description') else None,
                                              batch_uri=request.batch_uri.value if request.HasField(
                                                  'batch_uri') else None,
                                              stream_uri=request.stream_uri.value if request.HasField(
                                                  'stream_uri') else None,
                                              update_time=request.update_time.value if request.HasField(
                                                  'update_time') else None
                                              )
        return _wrap_meta_response(MetaToProto.artifact_meta_to_proto(artifact))

    @catch_exception
    def listArtifact(self, request, context):
        artifact_meta_list = self.store.list_artifact(request.page_size, request.offset)
        return _warp_artifact_list_response(artifact_meta_list)

    @catch_exception
    def deleteArtifactById(self, request, context):
        status = self.store.delete_artifact_by_id(request.id)
        return _wrap_delete_response(status)

    @catch_exception
    def deleteArtifactByName(self, request, context):
        status = self.store.delete_artifact_by_name(request.name)
        return _wrap_delete_response(status)
예제 #22
0
 def __init__(self, store_uri, server_uri, notification_uri=None):
     self.model_repo_store = SqlAlchemyStore(store_uri)
     if notification_uri is None:
         self.notification_client = NotificationClient(server_uri)
     else:
         self.notification_client = NotificationClient(notification_uri)
예제 #23
0
class ModelCenterService(model_center_service_pb2_grpc.ModelCenterServiceServicer):

    def __init__(self, store_uri, server_uri, notification_uri=None):
        self.model_repo_store = SqlAlchemyStore(store_uri)
        if notification_uri is None:
            self.notification_client = NotificationClient(server_uri)
        else:
            self.notification_client = NotificationClient(notification_uri)

    @catch_exception
    def createRegisteredModel(self, request, context):
        registered_model_param = RegisteredModelParam.from_proto(request)
        registered_model_meta = self.model_repo_store.create_registered_model(registered_model_param.model_name,
                                                                              ModelType.Name(
                                                                                  registered_model_param.model_type),
                                                                              registered_model_param.model_desc)
        return _wrap_response(registered_model_meta.to_meta_proto())

    @catch_exception
    def updateRegisteredModel(self, request, context):
        model_meta_param = RegisteredModel.from_proto(request)
        registered_model_param = RegisteredModelParam.from_proto(request)
        registered_model_meta = self.model_repo_store.update_registered_model(
            RegisteredModel(model_meta_param.model_name),
            registered_model_param.model_name,
            ModelType.Name(
                registered_model_param.model_type),
            registered_model_param.model_desc)
        return _wrap_response(None if registered_model_meta is None else registered_model_meta.to_meta_proto())

    @catch_exception
    def deleteRegisteredModel(self, request, context):
        model_meta_param = RegisteredModel.from_proto(request)
        self.model_repo_store.delete_registered_model(RegisteredModel(model_name=model_meta_param.model_name))
        return _wrap_response(request.model_meta)

    @catch_exception
    def listRegisteredModels(self, request, context):
        registered_models = self.model_repo_store.list_registered_models()
        return _wrap_response(RegisteredModelMetas(registered_models=[registered_model.to_meta_proto()
                                                                      for registered_model in registered_models]))

    @catch_exception
    def getRegisteredModelDetail(self, request, context):
        model_meta_param = ModelVersion.from_proto(request)
        registered_model_detail = self.model_repo_store.get_registered_model_detail(
            RegisteredModel(model_name=model_meta_param.model_name))
        return _wrap_response(None if registered_model_detail is None else registered_model_detail.to_detail_proto())

    @catch_exception
    def createModelVersion(self, request, context):
        model_meta_param = ModelVersion.from_proto(request)
        model_version_param = ModelVersionParam.from_proto(request)
        model_version_meta = self.model_repo_store.create_model_version(model_meta_param.model_name,
                                                                        model_version_param.model_path,
                                                                        model_version_param.model_metric,
                                                                        model_version_param.model_flavor,
                                                                        model_version_param.version_desc,
                                                                        model_version_param.current_stage)
        model_type = MODEL_VERSION_TO_EVENT_TYPE.get(ModelVersionStage.from_string(model_version_param.current_stage))
        self.notification_client.send_event(BaseEvent(model_version_meta.model_name,
                                                      json.dumps(model_version_meta.__dict__),
                                                      model_type))
        return _wrap_response(model_version_meta.to_meta_proto())

    @catch_exception
    def updateModelVersion(self, request, context):
        model_meta_param = ModelVersion.from_proto(request)
        model_version_param = ModelVersionParam.from_proto(request)
        model_version_meta = self.model_repo_store.update_model_version(model_meta_param,
                                                                        model_version_param.model_path,
                                                                        model_version_param.model_metric,
                                                                        model_version_param.model_flavor,
                                                                        model_version_param.version_desc,
                                                                        model_version_param.current_stage)
        if model_version_param.current_stage is not None:
            model_type = MODEL_VERSION_TO_EVENT_TYPE.get(ModelVersionStage.from_string(model_version_param.current_stage))
            self.notification_client.send_event(BaseEvent(model_version_meta.model_name,
                                                          json.dumps(model_version_meta.__dict__),
                                                          model_type))
        return _wrap_response(None if model_version_meta is None else model_version_meta.to_meta_proto())

    @catch_exception
    def deleteModelVersion(self, request, context):
        model_meta_param = ModelVersion.from_proto(request)
        self.model_repo_store.delete_model_version(model_meta_param)
        self.notification_client.send_event(BaseEvent(model_meta_param.model_name,
                                                      json.dumps(model_meta_param.__dict__),
                                                      ModelVersionEventType.MODEL_DELETED))
        return _wrap_response(request.model_meta)

    @catch_exception
    def getModelVersionDetail(self, request, context):
        model_meta_param = ModelVersion.from_proto(request)
        model_version_meta = self.model_repo_store.get_model_version_detail(model_meta_param)
        return _wrap_response(None if model_version_meta is None else model_version_meta.to_meta_proto())
예제 #24
0
 def __init__(self, db_uri, server_uri):
     db_uri = db_uri
     self.store = SqlAlchemyStore(db_uri)
     self.model_center_client = ModelCenterClient(server_uri)
def _get_store(db_uri=''):
    return SqlAlchemyStore(db_uri)
예제 #26
0
class MetricService(MetricServiceServicer):
    def __init__(self, db_uri):
        self.db_uri = db_uri
        self.store = SqlAlchemyStore(self.db_uri)

    @catch_exception
    def registerMetricMeta(self, request, context):
        metric_meta_proto = request.metric_meta
        metric_meta = proto_to_metric_meta(metric_meta_proto)
        res_metric_meta = self.store.register_metric_meta(
            metric_meta.metric_name, metric_meta.metric_type,
            metric_meta.project_name, metric_meta.metric_desc,
            metric_meta.dataset_name, metric_meta.model_name,
            metric_meta.job_name, metric_meta.start_time, metric_meta.end_time,
            metric_meta.uri, metric_meta.tags, metric_meta.properties)
        return _warp_metric_meta_response(res_metric_meta)

    @catch_exception
    def updateMetricMeta(self, request, context):
        metric_meta_proto = request.metric_meta
        metric_meta = proto_to_metric_meta(metric_meta_proto)
        res_metric_meta = self.store.update_metric_meta(
            metric_meta.metric_name, metric_meta.metric_desc,
            metric_meta.project_name, metric_meta.dataset_name,
            metric_meta.model_name, metric_meta.job_name,
            metric_meta.start_time, metric_meta.end_time, metric_meta.uri,
            metric_meta.tags, metric_meta.properties)
        return _warp_metric_meta_response(res_metric_meta)

    @catch_exception
    def deleteMetricMeta(self, request, context):
        metric_name = request.metric_name
        try:
            self.store.delete_metric_meta(metric_name)
            return Response(return_code=str(ReturnCode.SUCCESS),
                            return_msg='',
                            data='')
        except Exception as e:
            return Response(return_code=str(ReturnCode.INTERNAL_ERROR),
                            return_msg=str(e),
                            data='')

    @catch_exception
    def getMetricMeta(self, request, context):
        metric_name = request.metric_name
        res_metric_meta = self.store.get_metric_meta(metric_name)
        return _warp_metric_meta_response(res_metric_meta)

    @catch_exception
    def listDatasetMetricMetas(self, request, context):
        dataset_name = request.dataset_name
        project_name = request.project_name.value if request.HasField(
            'project_name') else None
        metric_metas = self.store.list_dataset_metric_metas(
            dataset_name=dataset_name, project_name=project_name)
        return _warp_list_metric_metas_response(metric_metas)

    @catch_exception
    def listModelMetricMetas(self, request, context):
        model_name = request.model_name
        project_name = request.project_name.value if request.HasField(
            'project_name') else None
        metric_metas = self.store.list_model_metric_metas(
            model_name=model_name, project_name=project_name)
        return _warp_list_metric_metas_response(metric_metas)

    @catch_exception
    def registerMetricSummary(self, request, context):
        metric_summary_proto = request.metric_summary
        metric_summary = proto_to_metric_summary(metric_summary_proto)
        res_metric_summary = self.store.register_metric_summary(
            metric_summary.metric_name, metric_summary.metric_key,
            metric_summary.metric_value, metric_summary.metric_timestamp,
            metric_summary.model_version, metric_summary.job_execution_id)
        return _warp_metric_summary_response(res_metric_summary)

    @catch_exception
    def updateMetricSummary(self, request, context):
        metric_summary_proto = request.metric_summary
        metric_summary = proto_to_metric_summary(metric_summary_proto)
        res_metric_summary = self.store.update_metric_summary(
            metric_summary.uuid, metric_summary.metric_name,
            metric_summary.metric_key, metric_summary.metric_value,
            metric_summary.metric_timestamp, metric_summary.model_version,
            metric_summary.job_execution_id)
        return _warp_metric_summary_response(res_metric_summary)

    @catch_exception
    def deleteMetricSummary(self, request, context):
        uuid = request.uuid
        try:
            self.store.delete_metric_summary(uuid)
            return Response(return_code=str(ReturnCode.SUCCESS),
                            return_msg='',
                            data='')
        except Exception as e:
            return Response(return_code=str(ReturnCode.INTERNAL_ERROR),
                            return_msg=str(e),
                            data='')

    @catch_exception
    def getMetricSummary(self, request, context):
        uuid = request.uuid
        metric_summary = self.store.get_metric_summary(uuid)
        return _warp_metric_summary_response(metric_summary)

    @catch_exception
    def listMetricSummaries(self, request, context):
        metric_name = request.metric_name.value if request.HasField(
            'metric_name') else None
        metric_key = request.metric_key.value if request.HasField(
            'metric_key') else None
        model_version = request.model_version.value if request.HasField(
            'model_version') else None
        start_time = request.start_time.value if request.HasField(
            'start_time') else None
        end_time = request.end_time.value if request.HasField(
            'end_time') else None
        metric_summaries = self.store.list_metric_summaries(
            metric_name, metric_key, model_version, start_time, end_time)
        return _warp_list_metric_summaries_response(metric_summaries)