def test_sequential_params(self): seq = basic.Sequential([ basic.Sequential([basic.Linear(2), basic.Linear(2)]), basic.Sequential([lambda x: basic.Linear(2)(x * 1)])]) for _ in range(2): # Connect seq to ensure params are created. Connect twice to ensure that # we see the two instances of the lambda Linear. seq(jnp.zeros([1, 1])) params = seq.params_dict() self.assertCountEqual( list(params), ["linear/w", "linear/b", "linear_1/w", "linear_1/b", "sequential_1/linear/w", "sequential_1/linear/b"])
def test_sequential(self): seq = basic.Sequential([basic.Linear(2), jax.nn.relu]) out = seq(jnp.zeros([3, 2])) self.assertEqual(out.shape, (3, 2))
def f(): seq = basic.Sequential([basic.Linear(2), jax.nn.relu]) return seq(jnp.zeros([3, 2]))