コード例 #1
0
    def __init__(self, include_dir_paths=None):
        object.__init__(self)

        if include_dir_paths is None:
            include_dir_paths = []
        elif isinstance(include_dir_paths, (list, tuple)):
            include_dir_paths = list(include_dir_paths)
        else:
            include_dir_paths = [include_dir_paths]
        if len(include_dir_paths) == 0:
            include_dir_paths.append(os.getcwd())
        my_dir_path = os.path.dirname(os.path.realpath(__file__))
        lib_thrift_src_dir_path = \
            os.path.abspath(os.path.join(
                my_dir_path,
                '..', '..', '..', '..',
                'lib', 'thrift', 'src'
            ))
        if lib_thrift_src_dir_path not in include_dir_paths:
            include_dir_paths.append(lib_thrift_src_dir_path)
        self.__include_dir_paths = []
        for include_dir_path in include_dir_paths:
            if include_dir_path not in self.__include_dir_paths:
                self.__include_dir_paths.append(include_dir_path)
        self.__include_dir_paths = tuple(self.__include_dir_paths)

        self.__parsed_thrift_files_by_path = {}
        self.__scanner = Scanner()
        self.__parser = Parser()
コード例 #2
0
ファイル: compiler.py プロジェクト: minorg/thryft
    def __init__(self, document_root_dir_path, include_dir_paths):
        object.__init__(self)

        self.__document_root_dir_path = document_root_dir_path

        if isinstance(include_dir_paths, (list, tuple)):
            include_dir_paths = list(include_dir_paths)
        else:
            include_dir_paths = [include_dir_paths]
        my_dir_path = os.path.dirname(os.path.realpath(__file__))
        lib_thrift_src_dir_path = \
            os.path.abspath(os.path.join(
                my_dir_path,
                '..', '..', '..', '..',
                'lib', 'thrift', 'src'
            ))
        if lib_thrift_src_dir_path not in include_dir_paths:
            include_dir_paths.append(lib_thrift_src_dir_path)
        self.__include_dir_paths = []
        for include_dir_path in include_dir_paths:
            if include_dir_path not in self.__include_dir_paths:
                self.__include_dir_paths.append(include_dir_path)
        self.__include_dir_paths = tuple(self.__include_dir_paths)

        self.__document_cache = {}
        self.__scanner = Scanner()
        self.__parser = Parser()
コード例 #3
0
    def __init__(self, include_dir_paths=None):
        object.__init__(self)

        if include_dir_paths is None:
            include_dir_paths = []
        elif isinstance(include_dir_paths, (list, tuple)):
            include_dir_paths = list(include_dir_paths)
        else:
            include_dir_paths = [include_dir_paths]
        if len(include_dir_paths) == 0:
            include_dir_paths.append(os.getcwd())
        my_dir_path = os.path.dirname(os.path.realpath(__file__))
        lib_thrift_src_dir_path = \
            os.path.abspath(os.path.join(
                my_dir_path,
                '..', '..', '..', '..',
                'lib', 'thrift', 'src'
            ))
        if lib_thrift_src_dir_path not in include_dir_paths:
            include_dir_paths.append(lib_thrift_src_dir_path)
        self.__include_dir_paths = []
        for include_dir_path in include_dir_paths:
            if include_dir_path not in self.__include_dir_paths:
                self.__include_dir_paths.append(include_dir_path)
        self.__include_dir_paths = tuple(self.__include_dir_paths)

        self.__parsed_thrift_files_by_path = {}
        self.__scanner = Scanner()
        self.__parser = Parser()
コード例 #4
0
 def _runTest(self, thrift_file_path):
     #        import os.path
     #        if os.path.split(thrift_file_path)[1] != 'comment.thrift':
     #            return
     tokens = Scanner().tokenize(thrift_file_path)
     self.assertNotEquals(0, len(tokens))
     #        print thrift_file_path
     #        for token in tokens:
     #            print token.type, ':', len(token.text), ':', token.text
     #        print
     actual_text = ''.join(token.text for token in tokens)
     with open(thrift_file_path, 'rb') as thrift_file:
         expected_text = thrift_file.read()
         self.assertEquals(expected_text, actual_text)
コード例 #5
0
 def _runTest(self, thrift_file_path):
     #        import logging
     #        logging.basicConfig(level=logging.DEBUG)
     #        import os.path
     #        if os.path.split(thrift_file_path)[1] != 'struct_type.thrift':
     #            return
     tokens = []
     try:
         tokens = Scanner().tokenize(thrift_file_path)
         ast = Parser().parse(tokens)  # @UnusedVariable
         # import pprint; pprint.pprint(ast.to_dict())
     except:
         print >> sys.stderr, 'Error parsing', thrift_file_path
         traceback.print_exc()
         for token in tokens:
             print >> sys.stderr, token.index, token.type, ':', len(
                 token.text), ':', token.text
         print >> sys.stderr
         raise
コード例 #6
0
ファイル: compiler.py プロジェクト: minorg/thryft
class Compiler(object):
    class __AstVisitor(object):
        def __init__(self, compiler, document_root_dir_path, generator, include_dir_paths):
            object.__init__(self)
            self.__compiler = compiler
            self.__document_root_dir_path = document_root_dir_path
            self.__generator = generator
            self.__include_dir_paths = include_dir_paths
            self.__scope_stack = []
            self.__type_by_thrift_qname_cache = {}
            self.__used_include_abspaths = {}
            self.__visited_includes = []

        def __construct(self, class_name, annotation_nodes=None, **kwds):
            if len(self.__scope_stack) > 0:
                parent = self.__scope_stack[-1]
            else:
                parent = self.__generator
            kwds['parent'] = parent

            construct_class = getattr(self.__generator, class_name)

            native = False
            if annotation_nodes:
                annotations = []
                for annotation_node in annotation_nodes:
                    if annotation_node.name == 'native':
                        native = True
                        continue

                    try:
                        annotation_class_name = upper_camelize(
                            annotation_node.name) + 'Annotation'
                        annotation_class = getattr(
                            construct_class, annotation_class_name)
                    except AttributeError:
                        annotation_class = getattr(
                            construct_class, 'Annotation')
                    annotations.append(
                        annotation_class(
                            name=annotation_node.name,
                            value=annotation_node.value
                        )
                    )
                kwds['annotations'] = tuple(annotations)

            construct = construct_class(**kwds)

            if not native:
                return construct

            parent_document = construct._parent_document()
            native_module_file_path = os.path.splitext(
                parent_document.path)[0] + '.py'
            if not os.path.isfile(native_module_file_path):
                return
            native_module_dir_path, native_module_file_name = \
                os.path.split(native_module_file_path)
            native_module_name = \
                os.path.splitext(native_module_file_name)[0]
            native_module = \
                imp.load_module(
                    native_module_name,
                    *imp.find_module(
                        native_module_name,
                        [native_module_dir_path]
                    )
                )

            native_construct_class = getattr(native_module, construct.name)

            construct = native_construct_class(
                overridden_construct=construct, **kwds)

            return construct

        def __get_type(self, type_thrift_qname, resolve=True):
            # e.g., struct_include_file.Struct
            try:
                return self.__type_by_thrift_qname_cache[type_thrift_qname]
            except KeyError:
                if not resolve:
                    raise
                for include in self.__visited_includes:
                    for definition in include.document.definitions:
                        if isinstance(definition, _Type) or isinstance(definition, Typedef):
                            definition_thrift_qname = definition.thrift_qname()
                            if definition_thrift_qname == type_thrift_qname:
                                self.__type_by_thrift_qname_cache[definition_thrift_qname] = definition
                                self.__used_include_abspaths[include.abspath] = True
                                return definition
                raise KeyError(type_thrift_qname)

        def __put_type(self, type_thrift_qname, type_):
            if type_thrift_qname in self.__type_by_thrift_qname_cache:
                raise CompileException("duplicate type %s" % type_thrift_qname)
            if isinstance(type_, Typedef):
                type_ = type_.type
            self.__type_by_thrift_qname_cache[type_thrift_qname] = type_

        def visit_base_type_node(self, base_type_node):
            try:
                return self.__get_type(base_type_node.name, resolve=False)
            except KeyError:
                base_type = getattr(self.__generator, base_type_node.name.capitalize(
                ) + 'Type')(name=base_type_node.name)
                self.__put_type(base_type_node.name, base_type)
                return base_type

        def visit_bool_literal_node(self, bool_literal_node):
            return bool_literal_node.value

        def __visit_compound_type_node(self, construct_class_name, compound_type_node):
            compound_type = \
                self.__construct(
                    construct_class_name,
                    annotations=compound_type_node.annotations,
                    doc=self.__visit_doc_node(compound_type_node.doc),
                    name=compound_type_node.name
                )
            self.__scope_stack.append(compound_type)

            # Insert the compound type into the type_map here to allow recursive
            # definitions
            self.__put_type(compound_type.thrift_qname(), compound_type)

            if construct_class_name == 'EnumType':
                enum_type_node = compound_type_node
                have_enumerator_with_value = False
                enumerator_node_names = []
                for enumerator_i, enumerator_node in enumerate(enum_type_node.enumerators):
                    if enumerator_node.name in enumerator_node_names:
                        raise CompileException("%s has a duplicate enumerator name, %s" % (
                            enum_type_node.name, enumerator_node.name), ast_node=enumerator_node)
                    enumerator_node_names.append(enumerator_node.name)

                    if enumerator_node.value is not None:
                        have_enumerator_with_value = True
                        assert isinstance(enumerator_node.value, Ast.IntLiteralNode), type(
                            enumerator_node.value)
                        value = enumerator_node.value.value
                    else:
                        if have_enumerator_with_value:
                            raise CompileException(
                                "%s has mix of enumerators with and without values, must be one or the other" % enum_type_node.name, ast_node=enum_type_node)
                        value = enumerator_i

                    compound_type.enumerators.append(
                        self.__construct(
                            'Field',
                            annotation_nodes=enumerator_node.annotations,
                            doc=self.__visit_doc_node(enumerator_node.doc),
                            id=enumerator_i,
                            name=enumerator_node.name,
                            type=Ast.BaseTypeNode('i32').accept(self),
                            value=value
                        )
                    )
            else:
                field_name_variations = []
                id_count = 0
                for field_node in compound_type_node.fields:
                    field_name = field_node.name
                    if field_name in field_name_variations:
                        raise CompileException("compound type %s has a duplicate field %s" % (
                            compound_type_node.name, field_name), ast_node=field_node)

                    field_name_lower = field_name.lower()
                    if field_name_lower in field_name_variations:
                        raise CompileException("compound type %s has a duplicate field %s" % (
                            compound_type_node.name, field_name), ast_node=field_node)

                    field_name_lower_camelized = lower_camelize(field_name)
                    if field_name_lower_camelized in field_name_variations:
                        raise CompileException("compound type %s has a duplicate field %s" % (
                            compound_type_node.name, field_name), ast_node=field_node)

                    field_name_variations.append(field_name)
                    field_name_variations.append(field_name_lower)
                    field_name_variations.append(field_name_lower_camelized)

                    field = field_node.accept(self)
                    if field.required:
                        if len(compound_type.fields) > 0:
                            if not compound_type.fields[-1].required:
                                raise CompileException("compound type %s has a required field %s after an optional field %s" % (
                                    compound_type_node.name, field.name, compound_type.fields[-1].name), ast_node=compound_type_node)
                    if field.id is not None:
                        id_count += 1
                        for existing_field in compound_type.fields:
                            if existing_field.id == field.id:
                                raise CompileException("compound type %s has duplicate field id %d (%s and %s fields)" % (
                                    compound_type_node.name, field.id, field.name, existing_field.name), ast_node=compound_type_node)
                    compound_type.fields.append(field)
                if len(compound_type.fields) > 0:
                    if id_count != 0 and id_count != len(compound_type_node.fields):
                        raise CompileException("compound type %s has some fields with ids and some fields without" %
                                               compound_type_node.name, ast_node=compound_type_node)

            self.__scope_stack.pop(-1)

            return compound_type

        def visit_const_node(self, const_node):
            return \
                self.__construct(
                    'Const',
                    annotation_nodes=const_node.annotations,
                    doc=self.__visit_doc_node(const_node.doc),
                    name=const_node.name,
                    type=const_node.type.accept(self),
                    value=const_node.value.accept(self)
                )

        def __visit_doc_node(self, doc_node):
            if doc_node is not None:
                return doc_node.text
            else:
                return None

        def visit_document_node(self, document_node):
            document = \
                self.__construct(
                    'Document',
                    document_root_dir_path=self.__document_root_dir_path,
                    path=document_node.path
                )
            self.__scope_stack.append(document)

            for header_node in document_node.headers:
                document.headers.append(header_node.accept(self))
            for definition_node in document_node.definitions:
                document.definitions.append(definition_node.accept(self))

            self.__scope_stack.pop(-1)

            for include in self.__visited_includes:
                include.used = include.abspath in self.__used_include_abspaths

            return document

        def visit_enum_type_node(self, enum_node):
            return self.__visit_compound_type_node('EnumType', enum_node)

        def visit_exception_type_node(self, exception_type_node):
            return self.__visit_compound_type_node('ExceptionType', exception_type_node)

        def visit_field_node(self, field_node):
            if field_node.value is not None:
                value = field_node.value.accept(self)
            else:
                value = None
            return \
                self.__construct(
                    'Field',
                    annotation_nodes=field_node.annotations,
                    doc=self.__visit_doc_node(field_node.doc),
                    id=field_node.id,
                    name=field_node.name,
                    required=field_node.required,
                    type=field_node.type.accept(self),
                    value=value
                )

        def visit_float_literal_node(self, float_literal_node):
            return float_literal_node.value

        def visit_function_node(self, function_node):
            function = \
                self.__construct(
                    'Function',
                    annotation_nodes=function_node.annotations,
                    doc=self.__visit_doc_node(function_node.doc),
                    name=function_node.name,
                    oneway=function_node.oneway
                )
            self.__scope_stack.append(function)

            for parameter_node in function_node.parameters:
                parameter = parameter_node.accept(self)
                if parameter.required:
                    if len(function.parameters) > 0:
                        if not function.parameters[-1].required:
                            raise CompileException("function %s has a required parameter %s after an optional parameter %s" % (
                                function.name, parameter.name, function.parameters[-1].name))
                function.parameters.append(parameter)
            if function_node.return_field is not None:
                function.return_field = function_node.return_field.accept(self)
            for throws_node in function_node.throws:
                function.throws.append(throws_node.accept(self))

            self.__scope_stack.pop(-1)
            return function

        def visit_include_node(self, include_node):
            include_dir_paths = list(self.__include_dir_paths)
            for scope in reversed(self.__scope_stack):
                if isinstance(scope, Document):
                    include_dir_paths.append(os.path.dirname(scope.path))
                    break

            include_file_relpath = include_node.path.replace('/', os.path.sep)
            for include_dir_path in include_dir_paths:
                include_file_path = os.path.join(
                    include_dir_path, include_file_relpath)
                if os.path.exists(include_file_path):
                    include_file_path = os.path.abspath(include_file_path)
                    included_document = \
                        self.__compiler.compile(
                            generator=self.__generator,
                            thrift_file_path=include_file_path
                        )
                    include = \
                        self.__construct(
                            'Include',
                            abspath=include_file_path,
                            annotation_nodes=include_node.annotations,
                            doc=self.__visit_doc_node(include_node.doc),
                            document=included_document,
                            relpath=include_file_relpath
                        )
                    self.__visited_includes.append(include)
                    return include
            raise CompileException("include path not found: %s" %
                                   include_file_relpath, ast_node=include_node)

        def visit_int_literal_node(self, int_literal_node):
            return int_literal_node.value

        def visit_list_literal_node(self, list_literal_node):
            return tuple(element_value.accept(self) for element_value in list_literal_node.value)

        def visit_list_type_node(self, list_type_node):
            return self.__visit_sequence_type_node('ListType', list_type_node)

        def visit_map_literal_node(self, map_literal_node):
            return tuple((key_value.accept(self), value_value.accept(self)) for key_value, value_value in map_literal_node.value)

        def visit_map_type_node(self, map_type_node):
            try:
                return self.__get_type(map_type_node.name, resolve=False)
            except KeyError:
                map_type = self.__construct('MapType', key_type=map_type_node.key_type.accept(
                    self), value_type=map_type_node.value_type.accept(self))
                self.__put_type(map_type_node.name, map_type)
                return map_type

        def visit_namespace_node(self, namespace_node):
            return \
                self.__construct(
                    'Namespace',
                    annotation_nodes=namespace_node.annotations,
                    doc=self.__visit_doc_node(namespace_node.doc),
                    name=namespace_node.name,
                    scope=namespace_node.scope
                )

        def __visit_sequence_type_node(self, construct_class_name, sequence_type_node):
            try:
                return self.__get_type(sequence_type_node.name, resolve=False)
            except KeyError:
                sequence_type = self.__construct(
                    construct_class_name, element_type=sequence_type_node.element_type.accept(self))
                self.__put_type(sequence_type_node.name, sequence_type)
                return sequence_type

        def visit_service_node(self, service_node):
            service = \
                self.__construct(
                    'Service',
                    annotation_nodes=service_node.annotations,
                    doc=self.__visit_doc_node(service_node.doc),
                    name=service_node.name
                )
            self.__scope_stack.append(service)

            function_names_lower = []
            for function_node in service_node.functions:
                function = function_node.accept(self)

                function_name_lower = function.name.lower()
                if function_name_lower in function_names_lower:
                    raise CompileException(
                        "duplicate (case-insensitive) function name '%s'" % function.name, ast_node=function_node)
                function_names_lower.append(function_name_lower)

                service.functions.append(function)
            self.__scope_stack.pop(-1)

            return service

        def visit_set_type_node(self, set_type_node):
            return self.__visit_sequence_type_node('SetType', set_type_node)

        def visit_string_literal_node(self, string_literal_node):
            return string_literal_node.value

        def visit_struct_type_node(self, struct_type_node):
            return \
                self.__visit_compound_type_node(
                    'StructType',
                    struct_type_node,
                )

        def visit_type_node(self, type_node):
            try:
                try:
                    return self.__get_type(type_node.qname)
                except KeyError:
                    if type_node.qname == type_node.name:
                        document = self.__scope_stack[0]
                        return self.__get_type(document.name + '.' + type_node.qname)
                    else:
                        raise
            except KeyError:
                raise CompileException(
                    "unrecognized type '%s'" % type_node.qname, ast_node=type_node)

        def visit_typedef_node(self, typedef_node):
            typedef = \
                self.__construct(
                    'Typedef',
                    annotation_nodes=typedef_node.annotations,
                    doc=self.__visit_doc_node(typedef_node.doc),
                    name=typedef_node.name,
                    type=typedef_node.type.accept(self)
                )

            self.__put_type(typedef.thrift_qname(), typedef)

            return typedef

    def __init__(self, document_root_dir_path, include_dir_paths):
        object.__init__(self)

        self.__document_root_dir_path = document_root_dir_path

        if isinstance(include_dir_paths, (list, tuple)):
            include_dir_paths = list(include_dir_paths)
        else:
            include_dir_paths = [include_dir_paths]
        my_dir_path = os.path.dirname(os.path.realpath(__file__))
        lib_thrift_src_dir_path = \
            os.path.abspath(os.path.join(
                my_dir_path,
                '..', '..', '..', '..',
                'lib', 'thrift', 'src'
            ))
        if lib_thrift_src_dir_path not in include_dir_paths:
            include_dir_paths.append(lib_thrift_src_dir_path)
        self.__include_dir_paths = []
        for include_dir_path in include_dir_paths:
            if include_dir_path not in self.__include_dir_paths:
                self.__include_dir_paths.append(include_dir_path)
        self.__include_dir_paths = tuple(self.__include_dir_paths)

        self.__document_cache = {}
        self.__scanner = Scanner()
        self.__parser = Parser()

    def __call__(self, *args, **kwds):
        return self.compile(*args, **kwds)

    def compile(self, generator, thrift_file_path):
        if generator is None:
            raise ValueError('generator must not be None')

        thrift_file_path = os.path.abspath(thrift_file_path)

        path_document_cache = self.__document_cache.get(thrift_file_path)
        if path_document_cache is not None:
            for other_generator, generator_document in path_document_cache:
                if other_generator is None:
                    # Special placeholder for the AST
                    document_node = generator_document
                elif generator == other_generator:
                    return generator_document
            assert document_node is not None
        else:
            tokens = self.__scanner.tokenize(thrift_file_path)
            document_node = self.__parser.parse(
                thrift_file_path=thrift_file_path, tokens=tokens)
            # Can't hash the generator easily so just use its __eq__
            self.__document_cache[thrift_file_path] = path_document_cache = [
                (None, document_node)]

        ast_visitor = \
            self.__AstVisitor(
                compiler=self,
                document_root_dir_path=self.__document_root_dir_path,
                generator=generator,
                include_dir_paths=self.__include_dir_paths
            )
        generator_document = document_node.accept(ast_visitor)
        path_document_cache.append((generator, generator_document))
        return generator_document

    @property
    def document_root_dir_path(self):
        return self.__document_root_dir_path

    @property
    def include_dir_paths(self):
        return self.__include_dir_paths
コード例 #7
0
class Compiler(object):
    class __AstVisitor(object):
        __ROOT_GENERATOR = Generator()

        def __init__(self, compiler, generator, include_dir_paths):
            object.__init__(self)
            self.__compiler = compiler
            self.__generator = generator
            self.__include_dir_paths = include_dir_paths
            self.__scope_stack = []
            self.__type_cache = {}

        def __construct(self, class_name, **kwds):
            if len(self.__scope_stack) > 0:
                parent = self.__scope_stack[-1]
            else:
                parent = self.__generator
            kwds['parent'] = parent

            annotations = kwds.get('annotations')
            name = kwds.get('name')
            if annotations is None:
                native = False
            elif name is None:
                native = False
            else:
                native = False
                for annotation in annotations:
                    if annotation == 'native':
                        native = True
                        break

            if native:
                for scope in reversed(self.__scope_stack):
                    if not isinstance(scope, Document):
                        continue

                    document = scope
                    overrides_module_file_path = os.path.splitext(
                        document.path)[0] + '.py'
                    if not os.path.isfile(overrides_module_file_path):
                        continue
                    overrides_module_dir_path, overrides_module_file_name = \
                        os.path.split(overrides_module_file_path)
                    overrides_module_name = \
                        os.path.splitext(overrides_module_file_name)[0]
                    try:
                        overrides_module = \
                            imp.load_module(
                                overrides_module_name,
                                *imp.find_module(
                                    overrides_module_name,
                                    [overrides_module_dir_path]
                                )
                            )
                    except ImportError:
                        logging.error("error importing overrides module %s",
                                      overrides_module_file_path,
                                      exc_info=True)
                        continue

                    # Find an override implementation corresponding to our generator
                    # For example, find JavaDateTime by looking in the module for a
                    # class that inherits JavaNativeType.
                    # The first_bases algorithm below covers the case where we want
                    # JavaDateTime but the generator gave us something that itself
                    # inherits from JavaStructType e.g., SubJavaStructType.
                    # In this case there is no SubJavaDateTime(SubJavaStructType), so
                    # we need to consider SubJavaStructType's parent (JavaStructType)
                    # and look for a subclass of that.
                    root_construct_class = getattr(self.__ROOT_GENERATOR,
                                                   'NativeType')
                    generator_construct_class = getattr(
                        self.__generator, 'NativeType')
                    parent_construct_classes = [generator_construct_class]

                    def __get_first_bases(class_, first_bases):
                        if len(class_.__bases__) == 0:
                            return
                        first_base = class_.__bases__[0]
                        if first_base is object or first_base is root_construct_class:
                            return
                        if first_base.__name__[0] != '_':
                            first_bases.append(first_base)
                        __get_first_bases(first_base, first_bases)

                    __get_first_bases(generator_construct_class,
                                      parent_construct_classes)

                    for parent_construct_class in parent_construct_classes:
                        for attr in dir(overrides_module):
                            value = getattr(overrides_module, attr)
                            if isclass(value) and \
                               issubclass(value, parent_construct_class) and \
                               value != parent_construct_class:
                                return getattr(overrides_module, attr)(**kwds)
                    # logging.warn("could not find override class for %s in %s" % (default_construct_class.__name__, overrides_module_name))

            return getattr(self.__generator, class_name)(**kwds)

        def visit_annotation_node(self, annotation_node):
            return annotation_node.value

        def __visit_annotation_nodes(self, annotation_nodes):
            if annotation_nodes is None:
                return {}
            return dict((annotation_node.name, annotation_node.accept(self))
                        for annotation_node in annotation_nodes)

        def visit_base_type_node(self, base_type_node):
            try:
                return self.__type_cache[base_type_node.name]
            except:
                base_type = getattr(self.__generator,
                                    base_type_node.name.capitalize() +
                                    'Type')(name=base_type_node.name)
                self.__type_cache[base_type_node.name] = base_type
                return base_type

        def visit_bool_literal_node(self, bool_literal_node):
            return bool_literal_node.value

        def __visit_compound_type_node(self, construct_class_name,
                                       compound_type_node):
            compound_type = \
                self.__construct(
                    construct_class_name,
                    annotations=self.__visit_annotation_nodes(compound_type_node.annotations),
                    doc=self.__visit_doc_node(compound_type_node.doc),
                    name=compound_type_node.name
                )
            if isinstance(compound_type, NativeType):
                return compound_type
            self.__scope_stack.append(compound_type)

            # Insert the compound type into the type_map here to allow recursive
            # definitions
            self.__type_cache[compound_type.thrift_qname()] = compound_type

            if construct_class_name == 'EnumType':
                have_enumerator_with_value = False
                for enumerator_node in compound_type_node.enumerators:
                    if enumerator_node.value is not None:
                        have_enumerator_with_value = True
                    elif have_enumerator_with_value:
                        raise CompileException(
                            "%s has mix of enumerators with and without values, must be one or the other"
                            % compound_type_node.name,
                            ast_node=compound_type_node)
                for enumerator_i, enumerator_node in enumerate(
                        compound_type_node.enumerators):
                    if enumerator_node.value is None:
                        value = enumerator_i
                    else:
                        assert isinstance(enumerator_node.value,
                                          Ast.IntLiteralNode), type(
                                              enumerator_node.value)
                        value = enumerator_node.value.value
                    compound_type.enumerators.append(
                        self.__construct(
                            'Field',
                            annotations=self.__visit_annotation_nodes(
                                enumerator_node.annotations),
                            doc=self.__visit_doc_node(enumerator_node.doc),
                            id=enumerator_i,
                            name=enumerator_node.name,
                            type=Ast.BaseTypeNode('i32').accept(self),
                            value=value))
            else:
                for field in compound_type_node.fields:
                    compound_type.fields.append(field.accept(self))

            self.__scope_stack.pop(-1)

            return compound_type

        def visit_const_node(self, const_node):
            return \
                self.__construct(
                    'Const',
                    annotations=self.__visit_annotation_nodes(const_node.annotations),
                    doc=self.__visit_doc_node(const_node.doc),
                    name=const_node.name,
                    type=const_node.type.accept(self),
                    value=const_node.value.accept(self)
                )

        def __visit_doc_node(self, doc_node):
            if doc_node is not None:
                return doc_node.text
            else:
                return None

        def visit_document_node(self, document_node):
            document = self.__construct('Document', path=document_node.path)
            self.__scope_stack.append(document)

            for header_node in document_node.headers:
                document.headers.append(header_node.accept(self))
            for definition_node in document_node.definitions:
                document.definitions.append(definition_node.accept(self))

            self.__scope_stack.pop(-1)

            return document

        def visit_enum_type_node(self, enum_node):
            return self.__visit_compound_type_node('EnumType', enum_node)

        def visit_exception_type_node(self, exception_type_node):
            return self.__visit_compound_type_node('ExceptionType',
                                                   exception_type_node)

        def visit_field_node(self, field_node):
            return \
                self.__construct(
                    'Field',
                    annotations=self.__visit_annotation_nodes(field_node.annotations),
                    doc=self.__visit_doc_node(field_node.doc),
                    id=field_node.id,
                    name=field_node.name,
                    required=field_node.required,
                    type=field_node.type.accept(self),
                    value=field_node.value
                )

        def visit_float_literal_node(self, float_literal_node):
            return float_literal_node.value

        def visit_function_node(self, function_node):
            function = \
                self.__construct(
                    'Function',
                    annotations=self.__visit_annotation_nodes(function_node.annotations),
                    doc=self.__visit_doc_node(function_node.doc),
                    name=function_node.name,
                    oneway=function_node.oneway
                )
            self.__scope_stack.append(function)

            for parameter_node in function_node.parameters:
                function.parameters.append(parameter_node.accept(self))
            if function_node.return_field is not None:
                function.return_field = function_node.return_field.accept(self)
            for throws_node in function_node.throws:
                function.throws.append(throws_node.accept(self))

            self.__scope_stack.pop(-1)
            return function

        def visit_include_node(self, include_node):
            include_dir_paths = list(self.__include_dir_paths)
            for scope in reversed(self.__scope_stack):
                if isinstance(scope, Document):
                    include_dir_paths.append(os.path.dirname(scope.path))
                    break

            include_file_relpath = include_node.path.replace('/', os.path.sep)
            for include_dir_path in include_dir_paths:
                include_file_path = os.path.join(include_dir_path,
                                                 include_file_relpath)
                if os.path.exists(include_file_path):
                    include_file_path = os.path.abspath(include_file_path)
                    included_document = self.__compiler.compile(
                        (include_file_path, ), generator=self.__generator)[0]
                    include = \
                        self.__construct(
                            'Include',
                            annotations=self.__visit_annotation_nodes(include_node.annotations),
                            doc=self.__visit_doc_node(include_node.doc),
                            document=included_document,
                            path=include_file_relpath
                        )
                    for definition in included_document.definitions:
                        if isinstance(definition, _Type):
                            self.__type_cache[
                                definition.thrift_qname()] = definition
                        elif isinstance(definition, Typedef):
                            self.__type_cache[definition.thrift_qname(
                            )] = definition  # .type
                    return include
            raise CompileException("include path not found: %s" %
                                   include_file_relpath,
                                   ast_node=include_node)

        def visit_int_literal_node(self, int_literal_node):
            return int_literal_node.value

        def visit_list_literal_node(self, list_literal_node):
            return tuple(
                element_value.accept(self)
                for element_value in list_literal_node.value)

        def visit_list_type_node(self, list_type_node):
            return self.__visit_sequence_type_node('ListType', list_type_node)

        def visit_map_literal_node(self, map_literal_node):
            return dict((key_value.accept(self), value_value.accept(self))
                        for key_value, value_value in
                        map_literal_node.value.iteritems())

        def visit_map_type_node(self, map_type_node):
            try:
                return self.__type_cache[map_type_node.name]
            except:
                map_type = self.__construct(
                    'MapType',
                    key_type=map_type_node.key_type.accept(self),
                    value_type=map_type_node.value_type.accept(self))
                self.__type_cache[map_type_node.name] = map_type
                return map_type

        def visit_namespace_node(self, namespace_node):
            return \
                self.__construct(
                    'Namespace',
                    annotations=self.__visit_annotation_nodes(namespace_node.annotations),
                    doc=self.__visit_doc_node(namespace_node.doc),
                    name=namespace_node.name,
                    scope=namespace_node.scope
                )

        def __visit_sequence_type_node(self, construct_class_name,
                                       sequence_type_node):
            try:
                return self.__type_cache[sequence_type_node.name]
            except:
                sequence_type = self.__construct(
                    construct_class_name,
                    element_type=sequence_type_node.element_type.accept(self))
                self.__type_cache[sequence_type_node.name] = sequence_type
                return sequence_type

        def visit_service_node(self, service_node):
            service = \
                self.__construct(
                    'Service',
                    annotations=self.__visit_annotation_nodes(service_node.annotations),
                    doc=self.__visit_doc_node(service_node.doc),
                    name=service_node.name
                )
            self.__scope_stack.append(service)

            for function_node in service_node.functions:
                service.functions.append(function_node.accept(self))

            self.__scope_stack.pop(-1)

            return service

        def visit_set_type_node(self, set_type_node):
            return self.__visit_sequence_type_node('SetType', set_type_node)

        def visit_string_literal_node(self, string_literal_node):
            return string_literal_node.value

        def visit_struct_type_node(self, struct_type_node):
            return \
                self.__visit_compound_type_node(
                    'StructType',
                    struct_type_node,
                )

        def visit_type_node(self, type_node):
            try:
                try:
                    return self.__type_cache[type_node.qname]
                except KeyError:
                    if type_node.qname == type_node.name:
                        document = self.__scope_stack[0]
                        return self.__type_cache[document.name + '.' +
                                                 type_node.qname]
                    else:
                        raise
            except KeyError:
                raise CompileException("unrecognized type '%s'" %
                                       type_node.qname,
                                       ast_node=type_node)

        def visit_typedef_node(self, typedef_node):
            typedef = \
                self.__construct(
                    'Typedef',
                    annotations=self.__visit_annotation_nodes(typedef_node.annotations),
                    doc=self.__visit_doc_node(typedef_node.doc),
                    name=typedef_node.name,
                    type=typedef_node.type.accept(self)
                )

            self.__type_cache[typedef.thrift_qname()] = typedef.type

            return typedef

    def __init__(self, include_dir_paths=None):
        object.__init__(self)

        if include_dir_paths is None:
            include_dir_paths = []
        elif isinstance(include_dir_paths, (list, tuple)):
            include_dir_paths = list(include_dir_paths)
        else:
            include_dir_paths = [include_dir_paths]
        if len(include_dir_paths) == 0:
            include_dir_paths.append(os.getcwd())
        my_dir_path = os.path.dirname(os.path.realpath(__file__))
        lib_thrift_src_dir_path = \
            os.path.abspath(os.path.join(
                my_dir_path,
                '..', '..', '..', '..',
                'lib', 'thrift', 'src'
            ))
        if lib_thrift_src_dir_path not in include_dir_paths:
            include_dir_paths.append(lib_thrift_src_dir_path)
        self.__include_dir_paths = []
        for include_dir_path in include_dir_paths:
            if include_dir_path not in self.__include_dir_paths:
                self.__include_dir_paths.append(include_dir_path)
        self.__include_dir_paths = tuple(self.__include_dir_paths)

        self.__parsed_thrift_files_by_path = {}
        self.__scanner = Scanner()
        self.__parser = Parser()

    def __call__(self, thrift_file_paths, generator=None):
        return self.compile(thrift_file_paths, generator=generator)

    def compile(self, thrift_file_paths, generator=None):
        if not isinstance(thrift_file_paths, (list, tuple)):
            thrift_file_paths = (thrift_file_paths, )

        documents = []
        for thrift_file_path in thrift_file_paths:
            thrift_file_path = os.path.abspath(thrift_file_path)
            document_node = self.__parsed_thrift_files_by_path.get(
                thrift_file_path)
            if document_node is None:
                tokens = self.__scanner.tokenize(thrift_file_path)
                document_node = self.__parser.parse(tokens)
                self.__parsed_thrift_files_by_path[
                    thrift_file_path] = document_node
            if generator is not None:
                ast_visitor = \
                    self.__AstVisitor(
                        compiler=self,
                        generator=generator,
                        include_dir_paths=self.__include_dir_paths
                    )
                document = document_node.accept(ast_visitor)
                documents.append(document)
            else:
                documents.append(document_node)
        return tuple(documents)

    @property
    def include_dir_paths(self):
        return self.__include_dir_paths
コード例 #8
0
ファイル: compiler.py プロジェクト: financeCoding/thryft
class Compiler(object):
    class __AstVisitor(object):
        __ROOT_GENERATOR = Generator()

        def __init__(self, compiler, generator, include_dir_paths):
            object.__init__(self)
            self.__compiler = compiler
            self.__generator = generator
            self.__include_dir_paths = include_dir_paths
            self.__scope_stack = []
            self.__type_cache = {}

        def __construct(self, class_name, **kwds):
            if len(self.__scope_stack) > 0:
                parent = self.__scope_stack[-1]
            else:
                parent = self.__generator
            kwds['parent'] = parent

            annotations = kwds.get('annotations')
            name = kwds.get('name')
            if annotations is None:
                native = False
            elif name is None:
                native = False
            else:
                native = False
                for annotation in annotations:
                    if annotation == 'native':
                        native = True
                        break

            if native:
                for scope in reversed(self.__scope_stack):
                    if not isinstance(scope, Document):
                        continue

                    document = scope
                    overrides_module_file_path = os.path.splitext(document.path)[0] + '.py'
                    if not os.path.isfile(overrides_module_file_path):
                        continue
                    overrides_module_dir_path, overrides_module_file_name = \
                        os.path.split(overrides_module_file_path)
                    overrides_module_name = \
                        os.path.splitext(overrides_module_file_name)[0]
                    try:
                        overrides_module = \
                            imp.load_module(
                                overrides_module_name,
                                *imp.find_module(
                                    overrides_module_name,
                                    [overrides_module_dir_path]
                                )
                            )
                    except ImportError:
                        logging.error(
                            "error importing overrides module %s",
                            overrides_module_file_path,
                            exc_info=True
                        )
                        continue

                    # Find an override implementation corresponding to our generator
                    # For example, find JavaDateTime by looking in the module for a
                    # class that inherits JavaNativeType.
                    # The first_bases algorithm below covers the case where we want
                    # JavaDateTime but the generator gave us something that itself
                    # inherits from JavaStructType e.g., SubJavaStructType.
                    # In this case there is no SubJavaDateTime(SubJavaStructType), so
                    # we need to consider SubJavaStructType's parent (JavaStructType)
                    # and look for a subclass of that.
                    root_construct_class = getattr(self.__ROOT_GENERATOR, 'NativeType')
                    generator_construct_class = getattr(self.__generator, 'NativeType')
                    parent_construct_classes = [generator_construct_class]

                    def __get_first_bases(class_, first_bases):
                        if len(class_.__bases__) == 0:
                            return
                        first_base = class_.__bases__[0]
                        if first_base is object or first_base is root_construct_class:
                            return
                        if first_base.__name__[0] != '_':
                            first_bases.append(first_base)
                        __get_first_bases(first_base, first_bases)
                    __get_first_bases(generator_construct_class, parent_construct_classes)

                    for parent_construct_class in parent_construct_classes:
                        for attr in dir(overrides_module):
                            value = getattr(overrides_module, attr)
                            if isclass(value) and \
                               issubclass(value, parent_construct_class) and \
                               value != parent_construct_class:
                                return getattr(overrides_module, attr)(**kwds)
                    # logging.warn("could not find override class for %s in %s" % (default_construct_class.__name__, overrides_module_name))

            return getattr(self.__generator, class_name)(**kwds)

        def visit_annotation_node(self, annotation_node):
            return annotation_node.value

        def __visit_annotation_nodes(self, annotation_nodes):
            if annotation_nodes is None:
                return {}
            return dict((annotation_node.name, annotation_node.accept(self))
                         for annotation_node in annotation_nodes)

        def visit_base_type_node(self, base_type_node):
            try:
                return self.__type_cache[base_type_node.name]
            except:
                base_type = getattr(self.__generator, base_type_node.name.capitalize() + 'Type')(name=base_type_node.name)
                self.__type_cache[base_type_node.name] = base_type
                return base_type

        def visit_bool_literal_node(self, bool_literal_node):
            return bool_literal_node.value

        def __visit_compound_type_node(self, construct_class_name, compound_type_node):
            compound_type = \
                self.__construct(
                    construct_class_name,
                    annotations=self.__visit_annotation_nodes(compound_type_node.annotations),
                    doc=self.__visit_doc_node(compound_type_node.doc),
                    name=compound_type_node.name
                )
            if isinstance(compound_type, NativeType):
                return compound_type
            self.__scope_stack.append(compound_type)

            # Insert the compound type into the type_map here to allow recursive
            # definitions
            self.__type_cache[compound_type.thrift_qname()] = compound_type

            if construct_class_name == 'EnumType':
                have_enumerator_with_value = False
                for enumerator_node in compound_type_node.enumerators:
                    if enumerator_node.value is not None:
                        have_enumerator_with_value = True
                    elif have_enumerator_with_value:
                        raise CompileException("%s has mix of enumerators with and without values, must be one or the other" % compound_type_node.name, ast_node=compound_type_node)
                for enumerator_i, enumerator_node in enumerate(compound_type_node.enumerators):
                    if enumerator_node.value is None:
                        value = enumerator_i
                    else:
                        assert isinstance(enumerator_node.value, Ast.IntLiteralNode), type(enumerator_node.value)
                        value = enumerator_node.value.value
                    compound_type.enumerators.append(
                        self.__construct(
                            'Field',
                            annotations=self.__visit_annotation_nodes(enumerator_node.annotations),
                            doc=self.__visit_doc_node(enumerator_node.doc),
                            id=enumerator_i,
                            name=enumerator_node.name,
                            type=Ast.BaseTypeNode('i32').accept(self),
                            value=value
                        )
                    )
            else:
                for field in compound_type_node.fields:
                    compound_type.fields.append(field.accept(self))

            self.__scope_stack.pop(-1)

            return compound_type

        def visit_const_node(self, const_node):
            return \
                self.__construct(
                    'Const',
                    annotations=self.__visit_annotation_nodes(const_node.annotations),
                    doc=self.__visit_doc_node(const_node.doc),
                    name=const_node.name,
                    type=const_node.type.accept(self),
                    value=const_node.value.accept(self)
                )

        def __visit_doc_node(self, doc_node):
            if doc_node is not None:
                return doc_node.text
            else:
                return None

        def visit_document_node(self, document_node):
            document = self.__construct('Document', path=document_node.path)
            self.__scope_stack.append(document)

            for header_node in document_node.headers:
                document.headers.append(header_node.accept(self))
            for definition_node in document_node.definitions:
                document.definitions.append(definition_node.accept(self))

            self.__scope_stack.pop(-1)

            return document

        def visit_enum_type_node(self, enum_node):
            return self.__visit_compound_type_node('EnumType', enum_node)

        def visit_exception_type_node(self, exception_type_node):
            return self.__visit_compound_type_node('ExceptionType', exception_type_node)

        def visit_field_node(self, field_node):
            return \
                self.__construct(
                    'Field',
                    annotations=self.__visit_annotation_nodes(field_node.annotations),
                    doc=self.__visit_doc_node(field_node.doc),
                    id=field_node.id,
                    name=field_node.name,
                    required=field_node.required,
                    type=field_node.type.accept(self),
                    value=field_node.value
                )

        def visit_float_literal_node(self, float_literal_node):
            return float_literal_node.value

        def visit_function_node(self, function_node):
            function = \
                self.__construct(
                    'Function',
                    annotations=self.__visit_annotation_nodes(function_node.annotations),
                    doc=self.__visit_doc_node(function_node.doc),
                    name=function_node.name,
                    oneway=function_node.oneway
                )
            self.__scope_stack.append(function)

            for parameter_node in function_node.parameters:
                function.parameters.append(parameter_node.accept(self))
            if function_node.return_field is not None:
                function.return_field = function_node.return_field.accept(self)
            for throws_node in function_node.throws:
                function.throws.append(throws_node.accept(self))

            self.__scope_stack.pop(-1)
            return function

        def visit_include_node(self, include_node):
            include_dir_paths = list(self.__include_dir_paths)
            for scope in reversed(self.__scope_stack):
                if isinstance(scope, Document):
                    include_dir_paths.append(os.path.dirname(scope.path))
                    break

            include_file_relpath = include_node.path.replace('/', os.path.sep)
            for include_dir_path in include_dir_paths:
                include_file_path = os.path.join(include_dir_path, include_file_relpath)
                if os.path.exists(include_file_path):
                    include_file_path = os.path.abspath(include_file_path)
                    included_document = self.__compiler.compile((include_file_path,), generator=self.__generator)[0]
                    include = \
                        self.__construct(
                            'Include',
                            annotations=self.__visit_annotation_nodes(include_node.annotations),
                            doc=self.__visit_doc_node(include_node.doc),
                            document=included_document,
                            path=include_file_relpath
                        )
                    for definition in included_document.definitions:
                        if isinstance(definition, _Type):
                            self.__type_cache[definition.thrift_qname()] = definition
                        elif isinstance(definition, Typedef):
                            self.__type_cache[definition.thrift_qname()] = definition  # .type
                    return include
            raise CompileException("include path not found: %s" % include_file_relpath, ast_node=include_node)

        def visit_int_literal_node(self, int_literal_node):
            return int_literal_node.value

        def visit_list_literal_node(self, list_literal_node):
            return tuple(element_value.accept(self) for element_value in list_literal_node.value)

        def visit_list_type_node(self, list_type_node):
            return self.__visit_sequence_type_node('ListType', list_type_node)

        def visit_map_literal_node(self, map_literal_node):
            return dict((key_value.accept(self), value_value.accept(self)) for key_value, value_value in map_literal_node.value.iteritems())

        def visit_map_type_node(self, map_type_node):
            try:
                return self.__type_cache[map_type_node.name]
            except:
                map_type = self.__construct('MapType', key_type=map_type_node.key_type.accept(self), value_type=map_type_node.value_type.accept(self))
                self.__type_cache[map_type_node.name] = map_type
                return map_type

        def visit_namespace_node(self, namespace_node):
            return \
                self.__construct(
                    'Namespace',
                    annotations=self.__visit_annotation_nodes(namespace_node.annotations),
                    doc=self.__visit_doc_node(namespace_node.doc),
                    name=namespace_node.name,
                    scope=namespace_node.scope
                )

        def __visit_sequence_type_node(self, construct_class_name, sequence_type_node):
            try:
                return self.__type_cache[sequence_type_node.name]
            except:
                sequence_type = self.__construct(construct_class_name, element_type=sequence_type_node.element_type.accept(self))
                self.__type_cache[sequence_type_node.name] = sequence_type
                return sequence_type

        def visit_service_node(self, service_node):
            service = \
                self.__construct(
                    'Service',
                    annotations=self.__visit_annotation_nodes(service_node.annotations),
                    doc=self.__visit_doc_node(service_node.doc),
                    name=service_node.name
                )
            self.__scope_stack.append(service)

            for function_node in service_node.functions:
                service.functions.append(function_node.accept(self))

            self.__scope_stack.pop(-1)

            return service

        def visit_set_type_node(self, set_type_node):
            return self.__visit_sequence_type_node('SetType', set_type_node)

        def visit_string_literal_node(self, string_literal_node):
            return string_literal_node.value

        def visit_struct_type_node(self, struct_type_node):
            return \
                self.__visit_compound_type_node(
                    'StructType',
                    struct_type_node,
                )

        def visit_type_node(self, type_node):
            try:
                try:
                    return self.__type_cache[type_node.qname]
                except KeyError:
                    if type_node.qname == type_node.name:
                        document = self.__scope_stack[0]
                        return self.__type_cache[document.name + '.' + type_node.qname]
                    else:
                        raise
            except KeyError:
                raise CompileException("unrecognized type '%s'" % type_node.qname, ast_node=type_node)

        def visit_typedef_node(self, typedef_node):
            typedef = \
                self.__construct(
                    'Typedef',
                    annotations=self.__visit_annotation_nodes(typedef_node.annotations),
                    doc=self.__visit_doc_node(typedef_node.doc),
                    name=typedef_node.name,
                    type=typedef_node.type.accept(self)
                )

            self.__type_cache[typedef.thrift_qname()] = typedef.type

            return typedef

    def __init__(self, include_dir_paths=None):
        object.__init__(self)

        if include_dir_paths is None:
            include_dir_paths = []
        elif isinstance(include_dir_paths, (list, tuple)):
            include_dir_paths = list(include_dir_paths)
        else:
            include_dir_paths = [include_dir_paths]
        if len(include_dir_paths) == 0:
            include_dir_paths.append(os.getcwd())
        my_dir_path = os.path.dirname(os.path.realpath(__file__))
        lib_thrift_src_dir_path = \
            os.path.abspath(os.path.join(
                my_dir_path,
                '..', '..', '..', '..',
                'lib', 'thrift', 'src'
            ))
        if lib_thrift_src_dir_path not in include_dir_paths:
            include_dir_paths.append(lib_thrift_src_dir_path)
        self.__include_dir_paths = []
        for include_dir_path in include_dir_paths:
            if include_dir_path not in self.__include_dir_paths:
                self.__include_dir_paths.append(include_dir_path)
        self.__include_dir_paths = tuple(self.__include_dir_paths)

        self.__parsed_thrift_files_by_path = {}
        self.__scanner = Scanner()
        self.__parser = Parser()

    def __call__(self, thrift_file_paths, generator=None):
        return self.compile(thrift_file_paths, generator=generator)

    def compile(self, thrift_file_paths, generator=None):
        if not isinstance(thrift_file_paths, (list, tuple)):
            thrift_file_paths = (thrift_file_paths,)

        documents = []
        for thrift_file_path in thrift_file_paths:
            thrift_file_path = os.path.abspath(thrift_file_path)
            document_node = self.__parsed_thrift_files_by_path.get(thrift_file_path)
            if document_node is None:
                tokens = self.__scanner.tokenize(thrift_file_path)
                document_node = self.__parser.parse(tokens)
                self.__parsed_thrift_files_by_path[thrift_file_path] = document_node
            if generator is not None:
                ast_visitor = \
                    self.__AstVisitor(
                        compiler=self,
                        generator=generator,
                        include_dir_paths=self.__include_dir_paths
                    )
                document = document_node.accept(ast_visitor)
                documents.append(document)
            else:
                documents.append(document_node)
        return tuple(documents)

    @property
    def include_dir_paths(self):
        return self.__include_dir_paths