示例#1
0
def test_Parameter(Parameter=Parameter):
    scalar = Parameter(lambda _: np.zeros(()))
    params = scalar.init_parameters(key=PRNGKey(0))

    assert np.zeros(()) == params
    out = scalar.apply(params)
    assert params == out
示例#2
0
def test_Parameter_with_multiple_arrays(Parameter=Parameter):
    two_scalars = Parameter(lambda _: (np.zeros(()), np.zeros(())))
    params = two_scalars.init_parameters(key=PRNGKey(0))

    a, b = params
    assert np.zeros(()) == a
    assert np.zeros(()) == b
    out = two_scalars.apply(params)
    assert params == out
示例#3
0
    def conv_or_conv_transpose(inputs):
        V = parameter(filter_shape + (inputs.shape[-1], out_chan), normal(.05), 'V')

        example_out = apply(inputs, V=V, g=np.ones(out_chan), b=np.zeros(out_chan))

        # TODO remove need for `.aval.val` when capturing variables in initializer function:
        g = Parameter(lambda key: init_scale /
                                  np.sqrt(np.var(example_out.aval.val, (0, 1, 2)) + 1e-10), 'g')()
        b = Parameter(lambda key: np.mean(example_out.aval.val, (0, 1, 2)) * g.aval.val, 'b')()

        return apply(inputs, V, b, g)
示例#4
0
    def dense(inputs):
        V = parameter((out_chan, inputs.shape[1]), randn(stddev=.05), inputs,
                      'V')

        # TODO apply = vmap(apply, (0, None, None, None))
        example_output = lambda: apply(
            inputs, V, g=np.ones(out_chan), b=np.zeros(out_chan))

        g = Parameter(
            lambda rng: init_scale / np.sqrt(
                np.var(example_output(), 0) + 1e-10), 'g')(inputs)
        b = Parameter(lambda rng: np.mean(example_output(), 0) * g,
                      'b')(inputs)
        return apply(inputs, V, g, b)
示例#5
0
    def conv_or_conv_transpose(inputs):
        V = Parameter(
            lambda rng: randn(.05)(rng, tuple(filter_shape) +
                                   (inputs.shape[-1], out_chan)), 'V')(inputs)

        # TODO apply = vmap(apply, (0, None, None, None))
        example_output = lambda: apply(
            inputs, V=V, g=np.ones(out_chan), b=np.zeros(out_chan))

        g = Parameter(
            lambda rng: init_scale / np.sqrt(
                np.var(example_output(), (0, 1, 2)) + 1e-10), 'g')(inputs)
        b = Parameter(lambda rng: np.mean(example_output(), (0, 1, 2)) * g,
                      'b')(inputs)

        return apply(inputs, V, b, g)
示例#6
0
def test_deep_nested_inline_submodule():
    Net = lambda: parametrized(lambda inputs: Parameter(lambda key: np.zeros(
        ()))(),
                               name='net')
    Net2 = lambda: parametrized(lambda inputs: Net()(inputs), name='net2')
    Net3 = lambda: parametrized(lambda inputs: Net2()(inputs), name='net3')
    Net4 = lambda: parametrized(lambda inputs: Net3()(inputs), name='net4')

    net = Net4()
    params = net.init_parameters(np.zeros(()), key=PRNGKey(0))
    out = net.apply(params, np.zeros(()))
    assert 0 == out
示例#7
0
    def net():
        p = Parameter(lambda key: np.zeros((1, )))
        a = p()
        b = parameter((2, ), zeros)
        c = parameter((3, ), zeros)
        d = parameter((4, ), zeros)
        e = parameter((5, ), zeros)
        f = parameter((6, ), zeros)

        # must not mess up order (decided by first submodule call):
        k = p()

        return np.concatenate([a, f]) + np.concatenate(
            [b, e]) + np.concatenate([c, d]) + k
示例#8
0
def test_diamond_shared_submodules():
    p = Parameter(lambda rng: np.ones(()))
    a = Sequential(p)
    b = Sequential(p)

    @parametrized
    def net(inputs):
        return a(inputs), b(inputs)

    params = net.init_parameters(PRNGKey(0), np.zeros(()))
    assert 1 == len(params)
    assert np.array_equal(np.ones(()), params)
    a, b = net.apply(params, np.zeros(()))
    assert np.array_equal(np.ones(()), a)
    assert np.array_equal(np.ones(()), b)
示例#9
0
def test_parameter_sharing_between_multiple_parents():
    p = Parameter(lambda key: np.ones(()))

    @parametrized
    def wrapped():
        return p()

    @parametrized
    def net():
        return wrapped(), p()

    params = net.init_parameters(key=PRNGKey(0))
    assert 1 == len(params)
    assert np.array_equal(np.ones(()), params.wrapped.parameter)
    a, b = net.apply(params)
    assert np.array_equal(np.ones(()), a)
    assert np.array_equal(np.ones(()), b)
示例#10
0
 def dense(inputs):
     kernel = Parameter(lambda key: kernel_init(key, (inputs.shape[-1], out_dim)))()
     bias = Parameter(lambda key: bias_init(key, (out_dim,)))()
     return np.dot(inputs, kernel) + bias
示例#11
0
 def wrapper():
     return Parameter(lambda _: (np.zeros(()), np.zeros(())))()
示例#12
0
 def scalar():
     return Parameter(lambda key: np.zeros(()))()
示例#13
0
 def dense(inputs):
     kernel = Parameter(lambda rng: kernel_init(rng, (inputs.shape[-1],
                                                      out_dim)))(inputs)
     bias = Parameter(lambda rng: bias_init(rng, (out_dim, )))(inputs)
     return np.dot(inputs, kernel) + bias
示例#14
0
 def wrapper(dummy_inputs):
     return Parameter(lambda _: (np.zeros(()), np.zeros(())))(dummy_inputs)