示例#1
0
  def test_nested_add(self):
    def add(x):
      return (lambda x: jnp.add(x, x))(x)
    in_spec = state.Shape(50)
    out_spec = state.spec(add)(in_spec)
    self.assertEqual(out_spec, in_spec)

    in_spec = state.Shape((100, 20))
    out_spec = state.spec(add)(in_spec)
    self.assertEqual(out_spec, in_spec)

    in_spec = state.Shape(2)
    net = state.init(add)(self._seed, jnp.ones(2))
    np.testing.assert_array_equal(net(jnp.arange(5)), 2 * jnp.arange(5))
示例#2
0
    def test_add(self):
        def add(x):
            return np.add(x, x)

        in_spec = state.Shape(50)
        out_spec = state.spec(add)(in_spec)
        self.assertEqual(out_spec, in_spec)

        in_spec = state.Shape((100, 20))
        out_spec = state.spec(add)(in_spec)
        self.assertEqual(out_spec, in_spec)

        in_spec = state.Shape(2)
        net = state.init(add)(self._seed, in_spec)
        onp.testing.assert_array_equal(net(np.arange(5)), 2 * np.arange(5))
示例#3
0
  def test_multiple_input_single_output(self):
    def add(x, y):
      return jnp.add(x, y)

    in_spec = (state.Shape((100, 20)), state.Shape(50))
    with self.assertRaises(ValueError):
      out_spec = state.spec(add)(*in_spec)

    in_spec = (state.Shape(50), state.Shape(50))
    out_spec = state.spec(add)(*in_spec)
    self.assertEqual(out_spec, in_spec[0])

    net = state.init(add)(self._seed, jnp.ones(50), jnp.ones(50))
    np.testing.assert_array_equal(net(jnp.arange(5), jnp.arange(5)),
                                  2 * jnp.arange(5))
示例#4
0
    def test_add_one_combinator(self):
        def add_two(x, init_key=None):
            return (AddOne() >> AddOne())(x, name='add_one', init_key=init_key)

        in_spec = state.Shape(20)
        out_spec = state.spec(add_two)(in_spec)
        self.assertEqual(out_spec, in_spec)

        in_spec = state.Shape((5, 50))
        out_spec = state.spec(add_two)(in_spec)
        self.assertEqual(out_spec, in_spec)

        net = state.init(add_two)(self._seed, state.Shape(2))
        onp.testing.assert_allclose(net(np.ones(2)), 3 * np.ones(2))
        onp.testing.assert_array_equal(
            net(np.ones(2)), add_two(np.ones(2), init_key=self._seed))
示例#5
0
  def test_identity(self):
    def identity(x):
      return x
    in_spec = state.Shape(50)
    out_spec = state.spec(identity)(in_spec)
    self.assertEqual(out_spec, in_spec)

    net = state.init(identity)(self._seed, jnp.ones(50))
    np.testing.assert_array_equal(net(jnp.arange(5)), jnp.arange(5))
示例#6
0
    def test_add_one_imperative(self):
        def add_two(x, init_key=None):
            k1, k2 = random.split(init_key)
            x = AddOne()(x, name='add_one_1', init_key=k1)
            x = AddOne()(x, name='add_one_2', init_key=k2)
            return x

        in_spec = state.Shape(20)
        out_spec = state.spec(add_two)(in_spec)
        self.assertEqual(out_spec, in_spec)

        in_spec = state.Shape((5, 50))
        out_spec = state.spec(add_two)(in_spec)
        self.assertEqual(out_spec, in_spec)

        net = state.init(add_two)(self._seed, state.Shape(2))
        onp.testing.assert_array_equal(net(np.ones(2)), 3 * np.ones(2))
        onp.testing.assert_array_equal(
            net(np.ones(2)), add_two(np.ones(2), init_key=self._seed))
示例#7
0
    def test_single_input_multiple_output(self):
        def dup(x):
            return x, x

        in_spec = state.Shape(50)
        out_spec = state.spec(dup)(in_spec)
        self.assertEqual(out_spec, (in_spec, in_spec))

        net = state.init(dup)(self._seed, in_spec)
        for x1, x2 in zip(net(np.arange(5)), (np.arange(5), np.arange(5))):
            onp.testing.assert_array_equal(x1, x2)
示例#8
0
  def test_dense_combinator(self):
    def dense(x, init_key=None):
      return (nn.Dense(50) >> nn.Dense(20))(x, init_key=init_key, name='dense')
    in_spec = state.Shape(50)
    out_spec = state.spec(dense)(in_spec)
    self.assertEqual(out_spec, state.Shape(20, dtype=in_spec.dtype))

    net = state.init(dense)(self._seed, jnp.ones(2))
    self.assertTupleEqual(net(jnp.ones(2)).shape, out_spec.shape)
    np.testing.assert_allclose(
        dense(jnp.ones(2), init_key=self._seed),
        net(jnp.ones(2)), rtol=1e-5)
示例#9
0
  def test_dense_function(self):
    def dense_no_rng(x):
      return nn.Dense(20)(x, name='dense')

    with self.assertRaises(ValueError):
      out_spec = state.spec(dense_no_rng)(state.Shape(50))

    with self.assertRaises(ValueError):
      net = state.init(dense_no_rng)(self._seed, state.Shape(2))

    def dense(x, init_key=None):
      return nn.Dense(20)(x, init_key=init_key, name='dense')

    out_spec = state.spec(dense)(state.Shape(2))
    self.assertEqual(out_spec, state.Shape(20))

    net = state.init(dense)(self._seed, jnp.ones(2))
    self.assertTupleEqual(net(jnp.ones(2)).shape, (20,))
    np.testing.assert_allclose(net(jnp.ones(2)),
                               dense(jnp.ones(2), init_key=self._seed),
                               rtol=1e-5)
示例#10
0
  def test_multiple_input_multiple_output(self):
    def swap(x, y):
      return y, x

    in_spec = (state.Shape(50), state.Shape(20))
    out_spec = state.spec(swap)(*in_spec)
    self.assertEqual(out_spec, (in_spec[1], in_spec[0]))

    net = state.init(swap)(self._seed, jnp.ones(50), jnp.ones(20))
    for x1, x2 in zip(net(jnp.zeros(50), jnp.ones(20)),
                      (jnp.ones(20), jnp.zeros(50))):
      np.testing.assert_array_equal(x1, x2)
示例#11
0
  def test_dense_imperative(self):
    def dense(x, init_key=None):
      key, subkey = random.split(init_key)
      x = nn.Dense(50)(x, init_key=key, name='dense1')
      x = nn.Dense(20)(x, init_key=subkey, name='dense2')
      return x
    in_spec = state.Shape(50)
    out_spec = state.spec(dense)(in_spec)
    self.assertEqual(out_spec, state.Shape(20, dtype=in_spec.dtype))

    net = state.init(dense)(self._seed, jnp.ones(2))
    self.assertTupleEqual(net(jnp.ones(2)).shape, out_spec.shape)
    np.testing.assert_allclose(
        dense(jnp.ones(2), init_key=self._seed),
        net(jnp.ones(2)), rtol=1e-5)
示例#12
0
 def spec(cls, *args):
     in_specs, layer_inits = args[:-1], args[-1]
     return state.spec(list(layer_inits))(*in_specs)