Exemplo n.º 1
0
class ClassSanitizerTest(FactoryTestCase):
    def setUp(self):
        super().setUp()

        self.config = GeneratorConfig()
        self.container = ClassContainer()
        self.sanitizer = ClassSanitizer(container=self.container, config=self.config)

    @mock.patch.object(ClassSanitizer, "resolve_conflicts")
    @mock.patch.object(ClassSanitizer, "process_class")
    def test_process(self, mock_process_class, mock_resolve_conflicts):
        classes = ClassFactory.list(2)

        self.sanitizer.container.extend(classes)
        ClassSanitizer.process(self.container, self.config)

        mock_process_class.assert_has_calls(list(map(mock.call, classes)))
        mock_resolve_conflicts.assert_called_once_with()

    @mock.patch.object(ClassSanitizer, "process_duplicate_attribute_names")
    @mock.patch.object(ClassSanitizer, "process_attribute_sequence")
    @mock.patch.object(ClassSanitizer, "process_attribute_restrictions")
    @mock.patch.object(ClassSanitizer, "process_attribute_default")
    def test_process_class(
        self,
        mock_process_attribute_default,
        mock_process_attribute_restrictions,
        mock_process_attribute_sequence,
        mock_process_duplicate_attribute_names,
    ):
        target = ClassFactory.elements(2)
        inner = ClassFactory.elements(1)
        target.inner.append(inner)

        self.sanitizer.process_class(target)

        calls_with_target = [
            mock.call(target.inner[0], target.inner[0].attrs[0]),
            mock.call(target, target.attrs[0]),
            mock.call(target, target.attrs[1]),
        ]

        calls_without_target = [
            mock.call(target.inner[0].attrs[0]),
            mock.call(target.attrs[0]),
            mock.call(target.attrs[1]),
        ]

        mock_process_attribute_default.assert_has_calls(calls_with_target)
        mock_process_attribute_restrictions.assert_has_calls(calls_without_target)
        mock_process_attribute_sequence.assert_has_calls(calls_with_target)
        mock_process_duplicate_attribute_names.assert_has_calls(
            [mock.call(target.inner[0].attrs), mock.call(target.attrs)]
        )

    @mock.patch.object(ClassSanitizer, "group_compound_fields")
    def test_process_class_group_compound_fields(self, mock_group_compound_fields):
        target = ClassFactory.create()
        inner = ClassFactory.create()
        target.inner.append(inner)

        self.config.output.compound_fields = True
        self.sanitizer.process_class(target)

        mock_group_compound_fields.assert_has_calls(
            [
                mock.call(inner),
                mock.call(target),
            ]
        )

    def test_process_attribute_default_with_enumeration(self):
        target = ClassFactory.create()
        attr = AttrFactory.enumeration()
        attr.restrictions.max_occurs = 2
        attr.fixed = True

        self.sanitizer.process_attribute_default(target, attr)
        self.assertTrue(attr.fixed)

    def test_process_attribute_default_with_list_field(self):
        target = ClassFactory.create()
        attr = AttrFactory.create(fixed=True)
        attr.restrictions.max_occurs = 2
        self.sanitizer.process_attribute_default(target, attr)
        self.assertFalse(attr.fixed)

    def test_process_attribute_default_with_optional_field(self):
        target = ClassFactory.create()
        attr = AttrFactory.create(fixed=True, default=2)
        attr.restrictions.min_occurs = 0
        self.sanitizer.process_attribute_default(target, attr)
        self.assertFalse(attr.fixed)
        self.assertIsNone(attr.default)

    def test_process_attribute_default_with_xsi_type(self):
        target = ClassFactory.create()
        attr = AttrFactory.create(
            fixed=True, default=2, name="type", namespace=Namespace.XSI.uri
        )
        self.sanitizer.process_attribute_default(target, attr)
        self.assertFalse(attr.fixed)
        self.assertIsNone(attr.default)

    def test_process_attribute_default_with_valid_case(self):
        target = ClassFactory.create()
        attr = AttrFactory.create(fixed=True, default=2)
        self.sanitizer.process_attribute_default(target, attr)
        self.assertTrue(attr.fixed)
        self.assertEqual(2, attr.default)

    @mock.patch("xsdata.codegen.sanitizer.logger.warning")
    @mock.patch.object(ClassSanitizer, "promote_inner_class")
    @mock.patch.object(ClassSanitizer, "find_enum")
    def test_process_attribute_default_enum(
        self, mock_find_enum, mock_promote_inner_class, mock_logger_warning
    ):
        enum_one = ClassFactory.enumeration(1, qname="root")
        enum_one.attrs[0].default = "1"
        enum_one.attrs[0].name = "one"
        enum_two = ClassFactory.enumeration(1, qname="inner")
        enum_two.attrs[0].default = "2"
        enum_two.attrs[0].name = "two"
        enum_three = ClassFactory.enumeration(1, qname="missing_member")

        mock_find_enum.side_effect = [
            None,
            enum_one,
            None,
            enum_two,
            enum_three,
        ]

        target = ClassFactory.create(
            qname="target",
            attrs=[
                AttrFactory.create(
                    types=[
                        AttrTypeFactory.create(),
                        AttrTypeFactory.create(qname="foo"),
                    ],
                    default="1",
                ),
                AttrFactory.create(
                    types=[
                        AttrTypeFactory.create(),
                        AttrTypeFactory.create(qname="bar", forward=True),
                    ],
                    default="2",
                ),
                AttrFactory.create(default="3"),
            ],
        )

        actual = []
        for attr in target.attrs:
            self.sanitizer.process_attribute_default(target, attr)
            actual.append(attr.default)

        self.assertEqual(["@enum@root::one", "@enum@inner::two", None], actual)
        mock_promote_inner_class.assert_called_once_with(target, enum_two)
        mock_logger_warning.assert_called_once_with(
            "No enumeration member matched %s.%s default value `%s`",
            target.name,
            target.attrs[2].local_name,
            "3",
        )

    def test_promote_inner_class(self):
        target = ClassFactory.elements(2, qname="parent")
        inner = ClassFactory.create(qname="{foo}bar")

        target.inner.append(inner)
        target.attrs[1].types.append(AttrTypeFactory.create(forward=True, qname="bar"))

        clone_target = target.clone()

        self.container.add(target)
        self.sanitizer.promote_inner_class(target, inner)

        self.assertNotIn(inner, target.inner)

        self.assertEqual("{foo}parent_bar", target.attrs[1].types[1].qname)
        self.assertFalse(target.attrs[1].types[1].forward)
        self.assertEqual("{foo}parent_bar", inner.qname)
        self.assertEqual(2, len(self.container.data))
        self.assertIn(inner, self.container["{foo}parent_bar"])

        self.assertEqual(clone_target.attrs[0], target.attrs[0])
        self.assertEqual(clone_target.attrs[1].types[0], target.attrs[1].types[0])

    def test_find_enum(self):
        native_type = AttrTypeFactory.create()
        matching_external = AttrTypeFactory.create("foo")
        missing_external = AttrTypeFactory.create("bar")
        matching_inner = AttrTypeFactory.create("foobar", forward=True)
        missing_inner = AttrTypeFactory.create("barfoo", forward=True)
        enumeration = ClassFactory.enumeration(1, qname="foo")
        inner = ClassFactory.enumeration(1, qname="foobar")

        target = ClassFactory.create(
            attrs=[
                AttrFactory.create(
                    types=[
                        native_type,
                        matching_external,
                        missing_external,
                        matching_inner,
                        missing_inner,
                    ]
                )
            ],
            inner=[inner],
        )
        self.sanitizer.container.extend([target, enumeration])

        actual = self.sanitizer.find_enum(target, native_type)
        self.assertIsNone(actual)

        actual = self.sanitizer.find_enum(target, matching_external)
        self.assertEqual(enumeration, actual)

        actual = self.sanitizer.find_enum(target, missing_external)
        self.assertIsNone(actual)

        actual = self.sanitizer.find_enum(target, matching_inner)
        self.assertEqual(inner, actual)

        actual = self.sanitizer.find_enum(target, missing_inner)
        self.assertIsNone(actual)

    def test_process_attribute_restrictions(self):
        restrictions = [
            Restrictions(min_occurs=0, max_occurs=0, required=True),
            Restrictions(min_occurs=0, max_occurs=1, required=True),
            Restrictions(min_occurs=1, max_occurs=1, required=False),
            Restrictions(max_occurs=2, required=True),
            Restrictions(min_occurs=2, max_occurs=2, required=True),
        ]
        expected = [
            {},
            {},
            {"required": True},
            {"max_occurs": 2},
            {"max_occurs": 2, "min_occurs": 2},
        ]

        for idx, res in enumerate(restrictions):
            attr = AttrFactory.create(restrictions=res)
            self.sanitizer.process_attribute_restrictions(attr)
            self.assertEqual(expected[idx], res.asdict())

    def test_sanitize_duplicate_attribute_names(self):
        attrs = [
            AttrFactory.create(name="a", tag=Tag.ELEMENT),
            AttrFactory.create(name="a", tag=Tag.ATTRIBUTE),
            AttrFactory.create(name="b", tag=Tag.ATTRIBUTE),
            AttrFactory.create(name="c", tag=Tag.ATTRIBUTE),
            AttrFactory.create(name="c", tag=Tag.ELEMENT),
            AttrFactory.create(name="d", tag=Tag.ELEMENT),
            AttrFactory.create(name="d", tag=Tag.ELEMENT),
            AttrFactory.create(name="e", tag=Tag.ELEMENT, namespace="b"),
            AttrFactory.create(name="e", tag=Tag.ELEMENT),
            AttrFactory.create(name="f", tag=Tag.ELEMENT),
            AttrFactory.create(name="f", tag=Tag.ELEMENT, namespace="a"),
            AttrFactory.create(name="gA", tag=Tag.ENUMERATION),
            AttrFactory.create(name="g[A]", tag=Tag.ENUMERATION),
            AttrFactory.create(name="g_a", tag=Tag.ENUMERATION),
            AttrFactory.create(name="g_a_1", tag=Tag.ENUMERATION),
        ]

        self.sanitizer.process_duplicate_attribute_names(attrs)
        expected = [
            "a",
            "a_Attribute",
            "b",
            "c_Attribute",
            "c",
            "d_Element",
            "d",
            "b_e",
            "e",
            "f",
            "a_f",
            "gA",
            "g[A]_2",
            "g_a_3",
            "g_a_1",
        ]
        self.assertEqual(expected, [x.name for x in attrs])

    def test_sanitize_attribute_sequence(self):
        def len_sequential(target: Class):
            return len([attr for attr in target.attrs if attr.restrictions.sequential])

        restrictions = Restrictions(max_occurs=2, sequential=True)
        target = ClassFactory.create(
            attrs=[
                AttrFactory.create(restrictions=restrictions.clone()),
                AttrFactory.create(restrictions=restrictions.clone()),
            ]
        )

        attrs_clone = [attr.clone() for attr in target.attrs]

        self.sanitizer.process_attribute_sequence(target, target.attrs[0])
        self.assertEqual(2, len_sequential(target))

        target.attrs[0].restrictions.sequential = False
        self.sanitizer.process_attribute_sequence(target, target.attrs[0])
        self.assertEqual(1, len_sequential(target))

        self.sanitizer.process_attribute_sequence(target, target.attrs[1])
        self.assertEqual(0, len_sequential(target))

        target.attrs = attrs_clone
        target.attrs[1].restrictions.sequential = False
        self.sanitizer.process_attribute_sequence(target, target.attrs[0])
        self.assertEqual(0, len_sequential(target))

        target.attrs[0].restrictions.sequential = True
        target.attrs[0].restrictions.max_occurs = 0
        target.attrs[1].restrictions.sequential = True
        self.sanitizer.process_attribute_sequence(target, target.attrs[0])
        self.assertEqual(1, len_sequential(target))

    @mock.patch.object(ClassSanitizer, "rename_classes")
    def test_resolve_conflicts(self, mock_rename_classes):
        classes = [
            ClassFactory.create(qname="{foo}A"),
            ClassFactory.create(qname="{foo}a"),
            ClassFactory.create(qname="a"),
            ClassFactory.create(qname="b"),
            ClassFactory.create(qname="b"),
        ]
        self.sanitizer.container.extend(classes)
        self.sanitizer.resolve_conflicts()

        mock_rename_classes.assert_has_calls(
            [
                mock.call(classes[:2]),
                mock.call(classes[3:]),
            ]
        )

    @mock.patch.object(ClassSanitizer, "rename_class")
    def test_rename_classes(self, mock_rename_class):
        classes = [
            ClassFactory.create(qname="a", type=Element),
            ClassFactory.create(qname="A", type=Element),
            ClassFactory.create(qname="a", type=ComplexType),
        ]
        self.sanitizer.rename_classes(classes)

        mock_rename_class.assert_has_calls(
            [
                mock.call(classes[0]),
                mock.call(classes[1]),
                mock.call(classes[2]),
            ]
        )

    @mock.patch.object(ClassSanitizer, "rename_class")
    def test_rename_classes_protects_single_element(self, mock_rename_class):
        classes = [
            ClassFactory.create(qname="a", type=Element),
            ClassFactory.create(qname="a", type=ComplexType),
        ]
        self.sanitizer.rename_classes(classes)

        mock_rename_class.assert_called_once_with(classes[1])

    @mock.patch.object(ClassSanitizer, "rename_dependency")
    def test_rename_class(self, mock_rename_dependency):
        target = ClassFactory.create(qname="{foo}a")
        self.sanitizer.container.add(target)
        self.sanitizer.container.add(ClassFactory.create())
        self.sanitizer.container.add(ClassFactory.create(qname="{foo}a_1"))
        self.sanitizer.container.add(ClassFactory.create(qname="{foo}A_2"))
        self.sanitizer.rename_class(target)

        self.assertEqual("{foo}a_3", target.qname)
        self.assertEqual("a", target.meta_name)

        mock_rename_dependency.assert_has_calls(
            mock.call(item, "{foo}a", "{foo}a_3")
            for item in self.sanitizer.container.iterate()
        )

        self.assertEqual([target], self.container.data["{foo}a_3"])
        self.assertEqual([], self.container.data["{foo}a"])

    def test_rename_dependency(self):
        attr_type = AttrTypeFactory.create("{foo}bar")

        target = ClassFactory.create(
            extensions=[
                ExtensionFactory.create(),
                ExtensionFactory.create(type=attr_type.clone()),
            ],
            attrs=[
                AttrFactory.create(),
                AttrFactory.create(types=[AttrTypeFactory.create(), attr_type.clone()]),
            ],
            inner=[
                ClassFactory.create(
                    extensions=[ExtensionFactory.create(type=attr_type.clone())],
                    attrs=[
                        AttrFactory.create(),
                        AttrFactory.create(
                            types=[AttrTypeFactory.create(), attr_type.clone()]
                        ),
                    ],
                )
            ],
        )

        self.sanitizer.rename_dependency(target, "{foo}bar", "thug")
        dependencies = set(target.dependencies())
        self.assertNotIn("{foo}bar", dependencies)
        self.assertIn("thug", dependencies)

    @mock.patch.object(ClassSanitizer, "group_fields")
    def test_group_compound_fields(self, mock_group_fields):
        target = ClassFactory.elements(8)
        # First group repeating
        target.attrs[0].restrictions.choice = "1"
        target.attrs[1].restrictions.choice = "1"
        target.attrs[1].restrictions.max_occurs = 2
        # Second group repeating
        target.attrs[2].restrictions.choice = "2"
        target.attrs[3].restrictions.choice = "2"
        target.attrs[3].restrictions.max_occurs = 2
        # Third group optional
        target.attrs[4].restrictions.choice = "3"
        target.attrs[5].restrictions.choice = "3"

        self.sanitizer.group_compound_fields(target)
        mock_group_fields.assert_has_calls(
            [
                mock.call(target, target.attrs[0:2]),
                mock.call(target, target.attrs[2:4]),
            ]
        )

    def test_group_fields(self):
        target = ClassFactory.create(attrs=AttrFactory.list(2))
        target.attrs[0].restrictions.min_occurs = 10
        target.attrs[0].restrictions.max_occurs = 15
        target.attrs[1].restrictions.min_occurs = 5
        target.attrs[1].restrictions.max_occurs = 20

        expected = AttrFactory.create(
            name="attr_B_Or_attr_C",
            tag="Choice",
            index=0,
            types=[AttrTypeFactory.xs_any()],
            choices=[
                AttrFactory.create(
                    tag=target.attrs[0].tag,
                    name="attr_B",
                    types=target.attrs[0].types,
                ),
                AttrFactory.create(
                    tag=target.attrs[1].tag,
                    name="attr_C",
                    types=target.attrs[1].types,
                ),
            ],
        )
        expected_res = Restrictions(min_occurs=5, max_occurs=20)

        self.sanitizer.group_fields(target, list(target.attrs))
        self.assertEqual(1, len(target.attrs))
        self.assertEqual(expected, target.attrs[0])
        self.assertEqual(expected_res, target.attrs[0].restrictions)

    def test_group_fields_limit_name(self):
        target = ClassFactory.create(attrs=AttrFactory.list(3))
        self.sanitizer.group_fields(target, list(target.attrs))

        self.assertEqual(1, len(target.attrs))
        self.assertEqual("attr_B_Or_attr_C_Or_attr_D", target.attrs[0].name)

        target = ClassFactory.create(attrs=AttrFactory.list(4))
        self.sanitizer.group_fields(target, list(target.attrs))
        self.assertEqual("choice", target.attrs[0].name)

    def test_build_attr_choice(self):
        attr = AttrFactory.create(
            name="a", namespace="xsdata", default="123", help="help", fixed=True
        )
        attr.local_name = "aaa"
        attr.restrictions = Restrictions(
            required=True,
            prohibited=None,
            min_occurs=1,
            max_occurs=1,
            min_exclusive="1.1",
            min_inclusive="1",
            min_length=1,
            max_exclusive="1",
            max_inclusive="1.1",
            max_length=10,
            total_digits=333,
            fraction_digits=2,
            length=5,
            white_space="collapse",
            pattern=r"[A-Z]",
            explicit_timezone="+1",
            nillable=True,
            choice="abc",
            sequential=True,
        )
        expected_res = attr.restrictions.clone()
        expected_res.min_occurs = None
        expected_res.max_occurs = None
        expected_res.sequential = None

        actual = self.sanitizer.build_attr_choice(attr)

        self.assertEqual(attr.local_name, actual.name)
        self.assertEqual(attr.namespace, actual.namespace)
        self.assertEqual(attr.default, actual.default)
        self.assertEqual(attr.tag, actual.tag)
        self.assertEqual(attr.types, actual.types)
        self.assertEqual(expected_res, actual.restrictions)
        self.assertEqual(attr.help, actual.help)
        self.assertFalse(actual.fixed)
Exemplo n.º 2
0
class ClassContainerTests(FactoryTestCase):
    def setUp(self):
        super().setUp()

        self.container = ClassContainer()

    def test_initialize(self):
        classes = [
            ClassFactory.create(qname="{xsdata}foo", tag=Tag.ELEMENT),
            ClassFactory.create(qname="{xsdata}foo", tag=Tag.COMPLEX_TYPE),
            ClassFactory.create(qname="{xsdata}foobar", tag=Tag.COMPLEX_TYPE),
        ]
        container = ClassContainer()
        container.extend(classes)

        expected = {
            "{xsdata}foo": classes[:2],
            "{xsdata}foobar": classes[2:],
        }

        self.assertEqual(2, len(container.data))
        self.assertEqual(expected, container.data)
        self.assertEqual(
            [
                "AttributeGroupHandler",
                "ClassExtensionHandler",
                "ClassEnumerationHandler",
                "AttributeSubstitutionHandler",
                "AttributeTypeHandler",
                "AttributeMergeHandler",
                "AttributeMixedContentHandler",
                "AttributeSanitizerHandler",
            ],
            [x.__class__.__name__ for x in container.processors],
        )

    @mock.patch.object(ClassContainer, "process_class")
    def test_find(self, mock_process_class):
        def process_class(x: Class):
            x.status = Status.PROCESSED

        class_a = ClassFactory.create(qname="a")
        class_b = ClassFactory.create(qname="b", status=Status.PROCESSED)
        class_c = ClassFactory.enumeration(2,
                                           qname="b",
                                           status=Status.PROCESSING)
        mock_process_class.side_effect = process_class
        self.container.extend([class_a, class_b, class_c])

        self.assertIsNone(self.container.find("nope"))
        self.assertEqual(class_a, self.container.find(class_a.qname))
        self.assertEqual(class_b, self.container.find(class_b.qname))
        self.assertEqual(
            class_c,
            self.container.find(class_b.qname, lambda x: x.is_enumeration))
        mock_process_class.assert_called_once_with(class_a)

    @mock.patch.object(ClassContainer, "process_class")
    def test_find_inner(self, mock_process_class):
        obj = ClassFactory.create()
        first = ClassFactory.create(qname="{a}a")
        second = ClassFactory.create(qname="{a}b", status=Status.PROCESSED)
        obj.inner.extend((first, second))

        def process_class(x: Class):
            x.status = Status.PROCESSED

        mock_process_class.side_effect = process_class

        self.assertEqual(first, self.container.find_inner(obj, "{a}a"))
        self.assertEqual(second, self.container.find_inner(obj, "{a}b"))
        mock_process_class.assert_called_once_with(first)

    def test_process(self):
        target = ClassFactory.create(inner=ClassFactory.list(2))
        self.container.add(target)

        self.container.process_class(target)
        self.assertEqual(Status.PROCESSED, target.status)
        self.assertEqual(Status.PROCESSED, target.inner[0].status)
        self.assertEqual(Status.PROCESSED, target.inner[1].status)

    @mock.patch.object(Class,
                       "should_generate",
                       new_callable=mock.PropertyMock)
    def test_filter_classes(self, mock_class_should_generate):
        mock_class_should_generate.side_effect = [
            True, False, False, True, False
        ]

        classes = ClassFactory.list(5)
        container = ClassContainer()
        container.extend(classes)

        expected = [
            classes[0],
            classes[3],
        ]
        container.filter_classes()
        self.assertEqual(expected, container.class_list)

    @mock.patch.object(Class,
                       "should_generate",
                       new_callable=mock.PropertyMock)
    def test_filter_classes_with_only_simple_types(self,
                                                   mock_class_should_generate):
        mock_class_should_generate.return_value = False
        classes = [ClassFactory.enumeration(2), ClassFactory.simple_type()]
        container = ClassContainer()
        container.extend(classes)
        container.filter_classes()

        self.assertEqual(classes, container.class_list)
Exemplo n.º 3
0
class ClassContainerTests(FactoryTestCase):
    def setUp(self):
        super().setUp()

        self.container = ClassContainer()

    def test_from_list(self):
        classes = [
            ClassFactory.create(qname="{xsdata}foo", type=Element),
            ClassFactory.create(qname="{xsdata}foo", type=ComplexType),
            ClassFactory.create(qname="{xsdata}foobar", type=ComplexType),
        ]
        container = ClassContainer.from_list(classes)

        expected = {
            "{xsdata}foo": classes[:2],
            "{xsdata}foobar": classes[2:],
        }

        self.assertEqual(2, len(container))
        self.assertEqual(expected, container)
        self.assertEqual(
            [
                "AttributeGroupHandler",
                "ClassExtensionHandler",
                "AttributeEnumUnionHandler",
                "AttributeSubstitutionHandler",
                "AttributeTypeHandler",
                "AttributeMergeHandler",
                "AttributeMixedContentHandler",
                "AttributeMismatchHandler",
            ],
            [x.__class__.__name__ for x in container.processors],
        )

    @mock.patch.object(ClassContainer, "process_class")
    def test_find(self, mock_process_class):
        def process_class(x: Class):
            x.status = Status.PROCESSED

        class_a = ClassFactory.create(qname="a")
        class_b = ClassFactory.create(qname="b", status=Status.PROCESSED)
        class_c = ClassFactory.enumeration(2, qname="b", status=Status.PROCESSING)
        mock_process_class.side_effect = process_class
        self.container.extend([class_a, class_b, class_c])

        self.assertIsNone(self.container.find("nope"))
        self.assertEqual(class_a, self.container.find(class_a.qname))
        self.assertEqual(class_b, self.container.find(class_b.qname))
        self.assertEqual(
            class_c, self.container.find(class_b.qname, lambda x: x.is_enumeration)
        )
        mock_process_class.assert_called_once_with(class_a)

    @mock.patch.object(ClassContainer, "process_class")
    def test_find_inner(self, mock_process_class):
        obj = ClassFactory.create()
        first = ClassFactory.create(qname="{a}a")
        second = ClassFactory.enumeration(2, qname="{a}a")
        third = ClassFactory.create(qname="{c}c", status=Status.PROCESSED)
        fourth = ClassFactory.enumeration(2, qname="{d}d", status=Status.PROCESSING)
        obj.inner.extend((first, second, third, fourth))

        def process_class(x: Class):
            x.status = Status.PROCESSED

        def is_enum(x: Class):
            return x.is_enumeration

        mock_process_class.side_effect = process_class

        self.assertIsNone(self.container.find_inner(obj, "nope"))
        self.assertEqual(first, self.container.find_inner(obj, "a"))
        self.assertEqual(second, self.container.find_inner(obj, "a", is_enum))
        self.assertEqual(third, self.container.find_inner(obj, "c"))
        self.assertEqual(fourth, self.container.find_inner(obj, "d", is_enum))
        mock_process_class.assert_has_calls([mock.call(first), mock.call(second)])

    def test_process(self):
        target = ClassFactory.create(inner=ClassFactory.list(2))
        self.container.add(target)

        self.container.process_class(target)
        self.assertEqual(Status.PROCESSED, target.status)
        self.assertEqual(Status.PROCESSED, target.inner[0].status)
        self.assertEqual(Status.PROCESSED, target.inner[1].status)

    @mock.patch.object(Class, "should_generate", new_callable=mock.PropertyMock)
    def test_filter_classes(self, mock_class_should_generate):
        mock_class_should_generate.side_effect = [True, False, False, True, False]

        classes = ClassFactory.list(5)
        container = ClassContainer.from_list(classes)

        expected = [
            classes[0],
            classes[3],
        ]
        container.filter_classes()
        self.assertEqual(expected, container.class_list)

    @mock.patch.object(Class, "should_generate", new_callable=mock.PropertyMock)
    def test_filter_classes_with_only_simple_types(self, mock_class_should_generate):
        mock_class_should_generate.return_value = False
        classes = [ClassFactory.enumeration(2), ClassFactory.create(type=SimpleType)]
        container = ClassContainer.from_list(classes)
        container.filter_classes()

        self.assertEqual(classes, container.class_list)
Exemplo n.º 4
0
class ClassEnumerationHandlerTests(FactoryTestCase):
    def setUp(self):
        super().setUp()

        self.root_enum = ClassFactory.enumeration(2)
        self.inner_enum = ClassFactory.enumeration(2)
        self.target = ClassFactory.create(attrs=[
            AttrFactory.create(
                name="value",
                tag=Tag.UNION,
                types=[
                    AttrTypeFactory.create(qname=self.root_enum.qname),
                    AttrTypeFactory.create(qname=self.inner_enum.qname,
                                           forward=True),
                ],
            ),
        ])
        self.target.inner.append(self.inner_enum)

        self.container = ClassContainer()
        self.container.add(self.target)
        self.container.add(self.root_enum)
        self.processor = ClassEnumerationHandler(container=self.container)

    def test_filter(self):
        target = ClassFactory.elements(2)
        self.processor.process(target)

        self.assertEqual(2, len(target.attrs))

        target.attrs.append(AttrFactory.enumeration())
        self.processor.process(target)
        self.assertEqual(1, len(target.attrs))
        self.assertTrue(target.attrs[0].is_enumeration)

    def test_flatten_skip_if_class_has_more_than_one_attribute(self):
        self.target.attrs.append(AttrFactory.create())
        self.processor.process(self.target)
        self.assertFalse(self.target.is_enumeration)
        self.assertEqual(2, len(self.target.attrs))

    def test_flatten_skip_when_attribute_tag_is_not_union(self):
        self.target.attrs[0].tag = Tag.ELEMENT
        self.processor.process(self.target)
        self.assertFalse(self.target.is_enumeration)

    def test_flatten_skip_when_types_is_not_enumeration_union(self):
        self.target.attrs[0].types.append(AttrTypeFactory.native(DataType.INT))
        self.processor.process(self.target)
        self.assertFalse(self.target.is_enumeration)

    def test_flatten_merges_enumeration_unions(self):
        self.processor.process(self.target)
        self.assertTrue(self.target.is_enumeration)

        expected = self.root_enum.attrs + self.inner_enum.attrs
        self.assertEqual(expected, self.target.attrs)
        self.assertEqual(0, len(self.target.inner))

    def test_promote(self):
        target = ClassFactory.elements(2)
        inner = ClassFactory.enumeration(3)

        target.inner.append(inner)
        target.inner.append(ClassFactory.simple_type())  # Irrelevant
        attr_type = AttrTypeFactory.create(qname=inner.qname, forward=True)

        target.attrs[0].types.append(attr_type.clone())
        target.attrs[1].types.append(attr_type.clone())

        self.container.add(target)
        self.assertEqual(3, len(self.container.data))

        self.processor.process(target)

        new_qname = build_qname(inner.target_namespace,
                                f"{target.name}_{inner.name}")

        self.assertEqual(4, len(self.container.data))
        new_inner = self.container.find(new_qname)

        self.assertEqual(1, len(target.inner))
        self.assertNotEqual(new_inner.qname, inner.qname)
        self.assertEqual(new_inner.attrs, inner.attrs)
        self.assertEqual(new_inner.qname, target.attrs[0].types[1].qname)
        self.assertEqual(new_inner.qname, target.attrs[1].types[1].qname)
        self.assertFalse(target.attrs[0].types[1].forward)
        self.assertFalse(target.attrs[1].types[1].forward)