예제 #1
0
 def __init__(self,
              store_uri=None,
              port=_PORT,
              start_default_notification: bool = True,
              notification_uri=None):
     self.executor = Executor(futures.ThreadPoolExecutor(max_workers=10))
     self.server = grpc.server(self.executor)
     self.start_default_notification = start_default_notification
     server_uri = 'localhost:{}'.format(port)
     if start_default_notification:
         logging.info("start default notification service.")
         notification_service_pb2_grpc.add_NotificationServiceServicer_to_server(
             NotificationService(store_uri), self.server)
     model_center_service_pb2_grpc.add_ModelCenterServiceServicer_to_server(
         ModelCenterService(store_uri=store_uri,
                            server_uri=server_uri,
                            notification_uri=notification_uri), self.server)
     metadata_service_pb2_grpc.add_MetadataServiceServicer_to_server(
         MetadataService(db_uri=store_uri, server_uri=server_uri),
         self.server)
     metric_service_pb2_grpc.add_MetricServiceServicer_to_server(
         MetricService(db_uri=store_uri), self.server)
     self.deploy_service = DeployService(server_uri=server_uri)
     deploy_service_pb2_grpc.add_DeployServiceServicer_to_server(
         self.deploy_service, self.server)
     self.server.add_insecure_port('[::]:' + str(port))
예제 #2
0
 def __init__(self, service, port=_PORT):
     self.executor = Executor(futures.ThreadPoolExecutor(max_workers=10))
     self.server = grpc.server(self.executor)
     self.service = service
     notification_service_pb2_grpc.add_NotificationServiceServicer_to_server(
         service, self.server)
     self.server.add_insecure_port('[::]:' + str(port))
예제 #3
0
    def __init__(self,
                 store_uri=None,
                 port=_PORT,
                 start_default_notification: bool = True,
                 notification_uri=None,
                 start_meta_service: bool = True,
                 start_model_center_service: bool = True,
                 start_metric_service: bool = True,
                 start_scheduler_service: bool = True,
                 scheduler_service_config: Dict = None,
                 enabled_ha: bool = False,
                 ha_manager=None,
                 ha_server_uri=None,
                 ha_storage=None,
                 ttl_ms: int = 10000):
        self.store_uri = store_uri
        self.db_type = DBType.value_of(extract_db_engine_from_uri(store_uri))
        self.executor = Executor(futures.ThreadPoolExecutor(max_workers=10))
        self.server = grpc.server(self.executor)
        self.start_default_notification = start_default_notification
        self.enabled_ha = enabled_ha
        server_uri = 'localhost:{}'.format(port)
        if start_default_notification:
            logging.info("start default notification service.")
            notification_service_pb2_grpc.add_NotificationServiceServicer_to_server(
                NotificationService.from_storage_uri(store_uri), self.server)
        if start_model_center_service:
            logging.info("start model center service.")
            model_center_service_pb2_grpc.add_ModelCenterServiceServicer_to_server(
                ModelCenterService(
                    store_uri=store_uri,
                    notification_uri=server_uri if start_default_notification
                    and notification_uri is None else notification_uri),
                self.server)
        if start_meta_service:
            logging.info("start meta service.")
            metadata_service_pb2_grpc.add_MetadataServiceServicer_to_server(
                MetadataService(db_uri=store_uri, server_uri=server_uri),
                self.server)
        if start_metric_service:
            logging.info("start metric service.")
            metric_service_pb2_grpc.add_MetricServiceServicer_to_server(
                MetricService(db_uri=store_uri), self.server)

        if start_scheduler_service:
            self._add_scheduler_service(scheduler_service_config)

        if enabled_ha:
            self._add_ha_service(ha_manager, ha_server_uri, ha_storage,
                                 store_uri, ttl_ms)

        self.server.add_insecure_port('[::]:' + str(port))
예제 #4
0
    def __init__(self,
                 store_uri=None,
                 port=_PORT,
                 start_default_notification: bool = True,
                 notification_uri=None,
                 start_meta_service: bool = True,
                 start_model_center_service: bool = True,
                 start_metric_service: bool = True,
                 start_deploy_service: bool = True,
                 start_scheduling_service: bool = True,
                 scheduler_config: Dict = None):
        self.executor = Executor(futures.ThreadPoolExecutor(max_workers=10))
        self.server = grpc.server(self.executor)
        self.start_default_notification = start_default_notification
        server_uri = 'localhost:{}'.format(port)
        if start_default_notification:
            logging.info("start default notification service.")
            notification_service_pb2_grpc.add_NotificationServiceServicer_to_server(
                NotificationService(store_uri), self.server)
        if start_model_center_service:
            logging.info("start model center service.")
            model_center_service_pb2_grpc.add_ModelCenterServiceServicer_to_server(
                ModelCenterService(store_uri=store_uri,
                                   server_uri=server_uri,
                                   notification_uri=notification_uri),
                self.server)
        if start_meta_service:
            logging.info("start meta service.")
            metadata_service_pb2_grpc.add_MetadataServiceServicer_to_server(
                MetadataService(db_uri=store_uri, server_uri=server_uri),
                self.server)
        if start_metric_service:
            logging.info("start metric service.")
            metric_service_pb2_grpc.add_MetricServiceServicer_to_server(
                MetricService(db_uri=store_uri), self.server)

        if start_deploy_service:
            logging.info("start deploy service.")
            self.start_deploy_service = True
            self.deploy_service = DeployService(server_uri=server_uri)
            deploy_service_pb2_grpc.add_DeployServiceServicer_to_server(
                self.deploy_service, self.server)
        else:
            self.start_deploy_service = False

        if start_scheduling_service:
            logging.info("start scheduling service.")
            if scheduler_config is None:
                nf_uri = server_uri if start_default_notification else notification_uri
                scheduler_config = SchedulerConfig()
                scheduler_config.set_notification_service_uri(nf_uri)
                scheduler_config.\
                    set_scheduler_class_name('ai_flow.scheduler.implements.airflow_scheduler.AirFlowScheduler')
                scheduler_config.set_repository('/tmp/airflow')
            real_config = SchedulerConfig()
            if scheduler_config.get('notification_uri') is None:
                nf_uri = server_uri if start_default_notification else notification_uri
                real_config.set_notification_service_uri(nf_uri)
            else:
                real_config.set_notification_service_uri(
                    scheduler_config.get('notification_uri'))
            real_config.set_properties(scheduler_config.get('properties'))
            real_config.set_repository(scheduler_config.get('repository'))
            real_config.set_scheduler_class_name(
                scheduler_config.get('scheduler_class_name'))
            self.scheduling_service = SchedulingService(real_config)
            scheduling_service_pb2_grpc.add_SchedulingServiceServicer_to_server(
                self.scheduling_service, self.server)

        self.server.add_insecure_port('[::]:' + str(port))