def test_list(self): func = lambda x: x + 1 in_spec = api.Shape(()) out_spec = api.spec([func, func, func])(in_spec) lst = api.init([func, func, func])(self._seed, in_spec) self.assertEqual(out_spec, in_spec) self.assertEqual(api.call(lst, 1.), 4.)
def test_tuple(self): id_ = lambda x: x in_spec = api.Shape(()) out_spec = api.spec((id_, id_, id_))(in_spec) tupl = api.init((id_, id_, id_))(self._seed, in_spec) self.assertTupleEqual(out_spec, (in_spec, ) * 3) self.assertTupleEqual(api.call(tupl, 1.), (1., 1., 1.))
def test_func(self): id_ = lambda x: x in_spec = api.Shape(()) out_spec = api.spec(id_)(in_spec) func = api.init(id_)(self._seed, in_spec) self.assertEqual(out_spec, in_spec) self.assertEqual(api.call(func, 1.), 1.)
def test_spec_works_for_identity_function(self): def f(x): return x out_spec = api.spec(f)(np.ones(5)) self.assertTupleEqual(out_spec.shape, (5, )) self.assertEqual(out_spec.dtype, np.float32)
def test_list_multiple_args(self): func1 = lambda x: (x, x + 1) func2 = lambda x, y: (y, x + 1) func3 = lambda x, y: (x + 1, y + 2, x + y) in_spec = api.Shape(()) out_spec = api.spec([func1, func2, func3])(in_spec) lst = api.init([func1, func2, func3])(self._seed, in_spec) self.assertEqual(out_spec, (api.Shape(()), ) * 3) self.assertTupleEqual(api.call(lst, 1.), (3., 4., 4.))
def wrapped(init_key, *args, **kwargs): if init_key is not None: init_keys = random.split(init_key, len(inits)) else: init_keys = (None,) * len(inits) modules = [] for i, (elem_init, init_key) in enumerate(zip(inits, init_keys)): if isinstance(args, api.ArraySpec): args = (args,) elem_name = None if name is None else '{}_{}'.format(name, i) elem = api.init(elem_init, name=elem_name)(init_key, *args, **kwargs) # pylint: disable=assignment-from-no-return args = api.spec(elem_init)(*args, **kwargs) # pylint: disable=assignment-from-no-return modules.append(elem) return inits.__class__(modules)
def wrapped(*specs, **kwargs): for elem_init in inits: if isinstance(specs, api.ArraySpec): specs = (specs, ) specs = api.spec(elem_init)(*specs, **kwargs) # pylint: disable=assignment-from-no-return return specs
def wrapped(*args, **kwargs): args = tuple(api.spec(t)(*args, **kwargs) for t in tupl) return tuple.__new__(tupl.__class__, args)