Exemplo n.º 1
0
class CWriteFileBase(object):
    def __init__(self, file_path):
        self.m_file_handler = CFileHandle()
        self.m_file_path = file_path

    def is_debug(self):
        return True

    def implement(self):
        return ""

    def write(self, encoding="utf8"):
        content = ""
        content += self.implement()
        if self.is_debug() is True:
            print(content)
        else:
            dirname = os.path.dirname(self.m_file_path)
            if os.path.exists(dirname) is False:
                user_input = input("[Tip] dir is not exist, create ? [y/n]")
                if user_input.lower() != "y":
                    raise SystemExit("[Waring] Create Failed")
                else:
                    os.makedirs(dirname)
            self.m_file_handler.clear_write(content, self.m_file_path,
                                            encoding)
Exemplo n.º 2
0
 def __init__(self, file_path, root="."):
     self.m_file_handler = CFileHandle()
     self.m_file_path = ""
     self.m_content = ""
     self.m_namespace = ""
     self.m_import_list = []
     self.__compare_file_path(file_path, root)
Exemplo n.º 3
0
 def __init__(self, parser, file_path, root="."):
     CWriteBase.__init__(self, parser)
     self.m_file_handler = CFileHandle()
     self.m_file_name = ""
     self.m_file_path = ""
     self.m_content = ""
     self.m_namespace = ""
     self.__compare_file_path(file_path, root)
Exemplo n.º 4
0
Arquivo: main.py Projeto: MwlLj/sql2go
 def __write_config(self, path):
     content = ""
     content += "host=localhost\n"
     content += "port=3306\n"
     content += "dbname=test\n"
     content += "username=root\n"
     content += "userpwd=123456\n"
     handler = CFileHandle()
     handler.clear_write(content, path, "utf8")
Exemplo n.º 5
0
 def __init__(self, parser, file_path, root="."):
     self.m_file_handler = CFileHandle()
     self.m_file_name = ""
     self.m_file_path = ""
     self.m_content = ""
     self.m_namespace = ""
     self.m_procedure_info_list = []
     self.m_parser = parser
     self.__compare_file_path(file_path, root)
Exemplo n.º 6
0
 def __init__(self, file_path, root="."):
     self.m_file_handler = CFileHandle()
     self.m_file_path = ""
     self.m_content = ""
     self.m_namespace = ""
     self.m_uuid = str(uuid.uuid4())
     self.m_uuid = re.sub(r"-", "", self.m_uuid).upper()
     self.m_struct_list = []
     self.m_class_implement = {}
     self.__compare_file_path(file_path, root)
     CWriteCppBase.__init__(self, self.m_file_path)
Exemplo n.º 7
0
 def __init__(self, file_path):
     self.m_file_handler = CFileHandle()
     self.m_file_path = file_path
Exemplo n.º 8
0
class CWriteCppBase(object):
    def __init__(self, file_path):
        self.m_file_handler = CFileHandle()
        self.m_file_path = file_path

    def is_debug(self):
        return True

    def is_header(self):
        return True

    def define_name(self):
        # #ifndef xxx
        # #define xxx
        return ""

    def include_sys_list(self):
        # #include <xxx>
        return []

    def include_other_list(self):
        # #include "xxx"
        return []

    def namespace_list(self):
        # [(namespace1, [class1, class2]), (namespace2, [class1, class2])]
        return []

    def implement(self, namespace_name, class_name):
        return ""

    def namespace_implement_begin(self, namespace):
        return ""

    def namespace_implement_end(self, namespace):
        return ""

    def write(self, encoding="utf8"):
        content = ""
        if self.define_name() != "":
            content += self.write_header()
        sys_include_len = len(self.include_sys_list())
        other_include_len = len(self.include_other_list())
        include_count = sys_include_len + other_include_len
        if include_count > 0:
            content += self.write_includes()
        namespace_list_len = len(self.namespace_list())
        if namespace_list_len > 0:
            content += self.write_namespace()
        else:
            content += self.implement("", "")
        if self.define_name() != "":
            content += self.write_tail()
        if self.is_debug() is True:
            print(content)
        else:
            self.m_file_handler.clear_write(content, self.m_file_path,
                                            encoding)

    def write_member_var(self, param_type, param_name):
        content = ""
        param_type, is_custom_type = self.type_change(param_type)
        content += "{0} {1};".format(param_type, param_name)
        return content

    def write_get_method(self, param_type, param_name):
        content = ""
        param_type, is_custom_type = self.type_change(param_type)
        content += "const {0} &get{1}() const".format(
            param_type, CStringTools.upperFirstByte(param_name))
        content += " { return this->" + param_name + "; }"
        return content

    def write_get_mut_method(self, param_type, param_name):
        content = ""
        param_type, is_custom_type = self.type_change(param_type)
        content += "{0} &getMut{1}()".format(
            param_type, CStringTools.upperFirstByte(param_name))
        content += " { return this->" + param_name + "; }"
        return content

    def write_set_method(self, param_type, param_name):
        content = ""
        param_type, is_custom_type = self.type_change(param_type)
        content += "void set{0}(const {1} &{2})".format(
            CStringTools.upperFirstByte(param_name), param_type, param_name)
        content += " { this->" + "{0} = {0}".format(param_name) + "; }"
        return content

    def __write_construction_param(self, param_type, param_name):
        content = ""
        if param_type is None or param_name is None:
            return content
        param_type, is_custom_type = self.type_change(param_type)
        content += "const {0} &{1}".format(param_type, param_name)
        return content

    def write_construction_param_list(self, param_list):
        content = ""
        length = len(param_list)
        i = 0
        for param_type, param_name in param_list:
            i += 1
            content += self.__write_construction_param(param_type, param_name)
            if i < length:
                content += ", "
        return content

    def write_default_init_param_list(self, param_list):
        return self.__write_init_param_list(param_list, True)

    def write_member_init_param_list(self, param_list):
        return self.__write_init_param_list(param_list, False)

    def __write_init_param(self, param_type, param_name, is_default):
        content = ""
        if param_type is None or param_name is None:
            return content
        param_type, is_custom_type = self.type_change(param_type)
        value = param_name
        if is_custom_type is False:
            if param_type == "std::string":
                if is_default is True:
                    value = '""'
            else:
                if is_default is True:
                    value = "0"
        else:
            if is_default is True:
                value = ""
        content += "{0}({1})".format(param_name, value)
        return content

    def __write_init_param_list(self, param_list, is_default):
        content = ""
        length = len(param_list)
        i = 0
        for param in param_list:
            if len(param) != 2:
                raise RuntimeError("param format error")
            i += 1
            param_type, param_name = param
            content += self.__write_init_param(param_type, param_name,
                                               is_default)
            if i < length:
                content += ", "
        return content

    def write_header(self):
        content = ""
        if self.is_header() is True:
            content += "#ifndef {0}\n".format(self.define_name())
            content += "#define {0}\n".format(self.define_name())
            content += "\n"
        return content

    def write_includes(self):
        content = ""
        for include in self.include_sys_list():
            content += '#include <{0}>\n'.format(include)
        for include in self.include_other_list():
            content += '#include "{0}"\n'.format(include)
        content += "\n"
        return content

    def write_namespace(self):
        content = ""
        namespaces = self.namespace_list()
        for np_name, class_list in namespaces:
            if np_name != "":
                content += "namespace {0}\n".format(np_name)
                content += "{\n\n"
            content += self.namespace_implement_begin(np_name)
            for class_name in class_list:
                if class_name == "":
                    continue
                if self.is_header() is True:
                    content += "class {0}\n".format(class_name)
                    content += "{\n"
                content += self.implement(np_name, class_name)
                if self.is_header() is True:
                    content += "};\n\n"
            if len(class_list) == 0:
                content += self.implement(np_name, "")
            content += self.namespace_implement_end(np_name)
            if np_name != "":
                content += "}\n\n"
        return content

    def write_tail(self):
        content = ""
        if self.is_header() is True:
            content += "#endif // {0}\n".format(self.define_name())
        return content

    def type_change(self, param_type):
        # return param_type, is_custom_type
        return param_type, False
Exemplo n.º 9
0
class CWriteParamClass(CWriteBase):
    def __init__(self, file_path, root="."):
        self.m_file_handler = CFileHandle()
        self.m_file_path = ""
        self.m_content = ""
        self.m_namespace = ""
        self.m_import_list = []
        self.__compare_file_path(file_path, root)

    def __compare_file_path(self, file_path, root):
        basename = os.path.basename(file_path)
        filename, fileext = os.path.splitext(basename)
        self.m_file_path = os.path.join(root, filename + "_db_param.h")

    def define_name(self):
        return "__{0}_DB_PARAM_H__".format(self.m_namespace.upper())

    def include_sys_list(self):
        return ["string", "list"]

    def include_other_list(self):
        paths = []
        for path in self.m_import_list:
            path = re.sub(r'"', "", path)
            paths.append(path)
        return paths

    def namespace(self):
        return self.m_namespace

    def class_name(self):
        return "CDbParam"

    def write(self, info_dict):
        # 获取 namesapce
        namespace = info_dict.get(CSqlParse.NAMESPACE)
        if namespace is None or namespace == "":
            raise RuntimeError("namespace is empty")
        self.m_namespace = namespace
        self.m_import_list = info_dict.get(CSqlParse.IMPORT_LIST)
        content = ""
        content += self.write_header()
        content += self.write_includes()
        content += self.write_namespace_define()
        # content += self.write_class_define()
        content += self.__write_implement(info_dict)
        # content += self.write_class_end()
        content += self.write_namespace_end()
        content += self.write_tail()
        self.m_content += content
        # print(self.m_content)
        self.m_file_handler.clear_write(self.m_content, self.m_file_path,
                                        "utf8")

    def __write_implement(self, info_dict):
        content = ""
        method_list = info_dict.get(CSqlParse.METHOD_LIST)
        for method_info in method_list:
            func_name = method_info.get(CSqlParse.FUNC_NAME)
            input_params = method_info.get(CSqlParse.INPUT_PARAMS)
            output_params = method_info.get(CSqlParse.OUTPUT_PARAMS)
            if input_params is not None:
                content += self.__write_class(func_name, input_params, True)
            if output_params is not None:
                content += self.__write_class(func_name, output_params, False)
        return content

    def __write_class(self, func_name, param_list, is_input):
        content = ""
        class_name = ""
        if is_input is True:
            class_name = self.get_input_class_name(func_name)
        else:
            class_name = self.get_output_class_name(func_name)
        content += "class {0}\n".format(class_name)
        content += "{\n"
        content += "public:\n"
        content += "\t" * 1 + "explicit {0}()\n".format(class_name)
        content += "\t" * 2 + ": {0}".format(
            self.write_default_init_param_list(param_list)) + " {}\n"
        content += "\t" * 1 + "explicit {0}({1})\n".format(
            class_name, self.write_construction_param_list(param_list))
        content += "\t" * 2 + ": {0}".format(
            self.write_member_init_param_list(param_list)) + " {}\n"
        content += "\t" * 1 + "virtual ~{0}()".format(class_name) + " {}\n"
        content += self.__write_methods(param_list)
        content += self.__write_private_member(param_list)
        content += "};\n"
        content += "\n"
        return content

    def __write_methods(self, param_list):
        content = ""
        content += "\n"
        content += "public:\n"
        for param in param_list:
            param_type = param.get(CSqlParse.PARAM_TYPE)
            param_name = param.get(CSqlParse.PARAM_NAME)
            if param_type is None or param_name is None:
                continue
            content += "\t" * 1 + self.write_set_method(
                param_type, param_name) + "\n"
            content += "\t" * 1 + self.write_get_method(
                param_type, param_name) + "\n"
            content += "\t" * 1 + self.write_get_method(
                "bool", param_name + "Used") + "\n"
        return content

    def __write_private_member(self, param_list):
        content = ""
        content += "\n"
        content += "private:\n"
        for param in param_list:
            param_type = param.get(CSqlParse.PARAM_TYPE)
            param_name = param.get(CSqlParse.PARAM_NAME)
            if param_type is None or param_name is None:
                continue
            content += "\t" * 1 + self.write_member_var(
                param_type, param_name) + "\n"
            content += "\t" * 1 + self.write_member_var(
                "bool", param_name + "Used") + "\n"
        return content
Exemplo n.º 10
0
class CWriteSqliteImpH(CWriteBase):
    def __init__(self, parser, file_path, root="."):
        CWriteBase.__init__(self, parser)
        self.m_file_handler = CFileHandle()
        self.m_file_name = ""
        self.m_file_path = ""
        self.m_content = ""
        self.m_namespace = ""
        self.__compare_file_path(file_path, root)

    def __compare_file_path(self, file_path, root):
        basename = os.path.basename(file_path)
        filename, fileext = os.path.splitext(basename)
        self.m_file_name = filename
        self.m_file_path = os.path.join(root, filename + "_db_handler.h")

    def define_name(self):
        return "__{0}_DB_HANDLER_H__".format(self.m_namespace.upper())

    def include_sys_list(self):
        return ["stdint.h"]

    def include_other_list(self):
        return ["{0}_db_param.h".format(self.m_file_name), "sql.h"]

    def namespace(self):
        return self.m_namespace

    def class_name(self):
        return "CDbHandler"

    def write(self, info_dict):
        # 获取 namesapce
        namespace = info_dict.get(CSqlParse.NAMESPACE)
        if namespace is None or namespace == "":
            raise RuntimeError("namespace is empty")
        self.m_namespace = namespace
        content = ""
        content += self.write_header()
        content += self.write_includes()
        """
		content += "namespace sql\n"
		content += "{\n"
		content += "class IConnect;\n"
		content += "}\n"
		"""
        content += self.write_namespace_define()
        content += self.write_class_define()
        content += self.__write_implement(info_dict)
        content += self.write_class_end()
        content += self.write_namespace_end()
        content += self.write_tail()
        self.m_content += content
        # print(self.m_content)
        self.m_file_handler.clear_write(self.m_content, self.m_file_path,
                                        "utf8")

    def __write_implement(self, info_dict):
        content = ""
        content += "public:\n"
        content += "\t" * 1 + "explicit {0}(const std::string &dial, sql::ISql *s, int max = 1);\n".format(
            self.class_name())
        content += "\t" * 1 + "virtual ~{0}();\n".format(self.class_name())
        method_list = info_dict.get(CSqlParse.METHOD_LIST)
        content += self.__write_methods(method_list)
        content += self.__write_private_member()
        return content

    def __write_methods(self, method_list):
        content = ""
        content += "\n"
        content += "public:\n"
        content += "\t" * 1 + "/*@@start@@*/" + "\n"
        for method_info in method_list:
            content += self.write_method_define(method_info)
        content += "\npublic:\n"
        content += "\t" * 1 + "sql::IConnect *connect() { return m_connPool.connect(m_dial); }\n"
        content += "\t" * 1 + "void freeConnect(sql::IConnect *conn) { m_connPool.freeConnect(conn); }\n"
        return content

    def __write_private_member(self):
        content = ""
        content += "\n"
        content += "private:\n"
        content += "\t" * 1 + "sql::CConnPool m_connPool;\n"
        content += "\t" * 1 + "std::string m_dial;\n"
        return content
Exemplo n.º 11
0
 def __init__(self, file_path, root="."):
     self.m_file_handler = CFileHandle()
     self.m_file_path = ""
     self.m_content = ""
     self.m_class_set = set()
     self.__compare_file_path(file_path, root)
Exemplo n.º 12
0
class CWriteParamClass(CWriteBase):
    def __init__(self, file_path, root="."):
        self.m_file_handler = CFileHandle()
        self.m_file_path = ""
        self.m_content = ""
        self.m_class_set = set()
        self.__compare_file_path(file_path, root)

    def __class_is_writed(self, is_input, method):
        name = ""
        if is_input is True:
            name = self.get_input_struct_name(method)
        else:
            name = self.get_output_struct_name(method)
        if name in self.m_class_set:
            return True
        else:
            self.m_class_set.add(name)
            return False

    def __compare_file_path(self, file_path, root):
        basename = os.path.basename(file_path)
        filename, fileext = os.path.splitext(basename)
        self.m_file_path = os.path.join(root, filename + "_db_param.go")

    def write(self, info_dict):
        # 获取 namesapce
        namespace = info_dict.get(CSqlParse.NAMESPACE)
        if namespace is None or namespace == "":
            raise RuntimeError("namespace is empty")
        self.__write_header(namespace)
        method_list = info_dict.get(CSqlParse.METHOD_LIST)
        self.m_content += self.__write_structs(method_list)
        # print(self.m_content)
        self.m_file_handler.clear_write(self.m_content, self.m_file_path,
                                        "utf8")

    def __write_structs(self, method_list):
        content = ""
        for method in method_list:
            input_params = method.get(CSqlParse.INPUT_PARAMS)
            output_params = method.get(CSqlParse.OUTPUT_PARAMS)
            method_name = method.get(CSqlParse.FUNC_NAME)
            if input_params is not None:
                content += self.__write_struct(
                    method, True, method_name, input_params,
                    self.get_input_struct_name(method))
            if output_params is not None:
                content += self.__write_struct(
                    method, False, method_name, output_params,
                    self.get_output_struct_name(method))
        return content

    def __write_struct(self, method, is_input, method_name, params,
                       struct_name):
        content = ""
        if self.__class_is_writed(is_input, method) is True:
            return content
        content += "type {0} struct".format(struct_name)
        content += " {\n"
        for param in params:
            param_type = param.get(CSqlParse.PARAM_TYPE)
            param_name = param.get(CSqlParse.PARAM_NAME)
            if param_type is None or param_name is None:
                raise SystemExit(
                    "[Error] method: {0}, type or name is none".format(
                        method_name))
            param_type = self.type_change(param_type)
            content += "\t" + "{0} {1}\n".format(
                CStringTools.upperFirstByte(param_name), param_type)
            content += "\t" + "{0}{1} bool\n".format(
                CStringTools.upperFirstByte(param_name),
                self.get_isvail_join_str())
        content += "}\n\n"
        return content

    def __write_header(self, namespace):
        # 写宏定义防止多包含
        self.m_content += "package {0}\n\n".format(namespace)
Exemplo n.º 13
0
class CWriteSqliteImpH(CWriteBase):
    def __init__(self, file_path, root="."):
        self.m_file_handler = CFileHandle()
        self.m_file_name = ""
        self.m_file_path = ""
        self.m_content = ""
        self.m_namespace = ""
        self.__compare_file_path(file_path, root)

    def __compare_file_path(self, file_path, root):
        basename = os.path.basename(file_path)
        filename, fileext = os.path.splitext(basename)
        self.m_file_name = filename
        self.m_file_path = os.path.join(root, filename + "_db_handler.h")

    def define_name(self):
        return "__{0}_DB_HANDLER_H__".format(self.m_namespace.upper())

    def include_sys_list(self):
        return ["stdint.h", "mutex"]

    def include_other_list(self):
        return ["{0}_db_param.h".format(self.m_file_name)]

    def namespace(self):
        return self.m_namespace

    def class_name(self):
        return "CDbHandler"

    def write(self, info_dict):
        # 获取 namesapce
        namespace = info_dict.get(CSqlParse.NAMESPACE)
        if namespace is None or namespace == "":
            raise RuntimeError("namespace is empty")
        self.m_namespace = namespace
        content = ""
        content += self.write_header()
        content += self.write_includes()
        content += "struct sqlite3;\n"
        content += self.write_namespace_define()
        content += self.write_class_define()
        content += self.__write_implement(info_dict)
        content += self.write_class_end()
        content += self.write_namespace_end()
        content += self.write_tail()
        self.m_content += content
        # print(self.m_content)
        self.m_file_handler.clear_write(self.m_content, self.m_file_path,
                                        "utf8")

    def __write_implement(self, info_dict):
        content = ""
        content += "public:\n"
        content += "\t" * 1 + "explicit {0}(const std::string &dbpath, bool isMemory = false);\n".format(
            self.class_name())
        content += "\t" * 1 + "virtual ~{0}();\n".format(self.class_name())
        method_list = info_dict.get(CSqlParse.METHOD_LIST)
        content += self.__write_methods(method_list)
        content += self.__write_private_member()
        return content

    def __write_methods(self, method_list):
        content = ""
        content += "\n"
        content += "public:\n"
        content += "\t" * 1 + "/*@@start@@*/" + "\n"
        for method_info in method_list:
            content += self.write_method_define(method_info)
        return content

    def __write_private_member(self):
        content = ""
        content += "\n"
        content += "private:\n"
        content += "\t" * 1 + "sqlite3 *m_db;\n"
        content += "\t" * 1 + "std::mutex m_mutex;\n"
        return content
Exemplo n.º 14
0
class CWriteInterface(CWriteBase):
    DICT_KEY_METHOD_DEFINE = "method_define"
    DICT_KEY_PROCEDURE_INFO = "procedure_info"
    DICT_KEY_INPUT_PARAMS = "input_params"
    DICT_KEY_OUTPUT_PARAMS = "output_params"
    DICT_KEY_OUTPUT_CLASSNAME = "output_classname"

    def __init__(self, parser, file_path, root="."):
        self.m_file_handler = CFileHandle()
        self.m_file_name = ""
        self.m_file_path = ""
        self.m_content = ""
        self.m_namespace = ""
        self.m_procedure_info_list = []
        self.m_parser = parser
        self.__compare_file_path(file_path, root)

    def __compare_file_path(self, file_path, root):
        basename = os.path.basename(file_path)
        filename, fileext = os.path.splitext(basename)
        self.m_file_name = filename
        self.m_file_path = os.path.join(root, filename + "_db_handler" + ".go")

    def get_class_name(self):
        return "CDbHandler"

    def write(self, info_dict):
        # 获取 namesapce
        namespace = info_dict.get(CSqlParse.NAMESPACE)
        if namespace is None or namespace == "":
            raise RuntimeError("namespace is empty")
        self.m_namespace = namespace
        self.__write_header(namespace)
        self.m_content += self.__write_stuct()
        self.m_content += self.__write_connect()
        self.m_content += self.__write_connect_by_rule()
        self.m_content += self.__write_connect_by_cfg()
        self.m_content += self.__write_disconnect()
        create_sql = info_dict.get(CSqlParse.CREATE_TABELS_SQL)
        create_functions = info_dict.get(CSqlParse.CREATE_FUNCTION_SQLS)
        self.m_content += self.__write_create(create_sql, create_functions)
        # 获取每一个存储过程的参数
        method_list = info_dict.get(CSqlParse.METHOD_LIST)
        if method_list is None:
            raise SystemExit("[Error] method is None")
        self.m_content += self.__write_struct_method(method_list)
        # print(self.m_content)
        self.m_file_handler.clear_write(self.m_content, self.m_file_path,
                                        "utf8")

    def __join_method_param(self, method, method_define, param_no):
        if method is None:
            return method_define
        sub_func_list = method.get(CSqlParse.SUB_FUNC_SORT_LIST)
        func_name = method.get(CSqlParse.FUNC_NAME)
        input_params = method.get(CSqlParse.INPUT_PARAMS)
        output_params = method.get(CSqlParse.OUTPUT_PARAMS)

        def inner(method_define, param_no):
            input_class_name = self.get_input_struct_name(method)
            output_class_name = self.get_output_struct_name(method)
            in_isarr = method.get(CSqlParse.IN_ISARR)
            out_isarr = method.get(CSqlParse.OUT_ISARR)
            in_ismul = None
            out_ismul = None
            if in_isarr == "true":
                in_ismul = True
            else:
                in_ismul = False
            if out_isarr == "true":
                out_ismul = True
            else:
                out_ismul = False
            input_params_len = 0
            output_params_len = 0
            if input_params is not None:
                input_params_len = len(input_params)
            if output_params is not None:
                output_params_len = len(output_params)
            # 获取输入输出参数的字符串
            input_str = input_class_name
            output_str = output_class_name
            if in_ismul is True:
                input_str = "[]{0}".format(input_class_name)
            if out_ismul is True:
                output_str = "[]{0}".format(output_class_name)
            if input_params_len == 0 and output_params_len == 0:
                method_define += ""
            elif input_params_len > 0 and output_params_len == 0:
                method_define += "input{1} *{0}".format(
                    input_str, str(param_no))
            elif input_params_len == 0 and output_params_len > 0:
                method_define += "output{1} *{0}".format(
                    output_str, str(param_no))
            elif input_params_len > 0 and output_params_len > 0:
                method_define += "input{2} *{0}, output{2} *{1}".format(
                    input_str, output_str, str(param_no))
            else:
                return None
            param_no += 1
            return method_define, param_no

        if sub_func_list is None:
            method_define, param_no = inner(method_define, param_no)
        else:
            i = 0
            length = len(sub_func_list)
            for sub_func_name, sub_func_index in sub_func_list:
                i += 1
                if func_name == sub_func_name:
                    method_define, param_no = inner(method_define, param_no)
                    if i < length and (input_params is not None
                                       or output_params is not None):
                        method_define += ", "
                    continue
                method_info = self.m_parser.get_methodinfo_by_methodname(
                    sub_func_name)
                method_define, param_no = self.__join_method_param(
                    method_info, method_define, param_no)
                if i < length and (input_params is not None
                                   or output_params is not None):
                    method_define += ", "
        return method_define, param_no

    def __sub_func_index_change(self, sub_func_index):
        if sub_func_index == "":
            return ""
        result = ""
        if int(sub_func_index) < 0:
            result = "N" + str(int(sub_func_index) * -1)
        elif int(sub_func_index) > 0:
            result = "P" + sub_func_index
        else:
            result = "0"
        return result

    def __write_struct_method(self, method_list):
        content = ""
        for method in method_list:
            # ################################
            is_brace = method.get(CSqlParse.IS_BRACE)
            if is_brace is None:
                continue
            is_group = method.get(CSqlParse.IS_GROUP)
            if is_group is not None and is_group is True:
                continue
            # ################################
            func_name = method.get(CSqlParse.FUNC_NAME)
            method_name = self.get_interface_name(func_name)
            method_define, _ = self.__join_method_param(method, "", 0)
            if method_define is None:
                return content
            else:
                method_define = "{0}({1})".format(method_name, method_define)
            content += "func (this *{0}) ".format(
                self.get_class_name()) + method_define + " (error, uint64) {\n"
            content += self.get_method_imp(method)
            content += "}\n\n"
        return content

    def __replace_sql_brace(self, input_params, sql, is_group):
        if input_params is None:
            return sql, []
        fulls, max_number = CStringTools.get_brace_format_list(sql)
        param_len = len(input_params)
        full_set = set(fulls)
        full_len = len(full_set)
        if is_group is False:
            if param_len != full_len:
                str_tmp = "[Param Length Error] may be last #define error ? fulllen length({1}) != params length({2})\n[sql] : \t{0}".format(
                    sql, full_len, param_len)
                raise SystemExit(str_tmp)
            if param_len < max_number + 1:
                str_tmp = "[Param Match Error] may be last #define error ? input param length == {1}, max index == {2}\n[sql] : \t{0}".format(
                    sql, param_len, max_number)
                raise SystemExit(str_tmp)
        for number, keyword in list(full_set):
            inpams = input_params[number]
            tmp = ""
            param_type = inpams.get(CSqlParse.PARAM_TYPE)
            if inpams.get(CSqlParse.PARAM_IS_CONDITION) is True:
                tmp = "%s"
            else:
                tmp = "?"
            sql = re.sub(keyword, tmp, sql)
        return sql, fulls

    def get_method_imp(self, method):
        content = ""
        in_isarr = method.get(CSqlParse.IN_ISARR)
        out_isarr = method.get(CSqlParse.OUT_ISARR)
        in_ismul = None
        out_ismul = None
        if in_isarr == "true":
            in_ismul = True
        else:
            in_ismul = False
        if out_isarr == "true":
            out_ismul = True
        else:
            out_ismul = False
        func_name = method.get(CSqlParse.FUNC_NAME)
        input_params = method.get(CSqlParse.INPUT_PARAMS)
        input_class_name = self.get_input_struct_name(method)
        content += "\t" * 1 + 'var rowCount uint64 = 0\n'
        content += "\t" * 1 + "if this.m_db == nil {\n"
        content += "\t" * 1 + '\treturn errors.New("db is nil"), 0\n'
        content += "\t" * 1 + "}\n"
        content += "\t" * 1 + "tx, err := this.m_db.Begin()\n"
        content += "\t" * 1 + "if err != nil {\n"
        content += "\t" * 1 + '\treturn err, 0\n'
        content += "\t" * 1 + "}\n"
        content += "\t" * 1 + "var result sql.Result\n"
        content += "\t" * 1 + "var _ = result\n"
        content += "\t" * 1 + "var _ error = err\n"
        sub_func_sort_list = method.get(CSqlParse.SUB_FUNC_SORT_LIST)
        c, _ = self.__write_input(method, "", 0)
        content += c
        # if in_ismul is True:
        # 	content += "\t"*1 + "tx.Commit()\n"
        content += "\t" * 1 + "tx.Commit()\n"
        content += "\t" * 1 + 'return nil, rowCount\n'
        return content

    def __write_input(self, method, content, param_no):
        in_isarr = method.get(CSqlParse.IN_ISARR)
        out_isarr = method.get(CSqlParse.OUT_ISARR)
        in_ismul = None
        out_ismul = None
        if in_isarr == "true":
            in_ismul = True
        else:
            in_ismul = False
        if out_isarr == "true":
            out_ismul = True
        else:
            out_ismul = False
        func_name = method.get(CSqlParse.FUNC_NAME)
        input_params = method.get(CSqlParse.INPUT_PARAMS)
        output_params = method.get(CSqlParse.OUTPUT_PARAMS)
        sub_func_list = method.get(CSqlParse.SUB_FUNC_SORT_LIST)
        output_class_name = self.get_output_struct_name(method)

        def inner(content, param_no):
            sql = method.get(CSqlParse.SQL)
            sql = re.sub(r"\\", "", sql)
            tc = 1
            var_name = "input{0}".format(str(param_no))
            if in_ismul is True:
                tc = 2
                var_name = "v"
                content += "\t" * 1 + "for _, v := range *input" + str(
                    param_no) + " {\n"
            if output_params is not None and len(output_params) > 0:
                content += "\t" * tc + "rows{0}, err := tx.Query(".format(
                    str(param_no))
            else:
                content += "\t" * tc + "result, err = tx.Exec("
            sql, fulls = self.__replace_sql_brace(input_params, sql, False)
            content += 'fmt.Sprintf(`{0}`'.format(sql)
            if input_params is not None:
                for param in input_params:
                    is_cond = param.get(CSqlParse.PARAM_IS_CONDITION)
                    if is_cond is True:
                        param_name = param.get(CSqlParse.PARAM_NAME)
                        content += ", {1}.{0}".format(
                            CStringTools.upperFirstByte(param_name), var_name)
            content += ")"
            content += self.__write_query_params(input_params, var_name, fulls)
            content += ")\n"
            tc = 1
            end_str = "return err, rowCount"
            if in_ismul is True:
                tc = 2
                end_str = "return err, rowCount"
            content += "\t" * tc + 'if err != nil {\n'
            content += "\t" * (tc + 1) + 'tx.Rollback()\n'
            content += "\t" * (tc + 1) + '{0}\n'.format(end_str)
            content += "\t" * tc + '}\n'
            # if in_ismul is False:
            # 	content += "\t"*1 + "tx.Commit()\n"
            if output_params is not None and len(output_params) > 0:
                content += "\t" * tc + 'defer rows{0}.Close()\n'.format(
                    str(param_no))
                content += "\t" * tc + 'for rows' + str(
                    param_no) + '.Next() {\n'
                content += "\t" * (tc + 1) + 'rowCount += 1\n'
                content += self.__write_output(tc + 1, method, param_no)
                content += "\t" * tc + '}\n'
            else:
                content += "\t" * tc + "var _ = result\n"
            if in_ismul is True:
                content += "\t" * 1 + '}\n'
            param_no += 1
            # if in_ismul is False:
            # 	content += "\t"*1 + "tx.Commit()\n"
            return content, param_no

        if sub_func_list is None:
            content, param_no = inner(content, param_no)
        else:
            for sub_func_name, sub_func_index in sub_func_list:
                if func_name == sub_func_name:
                    content, param_no = inner(content, param_no)
                    continue
                method_info = self.m_parser.get_methodinfo_by_methodname(
                    sub_func_name)
                content, param_no = self.__write_input(method_info, content,
                                                       param_no)
        return content, param_no

    def __write_output(self, tc, method, param_no):
        in_isarr = method.get(CSqlParse.IN_ISARR)
        out_isarr = method.get(CSqlParse.OUT_ISARR)
        in_ismul = None
        out_ismul = None
        if in_isarr == "true":
            in_ismul = True
        else:
            in_ismul = False
        if out_isarr == "true":
            out_ismul = True
        else:
            out_ismul = False
        func_name = method.get(CSqlParse.FUNC_NAME)
        output_params = method.get(CSqlParse.OUTPUT_PARAMS)
        output_class_name = self.get_output_struct_name(method)
        content = ""
        length = 0
        if output_params is not None:
            length = len(output_params)
        if length == 0:
            return content
        if out_ismul is True:
            content += "\t" * tc + "tmp := {0}".format(
                output_class_name) + "{}\n"
        else:
            pass
        for param in output_params:
            param_type = param.get(CSqlParse.PARAM_TYPE)
            param_name = param.get(CSqlParse.PARAM_NAME)
            param_type = self.type_null_change(param_type)
            content += "\t" * tc + "var {0} {1}\n".format(
                param_name, param_type)
        content += "\t" * tc + "scanErr := rows{0}.Scan(".format(str(param_no))
        i = 0
        for param in output_params:
            i += 1
            param_name = param.get(CSqlParse.PARAM_NAME)
            content += "&{0}".format(param_name)
            if i < length:
                content += ", "
        content += ")\n"
        content += "\t" * tc + "if scanErr != nil {\n"
        content += "\t" * (tc + 1) + "continue\n"
        content += "\t" * tc + "}\n"
        pre = ""
        if out_ismul is True:
            pre = "tmp"
        else:
            pre = "output{0}".format(str(param_no))
        for param in output_params:
            param_type = param.get(CSqlParse.PARAM_TYPE)
            param_name = param.get(CSqlParse.PARAM_NAME)
            content += "\t" * tc + "{0}.{1} = {2}\n".format(
                pre, CStringTools.upperFirstByte(param_name),
                self.type_back(param_type, param_name))
            content += "\t" * tc + "{0}.{1}{3} = {2}.Valid\n".format(
                pre, CStringTools.upperFirstByte(param_name), param_name,
                self.get_isvail_join_str())
        if out_ismul is True:
            content += "\t" * tc + "*output{0} = append(*output{0}, tmp)\n".format(
                str(param_no))
        return content

    def __write_query_params(self, input_params, var_name, fulls):
        content = ""
        if input_params is None:
            return content
        not_cond_params = []
        for number, keyword in fulls:
            param = input_params[number]
            is_cond = param.get(CSqlParse.PARAM_IS_CONDITION)
            if is_cond is True:
                continue
            not_cond_params.append(param)
        if len(not_cond_params) > 0:
            content += ", "
        i = 0
        for param in not_cond_params:
            if i > 0:
                content += ", "
            i += 1
            param_name = param.get(CSqlParse.PARAM_NAME)
            param_type = param.get(CSqlParse.PARAM_TYPE)
            content += "{0}.{1}".format(
                var_name, CStringTools.upperFirstByte(param_name))
        """
		not_cond_params = []
		for param in input_params:
			is_cond = param.get(CSqlParse.PARAM_IS_CONDITION)
			if is_cond is False:
				not_cond_params.append(param)
		length = len(fulls)
		if len(not_cond_params) > 0:
			content += ", "
		i = 0
		forwardIsCond = False
		for number, keyword in fulls:
			param = input_params[number]
			is_cond = param.get(CSqlParse.PARAM_IS_CONDITION)
			if i > 0 and is_cond is False and forwardIsCond is False:
				content += ", "
			i += 1
			if is_cond is True:
				forwardIsCond = True
				continue
			else:
				forwardIsCond = False
			param_name = param.get(CSqlParse.PARAM_NAME)
			param_type = param.get(CSqlParse.PARAM_TYPE)
			content += "{0}.{1}".format(var_name, CStringTools.upperFirstByte(param_name))
		"""
        return content

    def __write_stuct(self):
        content = ""
        content += "type {0} struct ".format(self.get_class_name()) + " {\n"
        content += "\t" * 1 + "m_db *sql.DB\n"
        content += "}\n\n"
        return content

    def __write_connect(self):
        content = ""
        content += "func (this *CDbHandler) Connect(host string, port uint, username string, userpwd string, dbname string, dbtype string) (err error) {\n"
        content += "\t" * 1 + "b := bytes.Buffer{}\n"
        content += "\t" * 1 + "b.WriteString(username)\n"
        content += "\t" * 1 + 'b.WriteString(":")\n'
        content += "\t" * 1 + 'b.WriteString(userpwd)\n'
        content += "\t" * 1 + 'b.WriteString("@tcp(")\n'
        content += "\t" * 1 + 'b.WriteString(host)\n'
        content += "\t" * 1 + 'b.WriteString(":")\n'
        content += "\t" * 1 + 'b.WriteString(strconv.FormatUint(uint64(port), 10))\n'
        content += "\t" * 1 + 'b.WriteString(")/")\n'
        content += "\t" * 1 + 'b.WriteString(dbname)\n'
        content += "\t" * 1 + 'var name string\n'
        content += "\t" * 1 + 'if dbtype == "mysql" {\n'
        content += "\t" * 2 + 'name = b.String()\n'
        content += "\t" * 1 + '} else if dbtype == "sqlite3" {\n'
        content += "\t" * 2 + 'name = dbname\n'
        content += "\t" * 1 + '} else {\n'
        content += "\t" * 2 + 'return errors.New("dbtype not support")\n'
        content += "\t" * 1 + '}\n'
        content += "\t" * 1 + 'this.m_db, err = sql.Open(dbtype, name)\n'
        content += "\t" * 1 + 'if err != nil {\n'
        content += "\t" * 2 + 'return err\n'
        content += "\t" * 1 + '}\n'
        content += "\t" * 1 + 'this.m_db.SetMaxOpenConns(2000)\n'
        content += "\t" * 1 + 'this.m_db.SetMaxIdleConns(1000)\n'
        content += "\t" * 1 + 'this.m_db.Ping()\n'
        content += "\t" * 1 + 'return nil\n'
        content += "}\n\n"
        return content

    def __write_connect_by_rule(self):
        content = ""
        content += "func (this *CDbHandler) ConnectByRule(rule string, dbtype string) (err error) {\n"
        content += "\t" * 1 + 'this.m_db, err = sql.Open(dbtype, rule)\n'
        content += "\t" * 1 + 'if err != nil {\n'
        content += "\t" * 2 + 'return err\n'
        content += "\t" * 1 + '}\n'
        content += "\t" * 1 + 'this.m_db.SetMaxOpenConns(2000)\n'
        content += "\t" * 1 + 'this.m_db.SetMaxIdleConns(1000)\n'
        content += "\t" * 1 + 'this.m_db.Ping()\n'
        content += "\t" * 1 + 'return nil\n'
        content += "}\n\n"
        return content

    def __write_connect_by_cfg(self):
        content = ""
        content += "func (this *CDbHandler) ConnectByCfg(path string) error {\n"
        content += "\t" * 1 + 'fi, err := os.Open(path)\n'
        content += "\t" * 1 + 'if err != nil {\n'
        content += "\t" * 2 + 'return err\n'
        content += "\t" * 1 + '}\n'
        content += "\t" * 1 + 'defer fi.Close()\n'
        content += "\t" * 1 + 'br := bufio.NewReader(fi)\n'
        content += "\t" * 1 + 'var host string = "localhost"\n'
        content += "\t" * 1 + 'var port uint = 3306\n'
        content += "\t" * 1 + 'var username string = "root"\n'
        content += "\t" * 1 + 'var userpwd string = "123456"\n'
        content += "\t" * 1 + 'var dbname string = "test"\n'
        content += "\t" * 1 + 'var dbtype string = "mysql"\n'
        content += "\t" * 1 + 'for {\n'
        content += "\t" * 2 + 'a, _, c := br.ReadLine()\n'
        content += "\t" * 2 + 'if c == io.EOF {\n'
        content += "\t" * 3 + 'break\n'
        content += "\t" * 2 + '}\n'
        content += "\t" * 2 + 'content := string(a)\n'
        content += "\t" * 2 + 'r, _ := regexp.Compile("(.*)?=(.*)?")\n'
        content += "\t" * 2 + 'ret := r.FindStringSubmatch(content)\n'
        content += "\t" * 2 + 'if len(ret) != 3 {\n'
        content += "\t" * 3 + 'continue\n'
        content += "\t" * 2 + '}\n'
        content += "\t" * 2 + 'k := ret[1]\n'
        content += "\t" * 2 + 'v := ret[2]\n'
        content += "\t" * 2 + 'switch k {\n'
        content += "\t" * 2 + 'case "host":\n'
        content += "\t" * 3 + 'host = v\n'
        content += "\t" * 2 + 'case "port":\n'
        content += "\t" * 3 + 'port_tmp, _ := strconv.ParseUint(v, 10, 32)\n'
        content += "\t" * 3 + 'port = uint(port_tmp)\n'
        content += "\t" * 2 + 'case "username":\n'
        content += "\t" * 3 + 'username = v\n'
        content += "\t" * 2 + 'case "userpwd":\n'
        content += "\t" * 3 + 'userpwd = v\n'
        content += "\t" * 2 + 'case "dbname":\n'
        content += "\t" * 3 + 'dbname = v\n'
        content += "\t" * 2 + 'case "dbtype":\n'
        content += "\t" * 3 + 'dbtype = v\n'
        content += "\t" * 2 + '}\n'
        content += "\t" * 1 + '}\n'
        content += "\t" * 1 + 'return this.Connect(host, port, username, userpwd, dbname, dbtype)\n'
        content += "}\n\n"
        return content

    def __write_disconnect(self):
        content = ""
        content += "func (this *CDbHandler) Disconnect() {\n"
        content += "\t" * 1 + 'this.m_db.Close()\n'
        content += "}\n\n"
        return content

    def __write_create(self, create_sql, create_functions):
        content = ""
        create_sql = re.sub(r"\\", "", create_sql)
        sqls = create_sql.split(";")
        content += "func (this *CDbHandler) Create() (error) {\n"
        content += "\t" * 1 + "var err error = nil\n"
        content += "\t" * 1 + "var _ error = err\n"

        def err_content():
            con = ""
            con += "\t" * 1 + "if err != nil {\n"
            con += "\t" * 2 + "return err\n"
            con += "\t" * 1 + "}\n"
            return con

        for sql in sqls:
            if sql == "":
                continue
            content += "\t" * 1 + "_, err = this.m_db.Exec(`{0}`)\n".format(
                sql + ";")
            content += err_content()
        for sql in create_functions:
            content += "\t" * 1 + "_, err = this.m_db.Exec(`{0}`)\n".format(
                sql)
            content += err_content()
        content += "\t" * 1 + "return nil\n"
        content += "}\n\n"
        return content

    def __write_header(self, namespace):
        self.m_content += "package {0}\n\n".format(namespace)
        self.m_content += "import (\n"
        self.m_content += "\t" * 1 + '"{0}"\n'.format("bufio")
        self.m_content += "\t" * 1 + '"{0}"\n'.format("bytes")
        self.m_content += "\t" * 1 + '"{0}"\n'.format("database/sql")
        self.m_content += "\t" * 1 + '"{0}"\n'.format("io")
        self.m_content += "\t" * 1 + '"{0}"\n'.format("os")
        self.m_content += "\t" * 1 + '"{0}"\n'.format("regexp")
        self.m_content += "\t" * 1 + '"{0}"\n'.format("strconv")
        self.m_content += "\t" * 1 + '"{0}"\n'.format("fmt")
        self.m_content += "\t" * 1 + '"{0}"\n'.format("errors")
        self.m_content += ")\n\n"
Exemplo n.º 15
0
class CWriteCMakeListsBase(object):
    MODE_STATIC_LIB = 1
    MODE_DYNAMIC_LIB = 2
    MODE_EXECUTE = 3

    def __init__(self, file_path):
        self.m_file_handler = CFileHandle()
        self.m_file_path = file_path

    def is_debug(self):
        return True

    def mode(self):
        return CWriteCMakeListsBase.MODE_STATIC_LIB

    def include_paths(self):
        # include_directories(...)
        return []

    def header_list(self):
        # HEADER_LIST
        return []

    def link_lib_dirs(self):
        return []

    def win32_library_list(self):
        return []

    def linux_library_list(self):
        return []

    def obj_name(self):
        return ""

    def insert_after_include_dirs(self):
        return ""

    def insert_befor_if(self):
        return ""

    def write(self, encoding="utf8"):
        content = ""
        content += self.__write_include_dirs()
        if self.mode() == CWriteCMakeListsBase.MODE_EXECUTE:
            content += self.__write_link_lib_dirs()
        content += self.insert_after_include_dirs()
        content += self.__write_base_h_cpp()
        content += self.__write_add_header()
        if self.mode() == CWriteCMakeListsBase.MODE_EXECUTE:
            content += self.__write_add_execute()
            content += self.__write_set_target_properties()
        elif self.mode() == CWriteCMakeListsBase.MODE_STATIC_LIB:
            content += self.__write_add_library()
        content += self.insert_befor_if()
        if self.mode() == CWriteCMakeListsBase.MODE_EXECUTE:
            content += self.__write_target_link_libs()
        if self.is_debug() is True:
            print(content)
        else:
            self.m_file_handler.clear_write(content, self.m_file_path,
                                            encoding)

    def __write_include_dirs(self):
        content = ""
        content += "include_directories (\n"
        for d in self.include_paths():
            content += "\t" * 1 + d
            content += "\n"
        content += ")\n"
        content += "\n"
        return content

    def __write_link_lib_dirs(self):
        content = ""
        content += "link_directories (\n"
        for lib in self.link_lib_dirs():
            content += "\t" * 1 + lib
            content += "\n"
        content += ")\n"
        content += "\n"
        return content

    def __write_base_h_cpp(self):
        content = ""
        content += "aux_source_directory(source SOURCE_LIST)\n"
        content += 'FILE (GLOB HEADER_LIST "include/*.h")\n'
        content += "\n"
        return content

    def __write_add_header(self):
        content = ""
        content += "set (HEADER_LIST\n"
        content += "\t" * 1 + "${HEADER_LIST}" + "\n"
        for h in self.header_list():
            content += "\t" * 1 + h
            content += "\n"
        content += ")\n"
        content += "\n"
        return content

    def __write_add_library(self):
        content = ""
        content += "add_library (" + self.obj_name(
        ) + " STATIC ${SOURCE_LIST} ${HEADER_LIST})\n"
        content += "\n"
        return content

    def __write_add_execute(self):
        content = ""
        content += "add_executable (" + self.obj_name(
        ) + " ${SOURCE_LIST} ${HEADER_LIST})\n"
        content += "\n"
        return content

    def __write_set_target_properties(self):
        content = ""
        content += 'set_target_properties({0} PROPERTIES DEBUG_POSTFIX "_d")\n'.format(
            self.obj_name())
        content += "\n"
        return content

    def __write_target_link_libs(self):
        content = ""
        content += 'if (CMAKE_SYSTEM_NAME MATCHES "Windows")\n'
        content += "\t" * 1 + "target_link_libraries ({0}\n".format(
            self.obj_name())
        for lib in self.win32_library_list():
            content += "\t" * 2 + lib + "\n"
        content += "\t" * 1 + ")\n"
        content += 'elseif (CMAKE_SYSTEM_NAME MATCHES "Linux")\n'
        content += "\t" * 1 + "target_link_libraries ({0}\n".format(
            self.obj_name())
        for lib in self.linux_library_list():
            content += "\t" * 2 + lib + "\n"
        content += "\t" * 1 + ")\n"
        content += "\n"
        content += "\t" * 1 + 'set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -static")\n'
        content += "\t" * 1 + 'set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -static")\n'
        content += "\n"
        content += "\n"
        content += "\t" * 1 + 'if (${CMAKE_BUILD_TYPE} STREQUAL "Release")\n'
        content += "\t" * 2 + 'add_custom_command(TARGET ' + self.obj_name(
        ) + ' POST_BUILD COMMAND echo "strip"\n'
        content += "\t" * 4 + 'COMMAND ${STRIP} ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/' + self.obj_name(
        ) + "\n"
        content += "\t" * 4 + 'COMMAND cp ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/' + self.obj_name(
        ) + ' ~/nfs/${PLATFORM_NAME} -f\n'
        content += "\t" * 4 + ")\n"
        content += "\t" * 1 + 'endif ()\n'
        content += "\n"
        content += "endif ()\n"
        return content
Exemplo n.º 16
0
class CWriteSqliteImpCpp(CWriteBase):
    def __init__(self, parser, file_path, root="."):
        CWriteBase.__init__(self, parser)
        self.m_file_handler = CFileHandle()
        self.m_file_name = ""
        self.m_file_path = ""
        self.m_content = ""
        self.m_namespace = ""
        self.__compare_file_path(file_path, root)

    def __compare_file_path(self, file_path, root):
        basename = os.path.basename(file_path)
        filename, fileext = os.path.splitext(basename)
        self.m_file_name = filename
        self.m_file_path = os.path.join(root, filename + "_db_handler.cpp")

    def define_name(self):
        return "__{0}_DB_HANDLER_H__".format(self.m_namespace.upper())

    def include_sys_list(self):
        return ["stdio.h", "string.h", "sstream"]

    def include_other_list(self):
        return [self.m_file_name + "_db_handler.h"]

    def namespace(self):
        return self.m_namespace

    def class_name(self):
        return "CDbHandler"

    def write(self, info_dict):
        # 获取 namesapce
        namespace = info_dict.get(CSqlParse.NAMESPACE)
        if namespace is None or namespace == "":
            raise RuntimeError("namespace is empty")
        self.m_namespace = namespace
        content = ""
        # content += self.write_header()
        content += self.write_includes()
        content += self.write_namespace_define()
        # content += self.write_class_define()
        content += self.__write_implement(info_dict)
        # content += self.write_class_end()
        content += self.write_namespace_end()
        # content += self.write_tail()
        self.m_content += content
        # print(self.m_content)
        self.m_file_handler.clear_write(self.m_content, self.m_file_path,
                                        "utf8")

    def __write_implement(self, info_dict):
        create_table_list = info_dict.get(CSqlParse.CREATE_TABLE_LIST)
        content = ""
        content += "{0}::{0}(const std::string &dial, sql::ISql *s, int max)\n".format(
            self.class_name())
        content += "\t" * 1 + ": m_connPool(s, max)\n"
        content += "\t" * 1 + ', m_dial(dial)\n'
        content += "{\n"
        content += "\t" * 1 + 'sql::IConnect *conn = m_connPool.connect(m_dial);\n'
        if len(create_table_list) > 0:
            content += "\t" * 1 + "if (conn != nullptr) {\n"
            content += "\t" * 2 + 'std::string sql("");\n'
            for create_sql in create_table_list:
                content += "\t" * 2 + 'sql = "\\\n{0}";\n'.format(create_sql)
                content += "\t" * 2 + "conn->exec(sql);\n"
                content += "\t" * 2 + "conn->compare(sql);\n"
            content += "\t" * 1 + "}\n"
        content += "\t" * 1 + "m_connPool.freeConnect(conn);\n"
        content += "}\n"
        content += "\n"
        content += "{0}::~{0}()\n".format(self.class_name())
        content += "{\n"
        content += "}\n"
        method_list = info_dict.get(CSqlParse.METHOD_LIST)
        content += self.__write_methods(method_list)
        return content

    def __write_methods(self, method_list):
        content = ""
        content += "\n"
        for method_info in method_list:
            content += self.write_method_implement(method_info)
        content += "/*@@start@@*/" + "\n\n"
        return content
Exemplo n.º 17
0
class CWriteSqliteImpCpp(CWriteBase):
    def __init__(self, file_path, root="."):
        self.m_file_handler = CFileHandle()
        self.m_file_name = ""
        self.m_file_path = ""
        self.m_content = ""
        self.m_namespace = ""
        self.__compare_file_path(file_path, root)

    def __compare_file_path(self, file_path, root):
        basename = os.path.basename(file_path)
        filename, fileext = os.path.splitext(basename)
        self.m_file_name = filename
        self.m_file_path = os.path.join(root, filename + "_db_handler.cpp")

    def define_name(self):
        return "__{0}_DB_HANDLER_H__".format(self.m_namespace.upper())

    def include_sys_list(self):
        return ["stdio.h", "string.h", "sstream"]

    def include_other_list(self):
        return ["sqlite3.h", self.m_file_name + "_db_handler.h"]

    def namespace(self):
        return self.m_namespace

    def class_name(self):
        return "CDbHandler"

    def write(self, info_dict):
        # 获取 namesapce
        namespace = info_dict.get(CSqlParse.NAMESPACE)
        if namespace is None or namespace == "":
            raise RuntimeError("namespace is empty")
        self.m_namespace = namespace
        content = ""
        # content += self.write_header()
        content += self.write_includes()
        content += self.write_namespace_define()
        # content += self.write_class_define()
        content += self.__write_implement(info_dict)
        # content += self.write_class_end()
        content += self.write_namespace_end()
        # content += self.write_tail()
        self.m_content += content
        # print(self.m_content)
        self.m_file_handler.clear_write(self.m_content, self.m_file_path,
                                        "utf8")

    def __write_implement(self, info_dict):
        create_tables_sql = info_dict.get(CSqlParse.CREATE_TABELS_SQL)
        content = ""
        content += "{0}::{0}(const std::string &dbpath, bool isMemory)\n".format(
            self.class_name())
        content += "\t" * 1 + ": m_db(nullptr)\n"
        content += "\t" * 1 + ", m_mutex()\n"
        content += "{\n"
        content += "\t" * 1 + "sqlite3_threadsafe();\n"
        content += "\t" * 1 + "sqlite3_config(SQLITE_CONFIG_MULTITHREAD);\n"
        content += "\t" * 1 + "int ret = SQLITE_OK;\n"
        content += "\t" * 1 + "if (isMemory == false) {\n"
        content += "\t" * 2 + "ret = sqlite3_open(dbpath.c_str(), &m_db);\n"
        content += "\t" * 1 + "}\n"
        content += "\t" * 1 + "else {\n"
        content += "\t" * 2 + 'ret = sqlite3_open(":memory:", &m_db);\n'
        content += "\t" * 1 + "}\n"
        if create_tables_sql is not None:
            content += "\t" * 1 + "if (ret == SQLITE_OK) {\n"
            content += "\t" * 2 + 'std::string sql = "\\\n{0}";\n'.format(
                create_tables_sql)
            content += "\t" * 2 + "sqlite3_exec(m_db, sql.c_str(), nullptr, nullptr, nullptr);\n"
            content += "\t" * 1 + "}\n"
        content += "}\n"
        content += "\n"
        content += "{0}::~{0}()\n".format(self.class_name())
        content += "{\n"
        content += "\t" * 1 + "sqlite3_close(m_db);\n"
        content += "}\n"
        method_list = info_dict.get(CSqlParse.METHOD_LIST)
        content += self.__write_methods(method_list)
        return content

    def __write_methods(self, method_list):
        content = ""
        content += "\n"
        for method_info in method_list:
            content += self.write_method_implement(method_info)
        content += "/*@@start@@*/" + "\n\n"
        return content