def get_data_type(self, schema: JsonSchemaObject, suffix: str = '') -> DataType:
        if schema.ref:
            data_type = self.openapi_model_parser.get_ref_data_type(schema.ref)
            self.imports.append(
                Import(
                    # TODO: Improve import statements
                    from_=model_module_name_var.get(),
                    import_=data_type.type_hint,
                )
            )
            return data_type
        elif schema.is_array:
            # TODO: Improve handling array
            items = schema.items if isinstance(schema.items, list) else [schema.items]
            return self.openapi_model_parser.data_type(
                data_types=[self.get_data_type(i, suffix) for i in items], is_list=True
            )
        elif schema.is_object:
            camelcase_path = stringcase.camelcase(self.path[1:].replace("/", "_"))
            capitalized_suffix = suffix.capitalize()
            name: str = f'{camelcase_path}{self.type.capitalize()}{capitalized_suffix}'
            path = ['paths', self.path, self.type, capitalized_suffix]

            data_type = self.openapi_model_parser.parse_object(name, schema, path)

            self.imports.append(
                Import(from_=model_module_name_var.get(), import_=data_type.type_hint,)
            )
            return data_type

        return self.openapi_model_parser.get_data_type(schema)
    def get_parameter_type(
        self, parameter: Dict[str, Union[str, Dict[str, str]]], snake_case: bool
    ) -> Argument:
        schema: JsonSchemaObject = JsonSchemaObject.parse_obj(parameter["schema"])
        format_ = schema.format or "default"
        type_ = json_schema_data_formats[schema.type][format_]
        name: str = parameter["name"]  # type: ignore
        orig_name = name
        if snake_case:
            name = stringcase.snakecase(name)

        field = DataModelField(
            name=name,
            data_type=type_map[type_],
            required=parameter.get("required") or parameter.get("in") == "path",
        )
        self.imports.extend(field.imports)
        if orig_name != name:
            default: Optional[
                str
            ] = f"Query({'...' if field.required else repr(schema.default)}, alias='{orig_name}')"
            self.imports.append(Import(from_='fastapi', import_='Query'))
        else:
            default = repr(schema.default) if 'default' in parameter["schema"] else None
        return Argument(
            name=field.name,
            type_hint=field.type_hint,
            default=default,  # type: ignore
            default_value=schema.default,
            required=field.required,
        )
    def __init__(
        self,
        *,
        reference: Reference,
        fields: List[DataModelFieldBase],
        decorators: Optional[List[str]] = None,
        base_classes: Optional[List[Reference]] = None,
        custom_base_class: Optional[str] = None,
        custom_template_dir: Optional[Path] = None,
        extra_template_data: Optional[DefaultDict[str, Any]] = None,
        path: Optional[Path] = None,
        description: Optional[str] = None,
    ):

        super().__init__(
            reference=reference,
            fields=fields,
            decorators=decorators,
            base_classes=base_classes,
            custom_base_class=custom_base_class,
            custom_template_dir=custom_template_dir,
            extra_template_data=extra_template_data,
            path=path,
            description=description,
        )
        self._additional_imports.append(
            Import.from_full_path('pydantic.dataclasses.dataclass'))
    def __init__(
        self,
        name: str,
        fields: List[DataModelField],
        decorators: Optional[List[str]] = None,
        base_classes: Optional[List[str]] = None,
        custom_base_class: Optional[str] = None,
        custom_template_dir: Optional[Path] = None,
        extra_template_data: Optional[DefaultDict[str, Any]] = None,
        auto_import: bool = True,
        reference_classes: Optional[List[str]] = None,
    ):

        super().__init__(
            name,
            fields,
            decorators,
            base_classes,
            custom_base_class=custom_base_class,
            custom_template_dir=custom_template_dir,
            extra_template_data=extra_template_data,
            auto_import=auto_import,
            reference_classes=reference_classes,
        )
        self.imports.append(Import.from_full_path('pydantic.dataclasses.dataclass'))
 def request(self) -> Optional[Argument]:
     arguments: List[Argument] = []
     for requests in self.request_objects:
         for content_type, schema in requests.contents.items():
             # TODO: support other content-types
             if RE_APPLICATION_JSON_PATTERN.match(content_type):
                 data_type = self.get_data_type(schema, 'request')
                 arguments.append(
                     # TODO: support multiple body
                     Argument(
                         name='body',  # type: ignore
                         type_hint=data_type.type_hint,
                         required=requests.required,
                     ))
                 self.imports.extend(data_type.imports)
             elif content_type == 'application/x-www-form-urlencoded':
                 arguments.append(
                     # TODO: support form with `Form()`
                     Argument(
                         name='request',  # type: ignore
                         type_hint='Request',  # type: ignore
                         required=True,
                     ))
                 self.imports.append(
                     Import.from_full_path('starlette.requests.Request'))
     if not arguments:
         return None
     return arguments[0]
Exemple #6
0
    def __init__(
        self,
        name: str,
        fields: List[DataModelField],
        decorators: Optional[List[str]] = None,
        base_classes: Optional[List[str]] = None,
        custom_base_class: Optional[str] = None,
        custom_template_dir: Optional[Path] = None,
        extra_template_data: Optional[DefaultDict[str, Any]] = None,
        auto_import: bool = True,
        reference_classes: Optional[List[str]] = None,
        imports: Optional[List[Import]] = None,
    ):

        super().__init__(
            name=name,
            fields=fields,
            decorators=decorators,
            base_classes=base_classes,
            custom_base_class=custom_base_class,
            custom_template_dir=custom_template_dir,
            extra_template_data=extra_template_data,
            auto_import=auto_import,
            reference_classes=reference_classes,
            imports=imports,
        )

        if 'additionalProperties' in self.extra_template_data:
            self.extra_template_data['config'] = Config(extra='Extra.allow')
            self.imports.append(Import(from_='pydantic', import_='Extra'))
 def parse_any_of(self, name: str, obj: JsonSchemaObject) -> List[DataType]:
     any_of_data_types: List[DataType] = []
     for any_of_item in obj.anyOf:
         if any_of_item.ref:  # $ref
             any_of_data_types.append(
                 self.data_type(
                     type=any_of_item.ref_object_name,
                     ref=True,
                     version_compatible=True,
                 ))
         elif not any(v
                      for k, v in vars(any_of_item).items() if k != 'type'):
             # trivial types
             any_of_data_types.append(self.get_data_type(any_of_item))
         elif (any_of_item.is_array
               and isinstance(any_of_item.items, JsonSchemaObject)
               and not any(v for k, v in vars(any_of_item.items).items()
                           if k != 'type')):
             # trivial item types
             any_of_data_types.append(
                 self.data_type(
                     type=
                     f"List[{self.get_data_type(any_of_item.items).type_hint}]",
                     imports_=[Import(from_='typing', import_='List')],
                 ))
         else:
             singular_name = get_singular_name(name)
             self.parse_object(singular_name, any_of_item)
             any_of_data_types.append(
                 self.data_type(type=singular_name,
                                ref=True,
                                version_compatible=True))
     return any_of_data_types
    def get_parameter_type(
        self, parameter: Dict[str, Union[str, Dict[str, str]]], snake_case: bool
    ) -> Argument:
        ref: Optional[str] = parameter.get('$ref')  # type: ignore
        if ref:
            parameter = get_ref_body(ref, self.openapi_model_parser, self.components)
        name: str = parameter["name"]  # type: ignore
        orig_name = name
        if snake_case:
            name = stringcase.snakecase(name)
        schema: JsonSchemaObject = JsonSchemaObject.parse_obj(parameter["schema"])

        field = DataModelField(
            name=name,
            data_type=self.get_data_type(schema),
            required=parameter.get("required") or parameter.get("in") == "path",
        )
        self.imports.extend(field.imports)
        if orig_name != name:
            default: Optional[
                str
            ] = f"Query({'...' if field.required else repr(schema.default)}, alias='{orig_name}')"
            self.imports.append(Import(from_='fastapi', import_='Query'))
        else:
            default = repr(schema.default) if 'default' in parameter["schema"] else None
        return Argument(
            name=field.name,
            type_hint=field.type_hint,
            default=default,  # type: ignore
            default_value=schema.default,
            required=field.required,
        )
Exemple #9
0
    def __init__(
        self,
        *,
        reference: Reference,
        fields: List[DataModelFieldBase],
        decorators: Optional[List[str]] = None,
        base_classes: Optional[List[Reference]] = None,
        custom_base_class: Optional[str] = None,
        custom_template_dir: Optional[Path] = None,
        extra_template_data: Optional[DefaultDict[str, Dict[str, Any]]] = None,
        methods: Optional[List[str]] = None,
        path: Optional[Path] = None,
        description: Optional[str] = None,
    ) -> None:
        if not self.TEMPLATE_FILE_PATH:
            raise Exception('TEMPLATE_FILE_PATH is undefined')

        template_file_path = Path(self.TEMPLATE_FILE_PATH)
        if custom_template_dir is not None:
            custom_template_file_path = custom_template_dir / template_file_path.name
            if custom_template_file_path.exists():
                template_file_path = custom_template_file_path

        self.fields: List[DataModelFieldBase] = fields or []
        self.decorators: List[str] = decorators or []
        self._additional_imports: List[Import] = []
        self.base_classes: List[Reference] = [
            base_class for base_class in base_classes or [] if base_class
        ]
        self.custom_base_class = custom_base_class
        self.file_path: Optional[Path] = path
        self.reference: Reference = reference

        self.reference.source = self

        self.extra_template_data = (
            extra_template_data[self.name]
            if extra_template_data is not None
            else defaultdict(dict)
        )

        if not self.base_classes:
            base_class_full_path = custom_base_class or self.BASE_CLASS
            if base_class_full_path:
                self._additional_imports.append(
                    Import.from_full_path(base_class_full_path)
                )

        if extra_template_data:
            all_model_extra_template_data = extra_template_data.get(ALL_MODEL)
            if all_model_extra_template_data:
                self.extra_template_data.update(all_model_extra_template_data)

        self.methods: List[str] = methods or []

        self.description = description
        for field in self.fields:
            field.parent = self

        super().__init__(template_file_path=template_file_path)
Exemple #10
0
def test_dump(inputs: Sequence[Tuple[Optional[str], str]], value):
    """Test creating import lines."""

    imports = Imports()
    imports.append(
        [Import(from_=from_, import_=import_) for from_, import_ in inputs])

    assert str(imports) == value
Exemple #11
0
    def __init__(
        self,
        name: str,
        fields: List[DataModelField],
        decorators: Optional[List[str]] = None,
        base_classes: Optional[List[str]] = None,
        custom_base_class: Optional[str] = None,
        custom_template_dir: Optional[Path] = None,
        extra_template_data: Optional[DefaultDict[str, Any]] = None,
        auto_import: bool = True,
        reference_classes: Optional[List[str]] = None,
        imports: Optional[List[Import]] = None,
    ):

        methods: List[str] = [field.method for field in fields if field.method]

        super().__init__(
            name=name,
            fields=fields,  # type: ignore
            decorators=decorators,
            base_classes=base_classes,
            custom_base_class=custom_base_class,
            custom_template_dir=custom_template_dir,
            extra_template_data=extra_template_data,
            auto_import=auto_import,
            reference_classes=reference_classes,
            imports=imports,
            methods=methods,
        )

        config_parameters: Dict[str, Any] = {}

        if 'additionalProperties' in self.extra_template_data:
            config_parameters['extra'] = 'Extra.allow'
            self.imports.append(Import(from_='pydantic', import_='Extra'))

        if config_parameters:
            from datamodel_code_generator.model.pydantic import Config

            self.extra_template_data['config'] = Config.parse_obj(
                config_parameters)

        for field in fields:
            if field.field:
                self.imports.append(Import(from_='pydantic', import_='Field'))
Exemple #12
0
def test_get_data_type(schema_type, schema_format, result_type, from_, import_):
    if from_ and import_:
        imports_: Optional[List[Import]] = [Import(from_=from_, import_=import_)]
    else:
        imports_ = None

    parser = OpenAPIParser(BaseModel, CustomRootType)
    assert parser.get_data_type(
        JsonSchemaObject(type=schema_type, format=schema_format)
    ) == DataType(type=result_type, imports_=imports_)
Exemple #13
0
def test_get_data_type(schema_type, schema_format, result_type, from_, import_):
    if from_ and import_:
        imports: Optional[List[Import]] = [Import(from_=from_, import_=import_)]
    else:
        imports = []

    parser = JsonSchemaParser('')
    assert parser.get_data_type(
        JsonSchemaObject(type=schema_type, format=schema_format)
    ) == DataType(type=result_type, imports=imports)
Exemple #14
0
    def get_parameter_type(
        self,
        parameters: ParameterObject,
        snake_case: bool,
        path: List[str],
    ) -> Optional[Argument]:
        orig_name = parameters.name
        if snake_case:
            name = stringcase.snakecase(parameters.name)
        else:
            name = parameters.name

        schema: Optional[JsonSchemaObject] = None
        data_type: Optional[DataType] = None
        for content in parameters.content.values():
            if isinstance(content.schema_, ReferenceObject):
                data_type = self.get_ref_data_type(content.schema_.ref)
                ref_model = self.get_ref_model(content.schema_.ref)
                schema = JsonSchemaObject.parse_obj(ref_model)
            else:
                schema = content.schema_
            break
        if not data_type:
            if not schema:
                schema = parameters.schema_
            data_type = self.parse_schema(name, schema, [*path, name])
        if not schema:
            return None

        field = DataModelField(
            name=name,
            data_type=data_type,
            required=parameters.required
            or parameters.in_ == ParameterLocation.path,
        )

        if orig_name != name:
            if parameters.in_:
                param_is = parameters.in_.value.lower().capitalize()
                self.imports_for_fastapi.append(
                    Import(from_='fastapi', import_=param_is))
                default: Optional[
                    str] = f"{param_is}({'...' if field.required else repr(schema.default)}, alias='{orig_name}')"
        else:
            default = repr(schema.default) if schema.has_default else None
        self.imports_for_fastapi.append(field.imports)
        self.data_types.append(field.data_type)
        return Argument(
            name=field.name,
            type_hint=field.type_hint,
            default=default,  # type: ignore
            default_value=schema.default,
            required=field.required,
        )
    def response(self) -> str:
        models: List[str] = []
        for response in self.response_objects:
            # expect 2xx
            if response.status_code.startswith("2"):
                for content_type, schema in response.contents.items():
                    if content_type == "application/json":
                        if schema.is_array:
                            if isinstance(schema.items, list):
                                type_ = f'List[{",".join(i.ref_object_name for i in schema.items)}]'
                                self.imports.extend(
                                    Import(
                                        from_=model_path_var.get(),
                                        import_=i.ref_object_name,
                                    )
                                    for i in schema.items
                                )
                            else:
                                type_ = f'List[{schema.items.ref_object_name}]'
                                self.imports.append(
                                    Import(
                                        from_=model_path_var.get(),
                                        import_=schema.items.ref_object_name,
                                    )
                                )
                            self.imports.append(IMPORT_LIST)
                        else:
                            type_ = schema.ref_object_name
                            self.imports.append(
                                Import(
                                    from_=model_path_var.get(),
                                    import_=schema.ref_object_name,
                                )
                            )
                        models.append(type_)

        if not models:
            return "None"
        if len(models) > 1:
            return f'Union[{",".join(models)}]'
        return models[0]
def test_get_data_type(schema_type, schema_format, result_type, from_,
                       import_):
    if from_ and import_:
        import_: Optional[Import] = Import(from_=from_, import_=import_)
    else:
        import_ = None

    parser = JsonSchemaParser('')
    assert (parser.get_data_type(
        JsonSchemaObject(type=schema_type,
                         format=schema_format)).dict() == DataType(
                             type=result_type, import_=import_).dict())
    def get_parameter_type(self, parameter: Dict[str, Union[str, Dict[str,
                                                                      Any]]],
                           snake_case: bool) -> Argument:
        ref: Optional[str] = parameter.get('$ref')  # type: ignore
        if ref:
            parameter = get_ref_body(ref, self.openapi_model_parser,
                                     self.components)
        name: str = parameter["name"]  # type: ignore
        orig_name = name
        if snake_case:
            name = stringcase.snakecase(name)
        content = parameter.get('content')
        schema: Optional[JsonSchemaObject] = None
        if content and isinstance(content, dict):
            content_schema = [
                c.get("schema") for c in content.values()
                if isinstance(c.get("schema"), dict)
            ]
            if content_schema:
                schema = JsonSchemaObject.parse_obj(content_schema[0])
        if not schema:
            schema = JsonSchemaObject.parse_obj(parameter["schema"])

        field = DataModelField(
            name=name,
            data_type=self.get_data_type(schema, 'parameter'),
            required=parameter.get("required")
            or parameter.get("in") == "path",
        )
        self.imports.extend(field.imports)
        if orig_name != name:
            has_in = parameter.get('in')
            if has_in and isinstance(has_in, str):
                param_is = has_in.lower().capitalize()
                self.imports.append(Import(from_='fastapi', import_=param_is))
                default: Optional[
                    str] = f"{param_is}({'...' if field.required else repr(schema.default)}, alias='{orig_name}')"
            else:
                # https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.3.md#parameterObject
                # the spec says 'in' is a str type
                raise TypeError(
                    f'Issue processing parameter for "in", expected a str, but got something else: {str(parameter)}'
                )
        else:
            default = repr(schema.default) if schema.has_default else None
        return Argument(
            name=field.name,
            type_hint=field.type_hint,
            default=default,  # type: ignore
            default_value=schema.default,
            required=field.required,
        )
    def __init__(
        self,
        name: str,
        fields: List[DataModelField],
        decorators: Optional[List[str]] = None,
        base_classes: Optional[List[str]] = None,
        custom_base_class: Optional[str] = None,
        imports: Optional[List[Import]] = None,
        auto_import: bool = True,
        reference_classes: Optional[List[str]] = None,
    ) -> None:
        if not self.TEMPLATE_FILE_PATH:
            raise Exception('TEMPLATE_FILE_PATH is undefined')

        self.name: str = name
        self.fields: List[DataModelField] = fields or []
        self.decorators: List[str] = decorators or []
        self.imports: List[Import] = imports or []
        self.base_class: Optional[str] = None
        base_classes = [
            base_class for base_class in base_classes or [] if base_class
        ]
        self.base_classes: List[str] = base_classes

        self.reference_classes: List[str] = [
            r for r in base_classes if r != self.BASE_CLASS
        ] if base_classes else []
        if reference_classes:
            self.reference_classes.extend(reference_classes)

        if self.base_classes:
            self.base_class = ', '.join(self.base_classes)
        else:
            base_class_full_path = custom_base_class or self.BASE_CLASS
            if auto_import:
                if base_class_full_path:
                    self.imports.append(
                        Import.from_full_path(base_class_full_path))
            self.base_class = base_class_full_path.split('.')[-1]

        unresolved_types: Set[str] = set()
        for field in self.fields:
            unresolved_types.update(set(field.unresolved_types))

        self.reference_classes = list(
            set(self.reference_classes) | unresolved_types)

        if auto_import:
            for field in self.fields:
                self.imports.extend(field.imports)
        super().__init__(template_file_path=self.TEMPLATE_FILE_PATH)
 def get_data_type(self, schema: JsonSchemaObject) -> DataType:
     if schema.ref:
         data_type = self.open_api_model_parser.get_ref_data_type(schema.ref)
         data_type.imports_.append(
             Import(
                 # TODO: Improve import statements
                 from_=model_path_var.get(),
                 import_=data_type.type,
             )
         )
         return data_type
     elif schema.is_array:
         # TODO: Improve handling array
         items = schema.items if isinstance(schema.items, list) else [schema.items]
         return self.open_api_model_parser.data_type(
             data_types=[self.get_data_type(i) for i in items], is_list=True
         )
     return self.open_api_model_parser.get_data_type(schema)
Exemple #20
0
 def parse_request_body(
     self,
     name: str,
     request_body: RequestBodyObject,
     path: List[str],
 ) -> None:
     super().parse_request_body(name, request_body, path)
     arguments: List[Argument] = []
     for (
             media_type,
             media_obj,
     ) in request_body.content.items():  # type: str, MediaObject
         if isinstance(
                 media_obj.schema_,
             (JsonSchemaObject, ReferenceObject)):  # pragma: no cover
             # TODO: support other content-types
             if RE_APPLICATION_JSON_PATTERN.match(media_type):
                 if isinstance(media_obj.schema_, ReferenceObject):
                     data_type = self.get_ref_data_type(
                         media_obj.schema_.ref)
                 else:
                     data_type = self.parse_schema(name, media_obj.schema_,
                                                   [*path, media_type])
                 arguments.append(
                     # TODO: support multiple body
                     Argument(
                         name='body',  # type: ignore
                         type_hint=data_type.type_hint,
                         required=request_body.required,
                     ))
                 self.data_types.append(data_type)
             elif media_type == 'application/x-www-form-urlencoded':
                 arguments.append(
                     # TODO: support form with `Form()`
                     Argument(
                         name='request',  # type: ignore
                         type_hint='Request',  # type: ignore
                         required=True,
                     ))
                 self.imports_for_fastapi.append(
                     Import.from_full_path('starlette.requests.Request'))
     self._temporary_operation[
         '_request'] = arguments[0] if arguments else None
    def get_parameter_type(
        self, parameter: Dict[str, Union[str, Dict[str, Any]]], snake_case: bool
    ) -> Argument:
        ref: Optional[str] = parameter.get('$ref')  # type: ignore
        if ref:
            parameter = get_ref_body(ref, self.openapi_model_parser, self.components)
        name: str = parameter["name"]  # type: ignore
        orig_name = name
        if snake_case:
            name = stringcase.snakecase(name)
        content = parameter.get('content')
        schema: Optional[JsonSchemaObject] = None
        if content and isinstance(content, dict):
            content_schema = [
                c.get("schema")
                for c in content.values()
                if isinstance(c.get("schema"), dict)
            ]
            if content_schema:
                schema = JsonSchemaObject.parse_obj(content_schema[0])
        if not schema:
            schema = JsonSchemaObject.parse_obj(parameter["schema"])

        field = DataModelField(
            name=name,
            data_type=self.get_data_type(schema, 'parameter'),
            required=parameter.get("required") or parameter.get("in") == "path",
        )
        self.imports.extend(field.imports)
        if orig_name != name:
            default: Optional[
                str
            ] = f"Query({'...' if field.required else repr(schema.default)}, alias='{orig_name}')"
            self.imports.append(Import(from_='fastapi', import_='Query'))
        else:
            default = repr(schema.default) if schema.has_default else None
        return Argument(
            name=field.name,
            type_hint=field.type_hint,
            default=default,  # type: ignore
            default_value=schema.default,
            required=field.required,
        )
 def request(self) -> Optional[Argument]:
     arguments: List[Argument] = []
     for requests in self.request_objects:
         for content_type, schema in requests.contents.items():
             # TODO: support other content-types
             if content_type == "application/json":
                 arguments.append(
                     # TODO: support multiple body
                     Argument(
                         name='body',  # type: ignore
                         type_hint=schema.ref_object_name,
                         required=requests.required,
                     ))
                 self.imports.append(
                     Import(from_=model_path_var.get(),
                            import_=schema.ref_object_name))
     if not arguments:
         return None
     return arguments[0]
Exemple #23
0
 def parse_list_item(
     self, name: str, target_items: List[JsonSchemaObject], path: List[str]
 ) -> List[DataType]:
     data_types: List[DataType] = []
     for item in target_items:
         if item.ref:  # $ref
             data_types.append(
                 self.data_type(
                     type=self.model_resolver.add_ref(item.ref).name,
                     ref=True,
                     version_compatible=True,
                 )
             )
         elif not any(v for k, v in vars(item).items() if k != 'type'):
             # trivial types
             data_types.extend(self.get_data_type(item))
         elif (
             item.is_array
             and isinstance(item.items, JsonSchemaObject)
             and not any(v for k, v in vars(item.items).items() if k != 'type')
         ):
             # trivial item types
             types = [t.type_hint for t in self.get_data_type(item.items)]
             data_types.append(
                 self.data_type(
                     type=f"List[Union[{', '.join(types)}]]"
                     if len(types) > 1
                     else f"List[{types[0]}]",
                     imports_=[Import(from_='typing', import_='List')],
                 )
             )
         else:
             data_types.append(
                 self.data_type(
                     type=self.parse_object(
                         name, item, path, singular_name=True
                     ).name,
                     ref=True,
                     version_compatible=True,
                 )
             )
     return data_types
def generate_app_code(environment, parsed_object) -> str:
    template_path = Path('main.jinja2')
    grouped_operations = defaultdict(list)
    for k, g in itertools.groupby(
            parsed_object.operations,
            key=lambda x: x.path.strip('/').split('/')[0]):
        grouped_operations[k] += list(g)

    imports = Imports()
    routers = []
    for name, operations in grouped_operations.items():
        imports.append(
            Import(from_=CONTROLLERS_DIR_NAME + '.' + name,
                   import_=name + '_router'))

        routers.append(name + '_router')
    result = environment.get_template(str(template_path)).render(
        imports=imports,
        routers=routers,
    )

    return result
    def parse(
        self,
        with_import: Optional[bool] = True,
        format_: Optional[bool] = True,
        settings_path: Optional[Path] = None,
    ) -> Union[str, Dict[Tuple[str, ...], Result]]:

        self.parse_raw()

        if with_import:
            if self.target_python_version != PythonVersion.PY_36:
                self.imports.append(IMPORT_ANNOTATIONS)

        if format_:
            code_formatter: Optional[CodeFormatter] = CodeFormatter(
                self.target_python_version, settings_path)
        else:
            code_formatter = None

        _, sorted_data_models, require_update_action_models = sort_data_models(
            self.results)

        results: Dict[Tuple[str, ...], Result] = {}

        module_key = lambda x: x.module_path

        # process in reverse order to correctly establish module levels
        grouped_models = groupby(
            sorted(sorted_data_models.values(), key=module_key, reverse=True),
            key=module_key,
        )

        module_models: List[Tuple[Tuple[str, ...], List[DataModel]]] = []

        for module, models in ((k, [*v]) for k, v in grouped_models
                               ):  # type: Tuple[str, ...], List[DataModel]

            for model in models:
                if isinstance(model, self.data_model_root_type):
                    root_data_type = model.fields[0].data_type

                    # backward compatible
                    # Remove duplicated root model
                    if (root_data_type.reference and not root_data_type.is_dict
                            and not root_data_type.is_list
                            and root_data_type.reference.source in models
                            and root_data_type.reference.name
                            == self.model_resolver.get_class_name(
                                model.reference.original_name, unique=False)):
                        # Replace referenced duplicate model to original model
                        for child in model.reference.children[:]:
                            child.replace_reference(root_data_type.reference)
                        models.remove(model)
                        continue

                    #  Custom root model can't be inherited on restriction of Pydantic
                    for child in model.reference.children:
                        # inheritance model
                        if isinstance(child, DataModel):
                            for base_class in child.base_classes:
                                if base_class.reference == model.reference:
                                    child.base_classes.remove(base_class)

            module_models.append((
                module,
                models,
            ))

            scoped_model_resolver = ModelResolver(
                exclude_names={
                    i.alias or i.import_
                    for m in models for i in m.imports
                },
                duplicate_name_suffix='Model',
            )

            for model in models:
                class_name: str = model.class_name
                generated_name: str = scoped_model_resolver.add(
                    model.path, class_name, unique=True, class_name=True).name
                if class_name != generated_name:
                    if '.' in model.reference.name:
                        model.reference.name = (
                            f"{model.reference.name.rsplit('.', 1)[0]}.{generated_name}"
                        )
                    else:
                        model.reference.name = generated_name

        for module, models in module_models:
            init = False
            if module:
                parent = (*module[:-1], '__init__.py')
                if parent not in results:
                    results[parent] = Result(body='')
                if (*module, '__init__.py') in results:
                    module = (*module, '__init__.py')
                    init = True
                else:
                    module = (*module[:-1], f'{module[-1]}.py')
            else:
                module = ('__init__.py', )

            result: List[str] = []
            imports = Imports()
            scoped_model_resolver = ModelResolver()

            for model in models:
                imports.append(model.imports)
                for data_type in model.all_data_types:
                    # To change from/import

                    if not data_type.reference or data_type.reference.source in models:
                        # No need to import non-reference model.
                        # Or, Referenced model is in the same file. we don't need to import the model
                        continue

                    if isinstance(data_type, BaseClassDataType):
                        from_ = ''.join(
                            relative(model.module_name, data_type.full_name))
                        import_ = data_type.reference.short_name
                        full_path = from_, import_
                    else:
                        from_, import_ = full_path = relative(
                            model.module_name, data_type.full_name)

                    alias = scoped_model_resolver.add(full_path, import_).name

                    name = data_type.reference.short_name
                    if from_ and import_ and alias != name:
                        data_type.alias = f'{alias}.{name}'

                    if init:
                        from_ += "."
                    imports.append(
                        Import(from_=from_, import_=import_, alias=alias))

            if self.reuse_model:
                model_cache: Dict[Tuple[str, ...], Reference] = {}
                duplicates = []
                for model in models:
                    model_key = tuple(
                        to_hashable(v) for v in (
                            model.base_classes,
                            model.extra_template_data,
                            model.fields,
                        ))
                    cached_model_reference = model_cache.get(model_key)
                    if cached_model_reference:
                        if isinstance(model, Enum):
                            for child in model.reference.children[:]:
                                # child is resolved data_type by reference
                                data_model = get_most_of_parent(child)
                                # TODO: replace reference in all modules
                                if data_model in models:  # pragma: no cover
                                    child.replace_reference(
                                        cached_model_reference)
                            duplicates.append(model)
                        else:
                            index = models.index(model)
                            inherited_model = model.__class__(
                                fields=[],
                                base_classes=[cached_model_reference],
                                description=model.description,
                                reference=Reference(
                                    name=model.name,
                                    path=model.reference.path + '/reuse',
                                ),
                            )
                            if (cached_model_reference.path
                                    in require_update_action_models):
                                require_update_action_models.append(
                                    inherited_model.path)
                            models.insert(index, inherited_model)
                            models.remove(model)

                    else:
                        model_cache[model_key] = model.reference

                for duplicate in duplicates:
                    models.remove(duplicate)

            if self.set_default_enum_member:
                for model in models:
                    for model_field in model.fields:
                        if not model_field.default:
                            continue
                        for data_type in model_field.data_type.all_data_types:
                            if data_type.reference and isinstance(
                                    data_type.reference.source,
                                    Enum):  # pragma: no cover
                                enum_member = data_type.reference.source.find_member(
                                    model_field.default)
                                if enum_member:
                                    model_field.default = enum_member
            if with_import:
                result += [str(self.imports), str(imports), '\n']

            code = dump_templates(models)
            result += [code]

            if self.dump_resolve_reference_action is not None:
                result += [
                    '\n',
                    self.dump_resolve_reference_action(
                        m.reference.short_name for m in models
                        if m.path in require_update_action_models),
                ]

            body = '\n'.join(result)
            if code_formatter:
                body = code_formatter.format_code(body)

            results[module] = Result(body=body, source=models[0].file_path)

        # retain existing behaviour
        if [*results] == [('__init__.py', )]:
            return results[('__init__.py', )].body

        return results
 def get_data_type_from_full_path(self, full_path: str,
                                  is_custom_type: bool) -> DataType:
     return self.data_type.from_import(Import.from_full_path(full_path),
                                       is_custom_type=is_custom_type)
Exemple #27
0
    def __init__(
        self,
        name: str,
        fields: List[DataModelFieldBase],
        decorators: Optional[List[str]] = None,
        base_classes: Optional[List[str]] = None,
        custom_base_class: Optional[str] = None,
        custom_template_dir: Optional[Path] = None,
        extra_template_data: Optional[DefaultDict[str, Dict[str, Any]]] = None,
        imports: Optional[List[Import]] = None,
        auto_import: bool = True,
        reference_classes: Optional[List[str]] = None,
        methods: Optional[List[str]] = None,
    ) -> None:
        if not self.TEMPLATE_FILE_PATH:
            raise Exception('TEMPLATE_FILE_PATH is undefined')

        template_file_path = Path(self.TEMPLATE_FILE_PATH)
        if custom_template_dir is not None:
            custom_template_file_path = custom_template_dir / template_file_path.name
            if custom_template_file_path.exists():
                template_file_path = custom_template_file_path

        self.name: str = name
        self.fields: List[DataModelFieldBase] = fields or []
        self.decorators: List[str] = decorators or []
        self.imports: List[Import] = imports or []
        self.base_class: Optional[str] = None
        base_classes = [
            base_class for base_class in base_classes or [] if base_class
        ]
        self.base_classes: List[str] = base_classes

        self.reference_classes: List[str] = [
            r for r in base_classes if r != self.BASE_CLASS
        ] if base_classes else []
        if reference_classes:
            self.reference_classes.extend(reference_classes)

        if self.base_classes:
            self.base_class = ', '.join(self.base_classes)
        else:
            base_class_full_path = custom_base_class or self.BASE_CLASS
            if auto_import:
                if base_class_full_path:
                    self.imports.append(
                        Import.from_full_path(base_class_full_path))
            self.base_class = base_class_full_path.rsplit('.', 1)[-1]

        if '.' in name:
            module, class_name = name.rsplit('.', 1)
            prefix = f'{module}.'
            if self.base_class.startswith(prefix):
                self.base_class = self.base_class.replace(prefix, '', 1)
        else:
            class_name = name

        self.class_name: str = class_name

        self.extra_template_data = (extra_template_data[self.name]
                                    if extra_template_data is not None else
                                    defaultdict(dict))
        if extra_template_data:
            all_model_extra_template_data = extra_template_data.get(ALL_MODEL)
            if all_model_extra_template_data:
                self.extra_template_data.update(all_model_extra_template_data)

        unresolved_types: Set[str] = set()
        for field in self.fields:
            unresolved_types.update(set(field.unresolved_types))

        self.reference_classes = list(
            set(self.reference_classes) | unresolved_types)

        if auto_import:
            for field in self.fields:
                self.imports.extend(field.imports)

        self.methods: List[str] = methods or []

        super().__init__(template_file_path=template_file_path)
Exemple #28
0
            DataType(type='NegativeFloat', imports=[IMPORT_NEGATIVE_FLOAT]),
        ),
    ],
)
def test_get_data_float_type(types, params, data_type):
    assert DataTypeManager().get_data_float_type(types, **params) == data_type


@pytest.mark.parametrize(
    'types,params,data_type',
    [
        (
            Types.decimal,
            {},
            DataType(type='Decimal',
                     imports=[Import(from_='decimal', import_='Decimal')]),
        ),
        (
            Types.decimal,
            {
                'maximum': 10
            },
            DataType(
                type='condecimal',
                is_func=True,
                kwargs={'le': 10},
                imports=[IMPORT_CONDECIMAL],
            ),
        ),
        (
            Types.decimal,
from datamodel_code_generator.imports import Import

IMPORT_CONSTR = Import.from_full_path('pydantic.constr')
IMPORT_CONINT = Import.from_full_path('pydantic.conint')
IMPORT_CONFLOAT = Import.from_full_path('pydantic.confloat')
IMPORT_CONDECIMAL = Import.from_full_path('pydantic.condecimal')
IMPORT_CONBYTES = Import.from_full_path('pydantic.conbytes')
IMPORT_POSITIVE_INT = Import.from_full_path('pydantic.PositiveInt')
IMPORT_NEGATIVE_INT = Import.from_full_path('pydantic.NegativeInt')
IMPORT_NON_POSITIVE_INT = Import.from_full_path('pydantic.NonPositiveInt')
IMPORT_NON_NEGATIVE_INT = Import.from_full_path('pydantic.NonNegativeInt')
IMPORT_POSITIVE_FLOAT = Import.from_full_path('pydantic.PositiveFloat')
IMPORT_NEGATIVE_FLOAT = Import.from_full_path('pydantic.NegativeFloat')
IMPORT_NON_NEGATIVE_FLOAT = Import.from_full_path('pydantic.NonNegativeFloat')
IMPORT_NON_POSITIVE_FLOAT = Import.from_full_path('pydantic.NonPositiveFloat')
IMPORT_SECRET_STR = Import.from_full_path('pydantic.SecretStr')
IMPORT_EMAIL_STR = Import.from_full_path('pydantic.EmailStr')
IMPORT_UUID1 = Import.from_full_path('pydantic.UUID1')
IMPORT_UUID2 = Import.from_full_path('pydantic.UUID2')
IMPORT_UUID3 = Import.from_full_path('pydantic.UUID3')
IMPORT_UUID4 = Import.from_full_path('pydantic.UUID4')
IMPORT_UUID5 = Import.from_full_path('pydantic.UUID5')
IMPORT_ANYURL = Import.from_full_path('pydantic.AnyUrl')
IMPORT_IPV4ADDRESS = Import.from_full_path('ipaddress.IPv4Address')
IMPORT_IPV6ADDRESS = Import.from_full_path('ipaddress.IPv6Address')
IMPORT_EXTRA = Import.from_full_path('pydantic.Extra')
IMPORT_FIELD = Import.from_full_path('pydantic.Field')
IMPORT_STRICT_INT = Import.from_full_path('pydantic.StrictInt')
IMPORT_STRICT_FLOAT = Import.from_full_path('pydantic.StrictFloat')
IMPORT_STRICT_STR = Import.from_full_path('pydantic.StrictStr')
IMPORT_STRICT_BOOL = Import.from_full_path('pydantic.StrictBool')
            DataType(type='NegativeFloat', import_=IMPORT_NEGATIVE_FLOAT),
        ),
    ],
)
def test_get_data_float_type(types, params, data_type):
    assert DataTypeManager().get_data_float_type(types, **params) == data_type


@pytest.mark.parametrize(
    'types,params,data_type',
    [
        (
            Types.decimal,
            {},
            DataType(
                type='Decimal', import_=Import(from_='decimal', import_='Decimal')
            ),
        ),
        (
            Types.decimal,
            {'maximum': 10},
            DataType(
                type='condecimal',
                is_func=True,
                kwargs={'le': 10},
                import_=IMPORT_CONDECIMAL,
            ),
        ),
        (
            Types.decimal,
            {'exclusiveMaximum': 10},