def test_Parameter(Parameter=Parameter): scalar = Parameter(lambda _: np.zeros(())) params = scalar.init_parameters(key=PRNGKey(0)) assert np.zeros(()) == params out = scalar.apply(params) assert params == out
def test_Parameter_with_multiple_arrays(Parameter=Parameter): two_scalars = Parameter(lambda _: (np.zeros(()), np.zeros(()))) params = two_scalars.init_parameters(key=PRNGKey(0)) a, b = params assert np.zeros(()) == a assert np.zeros(()) == b out = two_scalars.apply(params) assert params == out