Exemplo n.º 1
0
    def get_data_type(self, schema: JsonSchemaObject, suffix: str = '') -> DataType:
        if schema.ref:
            data_type = self.openapi_model_parser.get_ref_data_type(schema.ref)
            self.imports.append(
                Import(
                    # TODO: Improve import statements
                    from_=model_module_name_var.get(),
                    import_=data_type.type_hint,
                )
            )
            return data_type
        elif schema.is_array:
            # TODO: Improve handling array
            items = schema.items if isinstance(schema.items, list) else [schema.items]
            return self.openapi_model_parser.data_type(
                data_types=[self.get_data_type(i, suffix) for i in items], is_list=True
            )
        elif schema.is_object:
            camelcase_path = stringcase.camelcase(self.path[1:].replace("/", "_"))
            capitalized_suffix = suffix.capitalize()
            name: str = f'{camelcase_path}{self.type.capitalize()}{capitalized_suffix}'
            path = ['paths', self.path, self.type, capitalized_suffix]

            data_type = self.openapi_model_parser.parse_object(name, schema, path)

            self.imports.append(
                Import(from_=model_module_name_var.get(), import_=data_type.type_hint,)
            )
            return data_type

        return self.openapi_model_parser.get_data_type(schema)
Exemplo n.º 2
0
    def get_parameter_type(
        self, parameter: Dict[str, Union[str, Dict[str, str]]], snake_case: bool
    ) -> Argument:
        ref: Optional[str] = parameter.get('$ref')  # type: ignore
        if ref:
            parameter = get_ref_body(ref, self.openapi_model_parser, self.components)
        name: str = parameter["name"]  # type: ignore
        orig_name = name
        if snake_case:
            name = stringcase.snakecase(name)
        schema: JsonSchemaObject = JsonSchemaObject.parse_obj(parameter["schema"])

        field = DataModelField(
            name=name,
            data_type=self.get_data_type(schema),
            required=parameter.get("required") or parameter.get("in") == "path",
        )
        self.imports.extend(field.imports)
        if orig_name != name:
            default: Optional[
                str
            ] = f"Query({'...' if field.required else repr(schema.default)}, alias='{orig_name}')"
            self.imports.append(Import(from_='fastapi', import_='Query'))
        else:
            default = repr(schema.default) if 'default' in parameter["schema"] else None
        return Argument(
            name=field.name,
            type_hint=field.type_hint,
            default=default,  # type: ignore
            default_value=schema.default,
            required=field.required,
        )
 def parse_any_of(self, name: str, obj: JsonSchemaObject) -> List[DataType]:
     any_of_data_types: List[DataType] = []
     for any_of_item in obj.anyOf:
         if any_of_item.ref:  # $ref
             any_of_data_types.append(
                 self.data_type(
                     type=any_of_item.ref_object_name,
                     ref=True,
                     version_compatible=True,
                 ))
         elif not any(v
                      for k, v in vars(any_of_item).items() if k != 'type'):
             # trivial types
             any_of_data_types.append(self.get_data_type(any_of_item))
         elif (any_of_item.is_array
               and isinstance(any_of_item.items, JsonSchemaObject)
               and not any(v for k, v in vars(any_of_item.items).items()
                           if k != 'type')):
             # trivial item types
             any_of_data_types.append(
                 self.data_type(
                     type=
                     f"List[{self.get_data_type(any_of_item.items).type_hint}]",
                     imports_=[Import(from_='typing', import_='List')],
                 ))
         else:
             singular_name = get_singular_name(name)
             self.parse_object(singular_name, any_of_item)
             any_of_data_types.append(
                 self.data_type(type=singular_name,
                                ref=True,
                                version_compatible=True))
     return any_of_data_types
Exemplo n.º 4
0
    def get_parameter_type(
        self, parameter: Dict[str, Union[str, Dict[str, str]]], snake_case: bool
    ) -> Argument:
        schema: JsonSchemaObject = JsonSchemaObject.parse_obj(parameter["schema"])
        format_ = schema.format or "default"
        type_ = json_schema_data_formats[schema.type][format_]
        name: str = parameter["name"]  # type: ignore
        orig_name = name
        if snake_case:
            name = stringcase.snakecase(name)

        field = DataModelField(
            name=name,
            data_type=type_map[type_],
            required=parameter.get("required") or parameter.get("in") == "path",
        )
        self.imports.extend(field.imports)
        if orig_name != name:
            default: Optional[
                str
            ] = f"Query({'...' if field.required else repr(schema.default)}, alias='{orig_name}')"
            self.imports.append(Import(from_='fastapi', import_='Query'))
        else:
            default = repr(schema.default) if 'default' in parameter["schema"] else None
        return Argument(
            name=field.name,
            type_hint=field.type_hint,
            default=default,  # type: ignore
            default_value=schema.default,
            required=field.required,
        )
Exemplo n.º 5
0
    def __init__(
        self,
        name: str,
        fields: List[DataModelField],
        decorators: Optional[List[str]] = None,
        base_classes: Optional[List[str]] = None,
        custom_base_class: Optional[str] = None,
        custom_template_dir: Optional[Path] = None,
        extra_template_data: Optional[DefaultDict[str, Any]] = None,
        auto_import: bool = True,
        reference_classes: Optional[List[str]] = None,
        imports: Optional[List[Import]] = None,
    ):

        super().__init__(
            name=name,
            fields=fields,
            decorators=decorators,
            base_classes=base_classes,
            custom_base_class=custom_base_class,
            custom_template_dir=custom_template_dir,
            extra_template_data=extra_template_data,
            auto_import=auto_import,
            reference_classes=reference_classes,
            imports=imports,
        )

        if 'additionalProperties' in self.extra_template_data:
            self.extra_template_data['config'] = Config(extra='Extra.allow')
            self.imports.append(Import(from_='pydantic', import_='Extra'))
Exemplo n.º 6
0
def test_dump(inputs: Sequence[Tuple[Optional[str], str]], value):
    """Test creating import lines."""

    imports = Imports()
    imports.append(
        [Import(from_=from_, import_=import_) for from_, import_ in inputs])

    assert str(imports) == value
Exemplo n.º 7
0
    def __init__(
        self,
        name: str,
        fields: List[DataModelField],
        decorators: Optional[List[str]] = None,
        base_classes: Optional[List[str]] = None,
        custom_base_class: Optional[str] = None,
        custom_template_dir: Optional[Path] = None,
        extra_template_data: Optional[DefaultDict[str, Any]] = None,
        auto_import: bool = True,
        reference_classes: Optional[List[str]] = None,
        imports: Optional[List[Import]] = None,
    ):

        methods: List[str] = [field.method for field in fields if field.method]

        super().__init__(
            name=name,
            fields=fields,  # type: ignore
            decorators=decorators,
            base_classes=base_classes,
            custom_base_class=custom_base_class,
            custom_template_dir=custom_template_dir,
            extra_template_data=extra_template_data,
            auto_import=auto_import,
            reference_classes=reference_classes,
            imports=imports,
            methods=methods,
        )

        config_parameters: Dict[str, Any] = {}

        if 'additionalProperties' in self.extra_template_data:
            config_parameters['extra'] = 'Extra.allow'
            self.imports.append(Import(from_='pydantic', import_='Extra'))

        if config_parameters:
            from datamodel_code_generator.model.pydantic import Config

            self.extra_template_data['config'] = Config.parse_obj(
                config_parameters)

        for field in fields:
            if field.field:
                self.imports.append(Import(from_='pydantic', import_='Field'))
Exemplo n.º 8
0
def test_get_data_type(schema_type, schema_format, result_type, from_, import_):
    if from_ and import_:
        imports_: Optional[List[Import]] = [Import(from_=from_, import_=import_)]
    else:
        imports_ = None

    parser = OpenAPIParser(BaseModel, CustomRootType)
    assert parser.get_data_type(
        JsonSchemaObject(type=schema_type, format=schema_format)
    ) == DataType(type=result_type, imports_=imports_)
Exemplo n.º 9
0
def test_get_data_type(schema_type, schema_format, result_type, from_, import_):
    if from_ and import_:
        imports: Optional[List[Import]] = [Import(from_=from_, import_=import_)]
    else:
        imports = []

    parser = JsonSchemaParser('')
    assert parser.get_data_type(
        JsonSchemaObject(type=schema_type, format=schema_format)
    ) == DataType(type=result_type, imports=imports)
Exemplo n.º 10
0
    def get_parameter_type(
        self,
        parameters: ParameterObject,
        snake_case: bool,
        path: List[str],
    ) -> Optional[Argument]:
        orig_name = parameters.name
        if snake_case:
            name = stringcase.snakecase(parameters.name)
        else:
            name = parameters.name

        schema: Optional[JsonSchemaObject] = None
        data_type: Optional[DataType] = None
        for content in parameters.content.values():
            if isinstance(content.schema_, ReferenceObject):
                data_type = self.get_ref_data_type(content.schema_.ref)
                ref_model = self.get_ref_model(content.schema_.ref)
                schema = JsonSchemaObject.parse_obj(ref_model)
            else:
                schema = content.schema_
            break
        if not data_type:
            if not schema:
                schema = parameters.schema_
            data_type = self.parse_schema(name, schema, [*path, name])
        if not schema:
            return None

        field = DataModelField(
            name=name,
            data_type=data_type,
            required=parameters.required
            or parameters.in_ == ParameterLocation.path,
        )

        if orig_name != name:
            if parameters.in_:
                param_is = parameters.in_.value.lower().capitalize()
                self.imports_for_fastapi.append(
                    Import(from_='fastapi', import_=param_is))
                default: Optional[
                    str] = f"{param_is}({'...' if field.required else repr(schema.default)}, alias='{orig_name}')"
        else:
            default = repr(schema.default) if schema.has_default else None
        self.imports_for_fastapi.append(field.imports)
        self.data_types.append(field.data_type)
        return Argument(
            name=field.name,
            type_hint=field.type_hint,
            default=default,  # type: ignore
            default_value=schema.default,
            required=field.required,
        )
Exemplo n.º 11
0
    def response(self) -> str:
        models: List[str] = []
        for response in self.response_objects:
            # expect 2xx
            if response.status_code.startswith("2"):
                for content_type, schema in response.contents.items():
                    if content_type == "application/json":
                        if schema.is_array:
                            if isinstance(schema.items, list):
                                type_ = f'List[{",".join(i.ref_object_name for i in schema.items)}]'
                                self.imports.extend(
                                    Import(
                                        from_=model_path_var.get(),
                                        import_=i.ref_object_name,
                                    )
                                    for i in schema.items
                                )
                            else:
                                type_ = f'List[{schema.items.ref_object_name}]'
                                self.imports.append(
                                    Import(
                                        from_=model_path_var.get(),
                                        import_=schema.items.ref_object_name,
                                    )
                                )
                            self.imports.append(IMPORT_LIST)
                        else:
                            type_ = schema.ref_object_name
                            self.imports.append(
                                Import(
                                    from_=model_path_var.get(),
                                    import_=schema.ref_object_name,
                                )
                            )
                        models.append(type_)

        if not models:
            return "None"
        if len(models) > 1:
            return f'Union[{",".join(models)}]'
        return models[0]
def test_get_data_type(schema_type, schema_format, result_type, from_,
                       import_):
    if from_ and import_:
        import_: Optional[Import] = Import(from_=from_, import_=import_)
    else:
        import_ = None

    parser = JsonSchemaParser('')
    assert (parser.get_data_type(
        JsonSchemaObject(type=schema_type,
                         format=schema_format)).dict() == DataType(
                             type=result_type, import_=import_).dict())
Exemplo n.º 13
0
    def get_parameter_type(self, parameter: Dict[str, Union[str, Dict[str,
                                                                      Any]]],
                           snake_case: bool) -> Argument:
        ref: Optional[str] = parameter.get('$ref')  # type: ignore
        if ref:
            parameter = get_ref_body(ref, self.openapi_model_parser,
                                     self.components)
        name: str = parameter["name"]  # type: ignore
        orig_name = name
        if snake_case:
            name = stringcase.snakecase(name)
        content = parameter.get('content')
        schema: Optional[JsonSchemaObject] = None
        if content and isinstance(content, dict):
            content_schema = [
                c.get("schema") for c in content.values()
                if isinstance(c.get("schema"), dict)
            ]
            if content_schema:
                schema = JsonSchemaObject.parse_obj(content_schema[0])
        if not schema:
            schema = JsonSchemaObject.parse_obj(parameter["schema"])

        field = DataModelField(
            name=name,
            data_type=self.get_data_type(schema, 'parameter'),
            required=parameter.get("required")
            or parameter.get("in") == "path",
        )
        self.imports.extend(field.imports)
        if orig_name != name:
            has_in = parameter.get('in')
            if has_in and isinstance(has_in, str):
                param_is = has_in.lower().capitalize()
                self.imports.append(Import(from_='fastapi', import_=param_is))
                default: Optional[
                    str] = f"{param_is}({'...' if field.required else repr(schema.default)}, alias='{orig_name}')"
            else:
                # https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.3.md#parameterObject
                # the spec says 'in' is a str type
                raise TypeError(
                    f'Issue processing parameter for "in", expected a str, but got something else: {str(parameter)}'
                )
        else:
            default = repr(schema.default) if schema.has_default else None
        return Argument(
            name=field.name,
            type_hint=field.type_hint,
            default=default,  # type: ignore
            default_value=schema.default,
            required=field.required,
        )
Exemplo n.º 14
0
 def get_data_type(self, schema: JsonSchemaObject) -> DataType:
     if schema.ref:
         data_type = self.open_api_model_parser.get_ref_data_type(schema.ref)
         data_type.imports_.append(
             Import(
                 # TODO: Improve import statements
                 from_=model_path_var.get(),
                 import_=data_type.type,
             )
         )
         return data_type
     elif schema.is_array:
         # TODO: Improve handling array
         items = schema.items if isinstance(schema.items, list) else [schema.items]
         return self.open_api_model_parser.data_type(
             data_types=[self.get_data_type(i) for i in items], is_list=True
         )
     return self.open_api_model_parser.get_data_type(schema)
Exemplo n.º 15
0
    def get_parameter_type(
        self, parameter: Dict[str, Union[str, Dict[str, Any]]], snake_case: bool
    ) -> Argument:
        ref: Optional[str] = parameter.get('$ref')  # type: ignore
        if ref:
            parameter = get_ref_body(ref, self.openapi_model_parser, self.components)
        name: str = parameter["name"]  # type: ignore
        orig_name = name
        if snake_case:
            name = stringcase.snakecase(name)
        content = parameter.get('content')
        schema: Optional[JsonSchemaObject] = None
        if content and isinstance(content, dict):
            content_schema = [
                c.get("schema")
                for c in content.values()
                if isinstance(c.get("schema"), dict)
            ]
            if content_schema:
                schema = JsonSchemaObject.parse_obj(content_schema[0])
        if not schema:
            schema = JsonSchemaObject.parse_obj(parameter["schema"])

        field = DataModelField(
            name=name,
            data_type=self.get_data_type(schema, 'parameter'),
            required=parameter.get("required") or parameter.get("in") == "path",
        )
        self.imports.extend(field.imports)
        if orig_name != name:
            default: Optional[
                str
            ] = f"Query({'...' if field.required else repr(schema.default)}, alias='{orig_name}')"
            self.imports.append(Import(from_='fastapi', import_='Query'))
        else:
            default = repr(schema.default) if schema.has_default else None
        return Argument(
            name=field.name,
            type_hint=field.type_hint,
            default=default,  # type: ignore
            default_value=schema.default,
            required=field.required,
        )
Exemplo n.º 16
0
 def request(self) -> Optional[Argument]:
     arguments: List[Argument] = []
     for requests in self.request_objects:
         for content_type, schema in requests.contents.items():
             # TODO: support other content-types
             if content_type == "application/json":
                 arguments.append(
                     # TODO: support multiple body
                     Argument(
                         name='body',  # type: ignore
                         type_hint=schema.ref_object_name,
                         required=requests.required,
                     ))
                 self.imports.append(
                     Import(from_=model_path_var.get(),
                            import_=schema.ref_object_name))
     if not arguments:
         return None
     return arguments[0]
Exemplo n.º 17
0
 def parse_list_item(
     self, name: str, target_items: List[JsonSchemaObject], path: List[str]
 ) -> List[DataType]:
     data_types: List[DataType] = []
     for item in target_items:
         if item.ref:  # $ref
             data_types.append(
                 self.data_type(
                     type=self.model_resolver.add_ref(item.ref).name,
                     ref=True,
                     version_compatible=True,
                 )
             )
         elif not any(v for k, v in vars(item).items() if k != 'type'):
             # trivial types
             data_types.extend(self.get_data_type(item))
         elif (
             item.is_array
             and isinstance(item.items, JsonSchemaObject)
             and not any(v for k, v in vars(item.items).items() if k != 'type')
         ):
             # trivial item types
             types = [t.type_hint for t in self.get_data_type(item.items)]
             data_types.append(
                 self.data_type(
                     type=f"List[Union[{', '.join(types)}]]"
                     if len(types) > 1
                     else f"List[{types[0]}]",
                     imports_=[Import(from_='typing', import_='List')],
                 )
             )
         else:
             data_types.append(
                 self.data_type(
                     type=self.parse_object(
                         name, item, path, singular_name=True
                     ).name,
                     ref=True,
                     version_compatible=True,
                 )
             )
     return data_types
Exemplo n.º 18
0
def generate_app_code(environment, parsed_object) -> str:
    template_path = Path('main.jinja2')
    grouped_operations = defaultdict(list)
    for k, g in itertools.groupby(
            parsed_object.operations,
            key=lambda x: x.path.strip('/').split('/')[0]):
        grouped_operations[k] += list(g)

    imports = Imports()
    routers = []
    for name, operations in grouped_operations.items():
        imports.append(
            Import(from_=CONTROLLERS_DIR_NAME + '.' + name,
                   import_=name + '_router'))

        routers.append(name + '_router')
    result = environment.get_template(str(template_path)).render(
        imports=imports,
        routers=routers,
    )

    return result
)
from datamodel_code_generator.types import DataType, Types

type_map: Dict[Types, DataType] = {
    Types.integer: DataType(type='int'),
    Types.int32: DataType(type='int'),
    Types.int64: DataType(type='int'),
    Types.number: DataType(type='float'),
    Types.float: DataType(type='float'),
    Types.double: DataType(type='float'),
    Types.time: DataType(type='time'),
    Types.string: DataType(type='str'),
    Types.byte: DataType(type='str'),  # base64 encoded string
    Types.binary: DataType(type='bytes'),
    Types.date: DataType(
        type='date', imports_=[Import(from_='datetime', import_='date')]
    ),
    Types.date_time: DataType(
        type='datetime', imports_=[Import(from_='datetime', import_='datetime')]
    ),
    Types.password: DataType(
        type='SecretStr', imports_=[Import(from_='pydantic', import_='SecretStr')]
    ),
    Types.email: DataType(
        type='EmailStr', imports_=[Import(from_='pydantic', import_='EmailStr')]
    ),
    Types.uuid: DataType(type='UUID', imports_=[Import(from_='uuid', import_='UUID')]),
    Types.uuid1: DataType(
        type='UUID1', imports_=[Import(from_='pydantic', import_='UUID1')]
    ),
    Types.uuid2: DataType(
Exemplo n.º 20
0
    def parse_object_fields(self, obj: JsonSchemaObject) -> List[DataModelFieldBase]:
        properties: Dict[str, JsonSchemaObject] = (
            obj.properties if obj.properties is not None else {}
        )
        requires: Set[str] = {*obj.required} if obj.required is not None else {*()}
        fields: List[DataModelFieldBase] = []

        for field_name, field in properties.items():
            is_list: bool = False
            is_union: bool = False
            field_types: List[DataType]
            field_name, alias = get_valid_field_name_and_alias(field_name)
            if field.ref:
                field_types = [
                    self.data_type(
                        type=self.get_class_name(field.ref_object_name, unique=False),
                        ref=True,
                        version_compatible=True,
                    )
                ]
            elif field.is_array:
                class_name = self.get_class_name(field_name)
                array_fields, array_field_classes = self.parse_array_fields(
                    class_name, field
                )
                field_types = array_fields[0].data_types
                is_list = True
                is_union = True
            elif field.anyOf:
                field_types = self.parse_any_of(field_name, field)
            elif field.oneOf:
                field_types = self.parse_one_of(field_name, field)
            elif field.allOf:
                class_name = self.get_class_name(field_name)
                field_types = self.parse_all_of(class_name, field)
            elif field.is_object:
                if field.properties:
                    class_name = self.get_class_name(field_name)
                    self.parse_object(class_name, field)
                    field_types = [
                        self.data_type(
                            type=class_name, ref=True, version_compatible=True
                        )
                    ]
                else:
                    field_types = [
                        self.data_type(
                            type='Dict[str, Any]',
                            imports_=[
                                Import(from_='typing', import_='Any'),
                                Import(from_='typing', import_='Dict'),
                            ],
                        )
                    ]
            elif field.enum:
                enum = self.parse_enum(field_name, field)
                field_types = [
                    self.data_type(type=enum.name, ref=True, version_compatible=True)
                ]
            else:
                field_types = self.get_data_type(field)
            required: bool = field_name in requires
            fields.append(
                self.data_model_field_type(
                    name=field_name,
                    example=field.examples,
                    description=field.description,
                    default=field.default,
                    title=field.title,
                    data_types=field_types,
                    required=required,
                    is_list=is_list,
                    is_union=is_union,
                    alias=alias,
                )
            )
        return fields
Exemplo n.º 21
0
    def parse_object_fields(self, obj: JsonSchemaObject,
                            path: List[str]) -> List[DataModelFieldBase]:
        properties: Dict[str,
                         JsonSchemaObject] = (obj.properties if obj.properties
                                              is not None else {})
        requires: Set[str] = {*obj.required
                              } if obj.required is not None else {*()}
        fields: List[DataModelFieldBase] = []

        for field_name, field in properties.items():
            is_list: bool = False
            is_union: bool = False
            field_types: List[DataType]
            original_field_name: str = field_name
            constraints: Optional[Mapping[str, Any]] = None
            field_name, alias = self.model_resolver.get_valid_field_name_and_alias(
                field_name)
            if field.ref:
                field_types = [
                    self.data_type(
                        type=self.model_resolver.add_ref(field.ref).name,
                        ref=True,
                        version_compatible=True,
                    )
                ]
            elif field.is_array:
                array_field, array_field_classes = self.parse_array_fields(
                    field_name, field, [*path, field_name])
                field_types = array_field.data_types
                is_list = True
                is_union = True
            elif field.anyOf:
                field_types = self.parse_any_of(field_name, field,
                                                [*path, field_name])
            elif field.oneOf:
                field_types = self.parse_one_of(field_name, field,
                                                [*path, field_name])
            elif field.allOf:
                field_types = self.parse_all_of(field_name, field,
                                                [*path, field_name])
            elif field.is_object:
                if field.properties:
                    field_types = [
                        self.data_type(
                            type=self.parse_object(field_name,
                                                   field, [*path, field_name],
                                                   unique=True).name,
                            ref=True,
                            version_compatible=True,
                        )
                    ]
                else:
                    field_types = [
                        self.data_type(
                            type='Dict[str, Any]',
                            imports_=[
                                Import(from_='typing', import_='Any'),
                                Import(from_='typing', import_='Dict'),
                            ],
                        )
                    ]
            elif field.enum:
                enum = self.parse_enum(field_name,
                                       field, [*path, field_name],
                                       unique=True)
                field_types = [
                    self.data_type(type=enum.name,
                                   ref=True,
                                   version_compatible=True)
                ]
            else:
                field_types = self.get_data_type(field)
                if self.field_constraints:
                    constraints = field.dict()
            required: bool = original_field_name in requires
            fields.append(
                self.data_model_field_type(
                    name=field_name,
                    example=field.examples,
                    description=field.description,
                    default=field.default,
                    title=field.title,
                    data_types=field_types,
                    required=required,
                    is_list=is_list,
                    is_union=is_union,
                    alias=alias,
                    constraints=constraints,
                ))
        return fields
Exemplo n.º 22
0
            DataType(type='NegativeFloat', imports=[IMPORT_NEGATIVE_FLOAT]),
        ),
    ],
)
def test_get_data_float_type(types, params, data_type):
    assert DataTypeManager().get_data_float_type(types, **params) == data_type


@pytest.mark.parametrize(
    'types,params,data_type',
    [
        (
            Types.decimal,
            {},
            DataType(type='Decimal',
                     imports=[Import(from_='decimal', import_='Decimal')]),
        ),
        (
            Types.decimal,
            {
                'maximum': 10
            },
            DataType(
                type='condecimal',
                is_func=True,
                kwargs={'le': 10},
                imports=[IMPORT_CONDECIMAL],
            ),
        ),
        (
            Types.decimal,
            DataType(type='NegativeFloat', import_=IMPORT_NEGATIVE_FLOAT),
        ),
    ],
)
def test_get_data_float_type(types, params, data_type):
    assert DataTypeManager().get_data_float_type(types, **params) == data_type


@pytest.mark.parametrize(
    'types,params,data_type',
    [
        (
            Types.decimal,
            {},
            DataType(
                type='Decimal', import_=Import(from_='decimal', import_='Decimal')
            ),
        ),
        (
            Types.decimal,
            {'maximum': 10},
            DataType(
                type='condecimal',
                is_func=True,
                kwargs={'le': 10},
                import_=IMPORT_CONDECIMAL,
            ),
        ),
        (
            Types.decimal,
            {'exclusiveMaximum': 10},
Exemplo n.º 24
0
type_map: Dict[Types, DataType] = {
    Types.integer:
    DataType(type='int'),
    Types.int32:
    DataType(type='int'),
    Types.int64:
    DataType(type='int'),
    Types.number:
    DataType(type='float'),
    Types.float:
    DataType(type='float'),
    Types.double:
    DataType(type='float'),
    Types.decimal:
    DataType(type='Decimal',
             imports_=[Import(from_='decimal', import_='Decimal')]),
    Types.time:
    DataType(type='time'),
    Types.string:
    DataType(type='str'),
    Types.byte:
    DataType(type='str'),  # base64 encoded string
    Types.binary:
    DataType(type='bytes'),
    Types.date:
    DataType(type='date', imports_=[Import(from_='datetime', import_='date')]),
    Types.date_time:
    DataType(type='datetime',
             imports_=[Import(from_='datetime', import_='datetime')]),
    Types.password:
    DataType(type='SecretStr',
Exemplo n.º 25
0
    def parse(
        self,
        with_import: Optional[bool] = True,
        format_: Optional[bool] = True
    ) -> Union[str, Dict[Tuple[str, ...], str]]:
        for obj_name, raw_obj in self.base_parser.specification['components'][
                'schemas'].items():  # type: str, Dict
            obj = JsonSchemaObject.parse_obj(raw_obj)
            if obj.is_object:
                self.parse_object(obj_name, obj)
            elif obj.is_array:
                self.parse_array(obj_name, obj)
            elif obj.enum:
                self.parse_enum(obj_name, obj)
            elif obj.allOf:
                self.parse_all_of(obj_name, obj)
            else:
                self.parse_root_type(obj_name, obj)

        if with_import:
            if self.target_python_version == PythonVersion.PY_37:
                self.imports.append(IMPORT_ANNOTATIONS)

        _, sorted_data_models, require_update_action_models = sort_data_models(
            self.results)

        results: Dict[Tuple[str, ...], str] = {}

        module_key = lambda x: (*x.name.split('.')[:-1], )

        grouped_models = groupby(sorted(sorted_data_models.values(),
                                        key=module_key),
                                 key=module_key)
        for module, models in ((k, [*v]) for k, v in grouped_models):
            module_path = '.'.join(module)

            result: List[str] = []
            imports = Imports()
            models_to_update: List[str] = []

            for model in models:
                if model.name in require_update_action_models:
                    models_to_update += [model.name]
                imports.append(model.imports)
                for ref_name in model.reference_classes:
                    if '.' not in ref_name:
                        continue
                    ref_path = ref_name.rsplit('.', 1)[0]
                    if ref_path == module_path:
                        continue
                    imports.append(Import(from_='.', import_=ref_path))

            if with_import:
                result += [imports.dump(), self.imports.dump(), '\n']

            code = dump_templates(models)
            result += [code]

            if self.dump_resolve_reference_action is not None:
                result += [
                    '\n',
                    self.dump_resolve_reference_action(models_to_update)
                ]

            body = '\n'.join(result)
            if format_:
                body = format_code(body, self.target_python_version)

            if module:
                module = (*module[:-1], f'{module[-1]}.py')
                parent = (*module[:-1], '__init__.py')
                if parent not in results:
                    results[parent] = ''
            else:
                module = ('__init__.py', )

            results[module] = body

        # retain existing behaviour
        if [*results] == [('__init__.py', )]:
            return results[('__init__.py', )]

        return results
Exemplo n.º 26
0
    def parse_object_fields(self,
                            obj: JsonSchemaObject) -> List[DataModelField]:
        properties: Dict[str,
                         JsonSchemaObject] = (obj.properties if obj.properties
                                              is not None else {})
        requires: Set[str] = {*obj.required
                              } if obj.required is not None else {*()}
        fields: List[DataModelField] = []

        for field_name, field in properties.items():  # type: ignore
            is_list = False
            field_types: List[DataType]
            if field.ref:
                field_types = [
                    self.data_type(type=field.ref_object_name,
                                   ref=True,
                                   version_compatible=True)
                ]
            elif field.is_array:
                class_name = self.get_class_name(field_name)
                array_fields, array_field_classes = self.parse_array_fields(
                    class_name, field)
                field_types = array_fields[0].data_types
                is_list = True
            elif field.is_object:
                if field.properties:
                    class_name = self.get_class_name(field_name)
                    self.parse_object(class_name, field)
                    field_types = [
                        self.data_type(type=class_name,
                                       ref=True,
                                       version_compatible=True)
                    ]
                else:
                    field_types = [
                        self.data_type(
                            type='Dict[str, Any]',
                            imports_=[
                                Import(from_='typing', import_='Any'),
                                Import(from_='typing', import_='Dict'),
                            ],
                        )
                    ]
            elif field.enum:
                enum = self.parse_enum(field_name, field)
                field_types = [
                    self.data_type(type=enum.name,
                                   ref=True,
                                   version_compatible=True)
                ]
            elif field.anyOf:
                field_types = self.parse_any_of(field_name, field)
            elif field.allOf:
                field_types = self.parse_all_of(field_name, field)
            else:
                data_type = self.get_data_type(field)
                field_types = [data_type]
            required: bool = field_name in requires
            fields.append(
                self.data_model_field_type(
                    name=field_name,
                    data_types=field_types,
                    required=required,
                    is_list=is_list,
                ))
        return fields
Exemplo n.º 27
0
    def parse(
        self,
        with_import: Optional[bool] = True,
        format_: Optional[bool] = True,
        settings_path: Optional[Path] = None,
    ) -> Union[str, Dict[Tuple[str, ...], Result]]:

        self.parse_raw()

        if with_import:
            if self.target_python_version != PythonVersion.PY_36:
                self.imports.append(IMPORT_ANNOTATIONS)

        if format_:
            code_formatter: Optional[CodeFormatter] = CodeFormatter(
                self.target_python_version, settings_path)
        else:
            code_formatter = None

        _, sorted_data_models, require_update_action_models = sort_data_models(
            self.results)

        results: Dict[Tuple[str, ...], Result] = {}

        module_key = lambda x: x.module_path

        # process in reverse order to correctly establish module levels
        grouped_models = groupby(
            sorted(sorted_data_models.values(), key=module_key, reverse=True),
            key=module_key,
        )

        module_models: List[Tuple[Tuple[str, ...], List[DataModel]]] = []

        for module, models in ((k, [*v]) for k, v in grouped_models
                               ):  # type: Tuple[str, ...], List[DataModel]

            for model in models:
                if isinstance(model, self.data_model_root_type):
                    root_data_type = model.fields[0].data_type

                    # backward compatible
                    # Remove duplicated root model
                    if (root_data_type.reference and not root_data_type.is_dict
                            and not root_data_type.is_list
                            and root_data_type.reference.source in models
                            and root_data_type.reference.name
                            == self.model_resolver.get_class_name(
                                model.reference.original_name, unique=False)):
                        # Replace referenced duplicate model to original model
                        for child in model.reference.children[:]:
                            child.replace_reference(root_data_type.reference)
                        models.remove(model)
                        continue

                    #  Custom root model can't be inherited on restriction of Pydantic
                    for child in model.reference.children:
                        # inheritance model
                        if isinstance(child, DataModel):
                            for base_class in child.base_classes:
                                if base_class.reference == model.reference:
                                    child.base_classes.remove(base_class)

            module_models.append((
                module,
                models,
            ))

            scoped_model_resolver = ModelResolver(
                exclude_names={
                    i.alias or i.import_
                    for m in models for i in m.imports
                },
                duplicate_name_suffix='Model',
            )

            for model in models:
                class_name: str = model.class_name
                generated_name: str = scoped_model_resolver.add(
                    model.path, class_name, unique=True, class_name=True).name
                if class_name != generated_name:
                    if '.' in model.reference.name:
                        model.reference.name = (
                            f"{model.reference.name.rsplit('.', 1)[0]}.{generated_name}"
                        )
                    else:
                        model.reference.name = generated_name

        for module, models in module_models:
            init = False
            if module:
                parent = (*module[:-1], '__init__.py')
                if parent not in results:
                    results[parent] = Result(body='')
                if (*module, '__init__.py') in results:
                    module = (*module, '__init__.py')
                    init = True
                else:
                    module = (*module[:-1], f'{module[-1]}.py')
            else:
                module = ('__init__.py', )

            result: List[str] = []
            imports = Imports()
            scoped_model_resolver = ModelResolver()

            for model in models:
                imports.append(model.imports)
                for data_type in model.all_data_types:
                    # To change from/import

                    if not data_type.reference or data_type.reference.source in models:
                        # No need to import non-reference model.
                        # Or, Referenced model is in the same file. we don't need to import the model
                        continue

                    if isinstance(data_type, BaseClassDataType):
                        from_ = ''.join(
                            relative(model.module_name, data_type.full_name))
                        import_ = data_type.reference.short_name
                        full_path = from_, import_
                    else:
                        from_, import_ = full_path = relative(
                            model.module_name, data_type.full_name)

                    alias = scoped_model_resolver.add(full_path, import_).name

                    name = data_type.reference.short_name
                    if from_ and import_ and alias != name:
                        data_type.alias = f'{alias}.{name}'

                    if init:
                        from_ += "."
                    imports.append(
                        Import(from_=from_, import_=import_, alias=alias))

            if self.reuse_model:
                model_cache: Dict[Tuple[str, ...], Reference] = {}
                duplicates = []
                for model in models:
                    model_key = tuple(
                        to_hashable(v) for v in (
                            model.base_classes,
                            model.extra_template_data,
                            model.fields,
                        ))
                    cached_model_reference = model_cache.get(model_key)
                    if cached_model_reference:
                        if isinstance(model, Enum):
                            for child in model.reference.children[:]:
                                # child is resolved data_type by reference
                                data_model = get_most_of_parent(child)
                                # TODO: replace reference in all modules
                                if data_model in models:  # pragma: no cover
                                    child.replace_reference(
                                        cached_model_reference)
                            duplicates.append(model)
                        else:
                            index = models.index(model)
                            inherited_model = model.__class__(
                                fields=[],
                                base_classes=[cached_model_reference],
                                description=model.description,
                                reference=Reference(
                                    name=model.name,
                                    path=model.reference.path + '/reuse',
                                ),
                            )
                            if (cached_model_reference.path
                                    in require_update_action_models):
                                require_update_action_models.append(
                                    inherited_model.path)
                            models.insert(index, inherited_model)
                            models.remove(model)

                    else:
                        model_cache[model_key] = model.reference

                for duplicate in duplicates:
                    models.remove(duplicate)

            if self.set_default_enum_member:
                for model in models:
                    for model_field in model.fields:
                        if not model_field.default:
                            continue
                        for data_type in model_field.data_type.all_data_types:
                            if data_type.reference and isinstance(
                                    data_type.reference.source,
                                    Enum):  # pragma: no cover
                                enum_member = data_type.reference.source.find_member(
                                    model_field.default)
                                if enum_member:
                                    model_field.default = enum_member
            if with_import:
                result += [str(self.imports), str(imports), '\n']

            code = dump_templates(models)
            result += [code]

            if self.dump_resolve_reference_action is not None:
                result += [
                    '\n',
                    self.dump_resolve_reference_action(
                        m.reference.short_name for m in models
                        if m.path in require_update_action_models),
                ]

            body = '\n'.join(result)
            if code_formatter:
                body = code_formatter.format_code(body)

            results[module] = Result(body=body, source=models[0].file_path)

        # retain existing behaviour
        if [*results] == [('__init__.py', )]:
            return results[('__init__.py', )].body

        return results
Exemplo n.º 28
0
 Types.number:
 DataType(type='float'),
 Types.float:
 DataType(type='float'),
 Types.double:
 DataType(type='float'),
 Types.time:
 DataType(type='time'),
 Types.string:
 DataType(type='str'),
 Types.byte:
 DataType(type='str'),  # base64 encoded string
 Types.binary:
 DataType(type='bytes'),
 Types.date:
 DataType(type='date', import_=Import(from_='datetime', import_='date')),
 Types.date_time:
 DataType(type='datetime',
          import_=Import(from_='datetime', import_='datetime')),
 Types.password:
 DataType(type='SecretStr',
          import_=Import(from_='pydantic', import_='SecretStr')),
 Types.email:
 DataType(type='EmailStr',
          import_=Import(from_='pydantic', import_='EmailStr')),
 Types.uuid:
 DataType(type='UUID', import_=Import(from_='pydantic', import_='UUID')),
 Types.uuid1:
 DataType(type='UUID1', import_=Import(from_='pydantic', import_='UUID1')),
 Types.uuid2:
 DataType(type='UUID2', import_=Import(from_='pydantic', import_='UUID2')),