def test_default_lookup(self):
     """Test lookup with default."""
     resolver = FunctionResolver([add_one, add_two, add_y], default=add_two)
     self.assertEqual(add_one, resolver.lookup("add_one"))
     self.assertEqual(add_one, resolver.lookup("ADD_ONE"))
     self.assertEqual(add_two, resolver.lookup(None))
     with self.assertRaises(KeyError):
         resolver.lookup("missing")
     with self.assertRaises(TypeError):
         resolver.lookup(3)
class TestFunctionResolver(unittest.TestCase):
    """Tests for the function resolver."""

    def setUp(self) -> None:
        """Set up the resolver class."""
        self.resolver = FunctionResolver([add_one, add_two, add_y])

    def test_contents(self):
        """Test the functions."""
        self.assertIn(add_one, set(self.resolver))

    def test_lookup(self):
        """Test looking up functions."""
        self.assertEqual(add_one, self.resolver.lookup("add_one"))
        self.assertEqual(add_one, self.resolver.lookup("ADD_ONE"))
        with self.assertRaises(ValueError):
            self.resolver.lookup(None)
        with self.assertRaises(KeyError):
            self.resolver.lookup("missing")
        with self.assertRaises(TypeError):
            self.resolver.lookup(3)

    def test_default_lookup(self):
        """Test lookup with default."""
        resolver = FunctionResolver([add_one, add_two, add_y], default=add_two)
        self.assertEqual(add_one, resolver.lookup("add_one"))
        self.assertEqual(add_one, resolver.lookup("ADD_ONE"))
        self.assertEqual(add_two, resolver.lookup(None))
        with self.assertRaises(KeyError):
            resolver.lookup("missing")
        with self.assertRaises(TypeError):
            resolver.lookup(3)

    def test_make(self):
        """Test making classes."""
        for x in range(10):
            f1 = self.resolver.make("add_y", {"y": 1})
            self.assertEqual(add_one(x), f1(x))
            # Test instantiating with kwargs
            f2 = self.resolver.make("add_y", y=1)
            self.assertEqual(add_one(x), f2(x))

    def test_make_safe(self):
        """Test the make_safe function, which always returns none on none input."""
        self.assertIsNone(self.resolver.make_safe(None))
        self.assertIsNone(FunctionResolver([add_one, add_two], default=add_two).make_safe(None))

    def test_passthrough(self):
        """Test instances are passed through unmodified."""
        for x in range(10):
            self.assertEqual(add_one(x), self.resolver.make(add_one)(x))

    def test_registration_synonym(self):
        """Test failure of registration."""
        self.resolver.register(add_three, synonyms={"add_trio"})
        for x in range(10):
            self.assertEqual(add_three(x), self.resolver.make("add_trio")(x))

    def test_registration_failure(self):
        """Test failure of registration."""
        with self.assertRaises(KeyError):
            self.resolver.register(add_one)

        def _new_fn(x: int) -> int:
            return x + 1

        with self.assertRaises(KeyError):
            self.resolver.register(_new_fn, synonyms={"add_one"})

    def test_entrypoints(self):
        """Test loading from entrypoints."""
        resolver = FunctionResolver.from_entrypoint("class_resolver_demo")
        self.assertEqual({"add", "sub", "mul"}, set(resolver.lookup_dict))
        self.assertEqual(set(), set(resolver.synonyms))
        self.assertNotIn("expected_failure", resolver.lookup_dict)

    def test_late_entrypoints(self):
        """Test loading late entrypoints."""
        resolver = FunctionResolver([operator.add, operator.sub])
        self.assertEqual({"add", "sub"}, set(resolver.lookup_dict))
        resolver.register_entrypoint("class_resolver_demo")
        self.assertEqual({"add", "sub", "mul"}, set(resolver.lookup_dict))
        self.assertEqual(set(), set(resolver.synonyms))
        self.assertNotIn("expected_failure", resolver.lookup_dict)