def _Init(self, params):
        """Initializes the InputGenerator."""
        super(generated_cls, self).__init__(params)
        p = self.params

        # We have to make the one-shot iterator only once as _InputBatch will be
        # called repeatedly in TFv2.
        overrides = {k: p.Get(v) for k, v in map_args.items()}
        dataset = inspect_utils.CallWithParams(func, p.args, **overrides)
        assert isinstance(dataset, (tf1.data.Dataset, tf2.data.Dataset)), (
            'DefineTFDataInput must take a callable which returns a '
            '`tf.data.Dataset`. The given callable `%s` returned `%s`' %
            (func, dataset))
        self.iterator = tf1.data.make_one_shot_iterator(dataset)
    def testFunctionWithOverrides(self):
        def my_function(a, b=3):
            return a + 1, b + 2

        params = hyperparams.Params()
        inspect_utils.DefineParams(my_function, params)
        self.assertIn('a', params)
        self.assertIn('b', params)
        self.assertIsNone(params.a)
        self.assertEqual(params.b, 3)

        params.a = 6
        a1, b1 = inspect_utils.CallWithParams(my_function, params, a=7)
        self.assertEqual(a1, 7 + 1)
        self.assertEqual(b1, 3 + 2)
    def testBareFunction(self):
        def my_function(a, b):
            return a + 1, b + 2

        params = hyperparams.Params()
        inspect_utils.DefineParams(my_function, params)
        self.assertIn('a', params)
        self.assertIn('b', params)
        self.assertIsNone(params.a)
        self.assertIsNone(params.b)

        params.a = 5
        params.b = 6
        a1, b1 = inspect_utils.CallWithParams(my_function, params)
        self.assertEqual(a1, 5 + 1)
        self.assertEqual(b1, 6 + 2)
    def testFunctionWithIgnore(self):
        def my_function(a, b=3, c=4):
            return a + 1, b + 2, c + 3

        params = hyperparams.Params()
        inspect_utils.DefineParams(my_function, params, ignore=['c'])
        self.assertIn('a', params)
        self.assertIn('b', params)
        self.assertNotIn('c', params)
        self.assertIsNone(params.a)
        self.assertEqual(params.b, 3)

        params.a = 6
        a1, b1, c1 = inspect_utils.CallWithParams(my_function, params, c=9)
        self.assertEqual(a1, 6 + 1)
        self.assertEqual(b1, 3 + 2)
        self.assertEqual(c1, 9 + 3)
    def testMethod(self):
        class MyClass(object):
            def __init__(self):
                self._s = 'a/b'

            def split(self, sep):
                return self._s.split(sep)

        params = hyperparams.Params()
        inspect_utils.DefineParams(MyClass.split, params, bound=True)
        self.assertNotIn('self', params)
        self.assertIn('sep', params)
        self.assertIsNone(params.sep)

        params.sep = '/'
        parts = inspect_utils.CallWithParams(MyClass().split, params)
        self.assertEqual(['a', 'b'], parts)
    def testClassInit(self):
        class MyClass(object):
            def __init__(self, a, b=3):
                self.a = a
                self.b = b

        params = hyperparams.Params()
        inspect_utils.DefineParams(MyClass, params)
        self.assertIn('a', params)
        self.assertIn('b', params)
        self.assertIsNone(params.a)
        self.assertEqual(params.b, 3)

        params.a = 9
        params.b = 5
        obj = inspect_utils.CallWithParams(MyClass, params)
        self.assertEqual(obj.a, 9)
        self.assertEqual(obj.b, 5)
    def testFunctionWithVarArgs(self):
        def my_function(a, *args, b=3, **kwargs):
            del args
            del kwargs
            return a + 1, b + 2

        params = hyperparams.Params()
        inspect_utils.DefineParams(my_function, params)
        self.assertIn('a', params)
        self.assertNotIn('args', params)
        self.assertIn('b', params)
        self.assertNotIn('kwargs', params)
        self.assertIsNone(params.a)
        self.assertEqual(params.b, 3)

        params.a = 6
        a1, b1 = inspect_utils.CallWithParams(my_function, params)
        self.assertEqual(a1, 6 + 1)
        self.assertEqual(b1, 3 + 2)