コード例 #1
0
def build_model_file(
    parsed_document: str,
    model_id: str,
    section_name: str,
    model_info: ModelInfo,
    code_formatter: CodeFormatter,
):
    """
    :param parsed_document: OpenApi parsed document
    :param model_id: instance or shared
    :param section_name: init or instances
    :param model_info: Information to build the model file
    :param code_formatter:
    """
    # Whether or not there are options with default values
    options_with_defaults = len(model_info.defaults_file_lines) > 0
    model_file_lines = parsed_document.splitlines()
    _add_imports(model_file_lines, options_with_defaults,
                 len(model_info.deprecation_data))

    if model_id in model_info.deprecation_data:
        model_file_lines += _define_deprecation_functions(
            model_id, section_name)

    model_file_lines += _define_validator_functions(model_id,
                                                    model_info.validator_data)

    model_file_lines.append('')
    model_file_contents = '\n'.join(model_file_lines)
    if any(len(line) > 120 for line in model_file_lines):
        model_file_contents = code_formatter.apply_black(model_file_contents)
    return model_file_contents
コード例 #2
0
    def parse(
        self,
        with_import: Optional[bool] = True,
        format_: Optional[bool] = True,
        settings_path: Optional[Path] = None,
    ) -> Union[str, Dict[Tuple[str, ...], Result]]:

        self.parse_raw()

        if with_import:
            if self.target_python_version != PythonVersion.PY_36:
                self.imports.append(IMPORT_ANNOTATIONS)

        if format_:
            code_formatter: Optional[CodeFormatter] = CodeFormatter(
                self.target_python_version, settings_path)
        else:
            code_formatter = None

        _, sorted_data_models, require_update_action_models = sort_data_models(
            self.results)

        results: Dict[Tuple[str, ...], Result] = {}

        module_key = lambda x: x.module_path

        # process in reverse order to correctly establish module levels
        grouped_models = groupby(
            sorted(sorted_data_models.values(), key=module_key, reverse=True),
            key=module_key,
        )

        module_models: List[Tuple[Tuple[str, ...], List[DataModel]]] = []

        for module, models in ((k, [*v]) for k, v in grouped_models
                               ):  # type: Tuple[str, ...], List[DataModel]

            for model in models:
                if isinstance(model, self.data_model_root_type):
                    root_data_type = model.fields[0].data_type

                    # backward compatible
                    # Remove duplicated root model
                    if (root_data_type.reference and not root_data_type.is_dict
                            and not root_data_type.is_list
                            and root_data_type.reference.source in models
                            and root_data_type.reference.name
                            == self.model_resolver.get_class_name(
                                model.reference.original_name, unique=False)):
                        # Replace referenced duplicate model to original model
                        for child in model.reference.children[:]:
                            child.replace_reference(root_data_type.reference)
                        models.remove(model)
                        continue

                    #  Custom root model can't be inherited on restriction of Pydantic
                    for child in model.reference.children:
                        # inheritance model
                        if isinstance(child, DataModel):
                            for base_class in child.base_classes:
                                if base_class.reference == model.reference:
                                    child.base_classes.remove(base_class)

            module_models.append((
                module,
                models,
            ))

            scoped_model_resolver = ModelResolver(
                exclude_names={
                    i.alias or i.import_
                    for m in models for i in m.imports
                },
                duplicate_name_suffix='Model',
            )

            for model in models:
                class_name: str = model.class_name
                generated_name: str = scoped_model_resolver.add(
                    model.path, class_name, unique=True, class_name=True).name
                if class_name != generated_name:
                    if '.' in model.reference.name:
                        model.reference.name = (
                            f"{model.reference.name.rsplit('.', 1)[0]}.{generated_name}"
                        )
                    else:
                        model.reference.name = generated_name

        for module, models in module_models:
            init = False
            if module:
                parent = (*module[:-1], '__init__.py')
                if parent not in results:
                    results[parent] = Result(body='')
                if (*module, '__init__.py') in results:
                    module = (*module, '__init__.py')
                    init = True
                else:
                    module = (*module[:-1], f'{module[-1]}.py')
            else:
                module = ('__init__.py', )

            result: List[str] = []
            imports = Imports()
            scoped_model_resolver = ModelResolver()

            for model in models:
                imports.append(model.imports)
                for data_type in model.all_data_types:
                    # To change from/import

                    if not data_type.reference or data_type.reference.source in models:
                        # No need to import non-reference model.
                        # Or, Referenced model is in the same file. we don't need to import the model
                        continue

                    if isinstance(data_type, BaseClassDataType):
                        from_ = ''.join(
                            relative(model.module_name, data_type.full_name))
                        import_ = data_type.reference.short_name
                        full_path = from_, import_
                    else:
                        from_, import_ = full_path = relative(
                            model.module_name, data_type.full_name)

                    alias = scoped_model_resolver.add(full_path, import_).name

                    name = data_type.reference.short_name
                    if from_ and import_ and alias != name:
                        data_type.alias = f'{alias}.{name}'

                    if init:
                        from_ += "."
                    imports.append(
                        Import(from_=from_, import_=import_, alias=alias))

            if self.reuse_model:
                model_cache: Dict[Tuple[str, ...], Reference] = {}
                duplicates = []
                for model in models:
                    model_key = tuple(
                        to_hashable(v) for v in (
                            model.base_classes,
                            model.extra_template_data,
                            model.fields,
                        ))
                    cached_model_reference = model_cache.get(model_key)
                    if cached_model_reference:
                        if isinstance(model, Enum):
                            for child in model.reference.children[:]:
                                # child is resolved data_type by reference
                                data_model = get_most_of_parent(child)
                                # TODO: replace reference in all modules
                                if data_model in models:  # pragma: no cover
                                    child.replace_reference(
                                        cached_model_reference)
                            duplicates.append(model)
                        else:
                            index = models.index(model)
                            inherited_model = model.__class__(
                                fields=[],
                                base_classes=[cached_model_reference],
                                description=model.description,
                                reference=Reference(
                                    name=model.name,
                                    path=model.reference.path + '/reuse',
                                ),
                            )
                            if (cached_model_reference.path
                                    in require_update_action_models):
                                require_update_action_models.append(
                                    inherited_model.path)
                            models.insert(index, inherited_model)
                            models.remove(model)

                    else:
                        model_cache[model_key] = model.reference

                for duplicate in duplicates:
                    models.remove(duplicate)

            if self.set_default_enum_member:
                for model in models:
                    for model_field in model.fields:
                        if not model_field.default:
                            continue
                        for data_type in model_field.data_type.all_data_types:
                            if data_type.reference and isinstance(
                                    data_type.reference.source,
                                    Enum):  # pragma: no cover
                                enum_member = data_type.reference.source.find_member(
                                    model_field.default)
                                if enum_member:
                                    model_field.default = enum_member
            if with_import:
                result += [str(self.imports), str(imports), '\n']

            code = dump_templates(models)
            result += [code]

            if self.dump_resolve_reference_action is not None:
                result += [
                    '\n',
                    self.dump_resolve_reference_action(
                        m.reference.short_name for m in models
                        if m.path in require_update_action_models),
                ]

            body = '\n'.join(result)
            if code_formatter:
                body = code_formatter.format_code(body)

            results[module] = Result(body=body, source=models[0].file_path)

        # retain existing behaviour
        if [*results] == [('__init__.py', )]:
            return results[('__init__.py', )].body

        return results
コード例 #3
0
    def parse(
        self,
        with_import: Optional[bool] = True,
        format_: Optional[bool] = True,
        settings_path: Optional[Path] = None,
    ) -> Union[str, Dict[Tuple[str, ...], Result]]:

        self.parse_raw()

        if with_import:
            if self.target_python_version != PythonVersion.PY_36:
                self.imports.append(IMPORT_ANNOTATIONS)

        if format_:
            code_formatter: Optional[CodeFormatter] = CodeFormatter(
                self.target_python_version, settings_path)
        else:
            code_formatter = None

        _, sorted_data_models, require_update_action_models = sort_data_models(
            self.results)

        results: Dict[Tuple[str, ...], Result] = {}

        module_key = lambda x: x.module_path

        # process in reverse order to correctly establish module levels
        grouped_models = groupby(
            sorted(sorted_data_models.values(), key=module_key, reverse=True),
            key=module_key,
        )
        for module, models in ((k, [*v]) for k, v in grouped_models
                               ):  # type: Tuple[str, ...], List[DataModel]
            module_path = '.'.join(module)

            init = False
            if module:
                parent = (*module[:-1], '__init__.py')
                if parent not in results:
                    results[parent] = Result(body='')
                if (*module, '__init__.py') in results:
                    module = (*module, '__init__.py')
                    init = True
                else:
                    module = (*module[:-1], f'{module[-1]}.py')
            else:
                module = ('__init__.py', )

            result: List[str] = []
            imports = Imports()
            models_to_update: List[str] = []
            scoped_model_resolver = ModelResolver()
            import_map: Dict[str, Tuple[str, str]] = {}
            model_names: Set[str] = {m.name for m in models}
            processed_models: Set[str] = set()
            model_cache: Dict[Tuple[str, ...], str] = {}
            duplicated_model_names: Dict[str, str] = {}
            for model in models:
                alias_map: Dict[str, Optional[str]] = {}
                if model.name in require_update_action_models:
                    ref_names = {
                        d.reference.name
                        for d in model.all_data_types if d.reference
                    }
                    if model.name in ref_names or (not ref_names - model_names
                                                   and ref_names -
                                                   processed_models):
                        models_to_update += [model.name]
                        processed_models.add(model.name)
                imports.append(model.imports)
                for data_type in model.all_data_types:
                    if not data_type.is_modular:
                        continue
                    full_name = data_type.full_name
                    from_, import_ = relative(module_path, full_name)
                    name = data_type.name
                    if data_type.reference:
                        reference = self.model_resolver.get(
                            data_type.reference.path)
                        if reference and (
                            (isinstance(self.source, Path)
                             and self.source.is_file() and self.source.name
                             == reference.path.split('#/')[0]) or
                                reference.actual_module_name == module_path):
                            if name in model.reference_classes:  # pragma: no cover
                                model.reference_classes.remove(name)
                            continue
                    full_path = f'{from_}/{import_}'
                    if full_path in alias_map:
                        alias = alias_map[full_path] or import_
                    else:
                        alias = scoped_model_resolver.add(full_path.split('/'),
                                                          import_,
                                                          unique=True).name
                        alias_map[
                            full_path] = None if alias == import_ else alias
                    new_name = f'{alias}.{name}' if from_ and import_ else name
                    if data_type.module_name and not full_name.startswith(
                            from_):
                        import_map[new_name] = (
                            f'.{full_name[:len(new_name) * - 1 - 1]}',
                            new_name.split('.')[0],
                        )
                    if name in model.reference_classes:
                        model.reference_classes.remove(name)
                        model.reference_classes.add(new_name)
                    data_type.type = new_name

                for ref_name in model.reference_classes:
                    if ref_name in model_names:
                        continue
                    elif ref_name in import_map:
                        from_, import_ = import_map[ref_name]
                    else:
                        from_, import_ = relative(module_path, ref_name)
                    if init:
                        from_ += "."
                    imports.append(
                        Import(
                            from_=from_,
                            import_=import_,
                            alias=alias_map.get(f'{from_}/{import_}'),
                        ))

                if self.reuse_model:
                    model_key = tuple(
                        to_hashable(v) for v in (
                            model.base_classes,
                            model.extra_template_data,
                            model.fields,
                        ))
                    cached_model_name = model_cache.get(model_key)
                    if cached_model_name:
                        if isinstance(model, Enum):
                            duplicated_model_names[
                                model.name] = cached_model_name
                        else:
                            index = models.index(model)
                            inherited_model = model.__class__(
                                name=model.name,
                                fields=[],
                                base_classes=[cached_model_name],
                                description=model.description,
                            )
                            models.insert(index, inherited_model)
                            models.remove(model)

                    else:
                        model_cache[model_key] = model.name

            if self.reuse_model:
                for model in models:
                    for data_type in model.all_data_types:  # pragma: no cover
                        if data_type.type:
                            duplicated_model_name = duplicated_model_names.get(
                                data_type.type)
                            if duplicated_model_name:
                                data_type.type = duplicated_model_name

            if with_import:
                result += [str(imports), str(self.imports), '\n']

            code = dump_templates(models)
            result += [code]

            if self.dump_resolve_reference_action is not None:
                result += [
                    '\n',
                    self.dump_resolve_reference_action(models_to_update)
                ]

            body = '\n'.join(result)
            if code_formatter:
                body = code_formatter.format_code(body)

            results[module] = Result(body=body, source=models[0].path)

        # retain existing behaviour
        if [*results] == [('__init__.py', )]:
            return results[('__init__.py', )].body

        return results
コード例 #4
0
 def create_code_formatter():
     return CodeFormatter(PYTHON_VERSION)
コード例 #5
0
def generate_code(
    input_name: str,
    input_text: str,
    output_dir: Path,
    template_dir: Optional[Path],
    model_path: Optional[Path] = None,
    enum_field_as_literal: Optional[str] = None,
) -> None:
    if not model_path:
        model_path = MODEL_PATH
    if not output_dir.exists():
        output_dir.mkdir(parents=True)
    if not template_dir:
        template_dir = BUILTIN_TEMPLATE_DIR
    if enum_field_as_literal:
        parser = OpenAPIParser(input_text, enum_field_as_literal=enum_field_as_literal)
    else:
        parser = OpenAPIParser(input_text)
    with chdir(output_dir):
        models = parser.parse()
    if not models:
        return
    elif isinstance(models, str):
        output = output_dir / model_path
        modules = {output: (models, input_name)}
    else:
        raise Exception('Modular references are not supported in this version')

    environment: Environment = Environment(
        loader=FileSystemLoader(
            template_dir if template_dir else f"{Path(__file__).parent}/template",
            encoding="utf8",
        ),
    )
    imports = Imports()
    imports.update(parser.imports)
    for data_type in parser.data_types:
        reference = _get_most_of_reference(data_type)
        if reference:
            imports.append(data_type.all_imports)
            imports.append(
                Import.from_full_path(f'.{model_path.stem}.{reference.name}')
            )
    for from_, imports_ in parser.imports_for_fastapi.items():
        imports[from_].update(imports_)
    results: Dict[Path, str] = {}
    code_formatter = CodeFormatter(PythonVersion.PY_38, Path().resolve())
    sorted_operations: List[Operation] = sorted(
        parser.operations.values(), key=lambda m: m.path
    )
    for target in template_dir.rglob("*"):
        relative_path = target.relative_to(template_dir)
        result = environment.get_template(str(relative_path)).render(
            operations=sorted_operations, imports=imports, info=parser.parse_info(),
        )
        results[relative_path] = code_formatter.format_code(result)

    timestamp = datetime.now(timezone.utc).replace(microsecond=0).isoformat()
    header = f"""\
# generated by fastapi-codegen:
#   filename:  {Path(input_name).name}
#   timestamp: {timestamp}"""

    for path, code in results.items():
        with output_dir.joinpath(path.with_suffix(".py")).open("wt") as file:
            print(header, file=file)
            print("", file=file)
            print(code.rstrip(), file=file)

    header = f'''\
# generated by fastapi-codegen:
#   filename:  {{filename}}'''
    #     if not disable_timestamp:
    header += f'\n#   timestamp: {timestamp}'

    for path, body_and_filename in modules.items():
        body, filename = body_and_filename
        if path is None:
            file = None
        else:
            if not path.parent.exists():
                path.parent.mkdir(parents=True)
            file = path.open('wt', encoding='utf8')

        print(header.format(filename=filename), file=file)
        if body:
            print('', file=file)
            print(body.rstrip(), file=file)

        if file is not None:
            file.close()
コード例 #6
0
def generate_code(input_name: str, input_text: str, output_dir: Path,
                  template_dir: Optional[Path]) -> None:
    if not output_dir.exists():
        output_dir.mkdir(parents=True)
    if not template_dir:
        template_dir = BUILTIN_TEMPLATE_DIR

    model_parser = OpenAPIModelParser(source=input_text, )

    parser = OpenAPIParser(input_name,
                           input_text,
                           openapi_model_parser=model_parser)
    parsed_object: ParsedObject = parser.parse()

    environment: Environment = Environment(loader=FileSystemLoader(
        template_dir if template_dir else f"{Path(__file__).parent}/template",
        encoding="utf8",
    ), )
    results: Dict[Path, str] = {}
    code_formatter = CodeFormatter(PythonVersion.PY_38, Path().resolve())
    for target in template_dir.rglob("*"):
        relative_path = target.relative_to(template_dir)
        result = environment.get_template(str(relative_path)).render(
            operations=parsed_object.operations,
            imports=parsed_object.imports,
            info=parsed_object.info,
        )
        results[relative_path] = code_formatter.format_code(result)

    timestamp = datetime.now(timezone.utc).replace(microsecond=0).isoformat()
    header = f"""\
# generated by fastapi-codegen:
#   filename:  {Path(input_name).name}
#   timestamp: {timestamp}"""

    for path, code in results.items():
        with output_dir.joinpath(path.with_suffix(".py")).open("wt") as file:
            print(header, file=file)
            print("", file=file)
            print(code.rstrip(), file=file)

    with chdir(output_dir):
        results = model_parser.parse()
    if not results:
        return
    elif isinstance(results, str):
        output = output_dir / MODEL_PATH
        modules = {output: (results, input_name)}
    else:
        raise Exception('Modular references are not supported in this version')

    header = f'''\
# generated by fastapi-codegen:
#   filename:  {{filename}}'''
    #     if not disable_timestamp:
    header += f'\n#   timestamp: {timestamp}'

    for path, body_and_filename in modules.items():
        body, filename = body_and_filename
        if path is None:
            file = None
        else:
            if not path.parent.exists():
                path.parent.mkdir(parents=True)
            file = path.open('wt', encoding='utf8')

        print(header.format(filename=filename), file=file)
        if body:
            print('', file=file)
            print(body.rstrip(), file=file)

        if file is not None:
            file.close()