Esempio n. 1
0
def test_dump(inputs: Sequence[Tuple[Optional[str], str]], value):
    """Test creating import lines."""

    imports = Imports()
    imports.append(
        [Import(from_=from_, import_=import_) for from_, import_ in inputs])

    assert str(imports) == value
Esempio n. 2
0
 def __init__(self, parsed_operations: List[Operation]):
     self.operations: List[Operation] = sorted(parsed_operations,
                                               key=lambda m: m.path)
     self.imports: Imports = Imports()
     for operation in self.operations:
         # create imports
         operation.arguments
         operation.snake_case_arguments
         operation.request
         operation.response
         self.imports.append(operation.imports)
Esempio n. 3
0
def generate_app_code(environment, parsed_object) -> str:
    template_path = Path('main.jinja2')
    grouped_operations = defaultdict(list)
    for k, g in itertools.groupby(
            parsed_object.operations,
            key=lambda x: x.path.strip('/').split('/')[0]):
        grouped_operations[k] += list(g)

    imports = Imports()
    routers = []
    for name, operations in grouped_operations.items():
        imports.append(
            Import(from_=CONTROLLERS_DIR_NAME + '.' + name,
                   import_=name + '_router'))

        routers.append(name + '_router')
    result = environment.get_template(str(template_path)).render(
        imports=imports,
        routers=routers,
    )

    return result
Esempio n. 4
0
 def __init__(
     self,
     parsed_operations: List[Operation],
     info: Optional[List[Dict[str, Any]]] = None,
 ):
     self.operations: List[Operation] = sorted(parsed_operations,
                                               key=lambda m: m.path)
     self.imports: Imports = Imports()
     self.info = info
     for operation in self.operations:
         # create imports
         operation.arguments
         operation.snake_case_arguments
         operation.request
         operation.response
         self.imports.append(operation.imports)
Esempio n. 5
0
    def parse(
        self,
        with_import: Optional[bool] = True,
        format_: Optional[bool] = True,
        settings_path: Optional[Path] = None,
    ) -> Union[str, Dict[Tuple[str, ...], Result]]:

        self.parse_raw()

        if with_import:
            if self.target_python_version != PythonVersion.PY_36:
                self.imports.append(IMPORT_ANNOTATIONS)

        if format_:
            code_formatter: Optional[CodeFormatter] = CodeFormatter(
                self.target_python_version, settings_path)
        else:
            code_formatter = None

        _, sorted_data_models, require_update_action_models = sort_data_models(
            self.results)

        results: Dict[Tuple[str, ...], Result] = {}

        module_key = lambda x: x.module_path

        # process in reverse order to correctly establish module levels
        grouped_models = groupby(
            sorted(sorted_data_models.values(), key=module_key, reverse=True),
            key=module_key,
        )

        module_models: List[Tuple[Tuple[str, ...], List[DataModel]]] = []

        for module, models in ((k, [*v]) for k, v in grouped_models
                               ):  # type: Tuple[str, ...], List[DataModel]

            for model in models:
                if isinstance(model, self.data_model_root_type):
                    root_data_type = model.fields[0].data_type

                    # backward compatible
                    # Remove duplicated root model
                    if (root_data_type.reference and not root_data_type.is_dict
                            and not root_data_type.is_list
                            and root_data_type.reference.source in models
                            and root_data_type.reference.name
                            == self.model_resolver.get_class_name(
                                model.reference.original_name, unique=False)):
                        # Replace referenced duplicate model to original model
                        for child in model.reference.children[:]:
                            child.replace_reference(root_data_type.reference)
                        models.remove(model)
                        continue

                    #  Custom root model can't be inherited on restriction of Pydantic
                    for child in model.reference.children:
                        # inheritance model
                        if isinstance(child, DataModel):
                            for base_class in child.base_classes:
                                if base_class.reference == model.reference:
                                    child.base_classes.remove(base_class)

            module_models.append((
                module,
                models,
            ))

            scoped_model_resolver = ModelResolver(
                exclude_names={
                    i.alias or i.import_
                    for m in models for i in m.imports
                },
                duplicate_name_suffix='Model',
            )

            for model in models:
                class_name: str = model.class_name
                generated_name: str = scoped_model_resolver.add(
                    model.path, class_name, unique=True, class_name=True).name
                if class_name != generated_name:
                    if '.' in model.reference.name:
                        model.reference.name = (
                            f"{model.reference.name.rsplit('.', 1)[0]}.{generated_name}"
                        )
                    else:
                        model.reference.name = generated_name

        for module, models in module_models:
            init = False
            if module:
                parent = (*module[:-1], '__init__.py')
                if parent not in results:
                    results[parent] = Result(body='')
                if (*module, '__init__.py') in results:
                    module = (*module, '__init__.py')
                    init = True
                else:
                    module = (*module[:-1], f'{module[-1]}.py')
            else:
                module = ('__init__.py', )

            result: List[str] = []
            imports = Imports()
            scoped_model_resolver = ModelResolver()

            for model in models:
                imports.append(model.imports)
                for data_type in model.all_data_types:
                    # To change from/import

                    if not data_type.reference or data_type.reference.source in models:
                        # No need to import non-reference model.
                        # Or, Referenced model is in the same file. we don't need to import the model
                        continue

                    if isinstance(data_type, BaseClassDataType):
                        from_ = ''.join(
                            relative(model.module_name, data_type.full_name))
                        import_ = data_type.reference.short_name
                        full_path = from_, import_
                    else:
                        from_, import_ = full_path = relative(
                            model.module_name, data_type.full_name)

                    alias = scoped_model_resolver.add(full_path, import_).name

                    name = data_type.reference.short_name
                    if from_ and import_ and alias != name:
                        data_type.alias = f'{alias}.{name}'

                    if init:
                        from_ += "."
                    imports.append(
                        Import(from_=from_, import_=import_, alias=alias))

            if self.reuse_model:
                model_cache: Dict[Tuple[str, ...], Reference] = {}
                duplicates = []
                for model in models:
                    model_key = tuple(
                        to_hashable(v) for v in (
                            model.base_classes,
                            model.extra_template_data,
                            model.fields,
                        ))
                    cached_model_reference = model_cache.get(model_key)
                    if cached_model_reference:
                        if isinstance(model, Enum):
                            for child in model.reference.children[:]:
                                # child is resolved data_type by reference
                                data_model = get_most_of_parent(child)
                                # TODO: replace reference in all modules
                                if data_model in models:  # pragma: no cover
                                    child.replace_reference(
                                        cached_model_reference)
                            duplicates.append(model)
                        else:
                            index = models.index(model)
                            inherited_model = model.__class__(
                                fields=[],
                                base_classes=[cached_model_reference],
                                description=model.description,
                                reference=Reference(
                                    name=model.name,
                                    path=model.reference.path + '/reuse',
                                ),
                            )
                            if (cached_model_reference.path
                                    in require_update_action_models):
                                require_update_action_models.append(
                                    inherited_model.path)
                            models.insert(index, inherited_model)
                            models.remove(model)

                    else:
                        model_cache[model_key] = model.reference

                for duplicate in duplicates:
                    models.remove(duplicate)

            if self.set_default_enum_member:
                for model in models:
                    for model_field in model.fields:
                        if not model_field.default:
                            continue
                        for data_type in model_field.data_type.all_data_types:
                            if data_type.reference and isinstance(
                                    data_type.reference.source,
                                    Enum):  # pragma: no cover
                                enum_member = data_type.reference.source.find_member(
                                    model_field.default)
                                if enum_member:
                                    model_field.default = enum_member
            if with_import:
                result += [str(self.imports), str(imports), '\n']

            code = dump_templates(models)
            result += [code]

            if self.dump_resolve_reference_action is not None:
                result += [
                    '\n',
                    self.dump_resolve_reference_action(
                        m.reference.short_name for m in models
                        if m.path in require_update_action_models),
                ]

            body = '\n'.join(result)
            if code_formatter:
                body = code_formatter.format_code(body)

            results[module] = Result(body=body, source=models[0].file_path)

        # retain existing behaviour
        if [*results] == [('__init__.py', )]:
            return results[('__init__.py', )].body

        return results
Esempio n. 6
0
    def __init__(
        self,
        source: Union[str, Path, List[Path], ParseResult],
        *,
        data_model_type: Type[DataModel] = pydantic_model.BaseModel,
        data_model_root_type: Type[DataModel] = pydantic_model.CustomRootType,
        data_type_manager_type: Type[DataTypeManager] = pydantic_model.
        DataTypeManager,
        data_model_field_type: Type[DataModelFieldBase] = pydantic_model.
        DataModelField,
        base_class: Optional[str] = None,
        custom_template_dir: Optional[Path] = None,
        extra_template_data: Optional[DefaultDict[str, Dict[str, Any]]] = None,
        target_python_version: PythonVersion = PythonVersion.PY_37,
        dump_resolve_reference_action: Optional[Callable[[Iterable[str]],
                                                         str]] = None,
        validation: bool = False,
        field_constraints: bool = False,
        snake_case_field: bool = False,
        strip_default_none: bool = False,
        aliases: Optional[Mapping[str, str]] = None,
        allow_population_by_field_name: bool = False,
        apply_default_values_for_required_fields: bool = False,
        force_optional_for_required_fields: bool = False,
        class_name: Optional[str] = None,
        use_standard_collections: bool = False,
        base_path: Optional[Path] = None,
        use_schema_description: bool = False,
        reuse_model: bool = False,
        encoding: str = 'utf-8',
        enum_field_as_literal: Optional[LiteralType] = None,
        set_default_enum_member: bool = False,
        strict_nullable: bool = False,
        use_generic_container_types: bool = False,
        enable_faux_immutability: bool = False,
        remote_text_cache: Optional[DefaultPutDict[str, str]] = None,
        disable_appending_item_suffix: bool = False,
        strict_types: Optional[Sequence[StrictTypes]] = None,
        empty_enum_field_name: Optional[str] = None,
        custom_class_name_generator: Optional[Callable[
            [str], str]] = title_to_class_name,
        field_extra_keys: Optional[Set[str]] = None,
        field_include_all_keys: bool = False,
    ):
        self.data_type_manager: DataTypeManager = data_type_manager_type(
            target_python_version,
            use_standard_collections,
            use_generic_container_types,
            strict_types,
        )
        self.data_model_type: Type[DataModel] = data_model_type
        self.data_model_root_type: Type[DataModel] = data_model_root_type
        self.data_model_field_type: Type[
            DataModelFieldBase] = data_model_field_type
        self.imports: Imports = Imports()
        self.base_class: Optional[str] = base_class
        self.target_python_version: PythonVersion = target_python_version
        self.results: List[DataModel] = []
        self.dump_resolve_reference_action: Optional[Callable[
            [Iterable[str]], str]] = dump_resolve_reference_action
        self.validation: bool = validation
        self.field_constraints: bool = field_constraints
        self.snake_case_field: bool = snake_case_field
        self.strip_default_none: bool = strip_default_none
        self.apply_default_values_for_required_fields: bool = (
            apply_default_values_for_required_fields)
        self.force_optional_for_required_fields: bool = (
            force_optional_for_required_fields)
        self.use_schema_description: bool = use_schema_description
        self.reuse_model: bool = reuse_model
        self.encoding: str = encoding
        self.enum_field_as_literal: Optional[
            LiteralType] = enum_field_as_literal
        self.set_default_enum_member: bool = set_default_enum_member
        self.strict_nullable: bool = strict_nullable
        self.use_generic_container_types: bool = use_generic_container_types
        self.enable_faux_immutability: bool = enable_faux_immutability
        self.custom_class_name_generator: Optional[Callable[
            [str], str]] = custom_class_name_generator
        self.field_extra_keys: Set[str] = field_extra_keys or set()
        self.field_include_all_keys: bool = field_include_all_keys

        self.remote_text_cache: DefaultPutDict[str,
                                               str] = (remote_text_cache
                                                       or DefaultPutDict())
        self.current_source_path: Optional[Path] = None

        if base_path:
            self.base_path = base_path
        elif isinstance(source, Path):
            self.base_path = (source.absolute()
                              if source.is_dir() else source.absolute().parent)
        else:
            self.base_path = Path.cwd()

        self.source: Union[str, Path, List[Path], ParseResult] = source
        self.custom_template_dir = custom_template_dir
        self.extra_template_data: DefaultDict[
            str, Any] = extra_template_data or defaultdict(dict)

        if allow_population_by_field_name:
            self.extra_template_data[ALL_MODEL][
                'allow_population_by_field_name'] = True

        if enable_faux_immutability:
            self.extra_template_data[ALL_MODEL]['allow_mutation'] = False

        self.model_resolver = ModelResolver(
            base_url=source.geturl()
            if isinstance(source, ParseResult) else None,
            singular_name_suffix='' if disable_appending_item_suffix else None,
            aliases=aliases,
            empty_field_name=empty_enum_field_name,
            snake_case_field=snake_case_field,
            custom_class_name_generator=custom_class_name_generator,
            base_path=self.base_path,
        )
        self.class_name: Optional[str] = class_name
Esempio n. 7
0
 def __init__(
     self,
     source: Union[str, pathlib.Path, List[pathlib.Path], ParseResult],
     *,
     data_model_type: Type[DataModel] = pydantic_model.BaseModel,
     data_model_root_type: Type[DataModel] = pydantic_model.CustomRootType,
     data_type_manager_type: Type[DataTypeManager] = pydantic_model.
     DataTypeManager,
     data_model_field_type: Type[DataModelFieldBase] = pydantic_model.
     DataModelField,
     base_class: Optional[str] = None,
     custom_template_dir: Optional[pathlib.Path] = None,
     extra_template_data: Optional[DefaultDict[str, Dict[str, Any]]] = None,
     target_python_version: PythonVersion = PythonVersion.PY_37,
     dump_resolve_reference_action: Optional[Callable[[Iterable[str]],
                                                      str]] = None,
     validation: bool = False,
     field_constraints: bool = False,
     snake_case_field: bool = False,
     strip_default_none: bool = False,
     aliases: Optional[Mapping[str, str]] = None,
     allow_population_by_field_name: bool = False,
     apply_default_values_for_required_fields: bool = False,
     force_optional_for_required_fields: bool = False,
     class_name: Optional[str] = None,
     use_standard_collections: bool = False,
     base_path: Optional[pathlib.Path] = None,
     use_schema_description: bool = False,
     reuse_model: bool = False,
     encoding: str = 'utf-8',
     enum_field_as_literal: Optional[LiteralType] = None,
     set_default_enum_member: bool = False,
     strict_nullable: bool = False,
     use_generic_container_types: bool = False,
     enable_faux_immutability: bool = False,
     remote_text_cache: Optional[DefaultPutDict[str, str]] = None,
     disable_appending_item_suffix: bool = False,
     strict_types: Optional[Sequence[StrictTypes]] = None,
     empty_enum_field_name: Optional[str] = None,
     custom_class_name_generator: Optional[Callable[[str], str]] = None,
     field_extra_keys: Optional[Set[str]] = None,
     field_include_all_keys: bool = False,
 ):
     super().__init__(
         source=source,
         data_model_type=data_model_type,
         data_model_root_type=data_model_root_type,
         data_type_manager_type=data_type_manager_type,
         data_model_field_type=data_model_field_type,
         base_class=base_class,
         custom_template_dir=custom_template_dir,
         extra_template_data=extra_template_data,
         target_python_version=target_python_version,
         dump_resolve_reference_action=dump_resolve_reference_action,
         validation=validation,
         field_constraints=field_constraints,
         snake_case_field=snake_case_field,
         strip_default_none=strip_default_none,
         aliases=aliases,
         allow_population_by_field_name=allow_population_by_field_name,
         apply_default_values_for_required_fields=
         apply_default_values_for_required_fields,
         force_optional_for_required_fields=
         force_optional_for_required_fields,
         class_name=class_name,
         use_standard_collections=use_standard_collections,
         base_path=base_path,
         use_schema_description=use_schema_description,
         reuse_model=reuse_model,
         encoding=encoding,
         enum_field_as_literal=enum_field_as_literal,
         set_default_enum_member=set_default_enum_member,
         strict_nullable=strict_nullable,
         use_generic_container_types=use_generic_container_types,
         enable_faux_immutability=enable_faux_immutability,
         remote_text_cache=remote_text_cache,
         disable_appending_item_suffix=disable_appending_item_suffix,
         strict_types=strict_types,
         empty_enum_field_name=empty_enum_field_name,
         custom_class_name_generator=custom_class_name_generator,
         field_extra_keys=field_extra_keys,
         field_include_all_keys=field_include_all_keys,
         openapi_scopes=[OpenAPIScope.Schemas, OpenAPIScope.Paths],
     )
     self.operations: Dict[str, Operation] = {}
     self._temporary_operation: Dict[str, Any] = {}
     self.imports_for_fastapi: Imports = Imports()
     self.data_types: List[DataType] = []
    def parse(
        self,
        with_import: Optional[bool] = True,
        format_: Optional[bool] = True
    ) -> Union[str, Dict[Tuple[str, ...], str]]:
        for obj_name, raw_obj in self.base_parser.specification['components'][
                'schemas'].items():  # type: str, Dict
            obj = JsonSchemaObject.parse_obj(raw_obj)
            if obj.is_object:
                self.parse_object(obj_name, obj)
            elif obj.is_array:
                self.parse_array(obj_name, obj)
            elif obj.enum:
                self.parse_enum(obj_name, obj)
            elif obj.allOf:
                self.parse_all_of(obj_name, obj)
            else:
                self.parse_root_type(obj_name, obj)

        if with_import:
            if self.target_python_version == PythonVersion.PY_37:
                self.imports.append(IMPORT_ANNOTATIONS)

        _, sorted_data_models, require_update_action_models = sort_data_models(
            self.results)

        results: Dict[Tuple[str, ...], str] = {}

        module_key = lambda x: (*x.name.split('.')[:-1], )

        grouped_models = groupby(sorted(sorted_data_models.values(),
                                        key=module_key),
                                 key=module_key)
        for module, models in ((k, [*v]) for k, v in grouped_models):
            module_path = '.'.join(module)

            result: List[str] = []
            imports = Imports()
            models_to_update: List[str] = []

            for model in models:
                if model.name in require_update_action_models:
                    models_to_update += [model.name]
                imports.append(model.imports)
                for ref_name in model.reference_classes:
                    if '.' not in ref_name:
                        continue
                    ref_path = ref_name.rsplit('.', 1)[0]
                    if ref_path == module_path:
                        continue
                    imports.append(Import(from_='.', import_=ref_path))

            if with_import:
                result += [imports.dump(), self.imports.dump(), '\n']

            code = dump_templates(models)
            result += [code]

            if self.dump_resolve_reference_action is not None:
                result += [
                    '\n',
                    self.dump_resolve_reference_action(models_to_update)
                ]

            body = '\n'.join(result)
            if format_:
                body = format_code(body, self.target_python_version)

            if module:
                module = (*module[:-1], f'{module[-1]}.py')
                parent = (*module[:-1], '__init__.py')
                if parent not in results:
                    results[parent] = ''
            else:
                module = ('__init__.py', )

            results[module] = body

        # retain existing behaviour
        if [*results] == [('__init__.py', )]:
            return results[('__init__.py', )]

        return results
Esempio n. 9
0
def generate_code(
    input_name: str,
    input_text: str,
    output_dir: Path,
    template_dir: Optional[Path],
    model_path: Optional[Path] = None,
    enum_field_as_literal: Optional[str] = None,
) -> None:
    if not model_path:
        model_path = MODEL_PATH
    if not output_dir.exists():
        output_dir.mkdir(parents=True)
    if not template_dir:
        template_dir = BUILTIN_TEMPLATE_DIR
    if enum_field_as_literal:
        parser = OpenAPIParser(input_text, enum_field_as_literal=enum_field_as_literal)
    else:
        parser = OpenAPIParser(input_text)
    with chdir(output_dir):
        models = parser.parse()
    if not models:
        return
    elif isinstance(models, str):
        output = output_dir / model_path
        modules = {output: (models, input_name)}
    else:
        raise Exception('Modular references are not supported in this version')

    environment: Environment = Environment(
        loader=FileSystemLoader(
            template_dir if template_dir else f"{Path(__file__).parent}/template",
            encoding="utf8",
        ),
    )
    imports = Imports()
    imports.update(parser.imports)
    for data_type in parser.data_types:
        reference = _get_most_of_reference(data_type)
        if reference:
            imports.append(data_type.all_imports)
            imports.append(
                Import.from_full_path(f'.{model_path.stem}.{reference.name}')
            )
    for from_, imports_ in parser.imports_for_fastapi.items():
        imports[from_].update(imports_)
    results: Dict[Path, str] = {}
    code_formatter = CodeFormatter(PythonVersion.PY_38, Path().resolve())
    sorted_operations: List[Operation] = sorted(
        parser.operations.values(), key=lambda m: m.path
    )
    for target in template_dir.rglob("*"):
        relative_path = target.relative_to(template_dir)
        result = environment.get_template(str(relative_path)).render(
            operations=sorted_operations, imports=imports, info=parser.parse_info(),
        )
        results[relative_path] = code_formatter.format_code(result)

    timestamp = datetime.now(timezone.utc).replace(microsecond=0).isoformat()
    header = f"""\
# generated by fastapi-codegen:
#   filename:  {Path(input_name).name}
#   timestamp: {timestamp}"""

    for path, code in results.items():
        with output_dir.joinpath(path.with_suffix(".py")).open("wt") as file:
            print(header, file=file)
            print("", file=file)
            print(code.rstrip(), file=file)

    header = f'''\
# generated by fastapi-codegen:
#   filename:  {{filename}}'''
    #     if not disable_timestamp:
    header += f'\n#   timestamp: {timestamp}'

    for path, body_and_filename in modules.items():
        body, filename = body_and_filename
        if path is None:
            file = None
        else:
            if not path.parent.exists():
                path.parent.mkdir(parents=True)
            file = path.open('wt', encoding='utf8')

        print(header.format(filename=filename), file=file)
        if body:
            print('', file=file)
            print(body.rstrip(), file=file)

        if file is not None:
            file.close()
Esempio n. 10
0
 def dump_imports(self) -> str:
     imports = Imports()
     imports.append(self.imports)
     return imports.dump()