Example #1
0
    def __init__(self, include_dir_paths=None):
        object.__init__(self)

        if include_dir_paths is None:
            include_dir_paths = []
        elif isinstance(include_dir_paths, (list, tuple)):
            include_dir_paths = list(include_dir_paths)
        else:
            include_dir_paths = [include_dir_paths]
        if len(include_dir_paths) == 0:
            include_dir_paths.append(os.getcwd())
        my_dir_path = os.path.dirname(os.path.realpath(__file__))
        lib_thrift_src_dir_path = \
            os.path.abspath(os.path.join(
                my_dir_path,
                '..', '..', '..', '..',
                'lib', 'thrift', 'src'
            ))
        if lib_thrift_src_dir_path not in include_dir_paths:
            include_dir_paths.append(lib_thrift_src_dir_path)
        self.__include_dir_paths = []
        for include_dir_path in include_dir_paths:
            if include_dir_path not in self.__include_dir_paths:
                self.__include_dir_paths.append(include_dir_path)
        self.__include_dir_paths = tuple(self.__include_dir_paths)

        self.__parsed_thrift_files_by_path = {}
        self.__scanner = Scanner()
        self.__parser = Parser()
Example #2
0
    def __init__(self, document_root_dir_path, include_dir_paths):
        object.__init__(self)

        self.__document_root_dir_path = document_root_dir_path

        if isinstance(include_dir_paths, (list, tuple)):
            include_dir_paths = list(include_dir_paths)
        else:
            include_dir_paths = [include_dir_paths]
        my_dir_path = os.path.dirname(os.path.realpath(__file__))
        lib_thrift_src_dir_path = \
            os.path.abspath(os.path.join(
                my_dir_path,
                '..', '..', '..', '..',
                'lib', 'thrift', 'src'
            ))
        if lib_thrift_src_dir_path not in include_dir_paths:
            include_dir_paths.append(lib_thrift_src_dir_path)
        self.__include_dir_paths = []
        for include_dir_path in include_dir_paths:
            if include_dir_path not in self.__include_dir_paths:
                self.__include_dir_paths.append(include_dir_path)
        self.__include_dir_paths = tuple(self.__include_dir_paths)

        self.__document_cache = {}
        self.__scanner = Scanner()
        self.__parser = Parser()
Example #3
0
    def __init__(self, include_dir_paths=None):
        object.__init__(self)

        if include_dir_paths is None:
            include_dir_paths = []
        elif isinstance(include_dir_paths, (list, tuple)):
            include_dir_paths = list(include_dir_paths)
        else:
            include_dir_paths = [include_dir_paths]
        if len(include_dir_paths) == 0:
            include_dir_paths.append(os.getcwd())
        my_dir_path = os.path.dirname(os.path.realpath(__file__))
        lib_thrift_src_dir_path = \
            os.path.abspath(os.path.join(
                my_dir_path,
                '..', '..', '..', '..',
                'lib', 'thrift', 'src'
            ))
        if lib_thrift_src_dir_path not in include_dir_paths:
            include_dir_paths.append(lib_thrift_src_dir_path)
        self.__include_dir_paths = []
        for include_dir_path in include_dir_paths:
            if include_dir_path not in self.__include_dir_paths:
                self.__include_dir_paths.append(include_dir_path)
        self.__include_dir_paths = tuple(self.__include_dir_paths)

        self.__parsed_thrift_files_by_path = {}
        self.__scanner = Scanner()
        self.__parser = Parser()
Example #4
0
 def _runTest(self, thrift_file_path):
     #        import logging
     #        logging.basicConfig(level=logging.DEBUG)
     #        import os.path
     #        if os.path.split(thrift_file_path)[1] != 'struct_type.thrift':
     #            return
     tokens = []
     try:
         tokens = Scanner().tokenize(thrift_file_path)
         ast = Parser().parse(tokens)  # @UnusedVariable
         # import pprint; pprint.pprint(ast.to_dict())
     except:
         print >> sys.stderr, 'Error parsing', thrift_file_path
         traceback.print_exc()
         for token in tokens:
             print >> sys.stderr, token.index, token.type, ':', len(
                 token.text), ':', token.text
         print >> sys.stderr
         raise
            for function in self.functions:
                methods.extend(function.java_definitions())
            return methods

        def java_repr(self):
            delegate_name = self.__java_delegate_name()
            name = self.java_name()

            sections = []
            sections.append(self.__java_markers())
            sections.append("""public final static com.google.inject.name.Named DELEGATE_NAME = com.google.inject.name.Names.named("%(delegate_name)s");""" % locals())
            sections.append("\n\n".join([self._java_constructor()] + self._java_methods()))
            sections.append("\n".join(self._java_member_declarations()))
            sections.append(self.__java_log_messages())
            sections.append(self.__java_log_return_values())
            sections = "\n\n".join(indent(' ' * 4, sections))

            service_qname = java_generator.JavaGenerator.Service.java_qname(self)  # @UndefinedVariable

            return """\
@com.google.inject.Singleton
public class %(name)s implements %(service_qname)s {
%(sections)s
}""" % locals()


Parser.register_annotation_parser(JavaLogExceptionStackTraceAnnotationParser())
Parser.register_annotation_parser(JavaLogLevelAnnotationParser('java_log_level', (Ast.ExceptionTypeNode, Ast.FieldNode, Ast.FunctionNode)))
Parser.register_annotation_parser(JavaLogLevelAnnotationParser('java_log_level_post', (Ast.FunctionNode,)))
Parser.register_annotation_parser(JavaLogLevelAnnotationParser('java_log_level_pre', (Ast.FunctionNode,)))
            for function in self.functions:
                methods.extend(function.java_definitions())
            return methods

        def java_repr(self):
            delegate_name = self.__java_delegate_name()
            name = self.java_name()

            sections = []
            sections.append(self.__java_markers())
            sections.append(
                """public final static com.google.inject.name.Named DELEGATE_NAME = com.google.inject.name.Names.named("%(delegate_name)s");"""
                % locals())
            sections.append("\n\n".join([self._java_constructor()] +
                                        self._java_methods()))
            sections.append("\n".join(self._java_member_declarations()))
            sections = "\n\n".join(indent(' ' * 4, sections))

            service_qname = java_generator.JavaGenerator.Service.java_qname(
                self)  # @UndefinedVariable

            return """\
@com.google.inject.Singleton
public class %(name)s implements %(service_qname)s {
%(sections)s
}""" % locals()


Parser.register_annotation_parser(JavaLogExceptionStackTraceAnnotationParser())
Parser.register_annotation_parser(JavaLogLevelAnnotationParser())
Example #7
0
    from thryft.generators.sql.sql_i64_type import SqlI64Type as I64Type  # @UnusedImport
    from thryft.generators.sql.sql_include import SqlInclude as Include  # @UnusedImport
    from thryft.generators.sql.sql_list_type import SqlListType as ListType  # @UnusedImport
    from thryft.generators.sql.sql_map_type import SqlMapType as MapType  # @UnusedImport
    from thryft.generators.sql.sql_native_type import SqlNativeType as NativeType  # @UnusedImport
    from thryft.generators.sql.sql_service import SqlService as Service  # @UnusedImport
    from thryft.generators.sql.sql_set_type import SqlSetType as SetType  # @UnusedImport
    from thryft.generators.sql.sql_string_type import SqlStringType as StringType  # @UnusedImport
    from thryft.generators.sql.sql_struct_type import SqlStructType as StructType  # @UnusedImport
    from thryft.generators.sql.sql_typedef import SqlTypedef as Typedef  # @UnusedImport


def __parse_sql_foreign_key_annotation(ast_node, name, value, **kwds):
    value_parts = value.split('.')
    if len(value_parts) != 2:
        raise ValueError("@%s must be specify table.column: '%s'" % (name, value))
    table_name, column_name = value_parts
    if len(table_name) == 0 or len(column_name) == 0:
        raise ValueError("@%s must be specify a table.column: '%s'" % (name, value))

    annotation = Ast.AnnotationNode(name=name, value=(table_name, column_name), **kwds)

    ast_node.annotations.append(annotation)
Parser.register_annotation(Ast.FieldNode, 'sql_foreign_key', __parse_sql_foreign_key_annotation)

def __parse_sql_unique_annotation(ast_node, name, value, **kwds):
    if value is not None:
        raise ValueError("@%(name)s does not take a value" % locals())
    ast_node.annotations.append(Ast.AnnotationNode(name=name, **kwds))
Parser.register_annotation(Ast.FieldNode, 'sql_unique', __parse_sql_unique_annotation)
Example #8
0
    class MapType(Generator.MapType, _ContainerType):  # @UndefinedVariable
        pass

    class Service(Generator.Service, _NamedConstruct):  # @UndefinedVariable
        def lint(self):
            function_names = []
            for function in self.functions:
                after_function_name = check_lexicographic_order(
                    function_names, function.name)
                if after_function_name is not None:
                    self._logger.warn("function %s in %s is out of lexicographic order (should be after %s)",
                                      function.name, self._parent_document().path, after_function_name)
                function_names.append(function.name)

                function.lint()

    class SetType(Generator.SetType, _SequenceType):  # @UndefinedVariable
        pass

    class StringType(Generator.StringType, _Type):  # @UndefinedVariable
        pass

    class StructType(Generator.StructType, _CompoundType):  # @UndefinedVariable
        pass


Parser.register_annotation_parser(ValuelessAnnotationParser(
    'lint_require_field_ids_recursive', (Ast.ExceptionTypeNode, Ast.StructTypeNode)))
Parser.register_annotation_parser(ValuelessAnnotationParser(
    'lint_suppress', (Ast.EnumTypeNode, Ast.ExceptionTypeNode, Ast.FunctionNode, Ast.ServiceNode, Ast.StructTypeNode)))
                {
                    '_all': {'enabled': False},
                    'dynamic': 'strict',
                }
            for annotation in self.annotations:
                if annotation.name == 'elastic_search_mappings_base':
                    mappings[document_type].update(annotation.value)
            if 'properties' in mappings[document_type]:
                updated_properties = OrderedDict()
                updated_properties.update(
                    mappings[document_type]['properties'])
                updated_properties.update(properties)
                properties = updated_properties
            mappings[document_type]['properties'] = properties

            return mappings

    def __init__(self, settings=None, template=None):
        Generator.__init__(self)
        if settings is not None and not isinstance(settings, dict):
            raise TypeError('settings must be a dict')
        self._settings = settings
        if template is not None and not isinstance(template, str):
            raise TypeError('template must be a str')
        self._template = template


Parser.register_annotation_parser(ElasticSearchDocumentTypeAnnotationParser())
Parser.register_annotation_parser(ElasticSearchMappingAnnotationParser())
Parser.register_annotation_parser(ElasticSearchMappingsBaseAnnotationParser())
Example #10
0
    def __init__(self, default_methods=False, function_overloads=False, mutable_compound_types=False, namespace_prefix=None, **kwds):
        Generator.__init__(self, **kwds)
        self.__default_methods = default_methods
        self.__function_overloads = function_overloads
        self.__mutable_compound_types = mutable_compound_types
        self.__namespace_prefix = namespace_prefix

    @property
    def default_methods(self):
        return self.__default_methods

    @property
    def function_overloads(self):
        return self.__function_overloads

    @property
    def mutable_compound_types(self):
        return self.__mutable_compound_types

    @property
    def namespace_prefix(self):
        return self.__namespace_prefix


Parser.register_annotation_parser(JavaFinalAnnotationParser())
Parser.register_annotation_parser(AnnotationParser('java_extends', (Ast.ExceptionTypeNode, Ast.StructTypeNode)))
Parser.register_annotation_parser(AnnotationParser('java_implements', (Ast.EnumTypeNode, Ast.ExceptionTypeNode, Ast.ServiceNode, Ast.StructTypeNode)))
Parser.register_annotation_parser(ValuelessAnnotationParser('java_escape_to_string', Ast.FieldNode))
Parser.register_annotation_parser(ValuelessAnnotationParser('java_exclude_from_to_string', Ast.FieldNode))
Example #11
0
            mappings[document_type] = \
                {
                    '_all': {'enabled': False},
                    'dynamic': 'strict',
                }
            for annotation in self.annotations:
                if annotation.name == 'elastic_search_mappings_base':
                    mappings[document_type].update(annotation.value)
            if 'properties' in mappings[document_type]:
                updated_properties = OrderedDict()
                updated_properties.update(mappings[document_type]['properties'])
                updated_properties.update(properties)
                properties = updated_properties
            mappings[document_type]['properties'] = properties

            return mappings

    def __init__(self, settings=None, template=None):
        Generator.__init__(self)
        if settings is not None and not isinstance(settings, dict):
            raise TypeError('settings must be a dict')
        self._settings = settings
        if template is not None and not isinstance(template, str):
            raise TypeError('template must be a str')
        self._template = template


Parser.register_annotation_parser(ElasticSearchDocumentTypeAnnotationParser())
Parser.register_annotation_parser(ElasticSearchMappingAnnotationParser())
Parser.register_annotation_parser(ElasticSearchMappingsBaseAnnotationParser())
Example #12
0
class Compiler(object):
    class __AstVisitor(object):
        __ROOT_GENERATOR = Generator()

        def __init__(self, compiler, generator, include_dir_paths):
            object.__init__(self)
            self.__compiler = compiler
            self.__generator = generator
            self.__include_dir_paths = include_dir_paths
            self.__scope_stack = []
            self.__type_cache = {}

        def __construct(self, class_name, **kwds):
            if len(self.__scope_stack) > 0:
                parent = self.__scope_stack[-1]
            else:
                parent = self.__generator
            kwds['parent'] = parent

            annotations = kwds.get('annotations')
            name = kwds.get('name')
            if annotations is None:
                native = False
            elif name is None:
                native = False
            else:
                native = False
                for annotation in annotations:
                    if annotation == 'native':
                        native = True
                        break

            if native:
                for scope in reversed(self.__scope_stack):
                    if not isinstance(scope, Document):
                        continue

                    document = scope
                    overrides_module_file_path = os.path.splitext(document.path)[0] + '.py'
                    if not os.path.isfile(overrides_module_file_path):
                        continue
                    overrides_module_dir_path, overrides_module_file_name = \
                        os.path.split(overrides_module_file_path)
                    overrides_module_name = \
                        os.path.splitext(overrides_module_file_name)[0]
                    try:
                        overrides_module = \
                            imp.load_module(
                                overrides_module_name,
                                *imp.find_module(
                                    overrides_module_name,
                                    [overrides_module_dir_path]
                                )
                            )
                    except ImportError:
                        logging.error(
                            "error importing overrides module %s",
                            overrides_module_file_path,
                            exc_info=True
                        )
                        continue

                    # Find an override implementation corresponding to our generator
                    # For example, find JavaDateTime by looking in the module for a
                    # class that inherits JavaNativeType.
                    # The first_bases algorithm below covers the case where we want
                    # JavaDateTime but the generator gave us something that itself
                    # inherits from JavaStructType e.g., SubJavaStructType.
                    # In this case there is no SubJavaDateTime(SubJavaStructType), so
                    # we need to consider SubJavaStructType's parent (JavaStructType)
                    # and look for a subclass of that.
                    root_construct_class = getattr(self.__ROOT_GENERATOR, 'NativeType')
                    generator_construct_class = getattr(self.__generator, 'NativeType')
                    parent_construct_classes = [generator_construct_class]

                    def __get_first_bases(class_, first_bases):
                        if len(class_.__bases__) == 0:
                            return
                        first_base = class_.__bases__[0]
                        if first_base is object or first_base is root_construct_class:
                            return
                        if first_base.__name__[0] != '_':
                            first_bases.append(first_base)
                        __get_first_bases(first_base, first_bases)
                    __get_first_bases(generator_construct_class, parent_construct_classes)

                    for parent_construct_class in parent_construct_classes:
                        for attr in dir(overrides_module):
                            value = getattr(overrides_module, attr)
                            if isclass(value) and \
                               issubclass(value, parent_construct_class) and \
                               value != parent_construct_class:
                                return getattr(overrides_module, attr)(**kwds)
                    # logging.warn("could not find override class for %s in %s" % (default_construct_class.__name__, overrides_module_name))

            return getattr(self.__generator, class_name)(**kwds)

        def visit_annotation_node(self, annotation_node):
            return annotation_node.value

        def __visit_annotation_nodes(self, annotation_nodes):
            if annotation_nodes is None:
                return {}
            return dict((annotation_node.name, annotation_node.accept(self))
                         for annotation_node in annotation_nodes)

        def visit_base_type_node(self, base_type_node):
            try:
                return self.__type_cache[base_type_node.name]
            except:
                base_type = getattr(self.__generator, base_type_node.name.capitalize() + 'Type')(name=base_type_node.name)
                self.__type_cache[base_type_node.name] = base_type
                return base_type

        def visit_bool_literal_node(self, bool_literal_node):
            return bool_literal_node.value

        def __visit_compound_type_node(self, construct_class_name, compound_type_node):
            compound_type = \
                self.__construct(
                    construct_class_name,
                    annotations=self.__visit_annotation_nodes(compound_type_node.annotations),
                    doc=self.__visit_doc_node(compound_type_node.doc),
                    name=compound_type_node.name
                )
            if isinstance(compound_type, NativeType):
                return compound_type
            self.__scope_stack.append(compound_type)

            # Insert the compound type into the type_map here to allow recursive
            # definitions
            self.__type_cache[compound_type.thrift_qname()] = compound_type

            if construct_class_name == 'EnumType':
                have_enumerator_with_value = False
                for enumerator_node in compound_type_node.enumerators:
                    if enumerator_node.value is not None:
                        have_enumerator_with_value = True
                    elif have_enumerator_with_value:
                        raise CompileException("%s has mix of enumerators with and without values, must be one or the other" % compound_type_node.name, ast_node=compound_type_node)
                for enumerator_i, enumerator_node in enumerate(compound_type_node.enumerators):
                    if enumerator_node.value is None:
                        value = enumerator_i
                    else:
                        assert isinstance(enumerator_node.value, Ast.IntLiteralNode), type(enumerator_node.value)
                        value = enumerator_node.value.value
                    compound_type.enumerators.append(
                        self.__construct(
                            'Field',
                            annotations=self.__visit_annotation_nodes(enumerator_node.annotations),
                            doc=self.__visit_doc_node(enumerator_node.doc),
                            id=enumerator_i,
                            name=enumerator_node.name,
                            type=Ast.BaseTypeNode('i32').accept(self),
                            value=value
                        )
                    )
            else:
                for field in compound_type_node.fields:
                    compound_type.fields.append(field.accept(self))

            self.__scope_stack.pop(-1)

            return compound_type

        def visit_const_node(self, const_node):
            return \
                self.__construct(
                    'Const',
                    annotations=self.__visit_annotation_nodes(const_node.annotations),
                    doc=self.__visit_doc_node(const_node.doc),
                    name=const_node.name,
                    type=const_node.type.accept(self),
                    value=const_node.value.accept(self)
                )

        def __visit_doc_node(self, doc_node):
            if doc_node is not None:
                return doc_node.text
            else:
                return None

        def visit_document_node(self, document_node):
            document = self.__construct('Document', path=document_node.path)
            self.__scope_stack.append(document)

            for header_node in document_node.headers:
                document.headers.append(header_node.accept(self))
            for definition_node in document_node.definitions:
                document.definitions.append(definition_node.accept(self))

            self.__scope_stack.pop(-1)

            return document

        def visit_enum_type_node(self, enum_node):
            return self.__visit_compound_type_node('EnumType', enum_node)

        def visit_exception_type_node(self, exception_type_node):
            return self.__visit_compound_type_node('ExceptionType', exception_type_node)

        def visit_field_node(self, field_node):
            return \
                self.__construct(
                    'Field',
                    annotations=self.__visit_annotation_nodes(field_node.annotations),
                    doc=self.__visit_doc_node(field_node.doc),
                    id=field_node.id,
                    name=field_node.name,
                    required=field_node.required,
                    type=field_node.type.accept(self),
                    value=field_node.value
                )

        def visit_float_literal_node(self, float_literal_node):
            return float_literal_node.value

        def visit_function_node(self, function_node):
            function = \
                self.__construct(
                    'Function',
                    annotations=self.__visit_annotation_nodes(function_node.annotations),
                    doc=self.__visit_doc_node(function_node.doc),
                    name=function_node.name,
                    oneway=function_node.oneway
                )
            self.__scope_stack.append(function)

            for parameter_node in function_node.parameters:
                function.parameters.append(parameter_node.accept(self))
            if function_node.return_field is not None:
                function.return_field = function_node.return_field.accept(self)
            for throws_node in function_node.throws:
                function.throws.append(throws_node.accept(self))

            self.__scope_stack.pop(-1)
            return function

        def visit_include_node(self, include_node):
            include_dir_paths = list(self.__include_dir_paths)
            for scope in reversed(self.__scope_stack):
                if isinstance(scope, Document):
                    include_dir_paths.append(os.path.dirname(scope.path))
                    break

            include_file_relpath = include_node.path.replace('/', os.path.sep)
            for include_dir_path in include_dir_paths:
                include_file_path = os.path.join(include_dir_path, include_file_relpath)
                if os.path.exists(include_file_path):
                    include_file_path = os.path.abspath(include_file_path)
                    included_document = self.__compiler.compile((include_file_path,), generator=self.__generator)[0]
                    include = \
                        self.__construct(
                            'Include',
                            annotations=self.__visit_annotation_nodes(include_node.annotations),
                            doc=self.__visit_doc_node(include_node.doc),
                            document=included_document,
                            path=include_file_relpath
                        )
                    for definition in included_document.definitions:
                        if isinstance(definition, _Type):
                            self.__type_cache[definition.thrift_qname()] = definition
                        elif isinstance(definition, Typedef):
                            self.__type_cache[definition.thrift_qname()] = definition  # .type
                    return include
            raise CompileException("include path not found: %s" % include_file_relpath, ast_node=include_node)

        def visit_int_literal_node(self, int_literal_node):
            return int_literal_node.value

        def visit_list_literal_node(self, list_literal_node):
            return tuple(element_value.accept(self) for element_value in list_literal_node.value)

        def visit_list_type_node(self, list_type_node):
            return self.__visit_sequence_type_node('ListType', list_type_node)

        def visit_map_literal_node(self, map_literal_node):
            return dict((key_value.accept(self), value_value.accept(self)) for key_value, value_value in map_literal_node.value.iteritems())

        def visit_map_type_node(self, map_type_node):
            try:
                return self.__type_cache[map_type_node.name]
            except:
                map_type = self.__construct('MapType', key_type=map_type_node.key_type.accept(self), value_type=map_type_node.value_type.accept(self))
                self.__type_cache[map_type_node.name] = map_type
                return map_type

        def visit_namespace_node(self, namespace_node):
            return \
                self.__construct(
                    'Namespace',
                    annotations=self.__visit_annotation_nodes(namespace_node.annotations),
                    doc=self.__visit_doc_node(namespace_node.doc),
                    name=namespace_node.name,
                    scope=namespace_node.scope
                )

        def __visit_sequence_type_node(self, construct_class_name, sequence_type_node):
            try:
                return self.__type_cache[sequence_type_node.name]
            except:
                sequence_type = self.__construct(construct_class_name, element_type=sequence_type_node.element_type.accept(self))
                self.__type_cache[sequence_type_node.name] = sequence_type
                return sequence_type

        def visit_service_node(self, service_node):
            service = \
                self.__construct(
                    'Service',
                    annotations=self.__visit_annotation_nodes(service_node.annotations),
                    doc=self.__visit_doc_node(service_node.doc),
                    name=service_node.name
                )
            self.__scope_stack.append(service)

            for function_node in service_node.functions:
                service.functions.append(function_node.accept(self))

            self.__scope_stack.pop(-1)

            return service

        def visit_set_type_node(self, set_type_node):
            return self.__visit_sequence_type_node('SetType', set_type_node)

        def visit_string_literal_node(self, string_literal_node):
            return string_literal_node.value

        def visit_struct_type_node(self, struct_type_node):
            return \
                self.__visit_compound_type_node(
                    'StructType',
                    struct_type_node,
                )

        def visit_type_node(self, type_node):
            try:
                try:
                    return self.__type_cache[type_node.qname]
                except KeyError:
                    if type_node.qname == type_node.name:
                        document = self.__scope_stack[0]
                        return self.__type_cache[document.name + '.' + type_node.qname]
                    else:
                        raise
            except KeyError:
                raise CompileException("unrecognized type '%s'" % type_node.qname, ast_node=type_node)

        def visit_typedef_node(self, typedef_node):
            typedef = \
                self.__construct(
                    'Typedef',
                    annotations=self.__visit_annotation_nodes(typedef_node.annotations),
                    doc=self.__visit_doc_node(typedef_node.doc),
                    name=typedef_node.name,
                    type=typedef_node.type.accept(self)
                )

            self.__type_cache[typedef.thrift_qname()] = typedef.type

            return typedef

    def __init__(self, include_dir_paths=None):
        object.__init__(self)

        if include_dir_paths is None:
            include_dir_paths = []
        elif isinstance(include_dir_paths, (list, tuple)):
            include_dir_paths = list(include_dir_paths)
        else:
            include_dir_paths = [include_dir_paths]
        if len(include_dir_paths) == 0:
            include_dir_paths.append(os.getcwd())
        my_dir_path = os.path.dirname(os.path.realpath(__file__))
        lib_thrift_src_dir_path = \
            os.path.abspath(os.path.join(
                my_dir_path,
                '..', '..', '..', '..',
                'lib', 'thrift', 'src'
            ))
        if lib_thrift_src_dir_path not in include_dir_paths:
            include_dir_paths.append(lib_thrift_src_dir_path)
        self.__include_dir_paths = []
        for include_dir_path in include_dir_paths:
            if include_dir_path not in self.__include_dir_paths:
                self.__include_dir_paths.append(include_dir_path)
        self.__include_dir_paths = tuple(self.__include_dir_paths)

        self.__parsed_thrift_files_by_path = {}
        self.__scanner = Scanner()
        self.__parser = Parser()

    def __call__(self, thrift_file_paths, generator=None):
        return self.compile(thrift_file_paths, generator=generator)

    def compile(self, thrift_file_paths, generator=None):
        if not isinstance(thrift_file_paths, (list, tuple)):
            thrift_file_paths = (thrift_file_paths,)

        documents = []
        for thrift_file_path in thrift_file_paths:
            thrift_file_path = os.path.abspath(thrift_file_path)
            document_node = self.__parsed_thrift_files_by_path.get(thrift_file_path)
            if document_node is None:
                tokens = self.__scanner.tokenize(thrift_file_path)
                document_node = self.__parser.parse(tokens)
                self.__parsed_thrift_files_by_path[thrift_file_path] = document_node
            if generator is not None:
                ast_visitor = \
                    self.__AstVisitor(
                        compiler=self,
                        generator=generator,
                        include_dir_paths=self.__include_dir_paths
                    )
                document = document_node.accept(ast_visitor)
                documents.append(document)
            else:
                documents.append(document_node)
        return tuple(documents)

    @property
    def include_dir_paths(self):
        return self.__include_dir_paths
Example #13
0
                if len(function_names) > 0 and cmp(function.name, function_names[-1]) < 0:
                    after_function_name = ""
                    for function_name_i in xrange(len(function_names) - 1, -1, -1):
                        test_function_name = function_names[function_name_i]
                        if cmp(function.name, test_function_name) >= 0:
                            after_function_name = test_function_name
                            break
                    self._logger.warn(
                        "function %s in %s is out of lexicographic order (should be after %s)",
                        function.name,
                        self._parent_document().path,
                        after_function_name,
                    )
                function_names.append(function.name)

    class SetType(Generator.SetType, _SequenceType):  # @UndefinedVariable
        pass

    class StringType(Generator.StringType, _Type):  # @UndefinedVariable
        pass

    class StructType(Generator.StructType, _CompoundType):  # @UndefinedVariable
        pass


Parser.register_annotation_parser(
    ValuelessAnnotationParser(
        "lint_suppress", (Ast.EnumTypeNode, Ast.ExceptionTypeNode, Ast.ServiceNode, Ast.StructTypeNode)
    )
)
Example #14
0
from thryft.compiler.parser import Parser
from thryft.generator.generator import Generator
from thryft.generators.js.js_view_metadata_annotation_parser import JsViewMetadataAnnotationParser


class JsGenerator(Generator):
    from thryft.generators.js.js_binary_type import JsBinaryType as BinaryType  # @UnusedImport
    from thryft.generators.js.js_bool_type import JsBoolType as BoolType  # @UnusedImport
    from thryft.generators.js.js_byte_type import JsByteType as ByteType  # @UnusedImport
    from thryft.generators.js.js_const import JsConst as Const  # @UnusedImport
    from thryft.generators.js.js_document import JsDocument as Document  # @UnusedImport
    from thryft.generators.js.js_double_type import JsDoubleType as DoubleType  # @UnusedImport
    from thryft.generators.js.js_enum_type import JsEnumType as EnumType  # @UnusedImport
    from thryft.generators.js.js_exception_type import JsExceptionType as ExceptionType  # @UnusedImport
    from thryft.generators.js.js_field import JsField as Field  # @UnusedImport
    from thryft.generators.js.js_function import JsFunction as Function  # @UnusedImport
    from thryft.generators.js.js_i16_type import JsI16Type as I16Type  # @UnusedImport
    from thryft.generators.js.js_i32_type import JsI32Type as I32Type  # @UnusedImport
    from thryft.generators.js.js_i64_type import JsI64Type as I64Type  # @UnusedImport
    from thryft.generators.js.js_include import JsInclude as Include  # @UnusedImport
    from thryft.generators.js.js_list_type import JsListType as ListType  # @UnusedImport
    from thryft.generators.js.js_map_type import JsMapType as MapType  # @UnusedImport
    from thryft.generators.js.js_service import JsService as Service  # @UnusedImport
    from thryft.generators.js.js_set_type import JsSetType as SetType  # @UnusedImport
    from thryft.generators.js.js_string_type import JsStringType as StringType  # @UnusedImport
    from thryft.generators.js.js_struct_type import JsStructType as StructType  # @UnusedImport
    from thryft.generators.js.js_typedef import JsTypedef as Typedef  # @UnusedImport

Parser.register_annotation_parser(JsViewMetadataAnnotationParser())
Example #15
0
    from thryft.generators.js.js_i64_type import JsI64Type as I64Type  # @UnusedImport
    from thryft.generators.js.js_include import JsInclude as Include  # @UnusedImport
    from thryft.generators.js.js_list_type import JsListType as ListType  # @UnusedImport
    from thryft.generators.js.js_map_type import JsMapType as MapType  # @UnusedImport
    from thryft.generators.js.js_native_type import JsNativeType as NativeType  # @UnusedImport
    from thryft.generators.js.js_service import JsService as Service  # @UnusedImport
    from thryft.generators.js.js_set_type import JsSetType as SetType  # @UnusedImport
    from thryft.generators.js.js_string_type import JsStringType as StringType  # @UnusedImport
    from thryft.generators.js.js_struct_type import JsStructType as StructType  # @UnusedImport


def __parse_js_view_metadata_annotation(ast_node, name, value, **kwds):
    assert isinstance(ast_node, Ast.FieldNode)

    try:
        value = json.loads(value)
    except ValueError, e:
        raise ValueError("@%s contains invalid JSON: '%s', exception: %s" % (name, value, e))
    if not isinstance(value, dict):
        raise ValueError("expected @%s to contain a JSON object, found '%s'" % (name, value))

    for subname, subvalue in value.iteritems():
        if subname not in ('displayFormat', 'editControl'):
            logging.warn("unknown %(name)s property '%s'" % locals())

    annotation = Ast.AnnotationNode(name=name, value=value, **kwds)

    ast_node.annotations.append(annotation)

Parser.register_annotation(Ast.FieldNode, 'js_view_metadata', __parse_js_view_metadata_annotation)
Example #16
0
def __parse_sql_foreign_key_annotation(ast_node, name, value, **kwds):
    value_parts = value.split('.')
    if len(value_parts) != 2:
        raise ValueError("@%s must be specify table.column: '%s'" %
                         (name, value))
    table_name, column_name = value_parts
    if len(table_name) == 0 or len(column_name) == 0:
        raise ValueError("@%s must be specify a table.column: '%s'" %
                         (name, value))

    annotation = Ast.AnnotationNode(name=name,
                                    value=(table_name, column_name),
                                    **kwds)

    ast_node.annotations.append(annotation)


Parser.register_annotation(Ast.FieldNode, 'sql_foreign_key',
                           __parse_sql_foreign_key_annotation)


def __parse_sql_unique_annotation(ast_node, name, value, **kwds):
    if value is not None:
        raise ValueError("@%(name)s does not take a value" % locals())
    ast_node.annotations.append(Ast.AnnotationNode(name=name, **kwds))


Parser.register_annotation(Ast.FieldNode, 'sql_unique',
                           __parse_sql_unique_annotation)
Example #17
0
from thryft.generators.sql.sql_foreign_key_annotation_parser import SqlForeignKeyAnnotationParser


class SqlGenerator(Generator):
    from thryft.generators.sql.sql_binary_type import SqlBinaryType as BinaryType  # @UnusedImport
    from thryft.generators.sql.sql_bool_type import SqlBoolType as BoolType  # @UnusedImport
    from thryft.generators.sql.sql_byte_type import SqlByteType as ByteType  # @UnusedImport
    from thryft.generators.sql.sql_const import SqlConst as Const  # @UnusedImport
    from thryft.generators.sql.sql_document import SqlDocument as Document  # @UnusedImport
    from thryft.generators.sql.sql_double_type import SqlDoubleType as DoubleType  # @UnusedImport
    from thryft.generators.sql.sql_enum_type import SqlEnumType as EnumType  # @UnusedImport
    from thryft.generators.sql.sql_exception_type import SqlExceptionType as ExceptionType  # @UnusedImport
    from thryft.generators.sql.sql_field import SqlField as Field  # @UnusedImport
    from thryft.generators.sql.sql_function import SqlFunction as Function  # @UnusedImport
    from thryft.generators.sql.sql_i16_type import SqlI16Type as I16Type  # @UnusedImport
    from thryft.generators.sql.sql_i32_type import SqlI32Type as I32Type  # @UnusedImport
    from thryft.generators.sql.sql_i64_type import SqlI64Type as I64Type  # @UnusedImport
    from thryft.generators.sql.sql_include import SqlInclude as Include  # @UnusedImport
    from thryft.generators.sql.sql_list_type import SqlListType as ListType  # @UnusedImport
    from thryft.generators.sql.sql_map_type import SqlMapType as MapType  # @UnusedImport
    from thryft.generators.sql.sql_service import SqlService as Service  # @UnusedImport
    from thryft.generators.sql.sql_set_type import SqlSetType as SetType  # @UnusedImport
    from thryft.generators.sql.sql_string_type import SqlStringType as StringType  # @UnusedImport
    from thryft.generators.sql.sql_struct_type import SqlStructType as StructType  # @UnusedImport
    from thryft.generators.sql.sql_typedef import SqlTypedef as Typedef  # @UnusedImport


Parser.register_annotation_parser(AnnotationParser('sql_column', Ast.StructTypeNode))
Parser.register_annotation_parser(SqlForeignKeyAnnotationParser())
Parser.register_annotation_parser(ValuelessAnnotationParser('sql_unique', Ast.FieldNode))
Example #18
0
class Compiler(object):
    class __AstVisitor(object):
        __ROOT_GENERATOR = Generator()

        def __init__(self, compiler, generator, include_dir_paths):
            object.__init__(self)
            self.__compiler = compiler
            self.__generator = generator
            self.__include_dir_paths = include_dir_paths
            self.__scope_stack = []
            self.__type_cache = {}

        def __construct(self, class_name, **kwds):
            if len(self.__scope_stack) > 0:
                parent = self.__scope_stack[-1]
            else:
                parent = self.__generator
            kwds['parent'] = parent

            annotations = kwds.get('annotations')
            name = kwds.get('name')
            if annotations is None:
                native = False
            elif name is None:
                native = False
            else:
                native = False
                for annotation in annotations:
                    if annotation == 'native':
                        native = True
                        break

            if native:
                for scope in reversed(self.__scope_stack):
                    if not isinstance(scope, Document):
                        continue

                    document = scope
                    overrides_module_file_path = os.path.splitext(
                        document.path)[0] + '.py'
                    if not os.path.isfile(overrides_module_file_path):
                        continue
                    overrides_module_dir_path, overrides_module_file_name = \
                        os.path.split(overrides_module_file_path)
                    overrides_module_name = \
                        os.path.splitext(overrides_module_file_name)[0]
                    try:
                        overrides_module = \
                            imp.load_module(
                                overrides_module_name,
                                *imp.find_module(
                                    overrides_module_name,
                                    [overrides_module_dir_path]
                                )
                            )
                    except ImportError:
                        logging.error("error importing overrides module %s",
                                      overrides_module_file_path,
                                      exc_info=True)
                        continue

                    # Find an override implementation corresponding to our generator
                    # For example, find JavaDateTime by looking in the module for a
                    # class that inherits JavaNativeType.
                    # The first_bases algorithm below covers the case where we want
                    # JavaDateTime but the generator gave us something that itself
                    # inherits from JavaStructType e.g., SubJavaStructType.
                    # In this case there is no SubJavaDateTime(SubJavaStructType), so
                    # we need to consider SubJavaStructType's parent (JavaStructType)
                    # and look for a subclass of that.
                    root_construct_class = getattr(self.__ROOT_GENERATOR,
                                                   'NativeType')
                    generator_construct_class = getattr(
                        self.__generator, 'NativeType')
                    parent_construct_classes = [generator_construct_class]

                    def __get_first_bases(class_, first_bases):
                        if len(class_.__bases__) == 0:
                            return
                        first_base = class_.__bases__[0]
                        if first_base is object or first_base is root_construct_class:
                            return
                        if first_base.__name__[0] != '_':
                            first_bases.append(first_base)
                        __get_first_bases(first_base, first_bases)

                    __get_first_bases(generator_construct_class,
                                      parent_construct_classes)

                    for parent_construct_class in parent_construct_classes:
                        for attr in dir(overrides_module):
                            value = getattr(overrides_module, attr)
                            if isclass(value) and \
                               issubclass(value, parent_construct_class) and \
                               value != parent_construct_class:
                                return getattr(overrides_module, attr)(**kwds)
                    # logging.warn("could not find override class for %s in %s" % (default_construct_class.__name__, overrides_module_name))

            return getattr(self.__generator, class_name)(**kwds)

        def visit_annotation_node(self, annotation_node):
            return annotation_node.value

        def __visit_annotation_nodes(self, annotation_nodes):
            if annotation_nodes is None:
                return {}
            return dict((annotation_node.name, annotation_node.accept(self))
                        for annotation_node in annotation_nodes)

        def visit_base_type_node(self, base_type_node):
            try:
                return self.__type_cache[base_type_node.name]
            except:
                base_type = getattr(self.__generator,
                                    base_type_node.name.capitalize() +
                                    'Type')(name=base_type_node.name)
                self.__type_cache[base_type_node.name] = base_type
                return base_type

        def visit_bool_literal_node(self, bool_literal_node):
            return bool_literal_node.value

        def __visit_compound_type_node(self, construct_class_name,
                                       compound_type_node):
            compound_type = \
                self.__construct(
                    construct_class_name,
                    annotations=self.__visit_annotation_nodes(compound_type_node.annotations),
                    doc=self.__visit_doc_node(compound_type_node.doc),
                    name=compound_type_node.name
                )
            if isinstance(compound_type, NativeType):
                return compound_type
            self.__scope_stack.append(compound_type)

            # Insert the compound type into the type_map here to allow recursive
            # definitions
            self.__type_cache[compound_type.thrift_qname()] = compound_type

            if construct_class_name == 'EnumType':
                have_enumerator_with_value = False
                for enumerator_node in compound_type_node.enumerators:
                    if enumerator_node.value is not None:
                        have_enumerator_with_value = True
                    elif have_enumerator_with_value:
                        raise CompileException(
                            "%s has mix of enumerators with and without values, must be one or the other"
                            % compound_type_node.name,
                            ast_node=compound_type_node)
                for enumerator_i, enumerator_node in enumerate(
                        compound_type_node.enumerators):
                    if enumerator_node.value is None:
                        value = enumerator_i
                    else:
                        assert isinstance(enumerator_node.value,
                                          Ast.IntLiteralNode), type(
                                              enumerator_node.value)
                        value = enumerator_node.value.value
                    compound_type.enumerators.append(
                        self.__construct(
                            'Field',
                            annotations=self.__visit_annotation_nodes(
                                enumerator_node.annotations),
                            doc=self.__visit_doc_node(enumerator_node.doc),
                            id=enumerator_i,
                            name=enumerator_node.name,
                            type=Ast.BaseTypeNode('i32').accept(self),
                            value=value))
            else:
                for field in compound_type_node.fields:
                    compound_type.fields.append(field.accept(self))

            self.__scope_stack.pop(-1)

            return compound_type

        def visit_const_node(self, const_node):
            return \
                self.__construct(
                    'Const',
                    annotations=self.__visit_annotation_nodes(const_node.annotations),
                    doc=self.__visit_doc_node(const_node.doc),
                    name=const_node.name,
                    type=const_node.type.accept(self),
                    value=const_node.value.accept(self)
                )

        def __visit_doc_node(self, doc_node):
            if doc_node is not None:
                return doc_node.text
            else:
                return None

        def visit_document_node(self, document_node):
            document = self.__construct('Document', path=document_node.path)
            self.__scope_stack.append(document)

            for header_node in document_node.headers:
                document.headers.append(header_node.accept(self))
            for definition_node in document_node.definitions:
                document.definitions.append(definition_node.accept(self))

            self.__scope_stack.pop(-1)

            return document

        def visit_enum_type_node(self, enum_node):
            return self.__visit_compound_type_node('EnumType', enum_node)

        def visit_exception_type_node(self, exception_type_node):
            return self.__visit_compound_type_node('ExceptionType',
                                                   exception_type_node)

        def visit_field_node(self, field_node):
            return \
                self.__construct(
                    'Field',
                    annotations=self.__visit_annotation_nodes(field_node.annotations),
                    doc=self.__visit_doc_node(field_node.doc),
                    id=field_node.id,
                    name=field_node.name,
                    required=field_node.required,
                    type=field_node.type.accept(self),
                    value=field_node.value
                )

        def visit_float_literal_node(self, float_literal_node):
            return float_literal_node.value

        def visit_function_node(self, function_node):
            function = \
                self.__construct(
                    'Function',
                    annotations=self.__visit_annotation_nodes(function_node.annotations),
                    doc=self.__visit_doc_node(function_node.doc),
                    name=function_node.name,
                    oneway=function_node.oneway
                )
            self.__scope_stack.append(function)

            for parameter_node in function_node.parameters:
                function.parameters.append(parameter_node.accept(self))
            if function_node.return_field is not None:
                function.return_field = function_node.return_field.accept(self)
            for throws_node in function_node.throws:
                function.throws.append(throws_node.accept(self))

            self.__scope_stack.pop(-1)
            return function

        def visit_include_node(self, include_node):
            include_dir_paths = list(self.__include_dir_paths)
            for scope in reversed(self.__scope_stack):
                if isinstance(scope, Document):
                    include_dir_paths.append(os.path.dirname(scope.path))
                    break

            include_file_relpath = include_node.path.replace('/', os.path.sep)
            for include_dir_path in include_dir_paths:
                include_file_path = os.path.join(include_dir_path,
                                                 include_file_relpath)
                if os.path.exists(include_file_path):
                    include_file_path = os.path.abspath(include_file_path)
                    included_document = self.__compiler.compile(
                        (include_file_path, ), generator=self.__generator)[0]
                    include = \
                        self.__construct(
                            'Include',
                            annotations=self.__visit_annotation_nodes(include_node.annotations),
                            doc=self.__visit_doc_node(include_node.doc),
                            document=included_document,
                            path=include_file_relpath
                        )
                    for definition in included_document.definitions:
                        if isinstance(definition, _Type):
                            self.__type_cache[
                                definition.thrift_qname()] = definition
                        elif isinstance(definition, Typedef):
                            self.__type_cache[definition.thrift_qname(
                            )] = definition  # .type
                    return include
            raise CompileException("include path not found: %s" %
                                   include_file_relpath,
                                   ast_node=include_node)

        def visit_int_literal_node(self, int_literal_node):
            return int_literal_node.value

        def visit_list_literal_node(self, list_literal_node):
            return tuple(
                element_value.accept(self)
                for element_value in list_literal_node.value)

        def visit_list_type_node(self, list_type_node):
            return self.__visit_sequence_type_node('ListType', list_type_node)

        def visit_map_literal_node(self, map_literal_node):
            return dict((key_value.accept(self), value_value.accept(self))
                        for key_value, value_value in
                        map_literal_node.value.iteritems())

        def visit_map_type_node(self, map_type_node):
            try:
                return self.__type_cache[map_type_node.name]
            except:
                map_type = self.__construct(
                    'MapType',
                    key_type=map_type_node.key_type.accept(self),
                    value_type=map_type_node.value_type.accept(self))
                self.__type_cache[map_type_node.name] = map_type
                return map_type

        def visit_namespace_node(self, namespace_node):
            return \
                self.__construct(
                    'Namespace',
                    annotations=self.__visit_annotation_nodes(namespace_node.annotations),
                    doc=self.__visit_doc_node(namespace_node.doc),
                    name=namespace_node.name,
                    scope=namespace_node.scope
                )

        def __visit_sequence_type_node(self, construct_class_name,
                                       sequence_type_node):
            try:
                return self.__type_cache[sequence_type_node.name]
            except:
                sequence_type = self.__construct(
                    construct_class_name,
                    element_type=sequence_type_node.element_type.accept(self))
                self.__type_cache[sequence_type_node.name] = sequence_type
                return sequence_type

        def visit_service_node(self, service_node):
            service = \
                self.__construct(
                    'Service',
                    annotations=self.__visit_annotation_nodes(service_node.annotations),
                    doc=self.__visit_doc_node(service_node.doc),
                    name=service_node.name
                )
            self.__scope_stack.append(service)

            for function_node in service_node.functions:
                service.functions.append(function_node.accept(self))

            self.__scope_stack.pop(-1)

            return service

        def visit_set_type_node(self, set_type_node):
            return self.__visit_sequence_type_node('SetType', set_type_node)

        def visit_string_literal_node(self, string_literal_node):
            return string_literal_node.value

        def visit_struct_type_node(self, struct_type_node):
            return \
                self.__visit_compound_type_node(
                    'StructType',
                    struct_type_node,
                )

        def visit_type_node(self, type_node):
            try:
                try:
                    return self.__type_cache[type_node.qname]
                except KeyError:
                    if type_node.qname == type_node.name:
                        document = self.__scope_stack[0]
                        return self.__type_cache[document.name + '.' +
                                                 type_node.qname]
                    else:
                        raise
            except KeyError:
                raise CompileException("unrecognized type '%s'" %
                                       type_node.qname,
                                       ast_node=type_node)

        def visit_typedef_node(self, typedef_node):
            typedef = \
                self.__construct(
                    'Typedef',
                    annotations=self.__visit_annotation_nodes(typedef_node.annotations),
                    doc=self.__visit_doc_node(typedef_node.doc),
                    name=typedef_node.name,
                    type=typedef_node.type.accept(self)
                )

            self.__type_cache[typedef.thrift_qname()] = typedef.type

            return typedef

    def __init__(self, include_dir_paths=None):
        object.__init__(self)

        if include_dir_paths is None:
            include_dir_paths = []
        elif isinstance(include_dir_paths, (list, tuple)):
            include_dir_paths = list(include_dir_paths)
        else:
            include_dir_paths = [include_dir_paths]
        if len(include_dir_paths) == 0:
            include_dir_paths.append(os.getcwd())
        my_dir_path = os.path.dirname(os.path.realpath(__file__))
        lib_thrift_src_dir_path = \
            os.path.abspath(os.path.join(
                my_dir_path,
                '..', '..', '..', '..',
                'lib', 'thrift', 'src'
            ))
        if lib_thrift_src_dir_path not in include_dir_paths:
            include_dir_paths.append(lib_thrift_src_dir_path)
        self.__include_dir_paths = []
        for include_dir_path in include_dir_paths:
            if include_dir_path not in self.__include_dir_paths:
                self.__include_dir_paths.append(include_dir_path)
        self.__include_dir_paths = tuple(self.__include_dir_paths)

        self.__parsed_thrift_files_by_path = {}
        self.__scanner = Scanner()
        self.__parser = Parser()

    def __call__(self, thrift_file_paths, generator=None):
        return self.compile(thrift_file_paths, generator=generator)

    def compile(self, thrift_file_paths, generator=None):
        if not isinstance(thrift_file_paths, (list, tuple)):
            thrift_file_paths = (thrift_file_paths, )

        documents = []
        for thrift_file_path in thrift_file_paths:
            thrift_file_path = os.path.abspath(thrift_file_path)
            document_node = self.__parsed_thrift_files_by_path.get(
                thrift_file_path)
            if document_node is None:
                tokens = self.__scanner.tokenize(thrift_file_path)
                document_node = self.__parser.parse(tokens)
                self.__parsed_thrift_files_by_path[
                    thrift_file_path] = document_node
            if generator is not None:
                ast_visitor = \
                    self.__AstVisitor(
                        compiler=self,
                        generator=generator,
                        include_dir_paths=self.__include_dir_paths
                    )
                document = document_node.accept(ast_visitor)
                documents.append(document)
            else:
                documents.append(document_node)
        return tuple(documents)

    @property
    def include_dir_paths(self):
        return self.__include_dir_paths
Example #19
0
class Compiler(object):
    class __AstVisitor(object):
        def __init__(self, compiler, document_root_dir_path, generator, include_dir_paths):
            object.__init__(self)
            self.__compiler = compiler
            self.__document_root_dir_path = document_root_dir_path
            self.__generator = generator
            self.__include_dir_paths = include_dir_paths
            self.__scope_stack = []
            self.__type_by_thrift_qname_cache = {}
            self.__used_include_abspaths = {}
            self.__visited_includes = []

        def __construct(self, class_name, annotation_nodes=None, **kwds):
            if len(self.__scope_stack) > 0:
                parent = self.__scope_stack[-1]
            else:
                parent = self.__generator
            kwds['parent'] = parent

            construct_class = getattr(self.__generator, class_name)

            native = False
            if annotation_nodes:
                annotations = []
                for annotation_node in annotation_nodes:
                    if annotation_node.name == 'native':
                        native = True
                        continue

                    try:
                        annotation_class_name = upper_camelize(
                            annotation_node.name) + 'Annotation'
                        annotation_class = getattr(
                            construct_class, annotation_class_name)
                    except AttributeError:
                        annotation_class = getattr(
                            construct_class, 'Annotation')
                    annotations.append(
                        annotation_class(
                            name=annotation_node.name,
                            value=annotation_node.value
                        )
                    )
                kwds['annotations'] = tuple(annotations)

            construct = construct_class(**kwds)

            if not native:
                return construct

            parent_document = construct._parent_document()
            native_module_file_path = os.path.splitext(
                parent_document.path)[0] + '.py'
            if not os.path.isfile(native_module_file_path):
                return
            native_module_dir_path, native_module_file_name = \
                os.path.split(native_module_file_path)
            native_module_name = \
                os.path.splitext(native_module_file_name)[0]
            native_module = \
                imp.load_module(
                    native_module_name,
                    *imp.find_module(
                        native_module_name,
                        [native_module_dir_path]
                    )
                )

            native_construct_class = getattr(native_module, construct.name)

            construct = native_construct_class(
                overridden_construct=construct, **kwds)

            return construct

        def __get_type(self, type_thrift_qname, resolve=True):
            # e.g., struct_include_file.Struct
            try:
                return self.__type_by_thrift_qname_cache[type_thrift_qname]
            except KeyError:
                if not resolve:
                    raise
                for include in self.__visited_includes:
                    for definition in include.document.definitions:
                        if isinstance(definition, _Type) or isinstance(definition, Typedef):
                            definition_thrift_qname = definition.thrift_qname()
                            if definition_thrift_qname == type_thrift_qname:
                                self.__type_by_thrift_qname_cache[definition_thrift_qname] = definition
                                self.__used_include_abspaths[include.abspath] = True
                                return definition
                raise KeyError(type_thrift_qname)

        def __put_type(self, type_thrift_qname, type_):
            if type_thrift_qname in self.__type_by_thrift_qname_cache:
                raise CompileException("duplicate type %s" % type_thrift_qname)
            if isinstance(type_, Typedef):
                type_ = type_.type
            self.__type_by_thrift_qname_cache[type_thrift_qname] = type_

        def visit_base_type_node(self, base_type_node):
            try:
                return self.__get_type(base_type_node.name, resolve=False)
            except KeyError:
                base_type = getattr(self.__generator, base_type_node.name.capitalize(
                ) + 'Type')(name=base_type_node.name)
                self.__put_type(base_type_node.name, base_type)
                return base_type

        def visit_bool_literal_node(self, bool_literal_node):
            return bool_literal_node.value

        def __visit_compound_type_node(self, construct_class_name, compound_type_node):
            compound_type = \
                self.__construct(
                    construct_class_name,
                    annotations=compound_type_node.annotations,
                    doc=self.__visit_doc_node(compound_type_node.doc),
                    name=compound_type_node.name
                )
            self.__scope_stack.append(compound_type)

            # Insert the compound type into the type_map here to allow recursive
            # definitions
            self.__put_type(compound_type.thrift_qname(), compound_type)

            if construct_class_name == 'EnumType':
                enum_type_node = compound_type_node
                have_enumerator_with_value = False
                enumerator_node_names = []
                for enumerator_i, enumerator_node in enumerate(enum_type_node.enumerators):
                    if enumerator_node.name in enumerator_node_names:
                        raise CompileException("%s has a duplicate enumerator name, %s" % (
                            enum_type_node.name, enumerator_node.name), ast_node=enumerator_node)
                    enumerator_node_names.append(enumerator_node.name)

                    if enumerator_node.value is not None:
                        have_enumerator_with_value = True
                        assert isinstance(enumerator_node.value, Ast.IntLiteralNode), type(
                            enumerator_node.value)
                        value = enumerator_node.value.value
                    else:
                        if have_enumerator_with_value:
                            raise CompileException(
                                "%s has mix of enumerators with and without values, must be one or the other" % enum_type_node.name, ast_node=enum_type_node)
                        value = enumerator_i

                    compound_type.enumerators.append(
                        self.__construct(
                            'Field',
                            annotation_nodes=enumerator_node.annotations,
                            doc=self.__visit_doc_node(enumerator_node.doc),
                            id=enumerator_i,
                            name=enumerator_node.name,
                            type=Ast.BaseTypeNode('i32').accept(self),
                            value=value
                        )
                    )
            else:
                field_name_variations = []
                id_count = 0
                for field_node in compound_type_node.fields:
                    field_name = field_node.name
                    if field_name in field_name_variations:
                        raise CompileException("compound type %s has a duplicate field %s" % (
                            compound_type_node.name, field_name), ast_node=field_node)

                    field_name_lower = field_name.lower()
                    if field_name_lower in field_name_variations:
                        raise CompileException("compound type %s has a duplicate field %s" % (
                            compound_type_node.name, field_name), ast_node=field_node)

                    field_name_lower_camelized = lower_camelize(field_name)
                    if field_name_lower_camelized in field_name_variations:
                        raise CompileException("compound type %s has a duplicate field %s" % (
                            compound_type_node.name, field_name), ast_node=field_node)

                    field_name_variations.append(field_name)
                    field_name_variations.append(field_name_lower)
                    field_name_variations.append(field_name_lower_camelized)

                    field = field_node.accept(self)
                    if field.required:
                        if len(compound_type.fields) > 0:
                            if not compound_type.fields[-1].required:
                                raise CompileException("compound type %s has a required field %s after an optional field %s" % (
                                    compound_type_node.name, field.name, compound_type.fields[-1].name), ast_node=compound_type_node)
                    if field.id is not None:
                        id_count += 1
                        for existing_field in compound_type.fields:
                            if existing_field.id == field.id:
                                raise CompileException("compound type %s has duplicate field id %d (%s and %s fields)" % (
                                    compound_type_node.name, field.id, field.name, existing_field.name), ast_node=compound_type_node)
                    compound_type.fields.append(field)
                if len(compound_type.fields) > 0:
                    if id_count != 0 and id_count != len(compound_type_node.fields):
                        raise CompileException("compound type %s has some fields with ids and some fields without" %
                                               compound_type_node.name, ast_node=compound_type_node)

            self.__scope_stack.pop(-1)

            return compound_type

        def visit_const_node(self, const_node):
            return \
                self.__construct(
                    'Const',
                    annotation_nodes=const_node.annotations,
                    doc=self.__visit_doc_node(const_node.doc),
                    name=const_node.name,
                    type=const_node.type.accept(self),
                    value=const_node.value.accept(self)
                )

        def __visit_doc_node(self, doc_node):
            if doc_node is not None:
                return doc_node.text
            else:
                return None

        def visit_document_node(self, document_node):
            document = \
                self.__construct(
                    'Document',
                    document_root_dir_path=self.__document_root_dir_path,
                    path=document_node.path
                )
            self.__scope_stack.append(document)

            for header_node in document_node.headers:
                document.headers.append(header_node.accept(self))
            for definition_node in document_node.definitions:
                document.definitions.append(definition_node.accept(self))

            self.__scope_stack.pop(-1)

            for include in self.__visited_includes:
                include.used = include.abspath in self.__used_include_abspaths

            return document

        def visit_enum_type_node(self, enum_node):
            return self.__visit_compound_type_node('EnumType', enum_node)

        def visit_exception_type_node(self, exception_type_node):
            return self.__visit_compound_type_node('ExceptionType', exception_type_node)

        def visit_field_node(self, field_node):
            if field_node.value is not None:
                value = field_node.value.accept(self)
            else:
                value = None
            return \
                self.__construct(
                    'Field',
                    annotation_nodes=field_node.annotations,
                    doc=self.__visit_doc_node(field_node.doc),
                    id=field_node.id,
                    name=field_node.name,
                    required=field_node.required,
                    type=field_node.type.accept(self),
                    value=value
                )

        def visit_float_literal_node(self, float_literal_node):
            return float_literal_node.value

        def visit_function_node(self, function_node):
            function = \
                self.__construct(
                    'Function',
                    annotation_nodes=function_node.annotations,
                    doc=self.__visit_doc_node(function_node.doc),
                    name=function_node.name,
                    oneway=function_node.oneway
                )
            self.__scope_stack.append(function)

            for parameter_node in function_node.parameters:
                parameter = parameter_node.accept(self)
                if parameter.required:
                    if len(function.parameters) > 0:
                        if not function.parameters[-1].required:
                            raise CompileException("function %s has a required parameter %s after an optional parameter %s" % (
                                function.name, parameter.name, function.parameters[-1].name))
                function.parameters.append(parameter)
            if function_node.return_field is not None:
                function.return_field = function_node.return_field.accept(self)
            for throws_node in function_node.throws:
                function.throws.append(throws_node.accept(self))

            self.__scope_stack.pop(-1)
            return function

        def visit_include_node(self, include_node):
            include_dir_paths = list(self.__include_dir_paths)
            for scope in reversed(self.__scope_stack):
                if isinstance(scope, Document):
                    include_dir_paths.append(os.path.dirname(scope.path))
                    break

            include_file_relpath = include_node.path.replace('/', os.path.sep)
            for include_dir_path in include_dir_paths:
                include_file_path = os.path.join(
                    include_dir_path, include_file_relpath)
                if os.path.exists(include_file_path):
                    include_file_path = os.path.abspath(include_file_path)
                    included_document = \
                        self.__compiler.compile(
                            generator=self.__generator,
                            thrift_file_path=include_file_path
                        )
                    include = \
                        self.__construct(
                            'Include',
                            abspath=include_file_path,
                            annotation_nodes=include_node.annotations,
                            doc=self.__visit_doc_node(include_node.doc),
                            document=included_document,
                            relpath=include_file_relpath
                        )
                    self.__visited_includes.append(include)
                    return include
            raise CompileException("include path not found: %s" %
                                   include_file_relpath, ast_node=include_node)

        def visit_int_literal_node(self, int_literal_node):
            return int_literal_node.value

        def visit_list_literal_node(self, list_literal_node):
            return tuple(element_value.accept(self) for element_value in list_literal_node.value)

        def visit_list_type_node(self, list_type_node):
            return self.__visit_sequence_type_node('ListType', list_type_node)

        def visit_map_literal_node(self, map_literal_node):
            return tuple((key_value.accept(self), value_value.accept(self)) for key_value, value_value in map_literal_node.value)

        def visit_map_type_node(self, map_type_node):
            try:
                return self.__get_type(map_type_node.name, resolve=False)
            except KeyError:
                map_type = self.__construct('MapType', key_type=map_type_node.key_type.accept(
                    self), value_type=map_type_node.value_type.accept(self))
                self.__put_type(map_type_node.name, map_type)
                return map_type

        def visit_namespace_node(self, namespace_node):
            return \
                self.__construct(
                    'Namespace',
                    annotation_nodes=namespace_node.annotations,
                    doc=self.__visit_doc_node(namespace_node.doc),
                    name=namespace_node.name,
                    scope=namespace_node.scope
                )

        def __visit_sequence_type_node(self, construct_class_name, sequence_type_node):
            try:
                return self.__get_type(sequence_type_node.name, resolve=False)
            except KeyError:
                sequence_type = self.__construct(
                    construct_class_name, element_type=sequence_type_node.element_type.accept(self))
                self.__put_type(sequence_type_node.name, sequence_type)
                return sequence_type

        def visit_service_node(self, service_node):
            service = \
                self.__construct(
                    'Service',
                    annotation_nodes=service_node.annotations,
                    doc=self.__visit_doc_node(service_node.doc),
                    name=service_node.name
                )
            self.__scope_stack.append(service)

            function_names_lower = []
            for function_node in service_node.functions:
                function = function_node.accept(self)

                function_name_lower = function.name.lower()
                if function_name_lower in function_names_lower:
                    raise CompileException(
                        "duplicate (case-insensitive) function name '%s'" % function.name, ast_node=function_node)
                function_names_lower.append(function_name_lower)

                service.functions.append(function)
            self.__scope_stack.pop(-1)

            return service

        def visit_set_type_node(self, set_type_node):
            return self.__visit_sequence_type_node('SetType', set_type_node)

        def visit_string_literal_node(self, string_literal_node):
            return string_literal_node.value

        def visit_struct_type_node(self, struct_type_node):
            return \
                self.__visit_compound_type_node(
                    'StructType',
                    struct_type_node,
                )

        def visit_type_node(self, type_node):
            try:
                try:
                    return self.__get_type(type_node.qname)
                except KeyError:
                    if type_node.qname == type_node.name:
                        document = self.__scope_stack[0]
                        return self.__get_type(document.name + '.' + type_node.qname)
                    else:
                        raise
            except KeyError:
                raise CompileException(
                    "unrecognized type '%s'" % type_node.qname, ast_node=type_node)

        def visit_typedef_node(self, typedef_node):
            typedef = \
                self.__construct(
                    'Typedef',
                    annotation_nodes=typedef_node.annotations,
                    doc=self.__visit_doc_node(typedef_node.doc),
                    name=typedef_node.name,
                    type=typedef_node.type.accept(self)
                )

            self.__put_type(typedef.thrift_qname(), typedef)

            return typedef

    def __init__(self, document_root_dir_path, include_dir_paths):
        object.__init__(self)

        self.__document_root_dir_path = document_root_dir_path

        if isinstance(include_dir_paths, (list, tuple)):
            include_dir_paths = list(include_dir_paths)
        else:
            include_dir_paths = [include_dir_paths]
        my_dir_path = os.path.dirname(os.path.realpath(__file__))
        lib_thrift_src_dir_path = \
            os.path.abspath(os.path.join(
                my_dir_path,
                '..', '..', '..', '..',
                'lib', 'thrift', 'src'
            ))
        if lib_thrift_src_dir_path not in include_dir_paths:
            include_dir_paths.append(lib_thrift_src_dir_path)
        self.__include_dir_paths = []
        for include_dir_path in include_dir_paths:
            if include_dir_path not in self.__include_dir_paths:
                self.__include_dir_paths.append(include_dir_path)
        self.__include_dir_paths = tuple(self.__include_dir_paths)

        self.__document_cache = {}
        self.__scanner = Scanner()
        self.__parser = Parser()

    def __call__(self, *args, **kwds):
        return self.compile(*args, **kwds)

    def compile(self, generator, thrift_file_path):
        if generator is None:
            raise ValueError('generator must not be None')

        thrift_file_path = os.path.abspath(thrift_file_path)

        path_document_cache = self.__document_cache.get(thrift_file_path)
        if path_document_cache is not None:
            for other_generator, generator_document in path_document_cache:
                if other_generator is None:
                    # Special placeholder for the AST
                    document_node = generator_document
                elif generator == other_generator:
                    return generator_document
            assert document_node is not None
        else:
            tokens = self.__scanner.tokenize(thrift_file_path)
            document_node = self.__parser.parse(
                thrift_file_path=thrift_file_path, tokens=tokens)
            # Can't hash the generator easily so just use its __eq__
            self.__document_cache[thrift_file_path] = path_document_cache = [
                (None, document_node)]

        ast_visitor = \
            self.__AstVisitor(
                compiler=self,
                document_root_dir_path=self.__document_root_dir_path,
                generator=generator,
                include_dir_paths=self.__include_dir_paths
            )
        generator_document = document_node.accept(ast_visitor)
        path_document_cache.append((generator, generator_document))
        return generator_document

    @property
    def document_root_dir_path(self):
        return self.__document_root_dir_path

    @property
    def include_dir_paths(self):
        return self.__include_dir_paths
Example #20
0
            ]

        def _java_methods(self):
            methods = []
            for function in self.functions:
                methods.extend(function.java_definitions())
            return methods

        def java_repr(self):
            delegate_name = self.__java_delegate_name()
            name = self.java_name()

            sections = []
            sections.append(self.__java_markers())
            sections.append("""public final static com.google.inject.name.Named DELEGATE_NAME = com.google.inject.name.Names.named("%(delegate_name)s");""" % locals())
            sections.append("\n\n".join([self._java_constructor()] + self._java_methods()))
            sections.append("\n".join(self._java_member_declarations()))
            sections = "\n\n".join(indent(' ' * 4, sections))

            service_qname = java_generator.JavaGenerator.Service.java_qname(self)  # @UndefinedVariable

            return """\
@com.google.inject.Singleton
public class %(name)s implements %(service_qname)s {
%(sections)s
}""" % locals()


Parser.register_annotation_parser(JavaLogExceptionStackTraceAnnotationParser())
Parser.register_annotation_parser(JavaLogLevelAnnotationParser())
Example #21
0
class SqlGenerator(Generator):
    from thryft.generators.sql.sql_binary_type import SqlBinaryType as BinaryType  # @UnusedImport
    from thryft.generators.sql.sql_bool_type import SqlBoolType as BoolType  # @UnusedImport
    from thryft.generators.sql.sql_byte_type import SqlByteType as ByteType  # @UnusedImport
    from thryft.generators.sql.sql_const import SqlConst as Const  # @UnusedImport
    from thryft.generators.sql.sql_document import SqlDocument as Document  # @UnusedImport
    from thryft.generators.sql.sql_double_type import SqlDoubleType as DoubleType  # @UnusedImport
    from thryft.generators.sql.sql_enum_type import SqlEnumType as EnumType  # @UnusedImport
    from thryft.generators.sql.sql_exception_type import SqlExceptionType as ExceptionType  # @UnusedImport
    from thryft.generators.sql.sql_field import SqlField as Field  # @UnusedImport
    from thryft.generators.sql.sql_function import SqlFunction as Function  # @UnusedImport
    from thryft.generators.sql.sql_i16_type import SqlI16Type as I16Type  # @UnusedImport
    from thryft.generators.sql.sql_i32_type import SqlI32Type as I32Type  # @UnusedImport
    from thryft.generators.sql.sql_i64_type import SqlI64Type as I64Type  # @UnusedImport
    from thryft.generators.sql.sql_include import SqlInclude as Include  # @UnusedImport
    from thryft.generators.sql.sql_list_type import SqlListType as ListType  # @UnusedImport
    from thryft.generators.sql.sql_map_type import SqlMapType as MapType  # @UnusedImport
    from thryft.generators.sql.sql_service import SqlService as Service  # @UnusedImport
    from thryft.generators.sql.sql_set_type import SqlSetType as SetType  # @UnusedImport
    from thryft.generators.sql.sql_string_type import SqlStringType as StringType  # @UnusedImport
    from thryft.generators.sql.sql_struct_type import SqlStructType as StructType  # @UnusedImport
    from thryft.generators.sql.sql_typedef import SqlTypedef as Typedef  # @UnusedImport


Parser.register_annotation_parser(
    AnnotationParser('sql_column', Ast.StructTypeNode))
Parser.register_annotation_parser(SqlForeignKeyAnnotationParser())
Parser.register_annotation_parser(
    ValuelessAnnotationParser('sql_unique', Ast.FieldNode))
Example #22
0
                if len(function_names) > 0 and cmp(function.name,
                                                   function_names[-1]) < 0:
                    after_function_name = ''
                    for function_name_i in xrange(
                            len(function_names) - 1, -1, -1):
                        test_function_name = function_names[function_name_i]
                        if cmp(function.name, test_function_name) >= 0:
                            after_function_name = test_function_name
                            break
                    self._logger.warn(
                        "function %s in %s is out of lexicographic order (should be after %s)",
                        function.name,
                        self._parent_document().path, after_function_name)
                function_names.append(function.name)

    class SetType(Generator.SetType, _SequenceType):  # @UndefinedVariable
        pass

    class StringType(Generator.StringType, _Type):  # @UndefinedVariable
        pass

    class StructType(Generator.StructType,
                     _CompoundType):  # @UndefinedVariable
        pass


Parser.register_annotation_parser(
    ValuelessAnnotationParser('lint_suppress',
                              (Ast.EnumTypeNode, Ast.ExceptionTypeNode,
                               Ast.ServiceNode, Ast.StructTypeNode)))