def _Params(cls): """Generates Params to configure the InputGenerator. This function analyses the signature of the given callable `func` and defines corresponding fields into `Params` to the obtained function parameters. Returns: An `InstantiableParams` object representing the InputGenerator. It has the `args` field which contains the set of parameters of `func`. """ # Keys in `map_args` will also be ignored. actual_ignore_args = ignore_args | set(map_args.keys()) p = super(generated_cls, cls).Params() # Introduces a new group `args` to avoid confusion between `func`'s # parameters and existing params defined by super classes. # TODO(oday): For better UX, consider removing this nested field and add # `func`s parameters to `p` directly. We need to make sure that there are no # side effects by integrating `func`'s parameters and follows: # - BaseInputGenerator.Params() # - BaseLayer.Params() # - InstantiableParams.cls p.Define('args', hyperparams.Params(), 'Parameter list of the pipeline.') inspect_utils.DefineParams(func, p.args, actual_ignore_args) return p
def testFunctionWithOverrides(self): def my_function(a, b=3): return a + 1, b + 2 params = hyperparams.Params() inspect_utils.DefineParams(my_function, params) self.assertIn('a', params) self.assertIn('b', params) self.assertIsNone(params.a) self.assertEqual(params.b, 3) params.a = 6 a1, b1 = inspect_utils.CallWithParams(my_function, params, a=7) self.assertEqual(a1, 7 + 1) self.assertEqual(b1, 3 + 2)
def testBareFunction(self): def my_function(a, b): return a + 1, b + 2 params = hyperparams.Params() inspect_utils.DefineParams(my_function, params) self.assertIn('a', params) self.assertIn('b', params) self.assertIsNone(params.a) self.assertIsNone(params.b) params.a = 5 params.b = 6 a1, b1 = inspect_utils.CallWithParams(my_function, params) self.assertEqual(a1, 5 + 1) self.assertEqual(b1, 6 + 2)
def testFunctionWithIgnore(self): def my_function(a, b=3, c=4): return a + 1, b + 2, c + 3 params = hyperparams.Params() inspect_utils.DefineParams(my_function, params, ignore=['c']) self.assertIn('a', params) self.assertIn('b', params) self.assertNotIn('c', params) self.assertIsNone(params.a) self.assertEqual(params.b, 3) params.a = 6 a1, b1, c1 = inspect_utils.CallWithParams(my_function, params, c=9) self.assertEqual(a1, 6 + 1) self.assertEqual(b1, 3 + 2) self.assertEqual(c1, 9 + 3)
def testMethod(self): class MyClass(object): def __init__(self): self._s = 'a/b' def split(self, sep): return self._s.split(sep) params = hyperparams.Params() inspect_utils.DefineParams(MyClass.split, params, bound=True) self.assertNotIn('self', params) self.assertIn('sep', params) self.assertIsNone(params.sep) params.sep = '/' parts = inspect_utils.CallWithParams(MyClass().split, params) self.assertEqual(['a', 'b'], parts)
def testClassInit(self): class MyClass(object): def __init__(self, a, b=3): self.a = a self.b = b params = hyperparams.Params() inspect_utils.DefineParams(MyClass, params) self.assertIn('a', params) self.assertIn('b', params) self.assertIsNone(params.a) self.assertEqual(params.b, 3) params.a = 9 params.b = 5 obj = inspect_utils.CallWithParams(MyClass, params) self.assertEqual(obj.a, 9) self.assertEqual(obj.b, 5)
def testFunctionWithVarArgs(self): def my_function(a, *args, b=3, **kwargs): del args del kwargs return a + 1, b + 2 params = hyperparams.Params() inspect_utils.DefineParams(my_function, params) self.assertIn('a', params) self.assertNotIn('args', params) self.assertIn('b', params) self.assertNotIn('kwargs', params) self.assertIsNone(params.a) self.assertEqual(params.b, 3) params.a = 6 a1, b1 = inspect_utils.CallWithParams(my_function, params) self.assertEqual(a1, 6 + 1) self.assertEqual(b1, 3 + 2)
def testClassInit2(self): class MyClass(object): def __init__(self, a, b=3): self.a = a self.b = b params = hyperparams.Params() inspect_utils.DefineParams(MyClass.__init__, params, bound=True) self.assertNotIn('self', params) self.assertIn('a', params) self.assertIn('b', params) self.assertIsNone(params.a) self.assertEqual(params.b, 3) params.a = 9 params.b = 5 obj = inspect_utils.ConstructWithParams(MyClass, params) self.assertEqual(obj.a, 9) self.assertEqual(obj.b, 5)