def test_dump(inputs: Sequence[Tuple[Optional[str], str]], value): """Test creating import lines.""" imports = Imports() imports.append( [Import(from_=from_, import_=import_) for from_, import_ in inputs]) assert str(imports) == value
def __init__(self, parsed_operations: List[Operation]): self.operations: List[Operation] = sorted(parsed_operations, key=lambda m: m.path) self.imports: Imports = Imports() for operation in self.operations: # create imports operation.arguments operation.snake_case_arguments operation.request operation.response self.imports.append(operation.imports)
def generate_app_code(environment, parsed_object) -> str: template_path = Path('main.jinja2') grouped_operations = defaultdict(list) for k, g in itertools.groupby( parsed_object.operations, key=lambda x: x.path.strip('/').split('/')[0]): grouped_operations[k] += list(g) imports = Imports() routers = [] for name, operations in grouped_operations.items(): imports.append( Import(from_=CONTROLLERS_DIR_NAME + '.' + name, import_=name + '_router')) routers.append(name + '_router') result = environment.get_template(str(template_path)).render( imports=imports, routers=routers, ) return result
def __init__( self, parsed_operations: List[Operation], info: Optional[List[Dict[str, Any]]] = None, ): self.operations: List[Operation] = sorted(parsed_operations, key=lambda m: m.path) self.imports: Imports = Imports() self.info = info for operation in self.operations: # create imports operation.arguments operation.snake_case_arguments operation.request operation.response self.imports.append(operation.imports)
def parse( self, with_import: Optional[bool] = True, format_: Optional[bool] = True, settings_path: Optional[Path] = None, ) -> Union[str, Dict[Tuple[str, ...], Result]]: self.parse_raw() if with_import: if self.target_python_version != PythonVersion.PY_36: self.imports.append(IMPORT_ANNOTATIONS) if format_: code_formatter: Optional[CodeFormatter] = CodeFormatter( self.target_python_version, settings_path) else: code_formatter = None _, sorted_data_models, require_update_action_models = sort_data_models( self.results) results: Dict[Tuple[str, ...], Result] = {} module_key = lambda x: x.module_path # process in reverse order to correctly establish module levels grouped_models = groupby( sorted(sorted_data_models.values(), key=module_key, reverse=True), key=module_key, ) module_models: List[Tuple[Tuple[str, ...], List[DataModel]]] = [] for module, models in ((k, [*v]) for k, v in grouped_models ): # type: Tuple[str, ...], List[DataModel] for model in models: if isinstance(model, self.data_model_root_type): root_data_type = model.fields[0].data_type # backward compatible # Remove duplicated root model if (root_data_type.reference and not root_data_type.is_dict and not root_data_type.is_list and root_data_type.reference.source in models and root_data_type.reference.name == self.model_resolver.get_class_name( model.reference.original_name, unique=False)): # Replace referenced duplicate model to original model for child in model.reference.children[:]: child.replace_reference(root_data_type.reference) models.remove(model) continue # Custom root model can't be inherited on restriction of Pydantic for child in model.reference.children: # inheritance model if isinstance(child, DataModel): for base_class in child.base_classes: if base_class.reference == model.reference: child.base_classes.remove(base_class) module_models.append(( module, models, )) scoped_model_resolver = ModelResolver( exclude_names={ i.alias or i.import_ for m in models for i in m.imports }, duplicate_name_suffix='Model', ) for model in models: class_name: str = model.class_name generated_name: str = scoped_model_resolver.add( model.path, class_name, unique=True, class_name=True).name if class_name != generated_name: if '.' in model.reference.name: model.reference.name = ( f"{model.reference.name.rsplit('.', 1)[0]}.{generated_name}" ) else: model.reference.name = generated_name for module, models in module_models: init = False if module: parent = (*module[:-1], '__init__.py') if parent not in results: results[parent] = Result(body='') if (*module, '__init__.py') in results: module = (*module, '__init__.py') init = True else: module = (*module[:-1], f'{module[-1]}.py') else: module = ('__init__.py', ) result: List[str] = [] imports = Imports() scoped_model_resolver = ModelResolver() for model in models: imports.append(model.imports) for data_type in model.all_data_types: # To change from/import if not data_type.reference or data_type.reference.source in models: # No need to import non-reference model. # Or, Referenced model is in the same file. we don't need to import the model continue if isinstance(data_type, BaseClassDataType): from_ = ''.join( relative(model.module_name, data_type.full_name)) import_ = data_type.reference.short_name full_path = from_, import_ else: from_, import_ = full_path = relative( model.module_name, data_type.full_name) alias = scoped_model_resolver.add(full_path, import_).name name = data_type.reference.short_name if from_ and import_ and alias != name: data_type.alias = f'{alias}.{name}' if init: from_ += "." imports.append( Import(from_=from_, import_=import_, alias=alias)) if self.reuse_model: model_cache: Dict[Tuple[str, ...], Reference] = {} duplicates = [] for model in models: model_key = tuple( to_hashable(v) for v in ( model.base_classes, model.extra_template_data, model.fields, )) cached_model_reference = model_cache.get(model_key) if cached_model_reference: if isinstance(model, Enum): for child in model.reference.children[:]: # child is resolved data_type by reference data_model = get_most_of_parent(child) # TODO: replace reference in all modules if data_model in models: # pragma: no cover child.replace_reference( cached_model_reference) duplicates.append(model) else: index = models.index(model) inherited_model = model.__class__( fields=[], base_classes=[cached_model_reference], description=model.description, reference=Reference( name=model.name, path=model.reference.path + '/reuse', ), ) if (cached_model_reference.path in require_update_action_models): require_update_action_models.append( inherited_model.path) models.insert(index, inherited_model) models.remove(model) else: model_cache[model_key] = model.reference for duplicate in duplicates: models.remove(duplicate) if self.set_default_enum_member: for model in models: for model_field in model.fields: if not model_field.default: continue for data_type in model_field.data_type.all_data_types: if data_type.reference and isinstance( data_type.reference.source, Enum): # pragma: no cover enum_member = data_type.reference.source.find_member( model_field.default) if enum_member: model_field.default = enum_member if with_import: result += [str(self.imports), str(imports), '\n'] code = dump_templates(models) result += [code] if self.dump_resolve_reference_action is not None: result += [ '\n', self.dump_resolve_reference_action( m.reference.short_name for m in models if m.path in require_update_action_models), ] body = '\n'.join(result) if code_formatter: body = code_formatter.format_code(body) results[module] = Result(body=body, source=models[0].file_path) # retain existing behaviour if [*results] == [('__init__.py', )]: return results[('__init__.py', )].body return results
def __init__( self, source: Union[str, Path, List[Path], ParseResult], *, data_model_type: Type[DataModel] = pydantic_model.BaseModel, data_model_root_type: Type[DataModel] = pydantic_model.CustomRootType, data_type_manager_type: Type[DataTypeManager] = pydantic_model. DataTypeManager, data_model_field_type: Type[DataModelFieldBase] = pydantic_model. DataModelField, base_class: Optional[str] = None, custom_template_dir: Optional[Path] = None, extra_template_data: Optional[DefaultDict[str, Dict[str, Any]]] = None, target_python_version: PythonVersion = PythonVersion.PY_37, dump_resolve_reference_action: Optional[Callable[[Iterable[str]], str]] = None, validation: bool = False, field_constraints: bool = False, snake_case_field: bool = False, strip_default_none: bool = False, aliases: Optional[Mapping[str, str]] = None, allow_population_by_field_name: bool = False, apply_default_values_for_required_fields: bool = False, force_optional_for_required_fields: bool = False, class_name: Optional[str] = None, use_standard_collections: bool = False, base_path: Optional[Path] = None, use_schema_description: bool = False, reuse_model: bool = False, encoding: str = 'utf-8', enum_field_as_literal: Optional[LiteralType] = None, set_default_enum_member: bool = False, strict_nullable: bool = False, use_generic_container_types: bool = False, enable_faux_immutability: bool = False, remote_text_cache: Optional[DefaultPutDict[str, str]] = None, disable_appending_item_suffix: bool = False, strict_types: Optional[Sequence[StrictTypes]] = None, empty_enum_field_name: Optional[str] = None, custom_class_name_generator: Optional[Callable[ [str], str]] = title_to_class_name, field_extra_keys: Optional[Set[str]] = None, field_include_all_keys: bool = False, ): self.data_type_manager: DataTypeManager = data_type_manager_type( target_python_version, use_standard_collections, use_generic_container_types, strict_types, ) self.data_model_type: Type[DataModel] = data_model_type self.data_model_root_type: Type[DataModel] = data_model_root_type self.data_model_field_type: Type[ DataModelFieldBase] = data_model_field_type self.imports: Imports = Imports() self.base_class: Optional[str] = base_class self.target_python_version: PythonVersion = target_python_version self.results: List[DataModel] = [] self.dump_resolve_reference_action: Optional[Callable[ [Iterable[str]], str]] = dump_resolve_reference_action self.validation: bool = validation self.field_constraints: bool = field_constraints self.snake_case_field: bool = snake_case_field self.strip_default_none: bool = strip_default_none self.apply_default_values_for_required_fields: bool = ( apply_default_values_for_required_fields) self.force_optional_for_required_fields: bool = ( force_optional_for_required_fields) self.use_schema_description: bool = use_schema_description self.reuse_model: bool = reuse_model self.encoding: str = encoding self.enum_field_as_literal: Optional[ LiteralType] = enum_field_as_literal self.set_default_enum_member: bool = set_default_enum_member self.strict_nullable: bool = strict_nullable self.use_generic_container_types: bool = use_generic_container_types self.enable_faux_immutability: bool = enable_faux_immutability self.custom_class_name_generator: Optional[Callable[ [str], str]] = custom_class_name_generator self.field_extra_keys: Set[str] = field_extra_keys or set() self.field_include_all_keys: bool = field_include_all_keys self.remote_text_cache: DefaultPutDict[str, str] = (remote_text_cache or DefaultPutDict()) self.current_source_path: Optional[Path] = None if base_path: self.base_path = base_path elif isinstance(source, Path): self.base_path = (source.absolute() if source.is_dir() else source.absolute().parent) else: self.base_path = Path.cwd() self.source: Union[str, Path, List[Path], ParseResult] = source self.custom_template_dir = custom_template_dir self.extra_template_data: DefaultDict[ str, Any] = extra_template_data or defaultdict(dict) if allow_population_by_field_name: self.extra_template_data[ALL_MODEL][ 'allow_population_by_field_name'] = True if enable_faux_immutability: self.extra_template_data[ALL_MODEL]['allow_mutation'] = False self.model_resolver = ModelResolver( base_url=source.geturl() if isinstance(source, ParseResult) else None, singular_name_suffix='' if disable_appending_item_suffix else None, aliases=aliases, empty_field_name=empty_enum_field_name, snake_case_field=snake_case_field, custom_class_name_generator=custom_class_name_generator, base_path=self.base_path, ) self.class_name: Optional[str] = class_name
def __init__( self, source: Union[str, pathlib.Path, List[pathlib.Path], ParseResult], *, data_model_type: Type[DataModel] = pydantic_model.BaseModel, data_model_root_type: Type[DataModel] = pydantic_model.CustomRootType, data_type_manager_type: Type[DataTypeManager] = pydantic_model. DataTypeManager, data_model_field_type: Type[DataModelFieldBase] = pydantic_model. DataModelField, base_class: Optional[str] = None, custom_template_dir: Optional[pathlib.Path] = None, extra_template_data: Optional[DefaultDict[str, Dict[str, Any]]] = None, target_python_version: PythonVersion = PythonVersion.PY_37, dump_resolve_reference_action: Optional[Callable[[Iterable[str]], str]] = None, validation: bool = False, field_constraints: bool = False, snake_case_field: bool = False, strip_default_none: bool = False, aliases: Optional[Mapping[str, str]] = None, allow_population_by_field_name: bool = False, apply_default_values_for_required_fields: bool = False, force_optional_for_required_fields: bool = False, class_name: Optional[str] = None, use_standard_collections: bool = False, base_path: Optional[pathlib.Path] = None, use_schema_description: bool = False, reuse_model: bool = False, encoding: str = 'utf-8', enum_field_as_literal: Optional[LiteralType] = None, set_default_enum_member: bool = False, strict_nullable: bool = False, use_generic_container_types: bool = False, enable_faux_immutability: bool = False, remote_text_cache: Optional[DefaultPutDict[str, str]] = None, disable_appending_item_suffix: bool = False, strict_types: Optional[Sequence[StrictTypes]] = None, empty_enum_field_name: Optional[str] = None, custom_class_name_generator: Optional[Callable[[str], str]] = None, field_extra_keys: Optional[Set[str]] = None, field_include_all_keys: bool = False, ): super().__init__( source=source, data_model_type=data_model_type, data_model_root_type=data_model_root_type, data_type_manager_type=data_type_manager_type, data_model_field_type=data_model_field_type, base_class=base_class, custom_template_dir=custom_template_dir, extra_template_data=extra_template_data, target_python_version=target_python_version, dump_resolve_reference_action=dump_resolve_reference_action, validation=validation, field_constraints=field_constraints, snake_case_field=snake_case_field, strip_default_none=strip_default_none, aliases=aliases, allow_population_by_field_name=allow_population_by_field_name, apply_default_values_for_required_fields= apply_default_values_for_required_fields, force_optional_for_required_fields= force_optional_for_required_fields, class_name=class_name, use_standard_collections=use_standard_collections, base_path=base_path, use_schema_description=use_schema_description, reuse_model=reuse_model, encoding=encoding, enum_field_as_literal=enum_field_as_literal, set_default_enum_member=set_default_enum_member, strict_nullable=strict_nullable, use_generic_container_types=use_generic_container_types, enable_faux_immutability=enable_faux_immutability, remote_text_cache=remote_text_cache, disable_appending_item_suffix=disable_appending_item_suffix, strict_types=strict_types, empty_enum_field_name=empty_enum_field_name, custom_class_name_generator=custom_class_name_generator, field_extra_keys=field_extra_keys, field_include_all_keys=field_include_all_keys, openapi_scopes=[OpenAPIScope.Schemas, OpenAPIScope.Paths], ) self.operations: Dict[str, Operation] = {} self._temporary_operation: Dict[str, Any] = {} self.imports_for_fastapi: Imports = Imports() self.data_types: List[DataType] = []
def parse( self, with_import: Optional[bool] = True, format_: Optional[bool] = True ) -> Union[str, Dict[Tuple[str, ...], str]]: for obj_name, raw_obj in self.base_parser.specification['components'][ 'schemas'].items(): # type: str, Dict obj = JsonSchemaObject.parse_obj(raw_obj) if obj.is_object: self.parse_object(obj_name, obj) elif obj.is_array: self.parse_array(obj_name, obj) elif obj.enum: self.parse_enum(obj_name, obj) elif obj.allOf: self.parse_all_of(obj_name, obj) else: self.parse_root_type(obj_name, obj) if with_import: if self.target_python_version == PythonVersion.PY_37: self.imports.append(IMPORT_ANNOTATIONS) _, sorted_data_models, require_update_action_models = sort_data_models( self.results) results: Dict[Tuple[str, ...], str] = {} module_key = lambda x: (*x.name.split('.')[:-1], ) grouped_models = groupby(sorted(sorted_data_models.values(), key=module_key), key=module_key) for module, models in ((k, [*v]) for k, v in grouped_models): module_path = '.'.join(module) result: List[str] = [] imports = Imports() models_to_update: List[str] = [] for model in models: if model.name in require_update_action_models: models_to_update += [model.name] imports.append(model.imports) for ref_name in model.reference_classes: if '.' not in ref_name: continue ref_path = ref_name.rsplit('.', 1)[0] if ref_path == module_path: continue imports.append(Import(from_='.', import_=ref_path)) if with_import: result += [imports.dump(), self.imports.dump(), '\n'] code = dump_templates(models) result += [code] if self.dump_resolve_reference_action is not None: result += [ '\n', self.dump_resolve_reference_action(models_to_update) ] body = '\n'.join(result) if format_: body = format_code(body, self.target_python_version) if module: module = (*module[:-1], f'{module[-1]}.py') parent = (*module[:-1], '__init__.py') if parent not in results: results[parent] = '' else: module = ('__init__.py', ) results[module] = body # retain existing behaviour if [*results] == [('__init__.py', )]: return results[('__init__.py', )] return results
def generate_code( input_name: str, input_text: str, output_dir: Path, template_dir: Optional[Path], model_path: Optional[Path] = None, enum_field_as_literal: Optional[str] = None, ) -> None: if not model_path: model_path = MODEL_PATH if not output_dir.exists(): output_dir.mkdir(parents=True) if not template_dir: template_dir = BUILTIN_TEMPLATE_DIR if enum_field_as_literal: parser = OpenAPIParser(input_text, enum_field_as_literal=enum_field_as_literal) else: parser = OpenAPIParser(input_text) with chdir(output_dir): models = parser.parse() if not models: return elif isinstance(models, str): output = output_dir / model_path modules = {output: (models, input_name)} else: raise Exception('Modular references are not supported in this version') environment: Environment = Environment( loader=FileSystemLoader( template_dir if template_dir else f"{Path(__file__).parent}/template", encoding="utf8", ), ) imports = Imports() imports.update(parser.imports) for data_type in parser.data_types: reference = _get_most_of_reference(data_type) if reference: imports.append(data_type.all_imports) imports.append( Import.from_full_path(f'.{model_path.stem}.{reference.name}') ) for from_, imports_ in parser.imports_for_fastapi.items(): imports[from_].update(imports_) results: Dict[Path, str] = {} code_formatter = CodeFormatter(PythonVersion.PY_38, Path().resolve()) sorted_operations: List[Operation] = sorted( parser.operations.values(), key=lambda m: m.path ) for target in template_dir.rglob("*"): relative_path = target.relative_to(template_dir) result = environment.get_template(str(relative_path)).render( operations=sorted_operations, imports=imports, info=parser.parse_info(), ) results[relative_path] = code_formatter.format_code(result) timestamp = datetime.now(timezone.utc).replace(microsecond=0).isoformat() header = f"""\ # generated by fastapi-codegen: # filename: {Path(input_name).name} # timestamp: {timestamp}""" for path, code in results.items(): with output_dir.joinpath(path.with_suffix(".py")).open("wt") as file: print(header, file=file) print("", file=file) print(code.rstrip(), file=file) header = f'''\ # generated by fastapi-codegen: # filename: {{filename}}''' # if not disable_timestamp: header += f'\n# timestamp: {timestamp}' for path, body_and_filename in modules.items(): body, filename = body_and_filename if path is None: file = None else: if not path.parent.exists(): path.parent.mkdir(parents=True) file = path.open('wt', encoding='utf8') print(header.format(filename=filename), file=file) if body: print('', file=file) print(body.rstrip(), file=file) if file is not None: file.close()
def dump_imports(self) -> str: imports = Imports() imports.append(self.imports) return imports.dump()