예제 #1
0
 def project_config(self) -> ProjectConfig:
     """
     return: The project configuration(ai_flow.project.project_config.ProjectConfig)
     """
     if self._project_config is None:
         self._project_config = ProjectConfig()
         self._project_config.load_from_file(self.project_config_file)
     return self._project_config
예제 #2
0
 def setUp(self) -> None:
     SqlAlchemyStore(_SQLITE_DB_URI)
     self.server1 = HighAvailableAIFlowServer(
         store_uri=_SQLITE_DB_URI, port=50051,
         server_uri='localhost:50051')
     self.server1.run()
     self.server2 = None
     self.server3 = None
     self.config = ProjectConfig()
     self.config.set_enable_ha(True)
     self.client = AIFlowClient(server_uri='localhost:50052,localhost:50051', project_config=self.config)
def get_project_description_from(project_path: Text) -> ProjectDesc:
    """
    Load a project descriptor for a given project path.
    :param project_path: the path of a ai flow project.
    :return: a ProjectDesc object that contains the structure information of this project.
    """
    project_spec = ProjectDesc()
    project_path = os.path.abspath(project_path)
    project_spec.project_path = project_path
    project_path_obj = Path(project_path)
    project_spec.jar_dependencies = get_file_paths_from(
        str(project_path_obj / 'jar_dependencies'))
    project_spec.python_dependencies = get_file_paths_from(
        str(project_path_obj / 'python_codes'))
    project_spec.resources = get_file_paths_from(
        str(project_path_obj / 'resources'))
    if not os.path.exists(project_spec.get_absolute_temp_path()):
        os.makedirs(project_spec.get_absolute_temp_path())
    project_spec.project_config = ProjectConfig()
    project_spec.project_config.load_from_file(
        os.path.join(project_path, 'project.yaml'))
    # adapter to old scheduler
    if _default_project_config.get_project_uuid() is not None:
        project_spec.project_config.set_project_uuid(
            _default_project_config.get_project_uuid())
    if 'entry_module_path' in _default_project_config:
        project_spec.project_config[
            'entry_module_path'] = _default_project_config['entry_module_path']
    project_spec.project_name = project_spec.project_config.get_project_name()
    return project_spec
예제 #4
0
 def setUp(self) -> None:
     SqlAlchemyStore(_SQLITE_DB_URI)
     self.notification = NotificationMaster(
         service=NotificationService(storage=MemoryEventStorage()),
         port=30031)
     self.notification.run()
     self.server1 = AIFlowServer(store_uri=_SQLITE_DB_URI,
                                 port=50051,
                                 enabled_ha=True,
                                 start_scheduler_service=False,
                                 ha_server_uri='localhost:50051',
                                 notification_uri='localhost:30031',
                                 start_default_notification=False)
     self.server1.run()
     self.server2 = None
     self.server3 = None
     self.config = ProjectConfig()
     self.config.set_enable_ha(True)
     self.config.set_notification_service_uri('localhost:30031')
     self.client = AIFlowClient(
         server_uri='localhost:50052,localhost:50051',
         project_config=self.config)
예제 #5
0
def build_project_context(project_path: Text) -> ProjectContext:
    """
    Load a project context for a given project path.
    :param project_path: the path of a ai flow project.
    :return: a ProjectContext object that contains the structure information of this project.
    """
    project_context = ProjectContext()
    project_path = os.path.abspath(project_path)
    project_context.project_path = project_path
    project_context.project_config = ProjectConfig()
    project_context.project_config.load_from_file(
        project_context.get_project_config_file())
    return project_context
 def test_load_project_config(self):
     project_path = get_file_dir(__file__)
     project_config = ProjectConfig()
     project_config.load_from_file(os.path.join(project_path, 'project.yaml'))
     self.assertEqual(project_config.get_server_uri(), "localhost:50051")
     self.assertIsNone(project_config.get('ai_flow config', None))
     self.assertEqual(project_config['ai_flow_home'], '/opt/ai_flow')
     self.assertEqual(project_config['ai_flow_job_master.host'], 'localhost')
     self.assertEqual(project_config['ai_flow_job_master.port'], 8081)
     self.assertEqual(project_config['ai_flow_conf'], 'taskmanager.slot=2')
예제 #7
0
 def test_translate_ai_graph_to_workflow(self):
     init_workflow_config(os.path.join(os.path.dirname(__file__), 'workflow_1.yaml'))
     project_context = ProjectContext()
     project_context.project_path = '/tmp'
     project_context.project_config = ProjectConfig()
     project_context.project_config.set_project_name('test_project')
     graph: AIGraph = build_ai_graph(9, 3)
     splitter = GraphSplitter()
     split_graph = splitter.split(graph)
     self.assertEqual(3, len(split_graph.nodes))
     self.assertEqual(1, len(split_graph.edges))
     self.assertEqual(2, len(split_graph.edges.get('job_2')))
     sub_graph = split_graph.nodes.get('job_0')
     self.assertTrue('AINode_4' in sub_graph.nodes)
     self.assertTrue('AINode_4' in sub_graph.edges)
     constructor = WorkflowConstructor()
     constructor.register_job_generator('mock', MockJobGenerator())
     workflow = constructor.build_workflow(split_graph, project_context)
     self.assertEqual(3, len(workflow.nodes))
     job = workflow.jobs.get('job_0')
     self.assertEqual(1, len(job.input_dataset_list))
     self.assertEqual(1, len(job.output_dataset_list))
예제 #8
0
def unset_project_config():
    global _default_project_config_set_flag, _default_project_config
    _default_project_config = ProjectConfig()
    _default_project_config_set_flag = False
예제 #9
0
class TestHighAvailableAIFlowServer(unittest.TestCase):
    @staticmethod
    def start_aiflow_server(host, port):
        port = str(port)
        server_uri = host + ":" + port
        server = AIFlowServer(store_uri=_SQLITE_DB_URI,
                              port=port,
                              enabled_ha=True,
                              start_scheduler_service=False,
                              ha_server_uri=server_uri,
                              notification_uri='localhost:30031',
                              start_default_notification=False)
        server.run()
        return server

    def wait_for_new_members_detected(self, new_member_uri):
        while True:
            living_member = self.client.living_aiflow_members
            if new_member_uri in living_member:
                break
            else:
                time.sleep(1)

    def setUp(self) -> None:
        SqlAlchemyStore(_SQLITE_DB_URI)
        self.notification = NotificationMaster(
            service=NotificationService(storage=MemoryEventStorage()),
            port=30031)
        self.notification.run()
        self.server1 = AIFlowServer(store_uri=_SQLITE_DB_URI,
                                    port=50051,
                                    enabled_ha=True,
                                    start_scheduler_service=False,
                                    ha_server_uri='localhost:50051',
                                    notification_uri='localhost:30031',
                                    start_default_notification=False)
        self.server1.run()
        self.server2 = None
        self.server3 = None
        self.config = ProjectConfig()
        self.config.set_enable_ha(True)
        self.config.set_notification_service_uri('localhost:30031')
        self.client = AIFlowClient(
            server_uri='localhost:50052,localhost:50051',
            project_config=self.config)

    def tearDown(self) -> None:
        self.client.stop_listen_event()
        self.client.disable_high_availability()
        if self.server1 is not None:
            self.server1.stop()
        if self.server2 is not None:
            self.server2.stop()
        if self.server3 is not None:
            self.server3.stop()
        if self.notification is not None:
            self.notification.stop()
        store = SqlAlchemyStore(_SQLITE_DB_URI)
        base.metadata.drop_all(store.db_engine)

    def test_server_change(self) -> None:
        self.client.register_project("test_project")
        projects = self.client.list_project(10, 0)
        self.assertEqual(self.client.current_aiflow_uri, "localhost:50051")
        self.assertEqual(projects[0].name, "test_project")

        self.server2 = self.start_aiflow_server("localhost", 50052)
        self.wait_for_new_members_detected("localhost:50052")
        self.server1.stop()
        projects = self.client.list_project(10, 0)
        self.assertEqual(self.client.current_aiflow_uri, "localhost:50052")
        self.assertEqual(projects[0].name, "test_project")

        self.server3 = self.start_aiflow_server("localhost", 50053)
        self.wait_for_new_members_detected("localhost:50053")
        self.server2.stop()
        projects = self.client.list_project(10, 0)
        self.assertEqual(self.client.current_aiflow_uri, "localhost:50053")
        self.assertEqual(projects[0].name, "test_project")
예제 #10
0
 def __init__(self,
              server_uri=_SERVER_URI,
              notification_service_uri=None,
              project_config: ProjectConfig = None):
     MetadataClient.__init__(self, server_uri)
     ModelCenterClient.__init__(self, server_uri)
     DeployClient.__init__(self, server_uri)
     MetricClient.__init__(self, server_uri)
     self.enable_ha = False
     self.list_member_interval_ms = 5000
     self.retry_interval_ms = 1000
     self.retry_timeout_ms = 10000
     if project_config is not None:
         if server_uri is None:
             server_uri = project_config.get_master_uri()
         if notification_service_uri is None:
             notification_service_uri = project_config.get_notification_service_uri(
             )
         self.enable_ha = project_config.get_enable_ha()
         self.list_member_interval_ms = project_config.get_list_member_interval_ms(
         )
         self.retry_interval_ms = project_config.get_retry_interval_ms()
         self.retry_timeout_ms = project_config.get_retry_timeout_ms()
     if notification_service_uri is None:
         NotificationClient.__init__(
             self,
             server_uri,
             enable_ha=self.enable_ha,
             list_member_interval_ms=self.list_member_interval_ms,
             retry_interval_ms=self.retry_interval_ms,
             retry_timeout_ms=self.retry_timeout_ms)
     else:
         NotificationClient.__init__(
             self,
             notification_service_uri,
             enable_ha=self.enable_ha,
             list_member_interval_ms=self.list_member_interval_ms,
             retry_interval_ms=self.retry_interval_ms,
             retry_timeout_ms=self.retry_timeout_ms)
     if self.enable_ha:
         server_uris = server_uri.split(",")
         self.living_aiflow_members = []
         self.current_aiflow_uri = None
         last_error = None
         for server_uri in server_uris:
             channel = grpc.insecure_channel(server_uri)
             high_availability_stub = HighAvailabilityManagerStub(channel)
             try:
                 request = ListMembersRequest(timeout_seconds=0)
                 response = high_availability_stub.listMembers(request)
                 if response.return_code == ReturnStatus.CALL_SUCCESS:
                     self.living_aiflow_members = [
                         proto_to_member(proto).server_uri
                         for proto in response.members
                     ]
                 else:
                     raise Exception(response.return_msg)
                 self.current_aiflow_uri = server_uri
                 self.high_availability_stub = high_availability_stub
                 break
             except grpc.RpcError as e:
                 last_error = e
         if self.current_aiflow_uri is None:
             raise Exception(
                 "No available aiflow server uri!") from last_error
         self.aiflow_ha_change_lock = threading.Lock()
         self.aiflow_ha_running = True
         self._replace_aiflow_stubs(self.current_aiflow_uri)
         self.list_aiflow_member_thread = threading.Thread(
             target=self._list_aiflow_members, daemon=True)
         self.list_aiflow_member_thread.start()
예제 #11
0
    Load a project context for a given project path.
    :param project_path: the path of a ai flow project.
    :return: a ProjectContext object that contains the structure information of this project.
    """
    project_context = ProjectContext()
    project_path = os.path.abspath(project_path)
    project_context.project_path = project_path
    project_context.project_config = ProjectConfig()
    project_context.project_config.load_from_file(
        project_context.get_project_config_file())
    return project_context


__current_project_context__ = ProjectContext()

__current_project_config__ = ProjectConfig()


def init_project_config(project_config_file):
    """
    Load project configuration of the ai flow project.
    """
    __current_project_config__.load_from_file(project_config_file)


def init_project_context(project_path: Text):
    """
    Load project configuration and project context of the ai flow project.
    """
    global __current_project_context__, __current_project_config__
    project_context = build_project_context(project_path)
예제 #12
0
class JobRuntimeEnv(object):
    """
    JobRuntimeEnv represents the environment information needed for an ai flow job to run. It contains:
    1. project configuration.
    2. workflow configuration.
    3. Job running depends on resource files.
    4. Information when the job is executed.
    """

    def __init__(self,
                 working_dir: Text,
                 job_execution_info: JobExecutionInfo = None):
        self._working_dir: Text = working_dir
        self._job_execution_info: JobExecutionInfo = job_execution_info
        self._workflow_config: WorkflowConfig = None
        self._project_config: ProjectConfig = None

    @property
    def working_dir(self) -> Text:
        """
        return: The working directory of the job.
        """
        return self._working_dir

    @property
    def workflow_name(self) -> Text:
        """
        return: The name of the workflow which the job belongs.
        """
        return self.job_execution_info.workflow_execution.workflow_info.workflow_name

    @property
    def workflow_dir(self) -> Text:
        """
        return: The directory of the workflow file.
        """
        return os.path.join(self.working_dir, self.workflow_name)

    @property
    def job_name(self) -> Text:
        """
        return: The name of the job.
        """
        if self._job_execution_info is None:
            return self.job_execution_info.job_name
        return self._job_execution_info.job_name

    @property
    def log_dir(self) -> Text:
        """
        return: The directory where job logs are stored.
        """
        return os.path.join(self._working_dir, 'logs')

    @property
    def resource_dir(self) -> Text:
        """
        return: The directory where job resource files are stored.
        """
        return os.path.join(self._working_dir, 'resources')

    @property
    def generated_dir(self) -> Text:
        """
        return: The directory where the job stores the generated executable files.
        """
        return os.path.join(self._working_dir, 'generated')

    @property
    def dependencies_dir(self) -> Text:
        """
        return: The directory where the job runs dependent files.
        """
        return os.path.join(self._working_dir, 'dependencies')

    @property
    def python_dep_dir(self) -> Text:
        """
        return: The directory where the job runs dependent python files.
        """
        return os.path.join(self.dependencies_dir, 'python')

    @property
    def go_dep_dir(self) -> Text:
        """
        return: The directory where the job runs dependent go files.
        """
        return os.path.join(self.dependencies_dir, 'go')

    @property
    def jar_dep_dir(self) -> Text:
        """
        return: The directory where the job runs dependent jar files.
        """
        return os.path.join(self.dependencies_dir, 'jar')

    @property
    def project_config_file(self) -> Text:
        """
        return: The project configuration file path.
        """
        return os.path.join(self.working_dir, 'project.yaml')

    @property
    def project_config(self) -> ProjectConfig:
        """
        return: The project configuration(ai_flow.project.project_config.ProjectConfig)
        """
        if self._project_config is None:
            self._project_config = ProjectConfig()
            self._project_config.load_from_file(self.project_config_file)
        return self._project_config

    @property
    def workflow_config_file(self) -> Text:
        """
        return: The workflow configuration file path.
        """
        return os.path.join(self.workflow_dir, '{}.yaml'.format(self.workflow_name))

    @property
    def workflow_config(self) -> WorkflowConfig:
        """
        return: The workflow configuration(ai_flow.workflow.workflow_config.WorkflowConfig)
        """
        if self._workflow_config is None:
            self._workflow_config = load_workflow_config(config_path=self.workflow_config_file)
        return self._workflow_config

    @property
    def workflow_entry_file(self) -> Text:
        """
        return: The path of file that defines the workflow.
        """
        return os.path.join(self.workflow_dir, '{}.py'.format(self.workflow_name))

    @property
    def job_execution_info(self) -> JobExecutionInfo:
        """
        return: Information when the job is executed.
        """
        if self._job_execution_info is None:
            self._job_execution_info = serialization_utils.read_object_from_serialized_file(
                os.path.join(self.working_dir, 'job_execution_info'))
        return self._job_execution_info

    def save_job_execution_info(self):
        if self._job_execution_info is None:
            return
        file_path = os.path.join(self.working_dir, 'job_execution_info')
        if os.path.exists(file_path):
            os.remove(file_path)
        with open(file_path, 'wb') as fp:
            fp.write(serialization_utils.serialize(self._job_execution_info))