예제 #1
0
    def from_config(cls, config: GeneratorConfig) -> "CodeWriter":
        if config.output.format not in cls.generators:
            raise CodeGenerationError(
                f"Unknown output format: '{config.output.format}'")

        generator_class = cls.generators[config.output.format]
        return cls(generator=generator_class(config))
예제 #2
0
파일: utils.py 프로젝트: neriusmika/xsdata
    def find_inner(cls, source: Class, qname: str) -> Class:
        for inner in source.inner:
            if inner.qname == qname:
                return inner

        raise CodeGenerationError(
            f"Missing inner class {source.qname}.{qname}")
예제 #3
0
파일: mixins.py 프로젝트: neriusmika/xsdata
    def designate(self, classes: List[Class]):
        """
        Normalize the target package and module names by the given output
        generator.

        :param classes: a list of codegen class instances
        """
        modules = {}
        packages = {}
        ns_struct = self.config.output.structure == OutputStructure.NAMESPACES
        for obj in classes:

            if ns_struct:
                obj.package = self.config.output.package
                obj.module = obj.target_namespace or ""

            if obj.package is None:
                raise CodeGenerationError(
                    f"Class `{obj.name}` has not been assign to a package.")

            if obj.module not in modules:
                modules[obj.module] = self.module_name(obj.module)

            if obj.package not in packages:
                packages[obj.package] = self.package_name(obj.package)

            obj.module = modules[obj.module]
            obj.package = packages[obj.package]
예제 #4
0
    def designate(
        self, classes: List[Class], output: str, package: str, ns_struct: bool
    ):
        """
        Normalize the target package and module names by the given output
        generator.

        :param classes: a list of codegen class instances
        :param output: target output format
        :param package: the original user provided package name
        :param ns_struct: use the target namespaces to group the classes in the same
            module.
        """
        modules = {}
        packages = {}

        for obj in classes:

            if ns_struct:
                if not obj.qname.namespace:
                    raise CodeGenerationError(
                        f"Class `{obj.name}` target namespace "
                        f"is empty, avoid option `--ns-struct`"
                    )

                obj.package = package
                obj.module = obj.qname.namespace

            if obj.package is None:
                raise CodeGenerationError(
                    f"Class `{obj.name}` has not been assign to a package."
                )

            if obj.module not in modules:
                modules[obj.module] = self.module_name(obj.module, output)

            if obj.package not in packages:
                packages[obj.package] = self.package_name(obj.package, output)

            obj.module = modules[obj.module]
            obj.package = packages[obj.package]
예제 #5
0
def resolve_source(source: str, wsdl: bool) -> Iterator[str]:
    if source.find("://") > -1 and not source.startswith("file://"):
        yield source
    else:
        path = Path(source).resolve()
        if path.is_dir():

            if wsdl:
                raise CodeGenerationError(
                    "WSDL mode doesn't support scanning directories.")

            yield from (x.as_uri() for x in path.glob("*.xsd"))
        else:  # is file
            yield path.as_uri()
예제 #6
0
    def designate(self, classes: List[Class], output: str):
        """Normalize the target package and module names by the given output
        generator."""
        modules = {}
        packages = {}

        for obj in classes:

            if obj.package is None:
                raise CodeGenerationError(
                    f"Class `{obj.name}` has not been assign to a package.")

            if obj.module not in modules:
                modules[obj.module] = self.module_name(obj.module, output)

            if obj.package not in packages:
                packages[obj.package] = self.package_name(obj.package, output)

            obj.module = modules[obj.module]
            obj.package = packages[obj.package]
예제 #7
0
    def process_enum_extension(cls, source: Class, target: Class,
                               ext: Extension):
        """
        Process enumeration class extension.

        Extension cases:
            1. Enumeration: copy all attr properties
            2. Simple type: copy value attr properties
            3. Complex type:
                3.1 Target has one member, clone source and set fixed default value
                3.2 Invalid schema.
        """
        if source.is_enumeration:
            source_attrs = {attr.name: attr for attr in source.attrs}
            target.attrs = [
                source_attrs[attr.name].clone()
                if attr.name in source_attrs else attr for attr in target.attrs
            ]
            target.extensions.remove(ext)
        elif len(source.attrs) == 1:
            source_attr = source.attrs[0]
            for attr in target.attrs:
                attr.types.extend([x.clone() for x in source_attr.types])
                attr.restrictions.merge(source_attr.restrictions)

            target.extensions.remove(ext)
        elif len(target.attrs) == 1:  # We are not an enumeration pal.
            default = target.attrs[0].default
            target.attrs = [attr.clone() for attr in source.attrs]
            target.extensions = [ext.clone() for ext in source.extensions]

            for attr in target.attrs:
                if attr.xml_type is None:
                    attr.default = default
                    attr.fixed = True
        else:
            raise CodeGenerationError(
                "Enumeration class with a complex extension.")
예제 #8
0
    def process_classes(self):
        """Process the generated classes and write or print the final
        output."""
        classes = [
            cls for classes in self.class_map.values() for cls in classes
        ]
        class_num, inner_num = self.count_classes(classes)
        if class_num:
            logger.info("Analyzer input: %d main and %d inner classes",
                        class_num, inner_num)
            self.assign_packages()

            classes = self.analyze_classes(classes)
            class_num, inner_num = self.count_classes(classes)
            logger.info("Analyzer output: %d main and %d inner classes",
                        class_num, inner_num)

            writer = CodeWriter.from_config(self.config)
            if self.print:
                writer.print(classes)
            else:
                writer.write(classes)
        else:
            raise CodeGenerationError("Nothing to generate.")