Beispiel #1
0
 def test_sequential_params(self):
   seq = basic.Sequential([
       basic.Sequential([basic.Linear(2), basic.Linear(2)]),
       basic.Sequential([lambda x: basic.Linear(2)(x * 1)])])
   for _ in range(2):
     # Connect seq to ensure params are created. Connect twice to ensure that
     # we see the two instances of the lambda Linear.
     seq(jnp.zeros([1, 1]))
   params = seq.params_dict()
   self.assertCountEqual(
       list(params),
       ["linear/w", "linear/b", "linear_1/w", "linear_1/b",
        "sequential_1/linear/w", "sequential_1/linear/b"])
Beispiel #2
0
 def test_sequential(self):
     seq = basic.Sequential([basic.Linear(2), jax.nn.relu])
     out = seq(jnp.zeros([3, 2]))
     self.assertEqual(out.shape, (3, 2))
Beispiel #3
0
 def f():
   seq = basic.Sequential([basic.Linear(2), jax.nn.relu])
   return seq(jnp.zeros([3, 2]))