Exemple #1
0
def test_base_model_reserved_name():
    field = DataModelField(name='except',
                           data_type=DataType(type='str'),
                           required=True)

    base_model = BaseModel(
        fields=[field],
        reference=Reference(name='test_model', path='test_model'),
    )

    assert base_model.name == 'test_model'
    assert base_model.fields == [field]
    assert base_model.decorators == []
    assert (base_model.render() == """class test_model(BaseModel):
    except_: str = Field(..., alias='except')""")

    field = DataModelField(name='def',
                           data_type=DataType(type='str'),
                           required=True,
                           alias='def-field')

    base_model = BaseModel(
        fields=[field],
        reference=Reference(name='test_model', path='test_model'),
    )

    assert base_model.name == 'test_model'
    assert base_model.fields == [field]
    assert base_model.decorators == []
    assert (base_model.render() == """class test_model(BaseModel):
    def_: str = Field(..., alias='def-field')""")
def test_data_model():
    field = DataModelFieldBase(name='a',
                               data_type=DataType(type='str'),
                               default=""
                               'abc'
                               "",
                               required=True)

    with NamedTemporaryFile('w', delete=False) as dummy_template:
        dummy_template.write(template)
        dummy_template.seek(0)
        dummy_template.close()
        B.TEMPLATE_FILE_PATH = dummy_template.name
        data_model = B(
            fields=[field],
            decorators=['@validate'],
            base_classes=[
                Reference(path='base', original_name='base', name='Base')
            ],
            reference=Reference(path='test_model', name='test_model'),
        )

    assert data_model.name == 'test_model'
    assert data_model.fields == [field]
    assert data_model.decorators == ['@validate']
    assert data_model.base_class == 'Base'
    assert (data_model.render() == '@validate\n'
            '@dataclass\n'
            'class test_model:\n'
            '    a: str')
def test_sort_data_models():
    reference_a = Reference(path='A', original_name='A', name='A')
    reference_b = Reference(path='B', original_name='B', name='B')
    reference_c = Reference(path='C', original_name='C', name='C')
    data_type_a = DataType(reference=reference_a)
    data_type_b = DataType(reference=reference_b)
    data_type_c = DataType(reference=reference_c)
    reference = [
        BaseModel(
            fields=[
                DataModelField(data_type=data_type_a),
                DataModelFieldBase(data_type=data_type_c),
            ],
            reference=reference_a,
        ),
        BaseModel(
            fields=[DataModelField(data_type=data_type_b)],
            reference=reference_b,
        ),
        BaseModel(
            fields=[DataModelField(data_type=data_type_b)],
            reference=reference_c,
        ),
    ]

    unresolved, resolved, require_update_action_models = sort_data_models(
        reference)
    expected = OrderedDict()
    expected['B'] = reference[1]
    expected['C'] = reference[2]
    expected['A'] = reference[0]

    assert resolved == expected
    assert unresolved == []
    assert require_update_action_models == ['B', 'A']
Exemple #4
0
def test_data_class_base_class():
    field = DataModelFieldBase(name='a', data_type=DataType(type='str'), required=True)

    data_class = DataClass(
        fields=[field],
        base_classes=[Reference(name='Base', original_name='Base', path='Base')],
        reference=Reference(name='test_model', path='test_model'),
    )

    assert data_class.name == 'test_model'
    assert data_class.fields == [field]
    assert data_class.decorators == []
    assert (
        data_class.render() == '@dataclass\n' 'class test_model(Base):\n' '    a: str'
    )
Exemple #5
0
def test_data_model_exception():
    field = DataModelFieldBase(
        name='a', data_type=DataType(type='str'), default="" 'abc' "", required=True
    )
    with pytest.raises(Exception, match='TEMPLATE_FILE_PATH is undefined'):
        C(
            fields=[field],
            reference=Reference(path='abc', original_name='abc', name='abc'),
        )
def test_custom_root_type_decorator():
    custom_root_type = CustomRootType(
        fields=[DataModelFieldBase(data_type=DataType(type='str'), required=True)],
        decorators=['@validate'],
        base_classes=[Reference(name='Base', original_name='Base', path='Base')],
        reference=Reference(name='test_model', path='test_model'),
    )

    assert custom_root_type.name == 'test_model'
    assert custom_root_type.fields == [
        DataModelFieldBase(data_type=DataType(type='str'), required=True)
    ]
    assert custom_root_type.base_class == 'Base'
    assert (
        custom_root_type.render() == '@validate\n'
        'class test_model(Base):\n'
        '    __root__: str'
    )
def test_base_model():
    field = DataModelField(name='a', data_type=DataType(type='str'), required=True)

    base_model = BaseModel(
        fields=[field], reference=Reference(name='test_model', path='test_model'),
    )

    assert base_model.name == 'test_model'
    assert base_model.fields == [field]
    assert base_model.decorators == []
    assert base_model.render() == 'class test_model(BaseModel):\n' '    a: str'
Exemple #8
0
def test_base_model_decorator():
    field = DataModelField(
        name='a', data_type=DataType(type='str'), default='abc', required=False
    )

    base_model = BaseModel(
        fields=[field],
        decorators=['@validate'],
        base_classes=[Reference(name='Base', original_name='Base', path='Base')],
        reference=Reference(name='test_model', path='test_model'),
    )

    assert base_model.name == 'test_model'
    assert base_model.fields == [field]
    assert base_model.base_class == 'Base'
    assert base_model.decorators == ['@validate']
    assert (
        base_model.render() == '@validate\n'
        'class test_model(Base):\n'
        '    a: Optional[str] = \'abc\''
    )
def test_custom_root_type_required():
    custom_root_type = CustomRootType(
        fields=[DataModelFieldBase(data_type=DataType(type='str'), required=True)],
        reference=Reference(name='test_model', path='test_model'),
    )

    assert custom_root_type.name == 'test_model'
    assert custom_root_type.fields == [
        DataModelFieldBase(data_type=DataType(type='str'), required=True)
    ]

    assert custom_root_type.render() == (
        'class test_model(BaseModel):\n' '    __root__: str'
    )
Exemple #10
0
def test_base_model_optional():
    field = DataModelField(name='a',
                           data_type=DataType(type='str'),
                           default='abc',
                           required=False)

    base_model = BaseModel(
        fields=[field],
        reference=Reference(name='test_model', path='test_model'),
    )

    assert base_model.name == 'test_model'
    assert base_model.fields == [field]
    assert base_model.decorators == []
    assert (base_model.render() == 'class test_model(BaseModel):\n'
            '    a: Optional[str] = \'abc\'')
def test_data_class_optional():
    field = DataModelFieldBase(name='a',
                               data_type=DataType(type='str'),
                               default="'abc'",
                               required=True)

    data_class = DataClass(
        fields=[field],
        reference=Reference(name='test_model', path='test_model'),
    )

    assert data_class.name == 'test_model'
    assert data_class.fields == [field]
    assert data_class.decorators == []
    assert (data_class.render() == '@dataclass\n'
            'class test_model:\n'
            '    a: str = \'abc\'')
def test_base_model_strict_non_nullable_required():
    field = DataModelField(
        name='a',
        data_type=DataType(type='str'),
        default='abc',
        required=True,
        nullable=False,
    )

    base_model = BaseModel(
        fields=[field], reference=Reference(name='test_model', path='test_model'),
    )

    assert base_model.name == 'test_model'
    assert base_model.fields == [field]
    assert base_model.decorators == []
    assert base_model.render() == 'class test_model(BaseModel):\n' '    a: str'
def test_custom_root_type():
    custom_root_type = CustomRootType(
        fields=[
            DataModelFieldBase(
                name='a',
                data_type=DataType(type='str'),
                default='abc',
                required=False,
            )
        ],
        reference=Reference(name='test_model', path='test_model'),
    )

    assert custom_root_type.name == 'test_model'
    assert custom_root_type.fields == [
        DataModelFieldBase(
            name='a', data_type=DataType(type='str'), default='abc', required=False
        )
    ]

    assert custom_root_type.render() == (
        'class test_model(BaseModel):\n' '    __root__: Optional[str] = \'abc\''
    )
def test_sort_data_models_unresolved():
    reference_a = Reference(path='A', original_name='A', name='A')
    reference_b = Reference(path='B', original_name='B', name='B')
    reference_c = Reference(path='C', original_name='C', name='C')
    reference_d = Reference(path='D', original_name='D', name='D')
    reference_v = Reference(path='V', original_name='V', name='V')
    reference_z = Reference(path='Z', original_name='Z', name='Z')
    data_type_a = DataType(reference=reference_a)
    data_type_b = DataType(reference=reference_b)
    data_type_c = DataType(reference=reference_c)
    data_type_v = DataType(reference=reference_v)
    data_type_z = DataType(reference=reference_z)
    reference = [
        BaseModel(
            fields=[
                DataModelField(data_type=data_type_a),
                DataModelFieldBase(data_type=data_type_c),
            ],
            reference=reference_a,
        ),
        BaseModel(
            fields=[DataModelField(data_type=data_type_b)],
            reference=reference_b,
        ),
        BaseModel(
            fields=[DataModelField(data_type=data_type_b)],
            reference=reference_c,
        ),
        BaseModel(
            fields=[
                DataModelField(data_type=data_type_a),
                DataModelField(data_type=data_type_c),
                DataModelField(data_type=data_type_z),
            ],
            reference=reference_d,
        ),
        BaseModel(
            fields=[DataModelField(data_type=data_type_v)],
            reference=reference_z,
        ),
    ]

    with pytest.raises(Exception):
        sort_data_models(reference)
    def parse(
        self,
        with_import: Optional[bool] = True,
        format_: Optional[bool] = True,
        settings_path: Optional[Path] = None,
    ) -> Union[str, Dict[Tuple[str, ...], Result]]:

        self.parse_raw()

        if with_import:
            if self.target_python_version != PythonVersion.PY_36:
                self.imports.append(IMPORT_ANNOTATIONS)

        if format_:
            code_formatter: Optional[CodeFormatter] = CodeFormatter(
                self.target_python_version, settings_path)
        else:
            code_formatter = None

        _, sorted_data_models, require_update_action_models = sort_data_models(
            self.results)

        results: Dict[Tuple[str, ...], Result] = {}

        module_key = lambda x: x.module_path

        # process in reverse order to correctly establish module levels
        grouped_models = groupby(
            sorted(sorted_data_models.values(), key=module_key, reverse=True),
            key=module_key,
        )

        module_models: List[Tuple[Tuple[str, ...], List[DataModel]]] = []

        for module, models in ((k, [*v]) for k, v in grouped_models
                               ):  # type: Tuple[str, ...], List[DataModel]

            for model in models:
                if isinstance(model, self.data_model_root_type):
                    root_data_type = model.fields[0].data_type

                    # backward compatible
                    # Remove duplicated root model
                    if (root_data_type.reference and not root_data_type.is_dict
                            and not root_data_type.is_list
                            and root_data_type.reference.source in models
                            and root_data_type.reference.name
                            == self.model_resolver.get_class_name(
                                model.reference.original_name, unique=False)):
                        # Replace referenced duplicate model to original model
                        for child in model.reference.children[:]:
                            child.replace_reference(root_data_type.reference)
                        models.remove(model)
                        continue

                    #  Custom root model can't be inherited on restriction of Pydantic
                    for child in model.reference.children:
                        # inheritance model
                        if isinstance(child, DataModel):
                            for base_class in child.base_classes:
                                if base_class.reference == model.reference:
                                    child.base_classes.remove(base_class)

            module_models.append((
                module,
                models,
            ))

            scoped_model_resolver = ModelResolver(
                exclude_names={
                    i.alias or i.import_
                    for m in models for i in m.imports
                },
                duplicate_name_suffix='Model',
            )

            for model in models:
                class_name: str = model.class_name
                generated_name: str = scoped_model_resolver.add(
                    model.path, class_name, unique=True, class_name=True).name
                if class_name != generated_name:
                    if '.' in model.reference.name:
                        model.reference.name = (
                            f"{model.reference.name.rsplit('.', 1)[0]}.{generated_name}"
                        )
                    else:
                        model.reference.name = generated_name

        for module, models in module_models:
            init = False
            if module:
                parent = (*module[:-1], '__init__.py')
                if parent not in results:
                    results[parent] = Result(body='')
                if (*module, '__init__.py') in results:
                    module = (*module, '__init__.py')
                    init = True
                else:
                    module = (*module[:-1], f'{module[-1]}.py')
            else:
                module = ('__init__.py', )

            result: List[str] = []
            imports = Imports()
            scoped_model_resolver = ModelResolver()

            for model in models:
                imports.append(model.imports)
                for data_type in model.all_data_types:
                    # To change from/import

                    if not data_type.reference or data_type.reference.source in models:
                        # No need to import non-reference model.
                        # Or, Referenced model is in the same file. we don't need to import the model
                        continue

                    if isinstance(data_type, BaseClassDataType):
                        from_ = ''.join(
                            relative(model.module_name, data_type.full_name))
                        import_ = data_type.reference.short_name
                        full_path = from_, import_
                    else:
                        from_, import_ = full_path = relative(
                            model.module_name, data_type.full_name)

                    alias = scoped_model_resolver.add(full_path, import_).name

                    name = data_type.reference.short_name
                    if from_ and import_ and alias != name:
                        data_type.alias = f'{alias}.{name}'

                    if init:
                        from_ += "."
                    imports.append(
                        Import(from_=from_, import_=import_, alias=alias))

            if self.reuse_model:
                model_cache: Dict[Tuple[str, ...], Reference] = {}
                duplicates = []
                for model in models:
                    model_key = tuple(
                        to_hashable(v) for v in (
                            model.base_classes,
                            model.extra_template_data,
                            model.fields,
                        ))
                    cached_model_reference = model_cache.get(model_key)
                    if cached_model_reference:
                        if isinstance(model, Enum):
                            for child in model.reference.children[:]:
                                # child is resolved data_type by reference
                                data_model = get_most_of_parent(child)
                                # TODO: replace reference in all modules
                                if data_model in models:  # pragma: no cover
                                    child.replace_reference(
                                        cached_model_reference)
                            duplicates.append(model)
                        else:
                            index = models.index(model)
                            inherited_model = model.__class__(
                                fields=[],
                                base_classes=[cached_model_reference],
                                description=model.description,
                                reference=Reference(
                                    name=model.name,
                                    path=model.reference.path + '/reuse',
                                ),
                            )
                            if (cached_model_reference.path
                                    in require_update_action_models):
                                require_update_action_models.append(
                                    inherited_model.path)
                            models.insert(index, inherited_model)
                            models.remove(model)

                    else:
                        model_cache[model_key] = model.reference

                for duplicate in duplicates:
                    models.remove(duplicate)

            if self.set_default_enum_member:
                for model in models:
                    for model_field in model.fields:
                        if not model_field.default:
                            continue
                        for data_type in model_field.data_type.all_data_types:
                            if data_type.reference and isinstance(
                                    data_type.reference.source,
                                    Enum):  # pragma: no cover
                                enum_member = data_type.reference.source.find_member(
                                    model_field.default)
                                if enum_member:
                                    model_field.default = enum_member
            if with_import:
                result += [str(self.imports), str(imports), '\n']

            code = dump_templates(models)
            result += [code]

            if self.dump_resolve_reference_action is not None:
                result += [
                    '\n',
                    self.dump_resolve_reference_action(
                        m.reference.short_name for m in models
                        if m.path in require_update_action_models),
                ]

            body = '\n'.join(result)
            if code_formatter:
                body = code_formatter.format_code(body)

            results[module] = Result(body=body, source=models[0].file_path)

        # retain existing behaviour
        if [*results] == [('__init__.py', )]:
            return results[('__init__.py', )].body

        return results
def test_openapi_model_resolver():
    parser = OpenAPIParser(source=(DATA_PATH / 'api.yaml'))
    parser.parse()

    assert parser.model_resolver.references == {
        '#/components/schemas/Event': Reference(
            path='#/components/schemas/Event',
            original_name='Event',
            name='Event',
            loaded=True,
        ),
        '#/components/schemas/Pet': Reference(
            path='#/components/schemas/Pet',
            original_name='Pet',
            name='Pet',
            loaded=True,
        ),
        'api.yaml#/components/schemas/Error': Reference(
            path='api.yaml#/components/schemas/Error',
            original_name='Error',
            name='Error',
            loaded=True,
        ),
        'api.yaml#/components/schemas/Event': Reference(
            path='api.yaml#/components/schemas/Event',
            original_name='Event',
            name='Event',
            loaded=True,
        ),
        'api.yaml#/components/schemas/Id': Reference(
            path='api.yaml#/components/schemas/Id',
            original_name='Id',
            name='Id',
            loaded=True,
        ),
        'api.yaml#/components/schemas/Pet': Reference(
            path='api.yaml#/components/schemas/Pet',
            original_name='Pet',
            name='Pet',
            loaded=True,
        ),
        'api.yaml#/components/schemas/Pets': Reference(
            path='api.yaml#/components/schemas/Pets',
            original_name='Pets',
            name='Pets',
            loaded=True,
        ),
        'api.yaml#/components/schemas/Result': Reference(
            path='api.yaml#/components/schemas/Result',
            original_name='Result',
            name='Result',
            loaded=True,
        ),
        'api.yaml#/components/schemas/Rules': Reference(
            path='api.yaml#/components/schemas/Rules',
            original_name='Rules',
            name='Rules',
            loaded=True,
        ),
        'api.yaml#/components/schemas/Users': Reference(
            path='api.yaml#/components/schemas/Users',
            original_name='Users',
            name='Users',
            loaded=True,
        ),
        'api.yaml#/components/schemas/Users/Users': Reference(
            path='api.yaml#/components/schemas/Users/Users',
            original_name='Users',
            name='User',
            loaded=True,
        ),
        'api.yaml#/components/schemas/apis': Reference(
            path='api.yaml#/components/schemas/apis',
            original_name='apis',
            name='Apis',
            loaded=True,
        ),
        'api.yaml#/components/schemas/apis/Apis': Reference(
            path='api.yaml#/components/schemas/apis/Apis',
            original_name='Apis',
            name='Api',
            loaded=True,
        ),
    }
 def __init__(self, filename: str, data: str,
              fields: List[DataModelFieldBase]):
     super().__init__(fields=fields, reference=Reference(''))
     self._data = data