예제 #1
0
 def request(self) -> Optional[Argument]:
     arguments: List[Argument] = []
     for requests in self.request_objects:
         for content_type, schema in requests.contents.items():
             # TODO: support other content-types
             if RE_APPLICATION_JSON_PATTERN.match(content_type):
                 data_type = self.get_data_type(schema, 'request')
                 arguments.append(
                     # TODO: support multiple body
                     Argument(
                         name='body',  # type: ignore
                         type_hint=data_type.type_hint,
                         required=requests.required,
                     ))
                 self.imports.extend(data_type.imports)
             elif content_type == 'application/x-www-form-urlencoded':
                 arguments.append(
                     # TODO: support form with `Form()`
                     Argument(
                         name='request',  # type: ignore
                         type_hint='Request',  # type: ignore
                         required=True,
                     ))
                 self.imports.append(
                     Import.from_full_path('starlette.requests.Request'))
     if not arguments:
         return None
     return arguments[0]
    def __init__(
        self,
        *,
        reference: Reference,
        fields: List[DataModelFieldBase],
        decorators: Optional[List[str]] = None,
        base_classes: Optional[List[Reference]] = None,
        custom_base_class: Optional[str] = None,
        custom_template_dir: Optional[Path] = None,
        extra_template_data: Optional[DefaultDict[str, Any]] = None,
        path: Optional[Path] = None,
        description: Optional[str] = None,
    ):

        super().__init__(
            reference=reference,
            fields=fields,
            decorators=decorators,
            base_classes=base_classes,
            custom_base_class=custom_base_class,
            custom_template_dir=custom_template_dir,
            extra_template_data=extra_template_data,
            path=path,
            description=description,
        )
        self._additional_imports.append(
            Import.from_full_path('pydantic.dataclasses.dataclass'))
예제 #3
0
    def __init__(
        self,
        name: str,
        fields: List[DataModelField],
        decorators: Optional[List[str]] = None,
        base_classes: Optional[List[str]] = None,
        custom_base_class: Optional[str] = None,
        custom_template_dir: Optional[Path] = None,
        extra_template_data: Optional[DefaultDict[str, Any]] = None,
        auto_import: bool = True,
        reference_classes: Optional[List[str]] = None,
    ):

        super().__init__(
            name,
            fields,
            decorators,
            base_classes,
            custom_base_class=custom_base_class,
            custom_template_dir=custom_template_dir,
            extra_template_data=extra_template_data,
            auto_import=auto_import,
            reference_classes=reference_classes,
        )
        self.imports.append(Import.from_full_path('pydantic.dataclasses.dataclass'))
예제 #4
0
    def __init__(
        self,
        *,
        reference: Reference,
        fields: List[DataModelFieldBase],
        decorators: Optional[List[str]] = None,
        base_classes: Optional[List[Reference]] = None,
        custom_base_class: Optional[str] = None,
        custom_template_dir: Optional[Path] = None,
        extra_template_data: Optional[DefaultDict[str, Dict[str, Any]]] = None,
        methods: Optional[List[str]] = None,
        path: Optional[Path] = None,
        description: Optional[str] = None,
    ) -> None:
        if not self.TEMPLATE_FILE_PATH:
            raise Exception('TEMPLATE_FILE_PATH is undefined')

        template_file_path = Path(self.TEMPLATE_FILE_PATH)
        if custom_template_dir is not None:
            custom_template_file_path = custom_template_dir / template_file_path.name
            if custom_template_file_path.exists():
                template_file_path = custom_template_file_path

        self.fields: List[DataModelFieldBase] = fields or []
        self.decorators: List[str] = decorators or []
        self._additional_imports: List[Import] = []
        self.base_classes: List[Reference] = [
            base_class for base_class in base_classes or [] if base_class
        ]
        self.custom_base_class = custom_base_class
        self.file_path: Optional[Path] = path
        self.reference: Reference = reference

        self.reference.source = self

        self.extra_template_data = (
            extra_template_data[self.name]
            if extra_template_data is not None
            else defaultdict(dict)
        )

        if not self.base_classes:
            base_class_full_path = custom_base_class or self.BASE_CLASS
            if base_class_full_path:
                self._additional_imports.append(
                    Import.from_full_path(base_class_full_path)
                )

        if extra_template_data:
            all_model_extra_template_data = extra_template_data.get(ALL_MODEL)
            if all_model_extra_template_data:
                self.extra_template_data.update(all_model_extra_template_data)

        self.methods: List[str] = methods or []

        self.description = description
        for field in self.fields:
            field.parent = self

        super().__init__(template_file_path=template_file_path)
예제 #5
0
    def __init__(
        self,
        name: str,
        fields: List[DataModelField],
        decorators: Optional[List[str]] = None,
        base_classes: Optional[List[str]] = None,
        custom_base_class: Optional[str] = None,
        imports: Optional[List[Import]] = None,
        auto_import: bool = True,
        reference_classes: Optional[List[str]] = None,
    ) -> None:
        if not self.TEMPLATE_FILE_PATH:
            raise Exception('TEMPLATE_FILE_PATH is undefined')

        self.name: str = name
        self.fields: List[DataModelField] = fields or []
        self.decorators: List[str] = decorators or []
        self.imports: List[Import] = imports or []
        self.base_class: Optional[str] = None
        base_classes = [
            base_class for base_class in base_classes or [] if base_class
        ]
        self.base_classes: List[str] = base_classes

        self.reference_classes: List[str] = [
            r for r in base_classes if r != self.BASE_CLASS
        ] if base_classes else []
        if reference_classes:
            self.reference_classes.extend(reference_classes)

        if self.base_classes:
            self.base_class = ', '.join(self.base_classes)
        else:
            base_class_full_path = custom_base_class or self.BASE_CLASS
            if auto_import:
                if base_class_full_path:
                    self.imports.append(
                        Import.from_full_path(base_class_full_path))
            self.base_class = base_class_full_path.split('.')[-1]

        unresolved_types: Set[str] = set()
        for field in self.fields:
            unresolved_types.update(set(field.unresolved_types))

        self.reference_classes = list(
            set(self.reference_classes) | unresolved_types)

        if auto_import:
            for field in self.fields:
                self.imports.extend(field.imports)
        super().__init__(template_file_path=self.TEMPLATE_FILE_PATH)
예제 #6
0
 def parse_request_body(
     self,
     name: str,
     request_body: RequestBodyObject,
     path: List[str],
 ) -> None:
     super().parse_request_body(name, request_body, path)
     arguments: List[Argument] = []
     for (
             media_type,
             media_obj,
     ) in request_body.content.items():  # type: str, MediaObject
         if isinstance(
                 media_obj.schema_,
             (JsonSchemaObject, ReferenceObject)):  # pragma: no cover
             # TODO: support other content-types
             if RE_APPLICATION_JSON_PATTERN.match(media_type):
                 if isinstance(media_obj.schema_, ReferenceObject):
                     data_type = self.get_ref_data_type(
                         media_obj.schema_.ref)
                 else:
                     data_type = self.parse_schema(name, media_obj.schema_,
                                                   [*path, media_type])
                 arguments.append(
                     # TODO: support multiple body
                     Argument(
                         name='body',  # type: ignore
                         type_hint=data_type.type_hint,
                         required=request_body.required,
                     ))
                 self.data_types.append(data_type)
             elif media_type == 'application/x-www-form-urlencoded':
                 arguments.append(
                     # TODO: support form with `Form()`
                     Argument(
                         name='request',  # type: ignore
                         type_hint='Request',  # type: ignore
                         required=True,
                     ))
                 self.imports_for_fastapi.append(
                     Import.from_full_path('starlette.requests.Request'))
     self._temporary_operation[
         '_request'] = arguments[0] if arguments else None
예제 #7
0
 def get_data_type_from_full_path(self, full_path: str,
                                  is_custom_type: bool) -> DataType:
     return self.data_type.from_import(Import.from_full_path(full_path),
                                       is_custom_type=is_custom_type)
예제 #8
0
    def __init__(
        self,
        name: str,
        fields: List[DataModelFieldBase],
        decorators: Optional[List[str]] = None,
        base_classes: Optional[List[str]] = None,
        custom_base_class: Optional[str] = None,
        custom_template_dir: Optional[Path] = None,
        extra_template_data: Optional[DefaultDict[str, Dict[str, Any]]] = None,
        imports: Optional[List[Import]] = None,
        auto_import: bool = True,
        reference_classes: Optional[List[str]] = None,
        methods: Optional[List[str]] = None,
    ) -> None:
        if not self.TEMPLATE_FILE_PATH:
            raise Exception('TEMPLATE_FILE_PATH is undefined')

        template_file_path = Path(self.TEMPLATE_FILE_PATH)
        if custom_template_dir is not None:
            custom_template_file_path = custom_template_dir / template_file_path.name
            if custom_template_file_path.exists():
                template_file_path = custom_template_file_path

        self.name: str = name
        self.fields: List[DataModelFieldBase] = fields or []
        self.decorators: List[str] = decorators or []
        self.imports: List[Import] = imports or []
        self.base_class: Optional[str] = None
        base_classes = [
            base_class for base_class in base_classes or [] if base_class
        ]
        self.base_classes: List[str] = base_classes

        self.reference_classes: List[str] = [
            r for r in base_classes if r != self.BASE_CLASS
        ] if base_classes else []
        if reference_classes:
            self.reference_classes.extend(reference_classes)

        if self.base_classes:
            self.base_class = ', '.join(self.base_classes)
        else:
            base_class_full_path = custom_base_class or self.BASE_CLASS
            if auto_import:
                if base_class_full_path:
                    self.imports.append(
                        Import.from_full_path(base_class_full_path))
            self.base_class = base_class_full_path.rsplit('.', 1)[-1]

        if '.' in name:
            module, class_name = name.rsplit('.', 1)
            prefix = f'{module}.'
            if self.base_class.startswith(prefix):
                self.base_class = self.base_class.replace(prefix, '', 1)
        else:
            class_name = name

        self.class_name: str = class_name

        self.extra_template_data = (extra_template_data[self.name]
                                    if extra_template_data is not None else
                                    defaultdict(dict))
        if extra_template_data:
            all_model_extra_template_data = extra_template_data.get(ALL_MODEL)
            if all_model_extra_template_data:
                self.extra_template_data.update(all_model_extra_template_data)

        unresolved_types: Set[str] = set()
        for field in self.fields:
            unresolved_types.update(set(field.unresolved_types))

        self.reference_classes = list(
            set(self.reference_classes) | unresolved_types)

        if auto_import:
            for field in self.fields:
                self.imports.extend(field.imports)

        self.methods: List[str] = methods or []

        super().__init__(template_file_path=template_file_path)
예제 #9
0
from datamodel_code_generator.imports import Import

IMPORT_CONSTR = Import.from_full_path('pydantic.constr')
IMPORT_CONINT = Import.from_full_path('pydantic.conint')
IMPORT_CONFLOAT = Import.from_full_path('pydantic.confloat')
IMPORT_CONDECIMAL = Import.from_full_path('pydantic.condecimal')
IMPORT_CONBYTES = Import.from_full_path('pydantic.conbytes')
IMPORT_POSITIVE_INT = Import.from_full_path('pydantic.PositiveInt')
IMPORT_NEGATIVE_INT = Import.from_full_path('pydantic.NegativeInt')
IMPORT_NON_POSITIVE_INT = Import.from_full_path('pydantic.NonPositiveInt')
IMPORT_NON_NEGATIVE_INT = Import.from_full_path('pydantic.NonNegativeInt')
IMPORT_POSITIVE_FLOAT = Import.from_full_path('pydantic.PositiveFloat')
IMPORT_NEGATIVE_FLOAT = Import.from_full_path('pydantic.NegativeFloat')
IMPORT_NON_NEGATIVE_FLOAT = Import.from_full_path('pydantic.NonNegativeFloat')
IMPORT_NON_POSITIVE_FLOAT = Import.from_full_path('pydantic.NonPositiveFloat')
IMPORT_SECRET_STR = Import.from_full_path('pydantic.SecretStr')
IMPORT_EMAIL_STR = Import.from_full_path('pydantic.EmailStr')
IMPORT_UUID1 = Import.from_full_path('pydantic.UUID1')
IMPORT_UUID2 = Import.from_full_path('pydantic.UUID2')
IMPORT_UUID3 = Import.from_full_path('pydantic.UUID3')
IMPORT_UUID4 = Import.from_full_path('pydantic.UUID4')
IMPORT_UUID5 = Import.from_full_path('pydantic.UUID5')
IMPORT_ANYURL = Import.from_full_path('pydantic.AnyUrl')
IMPORT_IPV4ADDRESS = Import.from_full_path('ipaddress.IPv4Address')
IMPORT_IPV6ADDRESS = Import.from_full_path('ipaddress.IPv6Address')
IMPORT_EXTRA = Import.from_full_path('pydantic.Extra')
IMPORT_FIELD = Import.from_full_path('pydantic.Field')
IMPORT_STRICT_INT = Import.from_full_path('pydantic.StrictInt')
IMPORT_STRICT_FLOAT = Import.from_full_path('pydantic.StrictFloat')
IMPORT_STRICT_STR = Import.from_full_path('pydantic.StrictStr')
IMPORT_STRICT_BOOL = Import.from_full_path('pydantic.StrictBool')
예제 #10
0
 def set_base_class(self) -> None:
     base_class_import = Import.from_full_path(self.custom_base_class
                                               or self.BASE_CLASS)
     self._additional_imports.append(base_class_import)
     self.base_classes = [BaseClassDataType.from_import(base_class_import)]
예제 #11
0
def generate_code(
    input_name: str,
    input_text: str,
    output_dir: Path,
    template_dir: Optional[Path],
    model_path: Optional[Path] = None,
    enum_field_as_literal: Optional[str] = None,
) -> None:
    if not model_path:
        model_path = MODEL_PATH
    if not output_dir.exists():
        output_dir.mkdir(parents=True)
    if not template_dir:
        template_dir = BUILTIN_TEMPLATE_DIR
    if enum_field_as_literal:
        parser = OpenAPIParser(input_text, enum_field_as_literal=enum_field_as_literal)
    else:
        parser = OpenAPIParser(input_text)
    with chdir(output_dir):
        models = parser.parse()
    if not models:
        return
    elif isinstance(models, str):
        output = output_dir / model_path
        modules = {output: (models, input_name)}
    else:
        raise Exception('Modular references are not supported in this version')

    environment: Environment = Environment(
        loader=FileSystemLoader(
            template_dir if template_dir else f"{Path(__file__).parent}/template",
            encoding="utf8",
        ),
    )
    imports = Imports()
    imports.update(parser.imports)
    for data_type in parser.data_types:
        reference = _get_most_of_reference(data_type)
        if reference:
            imports.append(data_type.all_imports)
            imports.append(
                Import.from_full_path(f'.{model_path.stem}.{reference.name}')
            )
    for from_, imports_ in parser.imports_for_fastapi.items():
        imports[from_].update(imports_)
    results: Dict[Path, str] = {}
    code_formatter = CodeFormatter(PythonVersion.PY_38, Path().resolve())
    sorted_operations: List[Operation] = sorted(
        parser.operations.values(), key=lambda m: m.path
    )
    for target in template_dir.rglob("*"):
        relative_path = target.relative_to(template_dir)
        result = environment.get_template(str(relative_path)).render(
            operations=sorted_operations, imports=imports, info=parser.parse_info(),
        )
        results[relative_path] = code_formatter.format_code(result)

    timestamp = datetime.now(timezone.utc).replace(microsecond=0).isoformat()
    header = f"""\
# generated by fastapi-codegen:
#   filename:  {Path(input_name).name}
#   timestamp: {timestamp}"""

    for path, code in results.items():
        with output_dir.joinpath(path.with_suffix(".py")).open("wt") as file:
            print(header, file=file)
            print("", file=file)
            print(code.rstrip(), file=file)

    header = f'''\
# generated by fastapi-codegen:
#   filename:  {{filename}}'''
    #     if not disable_timestamp:
    header += f'\n#   timestamp: {timestamp}'

    for path, body_and_filename in modules.items():
        body, filename = body_and_filename
        if path is None:
            file = None
        else:
            if not path.parent.exists():
                path.parent.mkdir(parents=True)
            file = path.open('wt', encoding='utf8')

        print(header.format(filename=filename), file=file)
        if body:
            print('', file=file)
            print(body.rstrip(), file=file)

        if file is not None:
            file.close()
예제 #12
0
from datamodel_code_generator.imports import Import

IMPORT_CONSTR = Import.from_full_path('pydantic.constr')
IMPORT_CONINT = Import.from_full_path('pydantic.conint')
IMPORT_CONFLOAT = Import.from_full_path('pydantic.confloat')
IMPORT_CONDECIMAL = Import.from_full_path('pydantic.condeciaml')
IMPORT_POSITIVE_INT = Import.from_full_path('pydantic.PositiveInt')
IMPORT_NEGATIVE_INT = Import.from_full_path('pydantic.NegativeInt')
IMPORT_POSITIVE_FLOAT = Import.from_full_path('pydantic.PositiveFloat')
IMPORT_NEGATIVE_FLOAT = Import.from_full_path('pydantic.NegativeFloat')
IMPORT_SECRET_STR = Import.from_full_path('pydantic.SecretStr')
IMPORT_EMAIL_STR = Import.from_full_path('pydantic.EmailStr')
IMPORT_UUID1 = Import.from_full_path('pydantic.UUID1')
IMPORT_UUID2 = Import.from_full_path('pydantic.UUID2')
IMPORT_UUID3 = Import.from_full_path('pydantic.UUID3')
IMPORT_UUID4 = Import.from_full_path('pydantic.UUID4')
IMPORT_UUID5 = Import.from_full_path('pydantic.UUID5')
IMPORT_ANYURL = Import.from_full_path('pydantic.AnyUrl')
IMPORT_IPV4ADDRESS = Import.from_full_path('pydantic.IPv4Address')
IMPORT_IPV6ADDRESS = Import.from_full_path('pydantic.IPv6Address')
IMPORT_EXTRA = Import.from_full_path('pydantic.Extra')
IMPORT_FIELD = Import.from_full_path('pydantic.Field')