def test_Parameter(Parameter=Parameter): scalar = Parameter(lambda _: np.zeros(())) params = scalar.init_parameters(key=PRNGKey(0)) assert np.zeros(()) == params out = scalar.apply(params) assert params == out
def test_Parameter_with_multiple_arrays(Parameter=Parameter): two_scalars = Parameter(lambda _: (np.zeros(()), np.zeros(()))) params = two_scalars.init_parameters(key=PRNGKey(0)) a, b = params assert np.zeros(()) == a assert np.zeros(()) == b out = two_scalars.apply(params) assert params == out
def conv_or_conv_transpose(inputs): V = parameter(filter_shape + (inputs.shape[-1], out_chan), normal(.05), 'V') example_out = apply(inputs, V=V, g=np.ones(out_chan), b=np.zeros(out_chan)) # TODO remove need for `.aval.val` when capturing variables in initializer function: g = Parameter(lambda key: init_scale / np.sqrt(np.var(example_out.aval.val, (0, 1, 2)) + 1e-10), 'g')() b = Parameter(lambda key: np.mean(example_out.aval.val, (0, 1, 2)) * g.aval.val, 'b')() return apply(inputs, V, b, g)
def dense(inputs): V = parameter((out_chan, inputs.shape[1]), randn(stddev=.05), inputs, 'V') # TODO apply = vmap(apply, (0, None, None, None)) example_output = lambda: apply( inputs, V, g=np.ones(out_chan), b=np.zeros(out_chan)) g = Parameter( lambda rng: init_scale / np.sqrt( np.var(example_output(), 0) + 1e-10), 'g')(inputs) b = Parameter(lambda rng: np.mean(example_output(), 0) * g, 'b')(inputs) return apply(inputs, V, g, b)
def conv_or_conv_transpose(inputs): V = Parameter( lambda rng: randn(.05)(rng, tuple(filter_shape) + (inputs.shape[-1], out_chan)), 'V')(inputs) # TODO apply = vmap(apply, (0, None, None, None)) example_output = lambda: apply( inputs, V=V, g=np.ones(out_chan), b=np.zeros(out_chan)) g = Parameter( lambda rng: init_scale / np.sqrt( np.var(example_output(), (0, 1, 2)) + 1e-10), 'g')(inputs) b = Parameter(lambda rng: np.mean(example_output(), (0, 1, 2)) * g, 'b')(inputs) return apply(inputs, V, b, g)
def test_deep_nested_inline_submodule(): Net = lambda: parametrized(lambda inputs: Parameter(lambda key: np.zeros( ()))(), name='net') Net2 = lambda: parametrized(lambda inputs: Net()(inputs), name='net2') Net3 = lambda: parametrized(lambda inputs: Net2()(inputs), name='net3') Net4 = lambda: parametrized(lambda inputs: Net3()(inputs), name='net4') net = Net4() params = net.init_parameters(np.zeros(()), key=PRNGKey(0)) out = net.apply(params, np.zeros(())) assert 0 == out
def net(): p = Parameter(lambda key: np.zeros((1, ))) a = p() b = parameter((2, ), zeros) c = parameter((3, ), zeros) d = parameter((4, ), zeros) e = parameter((5, ), zeros) f = parameter((6, ), zeros) # must not mess up order (decided by first submodule call): k = p() return np.concatenate([a, f]) + np.concatenate( [b, e]) + np.concatenate([c, d]) + k
def test_diamond_shared_submodules(): p = Parameter(lambda rng: np.ones(())) a = Sequential(p) b = Sequential(p) @parametrized def net(inputs): return a(inputs), b(inputs) params = net.init_parameters(PRNGKey(0), np.zeros(())) assert 1 == len(params) assert np.array_equal(np.ones(()), params) a, b = net.apply(params, np.zeros(())) assert np.array_equal(np.ones(()), a) assert np.array_equal(np.ones(()), b)
def test_parameter_sharing_between_multiple_parents(): p = Parameter(lambda key: np.ones(())) @parametrized def wrapped(): return p() @parametrized def net(): return wrapped(), p() params = net.init_parameters(key=PRNGKey(0)) assert 1 == len(params) assert np.array_equal(np.ones(()), params.wrapped.parameter) a, b = net.apply(params) assert np.array_equal(np.ones(()), a) assert np.array_equal(np.ones(()), b)
def dense(inputs): kernel = Parameter(lambda key: kernel_init(key, (inputs.shape[-1], out_dim)))() bias = Parameter(lambda key: bias_init(key, (out_dim,)))() return np.dot(inputs, kernel) + bias
def wrapper(): return Parameter(lambda _: (np.zeros(()), np.zeros(())))()
def scalar(): return Parameter(lambda key: np.zeros(()))()
def dense(inputs): kernel = Parameter(lambda rng: kernel_init(rng, (inputs.shape[-1], out_dim)))(inputs) bias = Parameter(lambda rng: bias_init(rng, (out_dim, )))(inputs) return np.dot(inputs, kernel) + bias
def wrapper(dummy_inputs): return Parameter(lambda _: (np.zeros(()), np.zeros(())))(dummy_inputs)