def test_tuple(self): id_ = lambda x: x in_spec = api.Shape(()) out_spec = api.spec((id_, id_, id_))(in_spec) tupl = api.init((id_, id_, id_))(self._seed, in_spec) self.assertTupleEqual(out_spec, (in_spec, ) * 3) self.assertTupleEqual(api.call(tupl, 1.), (1., 1., 1.))
def test_list(self): func = lambda x: x + 1 in_spec = api.Shape(()) out_spec = api.spec([func, func, func])(in_spec) lst = api.init([func, func, func])(self._seed, in_spec) self.assertEqual(out_spec, in_spec) self.assertEqual(api.call(lst, 1.), 4.)
def test_func(self): id_ = lambda x: x in_spec = api.Shape(()) out_spec = api.spec(id_)(in_spec) func = api.init(id_)(self._seed, in_spec) self.assertEqual(out_spec, in_spec) self.assertEqual(api.call(func, 1.), 1.)
def test_list_multiple_args(self): func1 = lambda x: (x, x + 1) func2 = lambda x, y: (y, x + 1) func3 = lambda x, y: (x + 1, y + 2, x + y) in_spec = api.Shape(()) out_spec = api.spec([func1, func2, func3])(in_spec) lst = api.init([func1, func2, func3])(self._seed, in_spec) self.assertEqual(out_spec, (api.Shape(()), ) * 3) self.assertTupleEqual(api.call(lst, 1.), (3., 4., 4.))