def test_init_and_apply(): example_inputs = random.normal(RNG, (2, )) def net_fun(inputs): return 2 * layer(inputs) params = core.init_fun(net_fun, RNG, example_inputs) out = core.apply_fun(net_fun, params, example_inputs) assert out.shape == (3, )
def test_apply_batch(): example_input = random.normal(RNG, (2, )) def net_fun(inputs): return 2 * layer(inputs) params = core.init_fun(net_fun, RNG, example_input) example_input_batch = np.stack(4 * [example_input]) out = core.apply_fun(vmap(net_fun), params, example_input_batch) assert out.shape == (4, 3)