def _Init(self, params): """Initializes the InputGenerator.""" super(generated_cls, self).__init__(params) p = self.params # We have to make the one-shot iterator only once as _InputBatch will be # called repeatedly in TFv2. overrides = {k: p.Get(v) for k, v in map_args.items()} dataset = inspect_utils.CallWithParams(func, p.args, **overrides) assert isinstance(dataset, (tf1.data.Dataset, tf2.data.Dataset)), ( 'DefineTFDataInput must take a callable which returns a ' '`tf.data.Dataset`. The given callable `%s` returned `%s`' % (func, dataset)) self.iterator = tf1.data.make_one_shot_iterator(dataset)
def testFunctionWithOverrides(self): def my_function(a, b=3): return a + 1, b + 2 params = hyperparams.Params() inspect_utils.DefineParams(my_function, params) self.assertIn('a', params) self.assertIn('b', params) self.assertIsNone(params.a) self.assertEqual(params.b, 3) params.a = 6 a1, b1 = inspect_utils.CallWithParams(my_function, params, a=7) self.assertEqual(a1, 7 + 1) self.assertEqual(b1, 3 + 2)
def testBareFunction(self): def my_function(a, b): return a + 1, b + 2 params = hyperparams.Params() inspect_utils.DefineParams(my_function, params) self.assertIn('a', params) self.assertIn('b', params) self.assertIsNone(params.a) self.assertIsNone(params.b) params.a = 5 params.b = 6 a1, b1 = inspect_utils.CallWithParams(my_function, params) self.assertEqual(a1, 5 + 1) self.assertEqual(b1, 6 + 2)
def testFunctionWithIgnore(self): def my_function(a, b=3, c=4): return a + 1, b + 2, c + 3 params = hyperparams.Params() inspect_utils.DefineParams(my_function, params, ignore=['c']) self.assertIn('a', params) self.assertIn('b', params) self.assertNotIn('c', params) self.assertIsNone(params.a) self.assertEqual(params.b, 3) params.a = 6 a1, b1, c1 = inspect_utils.CallWithParams(my_function, params, c=9) self.assertEqual(a1, 6 + 1) self.assertEqual(b1, 3 + 2) self.assertEqual(c1, 9 + 3)
def testMethod(self): class MyClass(object): def __init__(self): self._s = 'a/b' def split(self, sep): return self._s.split(sep) params = hyperparams.Params() inspect_utils.DefineParams(MyClass.split, params, bound=True) self.assertNotIn('self', params) self.assertIn('sep', params) self.assertIsNone(params.sep) params.sep = '/' parts = inspect_utils.CallWithParams(MyClass().split, params) self.assertEqual(['a', 'b'], parts)
def testClassInit(self): class MyClass(object): def __init__(self, a, b=3): self.a = a self.b = b params = hyperparams.Params() inspect_utils.DefineParams(MyClass, params) self.assertIn('a', params) self.assertIn('b', params) self.assertIsNone(params.a) self.assertEqual(params.b, 3) params.a = 9 params.b = 5 obj = inspect_utils.CallWithParams(MyClass, params) self.assertEqual(obj.a, 9) self.assertEqual(obj.b, 5)
def testFunctionWithVarArgs(self): def my_function(a, *args, b=3, **kwargs): del args del kwargs return a + 1, b + 2 params = hyperparams.Params() inspect_utils.DefineParams(my_function, params) self.assertIn('a', params) self.assertNotIn('args', params) self.assertIn('b', params) self.assertNotIn('kwargs', params) self.assertIsNone(params.a) self.assertEqual(params.b, 3) params.a = 6 a1, b1 = inspect_utils.CallWithParams(my_function, params) self.assertEqual(a1, 6 + 1) self.assertEqual(b1, 3 + 2)