def __init__(self, include_dir_paths=None): object.__init__(self) if include_dir_paths is None: include_dir_paths = [] elif isinstance(include_dir_paths, (list, tuple)): include_dir_paths = list(include_dir_paths) else: include_dir_paths = [include_dir_paths] if len(include_dir_paths) == 0: include_dir_paths.append(os.getcwd()) my_dir_path = os.path.dirname(os.path.realpath(__file__)) lib_thrift_src_dir_path = \ os.path.abspath(os.path.join( my_dir_path, '..', '..', '..', '..', 'lib', 'thrift', 'src' )) if lib_thrift_src_dir_path not in include_dir_paths: include_dir_paths.append(lib_thrift_src_dir_path) self.__include_dir_paths = [] for include_dir_path in include_dir_paths: if include_dir_path not in self.__include_dir_paths: self.__include_dir_paths.append(include_dir_path) self.__include_dir_paths = tuple(self.__include_dir_paths) self.__parsed_thrift_files_by_path = {} self.__scanner = Scanner() self.__parser = Parser()
def __init__(self, document_root_dir_path, include_dir_paths): object.__init__(self) self.__document_root_dir_path = document_root_dir_path if isinstance(include_dir_paths, (list, tuple)): include_dir_paths = list(include_dir_paths) else: include_dir_paths = [include_dir_paths] my_dir_path = os.path.dirname(os.path.realpath(__file__)) lib_thrift_src_dir_path = \ os.path.abspath(os.path.join( my_dir_path, '..', '..', '..', '..', 'lib', 'thrift', 'src' )) if lib_thrift_src_dir_path not in include_dir_paths: include_dir_paths.append(lib_thrift_src_dir_path) self.__include_dir_paths = [] for include_dir_path in include_dir_paths: if include_dir_path not in self.__include_dir_paths: self.__include_dir_paths.append(include_dir_path) self.__include_dir_paths = tuple(self.__include_dir_paths) self.__document_cache = {} self.__scanner = Scanner() self.__parser = Parser()
def _runTest(self, thrift_file_path): # import logging # logging.basicConfig(level=logging.DEBUG) # import os.path # if os.path.split(thrift_file_path)[1] != 'struct_type.thrift': # return tokens = [] try: tokens = Scanner().tokenize(thrift_file_path) ast = Parser().parse(tokens) # @UnusedVariable # import pprint; pprint.pprint(ast.to_dict()) except: print >> sys.stderr, 'Error parsing', thrift_file_path traceback.print_exc() for token in tokens: print >> sys.stderr, token.index, token.type, ':', len( token.text), ':', token.text print >> sys.stderr raise
for function in self.functions: methods.extend(function.java_definitions()) return methods def java_repr(self): delegate_name = self.__java_delegate_name() name = self.java_name() sections = [] sections.append(self.__java_markers()) sections.append("""public final static com.google.inject.name.Named DELEGATE_NAME = com.google.inject.name.Names.named("%(delegate_name)s");""" % locals()) sections.append("\n\n".join([self._java_constructor()] + self._java_methods())) sections.append("\n".join(self._java_member_declarations())) sections.append(self.__java_log_messages()) sections.append(self.__java_log_return_values()) sections = "\n\n".join(indent(' ' * 4, sections)) service_qname = java_generator.JavaGenerator.Service.java_qname(self) # @UndefinedVariable return """\ @com.google.inject.Singleton public class %(name)s implements %(service_qname)s { %(sections)s }""" % locals() Parser.register_annotation_parser(JavaLogExceptionStackTraceAnnotationParser()) Parser.register_annotation_parser(JavaLogLevelAnnotationParser('java_log_level', (Ast.ExceptionTypeNode, Ast.FieldNode, Ast.FunctionNode))) Parser.register_annotation_parser(JavaLogLevelAnnotationParser('java_log_level_post', (Ast.FunctionNode,))) Parser.register_annotation_parser(JavaLogLevelAnnotationParser('java_log_level_pre', (Ast.FunctionNode,)))
for function in self.functions: methods.extend(function.java_definitions()) return methods def java_repr(self): delegate_name = self.__java_delegate_name() name = self.java_name() sections = [] sections.append(self.__java_markers()) sections.append( """public final static com.google.inject.name.Named DELEGATE_NAME = com.google.inject.name.Names.named("%(delegate_name)s");""" % locals()) sections.append("\n\n".join([self._java_constructor()] + self._java_methods())) sections.append("\n".join(self._java_member_declarations())) sections = "\n\n".join(indent(' ' * 4, sections)) service_qname = java_generator.JavaGenerator.Service.java_qname( self) # @UndefinedVariable return """\ @com.google.inject.Singleton public class %(name)s implements %(service_qname)s { %(sections)s }""" % locals() Parser.register_annotation_parser(JavaLogExceptionStackTraceAnnotationParser()) Parser.register_annotation_parser(JavaLogLevelAnnotationParser())
from thryft.generators.sql.sql_i64_type import SqlI64Type as I64Type # @UnusedImport from thryft.generators.sql.sql_include import SqlInclude as Include # @UnusedImport from thryft.generators.sql.sql_list_type import SqlListType as ListType # @UnusedImport from thryft.generators.sql.sql_map_type import SqlMapType as MapType # @UnusedImport from thryft.generators.sql.sql_native_type import SqlNativeType as NativeType # @UnusedImport from thryft.generators.sql.sql_service import SqlService as Service # @UnusedImport from thryft.generators.sql.sql_set_type import SqlSetType as SetType # @UnusedImport from thryft.generators.sql.sql_string_type import SqlStringType as StringType # @UnusedImport from thryft.generators.sql.sql_struct_type import SqlStructType as StructType # @UnusedImport from thryft.generators.sql.sql_typedef import SqlTypedef as Typedef # @UnusedImport def __parse_sql_foreign_key_annotation(ast_node, name, value, **kwds): value_parts = value.split('.') if len(value_parts) != 2: raise ValueError("@%s must be specify table.column: '%s'" % (name, value)) table_name, column_name = value_parts if len(table_name) == 0 or len(column_name) == 0: raise ValueError("@%s must be specify a table.column: '%s'" % (name, value)) annotation = Ast.AnnotationNode(name=name, value=(table_name, column_name), **kwds) ast_node.annotations.append(annotation) Parser.register_annotation(Ast.FieldNode, 'sql_foreign_key', __parse_sql_foreign_key_annotation) def __parse_sql_unique_annotation(ast_node, name, value, **kwds): if value is not None: raise ValueError("@%(name)s does not take a value" % locals()) ast_node.annotations.append(Ast.AnnotationNode(name=name, **kwds)) Parser.register_annotation(Ast.FieldNode, 'sql_unique', __parse_sql_unique_annotation)
class MapType(Generator.MapType, _ContainerType): # @UndefinedVariable pass class Service(Generator.Service, _NamedConstruct): # @UndefinedVariable def lint(self): function_names = [] for function in self.functions: after_function_name = check_lexicographic_order( function_names, function.name) if after_function_name is not None: self._logger.warn("function %s in %s is out of lexicographic order (should be after %s)", function.name, self._parent_document().path, after_function_name) function_names.append(function.name) function.lint() class SetType(Generator.SetType, _SequenceType): # @UndefinedVariable pass class StringType(Generator.StringType, _Type): # @UndefinedVariable pass class StructType(Generator.StructType, _CompoundType): # @UndefinedVariable pass Parser.register_annotation_parser(ValuelessAnnotationParser( 'lint_require_field_ids_recursive', (Ast.ExceptionTypeNode, Ast.StructTypeNode))) Parser.register_annotation_parser(ValuelessAnnotationParser( 'lint_suppress', (Ast.EnumTypeNode, Ast.ExceptionTypeNode, Ast.FunctionNode, Ast.ServiceNode, Ast.StructTypeNode)))
{ '_all': {'enabled': False}, 'dynamic': 'strict', } for annotation in self.annotations: if annotation.name == 'elastic_search_mappings_base': mappings[document_type].update(annotation.value) if 'properties' in mappings[document_type]: updated_properties = OrderedDict() updated_properties.update( mappings[document_type]['properties']) updated_properties.update(properties) properties = updated_properties mappings[document_type]['properties'] = properties return mappings def __init__(self, settings=None, template=None): Generator.__init__(self) if settings is not None and not isinstance(settings, dict): raise TypeError('settings must be a dict') self._settings = settings if template is not None and not isinstance(template, str): raise TypeError('template must be a str') self._template = template Parser.register_annotation_parser(ElasticSearchDocumentTypeAnnotationParser()) Parser.register_annotation_parser(ElasticSearchMappingAnnotationParser()) Parser.register_annotation_parser(ElasticSearchMappingsBaseAnnotationParser())
def __init__(self, default_methods=False, function_overloads=False, mutable_compound_types=False, namespace_prefix=None, **kwds): Generator.__init__(self, **kwds) self.__default_methods = default_methods self.__function_overloads = function_overloads self.__mutable_compound_types = mutable_compound_types self.__namespace_prefix = namespace_prefix @property def default_methods(self): return self.__default_methods @property def function_overloads(self): return self.__function_overloads @property def mutable_compound_types(self): return self.__mutable_compound_types @property def namespace_prefix(self): return self.__namespace_prefix Parser.register_annotation_parser(JavaFinalAnnotationParser()) Parser.register_annotation_parser(AnnotationParser('java_extends', (Ast.ExceptionTypeNode, Ast.StructTypeNode))) Parser.register_annotation_parser(AnnotationParser('java_implements', (Ast.EnumTypeNode, Ast.ExceptionTypeNode, Ast.ServiceNode, Ast.StructTypeNode))) Parser.register_annotation_parser(ValuelessAnnotationParser('java_escape_to_string', Ast.FieldNode)) Parser.register_annotation_parser(ValuelessAnnotationParser('java_exclude_from_to_string', Ast.FieldNode))
mappings[document_type] = \ { '_all': {'enabled': False}, 'dynamic': 'strict', } for annotation in self.annotations: if annotation.name == 'elastic_search_mappings_base': mappings[document_type].update(annotation.value) if 'properties' in mappings[document_type]: updated_properties = OrderedDict() updated_properties.update(mappings[document_type]['properties']) updated_properties.update(properties) properties = updated_properties mappings[document_type]['properties'] = properties return mappings def __init__(self, settings=None, template=None): Generator.__init__(self) if settings is not None and not isinstance(settings, dict): raise TypeError('settings must be a dict') self._settings = settings if template is not None and not isinstance(template, str): raise TypeError('template must be a str') self._template = template Parser.register_annotation_parser(ElasticSearchDocumentTypeAnnotationParser()) Parser.register_annotation_parser(ElasticSearchMappingAnnotationParser()) Parser.register_annotation_parser(ElasticSearchMappingsBaseAnnotationParser())
class Compiler(object): class __AstVisitor(object): __ROOT_GENERATOR = Generator() def __init__(self, compiler, generator, include_dir_paths): object.__init__(self) self.__compiler = compiler self.__generator = generator self.__include_dir_paths = include_dir_paths self.__scope_stack = [] self.__type_cache = {} def __construct(self, class_name, **kwds): if len(self.__scope_stack) > 0: parent = self.__scope_stack[-1] else: parent = self.__generator kwds['parent'] = parent annotations = kwds.get('annotations') name = kwds.get('name') if annotations is None: native = False elif name is None: native = False else: native = False for annotation in annotations: if annotation == 'native': native = True break if native: for scope in reversed(self.__scope_stack): if not isinstance(scope, Document): continue document = scope overrides_module_file_path = os.path.splitext(document.path)[0] + '.py' if not os.path.isfile(overrides_module_file_path): continue overrides_module_dir_path, overrides_module_file_name = \ os.path.split(overrides_module_file_path) overrides_module_name = \ os.path.splitext(overrides_module_file_name)[0] try: overrides_module = \ imp.load_module( overrides_module_name, *imp.find_module( overrides_module_name, [overrides_module_dir_path] ) ) except ImportError: logging.error( "error importing overrides module %s", overrides_module_file_path, exc_info=True ) continue # Find an override implementation corresponding to our generator # For example, find JavaDateTime by looking in the module for a # class that inherits JavaNativeType. # The first_bases algorithm below covers the case where we want # JavaDateTime but the generator gave us something that itself # inherits from JavaStructType e.g., SubJavaStructType. # In this case there is no SubJavaDateTime(SubJavaStructType), so # we need to consider SubJavaStructType's parent (JavaStructType) # and look for a subclass of that. root_construct_class = getattr(self.__ROOT_GENERATOR, 'NativeType') generator_construct_class = getattr(self.__generator, 'NativeType') parent_construct_classes = [generator_construct_class] def __get_first_bases(class_, first_bases): if len(class_.__bases__) == 0: return first_base = class_.__bases__[0] if first_base is object or first_base is root_construct_class: return if first_base.__name__[0] != '_': first_bases.append(first_base) __get_first_bases(first_base, first_bases) __get_first_bases(generator_construct_class, parent_construct_classes) for parent_construct_class in parent_construct_classes: for attr in dir(overrides_module): value = getattr(overrides_module, attr) if isclass(value) and \ issubclass(value, parent_construct_class) and \ value != parent_construct_class: return getattr(overrides_module, attr)(**kwds) # logging.warn("could not find override class for %s in %s" % (default_construct_class.__name__, overrides_module_name)) return getattr(self.__generator, class_name)(**kwds) def visit_annotation_node(self, annotation_node): return annotation_node.value def __visit_annotation_nodes(self, annotation_nodes): if annotation_nodes is None: return {} return dict((annotation_node.name, annotation_node.accept(self)) for annotation_node in annotation_nodes) def visit_base_type_node(self, base_type_node): try: return self.__type_cache[base_type_node.name] except: base_type = getattr(self.__generator, base_type_node.name.capitalize() + 'Type')(name=base_type_node.name) self.__type_cache[base_type_node.name] = base_type return base_type def visit_bool_literal_node(self, bool_literal_node): return bool_literal_node.value def __visit_compound_type_node(self, construct_class_name, compound_type_node): compound_type = \ self.__construct( construct_class_name, annotations=self.__visit_annotation_nodes(compound_type_node.annotations), doc=self.__visit_doc_node(compound_type_node.doc), name=compound_type_node.name ) if isinstance(compound_type, NativeType): return compound_type self.__scope_stack.append(compound_type) # Insert the compound type into the type_map here to allow recursive # definitions self.__type_cache[compound_type.thrift_qname()] = compound_type if construct_class_name == 'EnumType': have_enumerator_with_value = False for enumerator_node in compound_type_node.enumerators: if enumerator_node.value is not None: have_enumerator_with_value = True elif have_enumerator_with_value: raise CompileException("%s has mix of enumerators with and without values, must be one or the other" % compound_type_node.name, ast_node=compound_type_node) for enumerator_i, enumerator_node in enumerate(compound_type_node.enumerators): if enumerator_node.value is None: value = enumerator_i else: assert isinstance(enumerator_node.value, Ast.IntLiteralNode), type(enumerator_node.value) value = enumerator_node.value.value compound_type.enumerators.append( self.__construct( 'Field', annotations=self.__visit_annotation_nodes(enumerator_node.annotations), doc=self.__visit_doc_node(enumerator_node.doc), id=enumerator_i, name=enumerator_node.name, type=Ast.BaseTypeNode('i32').accept(self), value=value ) ) else: for field in compound_type_node.fields: compound_type.fields.append(field.accept(self)) self.__scope_stack.pop(-1) return compound_type def visit_const_node(self, const_node): return \ self.__construct( 'Const', annotations=self.__visit_annotation_nodes(const_node.annotations), doc=self.__visit_doc_node(const_node.doc), name=const_node.name, type=const_node.type.accept(self), value=const_node.value.accept(self) ) def __visit_doc_node(self, doc_node): if doc_node is not None: return doc_node.text else: return None def visit_document_node(self, document_node): document = self.__construct('Document', path=document_node.path) self.__scope_stack.append(document) for header_node in document_node.headers: document.headers.append(header_node.accept(self)) for definition_node in document_node.definitions: document.definitions.append(definition_node.accept(self)) self.__scope_stack.pop(-1) return document def visit_enum_type_node(self, enum_node): return self.__visit_compound_type_node('EnumType', enum_node) def visit_exception_type_node(self, exception_type_node): return self.__visit_compound_type_node('ExceptionType', exception_type_node) def visit_field_node(self, field_node): return \ self.__construct( 'Field', annotations=self.__visit_annotation_nodes(field_node.annotations), doc=self.__visit_doc_node(field_node.doc), id=field_node.id, name=field_node.name, required=field_node.required, type=field_node.type.accept(self), value=field_node.value ) def visit_float_literal_node(self, float_literal_node): return float_literal_node.value def visit_function_node(self, function_node): function = \ self.__construct( 'Function', annotations=self.__visit_annotation_nodes(function_node.annotations), doc=self.__visit_doc_node(function_node.doc), name=function_node.name, oneway=function_node.oneway ) self.__scope_stack.append(function) for parameter_node in function_node.parameters: function.parameters.append(parameter_node.accept(self)) if function_node.return_field is not None: function.return_field = function_node.return_field.accept(self) for throws_node in function_node.throws: function.throws.append(throws_node.accept(self)) self.__scope_stack.pop(-1) return function def visit_include_node(self, include_node): include_dir_paths = list(self.__include_dir_paths) for scope in reversed(self.__scope_stack): if isinstance(scope, Document): include_dir_paths.append(os.path.dirname(scope.path)) break include_file_relpath = include_node.path.replace('/', os.path.sep) for include_dir_path in include_dir_paths: include_file_path = os.path.join(include_dir_path, include_file_relpath) if os.path.exists(include_file_path): include_file_path = os.path.abspath(include_file_path) included_document = self.__compiler.compile((include_file_path,), generator=self.__generator)[0] include = \ self.__construct( 'Include', annotations=self.__visit_annotation_nodes(include_node.annotations), doc=self.__visit_doc_node(include_node.doc), document=included_document, path=include_file_relpath ) for definition in included_document.definitions: if isinstance(definition, _Type): self.__type_cache[definition.thrift_qname()] = definition elif isinstance(definition, Typedef): self.__type_cache[definition.thrift_qname()] = definition # .type return include raise CompileException("include path not found: %s" % include_file_relpath, ast_node=include_node) def visit_int_literal_node(self, int_literal_node): return int_literal_node.value def visit_list_literal_node(self, list_literal_node): return tuple(element_value.accept(self) for element_value in list_literal_node.value) def visit_list_type_node(self, list_type_node): return self.__visit_sequence_type_node('ListType', list_type_node) def visit_map_literal_node(self, map_literal_node): return dict((key_value.accept(self), value_value.accept(self)) for key_value, value_value in map_literal_node.value.iteritems()) def visit_map_type_node(self, map_type_node): try: return self.__type_cache[map_type_node.name] except: map_type = self.__construct('MapType', key_type=map_type_node.key_type.accept(self), value_type=map_type_node.value_type.accept(self)) self.__type_cache[map_type_node.name] = map_type return map_type def visit_namespace_node(self, namespace_node): return \ self.__construct( 'Namespace', annotations=self.__visit_annotation_nodes(namespace_node.annotations), doc=self.__visit_doc_node(namespace_node.doc), name=namespace_node.name, scope=namespace_node.scope ) def __visit_sequence_type_node(self, construct_class_name, sequence_type_node): try: return self.__type_cache[sequence_type_node.name] except: sequence_type = self.__construct(construct_class_name, element_type=sequence_type_node.element_type.accept(self)) self.__type_cache[sequence_type_node.name] = sequence_type return sequence_type def visit_service_node(self, service_node): service = \ self.__construct( 'Service', annotations=self.__visit_annotation_nodes(service_node.annotations), doc=self.__visit_doc_node(service_node.doc), name=service_node.name ) self.__scope_stack.append(service) for function_node in service_node.functions: service.functions.append(function_node.accept(self)) self.__scope_stack.pop(-1) return service def visit_set_type_node(self, set_type_node): return self.__visit_sequence_type_node('SetType', set_type_node) def visit_string_literal_node(self, string_literal_node): return string_literal_node.value def visit_struct_type_node(self, struct_type_node): return \ self.__visit_compound_type_node( 'StructType', struct_type_node, ) def visit_type_node(self, type_node): try: try: return self.__type_cache[type_node.qname] except KeyError: if type_node.qname == type_node.name: document = self.__scope_stack[0] return self.__type_cache[document.name + '.' + type_node.qname] else: raise except KeyError: raise CompileException("unrecognized type '%s'" % type_node.qname, ast_node=type_node) def visit_typedef_node(self, typedef_node): typedef = \ self.__construct( 'Typedef', annotations=self.__visit_annotation_nodes(typedef_node.annotations), doc=self.__visit_doc_node(typedef_node.doc), name=typedef_node.name, type=typedef_node.type.accept(self) ) self.__type_cache[typedef.thrift_qname()] = typedef.type return typedef def __init__(self, include_dir_paths=None): object.__init__(self) if include_dir_paths is None: include_dir_paths = [] elif isinstance(include_dir_paths, (list, tuple)): include_dir_paths = list(include_dir_paths) else: include_dir_paths = [include_dir_paths] if len(include_dir_paths) == 0: include_dir_paths.append(os.getcwd()) my_dir_path = os.path.dirname(os.path.realpath(__file__)) lib_thrift_src_dir_path = \ os.path.abspath(os.path.join( my_dir_path, '..', '..', '..', '..', 'lib', 'thrift', 'src' )) if lib_thrift_src_dir_path not in include_dir_paths: include_dir_paths.append(lib_thrift_src_dir_path) self.__include_dir_paths = [] for include_dir_path in include_dir_paths: if include_dir_path not in self.__include_dir_paths: self.__include_dir_paths.append(include_dir_path) self.__include_dir_paths = tuple(self.__include_dir_paths) self.__parsed_thrift_files_by_path = {} self.__scanner = Scanner() self.__parser = Parser() def __call__(self, thrift_file_paths, generator=None): return self.compile(thrift_file_paths, generator=generator) def compile(self, thrift_file_paths, generator=None): if not isinstance(thrift_file_paths, (list, tuple)): thrift_file_paths = (thrift_file_paths,) documents = [] for thrift_file_path in thrift_file_paths: thrift_file_path = os.path.abspath(thrift_file_path) document_node = self.__parsed_thrift_files_by_path.get(thrift_file_path) if document_node is None: tokens = self.__scanner.tokenize(thrift_file_path) document_node = self.__parser.parse(tokens) self.__parsed_thrift_files_by_path[thrift_file_path] = document_node if generator is not None: ast_visitor = \ self.__AstVisitor( compiler=self, generator=generator, include_dir_paths=self.__include_dir_paths ) document = document_node.accept(ast_visitor) documents.append(document) else: documents.append(document_node) return tuple(documents) @property def include_dir_paths(self): return self.__include_dir_paths
if len(function_names) > 0 and cmp(function.name, function_names[-1]) < 0: after_function_name = "" for function_name_i in xrange(len(function_names) - 1, -1, -1): test_function_name = function_names[function_name_i] if cmp(function.name, test_function_name) >= 0: after_function_name = test_function_name break self._logger.warn( "function %s in %s is out of lexicographic order (should be after %s)", function.name, self._parent_document().path, after_function_name, ) function_names.append(function.name) class SetType(Generator.SetType, _SequenceType): # @UndefinedVariable pass class StringType(Generator.StringType, _Type): # @UndefinedVariable pass class StructType(Generator.StructType, _CompoundType): # @UndefinedVariable pass Parser.register_annotation_parser( ValuelessAnnotationParser( "lint_suppress", (Ast.EnumTypeNode, Ast.ExceptionTypeNode, Ast.ServiceNode, Ast.StructTypeNode) ) )
from thryft.compiler.parser import Parser from thryft.generator.generator import Generator from thryft.generators.js.js_view_metadata_annotation_parser import JsViewMetadataAnnotationParser class JsGenerator(Generator): from thryft.generators.js.js_binary_type import JsBinaryType as BinaryType # @UnusedImport from thryft.generators.js.js_bool_type import JsBoolType as BoolType # @UnusedImport from thryft.generators.js.js_byte_type import JsByteType as ByteType # @UnusedImport from thryft.generators.js.js_const import JsConst as Const # @UnusedImport from thryft.generators.js.js_document import JsDocument as Document # @UnusedImport from thryft.generators.js.js_double_type import JsDoubleType as DoubleType # @UnusedImport from thryft.generators.js.js_enum_type import JsEnumType as EnumType # @UnusedImport from thryft.generators.js.js_exception_type import JsExceptionType as ExceptionType # @UnusedImport from thryft.generators.js.js_field import JsField as Field # @UnusedImport from thryft.generators.js.js_function import JsFunction as Function # @UnusedImport from thryft.generators.js.js_i16_type import JsI16Type as I16Type # @UnusedImport from thryft.generators.js.js_i32_type import JsI32Type as I32Type # @UnusedImport from thryft.generators.js.js_i64_type import JsI64Type as I64Type # @UnusedImport from thryft.generators.js.js_include import JsInclude as Include # @UnusedImport from thryft.generators.js.js_list_type import JsListType as ListType # @UnusedImport from thryft.generators.js.js_map_type import JsMapType as MapType # @UnusedImport from thryft.generators.js.js_service import JsService as Service # @UnusedImport from thryft.generators.js.js_set_type import JsSetType as SetType # @UnusedImport from thryft.generators.js.js_string_type import JsStringType as StringType # @UnusedImport from thryft.generators.js.js_struct_type import JsStructType as StructType # @UnusedImport from thryft.generators.js.js_typedef import JsTypedef as Typedef # @UnusedImport Parser.register_annotation_parser(JsViewMetadataAnnotationParser())
from thryft.generators.js.js_i64_type import JsI64Type as I64Type # @UnusedImport from thryft.generators.js.js_include import JsInclude as Include # @UnusedImport from thryft.generators.js.js_list_type import JsListType as ListType # @UnusedImport from thryft.generators.js.js_map_type import JsMapType as MapType # @UnusedImport from thryft.generators.js.js_native_type import JsNativeType as NativeType # @UnusedImport from thryft.generators.js.js_service import JsService as Service # @UnusedImport from thryft.generators.js.js_set_type import JsSetType as SetType # @UnusedImport from thryft.generators.js.js_string_type import JsStringType as StringType # @UnusedImport from thryft.generators.js.js_struct_type import JsStructType as StructType # @UnusedImport def __parse_js_view_metadata_annotation(ast_node, name, value, **kwds): assert isinstance(ast_node, Ast.FieldNode) try: value = json.loads(value) except ValueError, e: raise ValueError("@%s contains invalid JSON: '%s', exception: %s" % (name, value, e)) if not isinstance(value, dict): raise ValueError("expected @%s to contain a JSON object, found '%s'" % (name, value)) for subname, subvalue in value.iteritems(): if subname not in ('displayFormat', 'editControl'): logging.warn("unknown %(name)s property '%s'" % locals()) annotation = Ast.AnnotationNode(name=name, value=value, **kwds) ast_node.annotations.append(annotation) Parser.register_annotation(Ast.FieldNode, 'js_view_metadata', __parse_js_view_metadata_annotation)
def __parse_sql_foreign_key_annotation(ast_node, name, value, **kwds): value_parts = value.split('.') if len(value_parts) != 2: raise ValueError("@%s must be specify table.column: '%s'" % (name, value)) table_name, column_name = value_parts if len(table_name) == 0 or len(column_name) == 0: raise ValueError("@%s must be specify a table.column: '%s'" % (name, value)) annotation = Ast.AnnotationNode(name=name, value=(table_name, column_name), **kwds) ast_node.annotations.append(annotation) Parser.register_annotation(Ast.FieldNode, 'sql_foreign_key', __parse_sql_foreign_key_annotation) def __parse_sql_unique_annotation(ast_node, name, value, **kwds): if value is not None: raise ValueError("@%(name)s does not take a value" % locals()) ast_node.annotations.append(Ast.AnnotationNode(name=name, **kwds)) Parser.register_annotation(Ast.FieldNode, 'sql_unique', __parse_sql_unique_annotation)
from thryft.generators.sql.sql_foreign_key_annotation_parser import SqlForeignKeyAnnotationParser class SqlGenerator(Generator): from thryft.generators.sql.sql_binary_type import SqlBinaryType as BinaryType # @UnusedImport from thryft.generators.sql.sql_bool_type import SqlBoolType as BoolType # @UnusedImport from thryft.generators.sql.sql_byte_type import SqlByteType as ByteType # @UnusedImport from thryft.generators.sql.sql_const import SqlConst as Const # @UnusedImport from thryft.generators.sql.sql_document import SqlDocument as Document # @UnusedImport from thryft.generators.sql.sql_double_type import SqlDoubleType as DoubleType # @UnusedImport from thryft.generators.sql.sql_enum_type import SqlEnumType as EnumType # @UnusedImport from thryft.generators.sql.sql_exception_type import SqlExceptionType as ExceptionType # @UnusedImport from thryft.generators.sql.sql_field import SqlField as Field # @UnusedImport from thryft.generators.sql.sql_function import SqlFunction as Function # @UnusedImport from thryft.generators.sql.sql_i16_type import SqlI16Type as I16Type # @UnusedImport from thryft.generators.sql.sql_i32_type import SqlI32Type as I32Type # @UnusedImport from thryft.generators.sql.sql_i64_type import SqlI64Type as I64Type # @UnusedImport from thryft.generators.sql.sql_include import SqlInclude as Include # @UnusedImport from thryft.generators.sql.sql_list_type import SqlListType as ListType # @UnusedImport from thryft.generators.sql.sql_map_type import SqlMapType as MapType # @UnusedImport from thryft.generators.sql.sql_service import SqlService as Service # @UnusedImport from thryft.generators.sql.sql_set_type import SqlSetType as SetType # @UnusedImport from thryft.generators.sql.sql_string_type import SqlStringType as StringType # @UnusedImport from thryft.generators.sql.sql_struct_type import SqlStructType as StructType # @UnusedImport from thryft.generators.sql.sql_typedef import SqlTypedef as Typedef # @UnusedImport Parser.register_annotation_parser(AnnotationParser('sql_column', Ast.StructTypeNode)) Parser.register_annotation_parser(SqlForeignKeyAnnotationParser()) Parser.register_annotation_parser(ValuelessAnnotationParser('sql_unique', Ast.FieldNode))
class Compiler(object): class __AstVisitor(object): __ROOT_GENERATOR = Generator() def __init__(self, compiler, generator, include_dir_paths): object.__init__(self) self.__compiler = compiler self.__generator = generator self.__include_dir_paths = include_dir_paths self.__scope_stack = [] self.__type_cache = {} def __construct(self, class_name, **kwds): if len(self.__scope_stack) > 0: parent = self.__scope_stack[-1] else: parent = self.__generator kwds['parent'] = parent annotations = kwds.get('annotations') name = kwds.get('name') if annotations is None: native = False elif name is None: native = False else: native = False for annotation in annotations: if annotation == 'native': native = True break if native: for scope in reversed(self.__scope_stack): if not isinstance(scope, Document): continue document = scope overrides_module_file_path = os.path.splitext( document.path)[0] + '.py' if not os.path.isfile(overrides_module_file_path): continue overrides_module_dir_path, overrides_module_file_name = \ os.path.split(overrides_module_file_path) overrides_module_name = \ os.path.splitext(overrides_module_file_name)[0] try: overrides_module = \ imp.load_module( overrides_module_name, *imp.find_module( overrides_module_name, [overrides_module_dir_path] ) ) except ImportError: logging.error("error importing overrides module %s", overrides_module_file_path, exc_info=True) continue # Find an override implementation corresponding to our generator # For example, find JavaDateTime by looking in the module for a # class that inherits JavaNativeType. # The first_bases algorithm below covers the case where we want # JavaDateTime but the generator gave us something that itself # inherits from JavaStructType e.g., SubJavaStructType. # In this case there is no SubJavaDateTime(SubJavaStructType), so # we need to consider SubJavaStructType's parent (JavaStructType) # and look for a subclass of that. root_construct_class = getattr(self.__ROOT_GENERATOR, 'NativeType') generator_construct_class = getattr( self.__generator, 'NativeType') parent_construct_classes = [generator_construct_class] def __get_first_bases(class_, first_bases): if len(class_.__bases__) == 0: return first_base = class_.__bases__[0] if first_base is object or first_base is root_construct_class: return if first_base.__name__[0] != '_': first_bases.append(first_base) __get_first_bases(first_base, first_bases) __get_first_bases(generator_construct_class, parent_construct_classes) for parent_construct_class in parent_construct_classes: for attr in dir(overrides_module): value = getattr(overrides_module, attr) if isclass(value) and \ issubclass(value, parent_construct_class) and \ value != parent_construct_class: return getattr(overrides_module, attr)(**kwds) # logging.warn("could not find override class for %s in %s" % (default_construct_class.__name__, overrides_module_name)) return getattr(self.__generator, class_name)(**kwds) def visit_annotation_node(self, annotation_node): return annotation_node.value def __visit_annotation_nodes(self, annotation_nodes): if annotation_nodes is None: return {} return dict((annotation_node.name, annotation_node.accept(self)) for annotation_node in annotation_nodes) def visit_base_type_node(self, base_type_node): try: return self.__type_cache[base_type_node.name] except: base_type = getattr(self.__generator, base_type_node.name.capitalize() + 'Type')(name=base_type_node.name) self.__type_cache[base_type_node.name] = base_type return base_type def visit_bool_literal_node(self, bool_literal_node): return bool_literal_node.value def __visit_compound_type_node(self, construct_class_name, compound_type_node): compound_type = \ self.__construct( construct_class_name, annotations=self.__visit_annotation_nodes(compound_type_node.annotations), doc=self.__visit_doc_node(compound_type_node.doc), name=compound_type_node.name ) if isinstance(compound_type, NativeType): return compound_type self.__scope_stack.append(compound_type) # Insert the compound type into the type_map here to allow recursive # definitions self.__type_cache[compound_type.thrift_qname()] = compound_type if construct_class_name == 'EnumType': have_enumerator_with_value = False for enumerator_node in compound_type_node.enumerators: if enumerator_node.value is not None: have_enumerator_with_value = True elif have_enumerator_with_value: raise CompileException( "%s has mix of enumerators with and without values, must be one or the other" % compound_type_node.name, ast_node=compound_type_node) for enumerator_i, enumerator_node in enumerate( compound_type_node.enumerators): if enumerator_node.value is None: value = enumerator_i else: assert isinstance(enumerator_node.value, Ast.IntLiteralNode), type( enumerator_node.value) value = enumerator_node.value.value compound_type.enumerators.append( self.__construct( 'Field', annotations=self.__visit_annotation_nodes( enumerator_node.annotations), doc=self.__visit_doc_node(enumerator_node.doc), id=enumerator_i, name=enumerator_node.name, type=Ast.BaseTypeNode('i32').accept(self), value=value)) else: for field in compound_type_node.fields: compound_type.fields.append(field.accept(self)) self.__scope_stack.pop(-1) return compound_type def visit_const_node(self, const_node): return \ self.__construct( 'Const', annotations=self.__visit_annotation_nodes(const_node.annotations), doc=self.__visit_doc_node(const_node.doc), name=const_node.name, type=const_node.type.accept(self), value=const_node.value.accept(self) ) def __visit_doc_node(self, doc_node): if doc_node is not None: return doc_node.text else: return None def visit_document_node(self, document_node): document = self.__construct('Document', path=document_node.path) self.__scope_stack.append(document) for header_node in document_node.headers: document.headers.append(header_node.accept(self)) for definition_node in document_node.definitions: document.definitions.append(definition_node.accept(self)) self.__scope_stack.pop(-1) return document def visit_enum_type_node(self, enum_node): return self.__visit_compound_type_node('EnumType', enum_node) def visit_exception_type_node(self, exception_type_node): return self.__visit_compound_type_node('ExceptionType', exception_type_node) def visit_field_node(self, field_node): return \ self.__construct( 'Field', annotations=self.__visit_annotation_nodes(field_node.annotations), doc=self.__visit_doc_node(field_node.doc), id=field_node.id, name=field_node.name, required=field_node.required, type=field_node.type.accept(self), value=field_node.value ) def visit_float_literal_node(self, float_literal_node): return float_literal_node.value def visit_function_node(self, function_node): function = \ self.__construct( 'Function', annotations=self.__visit_annotation_nodes(function_node.annotations), doc=self.__visit_doc_node(function_node.doc), name=function_node.name, oneway=function_node.oneway ) self.__scope_stack.append(function) for parameter_node in function_node.parameters: function.parameters.append(parameter_node.accept(self)) if function_node.return_field is not None: function.return_field = function_node.return_field.accept(self) for throws_node in function_node.throws: function.throws.append(throws_node.accept(self)) self.__scope_stack.pop(-1) return function def visit_include_node(self, include_node): include_dir_paths = list(self.__include_dir_paths) for scope in reversed(self.__scope_stack): if isinstance(scope, Document): include_dir_paths.append(os.path.dirname(scope.path)) break include_file_relpath = include_node.path.replace('/', os.path.sep) for include_dir_path in include_dir_paths: include_file_path = os.path.join(include_dir_path, include_file_relpath) if os.path.exists(include_file_path): include_file_path = os.path.abspath(include_file_path) included_document = self.__compiler.compile( (include_file_path, ), generator=self.__generator)[0] include = \ self.__construct( 'Include', annotations=self.__visit_annotation_nodes(include_node.annotations), doc=self.__visit_doc_node(include_node.doc), document=included_document, path=include_file_relpath ) for definition in included_document.definitions: if isinstance(definition, _Type): self.__type_cache[ definition.thrift_qname()] = definition elif isinstance(definition, Typedef): self.__type_cache[definition.thrift_qname( )] = definition # .type return include raise CompileException("include path not found: %s" % include_file_relpath, ast_node=include_node) def visit_int_literal_node(self, int_literal_node): return int_literal_node.value def visit_list_literal_node(self, list_literal_node): return tuple( element_value.accept(self) for element_value in list_literal_node.value) def visit_list_type_node(self, list_type_node): return self.__visit_sequence_type_node('ListType', list_type_node) def visit_map_literal_node(self, map_literal_node): return dict((key_value.accept(self), value_value.accept(self)) for key_value, value_value in map_literal_node.value.iteritems()) def visit_map_type_node(self, map_type_node): try: return self.__type_cache[map_type_node.name] except: map_type = self.__construct( 'MapType', key_type=map_type_node.key_type.accept(self), value_type=map_type_node.value_type.accept(self)) self.__type_cache[map_type_node.name] = map_type return map_type def visit_namespace_node(self, namespace_node): return \ self.__construct( 'Namespace', annotations=self.__visit_annotation_nodes(namespace_node.annotations), doc=self.__visit_doc_node(namespace_node.doc), name=namespace_node.name, scope=namespace_node.scope ) def __visit_sequence_type_node(self, construct_class_name, sequence_type_node): try: return self.__type_cache[sequence_type_node.name] except: sequence_type = self.__construct( construct_class_name, element_type=sequence_type_node.element_type.accept(self)) self.__type_cache[sequence_type_node.name] = sequence_type return sequence_type def visit_service_node(self, service_node): service = \ self.__construct( 'Service', annotations=self.__visit_annotation_nodes(service_node.annotations), doc=self.__visit_doc_node(service_node.doc), name=service_node.name ) self.__scope_stack.append(service) for function_node in service_node.functions: service.functions.append(function_node.accept(self)) self.__scope_stack.pop(-1) return service def visit_set_type_node(self, set_type_node): return self.__visit_sequence_type_node('SetType', set_type_node) def visit_string_literal_node(self, string_literal_node): return string_literal_node.value def visit_struct_type_node(self, struct_type_node): return \ self.__visit_compound_type_node( 'StructType', struct_type_node, ) def visit_type_node(self, type_node): try: try: return self.__type_cache[type_node.qname] except KeyError: if type_node.qname == type_node.name: document = self.__scope_stack[0] return self.__type_cache[document.name + '.' + type_node.qname] else: raise except KeyError: raise CompileException("unrecognized type '%s'" % type_node.qname, ast_node=type_node) def visit_typedef_node(self, typedef_node): typedef = \ self.__construct( 'Typedef', annotations=self.__visit_annotation_nodes(typedef_node.annotations), doc=self.__visit_doc_node(typedef_node.doc), name=typedef_node.name, type=typedef_node.type.accept(self) ) self.__type_cache[typedef.thrift_qname()] = typedef.type return typedef def __init__(self, include_dir_paths=None): object.__init__(self) if include_dir_paths is None: include_dir_paths = [] elif isinstance(include_dir_paths, (list, tuple)): include_dir_paths = list(include_dir_paths) else: include_dir_paths = [include_dir_paths] if len(include_dir_paths) == 0: include_dir_paths.append(os.getcwd()) my_dir_path = os.path.dirname(os.path.realpath(__file__)) lib_thrift_src_dir_path = \ os.path.abspath(os.path.join( my_dir_path, '..', '..', '..', '..', 'lib', 'thrift', 'src' )) if lib_thrift_src_dir_path not in include_dir_paths: include_dir_paths.append(lib_thrift_src_dir_path) self.__include_dir_paths = [] for include_dir_path in include_dir_paths: if include_dir_path not in self.__include_dir_paths: self.__include_dir_paths.append(include_dir_path) self.__include_dir_paths = tuple(self.__include_dir_paths) self.__parsed_thrift_files_by_path = {} self.__scanner = Scanner() self.__parser = Parser() def __call__(self, thrift_file_paths, generator=None): return self.compile(thrift_file_paths, generator=generator) def compile(self, thrift_file_paths, generator=None): if not isinstance(thrift_file_paths, (list, tuple)): thrift_file_paths = (thrift_file_paths, ) documents = [] for thrift_file_path in thrift_file_paths: thrift_file_path = os.path.abspath(thrift_file_path) document_node = self.__parsed_thrift_files_by_path.get( thrift_file_path) if document_node is None: tokens = self.__scanner.tokenize(thrift_file_path) document_node = self.__parser.parse(tokens) self.__parsed_thrift_files_by_path[ thrift_file_path] = document_node if generator is not None: ast_visitor = \ self.__AstVisitor( compiler=self, generator=generator, include_dir_paths=self.__include_dir_paths ) document = document_node.accept(ast_visitor) documents.append(document) else: documents.append(document_node) return tuple(documents) @property def include_dir_paths(self): return self.__include_dir_paths
class Compiler(object): class __AstVisitor(object): def __init__(self, compiler, document_root_dir_path, generator, include_dir_paths): object.__init__(self) self.__compiler = compiler self.__document_root_dir_path = document_root_dir_path self.__generator = generator self.__include_dir_paths = include_dir_paths self.__scope_stack = [] self.__type_by_thrift_qname_cache = {} self.__used_include_abspaths = {} self.__visited_includes = [] def __construct(self, class_name, annotation_nodes=None, **kwds): if len(self.__scope_stack) > 0: parent = self.__scope_stack[-1] else: parent = self.__generator kwds['parent'] = parent construct_class = getattr(self.__generator, class_name) native = False if annotation_nodes: annotations = [] for annotation_node in annotation_nodes: if annotation_node.name == 'native': native = True continue try: annotation_class_name = upper_camelize( annotation_node.name) + 'Annotation' annotation_class = getattr( construct_class, annotation_class_name) except AttributeError: annotation_class = getattr( construct_class, 'Annotation') annotations.append( annotation_class( name=annotation_node.name, value=annotation_node.value ) ) kwds['annotations'] = tuple(annotations) construct = construct_class(**kwds) if not native: return construct parent_document = construct._parent_document() native_module_file_path = os.path.splitext( parent_document.path)[0] + '.py' if not os.path.isfile(native_module_file_path): return native_module_dir_path, native_module_file_name = \ os.path.split(native_module_file_path) native_module_name = \ os.path.splitext(native_module_file_name)[0] native_module = \ imp.load_module( native_module_name, *imp.find_module( native_module_name, [native_module_dir_path] ) ) native_construct_class = getattr(native_module, construct.name) construct = native_construct_class( overridden_construct=construct, **kwds) return construct def __get_type(self, type_thrift_qname, resolve=True): # e.g., struct_include_file.Struct try: return self.__type_by_thrift_qname_cache[type_thrift_qname] except KeyError: if not resolve: raise for include in self.__visited_includes: for definition in include.document.definitions: if isinstance(definition, _Type) or isinstance(definition, Typedef): definition_thrift_qname = definition.thrift_qname() if definition_thrift_qname == type_thrift_qname: self.__type_by_thrift_qname_cache[definition_thrift_qname] = definition self.__used_include_abspaths[include.abspath] = True return definition raise KeyError(type_thrift_qname) def __put_type(self, type_thrift_qname, type_): if type_thrift_qname in self.__type_by_thrift_qname_cache: raise CompileException("duplicate type %s" % type_thrift_qname) if isinstance(type_, Typedef): type_ = type_.type self.__type_by_thrift_qname_cache[type_thrift_qname] = type_ def visit_base_type_node(self, base_type_node): try: return self.__get_type(base_type_node.name, resolve=False) except KeyError: base_type = getattr(self.__generator, base_type_node.name.capitalize( ) + 'Type')(name=base_type_node.name) self.__put_type(base_type_node.name, base_type) return base_type def visit_bool_literal_node(self, bool_literal_node): return bool_literal_node.value def __visit_compound_type_node(self, construct_class_name, compound_type_node): compound_type = \ self.__construct( construct_class_name, annotations=compound_type_node.annotations, doc=self.__visit_doc_node(compound_type_node.doc), name=compound_type_node.name ) self.__scope_stack.append(compound_type) # Insert the compound type into the type_map here to allow recursive # definitions self.__put_type(compound_type.thrift_qname(), compound_type) if construct_class_name == 'EnumType': enum_type_node = compound_type_node have_enumerator_with_value = False enumerator_node_names = [] for enumerator_i, enumerator_node in enumerate(enum_type_node.enumerators): if enumerator_node.name in enumerator_node_names: raise CompileException("%s has a duplicate enumerator name, %s" % ( enum_type_node.name, enumerator_node.name), ast_node=enumerator_node) enumerator_node_names.append(enumerator_node.name) if enumerator_node.value is not None: have_enumerator_with_value = True assert isinstance(enumerator_node.value, Ast.IntLiteralNode), type( enumerator_node.value) value = enumerator_node.value.value else: if have_enumerator_with_value: raise CompileException( "%s has mix of enumerators with and without values, must be one or the other" % enum_type_node.name, ast_node=enum_type_node) value = enumerator_i compound_type.enumerators.append( self.__construct( 'Field', annotation_nodes=enumerator_node.annotations, doc=self.__visit_doc_node(enumerator_node.doc), id=enumerator_i, name=enumerator_node.name, type=Ast.BaseTypeNode('i32').accept(self), value=value ) ) else: field_name_variations = [] id_count = 0 for field_node in compound_type_node.fields: field_name = field_node.name if field_name in field_name_variations: raise CompileException("compound type %s has a duplicate field %s" % ( compound_type_node.name, field_name), ast_node=field_node) field_name_lower = field_name.lower() if field_name_lower in field_name_variations: raise CompileException("compound type %s has a duplicate field %s" % ( compound_type_node.name, field_name), ast_node=field_node) field_name_lower_camelized = lower_camelize(field_name) if field_name_lower_camelized in field_name_variations: raise CompileException("compound type %s has a duplicate field %s" % ( compound_type_node.name, field_name), ast_node=field_node) field_name_variations.append(field_name) field_name_variations.append(field_name_lower) field_name_variations.append(field_name_lower_camelized) field = field_node.accept(self) if field.required: if len(compound_type.fields) > 0: if not compound_type.fields[-1].required: raise CompileException("compound type %s has a required field %s after an optional field %s" % ( compound_type_node.name, field.name, compound_type.fields[-1].name), ast_node=compound_type_node) if field.id is not None: id_count += 1 for existing_field in compound_type.fields: if existing_field.id == field.id: raise CompileException("compound type %s has duplicate field id %d (%s and %s fields)" % ( compound_type_node.name, field.id, field.name, existing_field.name), ast_node=compound_type_node) compound_type.fields.append(field) if len(compound_type.fields) > 0: if id_count != 0 and id_count != len(compound_type_node.fields): raise CompileException("compound type %s has some fields with ids and some fields without" % compound_type_node.name, ast_node=compound_type_node) self.__scope_stack.pop(-1) return compound_type def visit_const_node(self, const_node): return \ self.__construct( 'Const', annotation_nodes=const_node.annotations, doc=self.__visit_doc_node(const_node.doc), name=const_node.name, type=const_node.type.accept(self), value=const_node.value.accept(self) ) def __visit_doc_node(self, doc_node): if doc_node is not None: return doc_node.text else: return None def visit_document_node(self, document_node): document = \ self.__construct( 'Document', document_root_dir_path=self.__document_root_dir_path, path=document_node.path ) self.__scope_stack.append(document) for header_node in document_node.headers: document.headers.append(header_node.accept(self)) for definition_node in document_node.definitions: document.definitions.append(definition_node.accept(self)) self.__scope_stack.pop(-1) for include in self.__visited_includes: include.used = include.abspath in self.__used_include_abspaths return document def visit_enum_type_node(self, enum_node): return self.__visit_compound_type_node('EnumType', enum_node) def visit_exception_type_node(self, exception_type_node): return self.__visit_compound_type_node('ExceptionType', exception_type_node) def visit_field_node(self, field_node): if field_node.value is not None: value = field_node.value.accept(self) else: value = None return \ self.__construct( 'Field', annotation_nodes=field_node.annotations, doc=self.__visit_doc_node(field_node.doc), id=field_node.id, name=field_node.name, required=field_node.required, type=field_node.type.accept(self), value=value ) def visit_float_literal_node(self, float_literal_node): return float_literal_node.value def visit_function_node(self, function_node): function = \ self.__construct( 'Function', annotation_nodes=function_node.annotations, doc=self.__visit_doc_node(function_node.doc), name=function_node.name, oneway=function_node.oneway ) self.__scope_stack.append(function) for parameter_node in function_node.parameters: parameter = parameter_node.accept(self) if parameter.required: if len(function.parameters) > 0: if not function.parameters[-1].required: raise CompileException("function %s has a required parameter %s after an optional parameter %s" % ( function.name, parameter.name, function.parameters[-1].name)) function.parameters.append(parameter) if function_node.return_field is not None: function.return_field = function_node.return_field.accept(self) for throws_node in function_node.throws: function.throws.append(throws_node.accept(self)) self.__scope_stack.pop(-1) return function def visit_include_node(self, include_node): include_dir_paths = list(self.__include_dir_paths) for scope in reversed(self.__scope_stack): if isinstance(scope, Document): include_dir_paths.append(os.path.dirname(scope.path)) break include_file_relpath = include_node.path.replace('/', os.path.sep) for include_dir_path in include_dir_paths: include_file_path = os.path.join( include_dir_path, include_file_relpath) if os.path.exists(include_file_path): include_file_path = os.path.abspath(include_file_path) included_document = \ self.__compiler.compile( generator=self.__generator, thrift_file_path=include_file_path ) include = \ self.__construct( 'Include', abspath=include_file_path, annotation_nodes=include_node.annotations, doc=self.__visit_doc_node(include_node.doc), document=included_document, relpath=include_file_relpath ) self.__visited_includes.append(include) return include raise CompileException("include path not found: %s" % include_file_relpath, ast_node=include_node) def visit_int_literal_node(self, int_literal_node): return int_literal_node.value def visit_list_literal_node(self, list_literal_node): return tuple(element_value.accept(self) for element_value in list_literal_node.value) def visit_list_type_node(self, list_type_node): return self.__visit_sequence_type_node('ListType', list_type_node) def visit_map_literal_node(self, map_literal_node): return tuple((key_value.accept(self), value_value.accept(self)) for key_value, value_value in map_literal_node.value) def visit_map_type_node(self, map_type_node): try: return self.__get_type(map_type_node.name, resolve=False) except KeyError: map_type = self.__construct('MapType', key_type=map_type_node.key_type.accept( self), value_type=map_type_node.value_type.accept(self)) self.__put_type(map_type_node.name, map_type) return map_type def visit_namespace_node(self, namespace_node): return \ self.__construct( 'Namespace', annotation_nodes=namespace_node.annotations, doc=self.__visit_doc_node(namespace_node.doc), name=namespace_node.name, scope=namespace_node.scope ) def __visit_sequence_type_node(self, construct_class_name, sequence_type_node): try: return self.__get_type(sequence_type_node.name, resolve=False) except KeyError: sequence_type = self.__construct( construct_class_name, element_type=sequence_type_node.element_type.accept(self)) self.__put_type(sequence_type_node.name, sequence_type) return sequence_type def visit_service_node(self, service_node): service = \ self.__construct( 'Service', annotation_nodes=service_node.annotations, doc=self.__visit_doc_node(service_node.doc), name=service_node.name ) self.__scope_stack.append(service) function_names_lower = [] for function_node in service_node.functions: function = function_node.accept(self) function_name_lower = function.name.lower() if function_name_lower in function_names_lower: raise CompileException( "duplicate (case-insensitive) function name '%s'" % function.name, ast_node=function_node) function_names_lower.append(function_name_lower) service.functions.append(function) self.__scope_stack.pop(-1) return service def visit_set_type_node(self, set_type_node): return self.__visit_sequence_type_node('SetType', set_type_node) def visit_string_literal_node(self, string_literal_node): return string_literal_node.value def visit_struct_type_node(self, struct_type_node): return \ self.__visit_compound_type_node( 'StructType', struct_type_node, ) def visit_type_node(self, type_node): try: try: return self.__get_type(type_node.qname) except KeyError: if type_node.qname == type_node.name: document = self.__scope_stack[0] return self.__get_type(document.name + '.' + type_node.qname) else: raise except KeyError: raise CompileException( "unrecognized type '%s'" % type_node.qname, ast_node=type_node) def visit_typedef_node(self, typedef_node): typedef = \ self.__construct( 'Typedef', annotation_nodes=typedef_node.annotations, doc=self.__visit_doc_node(typedef_node.doc), name=typedef_node.name, type=typedef_node.type.accept(self) ) self.__put_type(typedef.thrift_qname(), typedef) return typedef def __init__(self, document_root_dir_path, include_dir_paths): object.__init__(self) self.__document_root_dir_path = document_root_dir_path if isinstance(include_dir_paths, (list, tuple)): include_dir_paths = list(include_dir_paths) else: include_dir_paths = [include_dir_paths] my_dir_path = os.path.dirname(os.path.realpath(__file__)) lib_thrift_src_dir_path = \ os.path.abspath(os.path.join( my_dir_path, '..', '..', '..', '..', 'lib', 'thrift', 'src' )) if lib_thrift_src_dir_path not in include_dir_paths: include_dir_paths.append(lib_thrift_src_dir_path) self.__include_dir_paths = [] for include_dir_path in include_dir_paths: if include_dir_path not in self.__include_dir_paths: self.__include_dir_paths.append(include_dir_path) self.__include_dir_paths = tuple(self.__include_dir_paths) self.__document_cache = {} self.__scanner = Scanner() self.__parser = Parser() def __call__(self, *args, **kwds): return self.compile(*args, **kwds) def compile(self, generator, thrift_file_path): if generator is None: raise ValueError('generator must not be None') thrift_file_path = os.path.abspath(thrift_file_path) path_document_cache = self.__document_cache.get(thrift_file_path) if path_document_cache is not None: for other_generator, generator_document in path_document_cache: if other_generator is None: # Special placeholder for the AST document_node = generator_document elif generator == other_generator: return generator_document assert document_node is not None else: tokens = self.__scanner.tokenize(thrift_file_path) document_node = self.__parser.parse( thrift_file_path=thrift_file_path, tokens=tokens) # Can't hash the generator easily so just use its __eq__ self.__document_cache[thrift_file_path] = path_document_cache = [ (None, document_node)] ast_visitor = \ self.__AstVisitor( compiler=self, document_root_dir_path=self.__document_root_dir_path, generator=generator, include_dir_paths=self.__include_dir_paths ) generator_document = document_node.accept(ast_visitor) path_document_cache.append((generator, generator_document)) return generator_document @property def document_root_dir_path(self): return self.__document_root_dir_path @property def include_dir_paths(self): return self.__include_dir_paths
] def _java_methods(self): methods = [] for function in self.functions: methods.extend(function.java_definitions()) return methods def java_repr(self): delegate_name = self.__java_delegate_name() name = self.java_name() sections = [] sections.append(self.__java_markers()) sections.append("""public final static com.google.inject.name.Named DELEGATE_NAME = com.google.inject.name.Names.named("%(delegate_name)s");""" % locals()) sections.append("\n\n".join([self._java_constructor()] + self._java_methods())) sections.append("\n".join(self._java_member_declarations())) sections = "\n\n".join(indent(' ' * 4, sections)) service_qname = java_generator.JavaGenerator.Service.java_qname(self) # @UndefinedVariable return """\ @com.google.inject.Singleton public class %(name)s implements %(service_qname)s { %(sections)s }""" % locals() Parser.register_annotation_parser(JavaLogExceptionStackTraceAnnotationParser()) Parser.register_annotation_parser(JavaLogLevelAnnotationParser())
class SqlGenerator(Generator): from thryft.generators.sql.sql_binary_type import SqlBinaryType as BinaryType # @UnusedImport from thryft.generators.sql.sql_bool_type import SqlBoolType as BoolType # @UnusedImport from thryft.generators.sql.sql_byte_type import SqlByteType as ByteType # @UnusedImport from thryft.generators.sql.sql_const import SqlConst as Const # @UnusedImport from thryft.generators.sql.sql_document import SqlDocument as Document # @UnusedImport from thryft.generators.sql.sql_double_type import SqlDoubleType as DoubleType # @UnusedImport from thryft.generators.sql.sql_enum_type import SqlEnumType as EnumType # @UnusedImport from thryft.generators.sql.sql_exception_type import SqlExceptionType as ExceptionType # @UnusedImport from thryft.generators.sql.sql_field import SqlField as Field # @UnusedImport from thryft.generators.sql.sql_function import SqlFunction as Function # @UnusedImport from thryft.generators.sql.sql_i16_type import SqlI16Type as I16Type # @UnusedImport from thryft.generators.sql.sql_i32_type import SqlI32Type as I32Type # @UnusedImport from thryft.generators.sql.sql_i64_type import SqlI64Type as I64Type # @UnusedImport from thryft.generators.sql.sql_include import SqlInclude as Include # @UnusedImport from thryft.generators.sql.sql_list_type import SqlListType as ListType # @UnusedImport from thryft.generators.sql.sql_map_type import SqlMapType as MapType # @UnusedImport from thryft.generators.sql.sql_service import SqlService as Service # @UnusedImport from thryft.generators.sql.sql_set_type import SqlSetType as SetType # @UnusedImport from thryft.generators.sql.sql_string_type import SqlStringType as StringType # @UnusedImport from thryft.generators.sql.sql_struct_type import SqlStructType as StructType # @UnusedImport from thryft.generators.sql.sql_typedef import SqlTypedef as Typedef # @UnusedImport Parser.register_annotation_parser( AnnotationParser('sql_column', Ast.StructTypeNode)) Parser.register_annotation_parser(SqlForeignKeyAnnotationParser()) Parser.register_annotation_parser( ValuelessAnnotationParser('sql_unique', Ast.FieldNode))
if len(function_names) > 0 and cmp(function.name, function_names[-1]) < 0: after_function_name = '' for function_name_i in xrange( len(function_names) - 1, -1, -1): test_function_name = function_names[function_name_i] if cmp(function.name, test_function_name) >= 0: after_function_name = test_function_name break self._logger.warn( "function %s in %s is out of lexicographic order (should be after %s)", function.name, self._parent_document().path, after_function_name) function_names.append(function.name) class SetType(Generator.SetType, _SequenceType): # @UndefinedVariable pass class StringType(Generator.StringType, _Type): # @UndefinedVariable pass class StructType(Generator.StructType, _CompoundType): # @UndefinedVariable pass Parser.register_annotation_parser( ValuelessAnnotationParser('lint_suppress', (Ast.EnumTypeNode, Ast.ExceptionTypeNode, Ast.ServiceNode, Ast.StructTypeNode)))