def test_add_default_attribute(self): xs_string = AttrTypeFactory.native(DataType.STRING) extension = ExtensionFactory.create(xs_string, Restrictions(required=True)) item = ClassFactory.elements(1, extensions=[extension]) ClassExtensionHandler.add_default_attribute(item, extension) expected = AttrFactory.create(name="@value", default=None, types=[xs_string], tag=Tag.EXTENSION) self.assertEqual(2, len(item.attrs)) self.assertEqual(0, len(item.extensions)) self.assertEqual(expected, item.attrs[0]) xs_int = AttrTypeFactory.native(DataType.INT) extension = ExtensionFactory.create(xs_int, Restrictions(tokens=True)) item.extensions.append(extension) ClassExtensionHandler.add_default_attribute(item, extension) expected.types.append(xs_int) expected_restrictions = Restrictions(tokens=True, required=True, min_occurs=1, max_occurs=1) self.assertEqual(2, len(item.attrs)) self.assertEqual(0, len(item.extensions)) self.assertEqual(expected, item.attrs[0]) self.assertEqual(expected_restrictions, item.attrs[0].restrictions)
def test_process_enum_extension_with_complex_source(self): source = ClassFactory.create( tag=Tag.COMPLEX_TYPE, attrs=[ AttrFactory.create(tag=Tag.ATTRIBUTE), AttrFactory.create(tag=Tag.RESTRICTION), ], extensions=ExtensionFactory.list(2), status=Status.PROCESSED, ) target = ClassFactory.enumeration(1) target.attrs[0].default = "Yes" extension = ExtensionFactory.reference(source.qname) target.extensions.append(extension) expected = target.clone() self.processor.container.add(source) self.processor.container.add(target) self.processor.process_dependency_extension(target, extension) expected.attrs = [attr.clone() for attr in source.attrs] expected.extensions = [ext.clone() for ext in source.extensions] expected.attrs[1].default = "Yes" expected.attrs[1].fixed = True self.assertEqual(expected, target) self.assertIsNone(target.attrs[0].default) self.assertFalse(target.attrs[0].fixed) self.assertEqual("Yes", target.attrs[1].default) self.assertTrue(target.attrs[1].fixed)
def test_rename_class_dependencies(self): attr_type = AttrTypeFactory.create("{foo}bar") target = ClassFactory.create( extensions=[ ExtensionFactory.create(), ExtensionFactory.create(attr_type.clone()), ], attrs=[ AttrFactory.create(), AttrFactory.create( types=[AttrTypeFactory.create(), attr_type.clone()]), ], inner=[ ClassFactory.create( extensions=[ExtensionFactory.create(attr_type.clone())], attrs=[ AttrFactory.create(), AttrFactory.create(types=[ AttrTypeFactory.create(), attr_type.clone() ]), ], ) ], ) self.sanitizer.rename_class_dependencies(target, "{foo}bar", "thug") dependencies = set(target.dependencies()) self.assertNotIn("{foo}bar", dependencies) self.assertIn("thug", dependencies)
def test_dependencies(self): obj = ClassFactory.create( attrs=[ AttrFactory.create( types=[AttrTypeFactory.native(DataType.DECIMAL)]), AttrFactory.create( types=[ AttrTypeFactory.create( qname=build_qname(Namespace.XS.uri, "annotated"), forward=True, ) ], choices=[ AttrFactory.create( name="x", types=[ AttrTypeFactory.create(qname="choiceAttr"), AttrTypeFactory.native(DataType.STRING), ], ), AttrFactory.create( name="x", types=[ AttrTypeFactory.create(qname="choiceAttrTwo"), AttrTypeFactory.create(qname="choiceAttrEnum"), ], ), ], ), AttrFactory.create(types=[ AttrTypeFactory.create( qname=build_qname(Namespace.XS.uri, "openAttrs")), AttrTypeFactory.create( qname=build_qname(Namespace.XS.uri, "localAttribute")), ]), ], extensions=[ ExtensionFactory.reference( build_qname(Namespace.XS.uri, "foobar")), ExtensionFactory.reference( build_qname(Namespace.XS.uri, "foobar")), ], inner=[ ClassFactory.create(attrs=AttrFactory.list( 2, types=AttrTypeFactory.list(1, qname="{xsdata}foo"))) ], ) expected = [ "choiceAttr", "choiceAttrTwo", "choiceAttrEnum", "{http://www.w3.org/2001/XMLSchema}openAttrs", "{http://www.w3.org/2001/XMLSchema}localAttribute", "{http://www.w3.org/2001/XMLSchema}foobar", "{xsdata}foo", ] self.assertCountEqual(expected, list(obj.dependencies()))
def test_copy_extensions(self): target = ClassFactory.create(extensions=ExtensionFactory.list(1)) source = ClassFactory.create(extensions=ExtensionFactory.list(2)) link_extension = ExtensionFactory.create() link_extension.restrictions.max_occurs = 2 ClassUtils.copy_extensions(source, target, link_extension) self.assertEqual(3, len(target.extensions)) self.assertEqual(2, target.extensions[1].restrictions.max_occurs) self.assertEqual(2, target.extensions[2].restrictions.max_occurs)
def test_copy_attributes(self, mock_clone_attribute, mock_copy_inner_classes): mock_clone_attribute.side_effect = lambda x, y: x.clone() target = ClassFactory.create( attrs=[AttrFactory.create(name="a"), AttrFactory.create(name="b")] ) source = ClassFactory.create( attrs=[ AttrFactory.create(name="c", index=sys.maxsize), AttrFactory.create(name="a"), AttrFactory.create(name="b"), AttrFactory.create(name="d"), ] ) extension = ExtensionFactory.create(AttrTypeFactory.create(qname="foo")) target.extensions.append(extension) ClassUtils.copy_attributes(source, target, extension) self.assertEqual(["a", "b", "d", "c"], [attr.name for attr in target.attrs]) mock_copy_inner_classes.assert_has_calls( [ mock.call(source, target, source.attrs[0]), mock.call(source, target, source.attrs[3]), ] ) mock_clone_attribute.assert_has_calls( [ mock.call(source.attrs[0], extension.restrictions), mock.call(source.attrs[3], extension.restrictions), ] )
def test_remove_inherited_fields_with_lists_type(self): target = ClassFactory.elements(2) target.attrs[0].restrictions.min_occurs = 1 target.attrs[0].restrictions.max_occurs = 3 source = target.clone() source.qname = "BaseClass" target.attrs[0].restrictions.min_occurs = 0 target.attrs[1].restrictions.max_occurs = 10 target.extensions.append(ExtensionFactory.reference(source.qname)) self.processor.container.add(source) self.processor.process(target) self.assertEqual(2, len(target.attrs)) # min/max occurs didn't match self.processor.container.config.output.compound_fields = True self.processor.process(target) self.assertEqual(2, len( target.attrs)) # The are not part of a choice group target.attrs[0].restrictions.choice = "123" target.attrs[1].restrictions.choice = "123" source.attrs[0].restrictions.choice = "456" source.attrs[1].restrictions.choice = "456" self.processor.process(target) self.assertEqual(0, len(target.attrs))
def test_process_native_extension_with_enumeration_target( self, mock_replace_attributes_type): extension = ExtensionFactory.create() target = ClassFactory.enumeration(1) self.processor.process_native_extension(target, extension) mock_replace_attributes_type.assert_called_once_with(target, extension)
def test_process_extension_with_dependency_type( self, mock_process_dependency_extension): extension = ExtensionFactory.create(AttrTypeFactory.create("foo")) target = ClassFactory.elements(1, extensions=[extension]) self.processor.process_extension(target, extension) mock_process_dependency_extension.assert_called_once_with( target, extension)
def test_process_simple_extension_when_source_and_target_are_not_enumerations( self, mock_copy_attributes): source = ClassFactory.elements(2) target = ClassFactory.elements(1) extension = ExtensionFactory.create() self.processor.process_simple_extension(source, target, extension) mock_copy_attributes.assert_called_once_with(source, target, extension)
def test_process_simple_extension_when_source_is_enumeration_and_target_is_not( self, mock_add_default_attribute): source = ClassFactory.enumeration(2) target = ClassFactory.elements(1) extension = ExtensionFactory.create() self.processor.process_simple_extension(source, target, extension) mock_add_default_attribute.assert_called_once_with(target, extension)
def test_process_extension_with_native_type(self, mock_flatten_extension_native): extension = ExtensionFactory.native(DataType.STRING) target = ClassFactory.elements(1, extensions=[extension]) self.processor.process_extension(target, extension) mock_flatten_extension_native.assert_called_once_with( target, extension)
def test_process_complex_extension_ignores_extension( self, mock_should_remove_extension, mock_should_flatten_extension): mock_should_remove_extension.return_value = False mock_should_flatten_extension.return_value = False extension = ExtensionFactory.create() target = ClassFactory.create(extensions=[extension]) source = ClassFactory.create() self.processor.process_complex_extension(source, target, extension) self.assertEqual(1, len(target.extensions))
def test_process_simple_extension_when_target_is_enumeration_and_source_is_not( self, mock_add_default_attribute, mock_copy_attributes): extension = ExtensionFactory.create() source = ClassFactory.elements(2) target = ClassFactory.enumeration(1, extensions=[extension]) self.processor.process_simple_extension(source, target, extension) self.assertEqual(0, mock_add_default_attribute.call_count) self.assertEqual(0, mock_copy_attributes.call_count) self.assertEqual(0, len(target.extensions))
def test_process_enum_extension_raises_exception(self): source = ClassFactory.elements(2) target = ClassFactory.enumeration(2) extension = ExtensionFactory.reference(source.qname) target.extensions.append(extension) self.processor.container.add(source) self.processor.container.add(target) with self.assertRaises(CodeGenerationError): self.processor.process_dependency_extension(target, extension)
def test_process_complex_extension_copies_attributes( self, mock_compare_attributes, mock_should_flatten_extension): mock_should_flatten_extension.return_value = True extension = ExtensionFactory.create() target = ClassFactory.create() source = ClassFactory.create() self.processor.process_complex_extension(source, target, extension) mock_compare_attributes.assert_called_once_with(source, target) mock_should_flatten_extension.assert_called_once_with( source, target, extension)
def test_should_remove_extension(self): source = ClassFactory.create() target = ClassFactory.create() # source is target self.assertTrue(self.processor.should_remove_extension(source, source)) self.assertFalse(self.processor.should_remove_extension(source, target)) # Source is parent class source.inner.append(target) self.assertTrue(self.processor.should_remove_extension(target, target)) # MRO Violation source.inner.clear() target.extensions.append(ExtensionFactory.reference("foo")) target.extensions.append(ExtensionFactory.reference("bar")) self.assertFalse(self.processor.should_remove_extension(source, target)) source.extensions.append(ExtensionFactory.reference("bar")) self.assertTrue(self.processor.should_remove_extension(source, target))
def test_build_class_extensions(self, mock_children_extensions): bar_type = AttrTypeFactory.create(qname="bar") foo_type = AttrTypeFactory.create(qname="foo") bar = ExtensionFactory.create(bar_type) double = ExtensionFactory.create(bar_type) foo = ExtensionFactory.create(foo_type) mock_children_extensions.return_value = [bar, double, foo] self_ext = ExtensionFactory.reference( qname="{xsdata}something", restrictions=Restrictions(min_occurs=1, max_occurs=1), ) item = ClassFactory.create() element = Element(type="something") SchemaMapper.build_class_extensions(element, item) self.assertEqual(3, len(item.extensions)) self.assertCountEqual([bar, self_ext, foo], item.extensions)
def test_replace_attributes_type(self): extension = ExtensionFactory.create() target = ClassFactory.elements(2) target.extensions.append(extension) ClassExtensionHandler.replace_attributes_type(target, extension) self.assertEqual(1, len(target.attrs[0].types)) self.assertEqual(1, len(target.attrs[1].types)) self.assertEqual(extension.type, target.attrs[0].types[0]) self.assertEqual(extension.type, target.attrs[1].types[0]) self.assertEqual(0, len(target.extensions))
def test_process_complex_extension_removes_extension( self, mock_should_remove_extension, mock_copy_attributes): mock_should_remove_extension.return_value = True extension = ExtensionFactory.create() target = ClassFactory.elements(1, extensions=[extension]) source = ClassFactory.elements(5) self.processor.process_complex_extension(source, target, extension) self.assertEqual(0, len(target.extensions)) self.assertEqual(1, len(target.attrs)) mock_should_remove_extension.assert_called_once_with(source, target) self.assertEqual(0, mock_copy_attributes.call_count)
def test_process_enum_extension_with_enum_source(self): source = ClassFactory.enumeration(3) target = ClassFactory.enumeration(2) target.attrs[1].name = source.attrs[2].name extension = ExtensionFactory.reference(source.qname) target.extensions.append(extension) self.processor.container.add(source) self.processor.container.add(target) self.processor.process_dependency_extension(target, extension) self.assertEqual(2, len(target.attrs)) self.assertEqual(0, len(target.extensions)) self.assertEqual(source.attrs[2], target.attrs[1])
def test_is_simple_type(self): obj = ClassFactory.elements(2) self.assertFalse(obj.is_simple_type) obj.attrs.pop() self.assertFalse(obj.is_simple_type) for tag in SIMPLE_TYPES: obj.attrs[0].tag = tag self.assertTrue(obj.is_simple_type) obj.extensions.append(ExtensionFactory.create()) self.assertFalse(obj.is_simple_type)
def test_process_dependency_extension_with_absent_type( self, mock_find_dependency, mock_process_simple_extension, mock_process_complex_extension, ): extension = ExtensionFactory.create() target = ClassFactory.create(extensions=[extension]) mock_find_dependency.return_value = None self.processor.process_extension(target, extension) self.assertEqual(0, len(target.extensions)) self.assertEqual(0, mock_process_simple_extension.call_count) self.assertEqual(0, mock_process_complex_extension.call_count)
def test_process_dependency_extension_with_simple_type( self, mock_find_dependency, mock_process_simple_extension, mock_process_complex_extension, ): extension = ExtensionFactory.create() target = ClassFactory.create(extensions=[extension]) source = ClassFactory.create(tag=Tag.SIMPLE_TYPE) mock_find_dependency.return_value = source self.processor.process_extension(target, extension) self.assertEqual(0, mock_process_complex_extension.call_count) mock_process_simple_extension.assert_called_once_with(source, target, extension)
def test_class_references(self): target = ClassFactory.elements( 2, inner=ClassFactory.list(2, attrs=AttrFactory.list(1)), extensions=ExtensionFactory.list(1), ) actual = ClassAnalyzer.class_references(target) # +1 target # +2 attrs # +2 attr types # +1 extension # +1 extension type # +2 inner classes # +2 inner classes attrs # +2 inner classes attr types self.assertEqual(13, len(actual)) self.assertEqual(id(target), actual[0])
def test_add_default_attribute_with_any_type(self): extension = ExtensionFactory.create( AttrTypeFactory.native(DataType.ANY_TYPE), Restrictions(min_occurs=1, max_occurs=1, required=True), ) item = ClassFactory.create(extensions=[extension]) ClassExtensionHandler.add_default_attribute(item, extension) expected = AttrFactory.create( name="@any_element", default=None, types=[extension.type.clone()], tag=Tag.ANY, namespace="##any", restrictions=Restrictions(min_occurs=1, max_occurs=1, required=True), ) self.assertEqual(1, len(item.attrs)) self.assertEqual(0, len(item.extensions)) self.assertEqual(expected, item.attrs[0])
def test_process_enum_extension_with_simple_source(self): qname_type = AttrTypeFactory.native(DataType.QNAME) source = ClassFactory.create( tag=Tag.SIMPLE_TYPE, attrs=[ AttrFactory.create(types=[qname_type], restrictions=Restrictions(length=10)) ], ) target = ClassFactory.enumeration(2) extension = ExtensionFactory.reference(source.qname) target.extensions.append(extension) self.processor.container.add(source) self.processor.container.add(target) self.processor.process_dependency_extension(target, extension) for attr in target.attrs: self.assertIn(qname_type, attr.types) self.assertEqual(10, attr.restrictions.length)
def test_remove_inherited_fields(self): target = ClassFactory.elements(2) source = target.clone() source.qname = "BaseClass" target.extensions.append(ExtensionFactory.reference(source.qname)) original = target.clone() self.processor.container.add(source) source.attrs[1].restrictions.choice = "1234" source.attrs[1].restrictions.max_length = 1 source.attrs[1].default = "1" source.attrs[1].mixed = True source.attrs[1].fixed = True self.processor.process(target) self.assertEqual(1, len(target.attrs)) target = original.clone() source.attrs[1].fixed = False self.processor.process(target) self.assertEqual(1, len(target.attrs)) target = original.clone() source.attrs[1].mixed = False self.processor.process(target) self.assertEqual(1, len(target.attrs)) target = original.clone() source.attrs[1].default = None self.processor.process(target) self.assertEqual(1, len(target.attrs)) target = original.clone() target.attrs.append(AttrFactory.create()) source.attrs[1].restrictions.max_length = None self.processor.process(target) self.assertEqual(1, len(target.attrs))
def test_apply_aliases(self): self.resolver.aliases = { build_qname("xsdata", "d"): "IamD", build_qname("xsdata", "a"): "IamA", } type_a = AttrTypeFactory.create(qname="{xsdata}a") type_b = AttrTypeFactory.create(qname="{xsdata}b") type_c = AttrTypeFactory.create(qname="{xsdata}c") type_d = AttrTypeFactory.create(qname="{xsdata}d") obj = ClassFactory.create( qname="a", attrs=[ AttrFactory.create(name="a", types=[type_a]), AttrFactory.create(name="b", types=[type_b]), AttrFactory.create(name="c", types=[type_a, type_d]), ], inner=[ ClassFactory.create( qname="b", attrs=[ AttrFactory.create(name="c", types=[type_c]), AttrFactory.create(name="d", types=[type_d]), AttrFactory.create( name="compound", types=[AttrTypeFactory.native(DataType.ANY_TYPE)], choices=[ AttrFactory.create(name="a", types=[type_a, type_d]), ], ), ], ) ], extensions=[ExtensionFactory.create(type_a)], ) self.resolver.apply_aliases(obj) self.assertEqual(3, len(obj.attrs)) self.assertEqual(1, len(obj.attrs[0].types)) self.assertEqual(1, len(obj.attrs[1].types)) self.assertEqual(2, len(obj.attrs[2].types)) self.assertEqual("IamA", obj.attrs[0].types[0].alias) self.assertIsNone(obj.attrs[1].types[0].alias) self.assertEqual("IamA", obj.attrs[2].types[0].alias) self.assertEqual("IamD", obj.attrs[2].types[1].alias) self.assertEqual("IamA", obj.extensions[0].type.alias) self.assertEqual(1, len(obj.inner)) self.assertEqual(3, len(obj.inner[0].attrs)) self.assertEqual(1, len(obj.inner[0].attrs[0].types)) self.assertEqual(1, len(obj.inner[0].attrs[1].types)) self.assertEqual("IamA", obj.inner[0].attrs[2].choices[0].types[0].alias) self.assertEqual("IamD", obj.inner[0].attrs[2].choices[0].types[1].alias) self.assertIsNone(obj.inner[0].attrs[0].types[0].alias) self.assertEqual("IamD", obj.inner[0].attrs[1].types[0].alias)
def test_process_simple_extension_with_circular_refence(self): extension = ExtensionFactory.create() target = ClassFactory.create(extensions=[extension]) self.processor.process_simple_extension(target, target, extension) self.assertEqual(0, len(target.extensions))