def get_data_type(self, schema: JsonSchemaObject, suffix: str = '') -> DataType: if schema.ref: data_type = self.openapi_model_parser.get_ref_data_type(schema.ref) self.imports.append( Import( # TODO: Improve import statements from_=model_module_name_var.get(), import_=data_type.type_hint, ) ) return data_type elif schema.is_array: # TODO: Improve handling array items = schema.items if isinstance(schema.items, list) else [schema.items] return self.openapi_model_parser.data_type( data_types=[self.get_data_type(i, suffix) for i in items], is_list=True ) elif schema.is_object: camelcase_path = stringcase.camelcase(self.path[1:].replace("/", "_")) capitalized_suffix = suffix.capitalize() name: str = f'{camelcase_path}{self.type.capitalize()}{capitalized_suffix}' path = ['paths', self.path, self.type, capitalized_suffix] data_type = self.openapi_model_parser.parse_object(name, schema, path) self.imports.append( Import(from_=model_module_name_var.get(), import_=data_type.type_hint,) ) return data_type return self.openapi_model_parser.get_data_type(schema)
def get_parameter_type( self, parameter: Dict[str, Union[str, Dict[str, str]]], snake_case: bool ) -> Argument: schema: JsonSchemaObject = JsonSchemaObject.parse_obj(parameter["schema"]) format_ = schema.format or "default" type_ = json_schema_data_formats[schema.type][format_] name: str = parameter["name"] # type: ignore orig_name = name if snake_case: name = stringcase.snakecase(name) field = DataModelField( name=name, data_type=type_map[type_], required=parameter.get("required") or parameter.get("in") == "path", ) self.imports.extend(field.imports) if orig_name != name: default: Optional[ str ] = f"Query({'...' if field.required else repr(schema.default)}, alias='{orig_name}')" self.imports.append(Import(from_='fastapi', import_='Query')) else: default = repr(schema.default) if 'default' in parameter["schema"] else None return Argument( name=field.name, type_hint=field.type_hint, default=default, # type: ignore default_value=schema.default, required=field.required, )
def __init__( self, *, reference: Reference, fields: List[DataModelFieldBase], decorators: Optional[List[str]] = None, base_classes: Optional[List[Reference]] = None, custom_base_class: Optional[str] = None, custom_template_dir: Optional[Path] = None, extra_template_data: Optional[DefaultDict[str, Any]] = None, path: Optional[Path] = None, description: Optional[str] = None, ): super().__init__( reference=reference, fields=fields, decorators=decorators, base_classes=base_classes, custom_base_class=custom_base_class, custom_template_dir=custom_template_dir, extra_template_data=extra_template_data, path=path, description=description, ) self._additional_imports.append( Import.from_full_path('pydantic.dataclasses.dataclass'))
def __init__( self, name: str, fields: List[DataModelField], decorators: Optional[List[str]] = None, base_classes: Optional[List[str]] = None, custom_base_class: Optional[str] = None, custom_template_dir: Optional[Path] = None, extra_template_data: Optional[DefaultDict[str, Any]] = None, auto_import: bool = True, reference_classes: Optional[List[str]] = None, ): super().__init__( name, fields, decorators, base_classes, custom_base_class=custom_base_class, custom_template_dir=custom_template_dir, extra_template_data=extra_template_data, auto_import=auto_import, reference_classes=reference_classes, ) self.imports.append(Import.from_full_path('pydantic.dataclasses.dataclass'))
def request(self) -> Optional[Argument]: arguments: List[Argument] = [] for requests in self.request_objects: for content_type, schema in requests.contents.items(): # TODO: support other content-types if RE_APPLICATION_JSON_PATTERN.match(content_type): data_type = self.get_data_type(schema, 'request') arguments.append( # TODO: support multiple body Argument( name='body', # type: ignore type_hint=data_type.type_hint, required=requests.required, )) self.imports.extend(data_type.imports) elif content_type == 'application/x-www-form-urlencoded': arguments.append( # TODO: support form with `Form()` Argument( name='request', # type: ignore type_hint='Request', # type: ignore required=True, )) self.imports.append( Import.from_full_path('starlette.requests.Request')) if not arguments: return None return arguments[0]
def __init__( self, name: str, fields: List[DataModelField], decorators: Optional[List[str]] = None, base_classes: Optional[List[str]] = None, custom_base_class: Optional[str] = None, custom_template_dir: Optional[Path] = None, extra_template_data: Optional[DefaultDict[str, Any]] = None, auto_import: bool = True, reference_classes: Optional[List[str]] = None, imports: Optional[List[Import]] = None, ): super().__init__( name=name, fields=fields, decorators=decorators, base_classes=base_classes, custom_base_class=custom_base_class, custom_template_dir=custom_template_dir, extra_template_data=extra_template_data, auto_import=auto_import, reference_classes=reference_classes, imports=imports, ) if 'additionalProperties' in self.extra_template_data: self.extra_template_data['config'] = Config(extra='Extra.allow') self.imports.append(Import(from_='pydantic', import_='Extra'))
def parse_any_of(self, name: str, obj: JsonSchemaObject) -> List[DataType]: any_of_data_types: List[DataType] = [] for any_of_item in obj.anyOf: if any_of_item.ref: # $ref any_of_data_types.append( self.data_type( type=any_of_item.ref_object_name, ref=True, version_compatible=True, )) elif not any(v for k, v in vars(any_of_item).items() if k != 'type'): # trivial types any_of_data_types.append(self.get_data_type(any_of_item)) elif (any_of_item.is_array and isinstance(any_of_item.items, JsonSchemaObject) and not any(v for k, v in vars(any_of_item.items).items() if k != 'type')): # trivial item types any_of_data_types.append( self.data_type( type= f"List[{self.get_data_type(any_of_item.items).type_hint}]", imports_=[Import(from_='typing', import_='List')], )) else: singular_name = get_singular_name(name) self.parse_object(singular_name, any_of_item) any_of_data_types.append( self.data_type(type=singular_name, ref=True, version_compatible=True)) return any_of_data_types
def get_parameter_type( self, parameter: Dict[str, Union[str, Dict[str, str]]], snake_case: bool ) -> Argument: ref: Optional[str] = parameter.get('$ref') # type: ignore if ref: parameter = get_ref_body(ref, self.openapi_model_parser, self.components) name: str = parameter["name"] # type: ignore orig_name = name if snake_case: name = stringcase.snakecase(name) schema: JsonSchemaObject = JsonSchemaObject.parse_obj(parameter["schema"]) field = DataModelField( name=name, data_type=self.get_data_type(schema), required=parameter.get("required") or parameter.get("in") == "path", ) self.imports.extend(field.imports) if orig_name != name: default: Optional[ str ] = f"Query({'...' if field.required else repr(schema.default)}, alias='{orig_name}')" self.imports.append(Import(from_='fastapi', import_='Query')) else: default = repr(schema.default) if 'default' in parameter["schema"] else None return Argument( name=field.name, type_hint=field.type_hint, default=default, # type: ignore default_value=schema.default, required=field.required, )
def __init__( self, *, reference: Reference, fields: List[DataModelFieldBase], decorators: Optional[List[str]] = None, base_classes: Optional[List[Reference]] = None, custom_base_class: Optional[str] = None, custom_template_dir: Optional[Path] = None, extra_template_data: Optional[DefaultDict[str, Dict[str, Any]]] = None, methods: Optional[List[str]] = None, path: Optional[Path] = None, description: Optional[str] = None, ) -> None: if not self.TEMPLATE_FILE_PATH: raise Exception('TEMPLATE_FILE_PATH is undefined') template_file_path = Path(self.TEMPLATE_FILE_PATH) if custom_template_dir is not None: custom_template_file_path = custom_template_dir / template_file_path.name if custom_template_file_path.exists(): template_file_path = custom_template_file_path self.fields: List[DataModelFieldBase] = fields or [] self.decorators: List[str] = decorators or [] self._additional_imports: List[Import] = [] self.base_classes: List[Reference] = [ base_class for base_class in base_classes or [] if base_class ] self.custom_base_class = custom_base_class self.file_path: Optional[Path] = path self.reference: Reference = reference self.reference.source = self self.extra_template_data = ( extra_template_data[self.name] if extra_template_data is not None else defaultdict(dict) ) if not self.base_classes: base_class_full_path = custom_base_class or self.BASE_CLASS if base_class_full_path: self._additional_imports.append( Import.from_full_path(base_class_full_path) ) if extra_template_data: all_model_extra_template_data = extra_template_data.get(ALL_MODEL) if all_model_extra_template_data: self.extra_template_data.update(all_model_extra_template_data) self.methods: List[str] = methods or [] self.description = description for field in self.fields: field.parent = self super().__init__(template_file_path=template_file_path)
def test_dump(inputs: Sequence[Tuple[Optional[str], str]], value): """Test creating import lines.""" imports = Imports() imports.append( [Import(from_=from_, import_=import_) for from_, import_ in inputs]) assert str(imports) == value
def __init__( self, name: str, fields: List[DataModelField], decorators: Optional[List[str]] = None, base_classes: Optional[List[str]] = None, custom_base_class: Optional[str] = None, custom_template_dir: Optional[Path] = None, extra_template_data: Optional[DefaultDict[str, Any]] = None, auto_import: bool = True, reference_classes: Optional[List[str]] = None, imports: Optional[List[Import]] = None, ): methods: List[str] = [field.method for field in fields if field.method] super().__init__( name=name, fields=fields, # type: ignore decorators=decorators, base_classes=base_classes, custom_base_class=custom_base_class, custom_template_dir=custom_template_dir, extra_template_data=extra_template_data, auto_import=auto_import, reference_classes=reference_classes, imports=imports, methods=methods, ) config_parameters: Dict[str, Any] = {} if 'additionalProperties' in self.extra_template_data: config_parameters['extra'] = 'Extra.allow' self.imports.append(Import(from_='pydantic', import_='Extra')) if config_parameters: from datamodel_code_generator.model.pydantic import Config self.extra_template_data['config'] = Config.parse_obj( config_parameters) for field in fields: if field.field: self.imports.append(Import(from_='pydantic', import_='Field'))
def test_get_data_type(schema_type, schema_format, result_type, from_, import_): if from_ and import_: imports_: Optional[List[Import]] = [Import(from_=from_, import_=import_)] else: imports_ = None parser = OpenAPIParser(BaseModel, CustomRootType) assert parser.get_data_type( JsonSchemaObject(type=schema_type, format=schema_format) ) == DataType(type=result_type, imports_=imports_)
def test_get_data_type(schema_type, schema_format, result_type, from_, import_): if from_ and import_: imports: Optional[List[Import]] = [Import(from_=from_, import_=import_)] else: imports = [] parser = JsonSchemaParser('') assert parser.get_data_type( JsonSchemaObject(type=schema_type, format=schema_format) ) == DataType(type=result_type, imports=imports)
def get_parameter_type( self, parameters: ParameterObject, snake_case: bool, path: List[str], ) -> Optional[Argument]: orig_name = parameters.name if snake_case: name = stringcase.snakecase(parameters.name) else: name = parameters.name schema: Optional[JsonSchemaObject] = None data_type: Optional[DataType] = None for content in parameters.content.values(): if isinstance(content.schema_, ReferenceObject): data_type = self.get_ref_data_type(content.schema_.ref) ref_model = self.get_ref_model(content.schema_.ref) schema = JsonSchemaObject.parse_obj(ref_model) else: schema = content.schema_ break if not data_type: if not schema: schema = parameters.schema_ data_type = self.parse_schema(name, schema, [*path, name]) if not schema: return None field = DataModelField( name=name, data_type=data_type, required=parameters.required or parameters.in_ == ParameterLocation.path, ) if orig_name != name: if parameters.in_: param_is = parameters.in_.value.lower().capitalize() self.imports_for_fastapi.append( Import(from_='fastapi', import_=param_is)) default: Optional[ str] = f"{param_is}({'...' if field.required else repr(schema.default)}, alias='{orig_name}')" else: default = repr(schema.default) if schema.has_default else None self.imports_for_fastapi.append(field.imports) self.data_types.append(field.data_type) return Argument( name=field.name, type_hint=field.type_hint, default=default, # type: ignore default_value=schema.default, required=field.required, )
def response(self) -> str: models: List[str] = [] for response in self.response_objects: # expect 2xx if response.status_code.startswith("2"): for content_type, schema in response.contents.items(): if content_type == "application/json": if schema.is_array: if isinstance(schema.items, list): type_ = f'List[{",".join(i.ref_object_name for i in schema.items)}]' self.imports.extend( Import( from_=model_path_var.get(), import_=i.ref_object_name, ) for i in schema.items ) else: type_ = f'List[{schema.items.ref_object_name}]' self.imports.append( Import( from_=model_path_var.get(), import_=schema.items.ref_object_name, ) ) self.imports.append(IMPORT_LIST) else: type_ = schema.ref_object_name self.imports.append( Import( from_=model_path_var.get(), import_=schema.ref_object_name, ) ) models.append(type_) if not models: return "None" if len(models) > 1: return f'Union[{",".join(models)}]' return models[0]
def test_get_data_type(schema_type, schema_format, result_type, from_, import_): if from_ and import_: import_: Optional[Import] = Import(from_=from_, import_=import_) else: import_ = None parser = JsonSchemaParser('') assert (parser.get_data_type( JsonSchemaObject(type=schema_type, format=schema_format)).dict() == DataType( type=result_type, import_=import_).dict())
def get_parameter_type(self, parameter: Dict[str, Union[str, Dict[str, Any]]], snake_case: bool) -> Argument: ref: Optional[str] = parameter.get('$ref') # type: ignore if ref: parameter = get_ref_body(ref, self.openapi_model_parser, self.components) name: str = parameter["name"] # type: ignore orig_name = name if snake_case: name = stringcase.snakecase(name) content = parameter.get('content') schema: Optional[JsonSchemaObject] = None if content and isinstance(content, dict): content_schema = [ c.get("schema") for c in content.values() if isinstance(c.get("schema"), dict) ] if content_schema: schema = JsonSchemaObject.parse_obj(content_schema[0]) if not schema: schema = JsonSchemaObject.parse_obj(parameter["schema"]) field = DataModelField( name=name, data_type=self.get_data_type(schema, 'parameter'), required=parameter.get("required") or parameter.get("in") == "path", ) self.imports.extend(field.imports) if orig_name != name: has_in = parameter.get('in') if has_in and isinstance(has_in, str): param_is = has_in.lower().capitalize() self.imports.append(Import(from_='fastapi', import_=param_is)) default: Optional[ str] = f"{param_is}({'...' if field.required else repr(schema.default)}, alias='{orig_name}')" else: # https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.3.md#parameterObject # the spec says 'in' is a str type raise TypeError( f'Issue processing parameter for "in", expected a str, but got something else: {str(parameter)}' ) else: default = repr(schema.default) if schema.has_default else None return Argument( name=field.name, type_hint=field.type_hint, default=default, # type: ignore default_value=schema.default, required=field.required, )
def __init__( self, name: str, fields: List[DataModelField], decorators: Optional[List[str]] = None, base_classes: Optional[List[str]] = None, custom_base_class: Optional[str] = None, imports: Optional[List[Import]] = None, auto_import: bool = True, reference_classes: Optional[List[str]] = None, ) -> None: if not self.TEMPLATE_FILE_PATH: raise Exception('TEMPLATE_FILE_PATH is undefined') self.name: str = name self.fields: List[DataModelField] = fields or [] self.decorators: List[str] = decorators or [] self.imports: List[Import] = imports or [] self.base_class: Optional[str] = None base_classes = [ base_class for base_class in base_classes or [] if base_class ] self.base_classes: List[str] = base_classes self.reference_classes: List[str] = [ r for r in base_classes if r != self.BASE_CLASS ] if base_classes else [] if reference_classes: self.reference_classes.extend(reference_classes) if self.base_classes: self.base_class = ', '.join(self.base_classes) else: base_class_full_path = custom_base_class or self.BASE_CLASS if auto_import: if base_class_full_path: self.imports.append( Import.from_full_path(base_class_full_path)) self.base_class = base_class_full_path.split('.')[-1] unresolved_types: Set[str] = set() for field in self.fields: unresolved_types.update(set(field.unresolved_types)) self.reference_classes = list( set(self.reference_classes) | unresolved_types) if auto_import: for field in self.fields: self.imports.extend(field.imports) super().__init__(template_file_path=self.TEMPLATE_FILE_PATH)
def get_data_type(self, schema: JsonSchemaObject) -> DataType: if schema.ref: data_type = self.open_api_model_parser.get_ref_data_type(schema.ref) data_type.imports_.append( Import( # TODO: Improve import statements from_=model_path_var.get(), import_=data_type.type, ) ) return data_type elif schema.is_array: # TODO: Improve handling array items = schema.items if isinstance(schema.items, list) else [schema.items] return self.open_api_model_parser.data_type( data_types=[self.get_data_type(i) for i in items], is_list=True ) return self.open_api_model_parser.get_data_type(schema)
def parse_request_body( self, name: str, request_body: RequestBodyObject, path: List[str], ) -> None: super().parse_request_body(name, request_body, path) arguments: List[Argument] = [] for ( media_type, media_obj, ) in request_body.content.items(): # type: str, MediaObject if isinstance( media_obj.schema_, (JsonSchemaObject, ReferenceObject)): # pragma: no cover # TODO: support other content-types if RE_APPLICATION_JSON_PATTERN.match(media_type): if isinstance(media_obj.schema_, ReferenceObject): data_type = self.get_ref_data_type( media_obj.schema_.ref) else: data_type = self.parse_schema(name, media_obj.schema_, [*path, media_type]) arguments.append( # TODO: support multiple body Argument( name='body', # type: ignore type_hint=data_type.type_hint, required=request_body.required, )) self.data_types.append(data_type) elif media_type == 'application/x-www-form-urlencoded': arguments.append( # TODO: support form with `Form()` Argument( name='request', # type: ignore type_hint='Request', # type: ignore required=True, )) self.imports_for_fastapi.append( Import.from_full_path('starlette.requests.Request')) self._temporary_operation[ '_request'] = arguments[0] if arguments else None
def get_parameter_type( self, parameter: Dict[str, Union[str, Dict[str, Any]]], snake_case: bool ) -> Argument: ref: Optional[str] = parameter.get('$ref') # type: ignore if ref: parameter = get_ref_body(ref, self.openapi_model_parser, self.components) name: str = parameter["name"] # type: ignore orig_name = name if snake_case: name = stringcase.snakecase(name) content = parameter.get('content') schema: Optional[JsonSchemaObject] = None if content and isinstance(content, dict): content_schema = [ c.get("schema") for c in content.values() if isinstance(c.get("schema"), dict) ] if content_schema: schema = JsonSchemaObject.parse_obj(content_schema[0]) if not schema: schema = JsonSchemaObject.parse_obj(parameter["schema"]) field = DataModelField( name=name, data_type=self.get_data_type(schema, 'parameter'), required=parameter.get("required") or parameter.get("in") == "path", ) self.imports.extend(field.imports) if orig_name != name: default: Optional[ str ] = f"Query({'...' if field.required else repr(schema.default)}, alias='{orig_name}')" self.imports.append(Import(from_='fastapi', import_='Query')) else: default = repr(schema.default) if schema.has_default else None return Argument( name=field.name, type_hint=field.type_hint, default=default, # type: ignore default_value=schema.default, required=field.required, )
def request(self) -> Optional[Argument]: arguments: List[Argument] = [] for requests in self.request_objects: for content_type, schema in requests.contents.items(): # TODO: support other content-types if content_type == "application/json": arguments.append( # TODO: support multiple body Argument( name='body', # type: ignore type_hint=schema.ref_object_name, required=requests.required, )) self.imports.append( Import(from_=model_path_var.get(), import_=schema.ref_object_name)) if not arguments: return None return arguments[0]
def parse_list_item( self, name: str, target_items: List[JsonSchemaObject], path: List[str] ) -> List[DataType]: data_types: List[DataType] = [] for item in target_items: if item.ref: # $ref data_types.append( self.data_type( type=self.model_resolver.add_ref(item.ref).name, ref=True, version_compatible=True, ) ) elif not any(v for k, v in vars(item).items() if k != 'type'): # trivial types data_types.extend(self.get_data_type(item)) elif ( item.is_array and isinstance(item.items, JsonSchemaObject) and not any(v for k, v in vars(item.items).items() if k != 'type') ): # trivial item types types = [t.type_hint for t in self.get_data_type(item.items)] data_types.append( self.data_type( type=f"List[Union[{', '.join(types)}]]" if len(types) > 1 else f"List[{types[0]}]", imports_=[Import(from_='typing', import_='List')], ) ) else: data_types.append( self.data_type( type=self.parse_object( name, item, path, singular_name=True ).name, ref=True, version_compatible=True, ) ) return data_types
def generate_app_code(environment, parsed_object) -> str: template_path = Path('main.jinja2') grouped_operations = defaultdict(list) for k, g in itertools.groupby( parsed_object.operations, key=lambda x: x.path.strip('/').split('/')[0]): grouped_operations[k] += list(g) imports = Imports() routers = [] for name, operations in grouped_operations.items(): imports.append( Import(from_=CONTROLLERS_DIR_NAME + '.' + name, import_=name + '_router')) routers.append(name + '_router') result = environment.get_template(str(template_path)).render( imports=imports, routers=routers, ) return result
def parse( self, with_import: Optional[bool] = True, format_: Optional[bool] = True, settings_path: Optional[Path] = None, ) -> Union[str, Dict[Tuple[str, ...], Result]]: self.parse_raw() if with_import: if self.target_python_version != PythonVersion.PY_36: self.imports.append(IMPORT_ANNOTATIONS) if format_: code_formatter: Optional[CodeFormatter] = CodeFormatter( self.target_python_version, settings_path) else: code_formatter = None _, sorted_data_models, require_update_action_models = sort_data_models( self.results) results: Dict[Tuple[str, ...], Result] = {} module_key = lambda x: x.module_path # process in reverse order to correctly establish module levels grouped_models = groupby( sorted(sorted_data_models.values(), key=module_key, reverse=True), key=module_key, ) module_models: List[Tuple[Tuple[str, ...], List[DataModel]]] = [] for module, models in ((k, [*v]) for k, v in grouped_models ): # type: Tuple[str, ...], List[DataModel] for model in models: if isinstance(model, self.data_model_root_type): root_data_type = model.fields[0].data_type # backward compatible # Remove duplicated root model if (root_data_type.reference and not root_data_type.is_dict and not root_data_type.is_list and root_data_type.reference.source in models and root_data_type.reference.name == self.model_resolver.get_class_name( model.reference.original_name, unique=False)): # Replace referenced duplicate model to original model for child in model.reference.children[:]: child.replace_reference(root_data_type.reference) models.remove(model) continue # Custom root model can't be inherited on restriction of Pydantic for child in model.reference.children: # inheritance model if isinstance(child, DataModel): for base_class in child.base_classes: if base_class.reference == model.reference: child.base_classes.remove(base_class) module_models.append(( module, models, )) scoped_model_resolver = ModelResolver( exclude_names={ i.alias or i.import_ for m in models for i in m.imports }, duplicate_name_suffix='Model', ) for model in models: class_name: str = model.class_name generated_name: str = scoped_model_resolver.add( model.path, class_name, unique=True, class_name=True).name if class_name != generated_name: if '.' in model.reference.name: model.reference.name = ( f"{model.reference.name.rsplit('.', 1)[0]}.{generated_name}" ) else: model.reference.name = generated_name for module, models in module_models: init = False if module: parent = (*module[:-1], '__init__.py') if parent not in results: results[parent] = Result(body='') if (*module, '__init__.py') in results: module = (*module, '__init__.py') init = True else: module = (*module[:-1], f'{module[-1]}.py') else: module = ('__init__.py', ) result: List[str] = [] imports = Imports() scoped_model_resolver = ModelResolver() for model in models: imports.append(model.imports) for data_type in model.all_data_types: # To change from/import if not data_type.reference or data_type.reference.source in models: # No need to import non-reference model. # Or, Referenced model is in the same file. we don't need to import the model continue if isinstance(data_type, BaseClassDataType): from_ = ''.join( relative(model.module_name, data_type.full_name)) import_ = data_type.reference.short_name full_path = from_, import_ else: from_, import_ = full_path = relative( model.module_name, data_type.full_name) alias = scoped_model_resolver.add(full_path, import_).name name = data_type.reference.short_name if from_ and import_ and alias != name: data_type.alias = f'{alias}.{name}' if init: from_ += "." imports.append( Import(from_=from_, import_=import_, alias=alias)) if self.reuse_model: model_cache: Dict[Tuple[str, ...], Reference] = {} duplicates = [] for model in models: model_key = tuple( to_hashable(v) for v in ( model.base_classes, model.extra_template_data, model.fields, )) cached_model_reference = model_cache.get(model_key) if cached_model_reference: if isinstance(model, Enum): for child in model.reference.children[:]: # child is resolved data_type by reference data_model = get_most_of_parent(child) # TODO: replace reference in all modules if data_model in models: # pragma: no cover child.replace_reference( cached_model_reference) duplicates.append(model) else: index = models.index(model) inherited_model = model.__class__( fields=[], base_classes=[cached_model_reference], description=model.description, reference=Reference( name=model.name, path=model.reference.path + '/reuse', ), ) if (cached_model_reference.path in require_update_action_models): require_update_action_models.append( inherited_model.path) models.insert(index, inherited_model) models.remove(model) else: model_cache[model_key] = model.reference for duplicate in duplicates: models.remove(duplicate) if self.set_default_enum_member: for model in models: for model_field in model.fields: if not model_field.default: continue for data_type in model_field.data_type.all_data_types: if data_type.reference and isinstance( data_type.reference.source, Enum): # pragma: no cover enum_member = data_type.reference.source.find_member( model_field.default) if enum_member: model_field.default = enum_member if with_import: result += [str(self.imports), str(imports), '\n'] code = dump_templates(models) result += [code] if self.dump_resolve_reference_action is not None: result += [ '\n', self.dump_resolve_reference_action( m.reference.short_name for m in models if m.path in require_update_action_models), ] body = '\n'.join(result) if code_formatter: body = code_formatter.format_code(body) results[module] = Result(body=body, source=models[0].file_path) # retain existing behaviour if [*results] == [('__init__.py', )]: return results[('__init__.py', )].body return results
def get_data_type_from_full_path(self, full_path: str, is_custom_type: bool) -> DataType: return self.data_type.from_import(Import.from_full_path(full_path), is_custom_type=is_custom_type)
def __init__( self, name: str, fields: List[DataModelFieldBase], decorators: Optional[List[str]] = None, base_classes: Optional[List[str]] = None, custom_base_class: Optional[str] = None, custom_template_dir: Optional[Path] = None, extra_template_data: Optional[DefaultDict[str, Dict[str, Any]]] = None, imports: Optional[List[Import]] = None, auto_import: bool = True, reference_classes: Optional[List[str]] = None, methods: Optional[List[str]] = None, ) -> None: if not self.TEMPLATE_FILE_PATH: raise Exception('TEMPLATE_FILE_PATH is undefined') template_file_path = Path(self.TEMPLATE_FILE_PATH) if custom_template_dir is not None: custom_template_file_path = custom_template_dir / template_file_path.name if custom_template_file_path.exists(): template_file_path = custom_template_file_path self.name: str = name self.fields: List[DataModelFieldBase] = fields or [] self.decorators: List[str] = decorators or [] self.imports: List[Import] = imports or [] self.base_class: Optional[str] = None base_classes = [ base_class for base_class in base_classes or [] if base_class ] self.base_classes: List[str] = base_classes self.reference_classes: List[str] = [ r for r in base_classes if r != self.BASE_CLASS ] if base_classes else [] if reference_classes: self.reference_classes.extend(reference_classes) if self.base_classes: self.base_class = ', '.join(self.base_classes) else: base_class_full_path = custom_base_class or self.BASE_CLASS if auto_import: if base_class_full_path: self.imports.append( Import.from_full_path(base_class_full_path)) self.base_class = base_class_full_path.rsplit('.', 1)[-1] if '.' in name: module, class_name = name.rsplit('.', 1) prefix = f'{module}.' if self.base_class.startswith(prefix): self.base_class = self.base_class.replace(prefix, '', 1) else: class_name = name self.class_name: str = class_name self.extra_template_data = (extra_template_data[self.name] if extra_template_data is not None else defaultdict(dict)) if extra_template_data: all_model_extra_template_data = extra_template_data.get(ALL_MODEL) if all_model_extra_template_data: self.extra_template_data.update(all_model_extra_template_data) unresolved_types: Set[str] = set() for field in self.fields: unresolved_types.update(set(field.unresolved_types)) self.reference_classes = list( set(self.reference_classes) | unresolved_types) if auto_import: for field in self.fields: self.imports.extend(field.imports) self.methods: List[str] = methods or [] super().__init__(template_file_path=template_file_path)
DataType(type='NegativeFloat', imports=[IMPORT_NEGATIVE_FLOAT]), ), ], ) def test_get_data_float_type(types, params, data_type): assert DataTypeManager().get_data_float_type(types, **params) == data_type @pytest.mark.parametrize( 'types,params,data_type', [ ( Types.decimal, {}, DataType(type='Decimal', imports=[Import(from_='decimal', import_='Decimal')]), ), ( Types.decimal, { 'maximum': 10 }, DataType( type='condecimal', is_func=True, kwargs={'le': 10}, imports=[IMPORT_CONDECIMAL], ), ), ( Types.decimal,
from datamodel_code_generator.imports import Import IMPORT_CONSTR = Import.from_full_path('pydantic.constr') IMPORT_CONINT = Import.from_full_path('pydantic.conint') IMPORT_CONFLOAT = Import.from_full_path('pydantic.confloat') IMPORT_CONDECIMAL = Import.from_full_path('pydantic.condecimal') IMPORT_CONBYTES = Import.from_full_path('pydantic.conbytes') IMPORT_POSITIVE_INT = Import.from_full_path('pydantic.PositiveInt') IMPORT_NEGATIVE_INT = Import.from_full_path('pydantic.NegativeInt') IMPORT_NON_POSITIVE_INT = Import.from_full_path('pydantic.NonPositiveInt') IMPORT_NON_NEGATIVE_INT = Import.from_full_path('pydantic.NonNegativeInt') IMPORT_POSITIVE_FLOAT = Import.from_full_path('pydantic.PositiveFloat') IMPORT_NEGATIVE_FLOAT = Import.from_full_path('pydantic.NegativeFloat') IMPORT_NON_NEGATIVE_FLOAT = Import.from_full_path('pydantic.NonNegativeFloat') IMPORT_NON_POSITIVE_FLOAT = Import.from_full_path('pydantic.NonPositiveFloat') IMPORT_SECRET_STR = Import.from_full_path('pydantic.SecretStr') IMPORT_EMAIL_STR = Import.from_full_path('pydantic.EmailStr') IMPORT_UUID1 = Import.from_full_path('pydantic.UUID1') IMPORT_UUID2 = Import.from_full_path('pydantic.UUID2') IMPORT_UUID3 = Import.from_full_path('pydantic.UUID3') IMPORT_UUID4 = Import.from_full_path('pydantic.UUID4') IMPORT_UUID5 = Import.from_full_path('pydantic.UUID5') IMPORT_ANYURL = Import.from_full_path('pydantic.AnyUrl') IMPORT_IPV4ADDRESS = Import.from_full_path('ipaddress.IPv4Address') IMPORT_IPV6ADDRESS = Import.from_full_path('ipaddress.IPv6Address') IMPORT_EXTRA = Import.from_full_path('pydantic.Extra') IMPORT_FIELD = Import.from_full_path('pydantic.Field') IMPORT_STRICT_INT = Import.from_full_path('pydantic.StrictInt') IMPORT_STRICT_FLOAT = Import.from_full_path('pydantic.StrictFloat') IMPORT_STRICT_STR = Import.from_full_path('pydantic.StrictStr') IMPORT_STRICT_BOOL = Import.from_full_path('pydantic.StrictBool')
DataType(type='NegativeFloat', import_=IMPORT_NEGATIVE_FLOAT), ), ], ) def test_get_data_float_type(types, params, data_type): assert DataTypeManager().get_data_float_type(types, **params) == data_type @pytest.mark.parametrize( 'types,params,data_type', [ ( Types.decimal, {}, DataType( type='Decimal', import_=Import(from_='decimal', import_='Decimal') ), ), ( Types.decimal, {'maximum': 10}, DataType( type='condecimal', is_func=True, kwargs={'le': 10}, import_=IMPORT_CONDECIMAL, ), ), ( Types.decimal, {'exclusiveMaximum': 10},