예제 #1
0
 def setUp(self) -> None:
     self.binding_registry = BindingRegistry()
     mock_binding = create_autospec(Binding, spec_set=True)
     mock_binding.target = FrozenTarget(MyType)
     self.my_type_binding = RegisteredBinding(mock_binding)
     mock_named_binding = create_autospec(Binding, spec_set=True)
     mock_named_binding.target = FrozenTarget(MyType, "my_name")
     self.my_type_named_binding = RegisteredBinding(mock_named_binding)
     self.my_type_binding_2 = RegisteredBinding(
         InstanceBinding(MyType, MyType()))
     other_mock_binding = create_autospec(Binding, spec_set=True)
     other_mock_binding.target = FrozenTarget(OtherType)
     self.other_type_binding = RegisteredBinding(other_mock_binding)
 def setUp(self):
     self.binding_registry = BindingRegistry()
     self.binding_registry.register(RegisteredBinding(InstanceBinding(SingletonScope, SingletonScope())))
     self.binding_registry.register(RegisteredBinding(InstanceBinding(ThreadScope, ThreadScope())))
     self.adapter = MultiBindingToProviderAdapter(
         FromRegisteredBindingProviderFactory()
     )
     self.my_instance = MyType()
     self.state = InjectionState(
         ProviderCreator(),
         self.binding_registry,
     )
     self.context = InjectionContext(Target(List[MyType]), self.state)
예제 #3
0
 def setUp(self) -> None:
     self.binding_registry = BindingRegistry()
     self.binding_registry.register(RegisteredBinding(InstanceBinding(SingletonScope, SingletonScope())))
     self.provider_creator = ProviderCreator()
     self.state = InjectionState(
         self.provider_creator,
         self.binding_registry,
     )
     self.context = InjectionContext(Target(MyType), self.state)
     self.other_context = InjectionContext(Target(MyOtherType), self.state)
     self.named_context = InjectionContext(Target(MyType, "my_name"), self.state)
     self.my_instance = MyType()
     self.my_instance_binding = InstanceBinding(MyType, self.my_instance)
     self.named_instance = MyType()
     self.my_named_instance_binding = InstanceBinding(MyType, self.named_instance, named="my_name")
     self.my_other_instance = MyOtherType()
     self.my_other_instance_binding = InstanceBinding(MyOtherType, self.my_other_instance)
예제 #4
0
class TestMultiBindingToProviderAdapter(unittest.TestCase):
    def setUp(self):
        self.binding_registry = BindingRegistry()
        self.binding_registry.register(
            RegisteredBinding(InstanceBinding(SingletonScope,
                                              SingletonScope())))
        self.binding_registry.register(
            RegisteredBinding(InstanceBinding(ThreadScope, ThreadScope())))
        self.adapter = MultiBindingToProviderAdapter(
            FromRegisteredBindingProviderFactory())
        self.my_instance = MyType()
        self.state = InjectionState(
            ProviderCreator(),
            self.binding_registry,
        )
        self.context = InjectionContext(Target(List[MyType]), self.state)

    def test_create_from_instance_binding(self):
        binding = RegisteredMultiBinding(
            MultiBinding(MyType,
                         [ItemBinding(bound_instance=self.my_instance)]),
            item_bindings=[
                RegisteredBinding(InstanceBinding(MyType, self.my_instance))
            ])
        provider = self.adapter.create(binding, self.context)

        list_instance = provider.get()
        self.assertEqual([self.my_instance], list_instance)

    def test_create_from_class_binding(self):
        binding = RegisteredMultiBinding(
            MultiBinding(MyType, [ItemBinding(bound_class=MyType)]),
            item_bindings=[RegisteredBinding(SelfBinding(MyType))])

        provider = self.adapter.create(binding, self.context)

        list_instance = provider.get()
        self.assertEqual(1, len(list_instance))
        self.assertIsInstance(list_instance[0], MyType)

    def test_create_from_provider_binding(self):
        provider = create_autospec(Provider, spec_set=True)
        instance = MyType()
        provider.get.return_value = instance
        binding = RegisteredMultiBinding(
            MultiBinding(MyType, [ItemBinding(bound_provider=provider)]),
            item_bindings=[
                RegisteredBinding(ProviderBinding(MyType, provider))
            ])

        provider = self.adapter.create(binding, self.context)

        list_instance = provider.get()
        self.assertEqual([instance], list_instance)

    def test_create_scoped_provider(self):
        provider = self.adapter.create(
            RegisteredMultiBinding(
                MultiBinding(MyType, [ItemBinding(bound_class=MyType)],
                             scope=ThreadScope),
                item_bindings=[
                    RegisteredBinding(SelfBinding(MyType, scope=ThreadScope))
                ]), self.context)

        self.assertIsInstance(provider, ThreadScopedProvider)
        instance = provider.get()
        self.assertIsInstance(instance, list)
        self.assertEqual(1, len(instance))
        self.assertIsInstance(instance[0], MyType)

    def test_non_injectable_scope_raises_exception(self):
        with self.assertRaises(NonInjectableTypeError):
            self.adapter.create(
                RegisteredMultiBinding(
                    MultiBinding(MyType, [], scope=ImmediateScope)),
                self.context)
예제 #5
0
class TestProviderCreator(unittest.TestCase):
    def setUp(self) -> None:
        self.binding_registry = BindingRegistry()
        self.binding_registry.register(RegisteredBinding(InstanceBinding(SingletonScope, SingletonScope())))
        self.provider_creator = ProviderCreator()
        self.state = InjectionState(
            self.provider_creator,
            self.binding_registry,
        )
        self.context = InjectionContext(Target(MyType), self.state)
        self.other_context = InjectionContext(Target(MyOtherType), self.state)
        self.named_context = InjectionContext(Target(MyType, "my_name"), self.state)
        self.my_instance = MyType()
        self.my_instance_binding = InstanceBinding(MyType, self.my_instance)
        self.named_instance = MyType()
        self.my_named_instance_binding = InstanceBinding(MyType, self.named_instance, named="my_name")
        self.my_other_instance = MyOtherType()
        self.my_other_instance_binding = InstanceBinding(MyOtherType, self.my_other_instance)

    def test_get_provider_with_instance_bindings(self):
        self.binding_registry.register(RegisteredBinding(self.my_instance_binding))
        self.binding_registry.register(RegisteredBinding(self.my_other_instance_binding))

        provider = self.provider_creator.get_provider(self.context)
        self.assertIsInstance(provider, FromInstanceProvider)
        instance = provider.get()
        self.assertIsInstance(instance, MyType)

        provider = self.provider_creator.get_provider(self.other_context)
        self.assertIsInstance(provider, FromInstanceProvider)
        instance = provider.get()
        self.assertIsInstance(instance, MyOtherType)

    def test_get_provider_caches_providers(self):
        self.binding_registry.register(RegisteredBinding(self.my_instance_binding))

        provider_1 = self.provider_creator.get_provider(self.context)
        provider_2 = self.provider_creator.get_provider(self.context)
        self.assertIs(provider_1, provider_2)

    def test_get_provider_with_named_bindings(self):
        self.binding_registry.register(RegisteredBinding(self.my_named_instance_binding))

        with self.assertRaises(NoBindingFound):
            self.provider_creator.get_provider(self.context)

        provider = self.provider_creator.get_provider(self.named_context)
        self.assertIsInstance(provider, FromInstanceProvider)
        instance = provider.get()
        self.assertIs(self.named_instance, instance)

    def test_missing_binding_raises_exception(self):
        class MyParentClass:
            def __init__(self, my_param: MyType):
                self.my_param = my_param

        my_parent_binding = SelfBinding(MyParentClass)
        context = InjectionContext(Target(MyParentClass), self.state)
        self.binding_registry.register(RegisteredBinding(my_parent_binding))

        with self.assertRaises(NonInjectableTypeError):
            self.provider_creator.get_provider(context)

    def test_list_binding_with_multi_binding(self):
        self.binding_registry.register(
            RegisteredMultiBinding(
                MultiBinding(
                    MyType,
                    [
                        ItemBinding(bound_instance=self.my_instance),
                        ItemBinding(bound_class=MyType),
                    ],
                ),
                item_bindings=[
                    RegisteredBinding(InstanceBinding(MyType, self.my_instance)),
                    RegisteredBinding(SelfBinding(MyType)),
                ]
            )
        )
        context = InjectionContext(Target(List[MyType]), self.state)
        provider = self.provider_creator.get_provider(context)
        list_instance = provider.get()
        self.assertEqual([self.my_instance, ANY], list_instance)
        self.assertIsInstance(list_instance[1], MyType)

    def test_list_binding_with_named_arguments(self):
        self.binding_registry.register(
            RegisteredMultiBinding(
                MultiBinding(
                    MyType,
                    [
                        ItemBinding(bound_instance=self.named_instance),
                    ],
                    named="my_name",
                ),
                item_bindings=[
                    RegisteredBinding(InstanceBinding(MyType, self.named_instance, named="my_name"))
                ]
            )
        )
        self.binding_registry.register(
            RegisteredMultiBinding(
                MultiBinding(
                    MyType,
                    [
                        ItemBinding(bound_instance=self.my_instance),
                    ],
                ),
                item_bindings=[
                    RegisteredBinding(InstanceBinding(MyType, self.my_instance))
                ]
            )
        )
        self.binding_registry.register(RegisteredBinding(self.my_named_instance_binding))
        self.binding_registry.register(RegisteredBinding(SelfBinding(MyType)))

        context = InjectionContext(Target(List[MyType], "my_name"), self.state)
        provider = self.provider_creator.get_provider(context)
        list_instance = provider.get()
        self.assertEqual([self.named_instance], list_instance)

    def test_set_binding_with_multi_binding(self):
        self.binding_registry.register(
            RegisteredMultiBinding(
                MultiBinding(
                    MyType,
                    [
                        ItemBinding(bound_instance=self.my_instance),
                        ItemBinding(bound_class=MyType),
                    ],
                ),
                item_bindings=[
                    RegisteredBinding(InstanceBinding(MyType, self.my_instance)),
                    RegisteredBinding(SelfBinding(MyType)),
                ]
            )
        )
        context = InjectionContext(Target(Set[MyType]), self.state)
        provider = self.provider_creator.get_provider(context)
        self.assertIsInstance(provider, FromClassProvider)
        set_instance = provider.get()
        self.assertIn(self.my_instance, set_instance)
        self.assertEqual(2, len(set_instance))

    def test_tuple_binding_with_multi_binding(self):
        self.binding_registry.register(
            RegisteredMultiBinding(
                MultiBinding(
                    MyType,
                    [
                        ItemBinding(bound_instance=self.my_instance),
                        ItemBinding(bound_class=MyType),
                    ],
                ),
                item_bindings=[
                    RegisteredBinding(InstanceBinding(MyType, self.my_instance)),
                    RegisteredBinding(SelfBinding(MyType)),
                ]
            )
        )
        context = InjectionContext(Target(Tuple[MyType]), self.state)
        provider = self.provider_creator.get_provider(context)
        self.assertIsInstance(provider, FromClassProvider)
        tuple_instance = provider.get()
        self.assertEqual((self.my_instance, ANY), tuple_instance)
        self.assertIsInstance(tuple_instance[1], MyType)

    def test_optional_binding(self):
        self.binding_registry.register(RegisteredBinding(self.my_instance_binding))

        context = InjectionContext(Target(Optional[MyType]), self.state)
        provider = self.provider_creator.get_provider(context)
        self.assertIsInstance(provider, FromInstanceProvider)
        instance = provider.get()
        self.assertIs(self.my_instance, instance)

    def test_type_binding(self):
        class SubType(MyType):
            pass

        self.binding_registry.register(RegisteredBinding(self.my_instance_binding))
        self.binding_registry.register(RegisteredBinding(ClassBinding(MyType, SubType)))

        context = InjectionContext(Target(Type[MyType]), self.state)
        provider = self.provider_creator.get_provider(context)
        self.assertIsInstance(provider, FromInstanceProvider)
        instance = provider.get()
        self.assertIs(SubType, instance)

    def test_type_binding_with_explicit_binding(self):
        class SubType(MyType):
            pass

        self.binding_registry.register(RegisteredBinding(InstanceBinding(Type[MyType], MyType)))
        self.binding_registry.register(RegisteredBinding(ClassBinding(MyType, SubType)))

        context = InjectionContext(Target(Type[MyType]), self.state)
        provider = self.provider_creator.get_provider(context)
        self.assertIsInstance(provider, FromInstanceProvider)
        instance = provider.get()
        self.assertIs(MyType, instance)

    def test_type_binding_without_class_binding(self):
        class MyParentClass:
            def __init__(self, my_param: Type[MyType]):
                self.my_param = my_param

        parent_binding = SelfBinding(MyParentClass)
        self.binding_registry.register(RegisteredBinding(self.my_instance_binding))
        self.binding_registry.register(RegisteredBinding(parent_binding))
        context = InjectionContext(Target(Type[MyType]), self.state)

        with self.assertRaises(NonInjectableTypeError):
            self.provider_creator.get_provider(context)

    def test_provider_binding(self):
        class MyInjectee:
            pass

        class MyProvider(Provider[MyInjectee]):
            def __init__(self, my_param: MyType):
                self.my_param = my_param

            def get(self) -> MyInjectee:
                return MyInjectee()

        provider_binding = ProviderBinding(MyInjectee, MyProvider)
        self.binding_registry.register(RegisteredBinding(self.my_instance_binding))
        self.binding_registry.register(RegisteredBinding(provider_binding))
        context = InjectionContext(Target(MyInjectee), self.state)

        self.provider_creator.get_provider(context)

    def test_list_implicit_binding(self):
        instance = MyType()
        self.binding_registry.register(RegisteredBinding(InstanceBinding(MyType, instance)))
        context = InjectionContext(Target(List[MyType]), self.state)
        provider = self.provider_creator.get_provider(context)
        self.assertIsInstance(provider, ListProvider)
        list_instance = provider.get()
        self.assertEqual([instance], list_instance)
예제 #6
0
class TestBindingRegistry(unittest.TestCase):
    def setUp(self) -> None:
        self.binding_registry = BindingRegistry()
        mock_binding = create_autospec(Binding, spec_set=True)
        mock_binding.target = FrozenTarget(MyType)
        self.my_type_binding = RegisteredBinding(mock_binding)
        mock_named_binding = create_autospec(Binding, spec_set=True)
        mock_named_binding.target = FrozenTarget(MyType, "my_name")
        self.my_type_named_binding = RegisteredBinding(mock_named_binding)
        self.my_type_binding_2 = RegisteredBinding(
            InstanceBinding(MyType, MyType()))
        other_mock_binding = create_autospec(Binding, spec_set=True)
        other_mock_binding.target = FrozenTarget(OtherType)
        self.other_type_binding = RegisteredBinding(other_mock_binding)

    def test_register_saves_binding_to_new_type(self):
        self.binding_registry.register(self.my_type_binding)
        self.binding_registry.register(self.my_type_named_binding)
        self.binding_registry.register(self.other_type_binding)
        self.assertEqual(
            {
                FrozenTarget(MyType): self.my_type_binding,
                FrozenTarget(MyType, "my_name"): self.my_type_named_binding,
                FrozenTarget(OtherType): self.other_type_binding,
            }, self.binding_registry.get_bindings_by_target())

    def test_register_multi_binding_saves_binding_to_known_type_in_order(self):
        item_binding_1 = ItemBinding(MyType)
        registered_item_binding_1 = RegisteredBinding(SelfBinding(MyType))
        item_binding_2 = ItemBinding(bound_instance=MyType())
        registered_item_binding_2 = RegisteredBinding(
            InstanceBinding(MyType, item_binding_2.bound_instance))
        binding_1 = RegisteredMultiBinding(MultiBinding(
            MyType, [item_binding_1]),
                                           item_bindings=[
                                               registered_item_binding_1,
                                           ])
        binding_2 = RegisteredMultiBinding(MultiBinding(
            MyType, [item_binding_2], override_bindings=False),
                                           item_bindings=[
                                               registered_item_binding_2,
                                           ])
        self.binding_registry.register(binding_1)
        self.binding_registry.register(binding_2)
        registered_binding = self.binding_registry.get_binding(
            Target(List[MyType]))
        self.assertIsInstance(registered_binding, RegisteredMultiBinding)
        self.assertIsInstance(registered_binding.raw_binding, MultiBinding)
        self.assertEqual(
            [registered_item_binding_1, registered_item_binding_2],
            registered_binding.item_bindings)

    def test_register_multi_binding_with_override(self):
        item_binding_1 = ItemBinding(MyType)
        item_binding_2 = ItemBinding(bound_instance=MyType())
        binding_1 = RegisteredMultiBinding(
            MultiBinding(MyType, [item_binding_1]))
        binding_2 = RegisteredMultiBinding(
            MultiBinding(MyType, [item_binding_2], override_bindings=True))
        self.binding_registry.register(binding_1)
        self.binding_registry.register(binding_2)
        binding = self.binding_registry.get_binding(Target(
            List[MyType])).raw_binding
        self.assertIsInstance(binding, MultiBinding)
        self.assertEqual([item_binding_2], binding.item_bindings)

    def test_get_binding_returns_binding(self):
        self.binding_registry.register(self.my_type_binding)
        self.binding_registry.register(self.my_type_named_binding)
        self.binding_registry.register(self.my_type_binding_2)
        binding = self.binding_registry.get_binding(Target(MyType))

        self.assertEqual(self.my_type_binding_2, binding)

    def test_get_binding_returns_named_binding(self):
        self.binding_registry.register(self.my_type_binding)
        self.binding_registry.register(self.my_type_named_binding)
        self.binding_registry.register(self.my_type_binding_2)
        binding = self.binding_registry.get_binding(Target(MyType, "my_name"))

        self.assertEqual(self.my_type_named_binding, binding)

    def test_get_binding_for_unknown_type_returns_none(self):
        binding = self.binding_registry.get_binding(Target(MyType))

        self.assertIsNone(binding)

    def test_get_binding_from_string(self):
        self.binding_registry.register(self.my_type_binding)
        self.binding_registry.register(self.other_type_binding)
        binding = self.binding_registry.get_binding(Target("MyType"))

        self.assertEqual(self.my_type_binding, binding)

    def test_get_named_binding_from_string(self):
        self.binding_registry.register(self.my_type_binding)
        self.binding_registry.register(self.my_type_named_binding)
        self.binding_registry.register(self.other_type_binding)
        binding = self.binding_registry.get_binding(Target(
            "MyType", "my_name"))

        self.assertEqual(self.my_type_named_binding, binding)

    def test_get_binding_from_unknown_string(self):
        binding = self.binding_registry.get_binding(Target("MyUnknownType"))
        self.assertIsNone(binding)

    def test_get_binding_from_string_with_name_conflict_raises_exception(self):
        class MyNewType:
            pass

        binding_1 = create_autospec(Binding, spec_set=True)
        binding_1.target = FrozenTarget(MyNewType)

        # pylint: disable=function-redefined
        class MyNewType:
            pass

        binding_2 = create_autospec(Binding, spec_set=True)
        binding_2.target = FrozenTarget(MyNewType)

        self.binding_registry.register(RegisteredBinding(binding_1))
        self.binding_registry.register(RegisteredBinding(binding_2))

        with self.assertRaises(NonInjectableTypeError):
            self.binding_registry.get_binding(Target("MyNewType"))

    def test_register_provider_binding_with_instance_creates_additional_binding(
            self):
        class MyProvider(Provider[str]):
            def get(self) -> str:
                return "hello"

        provider_instance = MyProvider()
        provider_binding = RegisteredBinding(
            ProviderBinding(str, provider_instance, named="my_name"))
        self.binding_registry.register(provider_binding)

        self.assertEqual(
            {
                FrozenTarget(str, "my_name"):
                provider_binding,
                FrozenTarget(MyProvider, "my_name"):
                RegisteredBinding(
                    InstanceBinding(MyProvider, provider_instance, "my_name")),
            }, self.binding_registry.get_bindings_by_target())

    def test_register_provider_binding_with_class_creates_self_binding(self):
        class MyProvider(Provider[str]):
            def get(self) -> str:
                return "hello"

        provider_binding = RegisteredBinding(
            ProviderBinding(str, MyProvider, PerLookupScope, "my_name"))
        self.binding_registry.register(provider_binding)

        self.assertEqual(
            provider_binding,
            self.binding_registry.get_binding(Target(str, "my_name")))
        provider_binding = self.binding_registry.get_binding(
            Target(MyProvider, "my_name"))
        self.assertIsInstance(provider_binding.raw_binding, SelfBinding)
        self.assertEqual(MyProvider, provider_binding.raw_binding.target_type)
        self.assertEqual(PerLookupScope, provider_binding.raw_binding.scope)

    def test_register_multi_binding_with_provider_binding_creates_self_binding(
            self):
        class MyProvider(Provider[str]):
            def get(self) -> str:
                return "hello"

        provider_binding = RegisteredBinding(ProviderBinding(str, MyProvider))
        multi_binding = RegisteredMultiBinding(
            MultiBinding(str, [provider_binding.raw_binding]),
            item_bindings=[provider_binding])
        self.binding_registry.register(multi_binding)

        provider_binding = self.binding_registry.get_binding(
            Target(MyProvider))
        self.assertIsInstance(provider_binding.raw_binding, SelfBinding)
        self.assertEqual(MyProvider, provider_binding.raw_binding.target_type)

    def test_register_class_binding_creates_self_binding_if_target_does_not_exist(
            self):
        class MySubType(MyType):
            pass

        class_binding = RegisteredBinding(ClassBinding(MyType, MySubType))
        self.binding_registry.register(class_binding)
        self.assertIs(class_binding,
                      self.binding_registry.get_binding(Target(MyType)))
        self_binding = self.binding_registry.get_binding(Target(MySubType))
        self.assertIsInstance(self_binding.raw_binding, SelfBinding)
        self.assertEqual(MySubType, self_binding.raw_binding.target_type)

    def test_register_class_binding_does_not_create_self_binding_if_target_exists(
            self):
        class MySubType(MyType):
            pass

        my_instance = MySubType()
        class_binding = RegisteredBinding(ClassBinding(MyType, MySubType))
        instance_binding = RegisteredBinding(
            InstanceBinding(MySubType, my_instance))
        self.binding_registry.register(instance_binding)
        self.binding_registry.register(class_binding)
        self.assertIs(class_binding,
                      self.binding_registry.get_binding(Target(MyType)))
        self.assertIs(instance_binding,
                      self.binding_registry.get_binding(Target(MySubType)))