class ClassExtensionHandlerTests(FactoryTestCase): def setUp(self): super().setUp() container = ClassContainer() self.processor = ClassExtensionHandler(container=container) @mock.patch.object(ClassExtensionHandler, "process_native_extension") def test_process_extension_with_native_type(self, mock_flatten_extension_native): extension = ExtensionFactory.create(type=AttrTypeFactory.xs_string()) target = ClassFactory.elements(1, extensions=[extension]) self.processor.process_extension(target, extension) mock_flatten_extension_native.assert_called_once_with( target, extension) @mock.patch.object(ClassExtensionHandler, "process_dependency_extension") def test_process_extension_with_dependency_type( self, mock_process_dependency_extension): extension = ExtensionFactory.create(type=AttrTypeFactory.create("foo")) target = ClassFactory.elements(1, extensions=[extension]) self.processor.process_extension(target, extension) mock_process_dependency_extension.assert_called_once_with( target, extension) @mock.patch.object(ClassExtensionHandler, "process_complex_extension") @mock.patch.object(ClassExtensionHandler, "process_simple_extension") @mock.patch.object(ClassExtensionHandler, "find_dependency") def test_process_dependency_extension_with_absent_type( self, mock_find_dependency, mock_process_simple_extension, mock_process_complex_extension, ): extension = ExtensionFactory.create() target = ClassFactory.create(extensions=[extension]) mock_find_dependency.return_value = None self.processor.process_extension(target, extension) self.assertEqual(0, len(target.extensions)) self.assertEqual(0, mock_process_simple_extension.call_count) self.assertEqual(0, mock_process_complex_extension.call_count) @mock.patch.object(ClassExtensionHandler, "process_complex_extension") @mock.patch.object(ClassExtensionHandler, "process_simple_extension") @mock.patch.object(ClassExtensionHandler, "find_dependency") def test_process_dependency_extension_with_simple_type( self, mock_find_dependency, mock_process_simple_extension, mock_process_complex_extension, ): extension = ExtensionFactory.create() target = ClassFactory.create(extensions=[extension]) source = ClassFactory.create(type=SimpleType) mock_find_dependency.return_value = source self.processor.process_extension(target, extension) self.assertEqual(0, mock_process_complex_extension.call_count) mock_process_simple_extension.assert_called_once_with( source, target, extension) @mock.patch.object(ClassExtensionHandler, "process_complex_extension") @mock.patch.object(ClassExtensionHandler, "process_simple_extension") @mock.patch.object(ClassExtensionHandler, "find_dependency") def test_process_dependency_extension_with_enum_type( self, mock_find_dependency, mock_process_simple_extension, mock_process_complex_extension, ): extension = ExtensionFactory.create() target = ClassFactory.create(extensions=[extension]) source = ClassFactory.create(type=ComplexType) source.attrs.append(AttrFactory.enumeration()) mock_find_dependency.return_value = source self.processor.process_extension(target, extension) self.assertEqual(0, mock_process_complex_extension.call_count) mock_process_simple_extension.assert_called_once_with( source, target, extension) @mock.patch.object(ClassExtensionHandler, "process_complex_extension") @mock.patch.object(ClassExtensionHandler, "process_simple_extension") @mock.patch.object(ClassExtensionHandler, "find_dependency") def test_process_dependency_extension_with_complex_type( self, mock_find_dependency, mock_process_simple_extension, mock_process_complex_extension, ): extension = ExtensionFactory.create() target = ClassFactory.create(extensions=[extension]) source = ClassFactory.create(type=ComplexType) mock_find_dependency.return_value = source self.processor.process_extension(target, extension) self.assertEqual(0, mock_process_simple_extension.call_count) mock_process_complex_extension.assert_called_once_with( source, target, extension) @mock.patch.object(ClassExtensionHandler, "create_default_attribute") def test_process_extension_native(self, mock_create_default_attribute): extension = ExtensionFactory.create() target = ClassFactory.elements(1) self.processor.process_native_extension(target, extension) mock_create_default_attribute.assert_called_once_with( target, extension) @mock.patch.object(ClassExtensionHandler, "copy_extension_type") def test_process_native_extension_with_enumeration_target( self, mock_copy_extension_type): extension = ExtensionFactory.create() target = ClassFactory.enumeration(1) self.processor.process_native_extension(target, extension) mock_copy_extension_type.assert_called_once_with(target, extension) def test_process_simple_extension_with_circular_refence(self): extension = ExtensionFactory.create() target = ClassFactory.create(extensions=[extension]) self.processor.process_simple_extension(target, target, extension) self.assertEqual(0, len(target.extensions)) @mock.patch.object(ClassExtensionHandler, "create_default_attribute") def test_process_simple_extension_when_source_is_enumeration_and_target_is_not( self, mock_create_default_attribute): source = ClassFactory.enumeration(2) target = ClassFactory.elements(1) extension = ExtensionFactory.create() self.processor.process_simple_extension(source, target, extension) mock_create_default_attribute.assert_called_once_with( target, extension) @mock.patch.object(ClassUtils, "copy_attributes") @mock.patch.object(ClassExtensionHandler, "create_default_attribute") def test_process_simple_extension_when_target_is_enumeration_and_source_is_not( self, mock_create_default_attribute, mock_copy_attributes): extension = ExtensionFactory.create() source = ClassFactory.elements(2) target = ClassFactory.enumeration(1, extensions=[extension]) self.processor.process_simple_extension(source, target, extension) self.assertEqual(0, mock_create_default_attribute.call_count) self.assertEqual(0, mock_copy_attributes.call_count) self.assertEqual(0, len(target.extensions)) @mock.patch.object(ClassUtils, "copy_attributes") def test_process_simple_extension_when_source_and_target_are_enumerations( self, mock_copy_attributes): source = ClassFactory.enumeration(2) target = ClassFactory.enumeration(1) extension = ExtensionFactory.create() self.processor.process_simple_extension(source, target, extension) mock_copy_attributes.assert_called_once_with(source, target, extension) @mock.patch.object(ClassUtils, "copy_attributes") def test_process_simple_extension_when_source_and_target_are_not_enumerations( self, mock_copy_attributes): source = ClassFactory.elements(2) target = ClassFactory.elements(1) extension = ExtensionFactory.create() self.processor.process_simple_extension(source, target, extension) mock_copy_attributes.assert_called_once_with(source, target, extension) @mock.patch.object(ClassUtils, "copy_attributes") @mock.patch.object(ClassExtensionHandler, "compare_attributes") def test_process_complex_extension_when_target_includes_all_source_attrs( self, mock_compare_attributes, mock_copy_attributes): mock_compare_attributes.return_value = ClassExtensionHandler.INCLUDES_ALL extension = ExtensionFactory.create() target = ClassFactory.create(extensions=[extension]) source = ClassFactory.create() self.processor.process_complex_extension(source, target, extension) self.assertEqual(0, len(target.extensions)) mock_compare_attributes.assert_called_once_with(source, target) self.assertEqual(0, mock_copy_attributes.call_count) @mock.patch.object(ClassUtils, "copy_attributes") @mock.patch.object(ClassExtensionHandler, "compare_attributes") def test_process_complex_extension_when_target_includes_some_source_attrs( self, mock_compare_attributes, mock_copy_attributes): mock_compare_attributes.return_value = ClassExtensionHandler.INCLUDES_SOME extension = ExtensionFactory.create() target = ClassFactory.create() source = ClassFactory.create() self.processor.process_complex_extension(source, target, extension) mock_compare_attributes.assert_called_once_with(source, target) mock_copy_attributes.assert_called_once_with(source, target, extension) @mock.patch.object(ClassUtils, "copy_attributes") @mock.patch.object(ClassExtensionHandler, "compare_attributes") def test_process_complex_extension_when_target_includes_no_source_attrs( self, mock_compare_attributes, mock_copy_attributes): mock_compare_attributes.return_value = ClassExtensionHandler.INCLUDES_NONE extension = ExtensionFactory.create() target = ClassFactory.create(extensions=[extension]) source = ClassFactory.create() self.processor.process_complex_extension(source, target, extension) mock_compare_attributes.assert_called_once_with(source, target) self.assertEqual(0, mock_copy_attributes.call_count) self.assertEqual(1, len(target.extensions)) @mock.patch.object(ClassUtils, "copy_attributes") @mock.patch.object(ClassExtensionHandler, "compare_attributes") def test_process_complex_extension_when_source_is_strict_type( self, mock_compare_attributes, mock_copy_attributes): mock_compare_attributes.return_value = ClassExtensionHandler.INCLUDES_NONE extension = ExtensionFactory.create() target = ClassFactory.create() source = ClassFactory.create(strict_type=True) self.processor.process_complex_extension(source, target, extension) mock_compare_attributes.assert_called_once_with(source, target) mock_copy_attributes.assert_called_once_with(source, target, extension) @mock.patch.object(ClassUtils, "copy_attributes") @mock.patch.object(ClassExtensionHandler, "compare_attributes") def test_process_complex_extension_when_source_has_suffix_attr( self, mock_compare_attributes, mock_copy_attributes): mock_compare_attributes.return_value = ClassExtensionHandler.INCLUDES_NONE extension = ExtensionFactory.create() target = ClassFactory.create() source = ClassFactory.create() source.attrs.append(AttrFactory.create(index=sys.maxsize)) self.processor.process_complex_extension(source, target, extension) self.assertEqual(0, len(target.attrs)) target.attrs.append(AttrFactory.create()) self.processor.process_complex_extension(source, target, extension) self.assertEqual(2, mock_compare_attributes.call_count) mock_copy_attributes.assert_called_once_with(source, target, extension) @mock.patch.object(ClassUtils, "copy_attributes") @mock.patch.object(ClassExtensionHandler, "compare_attributes") def test_process_complex_extension_when_target_has_suffix_attr( self, mock_compare_attributes, mock_copy_attributes): mock_compare_attributes.return_value = ClassExtensionHandler.INCLUDES_NONE extension = ExtensionFactory.create() target = ClassFactory.create() source = ClassFactory.create() target.attrs.append(AttrFactory.create(index=sys.maxsize)) self.processor.process_complex_extension(source, target, extension) mock_compare_attributes.assert_called_once_with(source, target) mock_copy_attributes.assert_called_once_with(source, target, extension) def test_find_dependency(self): attr_type = AttrTypeFactory.create(qname="a") self.assertIsNone(self.processor.find_dependency(attr_type)) complex = ClassFactory.create(qname="a", type=ComplexType) self.processor.container.add(complex) self.assertEqual(complex, self.processor.find_dependency(attr_type)) simple = ClassFactory.create(qname="a", type=SimpleType) self.processor.container.add(simple) self.assertEqual(simple, self.processor.find_dependency(attr_type)) def test_compare_attributes(self): source = ClassFactory.elements(2) self.assertEqual(2, self.processor.compare_attributes(source, source)) target = ClassFactory.create() self.assertEqual( 0, ClassExtensionHandler.compare_attributes(source, target)) target.attrs = [attr.clone() for attr in source.attrs] self.assertEqual( 2, ClassExtensionHandler.compare_attributes(source, target)) source.attrs.append(AttrFactory.element()) self.assertEqual( 1, ClassExtensionHandler.compare_attributes(source, target)) source.attrs = AttrFactory.list(3) self.assertEqual( 0, ClassExtensionHandler.compare_attributes(source, target)) self.assertEqual(0, ClassExtensionHandler.INCLUDES_NONE) self.assertEqual(1, ClassExtensionHandler.INCLUDES_SOME) self.assertEqual(2, ClassExtensionHandler.INCLUDES_ALL) def test_copy_extension_type(self): extension = ExtensionFactory.create() target = ClassFactory.elements(2) target.extensions.append(extension) ClassExtensionHandler.copy_extension_type(target, extension) self.assertEqual(extension.type, target.attrs[0].types[1]) self.assertEqual(extension.type, target.attrs[1].types[1]) self.assertEqual(0, len(target.extensions)) def test_create_default_attribute(self): extension = ExtensionFactory.create() item = ClassFactory.create(extensions=[extension]) ClassExtensionHandler.create_default_attribute(item, extension) expected = AttrFactory.create( name="value", index=0, default=None, types=[extension.type], tag=Tag.EXTENSION, ) self.assertEqual(1, len(item.attrs)) self.assertEqual(0, len(item.extensions)) self.assertEqual(expected, item.attrs[0]) def test_create_default_attribute_with_any_type(self): extension = ExtensionFactory.create( type=AttrTypeFactory.xs_any(), restrictions=Restrictions(min_occurs=1, max_occurs=1, required=True), ) item = ClassFactory.create(extensions=[extension]) ClassExtensionHandler.create_default_attribute(item, extension) expected = AttrFactory.create( name="any_element", index=0, default=None, types=[extension.type.clone()], tag=Tag.ANY, namespace="##any", restrictions=Restrictions(min_occurs=1, max_occurs=1, required=True), ) self.assertEqual(1, len(item.attrs)) self.assertEqual(0, len(item.extensions)) self.assertEqual(expected, item.attrs[0])
class ClassExtensionHandlerTests(FactoryTestCase): def setUp(self): super().setUp() container = ClassContainer() self.processor = ClassExtensionHandler(container=container) @mock.patch.object(ClassExtensionHandler, "process_extension") def test_process_with_no_extensions(self, mock_process_extension): target = ClassFactory.elements(1) self.processor.process(target) self.assertEqual(0, mock_process_extension.call_count) @mock.patch.object(ClassExtensionHandler, "process_native_extension") def test_process_extension_with_native_type(self, mock_flatten_extension_native): extension = ExtensionFactory.native(DataType.STRING) target = ClassFactory.elements(1, extensions=[extension]) self.processor.process_extension(target, extension) mock_flatten_extension_native.assert_called_once_with( target, extension) @mock.patch.object(ClassExtensionHandler, "process_dependency_extension") def test_process_extension_with_dependency_type( self, mock_process_dependency_extension): extension = ExtensionFactory.create(AttrTypeFactory.create("foo")) target = ClassFactory.elements(1, extensions=[extension]) self.processor.process_extension(target, extension) mock_process_dependency_extension.assert_called_once_with( target, extension) @mock.patch.object(ClassExtensionHandler, "process_complex_extension") @mock.patch.object(ClassExtensionHandler, "process_simple_extension") @mock.patch.object(ClassExtensionHandler, "find_dependency") def test_process_dependency_extension_with_absent_type( self, mock_find_dependency, mock_process_simple_extension, mock_process_complex_extension, ): extension = ExtensionFactory.create() target = ClassFactory.create(extensions=[extension]) mock_find_dependency.return_value = None self.processor.process_extension(target, extension) self.assertEqual(0, len(target.extensions)) self.assertEqual(0, mock_process_simple_extension.call_count) self.assertEqual(0, mock_process_complex_extension.call_count) @mock.patch.object(ClassExtensionHandler, "process_complex_extension") @mock.patch.object(ClassExtensionHandler, "process_simple_extension") @mock.patch.object(ClassExtensionHandler, "find_dependency") def test_process_dependency_extension_with_simple_type( self, mock_find_dependency, mock_process_simple_extension, mock_process_complex_extension, ): extension = ExtensionFactory.create() target = ClassFactory.create(extensions=[extension]) source = ClassFactory.create(tag=Tag.SIMPLE_TYPE) mock_find_dependency.return_value = source self.processor.process_extension(target, extension) self.assertEqual(0, mock_process_complex_extension.call_count) mock_process_simple_extension.assert_called_once_with( source, target, extension) @mock.patch.object(ClassExtensionHandler, "process_complex_extension") @mock.patch.object(ClassExtensionHandler, "process_simple_extension") @mock.patch.object(ClassExtensionHandler, "find_dependency") def test_process_dependency_extension_with_enum_type( self, mock_find_dependency, mock_process_simple_extension, mock_process_complex_extension, ): extension = ExtensionFactory.create() target = ClassFactory.create(extensions=[extension]) source = ClassFactory.create(tag=Tag.COMPLEX_TYPE) source.attrs.append(AttrFactory.enumeration()) mock_find_dependency.return_value = source self.processor.process_extension(target, extension) self.assertEqual(0, mock_process_complex_extension.call_count) mock_process_simple_extension.assert_called_once_with( source, target, extension) def test_process_enum_extension_with_enum_source(self): source = ClassFactory.enumeration(3) target = ClassFactory.enumeration(2) target.attrs[1].name = source.attrs[2].name extension = ExtensionFactory.reference(source.qname) target.extensions.append(extension) self.processor.container.add(source) self.processor.container.add(target) self.processor.process_dependency_extension(target, extension) self.assertEqual(2, len(target.attrs)) self.assertEqual(0, len(target.extensions)) self.assertEqual(source.attrs[2], target.attrs[1]) def test_process_enum_extension_with_simple_source(self): qname_type = AttrTypeFactory.native(DataType.QNAME) source = ClassFactory.create( tag=Tag.SIMPLE_TYPE, attrs=[ AttrFactory.create(types=[qname_type], restrictions=Restrictions(length=10)) ], ) target = ClassFactory.enumeration(2) extension = ExtensionFactory.reference(source.qname) target.extensions.append(extension) self.processor.container.add(source) self.processor.container.add(target) self.processor.process_dependency_extension(target, extension) for attr in target.attrs: self.assertIn(qname_type, attr.types) self.assertEqual(10, attr.restrictions.length) def test_process_enum_extension_with_complex_source(self): source = ClassFactory.create( tag=Tag.COMPLEX_TYPE, attrs=[ AttrFactory.create(tag=Tag.ATTRIBUTE), AttrFactory.create(tag=Tag.RESTRICTION), ], extensions=ExtensionFactory.list(2), status=Status.PROCESSED, ) target = ClassFactory.enumeration(1) target.attrs[0].default = "Yes" extension = ExtensionFactory.reference(source.qname) target.extensions.append(extension) expected = target.clone() self.processor.container.add(source) self.processor.container.add(target) self.processor.process_dependency_extension(target, extension) expected.attrs = [attr.clone() for attr in source.attrs] expected.extensions = [ext.clone() for ext in source.extensions] expected.attrs[1].default = "Yes" expected.attrs[1].fixed = True self.assertEqual(expected, target) self.assertIsNone(target.attrs[0].default) self.assertFalse(target.attrs[0].fixed) self.assertEqual("Yes", target.attrs[1].default) self.assertTrue(target.attrs[1].fixed) def test_process_enum_extension_raises_exception(self): source = ClassFactory.elements(2) target = ClassFactory.enumeration(2) extension = ExtensionFactory.reference(source.qname) target.extensions.append(extension) self.processor.container.add(source) self.processor.container.add(target) with self.assertRaises(CodeGenerationError): self.processor.process_dependency_extension(target, extension) @mock.patch.object(ClassExtensionHandler, "process_complex_extension") @mock.patch.object(ClassExtensionHandler, "process_simple_extension") @mock.patch.object(ClassExtensionHandler, "find_dependency") def test_process_dependency_extension_with_complex_type( self, mock_find_dependency, mock_process_simple_extension, mock_process_complex_extension, ): extension = ExtensionFactory.create() target = ClassFactory.create(extensions=[extension]) source = ClassFactory.create(tag=Tag.COMPLEX_TYPE) mock_find_dependency.return_value = source self.processor.process_extension(target, extension) self.assertEqual(0, mock_process_simple_extension.call_count) mock_process_complex_extension.assert_called_once_with( source, target, extension) @mock.patch.object(ClassExtensionHandler, "add_default_attribute") def test_process_extension_native(self, mock_add_default_attribute): extension = ExtensionFactory.create() target = ClassFactory.elements(1) self.processor.process_native_extension(target, extension) mock_add_default_attribute.assert_called_once_with(target, extension) @mock.patch.object(ClassExtensionHandler, "replace_attributes_type") def test_process_native_extension_with_enumeration_target( self, mock_replace_attributes_type): extension = ExtensionFactory.create() target = ClassFactory.enumeration(1) self.processor.process_native_extension(target, extension) mock_replace_attributes_type.assert_called_once_with(target, extension) def test_process_simple_extension_with_circular_refence(self): extension = ExtensionFactory.create() target = ClassFactory.create(extensions=[extension]) self.processor.process_simple_extension(target, target, extension) self.assertEqual(0, len(target.extensions)) @mock.patch.object(ClassExtensionHandler, "add_default_attribute") def test_process_simple_extension_when_source_is_enumeration_and_target_is_not( self, mock_add_default_attribute): source = ClassFactory.enumeration(2) target = ClassFactory.elements(1) extension = ExtensionFactory.create() self.processor.process_simple_extension(source, target, extension) mock_add_default_attribute.assert_called_once_with(target, extension) @mock.patch.object(ClassUtils, "copy_attributes") @mock.patch.object(ClassExtensionHandler, "add_default_attribute") def test_process_simple_extension_when_target_is_enumeration_and_source_is_not( self, mock_add_default_attribute, mock_copy_attributes): extension = ExtensionFactory.create() source = ClassFactory.elements(2) target = ClassFactory.enumeration(1, extensions=[extension]) self.processor.process_simple_extension(source, target, extension) self.assertEqual(0, mock_add_default_attribute.call_count) self.assertEqual(0, mock_copy_attributes.call_count) self.assertEqual(0, len(target.extensions)) @mock.patch.object(ClassUtils, "copy_attributes") def test_process_simple_extension_when_source_and_target_are_enumerations( self, mock_copy_attributes): source = ClassFactory.enumeration(2) target = ClassFactory.enumeration(1) extension = ExtensionFactory.create() self.processor.process_simple_extension(source, target, extension) mock_copy_attributes.assert_called_once_with(source, target, extension) @mock.patch.object(ClassUtils, "copy_attributes") def test_process_simple_extension_when_source_and_target_are_not_enumerations( self, mock_copy_attributes): source = ClassFactory.elements(2) target = ClassFactory.elements(1) extension = ExtensionFactory.create() self.processor.process_simple_extension(source, target, extension) mock_copy_attributes.assert_called_once_with(source, target, extension) @mock.patch.object(ClassUtils, "copy_attributes") @mock.patch.object(ClassExtensionHandler, "should_remove_extension") def test_process_complex_extension_removes_extension( self, mock_should_remove_extension, mock_copy_attributes): mock_should_remove_extension.return_value = True extension = ExtensionFactory.create() target = ClassFactory.elements(1, extensions=[extension]) source = ClassFactory.elements(5) self.processor.process_complex_extension(source, target, extension) self.assertEqual(0, len(target.extensions)) self.assertEqual(1, len(target.attrs)) mock_should_remove_extension.assert_called_once_with(source, target) self.assertEqual(0, mock_copy_attributes.call_count) @mock.patch.object(ClassUtils, "copy_attributes") @mock.patch.object(ClassExtensionHandler, "should_flatten_extension") def test_process_complex_extension_copies_attributes( self, mock_compare_attributes, mock_should_flatten_extension): mock_should_flatten_extension.return_value = True extension = ExtensionFactory.create() target = ClassFactory.create() source = ClassFactory.create() self.processor.process_complex_extension(source, target, extension) mock_compare_attributes.assert_called_once_with(source, target) mock_should_flatten_extension.assert_called_once_with( source, target, extension) @mock.patch.object(ClassExtensionHandler, "should_flatten_extension") @mock.patch.object(ClassExtensionHandler, "should_remove_extension") def test_process_complex_extension_ignores_extension( self, mock_should_remove_extension, mock_should_flatten_extension): mock_should_remove_extension.return_value = False mock_should_flatten_extension.return_value = False extension = ExtensionFactory.create() target = ClassFactory.create(extensions=[extension]) source = ClassFactory.create() self.processor.process_complex_extension(source, target, extension) self.assertEqual(1, len(target.extensions)) def test_find_dependency(self): attr_type = AttrTypeFactory.create(qname="a") self.assertIsNone(self.processor.find_dependency(attr_type)) complex = ClassFactory.create(qname="a", tag=Tag.COMPLEX_TYPE) self.processor.container.add(complex) self.assertEqual(complex, self.processor.find_dependency(attr_type)) simple = ClassFactory.create(qname="a", tag=Tag.SIMPLE_TYPE) self.processor.container.add(simple) self.assertEqual(simple, self.processor.find_dependency(attr_type)) def test_should_remove_extension(self): source = ClassFactory.create() target = ClassFactory.create() # source is target self.assertTrue(self.processor.should_remove_extension(source, source)) self.assertFalse(self.processor.should_remove_extension( source, target)) # Source is parent class source.inner.append(target) self.assertTrue(self.processor.should_remove_extension(target, target)) # MRO Violation source.inner.clear() target.extensions.append(ExtensionFactory.reference("foo")) target.extensions.append(ExtensionFactory.reference("bar")) self.assertFalse(self.processor.should_remove_extension( source, target)) source.extensions.append(ExtensionFactory.reference("bar")) self.assertTrue(self.processor.should_remove_extension(source, target)) def test_should_flatten_extension(self): source = ClassFactory.create() target = ClassFactory.create() self.assertFalse( self.processor.should_flatten_extension(source, target)) # Source has suffix attr and target has its own attrs source = ClassFactory.elements(1) source.attrs[0].index = sys.maxsize target.attrs.append(AttrFactory.create()) self.assertTrue(self.processor.should_flatten_extension( source, target)) # Target has suffix attr source = ClassFactory.create() target = ClassFactory.elements(1) target.attrs[0].index = sys.maxsize self.assertTrue(self.processor.should_flatten_extension( source, target)) # Source is a simple type source = ClassFactory.create( attrs=[AttrFactory.create(tag=Tag.SIMPLE_TYPE)]) target = ClassFactory.elements(1) self.assertTrue(self.processor.should_flatten_extension( source, target)) # Sequential violation source = ClassFactory.elements(3) target = source.clone() self.assertFalse( self.processor.should_flatten_extension(source, target)) for attr in target.attrs: attr.restrictions.sequential = True self.assertFalse( self.processor.should_flatten_extension(source, target)) target.attrs = [target.attrs[1], target.attrs[0], target.attrs[2]] self.assertTrue(self.processor.should_flatten_extension( source, target)) # Types violation target = source.clone() target.attrs[1].types = [ AttrTypeFactory.native(DataType.INT), AttrTypeFactory.native(DataType.FLOAT), ] source.attrs[1].types = [ AttrTypeFactory.native(DataType.INT), AttrTypeFactory.native(DataType.FLOAT), AttrTypeFactory.native(DataType.DECIMAL), ] self.assertFalse( self.processor.should_flatten_extension(source, target)) target.attrs[1].types.append(AttrTypeFactory.native(DataType.QNAME)) self.assertTrue(self.processor.should_flatten_extension( source, target)) def test_replace_attributes_type(self): extension = ExtensionFactory.create() target = ClassFactory.elements(2) target.extensions.append(extension) ClassExtensionHandler.replace_attributes_type(target, extension) self.assertEqual(1, len(target.attrs[0].types)) self.assertEqual(1, len(target.attrs[1].types)) self.assertEqual(extension.type, target.attrs[0].types[0]) self.assertEqual(extension.type, target.attrs[1].types[0]) self.assertEqual(0, len(target.extensions)) def test_add_default_attribute(self): xs_string = AttrTypeFactory.native(DataType.STRING) extension = ExtensionFactory.create(xs_string, Restrictions(required=True)) item = ClassFactory.elements(1, extensions=[extension]) ClassExtensionHandler.add_default_attribute(item, extension) expected = AttrFactory.create(name="@value", default=None, types=[xs_string], tag=Tag.EXTENSION) self.assertEqual(2, len(item.attrs)) self.assertEqual(0, len(item.extensions)) self.assertEqual(expected, item.attrs[0]) xs_int = AttrTypeFactory.native(DataType.INT) extension = ExtensionFactory.create(xs_int, Restrictions(tokens=True)) item.extensions.append(extension) ClassExtensionHandler.add_default_attribute(item, extension) expected.types.append(xs_int) expected_restrictions = Restrictions(tokens=True, required=True, min_occurs=1, max_occurs=1) self.assertEqual(2, len(item.attrs)) self.assertEqual(0, len(item.extensions)) self.assertEqual(expected, item.attrs[0]) self.assertEqual(expected_restrictions, item.attrs[0].restrictions) def test_add_default_attribute_with_any_type(self): extension = ExtensionFactory.create( AttrTypeFactory.native(DataType.ANY_TYPE), Restrictions(min_occurs=1, max_occurs=1, required=True), ) item = ClassFactory.create(extensions=[extension]) ClassExtensionHandler.add_default_attribute(item, extension) expected = AttrFactory.create( name="@any_element", default=None, types=[extension.type.clone()], tag=Tag.ANY, namespace="##any", restrictions=Restrictions(min_occurs=1, max_occurs=1, required=True), ) self.assertEqual(1, len(item.attrs)) self.assertEqual(0, len(item.extensions)) self.assertEqual(expected, item.attrs[0])