def testArgsAtCallKWArgsAtCall(self, func_or_method, run_mode): with run_mode(): inputs, labels = input_fn() future = eager_utils.Future(func_or_method) inputs, labels, param = future(inputs, labels=labels) self.assertAllEqual(self.evaluate(inputs), [[1, 2], [2, 3], [3, 4]]) self.assertAllEqual(self.evaluate(labels), [[0], [1], [2]]) self.assertEqual(self.evaluate(param), 0)
def testCreate(self, func_or_method, run_mode): with run_mode(): future = eager_utils.Future(input_fn) self.assertTrue(callable(future)) self.assertIsInstance(future, eager_utils.Future) inputs, labels = future() self.assertAllEqual(self.evaluate(inputs), [[1, 2], [2, 3], [3, 4]]) self.assertAllEqual(self.evaluate(labels), [[0], [1], [2]])
def testPartialArgsAtCallRaisesError(self, func_or_method, run_mode): with run_mode(): inputs, labels = input_fn() future = eager_utils.Future(func_or_method, inputs) with self.assertRaisesRegexp(TypeError, 'argument'): future(labels)