def test_base(self): template = m.TemplateBase("BaseTpl") self.assertEqual(str(template), "<TemplateBase {}.BaseTpl>".format(_TEST_MODULE)) # Single arguments. template.add_instantiation(int, 1) self.assertEqual(template[int], 1) self.assertEqual(template.get_instantiation(int), (1, (int, ))) self.assertEqual(template.get_param_set(1), {(int, )}) self.assertTrue(template.is_instantiation(1)) self.assertFalse(template.is_instantiation(10)) # Duplicate parameters. self.assertRaises(RuntimeError, lambda: template.add_instantiation(int, 4)) # Invalid parameters. self.assertRaises(RuntimeError, lambda: template[float]) # New instantiation. template.add_instantiation(float, 2) self.assertEqual(template[float], 2) # Default instantiation. self.assertEqual(template[None], 1) self.assertEqual(template.get_instantiation(), (1, (int, ))) # Multiple arguments. template.add_instantiation((int, int), 3) self.assertEqual(template[int, int], 3) # Duplicate instantiation. template.add_instantiation((float, float), 1) self.assertEqual(template.get_param_set(1), {(int, ), (float, float)}) # Nested getitem indices. self.assertEqual(template[(int, int)], 3) self.assertEqual(template[[int, int]], 3) # List instantiation. def instantiation_func(param): return 100 + len(param) dummy_a = (str, ) * 5 dummy_b = (str, ) * 10 template.add_instantiations(instantiation_func, [dummy_a, dummy_b]) self.assertEqual(template[dummy_a], 105) self.assertEqual(template[dummy_b], 110) # Ensure that we can only call this once. dummy_c = (str, ) * 7 with self.assertRaises(RuntimeError): template.add_instantiations(instantiation_func, [dummy_c]) with self.assertRaises(TypeError) as cm: assert_pickle(self, template) if sys.version_info[:2] >= (3, 8): pickle_error = "cannot pickle 'module' object" else: pickle_error = "can't pickle module objects" self.assertIn(pickle_error, str(cm.exception))
def test_deprecation(self): template = m.TemplateBase("BaseTpl") template.add_instantiation(int, 1) template.add_instantiation(float, 2) instantiation, param = template.deprecate_instantiation( int, "Example deprecation") self.assertEqual(instantiation, 1) self.assertEqual(param, (int, )) with catch_drake_warnings(expected_count=1) as w: self.assertEqual(template[int], 1) self.assertEqual(str(w[0].message), "Example deprecation") # There should be no deprecations for other types. self.assertEqual(template[float], 2) # Double-deprecating should raise an error. with self.assertRaises(RuntimeError) as cm: template.deprecate_instantiation(int, "Double-deprecate") self.assertEqual(str(cm.exception), "Deprecation already registered: BaseTpl[int]")