def flatten_enumeration_unions(self, target: Class): if not target.is_common: return if len(target.attrs) == 1 and target.attrs[0].name == "value": all_enums = True attrs = [] for attr_type in target.attrs[0].types: is_enumeration = False if attr_type.forward_ref and len(target.inner) == 1: if target.inner[0].is_enumeration: is_enumeration = True attrs.extend(target.inner[0].attrs) elif not attr_type.forward_ref and not attr_type.native: type_qname = target.source_qname(attr_type.name) source = self.find_class(type_qname) if source is not None and source.is_enumeration: is_enumeration = True attrs.extend(source.attrs) if not is_enumeration: all_enums = False if all_enums: target.attrs = attrs
def flatten_extension(self, target: Class, extension: Extension): """ Flatten target class extension based on the extension type. Types: 1. Native primitive type (int, str, float, etc) 2. Simple source type (simpleType, Extension) 3. Complex source type (ComplexType, Element) 4. Unknown type """ if extension.type.native: self.flatten_extension_native(target, extension) else: qname = target.source_qname(extension.type.name) simple_source = self.find_simple_class(qname) complex_source = None if simple_source else self.find_class(qname) if simple_source: self.flatten_extension_simple(simple_source, target, extension) elif complex_source: self.flatten_extension_complex(complex_source, target, extension) else: logger.warning("Missing extension type: %s", extension.type.name) target.extensions.remove(extension)
def expand_attribute_group(self, target: Class, attr: Attr): """ Expand a group attribute with the source class attributes. Clone the attributes and apply the group restrictions as well. """ if not attr.is_group: return attr_qname = target.source_qname(attr.name) source = self.find_class(attr_qname) if not source: raise AnalyzerError(f"Group attribute not found: `{attr_qname}`") if source is target: target.attrs.remove(attr) else: index = target.attrs.index(attr) target.attrs.pop(index) prefix = text.prefix(attr.name) for source_attr in source.attrs: clone = self.clone_attribute(source_attr, attr.restrictions, prefix) target.attrs.insert(index, clone) index += 1 self.copy_inner_classes(source, target)
def merge_duplicate_attributes(cls, target: Class): """ Flatten duplicate attributes. Remove duplicate fields in case of attributes or enumerations otherwise convert fields to lists. """ if not target.attrs: return result: List[Attr] = [] for attr in target.attrs: pos = cls.find_attribute(result, attr) existing = result[pos] if pos > -1 else None if not existing: result.append(attr) elif not (attr.is_attribute or attr.is_enumeration): min_occurs = existing.restrictions.min_occurs or 0 max_occurs = existing.restrictions.max_occurs or 1 attr_min_occurs = attr.restrictions.min_occurs or 0 attr_max_occurs = attr.restrictions.max_occurs or 1 existing.restrictions.min_occurs = min(min_occurs, attr_min_occurs) existing.restrictions.max_occurs = max_occurs + attr_max_occurs existing.fixed = False existing.restrictions.sequential = ( existing.restrictions.sequential or attr.restrictions.sequential) target.attrs = result
def attr_type_is_missing(self, source: Class, attr_type: AttrType) -> bool: """Check if given type declaration is not native and is missing.""" if attr_type.native: return False qname = source.source_qname(attr_type.name) return qname not in self.class_index
def find_attr_simple_type(self, source: Class, attr_type: AttrType) -> Optional[Class]: qname = source.source_qname(attr_type.name) return self.find_class( qname, condition=lambda x: not x.is_enumeration and not x.is_complex and x is not source, )
def apply_aliases(self, obj: Class): """Walk the attributes tree and set the type aliases.""" for attr in obj.attrs: for attr_type in attr.types: attr_type_qname = obj.source_qname(attr_type.name) attr_type.alias = self.aliases.get(attr_type_qname) for inner in obj.inner: self.apply_aliases(inner)
def class_depends_on(self, source: Class, target: Class) -> bool: """Check if any source dependencies recursively match the target class.""" for qname in source.dependencies(): check = self.find_class(qname, condition=None) if check is target or (check and self.class_depends_on(check, target)): return True return False
def attr_depends_on(self, dependency: AttrType, target: Class) -> bool: """Check if dependency or any of its dependencies match the target class.""" qname = target.source_qname(dependency.name) source = self.find_class(qname, condition=None) if source is None: return False elif source is target: return True else: return self.class_depends_on(source, target)
def build_class_extensions(self, obj: BaseElement, target: Class): """Build the item class extensions from the given ElementBase children.""" extensions = dict() raw_type = obj.raw_type if raw_type: restrictions = obj.get_restrictions() extension = self.build_class_extension(target, raw_type, 0, restrictions) extensions[raw_type] = extension for extension in self.children_extensions(obj, target): extension.forward_ref = False extensions[extension.type.name] = extension target.extensions = sorted(extensions.values(), key=lambda x: x.type.index)
def process_enumerations(cls, obj: Class): attr_types = {ext.type.name: ext.type for ext in obj.extensions} attrs = {str(attr.default): attr for attr in obj.attrs} obj.attrs = sorted(attrs.values(), key=lambda x: str(x.default)) names = set() for attr in obj.attrs: attr.types.extend(attr_types.values()) attr.default = cls.attribute_default(attr, obj.ns_map) attr.name = cls.enumeration_name(str(attr.default).strip("\"'")) names.add(attr.name) if len(names) != len(obj.attrs): for attr in obj.attrs: safe_name = urlsafe_b64encode(str(attr.default).encode()).decode() attr.name = cls.enumeration_name(safe_name)
def process_attributes(cls, obj: Class, parents_list: List[str]): seen: Set[str] = set() obj.attrs = [ attr for attr in obj.attrs if attr.name not in seen and seen.add(attr.name) is None # type: ignore ] seen.clear() for attr in obj.attrs: cls.process_attribute(obj, attr, parents_list) seen.add(attr.name) if len(seen) != len(obj.attrs): for attr in obj.attrs: safe_name = urlsafe_b64encode(str( attr.local_name).encode()).decode() attr.name = cls.attribute_name(safe_name)
def process_class(cls, obj: Class, parents: List[str] = None) -> Class: """Normalize all class instance fields, extends, name and the inner classes recursively.""" parents = parents or [] obj.name = cls.class_name(obj.name) curr_parents = parents + [obj.name] for inner in obj.inner: cls.process_class(inner, curr_parents) if obj.is_enumeration: cls.process_enumerations(obj) else: cls.process_attributes(obj, curr_parents) for extension in obj.extensions: cls.process_extension(extension) return obj
def add_substitution_attrs(self, target: Class, attr: Attr): """ Find all the substitution attributes for the given attribute and add them to the target class. Exclude enumerations and wildcard attributes. """ if attr.is_enumeration or attr.is_wildcard: return index = target.attrs.index(attr) qname = target.source_qname(attr.name) for substitution in self.substitutions_index[qname]: pos = self.find_attribute(target.attrs, substitution) index = pos + 1 if pos > -1 else index clone = substitution.clone() clone.restrictions.merge(attr.restrictions) target.attrs.insert(index, clone) self.add_substitution_attrs(target, clone)
def build_class(self, obj: ElementBase) -> Class: """Build and return a class instance.""" name = obj.real_name namespace = self.element_namespace(obj) instance = Class( name=name, local_name=name, abstract=obj.is_abstract, namespace=namespace, mixed=obj.is_mixed, nillable=obj.is_nillable, type=type(obj), help=obj.display_help, ns_map=obj.ns_map, source_namespace=self.schema.target_namespace, module=self.schema.module, package=self.package, substitutions=obj.substitutions, ) self.build_class_extensions(obj, instance) self.build_class_attributes(obj, instance) return instance
def flatten_extension(self, target: Class, extension: Extension): """ Remove if possible the given extension for the target class. If extension type is xsd native and target is not enumeration create a default value attribute otherwise delegate the task to the responsible handler based on the extension source type complex or simple. """ if extension.type.native and not target.is_enumeration: return self.create_default_attribute(target, extension) type_qname = target.source_qname(extension.type.name) simple_source = self.find_class(type_qname) if simple_source: return self.flatten_extension_simple(simple_source, target, extension) complex_source = self.find_class(type_qname, condition=None) if complex_source: return self.flatten_extension_complex(complex_source, target, extension)
def flatten_attribute_types(self, target: Class, attr: Attr): """ Flatten attribute types by using the source attribute type. Steps: * Skip xsd native types * Detect circular references if no source is found * Skip enumeration types * Overwrite attribute type from source """ types = [] for attr_type in attr.types: source = None if not attr_type.native: type_qname = target.source_qname(attr_type.name) source = self.find_class(type_qname) if source is None: attr_type.self_ref = self.attr_depends_on(attr_type, target) types.append(attr_type) elif self.is_qname(source): types.append(source.extensions[0].type.clone()) elif source.is_enumeration: types.append(attr_type) elif len(source.attrs) == 1: source_attr = source.attrs[0] types.extend(source_attr.types) restrictions = source_attr.restrictions.clone() restrictions.merge(attr.restrictions) attr.restrictions = restrictions self.copy_inner_classes(source, target) else: types.append(AttrType(name=DataType.STRING.code, native=True)) logger.warning("Missing type implementation: %s", source.type.__name__) attr.types = types
def find_attr_type(self, source: Class, attr_type: AttrType) -> Optional[Class]: qname = source.source_qname(attr_type.name) return self.find_class(qname)