def net(): p = Parameter(lambda key: np.zeros((1, ))) a = p() b = parameter((2, ), zeros) c = parameter((3, ), zeros) d = parameter((4, ), zeros) e = parameter((5, ), zeros) f = parameter((6, ), zeros) # must not mess up order (decided by first submodule call): k = p() return np.concatenate([a, f]) + np.concatenate( [b, e]) + np.concatenate([c, d]) + k
def conv_or_conv_transpose(inputs): V = parameter(filter_shape + (inputs.shape[-1], out_chan), normal(.05), 'V') example_out = apply(inputs, V=V, g=np.ones(out_chan), b=np.zeros(out_chan)) # TODO remove need for `.aval.val` when capturing variables in initializer function: g = Parameter(lambda key: init_scale / np.sqrt(np.var(example_out.aval.val, (0, 1, 2)) + 1e-10), 'g')() b = Parameter(lambda key: np.mean(example_out.aval.val, (0, 1, 2)) * g.aval.val, 'b')() return apply(inputs, V, b, g)
def dense(inputs): V = parameter((out_chan, inputs.shape[1]), randn(stddev=.05), inputs, 'V') # TODO apply = vmap(apply, (0, None, None, None)) example_output = lambda: apply( inputs, V, g=np.ones(out_chan), b=np.zeros(out_chan)) g = Parameter( lambda rng: init_scale / np.sqrt( np.var(example_output(), 0) + 1e-10), 'g')(inputs) b = Parameter(lambda rng: np.mean(example_output(), 0) * g, 'b')(inputs) return apply(inputs, V, g, b)
def wrapper(dummy_inputs): a = parameter((1,), zeros, dummy_inputs) b = parameter((2,), zeros, dummy_inputs) c = parameter((3,), zeros, dummy_inputs) d = parameter((4,), zeros, dummy_inputs) e = parameter((5,), zeros, dummy_inputs) f = parameter((6,), zeros, dummy_inputs) return np.concatenate([a, f]) + np.concatenate([b, e]) + np.concatenate([c, d])
def net(input_dict): return input_dict['a'] * input_dict['b'] * parameter((), zeros)
def net(inputs): assert isinstance(inputs, type) return inputs[0] * inputs[1] * parameter((), zeros)
def dense(inputs): return linear_map(inputs) + parameter((2, ), zeros, 'bias')
def dense(inputs): bias = parameter((2, ), zeros, 'bias') kernel = parameter((inputs.shape[-1], 2), zeros, 'kernel') return np.dot(inputs, kernel) + bias
def dense(inputs): a = parameter((), randn(), inputs, 'a') b = parameter((), randn(), inputs, 'b') return a + b
def linear_map(inputs): kernel = parameter((inputs.shape[-1], 2), zeros, 'kernel') return np.dot(inputs, kernel)
def dense(inputs): kernel = parameter((inputs.shape[-1], out_dim), kernel_init) bias = parameter((out_dim,), bias_init) return np.dot(inputs, kernel) + bias
def net(inputs): return parameter((), lambda key, shape: 2 * jnp.ones(shape))
def learnable_scale(params): return 2 * parameter((), ones, params) * params
def net(inputs): return parameter((), lambda rng, shape: 2 * np.ones(shape), inputs)
def loss(inputs): a = parameter((), ones, inputs, 'a') b = parameter((), lambda rng, shape: 2 * np.ones(shape), inputs, 'b') return a + b
def unbatched_dense(input): kernel = parameter((out_dim, input.shape[-1]), ones) bias = parameter((out_dim, ), ones) return np.dot(kernel, input) + bias
def net(): return parameter((), lambda key, shape: 2 * np.ones(shape))
def net(inputs): return inputs, inputs * parameter((), zeros)
def cell(carry, x): scale = parameter((2, ), zeros) return { 'a': scale * np.array([2]) * carry['a'] * x }, scale * np.array([2]) * carry['a'] * x
def dense(): a = parameter((), normal(), 'a') b = parameter((), normal(), 'b') return a + b
def net(input_dict): return input_dict[0] * input_dict[1] * parameter((), zeros, input_dict[0])
def cell(carry, x): scale = parameter((2, ), zeros) return scale * np.array([2]) * carry * x, scale * np.array( [2]) * carry * x
def loss(inputs): a = parameter((), ones, 'a') b = parameter((), lambda key, shape: 2 * jnp.ones(shape), 'b') return a + b