def test_nested_add(self): def add(x): return (lambda x: jnp.add(x, x))(x) in_spec = state.Shape(50) out_spec = state.spec(add)(in_spec) self.assertEqual(out_spec, in_spec) in_spec = state.Shape((100, 20)) out_spec = state.spec(add)(in_spec) self.assertEqual(out_spec, in_spec) in_spec = state.Shape(2) net = state.init(add)(self._seed, jnp.ones(2)) np.testing.assert_array_equal(net(jnp.arange(5)), 2 * jnp.arange(5))
def test_add(self): def add(x): return np.add(x, x) in_spec = state.Shape(50) out_spec = state.spec(add)(in_spec) self.assertEqual(out_spec, in_spec) in_spec = state.Shape((100, 20)) out_spec = state.spec(add)(in_spec) self.assertEqual(out_spec, in_spec) in_spec = state.Shape(2) net = state.init(add)(self._seed, in_spec) onp.testing.assert_array_equal(net(np.arange(5)), 2 * np.arange(5))
def test_multiple_input_single_output(self): def add(x, y): return jnp.add(x, y) in_spec = (state.Shape((100, 20)), state.Shape(50)) with self.assertRaises(ValueError): out_spec = state.spec(add)(*in_spec) in_spec = (state.Shape(50), state.Shape(50)) out_spec = state.spec(add)(*in_spec) self.assertEqual(out_spec, in_spec[0]) net = state.init(add)(self._seed, jnp.ones(50), jnp.ones(50)) np.testing.assert_array_equal(net(jnp.arange(5), jnp.arange(5)), 2 * jnp.arange(5))
def test_add_one_combinator(self): def add_two(x, init_key=None): return (AddOne() >> AddOne())(x, name='add_one', init_key=init_key) in_spec = state.Shape(20) out_spec = state.spec(add_two)(in_spec) self.assertEqual(out_spec, in_spec) in_spec = state.Shape((5, 50)) out_spec = state.spec(add_two)(in_spec) self.assertEqual(out_spec, in_spec) net = state.init(add_two)(self._seed, state.Shape(2)) onp.testing.assert_allclose(net(np.ones(2)), 3 * np.ones(2)) onp.testing.assert_array_equal( net(np.ones(2)), add_two(np.ones(2), init_key=self._seed))
def test_identity(self): def identity(x): return x in_spec = state.Shape(50) out_spec = state.spec(identity)(in_spec) self.assertEqual(out_spec, in_spec) net = state.init(identity)(self._seed, jnp.ones(50)) np.testing.assert_array_equal(net(jnp.arange(5)), jnp.arange(5))
def test_add_one_imperative(self): def add_two(x, init_key=None): k1, k2 = random.split(init_key) x = AddOne()(x, name='add_one_1', init_key=k1) x = AddOne()(x, name='add_one_2', init_key=k2) return x in_spec = state.Shape(20) out_spec = state.spec(add_two)(in_spec) self.assertEqual(out_spec, in_spec) in_spec = state.Shape((5, 50)) out_spec = state.spec(add_two)(in_spec) self.assertEqual(out_spec, in_spec) net = state.init(add_two)(self._seed, state.Shape(2)) onp.testing.assert_array_equal(net(np.ones(2)), 3 * np.ones(2)) onp.testing.assert_array_equal( net(np.ones(2)), add_two(np.ones(2), init_key=self._seed))
def test_single_input_multiple_output(self): def dup(x): return x, x in_spec = state.Shape(50) out_spec = state.spec(dup)(in_spec) self.assertEqual(out_spec, (in_spec, in_spec)) net = state.init(dup)(self._seed, in_spec) for x1, x2 in zip(net(np.arange(5)), (np.arange(5), np.arange(5))): onp.testing.assert_array_equal(x1, x2)
def test_dense_combinator(self): def dense(x, init_key=None): return (nn.Dense(50) >> nn.Dense(20))(x, init_key=init_key, name='dense') in_spec = state.Shape(50) out_spec = state.spec(dense)(in_spec) self.assertEqual(out_spec, state.Shape(20, dtype=in_spec.dtype)) net = state.init(dense)(self._seed, jnp.ones(2)) self.assertTupleEqual(net(jnp.ones(2)).shape, out_spec.shape) np.testing.assert_allclose( dense(jnp.ones(2), init_key=self._seed), net(jnp.ones(2)), rtol=1e-5)
def test_dense_function(self): def dense_no_rng(x): return nn.Dense(20)(x, name='dense') with self.assertRaises(ValueError): out_spec = state.spec(dense_no_rng)(state.Shape(50)) with self.assertRaises(ValueError): net = state.init(dense_no_rng)(self._seed, state.Shape(2)) def dense(x, init_key=None): return nn.Dense(20)(x, init_key=init_key, name='dense') out_spec = state.spec(dense)(state.Shape(2)) self.assertEqual(out_spec, state.Shape(20)) net = state.init(dense)(self._seed, jnp.ones(2)) self.assertTupleEqual(net(jnp.ones(2)).shape, (20,)) np.testing.assert_allclose(net(jnp.ones(2)), dense(jnp.ones(2), init_key=self._seed), rtol=1e-5)
def test_multiple_input_multiple_output(self): def swap(x, y): return y, x in_spec = (state.Shape(50), state.Shape(20)) out_spec = state.spec(swap)(*in_spec) self.assertEqual(out_spec, (in_spec[1], in_spec[0])) net = state.init(swap)(self._seed, jnp.ones(50), jnp.ones(20)) for x1, x2 in zip(net(jnp.zeros(50), jnp.ones(20)), (jnp.ones(20), jnp.zeros(50))): np.testing.assert_array_equal(x1, x2)
def test_dense_imperative(self): def dense(x, init_key=None): key, subkey = random.split(init_key) x = nn.Dense(50)(x, init_key=key, name='dense1') x = nn.Dense(20)(x, init_key=subkey, name='dense2') return x in_spec = state.Shape(50) out_spec = state.spec(dense)(in_spec) self.assertEqual(out_spec, state.Shape(20, dtype=in_spec.dtype)) net = state.init(dense)(self._seed, jnp.ones(2)) self.assertTupleEqual(net(jnp.ones(2)).shape, out_spec.shape) np.testing.assert_allclose( dense(jnp.ones(2), init_key=self._seed), net(jnp.ones(2)), rtol=1e-5)
def spec(cls, *args): in_specs, layer_inits = args[:-1], args[-1] return state.spec(list(layer_inits))(*in_specs)