def test_kwargs_training(self): def template(x, training=False, init_key=None): return IsTraining()(x, name='training', training=training, init_key=init_key) net = state.init(template)(self._seed, state.Shape(())) self.assertEqual(net(np.ones(()), training=True), 1.) self.assertEqual(net(np.ones(()), training=False), 0.) def template1(x, training=True, init_key=None): return IsTraining()(x, training=training, name='training', init_key=init_key) net = state.init(template1)(self._seed, state.Shape(())) self.assertEqual(net(np.ones(()), training=True), 1.) self.assertEqual(net(np.ones(()), training=False), 0.) def template2(x, init_key=None): return IsTraining()(x, name='training', init_key=init_key) + 1 net = state.init(template2)(self._seed, state.Shape(())) self.assertEqual(net(np.ones(()), training=True), 2.) self.assertEqual(net(np.ones(()), training=False), 1.)
def test_duplicate_names(self): def template(x, init_key=None): k1, k2 = random.split(init_key) layer1 = state.init(nn.Dense(20), name='dense')(k1, x) layer2 = state.init(nn.Dense(20), name='dense')(k2, x) return layer1(x) + layer2(x) with self.assertRaises(ValueError): state.init(template)(self._seed, jnp.ones(5))
def test_kwargs_rng(self): def template(x, init_key=None): return Sampler()(x, name='sampler', init_key=init_key) with self.assertRaises(AssertionError): net = state.init(template)(self._seed, jnp.ones(())) def template1(x, rng, init_key=None): return Sampler()(x, rng=rng, init_key=init_key) net = state.init(template1)(self._seed, jnp.ones(()), jnp.ones(2, dtype=jnp.uint32)) x1 = net(jnp.ones(()), random.PRNGKey(0)) x2 = net(jnp.ones(()), random.PRNGKey(1)) self.assertNotEqual(x1, x2)
def template(x, init_key=None): layer = state.init(ScalarMul(2 * jnp.ones(1)), name='scalar_mul')( init_key, x) x, layer = layer.call_and_update(x) x, layer = layer.call_and_update(x) state.assign(layer, name='scalar_mul') return x[0]
def test_grad_of_stateful_function(self): def template(x, init_key=None): x = ScalarMul(np.ones(1))(x, init_key=init_key, name='scalar_mul') x = Counter(np.zeros(1))(x, init_key=init_key, name='counter') return x net = state.init(template)(self._seed, state.Shape(5)) def loss(net, x): return net(x).sum() g = jax.grad(loss)(net, np.ones(5)) def add(x, y): return x + y net = tree_util.tree_multimap(add, net, g) # w_new = w_old + 5 onp.testing.assert_array_equal(net(np.ones(5)), 6 * np.ones(5)) net = net.update(np.ones(5)) onp.testing.assert_array_equal(net(np.ones(5)), 7 * np.ones(5)) g = jax.grad(loss)(net, np.ones(5)) net = tree_util.tree_multimap(add, net, g) # w_new = w_old + 5 onp.testing.assert_array_equal(net(np.ones(5)), 12 * np.ones(5))
def test_serialize(self, template): network = state.init(template)(random.PRNGKey(0), state.Shape(784)) network2 = deserialize(serialize(network)) rng = random.PRNGKey(0) onp.testing.assert_array_equal( jax.vmap(lambda x: network(x, rng=rng))(np.ones([10, 784])), jax.vmap(lambda x: network2(x, rng=rng))(np.ones([10, 784])))
def update(params, updates, init_key=None): keys = random.split(init_key, len(args) + len(kwargs)) names = [f'update_{i}' for i in range(len(args))] + list(kwargs.keys()) for (name, key, update_fn) in zip( names, keys, itertools.chain(args, kwargs.values())): step = state.init(update_fn, name=name)(key, params, updates) updates = step(params, updates) return updates
def test_shared_layer(self): def template(x, init_key=None): layer = state.init(ScalarMul(2 * np.ones(1)), name='scalar_mul')(init_key, x) return layer(layer(x)) net = state.init(template)(self._seed, state.Shape(5)) onp.testing.assert_array_equal(net(np.ones(5)), 4 * np.ones(5))
def test_function_in_combinator(self): def add_one(x): return x + 1 template = AddOne() >> add_one >> AddOne() net = state.init(template)(self._seed, jnp.ones(2)) np.testing.assert_array_equal(net(jnp.zeros(2)), 3 * jnp.ones(2))
def test_optimizer(self, update): def loss(x): return np.sum(x**2) x = np.array([3., 4.]) opt = state.init(optix.optimize(loss, update, 500))(random.PRNGKey(0), x) x = jax.jit(opt.call)(x) onp.testing.assert_allclose(np.zeros(2), x, atol=1e-1, rtol=1e-1)
def test_trace_should_keep_track_of_momentum(self): params = np.zeros(6) updates = np.ones(6) opt = state.init(optix.trace(0.99, False))(random.PRNGKey(0), params, updates) onp.testing.assert_array_equal(opt.trace, np.zeros(6)) opt = opt.update(params, updates) onp.testing.assert_array_equal(opt.trace, np.ones(6)) onp.testing.assert_array_equal(opt(params, updates), 1.99 * np.ones(6))
def test_identity(self): def identity(x): return x in_spec = state.Shape(50) out_spec = state.spec(identity)(in_spec) self.assertEqual(out_spec, in_spec) net = state.init(identity)(self._seed, jnp.ones(50)) np.testing.assert_array_equal(net(jnp.arange(5)), jnp.arange(5))
def test_call_list(self): def template(x, init_key=None): return state.call([Counter(0.), AddOne()], x, init_key=init_key, name='counter_add_one') layer = state.init(template)(self._seed, jnp.ones(())) self.assertEqual(layer(jnp.zeros(())), 1) layer = layer.update(jnp.zeros(())) self.assertEqual(layer(jnp.zeros(())), 2)
def test_function_in_combinator_in_function(self): def add_one(x): return x + 1 def template(x, init_key=None): return (AddOne() >> add_one >> AddOne())(x, init_key=init_key) net = state.init(template)(self._seed, jnp.ones(2)) np.testing.assert_array_equal(net(jnp.zeros(2), init_key=self._seed), 3 * jnp.ones(2))
def test_call_tuple(self): def template(x, init_key=None): return state.call((Counter(0.), AddOne()), x, init_key=init_key, name='counter_add_one') layer = state.init(template)(self._seed, jnp.ones(())) self.assertTupleEqual(layer(jnp.zeros(())), (0, 1)) layer = layer.update(jnp.zeros(())) self.assertTupleEqual(layer(jnp.zeros(())), (1, 1))
def test_trace_should_keep_track_of_momentum_with_nesterov(self): params = jnp.zeros(6) updates = jnp.ones(6) opt = state.init(optix.trace(0.99, True))(random.PRNGKey(0), updates, params) np.testing.assert_array_equal(opt.trace, jnp.zeros(6)) opt = opt.update(updates, params) np.testing.assert_array_equal(opt.trace, jnp.ones(6)) np.testing.assert_array_equal( opt(updates, params), (1.99 + 0.99**2) * jnp.ones(6))
def test_update_in_combinator(self): def template(x, init_key=None): def increment(x, init_key=None): return Counter(jnp.zeros(()))(x, init_key=init_key, name='counter') return nn.Serial([increment, increment])(x, init_key=init_key, name='increment') net = state.init(template)(self._seed, jnp.ones(())) self.assertEqual(net(jnp.ones(())), 1.) net = state.update(net, jnp.ones(())) self.assertEqual(net(jnp.ones(())), 3.)
def test_shared_layer(self): def template(x, init_key=None): layer = state.init(ScalarMul(2 * jnp.ones(1)), name='scalar_mul')( init_key, x) x, layer = layer.call_and_update(x) x, layer = layer.call_and_update(x) state.assign(layer, name='scalar_mul') return x net = state.init(template)(self._seed, jnp.ones(5)) np.testing.assert_array_equal(net(jnp.ones(5)), 4 * jnp.ones(5))
def test_add_noise_should_add_noise(self): params = 0. updates = 1. opt = state.init(optix.add_noise(1., 0., 0))(random.PRNGKey(0), params, updates) onp.testing.assert_array_equal(opt.count, 0.) onp.testing.assert_array_equal(opt.rng_key, random.PRNGKey(0)) value, opt = opt.call_and_update(params, updates) onp.testing.assert_array_equal( value, 1. + random.normal(random.split(random.PRNGKey(0))[1], ()))
def test_scale_by_rms_should_scale_by_rms(self): params = np.zeros(9) updates = 2 * np.ones(9) opt = state.init(optix.scale_by_rms(0.5, 0.))(random.PRNGKey(0), params, updates) onp.testing.assert_array_equal(opt.nu, np.zeros(9)) opt = opt.update(params, updates) onp.testing.assert_array_equal(opt.nu, 2 * np.ones(9)) onp.testing.assert_array_equal(opt(params, updates), 2 * np.ones(9) / np.sqrt(3.))
def test_update(self): def template(x, init_key=None): return Counter(jnp.zeros(()))(x, init_key=init_key, name='counter') net = state.init(template)(self._seed, jnp.ones(())) self.assertEqual(net(jnp.ones(())), 1.) net2 = state.update(net, jnp.ones(())) self.assertEqual(net2(jnp.ones(())), 2.) net2 = net.update(jnp.ones(())) self.assertEqual(net2(jnp.ones(())), 2.)
def test_apply_every_should_delay_updates(self): params = 0. updates = 1. opt = state.init(optix.apply_every(5))(random.PRNGKey(0), updates, params) np.testing.assert_array_equal(opt.count, 0.) np.testing.assert_array_equal(opt.grad_acc, 0.) for _ in range(4): value, opt = opt.call_and_update(updates, params) np.testing.assert_array_equal(value, 0.) value = opt(updates, params) np.testing.assert_array_equal(value, 5.)
def test_single_input_multiple_output(self): def dup(x): return x, x in_spec = state.Shape(50) out_spec = state.spec(dup)(in_spec) self.assertEqual(out_spec, (in_spec, in_spec)) net = state.init(dup)(self._seed, in_spec) for x1, x2 in zip(net(np.arange(5)), (np.arange(5), np.arange(5))): onp.testing.assert_array_equal(x1, x2)
def test_scale_by_stddev_should_scale_by_stddev(self): params = jnp.zeros(9) updates = 2 * jnp.ones(9) opt = state.init(optix.scale_by_stddev(0.5, 0.))(random.PRNGKey(0), updates, params) np.testing.assert_array_equal(opt.mu, jnp.zeros(9)) np.testing.assert_array_equal(opt.nu, jnp.zeros(9)) opt = opt.update(updates, params) np.testing.assert_array_equal(opt.mu, jnp.ones(9)) np.testing.assert_array_equal(opt.nu, 2 * jnp.ones(9)) np.testing.assert_array_equal( opt(updates, params), 2 * jnp.ones(9) / jnp.sqrt(0.75))
def test_dense_combinator(self): def dense(x, init_key=None): return (nn.Dense(50) >> nn.Dense(20))(x, init_key=init_key, name='dense') in_spec = state.Shape(50) out_spec = state.spec(dense)(in_spec) self.assertEqual(out_spec, state.Shape(20, dtype=in_spec.dtype)) net = state.init(dense)(self._seed, jnp.ones(2)) self.assertTupleEqual(net(jnp.ones(2)).shape, out_spec.shape) np.testing.assert_allclose( dense(jnp.ones(2), init_key=self._seed), net(jnp.ones(2)), rtol=1e-5)
def test_dense_function(self): def dense_no_rng(x): return nn.Dense(20)(x, name='dense') with self.assertRaises(ValueError): out_spec = state.spec(dense_no_rng)(state.Shape(50)) with self.assertRaises(ValueError): net = state.init(dense_no_rng)(self._seed, state.Shape(2)) def dense(x, init_key=None): return nn.Dense(20)(x, init_key=init_key, name='dense') out_spec = state.spec(dense)(state.Shape(2)) self.assertEqual(out_spec, state.Shape(20)) net = state.init(dense)(self._seed, jnp.ones(2)) self.assertTupleEqual(net(jnp.ones(2)).shape, (20,)) np.testing.assert_allclose(net(jnp.ones(2)), dense(jnp.ones(2), init_key=self._seed), rtol=1e-5)
def run(params, init_key=None): opt = state.init(gradient_descent(update, objective), name='opt')(init_key, params) def body(carry, _): opt, params = carry params, opt = opt.call_and_update(params) return (opt, params), () opt, params = lax.scan(body, (opt, params), np.arange(num_iters))[0] opt, params = primitive.tie_all(state.assign(opt, name='opt'), params) return params
def test_multiple_input_multiple_output(self): def swap(x, y): return y, x in_spec = (state.Shape(50), state.Shape(20)) out_spec = state.spec(swap)(*in_spec) self.assertEqual(out_spec, (in_spec[1], in_spec[0])) net = state.init(swap)(self._seed, jnp.ones(50), jnp.ones(20)) for x1, x2 in zip(net(jnp.zeros(50), jnp.ones(20)), (jnp.ones(20), jnp.zeros(50))): np.testing.assert_array_equal(x1, x2)
def test_grad_of_function_constant(self): def template(x): return x + jnp.ones_like(x) net = state.init(template)(self._seed, jnp.ones(5)) def loss(net, x): return net(x).sum() g = jax.grad(loss)(net, jnp.ones(5)) def add(x, y): return x + y net = tree_util.tree_multimap(add, net, g) # w_new = w_old + 5 np.testing.assert_array_equal(net(jnp.ones(5)), 2 * jnp.ones(5))
def test_grad_of_function(self): def template(x, init_key=None): # jnp.ones(1) does not behave like a literal when tracing return ScalarMul(jnp.ones(1))(x, init_key=init_key, name='scalar_mul') net = state.init(template)(self._seed, jnp.ones(5)) def loss(net, x): return net(x).sum() g = jax.grad(loss)(net, jnp.ones(5)) def add(x, y): return x + y net = tree_util.tree_multimap(add, net, g) # w_new = w_old + 5 np.testing.assert_array_equal(net(jnp.ones(5)), 6 * jnp.ones(5))