def test_parametrized_jit(jitted_fun): net = parametrized(jitted_fun) inputs = random_inputs((2, )) params = net.init_parameters(inputs, key=PRNGKey(0)) assert 'fun' == type(params).__name__ assert (3, ) == params.dense.bias.shape params_ = net.init_parameters(inputs, key=PRNGKey(0)) assert_parameters_equal(params, params_) out = net.apply(params, inputs) assert out.shape == (3, ) assert np.allclose([0.84194356, -1.5927866, -1.7411114], out) # run twice to cover cached jit call out_ = net.apply(params, inputs) assert np.allclose(out, out_) out = net.apply(params, inputs, jit=True) assert np.allclose(out, out_) out_ = net.apply(params, inputs, jit=True) assert np.allclose(out, out_) out_ = net.apply_from({net: params}, inputs, jit=True) assert np.allclose(out, out_)
def test_Batched(): out_dim = 1 @parametrized def unbatched_dense(input): kernel = parameter((out_dim, input.shape[-1]), ones) bias = parameter((out_dim, ), ones) return np.dot(kernel, input) + bias batch_size = 4 unbatched_params = unbatched_dense.init_parameters(np.zeros(2), key=PRNGKey(0)) out = unbatched_dense.apply(unbatched_params, np.ones(2)) assert np.array([3.]) == out dense_apply = vmap(unbatched_dense.apply, (None, 0)) out_batched_ = dense_apply(unbatched_params, np.ones((batch_size, 2))) assert np.array_equal(np.stack([out] * batch_size), out_batched_) dense = Batched(unbatched_dense) params = dense.init_parameters(np.ones((batch_size, 2)), key=PRNGKey(0)) assert_parameters_equal((unbatched_params, ), params) out_batched = dense.apply(params, np.ones((batch_size, 2))) assert np.array_equal(out_batched_, out_batched)
def test_internal_param_sharing2(): @parametrized def shared_net(inputs, layer=Sequential(Dense(2, zeros, zeros), relu)): inputs = layer(inputs) return layer(inputs) inputs = np.zeros((1, 2)) params = shared_net.init_parameters(inputs, key=PRNGKey(0)) assert_parameters_equal((((np.zeros((2, 2)), np.zeros(2)), ), ), params) out = shared_net.apply(params, inputs) assert np.array_equal(np.zeros((1, 2)), out)
def test_submodule_without_inputs(): @parametrized def scalar(): return Parameter(lambda key: np.zeros(()))() params = scalar.init_parameters(key=PRNGKey(0)) assert_parameters_equal((np.zeros(()), ), params) out = scalar.apply(params) assert np.zeros(()) == out out_ = scalar.apply(params, jit=True) assert out == out_
def test_external_param_sharing(): layer = Dense(2, zeros, zeros) shared_net = Sequential(layer, layer) inputs = np.zeros((1, 2)) params = shared_net.init_parameters(inputs, key=PRNGKey(0)) assert_parameters_equal(((np.zeros((2, 2)), np.zeros(2)), ), params) out = shared_net.apply(params, inputs) assert np.array_equal(np.zeros((1, 2)), out) out = shared_net.apply(params, inputs, jit=True) assert np.array_equal(np.zeros((1, 2)), out)
def test_submodule_without_inputs(): @parametrized def scalar(): return Parameter(lambda: np.zeros(())) params = scalar.init_parameters(PRNGKey(0)) assert_parameters_equal((), params) out = scalar.apply(params) assert np.array_equal(np.zeros(()), out) out_ = scalar.apply(params, jit=True) assert np.array_equal(out, out_)
def test_no_params(): @parametrized def double(inputs): return 2 * inputs inputs = np.zeros((1, 3)) params = double.init_parameters(inputs, key=PRNGKey(0)) assert_parameters_equal((), params) out = double.apply(params, inputs) assert np.array_equal(np.zeros((1, 3)), out) out_ = double.apply(params, inputs, jit=True) assert np.array_equal(out, out_)
def test_internal_param_sharing(): @parametrized def shared_net(inputs, layer=Dense(2, zeros, zeros)): return layer(layer(inputs)) inputs = np.zeros((1, 2)) params = shared_net.init_parameters(PRNGKey(0), inputs) assert_parameters_equal(((np.zeros((2, 2)), np.zeros(2),),), params) out = shared_net.apply(params, inputs) assert np.array_equal(np.zeros((1, 2)), out) out_ = shared_net.apply(params, inputs, jit=True) assert np.array_equal(out, out_)
def test_Dense_shape(Dense=Dense): net = Dense(2, kernel_init=zeros, bias_init=zeros) inputs = np.zeros((1, 3)) params = net.init_parameters(PRNGKey(0), inputs) assert_parameters_equal((np.zeros((3, 2)), np.zeros(2)), params) out = net.apply(params, inputs) assert np.array_equal(np.zeros((1, 2)), out) out_ = jit(net.apply)(params, inputs) assert np.array_equal(out, out_) params_ = net.shaped(inputs).init_parameters(PRNGKey(0)) assert_parameters_equal(params, params_)
def test_parameters_from(): layer = Dense(2) net = Sequential(layer, relu) inputs = np.zeros((1, 3)) layer_params = layer.init_parameters(inputs, key=PRNGKey(0)) params_ = net.parameters_from({layer: layer_params}, inputs) assert_parameters_equal((layer_params, ), params_) out = net.apply(params_, inputs) out_ = net.apply_from({layer: layer_params}, inputs) assert np.array_equal(out, out_) out_ = net.apply_from({layer: layer_params}, inputs, jit=True) assert np.array_equal(out, out_)
def test_external_submodule2(): layer = Dense(2, zeros, zeros) @parametrized def net(inputs): return layer(inputs) inputs = np.zeros((1, 2)) params = net.init_parameters(inputs, key=PRNGKey(0)) assert_parameters_equal(((np.zeros((2, 2)), np.zeros(2)), ), params) out = net.apply(params, inputs) assert np.array_equal(np.zeros((1, 2)), out) out_ = net.apply(params, inputs, jit=True) assert np.array_equal(out, out_)
def test_parameters_from_sharing_between_multiple_parents(): a = Dense(2) b = Sequential(a, np.sum) @parametrized def net(inputs): return a(inputs), b(inputs) inputs = np.zeros((1, 3)) a_params = a.init_parameters(inputs, key=PRNGKey(0)) out = a.apply(a_params, inputs) params = net.parameters_from({a: a_params}, inputs) assert_dense_parameters_equal(a_params, params.dense) assert_parameters_equal((), params.sequential) assert 2 == len(params) out_, _ = net.apply(params, inputs) assert np.array_equal(out, out_)
def test_scan_parametrized_cell_without_params(): @parametrized def cell(carry, x): return np.array([2]) * carry * x, np.array([2]) * carry * x @parametrized def rnn(inputs): _, outs = lax.scan(cell, np.zeros((2, )), inputs) return outs inputs = np.zeros((3, )) params = rnn.init_parameters(inputs, key=PRNGKey(0)) assert_parameters_equal(((), ), params) outs = rnn.apply(params, inputs) assert (3, 2) == outs.shape
def test_parameters_from_shared_submodules(): sublayer = Dense(2) a = Sequential(sublayer, relu) b = Sequential(sublayer, np.sum) @parametrized def net(inputs): return a(inputs) * b(inputs) inputs = np.zeros((1, 3)) a_params = a.init_parameters(inputs, key=PRNGKey(0)) out = a.apply(a_params, inputs) params = net.parameters_from({a: a_params}, inputs) assert_parameters_equal(a_params.dense.kernel, params.sequential0.dense.kernel) assert_parameters_equal((), params.sequential1) out = net.apply(params, inputs) out_ = net.apply_from({a: a_params}, inputs) assert np.array_equal(out, out_) out_ = net.apply_from({a: a_params}, inputs, jit=True) assert np.array_equal(out, out_) out_ = net.apply_from({a.shaped(inputs): a_params}, inputs) assert np.array_equal(out, out_) out_ = net.apply_from({a.shaped(inputs): a_params}, inputs, jit=True) assert np.array_equal(out, out_) out_ = net.shaped(inputs).apply_from({a: a_params}) assert np.array_equal(out, out_) out_ = net.shaped(inputs).apply_from({a: a_params}, jit=True) assert np.array_equal(out, out_) out_ = net.shaped(inputs).apply_from({a.shaped(inputs): a_params}) assert np.array_equal(out, out_) out_ = net.shaped(inputs).apply_from({a.shaped(inputs): a_params}, jit=True) assert np.array_equal(out, out_)