예제 #1
0
    def test_group_fields(self):
        target = ClassFactory.create(attrs=AttrFactory.list(2))
        target.attrs[0].restrictions.choice = "1"
        target.attrs[1].restrictions.choice = "1"
        target.attrs[0].restrictions.min_occurs = 10
        target.attrs[0].restrictions.max_occurs = 15
        target.attrs[1].restrictions.min_occurs = 5
        target.attrs[1].restrictions.max_occurs = 20

        expected = AttrFactory.create(
            name="attr_B_Or_attr_C",
            tag="Choice",
            index=0,
            types=[AttrTypeFactory.native(DataType.ANY_TYPE)],
            choices=[
                AttrFactory.create(
                    tag=target.attrs[0].tag,
                    name="attr_B",
                    types=target.attrs[0].types,
                ),
                AttrFactory.create(
                    tag=target.attrs[1].tag,
                    name="attr_C",
                    types=target.attrs[1].types,
                ),
            ],
        )
        expected_res = Restrictions(min_occurs=5, max_occurs=20)

        self.sanitizer.group_fields(target, list(target.attrs))
        self.assertEqual(1, len(target.attrs))
        self.assertEqual(expected, target.attrs[0])
        self.assertEqual(expected_res, target.attrs[0].restrictions)
예제 #2
0
    def test_process_attribute(self, mock_find):
        target = ClassFactory.create(
            attrs=[
                AttrFactory.create(types=[AttrTypeFactory.create("foo")]),
                AttrFactory.create(types=[AttrTypeFactory.create("bar")]),
            ]
        )
        mock_find.side_effect = [-1, 2]

        first_attr = target.attrs[0]
        second_attr = target.attrs[1]
        first_attr.restrictions.max_occurs = 2

        attr_qname = first_attr.types[0].qname
        reference_attrs = AttrFactory.list(2)

        self.processor.create_substitutions()
        self.processor.substitutions[attr_qname] = reference_attrs
        self.processor.process_attribute(target, first_attr)

        self.assertEqual(4, len(target.attrs))

        self.assertEqual(reference_attrs[0], target.attrs[0])
        self.assertIsNot(reference_attrs[0], target.attrs[0])
        self.assertEqual(reference_attrs[1], target.attrs[3])
        self.assertIsNot(reference_attrs[1], target.attrs[3])
        self.assertEqual(2, target.attrs[0].restrictions.max_occurs)
        self.assertEqual(2, target.attrs[3].restrictions.max_occurs)

        self.processor.process_attribute(target, second_attr)
        self.assertEqual(4, len(target.attrs))
예제 #3
0
    def test_create_substitutions(self, mock_create_substitution):
        ns = "xsdata"
        classes = [
            ClassFactory.create(
                substitutions=[build_qname(ns, "foo"), build_qname(ns, "bar")],
                abstract=True,
            ),
            ClassFactory.create(substitutions=[build_qname(ns, "foo")], abstract=True),
        ]

        reference_attrs = AttrFactory.list(3)
        mock_create_substitution.side_effect = reference_attrs

        self.processor.container.extend(classes)
        self.processor.create_substitutions()

        expected = {
            build_qname(ns, "foo"): [reference_attrs[0], reference_attrs[2]],
            build_qname(ns, "bar"): [reference_attrs[1]],
        }
        self.assertEqual(expected, self.processor.substitutions)

        mock_create_substitution.assert_has_calls(
            [mock.call(classes[0]), mock.call(classes[0]), mock.call(classes[1])]
        )
예제 #4
0
    def test_map_binding_message_parts_with_original_message(
            self, mock_find_message, mock_create_message_attributes):
        definitions = Definitions
        message_name = "foo:bar"
        ns_map = {}
        message = Message(
            name="session",
            parts=[
                Part(name="token", element="foo:token"),
                Part(name="messageId", type="id"),
                Part(name="another", type="id"),
            ],
        )
        extended = AnyElement()
        mock_create_message_attributes.return_value = AttrFactory.list(2)
        mock_find_message.return_value = message

        actual = DefinitionsMapper.map_binding_message_parts(
            definitions, message_name, extended, ns_map)

        self.assertIsInstance(actual, Generator)
        self.assertEqual(2, len(list(actual)))

        mock_create_message_attributes.assert_called_once_with(
            message.parts, ns_map)
        mock_find_message.assert_called_once_with("bar")
예제 #5
0
    def test_dependencies(self):
        obj = ClassFactory.create(
            attrs=[
                AttrFactory.create(
                    types=[AttrTypeFactory.native(DataType.DECIMAL)]),
                AttrFactory.create(
                    types=[
                        AttrTypeFactory.create(
                            qname=build_qname(Namespace.XS.uri, "annotated"),
                            forward=True,
                        )
                    ],
                    choices=[
                        AttrFactory.create(
                            name="x",
                            types=[
                                AttrTypeFactory.create(qname="choiceAttr"),
                                AttrTypeFactory.native(DataType.STRING),
                            ],
                        ),
                        AttrFactory.create(
                            name="x",
                            types=[
                                AttrTypeFactory.create(qname="choiceAttrTwo"),
                                AttrTypeFactory.create(qname="choiceAttrEnum"),
                            ],
                        ),
                    ],
                ),
                AttrFactory.create(types=[
                    AttrTypeFactory.create(
                        qname=build_qname(Namespace.XS.uri, "openAttrs")),
                    AttrTypeFactory.create(
                        qname=build_qname(Namespace.XS.uri, "localAttribute")),
                ]),
            ],
            extensions=[
                ExtensionFactory.reference(
                    build_qname(Namespace.XS.uri, "foobar")),
                ExtensionFactory.reference(
                    build_qname(Namespace.XS.uri, "foobar")),
            ],
            inner=[
                ClassFactory.create(attrs=AttrFactory.list(
                    2, types=AttrTypeFactory.list(1, qname="{xsdata}foo")))
            ],
        )

        expected = [
            "choiceAttr",
            "choiceAttrTwo",
            "choiceAttrEnum",
            "{http://www.w3.org/2001/XMLSchema}openAttrs",
            "{http://www.w3.org/2001/XMLSchema}localAttribute",
            "{http://www.w3.org/2001/XMLSchema}foobar",
            "{xsdata}foo",
        ]
        self.assertCountEqual(expected, list(obj.dependencies()))
예제 #6
0
    def test_group_fields_limit_name(self):
        target = ClassFactory.create(attrs=AttrFactory.list(3))
        for attr in target.attrs:
            attr.restrictions.choice = "1"

        self.sanitizer.group_fields(target, list(target.attrs))

        self.assertEqual(1, len(target.attrs))
        self.assertEqual("attr_B_Or_attr_C_Or_attr_D", target.attrs[0].name)

        target = ClassFactory.create(attrs=AttrFactory.list(4))
        for attr in target.attrs:
            attr.restrictions.choice = "1"

        self.sanitizer.group_fields(target, list(target.attrs))
        self.assertEqual("choice", target.attrs[0].name)

        target = ClassFactory.create()
        attr = AttrFactory.element(restrictions=Restrictions(choice="1"))
        target.attrs.append(attr)
        target.attrs.append(attr.clone())
        self.sanitizer.group_fields(target, list(target.attrs))
        self.assertEqual("choice", target.attrs[0].name)
예제 #7
0
    def test_group_fields_with_effective_choices_sums_occurs(self):
        target = ClassFactory.create(attrs=AttrFactory.list(2))
        target.attrs[0].restrictions.choice = "effective_1"
        target.attrs[1].restrictions.choice = "effective_1"
        target.attrs[0].restrictions.min_occurs = 1
        target.attrs[0].restrictions.max_occurs = 2
        target.attrs[1].restrictions.min_occurs = 3
        target.attrs[1].restrictions.max_occurs = 4

        expected_res = Restrictions(min_occurs=4, max_occurs=6)

        self.sanitizer.group_fields(target, list(target.attrs))
        self.assertEqual(1, len(target.attrs))
        self.assertEqual(expected_res, target.attrs[0].restrictions)
예제 #8
0
    def test_field_metadata_choices(self):
        attr = AttrFactory.create(choices=AttrFactory.list(2, tag=Tag.ELEMENT))
        actual = self.filters.field_metadata(attr, "foo", ["cls"])
        expected = (
            {
                "name": "attr_B",
                "type": "Type[str]"
            },
            {
                "name": "attr_C",
                "type": "Type[str]"
            },
        )

        self.assertEqual(expected, actual["choices"])
예제 #9
0
    def test_class_references(self):
        target = ClassFactory.elements(
            2,
            inner=ClassFactory.list(2, attrs=AttrFactory.list(1)),
            extensions=ExtensionFactory.list(1),
        )

        actual = ClassAnalyzer.class_references(target)
        # +1 target
        # +2 attrs
        # +2 attr types
        # +1 extension
        # +1 extension type
        # +2 inner classes
        # +2 inner classes attrs
        # +2 inner classes attr types
        self.assertEqual(13, len(actual))
        self.assertEqual(id(target), actual[0])
예제 #10
0
    def test_build_message_class(self, mock_create_message_attributes):
        message = Message(name="bar", parts=[Part()])
        message.ns_map["foo"] = "bar"
        definitions = Definitions(messages=[message],
                                  target_namespace="xsdata",
                                  location="foo.wsdl")
        port_type_message = PortTypeMessage(message="foo:bar")

        attrs = AttrFactory.list(2)
        mock_create_message_attributes.return_value = attrs
        actual = DefinitionsMapper.build_message_class(definitions,
                                                       port_type_message)
        expected = Class(
            qname=build_qname("xsdata", "bar"),
            status=Status.PROCESSED,
            tag=Tag.ELEMENT,
            module="foo",
            ns_map=message.ns_map,
            attrs=attrs,
        )
        self.assertEqual(expected, actual)
예제 #11
0
    def test_flatten(self):
        target = ClassFactory.create(
            qname="{xsdata}root", attrs=AttrFactory.list(3), inner=ClassFactory.list(2)
        )

        for attr in target.attrs:
            attr.types.extend([x.clone() for x in attr.types])
            for tp in attr.types:
                tp.forward = True

        result = ClassUtils.flatten(target, "xsdata")
        actual = list(result)

        self.assertIsInstance(result, Generator)
        self.assertEqual(3, len(actual))

        for obj in actual:
            self.assertEqual("xsdata", obj.module)

        for attr in target.attrs:
            self.assertEqual(1, len(attr.types))
            self.assertFalse(attr.types[0].forward)
예제 #12
0
def mock_create_attr(*args, **kwargs):
    return AttrFactory.list(1)