def test_func(self): id_ = lambda x: x in_spec = api.Shape(()) out_spec = api.spec(id_)(in_spec) func = api.init(id_)(self._seed, in_spec) self.assertEqual(out_spec, in_spec) self.assertEqual(api.call(func, 1.), 1.)
def test_tuple(self): id_ = lambda x: x in_spec = api.Shape(()) out_spec = api.spec((id_, id_, id_))(in_spec) tupl = api.init((id_, id_, id_))(self._seed, in_spec) self.assertTupleEqual(out_spec, (in_spec, ) * 3) self.assertTupleEqual(api.call(tupl, 1.), (1., 1., 1.))
def test_list(self): func = lambda x: x + 1 in_spec = api.Shape(()) out_spec = api.spec([func, func, func])(in_spec) lst = api.init([func, func, func])(self._seed, in_spec) self.assertEqual(out_spec, in_spec) self.assertEqual(api.call(lst, 1.), 4.)
def test_vmap_of_init_should_return_ensemble(self): def f(x, init_key=None): w = module.variable(random.normal(init_key, x.shape), name='w') return np.dot(w, x) ensemble = jax.vmap(api.init(f))(random.split(random.PRNGKey(0)), np.ones([2, 5])) self.assertTupleEqual(ensemble.w.shape, (2, 5)) onp.testing.assert_allclose(jax.vmap(api.call, in_axes=(0, None))(ensemble, np.ones(5)), jax.vmap(lambda key, x: f(x, init_key=key), in_axes=(0, None))(random.split( random.PRNGKey(0)), np.ones(5)), rtol=1e-5, atol=1e-5) onp.testing.assert_allclose( jax.vmap(api.call, in_axes=(0, 0))(ensemble, np.arange(10.).reshape((2, 5))), jax.vmap(lambda key, x: f(x, init_key=key), in_axes=(0, 0))(random.split(random.PRNGKey(0)), np.arange(10.).reshape((2, 5))), rtol=1e-5, atol=1e-5)
def test_should_pass_kwarg_into_primitive(self): def f(x): return training_add_p.bind(x) m = api.init(f)(random.PRNGKey(0), 1.) self.assertEqual(m(1.), 2.) self.assertEqual(m(1., training=True), 2.) self.assertEqual(m(1., training=False), 1.)
def test_init_nonstateful_function(self): def f(x): return x m = api.init(f)(random.PRNGKey(0), 1.) self.assertIsInstance(m, function.FunctionModule) self.assertDictEqual(m.variables(), {}) self.assertEqual(m.call(1.), 1.) self.assertEqual(m.update(1.).variables(), {})
def test_list_multiple_args(self): func1 = lambda x: (x, x + 1) func2 = lambda x, y: (y, x + 1) func3 = lambda x, y: (x + 1, y + 2, x + y) in_spec = api.Shape(()) out_spec = api.spec([func1, func2, func3])(in_spec) lst = api.init([func1, func2, func3])(self._seed, in_spec) self.assertEqual(out_spec, (api.Shape(()), ) * 3) self.assertTupleEqual(api.call(lst, 1.), (3., 4., 4.))
def test_assign_with_no_matching_variable_should_error(self): def f(x, init_key=None): y = module.variable(np.zeros(x.shape), name='y', key=init_key) next_y = module.assign(y + 1., name='z') return x + next_y m = api.init(f)(random.PRNGKey(0), 1.) with self.assertRaisesRegex(ValueError, 'No variable declared for assign: z'): m(1.)
def test_init_stateful_function(self): def f(x, init_key=None): y = module.variable(np.ones(x.shape), name='y', key=init_key) return x + y m = api.init(f)(random.PRNGKey(0), 1.) self.assertIsInstance(m, function.FunctionModule) self.assertDictEqual(m.variables(), {'y': 1.}) self.assertEqual(m.call(1.), 2.) self.assertEqual(m.update(1.).variables(), {'y': 1})
def wrapped(init_key, *args, **kwargs): if init_key is None: init_keys = (None,) * len(tupl) else: init_keys = random.split(init_key, len(tupl)) names = [None if name is None else f'{name}_{i}' for i in range(len(tupl))] modules = tuple(api.init(t, name=name)(init_key, *args, **kwargs) for t, name, init_key in zip(tupl, names, init_keys)) return tuple.__new__(tupl.__class__, modules)
def test_init_stateful_function_with_tied_in_assign(self): def f(x, init_key=None): y = module.variable(np.zeros(x.shape), name='y', key=init_key) next_y = module.assign(y + 1., name='y') return primitive.tie_in(next_y, x) + y m = api.init(f)(random.PRNGKey(0), 1.) self.assertIsInstance(m, function.FunctionModule) self.assertDictEqual(m.variables(), {'y': 0.}) self.assertEqual(m.call(1.), 1.) self.assertDictEqual(m.update(1.).variables(), {'y': 1.})
def wrapped(init_key, *args, **kwargs): if init_key is not None: init_keys = random.split(init_key, len(inits)) else: init_keys = (None,) * len(inits) modules = [] for i, (elem_init, init_key) in enumerate(zip(inits, init_keys)): if isinstance(args, api.ArraySpec): args = (args,) elem_name = None if name is None else '{}_{}'.format(name, i) elem = api.init(elem_init, name=elem_name)(init_key, *args, **kwargs) # pylint: disable=assignment-from-no-return args = api.spec(elem_init)(*args, **kwargs) # pylint: disable=assignment-from-no-return modules.append(elem) return inits.__class__(modules)
def test_init_of_nested_init_without_name_should_have_flat_params(self): def f(x, init_key=None): y = module.variable(np.zeros(x.shape), name='y', key=init_key) next_y = module.assign(y + 1., name='y') return primitive.tie_in(next_y, x) + y def g(x, init_key=None): return api.init(f)(init_key, x)(x) m = api.init(g)(random.PRNGKey(0), 1.) self.assertIsInstance(m, function.FunctionModule) self.assertDictEqual(m.variables(), {'y': 0.}) self.assertEqual(m.call(1.), 1.) self.assertDictEqual(m.update(1.).variables(), {'y': 1.})
def g(x, init_key=None): return api.init(f)(init_key, x)(x)
def g(x, init_key=None): return api.init(f, name='f')(init_key, x)(x)
def test_empty_tuple_init(self): with self.assertRaises(ValueError): api.init(())(self._seed, api.Shape((50, )))
def f_(init_key, x): return api.init(f)(init_key, x)