def test_internal_param_sharing2(): @parametrized def shared_net(inputs, layer=Sequential(Dense(2, zeros, zeros), relu)): inputs = layer(inputs) return layer(inputs) inputs = np.zeros((1, 2)) params = shared_net.init_parameters(PRNGKey(0), inputs) assert_params_equal((((np.zeros((2, 2)), np.zeros(2)),),), params) out = shared_net.apply(params, inputs) assert np.array_equal(np.zeros((1, 2)), out)
def test_submodule_without_inputs(): @parametrized def scalar(): return Parameter(lambda: np.zeros(())) params = scalar.init_parameters(PRNGKey(0)) assert_params_equal((), params) out = scalar.apply(params) assert np.array_equal(np.zeros(()), out) out_ = scalar.apply(params, jit=True) assert np.array_equal(out, out_)
def test_external_param_sharing(): layer = Dense(2, zeros, zeros) shared_net = Sequential(layer, layer) inputs = np.zeros((1, 2)) params = shared_net.init_parameters(PRNGKey(0), inputs) assert_params_equal(((np.zeros((2, 2)), np.zeros(2)),), params) out = shared_net.apply(params, inputs) assert np.array_equal(np.zeros((1, 2)), out) out = shared_net.apply(params, inputs, jit=True) assert np.array_equal(np.zeros((1, 2)), out)
def test_no_params(): @parametrized def double(inputs): return 2 * inputs inputs = np.zeros((1, 3)) params = double.init_parameters(PRNGKey(0), inputs) assert_params_equal((), params) out = double.apply(params, inputs) assert np.array_equal(np.zeros((1, 3)), out) out_ = double.apply(params, inputs, jit=True) assert np.array_equal(out, out_)
def test_internal_param_sharing(): @parametrized def shared_net(inputs, layer=Dense(2, zeros, zeros)): return layer(layer(inputs)) inputs = np.zeros((1, 2)) params = shared_net.init_parameters(PRNGKey(0), inputs) assert_params_equal(((np.zeros((2, 2)), np.zeros(2),),), params) out = shared_net.apply(params, inputs) assert np.array_equal(np.zeros((1, 2)), out) out_ = shared_net.apply(params, inputs, jit=True) assert np.array_equal(out, out_)
def test_Dense_shape(Dense=Dense): net = Dense(2, kernel_init=zeros, bias_init=zeros) inputs = np.zeros((1, 3)) params = net.init_parameters(PRNGKey(0), inputs) assert_params_equal((np.zeros((3, 2)), np.zeros(2)), params) out = net.apply(params, inputs) assert np.array_equal(np.zeros((1, 2)), out) out_ = jit(net.apply)(params, inputs) assert np.array_equal(out, out_) params_ = net.shaped(inputs).init_parameters(PRNGKey(0)) assert_params_equal(params, params_)
def test_params_from(): layer = Dense(2) net = Sequential(layer, relu) inputs = np.zeros((1, 3)) layer_params = layer.init_parameters(PRNGKey(0), inputs) params_ = net.parameters_from({layer: layer_params}, inputs) assert_params_equal((layer_params,), params_) out = net.apply(params_, inputs) out_ = net.apply_from({layer: layer_params}, inputs) assert np.array_equal(out, out_) out_ = net.apply_from({layer: layer_params}, inputs, jit=True) assert np.array_equal(out, out_)
def test_external_submodule2(): layer = Dense(2, zeros, zeros) @parametrized def net(inputs): return layer(inputs) inputs = np.zeros((1, 2)) params = net.init_parameters(PRNGKey(0), inputs) assert_params_equal(((np.zeros((2, 2)), np.zeros(2)),), params) out = net.apply(params, inputs) assert np.array_equal(np.zeros((1, 2)), out) out_ = net.apply(params, inputs, jit=True) assert np.array_equal(out, out_)
def test_scan_parametrized_cell_without_params(): @parametrized def cell(carry, x): return np.array([2]) * carry * x, np.array([2]) * carry * x @parametrized def rnn(inputs): _, outs = lax.scan(cell, np.zeros((2,)), inputs) return outs inputs = np.zeros((3,)) params = rnn.init_parameters(PRNGKey(0), inputs) assert_params_equal(((),), params) outs = rnn.apply(params, inputs) assert (3, 2) == outs.shape
def test_params_from_shared_submodules(): sublayer = Dense(2) a = Sequential(sublayer, relu) b = Sequential(sublayer, np.sum) @parametrized def net(inputs): return a(inputs) * b(inputs) inputs = np.zeros((1, 3)) a_params = a.init_parameters(PRNGKey(0), inputs) out = a.apply(a_params, inputs) params = net.parameters_from({a: a_params}, inputs) assert_params_equal(a_params.dense.kernel, params.sequential0.dense.kernel) assert_params_equal(a_params.dense.kernel, params.sequential1.dense.kernel) out = net.apply(params, inputs) out_ = net.apply_from({a: a_params}, inputs) assert np.array_equal(out, out_) out_ = net.apply_from({a: a_params}, inputs, jit=True) assert np.array_equal(out, out_) out_ = net.apply_from({a.shaped(inputs): a_params}, inputs) assert np.array_equal(out, out_) out_ = net.apply_from({a.shaped(inputs): a_params}, inputs, jit=True) assert np.array_equal(out, out_) out_ = net.shaped(inputs).apply_from({a: a_params}) assert np.array_equal(out, out_) out_ = net.shaped(inputs).apply_from({a: a_params}, jit=True) assert np.array_equal(out, out_) out_ = net.shaped(inputs).apply_from({a.shaped(inputs): a_params}) assert np.array_equal(out, out_) out_ = net.shaped(inputs).apply_from({a.shaped(inputs): a_params}, jit=True) assert np.array_equal(out, out_)