예제 #1
0
    def test_kwargs_training(self):
        def template(x, training=False, init_key=None):
            return IsTraining()(x,
                                name='training',
                                training=training,
                                init_key=init_key)

        net = state.init(template)(self._seed, state.Shape(()))
        self.assertEqual(net(np.ones(()), training=True), 1.)
        self.assertEqual(net(np.ones(()), training=False), 0.)

        def template1(x, training=True, init_key=None):
            return IsTraining()(x,
                                training=training,
                                name='training',
                                init_key=init_key)

        net = state.init(template1)(self._seed, state.Shape(()))
        self.assertEqual(net(np.ones(()), training=True), 1.)
        self.assertEqual(net(np.ones(()), training=False), 0.)

        def template2(x, init_key=None):
            return IsTraining()(x, name='training', init_key=init_key) + 1

        net = state.init(template2)(self._seed, state.Shape(()))
        self.assertEqual(net(np.ones(()), training=True), 2.)
        self.assertEqual(net(np.ones(()), training=False), 1.)
예제 #2
0
 def test_duplicate_names(self):
   def template(x, init_key=None):
     k1, k2 = random.split(init_key)
     layer1 = state.init(nn.Dense(20), name='dense')(k1, x)
     layer2 = state.init(nn.Dense(20), name='dense')(k2, x)
     return layer1(x) + layer2(x)
   with self.assertRaises(ValueError):
     state.init(template)(self._seed, jnp.ones(5))
예제 #3
0
 def test_kwargs_rng(self):
   def template(x, init_key=None):
     return Sampler()(x, name='sampler', init_key=init_key)
   with self.assertRaises(AssertionError):
     net = state.init(template)(self._seed, jnp.ones(()))
   def template1(x, rng, init_key=None):
     return Sampler()(x, rng=rng, init_key=init_key)
   net = state.init(template1)(self._seed, jnp.ones(()),
                               jnp.ones(2, dtype=jnp.uint32))
   x1 = net(jnp.ones(()), random.PRNGKey(0))
   x2 = net(jnp.ones(()), random.PRNGKey(1))
   self.assertNotEqual(x1, x2)
예제 #4
0
 def template(x, init_key=None):
   layer = state.init(ScalarMul(2 * jnp.ones(1)), name='scalar_mul')(
       init_key, x)
   x, layer = layer.call_and_update(x)
   x, layer = layer.call_and_update(x)
   state.assign(layer, name='scalar_mul')
   return x[0]
예제 #5
0
    def test_grad_of_stateful_function(self):
        def template(x, init_key=None):
            x = ScalarMul(np.ones(1))(x, init_key=init_key, name='scalar_mul')
            x = Counter(np.zeros(1))(x, init_key=init_key, name='counter')
            return x

        net = state.init(template)(self._seed, state.Shape(5))

        def loss(net, x):
            return net(x).sum()

        g = jax.grad(loss)(net, np.ones(5))

        def add(x, y):
            return x + y

        net = tree_util.tree_multimap(add, net, g)
        # w_new = w_old + 5
        onp.testing.assert_array_equal(net(np.ones(5)), 6 * np.ones(5))
        net = net.update(np.ones(5))
        onp.testing.assert_array_equal(net(np.ones(5)), 7 * np.ones(5))

        g = jax.grad(loss)(net, np.ones(5))
        net = tree_util.tree_multimap(add, net, g)
        # w_new = w_old + 5
        onp.testing.assert_array_equal(net(np.ones(5)), 12 * np.ones(5))
예제 #6
0
  def test_serialize(self, template):
    network = state.init(template)(random.PRNGKey(0), state.Shape(784))
    network2 = deserialize(serialize(network))

    rng = random.PRNGKey(0)
    onp.testing.assert_array_equal(
        jax.vmap(lambda x: network(x, rng=rng))(np.ones([10, 784])),
        jax.vmap(lambda x: network2(x, rng=rng))(np.ones([10, 784])))
예제 #7
0
 def update(params, updates, init_key=None):
   keys = random.split(init_key, len(args) + len(kwargs))
   names = [f'update_{i}' for i in range(len(args))] + list(kwargs.keys())
   for (name, key, update_fn) in zip(
       names, keys, itertools.chain(args, kwargs.values())):
     step = state.init(update_fn, name=name)(key, params, updates)
     updates = step(params, updates)
   return updates
예제 #8
0
    def test_shared_layer(self):
        def template(x, init_key=None):
            layer = state.init(ScalarMul(2 * np.ones(1)),
                               name='scalar_mul')(init_key, x)
            return layer(layer(x))

        net = state.init(template)(self._seed, state.Shape(5))
        onp.testing.assert_array_equal(net(np.ones(5)), 4 * np.ones(5))
예제 #9
0
  def test_function_in_combinator(self):
    def add_one(x):
      return x + 1

    template = AddOne() >> add_one >> AddOne()

    net = state.init(template)(self._seed, jnp.ones(2))
    np.testing.assert_array_equal(net(jnp.zeros(2)), 3 * jnp.ones(2))
예제 #10
0
    def test_optimizer(self, update):
        def loss(x):
            return np.sum(x**2)

        x = np.array([3., 4.])
        opt = state.init(optix.optimize(loss, update, 500))(random.PRNGKey(0),
                                                            x)
        x = jax.jit(opt.call)(x)
        onp.testing.assert_allclose(np.zeros(2), x, atol=1e-1, rtol=1e-1)
예제 #11
0
 def test_trace_should_keep_track_of_momentum(self):
     params = np.zeros(6)
     updates = np.ones(6)
     opt = state.init(optix.trace(0.99, False))(random.PRNGKey(0), params,
                                                updates)
     onp.testing.assert_array_equal(opt.trace, np.zeros(6))
     opt = opt.update(params, updates)
     onp.testing.assert_array_equal(opt.trace, np.ones(6))
     onp.testing.assert_array_equal(opt(params, updates), 1.99 * np.ones(6))
예제 #12
0
  def test_identity(self):
    def identity(x):
      return x
    in_spec = state.Shape(50)
    out_spec = state.spec(identity)(in_spec)
    self.assertEqual(out_spec, in_spec)

    net = state.init(identity)(self._seed, jnp.ones(50))
    np.testing.assert_array_equal(net(jnp.arange(5)), jnp.arange(5))
예제 #13
0
  def test_call_list(self):
    def template(x, init_key=None):
      return state.call([Counter(0.), AddOne()], x, init_key=init_key,
                        name='counter_add_one')
    layer = state.init(template)(self._seed, jnp.ones(()))
    self.assertEqual(layer(jnp.zeros(())), 1)

    layer = layer.update(jnp.zeros(()))
    self.assertEqual(layer(jnp.zeros(())), 2)
예제 #14
0
  def test_function_in_combinator_in_function(self):
    def add_one(x):
      return x + 1
    def template(x, init_key=None):
      return (AddOne() >> add_one >> AddOne())(x, init_key=init_key)

    net = state.init(template)(self._seed, jnp.ones(2))
    np.testing.assert_array_equal(net(jnp.zeros(2), init_key=self._seed),
                                  3 * jnp.ones(2))
예제 #15
0
  def test_call_tuple(self):
    def template(x, init_key=None):
      return state.call((Counter(0.), AddOne()), x, init_key=init_key,
                        name='counter_add_one')
    layer = state.init(template)(self._seed, jnp.ones(()))
    self.assertTupleEqual(layer(jnp.zeros(())), (0, 1))

    layer = layer.update(jnp.zeros(()))
    self.assertTupleEqual(layer(jnp.zeros(())), (1, 1))
예제 #16
0
 def test_trace_should_keep_track_of_momentum_with_nesterov(self):
   params = jnp.zeros(6)
   updates = jnp.ones(6)
   opt = state.init(optix.trace(0.99, True))(random.PRNGKey(0),
                                             updates, params)
   np.testing.assert_array_equal(opt.trace, jnp.zeros(6))
   opt = opt.update(updates, params)
   np.testing.assert_array_equal(opt.trace, jnp.ones(6))
   np.testing.assert_array_equal(
       opt(updates, params), (1.99 + 0.99**2) * jnp.ones(6))
예제 #17
0
 def test_update_in_combinator(self):
   def template(x, init_key=None):
     def increment(x, init_key=None):
       return Counter(jnp.zeros(()))(x, init_key=init_key, name='counter')
     return nn.Serial([increment, increment])(x, init_key=init_key,
                                              name='increment')
   net = state.init(template)(self._seed, jnp.ones(()))
   self.assertEqual(net(jnp.ones(())), 1.)
   net = state.update(net, jnp.ones(()))
   self.assertEqual(net(jnp.ones(())), 3.)
예제 #18
0
 def test_shared_layer(self):
   def template(x, init_key=None):
     layer = state.init(ScalarMul(2 * jnp.ones(1)), name='scalar_mul')(
         init_key, x)
     x, layer = layer.call_and_update(x)
     x, layer = layer.call_and_update(x)
     state.assign(layer, name='scalar_mul')
     return x
   net = state.init(template)(self._seed, jnp.ones(5))
   np.testing.assert_array_equal(net(jnp.ones(5)), 4 * jnp.ones(5))
예제 #19
0
 def test_add_noise_should_add_noise(self):
     params = 0.
     updates = 1.
     opt = state.init(optix.add_noise(1., 0., 0))(random.PRNGKey(0), params,
                                                  updates)
     onp.testing.assert_array_equal(opt.count, 0.)
     onp.testing.assert_array_equal(opt.rng_key, random.PRNGKey(0))
     value, opt = opt.call_and_update(params, updates)
     onp.testing.assert_array_equal(
         value, 1. + random.normal(random.split(random.PRNGKey(0))[1], ()))
예제 #20
0
 def test_scale_by_rms_should_scale_by_rms(self):
     params = np.zeros(9)
     updates = 2 * np.ones(9)
     opt = state.init(optix.scale_by_rms(0.5, 0.))(random.PRNGKey(0),
                                                   params, updates)
     onp.testing.assert_array_equal(opt.nu, np.zeros(9))
     opt = opt.update(params, updates)
     onp.testing.assert_array_equal(opt.nu, 2 * np.ones(9))
     onp.testing.assert_array_equal(opt(params, updates),
                                    2 * np.ones(9) / np.sqrt(3.))
예제 #21
0
  def test_update(self):
    def template(x, init_key=None):
      return Counter(jnp.zeros(()))(x, init_key=init_key, name='counter')
    net = state.init(template)(self._seed, jnp.ones(()))
    self.assertEqual(net(jnp.ones(())), 1.)

    net2 = state.update(net, jnp.ones(()))
    self.assertEqual(net2(jnp.ones(())), 2.)

    net2 = net.update(jnp.ones(()))
    self.assertEqual(net2(jnp.ones(())), 2.)
예제 #22
0
 def test_apply_every_should_delay_updates(self):
   params = 0.
   updates = 1.
   opt = state.init(optix.apply_every(5))(random.PRNGKey(0), updates, params)
   np.testing.assert_array_equal(opt.count, 0.)
   np.testing.assert_array_equal(opt.grad_acc, 0.)
   for _ in range(4):
     value, opt = opt.call_and_update(updates, params)
     np.testing.assert_array_equal(value, 0.)
   value = opt(updates, params)
   np.testing.assert_array_equal(value, 5.)
예제 #23
0
    def test_single_input_multiple_output(self):
        def dup(x):
            return x, x

        in_spec = state.Shape(50)
        out_spec = state.spec(dup)(in_spec)
        self.assertEqual(out_spec, (in_spec, in_spec))

        net = state.init(dup)(self._seed, in_spec)
        for x1, x2 in zip(net(np.arange(5)), (np.arange(5), np.arange(5))):
            onp.testing.assert_array_equal(x1, x2)
예제 #24
0
 def test_scale_by_stddev_should_scale_by_stddev(self):
   params = jnp.zeros(9)
   updates = 2 * jnp.ones(9)
   opt = state.init(optix.scale_by_stddev(0.5, 0.))(random.PRNGKey(0), updates,
                                                    params)
   np.testing.assert_array_equal(opt.mu, jnp.zeros(9))
   np.testing.assert_array_equal(opt.nu, jnp.zeros(9))
   opt = opt.update(updates, params)
   np.testing.assert_array_equal(opt.mu, jnp.ones(9))
   np.testing.assert_array_equal(opt.nu, 2 * jnp.ones(9))
   np.testing.assert_array_equal(
       opt(updates, params), 2 * jnp.ones(9) / jnp.sqrt(0.75))
예제 #25
0
  def test_dense_combinator(self):
    def dense(x, init_key=None):
      return (nn.Dense(50) >> nn.Dense(20))(x, init_key=init_key, name='dense')
    in_spec = state.Shape(50)
    out_spec = state.spec(dense)(in_spec)
    self.assertEqual(out_spec, state.Shape(20, dtype=in_spec.dtype))

    net = state.init(dense)(self._seed, jnp.ones(2))
    self.assertTupleEqual(net(jnp.ones(2)).shape, out_spec.shape)
    np.testing.assert_allclose(
        dense(jnp.ones(2), init_key=self._seed),
        net(jnp.ones(2)), rtol=1e-5)
예제 #26
0
  def test_dense_function(self):
    def dense_no_rng(x):
      return nn.Dense(20)(x, name='dense')

    with self.assertRaises(ValueError):
      out_spec = state.spec(dense_no_rng)(state.Shape(50))

    with self.assertRaises(ValueError):
      net = state.init(dense_no_rng)(self._seed, state.Shape(2))

    def dense(x, init_key=None):
      return nn.Dense(20)(x, init_key=init_key, name='dense')

    out_spec = state.spec(dense)(state.Shape(2))
    self.assertEqual(out_spec, state.Shape(20))

    net = state.init(dense)(self._seed, jnp.ones(2))
    self.assertTupleEqual(net(jnp.ones(2)).shape, (20,))
    np.testing.assert_allclose(net(jnp.ones(2)),
                               dense(jnp.ones(2), init_key=self._seed),
                               rtol=1e-5)
예제 #27
0
    def run(params, init_key=None):
        opt = state.init(gradient_descent(update, objective),
                         name='opt')(init_key, params)

        def body(carry, _):
            opt, params = carry
            params, opt = opt.call_and_update(params)
            return (opt, params), ()

        opt, params = lax.scan(body, (opt, params), np.arange(num_iters))[0]
        opt, params = primitive.tie_all(state.assign(opt, name='opt'), params)
        return params
예제 #28
0
  def test_multiple_input_multiple_output(self):
    def swap(x, y):
      return y, x

    in_spec = (state.Shape(50), state.Shape(20))
    out_spec = state.spec(swap)(*in_spec)
    self.assertEqual(out_spec, (in_spec[1], in_spec[0]))

    net = state.init(swap)(self._seed, jnp.ones(50), jnp.ones(20))
    for x1, x2 in zip(net(jnp.zeros(50), jnp.ones(20)),
                      (jnp.ones(20), jnp.zeros(50))):
      np.testing.assert_array_equal(x1, x2)
예제 #29
0
 def test_grad_of_function_constant(self):
   def template(x):
     return x + jnp.ones_like(x)
   net = state.init(template)(self._seed, jnp.ones(5))
   def loss(net, x):
     return net(x).sum()
   g = jax.grad(loss)(net, jnp.ones(5))
   def add(x, y):
     return x + y
   net = tree_util.tree_multimap(add, net, g)
   # w_new = w_old + 5
   np.testing.assert_array_equal(net(jnp.ones(5)), 2 * jnp.ones(5))
예제 #30
0
 def test_grad_of_function(self):
   def template(x, init_key=None):
     # jnp.ones(1) does not behave like a literal when tracing
     return ScalarMul(jnp.ones(1))(x, init_key=init_key, name='scalar_mul')
   net = state.init(template)(self._seed, jnp.ones(5))
   def loss(net, x):
     return net(x).sum()
   g = jax.grad(loss)(net, jnp.ones(5))
   def add(x, y):
     return x + y
   net = tree_util.tree_multimap(add, net, g)
   # w_new = w_old + 5
   np.testing.assert_array_equal(net(jnp.ones(5)), 6 * jnp.ones(5))