Example #1
0
def test_parametrized_jit(jitted_fun):
    net = parametrized(jitted_fun)
    inputs = random_inputs((2, ))
    params = net.init_parameters(inputs, key=PRNGKey(0))

    assert 'fun' == type(params).__name__
    assert (3, ) == params.dense.bias.shape

    params_ = net.init_parameters(inputs, key=PRNGKey(0))
    assert_parameters_equal(params, params_)

    out = net.apply(params, inputs)
    assert out.shape == (3, )
    assert np.allclose([0.84194356, -1.5927866, -1.7411114], out)

    # run twice to cover cached jit call
    out_ = net.apply(params, inputs)
    assert np.allclose(out, out_)

    out = net.apply(params, inputs, jit=True)
    assert np.allclose(out, out_)

    out_ = net.apply(params, inputs, jit=True)
    assert np.allclose(out, out_)

    out_ = net.apply_from({net: params}, inputs, jit=True)
    assert np.allclose(out, out_)
Example #2
0
def test_Batched():
    out_dim = 1

    @parametrized
    def unbatched_dense(input):
        kernel = parameter((out_dim, input.shape[-1]), ones)
        bias = parameter((out_dim, ), ones)
        return np.dot(kernel, input) + bias

    batch_size = 4

    unbatched_params = unbatched_dense.init_parameters(np.zeros(2),
                                                       key=PRNGKey(0))
    out = unbatched_dense.apply(unbatched_params, np.ones(2))
    assert np.array([3.]) == out

    dense_apply = vmap(unbatched_dense.apply, (None, 0))
    out_batched_ = dense_apply(unbatched_params, np.ones((batch_size, 2)))
    assert np.array_equal(np.stack([out] * batch_size), out_batched_)

    dense = Batched(unbatched_dense)
    params = dense.init_parameters(np.ones((batch_size, 2)), key=PRNGKey(0))
    assert_parameters_equal((unbatched_params, ), params)
    out_batched = dense.apply(params, np.ones((batch_size, 2)))
    assert np.array_equal(out_batched_, out_batched)
Example #3
0
def test_internal_param_sharing2():
    @parametrized
    def shared_net(inputs, layer=Sequential(Dense(2, zeros, zeros), relu)):
        inputs = layer(inputs)
        return layer(inputs)

    inputs = np.zeros((1, 2))
    params = shared_net.init_parameters(inputs, key=PRNGKey(0))

    assert_parameters_equal((((np.zeros((2, 2)), np.zeros(2)), ), ), params)
    out = shared_net.apply(params, inputs)
    assert np.array_equal(np.zeros((1, 2)), out)
Example #4
0
def test_submodule_without_inputs():
    @parametrized
    def scalar():
        return Parameter(lambda key: np.zeros(()))()

    params = scalar.init_parameters(key=PRNGKey(0))
    assert_parameters_equal((np.zeros(()), ), params)

    out = scalar.apply(params)
    assert np.zeros(()) == out

    out_ = scalar.apply(params, jit=True)
    assert out == out_
Example #5
0
def test_external_param_sharing():
    layer = Dense(2, zeros, zeros)
    shared_net = Sequential(layer, layer)

    inputs = np.zeros((1, 2))
    params = shared_net.init_parameters(inputs, key=PRNGKey(0))
    assert_parameters_equal(((np.zeros((2, 2)), np.zeros(2)), ), params)

    out = shared_net.apply(params, inputs)
    assert np.array_equal(np.zeros((1, 2)), out)

    out = shared_net.apply(params, inputs, jit=True)
    assert np.array_equal(np.zeros((1, 2)), out)
Example #6
0
def test_submodule_without_inputs():
    @parametrized
    def scalar():
        return Parameter(lambda: np.zeros(()))

    params = scalar.init_parameters(PRNGKey(0))
    assert_parameters_equal((), params)

    out = scalar.apply(params)
    assert np.array_equal(np.zeros(()), out)

    out_ = scalar.apply(params, jit=True)
    assert np.array_equal(out, out_)
Example #7
0
def test_no_params():
    @parametrized
    def double(inputs):
        return 2 * inputs

    inputs = np.zeros((1, 3))
    params = double.init_parameters(inputs, key=PRNGKey(0))
    assert_parameters_equal((), params)

    out = double.apply(params, inputs)
    assert np.array_equal(np.zeros((1, 3)), out)

    out_ = double.apply(params, inputs, jit=True)
    assert np.array_equal(out, out_)
Example #8
0
def test_internal_param_sharing():
    @parametrized
    def shared_net(inputs, layer=Dense(2, zeros, zeros)):
        return layer(layer(inputs))

    inputs = np.zeros((1, 2))
    params = shared_net.init_parameters(PRNGKey(0), inputs)
    assert_parameters_equal(((np.zeros((2, 2)), np.zeros(2),),), params)

    out = shared_net.apply(params, inputs)
    assert np.array_equal(np.zeros((1, 2)), out)

    out_ = shared_net.apply(params, inputs, jit=True)
    assert np.array_equal(out, out_)
Example #9
0
def test_Dense_shape(Dense=Dense):
    net = Dense(2, kernel_init=zeros, bias_init=zeros)
    inputs = np.zeros((1, 3))

    params = net.init_parameters(PRNGKey(0), inputs)
    assert_parameters_equal((np.zeros((3, 2)), np.zeros(2)), params)

    out = net.apply(params, inputs)
    assert np.array_equal(np.zeros((1, 2)), out)

    out_ = jit(net.apply)(params, inputs)
    assert np.array_equal(out, out_)

    params_ = net.shaped(inputs).init_parameters(PRNGKey(0))
    assert_parameters_equal(params, params_)
Example #10
0
def test_parameters_from():
    layer = Dense(2)
    net = Sequential(layer, relu)
    inputs = np.zeros((1, 3))
    layer_params = layer.init_parameters(inputs, key=PRNGKey(0))

    params_ = net.parameters_from({layer: layer_params}, inputs)
    assert_parameters_equal((layer_params, ), params_)

    out = net.apply(params_, inputs)

    out_ = net.apply_from({layer: layer_params}, inputs)
    assert np.array_equal(out, out_)

    out_ = net.apply_from({layer: layer_params}, inputs, jit=True)
    assert np.array_equal(out, out_)
Example #11
0
def test_external_submodule2():
    layer = Dense(2, zeros, zeros)

    @parametrized
    def net(inputs):
        return layer(inputs)

    inputs = np.zeros((1, 2))

    params = net.init_parameters(inputs, key=PRNGKey(0))
    assert_parameters_equal(((np.zeros((2, 2)), np.zeros(2)), ), params)

    out = net.apply(params, inputs)
    assert np.array_equal(np.zeros((1, 2)), out)

    out_ = net.apply(params, inputs, jit=True)
    assert np.array_equal(out, out_)
Example #12
0
def test_parameters_from_sharing_between_multiple_parents():
    a = Dense(2)
    b = Sequential(a, np.sum)

    @parametrized
    def net(inputs):
        return a(inputs), b(inputs)

    inputs = np.zeros((1, 3))
    a_params = a.init_parameters(inputs, key=PRNGKey(0))
    out = a.apply(a_params, inputs)

    params = net.parameters_from({a: a_params}, inputs)
    assert_dense_parameters_equal(a_params, params.dense)
    assert_parameters_equal((), params.sequential)
    assert 2 == len(params)
    out_, _ = net.apply(params, inputs)
    assert np.array_equal(out, out_)
Example #13
0
def test_scan_parametrized_cell_without_params():
    @parametrized
    def cell(carry, x):
        return np.array([2]) * carry * x, np.array([2]) * carry * x

    @parametrized
    def rnn(inputs):
        _, outs = lax.scan(cell, np.zeros((2, )), inputs)
        return outs

    inputs = np.zeros((3, ))

    params = rnn.init_parameters(inputs, key=PRNGKey(0))
    assert_parameters_equal(((), ), params)

    outs = rnn.apply(params, inputs)

    assert (3, 2) == outs.shape
Example #14
0
def test_parameters_from_shared_submodules():
    sublayer = Dense(2)
    a = Sequential(sublayer, relu)
    b = Sequential(sublayer, np.sum)

    @parametrized
    def net(inputs):
        return a(inputs) * b(inputs)

    inputs = np.zeros((1, 3))
    a_params = a.init_parameters(inputs, key=PRNGKey(0))
    out = a.apply(a_params, inputs)

    params = net.parameters_from({a: a_params}, inputs)
    assert_parameters_equal(a_params.dense.kernel,
                            params.sequential0.dense.kernel)
    assert_parameters_equal((), params.sequential1)
    out = net.apply(params, inputs)

    out_ = net.apply_from({a: a_params}, inputs)
    assert np.array_equal(out, out_)

    out_ = net.apply_from({a: a_params}, inputs, jit=True)
    assert np.array_equal(out, out_)

    out_ = net.apply_from({a.shaped(inputs): a_params}, inputs)
    assert np.array_equal(out, out_)

    out_ = net.apply_from({a.shaped(inputs): a_params}, inputs, jit=True)
    assert np.array_equal(out, out_)

    out_ = net.shaped(inputs).apply_from({a: a_params})
    assert np.array_equal(out, out_)

    out_ = net.shaped(inputs).apply_from({a: a_params}, jit=True)
    assert np.array_equal(out, out_)

    out_ = net.shaped(inputs).apply_from({a.shaped(inputs): a_params})
    assert np.array_equal(out, out_)

    out_ = net.shaped(inputs).apply_from({a.shaped(inputs): a_params},
                                         jit=True)
    assert np.array_equal(out, out_)