示例#1
0
def test_Parameter(Parameter=Parameter):
    scalar = Parameter(lambda _: np.zeros(()))
    params = scalar.init_parameters(key=PRNGKey(0))

    assert np.zeros(()) == params
    out = scalar.apply(params)
    assert params == out
示例#2
0
def test_Parameter_with_multiple_arrays(Parameter=Parameter):
    two_scalars = Parameter(lambda _: (np.zeros(()), np.zeros(())))
    params = two_scalars.init_parameters(key=PRNGKey(0))

    a, b = params
    assert np.zeros(()) == a
    assert np.zeros(()) == b
    out = two_scalars.apply(params)
    assert params == out