def test_base_model_reserved_name(): field = DataModelField(name='except', data_type=DataType(type='str'), required=True) base_model = BaseModel( fields=[field], reference=Reference(name='test_model', path='test_model'), ) assert base_model.name == 'test_model' assert base_model.fields == [field] assert base_model.decorators == [] assert (base_model.render() == """class test_model(BaseModel): except_: str = Field(..., alias='except')""") field = DataModelField(name='def', data_type=DataType(type='str'), required=True, alias='def-field') base_model = BaseModel( fields=[field], reference=Reference(name='test_model', path='test_model'), ) assert base_model.name == 'test_model' assert base_model.fields == [field] assert base_model.decorators == [] assert (base_model.render() == """class test_model(BaseModel): def_: str = Field(..., alias='def-field')""")
def test_data_model(): field = DataModelFieldBase(name='a', data_type=DataType(type='str'), default="" 'abc' "", required=True) with NamedTemporaryFile('w', delete=False) as dummy_template: dummy_template.write(template) dummy_template.seek(0) dummy_template.close() B.TEMPLATE_FILE_PATH = dummy_template.name data_model = B( fields=[field], decorators=['@validate'], base_classes=[ Reference(path='base', original_name='base', name='Base') ], reference=Reference(path='test_model', name='test_model'), ) assert data_model.name == 'test_model' assert data_model.fields == [field] assert data_model.decorators == ['@validate'] assert data_model.base_class == 'Base' assert (data_model.render() == '@validate\n' '@dataclass\n' 'class test_model:\n' ' a: str')
def test_sort_data_models(): reference_a = Reference(path='A', original_name='A', name='A') reference_b = Reference(path='B', original_name='B', name='B') reference_c = Reference(path='C', original_name='C', name='C') data_type_a = DataType(reference=reference_a) data_type_b = DataType(reference=reference_b) data_type_c = DataType(reference=reference_c) reference = [ BaseModel( fields=[ DataModelField(data_type=data_type_a), DataModelFieldBase(data_type=data_type_c), ], reference=reference_a, ), BaseModel( fields=[DataModelField(data_type=data_type_b)], reference=reference_b, ), BaseModel( fields=[DataModelField(data_type=data_type_b)], reference=reference_c, ), ] unresolved, resolved, require_update_action_models = sort_data_models( reference) expected = OrderedDict() expected['B'] = reference[1] expected['C'] = reference[2] expected['A'] = reference[0] assert resolved == expected assert unresolved == [] assert require_update_action_models == ['B', 'A']
def test_data_class_base_class(): field = DataModelFieldBase(name='a', data_type=DataType(type='str'), required=True) data_class = DataClass( fields=[field], base_classes=[Reference(name='Base', original_name='Base', path='Base')], reference=Reference(name='test_model', path='test_model'), ) assert data_class.name == 'test_model' assert data_class.fields == [field] assert data_class.decorators == [] assert ( data_class.render() == '@dataclass\n' 'class test_model(Base):\n' ' a: str' )
def test_data_model_exception(): field = DataModelFieldBase( name='a', data_type=DataType(type='str'), default="" 'abc' "", required=True ) with pytest.raises(Exception, match='TEMPLATE_FILE_PATH is undefined'): C( fields=[field], reference=Reference(path='abc', original_name='abc', name='abc'), )
def test_custom_root_type_decorator(): custom_root_type = CustomRootType( fields=[DataModelFieldBase(data_type=DataType(type='str'), required=True)], decorators=['@validate'], base_classes=[Reference(name='Base', original_name='Base', path='Base')], reference=Reference(name='test_model', path='test_model'), ) assert custom_root_type.name == 'test_model' assert custom_root_type.fields == [ DataModelFieldBase(data_type=DataType(type='str'), required=True) ] assert custom_root_type.base_class == 'Base' assert ( custom_root_type.render() == '@validate\n' 'class test_model(Base):\n' ' __root__: str' )
def test_base_model(): field = DataModelField(name='a', data_type=DataType(type='str'), required=True) base_model = BaseModel( fields=[field], reference=Reference(name='test_model', path='test_model'), ) assert base_model.name == 'test_model' assert base_model.fields == [field] assert base_model.decorators == [] assert base_model.render() == 'class test_model(BaseModel):\n' ' a: str'
def test_base_model_decorator(): field = DataModelField( name='a', data_type=DataType(type='str'), default='abc', required=False ) base_model = BaseModel( fields=[field], decorators=['@validate'], base_classes=[Reference(name='Base', original_name='Base', path='Base')], reference=Reference(name='test_model', path='test_model'), ) assert base_model.name == 'test_model' assert base_model.fields == [field] assert base_model.base_class == 'Base' assert base_model.decorators == ['@validate'] assert ( base_model.render() == '@validate\n' 'class test_model(Base):\n' ' a: Optional[str] = \'abc\'' )
def test_custom_root_type_required(): custom_root_type = CustomRootType( fields=[DataModelFieldBase(data_type=DataType(type='str'), required=True)], reference=Reference(name='test_model', path='test_model'), ) assert custom_root_type.name == 'test_model' assert custom_root_type.fields == [ DataModelFieldBase(data_type=DataType(type='str'), required=True) ] assert custom_root_type.render() == ( 'class test_model(BaseModel):\n' ' __root__: str' )
def test_base_model_optional(): field = DataModelField(name='a', data_type=DataType(type='str'), default='abc', required=False) base_model = BaseModel( fields=[field], reference=Reference(name='test_model', path='test_model'), ) assert base_model.name == 'test_model' assert base_model.fields == [field] assert base_model.decorators == [] assert (base_model.render() == 'class test_model(BaseModel):\n' ' a: Optional[str] = \'abc\'')
def test_data_class_optional(): field = DataModelFieldBase(name='a', data_type=DataType(type='str'), default="'abc'", required=True) data_class = DataClass( fields=[field], reference=Reference(name='test_model', path='test_model'), ) assert data_class.name == 'test_model' assert data_class.fields == [field] assert data_class.decorators == [] assert (data_class.render() == '@dataclass\n' 'class test_model:\n' ' a: str = \'abc\'')
def test_base_model_strict_non_nullable_required(): field = DataModelField( name='a', data_type=DataType(type='str'), default='abc', required=True, nullable=False, ) base_model = BaseModel( fields=[field], reference=Reference(name='test_model', path='test_model'), ) assert base_model.name == 'test_model' assert base_model.fields == [field] assert base_model.decorators == [] assert base_model.render() == 'class test_model(BaseModel):\n' ' a: str'
def test_custom_root_type(): custom_root_type = CustomRootType( fields=[ DataModelFieldBase( name='a', data_type=DataType(type='str'), default='abc', required=False, ) ], reference=Reference(name='test_model', path='test_model'), ) assert custom_root_type.name == 'test_model' assert custom_root_type.fields == [ DataModelFieldBase( name='a', data_type=DataType(type='str'), default='abc', required=False ) ] assert custom_root_type.render() == ( 'class test_model(BaseModel):\n' ' __root__: Optional[str] = \'abc\'' )
def test_sort_data_models_unresolved(): reference_a = Reference(path='A', original_name='A', name='A') reference_b = Reference(path='B', original_name='B', name='B') reference_c = Reference(path='C', original_name='C', name='C') reference_d = Reference(path='D', original_name='D', name='D') reference_v = Reference(path='V', original_name='V', name='V') reference_z = Reference(path='Z', original_name='Z', name='Z') data_type_a = DataType(reference=reference_a) data_type_b = DataType(reference=reference_b) data_type_c = DataType(reference=reference_c) data_type_v = DataType(reference=reference_v) data_type_z = DataType(reference=reference_z) reference = [ BaseModel( fields=[ DataModelField(data_type=data_type_a), DataModelFieldBase(data_type=data_type_c), ], reference=reference_a, ), BaseModel( fields=[DataModelField(data_type=data_type_b)], reference=reference_b, ), BaseModel( fields=[DataModelField(data_type=data_type_b)], reference=reference_c, ), BaseModel( fields=[ DataModelField(data_type=data_type_a), DataModelField(data_type=data_type_c), DataModelField(data_type=data_type_z), ], reference=reference_d, ), BaseModel( fields=[DataModelField(data_type=data_type_v)], reference=reference_z, ), ] with pytest.raises(Exception): sort_data_models(reference)
def parse( self, with_import: Optional[bool] = True, format_: Optional[bool] = True, settings_path: Optional[Path] = None, ) -> Union[str, Dict[Tuple[str, ...], Result]]: self.parse_raw() if with_import: if self.target_python_version != PythonVersion.PY_36: self.imports.append(IMPORT_ANNOTATIONS) if format_: code_formatter: Optional[CodeFormatter] = CodeFormatter( self.target_python_version, settings_path) else: code_formatter = None _, sorted_data_models, require_update_action_models = sort_data_models( self.results) results: Dict[Tuple[str, ...], Result] = {} module_key = lambda x: x.module_path # process in reverse order to correctly establish module levels grouped_models = groupby( sorted(sorted_data_models.values(), key=module_key, reverse=True), key=module_key, ) module_models: List[Tuple[Tuple[str, ...], List[DataModel]]] = [] for module, models in ((k, [*v]) for k, v in grouped_models ): # type: Tuple[str, ...], List[DataModel] for model in models: if isinstance(model, self.data_model_root_type): root_data_type = model.fields[0].data_type # backward compatible # Remove duplicated root model if (root_data_type.reference and not root_data_type.is_dict and not root_data_type.is_list and root_data_type.reference.source in models and root_data_type.reference.name == self.model_resolver.get_class_name( model.reference.original_name, unique=False)): # Replace referenced duplicate model to original model for child in model.reference.children[:]: child.replace_reference(root_data_type.reference) models.remove(model) continue # Custom root model can't be inherited on restriction of Pydantic for child in model.reference.children: # inheritance model if isinstance(child, DataModel): for base_class in child.base_classes: if base_class.reference == model.reference: child.base_classes.remove(base_class) module_models.append(( module, models, )) scoped_model_resolver = ModelResolver( exclude_names={ i.alias or i.import_ for m in models for i in m.imports }, duplicate_name_suffix='Model', ) for model in models: class_name: str = model.class_name generated_name: str = scoped_model_resolver.add( model.path, class_name, unique=True, class_name=True).name if class_name != generated_name: if '.' in model.reference.name: model.reference.name = ( f"{model.reference.name.rsplit('.', 1)[0]}.{generated_name}" ) else: model.reference.name = generated_name for module, models in module_models: init = False if module: parent = (*module[:-1], '__init__.py') if parent not in results: results[parent] = Result(body='') if (*module, '__init__.py') in results: module = (*module, '__init__.py') init = True else: module = (*module[:-1], f'{module[-1]}.py') else: module = ('__init__.py', ) result: List[str] = [] imports = Imports() scoped_model_resolver = ModelResolver() for model in models: imports.append(model.imports) for data_type in model.all_data_types: # To change from/import if not data_type.reference or data_type.reference.source in models: # No need to import non-reference model. # Or, Referenced model is in the same file. we don't need to import the model continue if isinstance(data_type, BaseClassDataType): from_ = ''.join( relative(model.module_name, data_type.full_name)) import_ = data_type.reference.short_name full_path = from_, import_ else: from_, import_ = full_path = relative( model.module_name, data_type.full_name) alias = scoped_model_resolver.add(full_path, import_).name name = data_type.reference.short_name if from_ and import_ and alias != name: data_type.alias = f'{alias}.{name}' if init: from_ += "." imports.append( Import(from_=from_, import_=import_, alias=alias)) if self.reuse_model: model_cache: Dict[Tuple[str, ...], Reference] = {} duplicates = [] for model in models: model_key = tuple( to_hashable(v) for v in ( model.base_classes, model.extra_template_data, model.fields, )) cached_model_reference = model_cache.get(model_key) if cached_model_reference: if isinstance(model, Enum): for child in model.reference.children[:]: # child is resolved data_type by reference data_model = get_most_of_parent(child) # TODO: replace reference in all modules if data_model in models: # pragma: no cover child.replace_reference( cached_model_reference) duplicates.append(model) else: index = models.index(model) inherited_model = model.__class__( fields=[], base_classes=[cached_model_reference], description=model.description, reference=Reference( name=model.name, path=model.reference.path + '/reuse', ), ) if (cached_model_reference.path in require_update_action_models): require_update_action_models.append( inherited_model.path) models.insert(index, inherited_model) models.remove(model) else: model_cache[model_key] = model.reference for duplicate in duplicates: models.remove(duplicate) if self.set_default_enum_member: for model in models: for model_field in model.fields: if not model_field.default: continue for data_type in model_field.data_type.all_data_types: if data_type.reference and isinstance( data_type.reference.source, Enum): # pragma: no cover enum_member = data_type.reference.source.find_member( model_field.default) if enum_member: model_field.default = enum_member if with_import: result += [str(self.imports), str(imports), '\n'] code = dump_templates(models) result += [code] if self.dump_resolve_reference_action is not None: result += [ '\n', self.dump_resolve_reference_action( m.reference.short_name for m in models if m.path in require_update_action_models), ] body = '\n'.join(result) if code_formatter: body = code_formatter.format_code(body) results[module] = Result(body=body, source=models[0].file_path) # retain existing behaviour if [*results] == [('__init__.py', )]: return results[('__init__.py', )].body return results
def test_openapi_model_resolver(): parser = OpenAPIParser(source=(DATA_PATH / 'api.yaml')) parser.parse() assert parser.model_resolver.references == { '#/components/schemas/Event': Reference( path='#/components/schemas/Event', original_name='Event', name='Event', loaded=True, ), '#/components/schemas/Pet': Reference( path='#/components/schemas/Pet', original_name='Pet', name='Pet', loaded=True, ), 'api.yaml#/components/schemas/Error': Reference( path='api.yaml#/components/schemas/Error', original_name='Error', name='Error', loaded=True, ), 'api.yaml#/components/schemas/Event': Reference( path='api.yaml#/components/schemas/Event', original_name='Event', name='Event', loaded=True, ), 'api.yaml#/components/schemas/Id': Reference( path='api.yaml#/components/schemas/Id', original_name='Id', name='Id', loaded=True, ), 'api.yaml#/components/schemas/Pet': Reference( path='api.yaml#/components/schemas/Pet', original_name='Pet', name='Pet', loaded=True, ), 'api.yaml#/components/schemas/Pets': Reference( path='api.yaml#/components/schemas/Pets', original_name='Pets', name='Pets', loaded=True, ), 'api.yaml#/components/schemas/Result': Reference( path='api.yaml#/components/schemas/Result', original_name='Result', name='Result', loaded=True, ), 'api.yaml#/components/schemas/Rules': Reference( path='api.yaml#/components/schemas/Rules', original_name='Rules', name='Rules', loaded=True, ), 'api.yaml#/components/schemas/Users': Reference( path='api.yaml#/components/schemas/Users', original_name='Users', name='Users', loaded=True, ), 'api.yaml#/components/schemas/Users/Users': Reference( path='api.yaml#/components/schemas/Users/Users', original_name='Users', name='User', loaded=True, ), 'api.yaml#/components/schemas/apis': Reference( path='api.yaml#/components/schemas/apis', original_name='apis', name='Apis', loaded=True, ), 'api.yaml#/components/schemas/apis/Apis': Reference( path='api.yaml#/components/schemas/apis/Apis', original_name='Apis', name='Api', loaded=True, ), }
def __init__(self, filename: str, data: str, fields: List[DataModelFieldBase]): super().__init__(fields=fields, reference=Reference('')) self._data = data