def test_lookup_no_synonyms(self): """Test looking up classes without auto-synonym.""" resolver = Resolver([A], base=Base, synonym_attribute=None) self.assertEqual(A, resolver.lookup("a")) self.assertEqual(A, resolver.lookup("A")) with self.assertRaises(KeyError): self.assertEqual(A, resolver.lookup("a_synonym_1"))
class TestResolver(unittest.TestCase): """Tests for the resolver.""" def setUp(self) -> None: """Set up the resolver class.""" self.resolver = Resolver([A, B, C, E], base=Base) def test_contents(self): """Test the functions.""" self.assertIn(A, set(self.resolver)) def test_iterator(self): """Test iterating over classes.""" self.assertEqual([A, B, C, E], list(self.resolver)) def test_lookup(self): """Test looking up classes.""" self.assertEqual(A, self.resolver.lookup("a")) self.assertEqual(A, self.resolver.lookup("A")) self.assertEqual(A, self.resolver.lookup("a_synonym_1")) self.assertEqual(A, self.resolver.lookup("a_synonym_2")) with self.assertRaises(ValueError): self.resolver.lookup(None) with self.assertRaises(KeyError): self.resolver.lookup("missing") with self.assertRaises(TypeError): self.resolver.lookup(3) self.assertEqual(self.resolver.lookup(A(name="max")), A) def test_docdata(self): """Test docdata.""" full = { "k1": "v1", "k2": { "k21": "v21" }, } self.assertEqual(full, self.resolver.docdata("a")) self.assertEqual("v1", self.resolver.docdata("a", "k1")) self.assertEqual({"k21": "v21"}, self.resolver.docdata("a", "k2")) self.assertEqual("v21", self.resolver.docdata("a", "k2", "k21")) def test_lookup_no_synonyms(self): """Test looking up classes without auto-synonym.""" resolver = Resolver([A], base=Base, synonym_attribute=None) self.assertEqual(A, resolver.lookup("a")) self.assertEqual(A, resolver.lookup("A")) with self.assertRaises(KeyError): self.assertEqual(A, resolver.lookup("a_synonym_1")) def test_passthrough(self): """Test instances are passed through unmodified.""" a = A(name="charlie") self.assertEqual(a, self.resolver.make(a)) def test_make(self): """Test making classes.""" name = "charlie" # Test instantiating with positional dict into kwargs self.assertEqual(A(name=name), self.resolver.make("a", {"name": name})) # Test instantiating with kwargs self.assertEqual(A(name=name), self.resolver.make("a", name=name)) def test_make_safe(self): """Test the make_safe function, which always returns none on none input.""" self.assertIsNone(self.resolver.make_safe(None)) self.assertIsNone( Resolver.from_subclasses(Base, default=A).make_safe(None)) name = "charlie" # Test instantiating with positional dict into kwargs self.assertEqual(A(name=name), self.resolver.make_safe("a", {"name": name})) # Test instantiating with kwargs self.assertEqual(A(name=name), self.resolver.make_safe("a", name=name)) def test_registration_synonym(self): """Test failure of registration.""" self.assertNotIn(D, self.resolver.lookup_dict.values()) self.resolver.register(D, synonyms={"dope"}) name = "charlie" self.assertEqual(D(name=name), self.resolver.make("d", name=name)) def test_registration_empty_synonym_failure(self): """Test failure of registration.""" self.assertNotIn(D, self.resolver.lookup_dict.values()) with self.assertRaises(ValueError): self.resolver.register(D, synonyms={""}) def test_registration_name_failure(self): """Test failure of registration.""" with self.assertRaises(RegistrationNameConflict) as e: self.resolver.register(A) self.assertEqual("name", e.exception.label) self.assertIn("name", str(e.exception)) with self.assertRaises(RegistrationNameConflict) as e: self.resolver.register(D, synonyms={"a"}) self.assertEqual("synonym", e.exception.label) self.assertIn("synonym", str(e.exception)) def test_registration_synonym_failure(self): """Test failure of registration.""" resolver = Resolver([], base=Base) resolver.register(A, synonyms={"B"}) with self.assertRaises(RegistrationSynonymConflict) as e: resolver.register(B) self.assertEqual("name", e.exception.label) self.assertIn("name", str(e.exception)) class F(Base): """Extra class for testing.""" with self.assertRaises(RegistrationSynonymConflict) as e: resolver.register(F, synonyms={"B"}) self.assertEqual("synonym", e.exception.label) self.assertIn("synonym", str(e.exception)) def test_make_from_kwargs(self): """Test making classes from kwargs.""" name = "charlie" self.assertEqual( A(name=name), self.resolver.make_from_kwargs( key="magic", data=dict( ignored_entry=..., magic="a", magic_kwargs=dict(name=name, ), ), ), ) @unittest.skipIf(tune is None, "ray[tune] was not installed properly") def test_variant_generation(self): """Test whether ray tune can generate variants from the search space.""" search_space = self.resolver.ray_tune_search_space( kwargs_search_space=dict(name=tune.choice(["charlie", "max"]), ), ) for spec in itertools.islice( tune.suggest.variant_generator.generate_variants(search_space), 2): config = {k[0]: v for k, v in spec[0].items()} query = config.pop("query") instance = self.resolver.make(query=query, pos_kwargs=config) self.assertIsInstance(instance, Base) def test_bad_click_option(self): """Test failure to get a click option.""" with self.assertRaises(ValueError): self.resolver.get_option("--opt") # no default given def test_required_click_option(self): """Test non-failure to get a required click option without default.""" self.resolver.get_option("--opt", as_string=True, required=True) def test_click_option(self): """Test the click option.""" @click.command() @self.resolver.get_option("--opt", default="a") def cli(opt): """Run the test CLI.""" self.assertIsInstance(opt, type) click.echo(opt.__name__, nl=False) self._test_cli(cli) def _test_cli(self, cli): runner = CliRunner() # Test default result: Result = runner.invoke(cli, []) self.assertEqual(A.__name__, result.output) # Test canonical name result: Result = runner.invoke(cli, ["--opt", "A"]) self.assertEqual(A.__name__, result.output) # Test normalizing name result: Result = runner.invoke(cli, ["--opt", "a"]) self.assertEqual(A.__name__, result.output) def test_click_option_str(self): """Test the click option.""" @click.command() @self.resolver.get_option("--opt", default="a", as_string=True) def cli(opt): """Run the test CLI.""" self.assertIsInstance(opt, str) click.echo(self.resolver.lookup(opt).__name__, nl=False) self._test_cli(cli) def test_click_option_default(self): """Test generating an option with a default.""" resolver = Resolver([A, B, C, E], base=Base, default=A) @click.command() @resolver.get_option("--opt", as_string=True) def cli(opt): """Run the test CLI.""" self.assertIsInstance(opt, str) click.echo(self.resolver.lookup(opt).__name__, nl=False) self._test_cli(cli) def test_click_option_multiple(self): """Test the click option with multiple arguments.""" @click.command() @self.resolver.get_option("--opt", default="a", as_string=True, multiple=True) def cli(opt): """Run the test CLI.""" self.assertIsInstance(opt, Sequence) for opt_ in opt: self.assertIsInstance(opt_, str) click.echo(self.resolver.lookup(opt_).__name__, nl=False) self._test_cli(cli) def test_signature(self): """Check signature tests.""" self.assertTrue(self.resolver.supports_argument("A", "name")) self.assertFalse(self.resolver.supports_argument("A", "nope")) def test_no_arguments(self): """Check that the unexpected keyword error is thrown properly.""" resolver = Resolver.from_subclasses(AltBase) with self.assertRaises(UnexpectedKeywordError) as e: resolver.make("A", nope="nopppeeee") self.assertEqual("AAltBase did not expect any keyword arguments", str(e)) def test_base_suffix(self): """Check that the unexpected keyword error is thrown properly.""" resolver = Resolver.from_subclasses(AltBase, suffix=None, base_as_suffix=True) self.assertEqual(AAltBase, resolver.lookup("AAltBase")) self.assertEqual(AAltBase, resolver.lookup("A")) resolver = Resolver.from_subclasses(AltBase, suffix="nope", base_as_suffix=True) self.assertEqual(AAltBase, resolver.lookup("AAltBase")) with self.assertRaises(KeyError): resolver.lookup("A") resolver = Resolver.from_subclasses(AltBase, suffix="") self.assertEqual(AAltBase, resolver.lookup("AAltBase")) with self.assertRaises(KeyError): resolver.lookup("A") resolver = Resolver.from_subclasses(AltBase, base_as_suffix=False) self.assertEqual(AAltBase, resolver.lookup("AAltBase")) with self.assertRaises(KeyError): resolver.lookup("A") def test_make_many(self): """Test the make_many function.""" with self.assertRaises(ValueError): # no default is given self.resolver.make_many(None) with self.assertRaises(ValueError): # wrong number of kwargs is given self.resolver.make_many([], [{}, {}]) with self.assertRaises(ValueError): # wrong number of kwargs is given self.resolver.make_many(["a", "a", "a"], [{}, {}]) # One class, one kwarg instances = self.resolver.make_many("a", dict(name="name")) self.assertEqual([A(name="name")], instances) instances = self.resolver.make_many("a", [dict(name="name")]) self.assertEqual([A(name="name")], instances) instances = self.resolver.make_many(["a"], dict(name="name")) self.assertEqual([A(name="name")], instances) instances = self.resolver.make_many(["a"], [dict(name="name")]) self.assertEqual([A(name="name")], instances) # Single class, multiple kwargs instances = self.resolver.make_many( "a", [dict(name="name1"), dict(name="name2")]) self.assertEqual([A(name="name1"), A(name="name2")], instances) instances = self.resolver.make_many( ["a"], [dict(name="name1"), dict(name="name2")]) self.assertEqual([A(name="name1"), A(name="name2")], instances) # Multiple class, one kwargs instances = self.resolver.make_many(["a", "b", "c"], dict(name="name")) self.assertEqual( [A(name="name"), B(name="name"), C(name="name")], instances) instances = self.resolver.make_many(["a", "b", "c"], [dict(name="name")]) self.assertEqual( [A(name="name"), B(name="name"), C(name="name")], instances) # Multiple class, multiple kwargs instances = self.resolver.make_many( ["a", "b", "c"], [dict(name="name1"), dict(name="name2"), dict(name="name3")]) self.assertEqual([A(name="name1"), B(name="name2"), C(name="name3")], instances) # One class, No kwargs instances = self.resolver.make_many("e") self.assertEqual([E()], instances) instances = self.resolver.make_many(["e"]) self.assertEqual([E()], instances) instances = self.resolver.make_many("e", None) self.assertEqual([E()], instances) instances = self.resolver.make_many(["e"], None) self.assertEqual([E()], instances) instances = self.resolver.make_many(["e"], [None]) self.assertEqual([E()], instances) # No class resolver = Resolver.from_subclasses(Base, default=A) instances = resolver.make_many(None, dict(name="name")) self.assertEqual([A(name="name")], instances)