Ejemplo n.º 1
0
def test_internal_param_sharing2():
    @parametrized
    def shared_net(inputs, layer=Sequential(Dense(2, zeros, zeros), relu)):
        inputs = layer(inputs)
        return layer(inputs)

    inputs = np.zeros((1, 2))
    params = shared_net.init_parameters(PRNGKey(0), inputs)

    assert_params_equal((((np.zeros((2, 2)), np.zeros(2)),),), params)
    out = shared_net.apply(params, inputs)
    assert np.array_equal(np.zeros((1, 2)), out)
Ejemplo n.º 2
0
def test_submodule_without_inputs():
    @parametrized
    def scalar():
        return Parameter(lambda: np.zeros(()))

    params = scalar.init_parameters(PRNGKey(0))
    assert_params_equal((), params)

    out = scalar.apply(params)
    assert np.array_equal(np.zeros(()), out)

    out_ = scalar.apply(params, jit=True)
    assert np.array_equal(out, out_)
Ejemplo n.º 3
0
def test_external_param_sharing():
    layer = Dense(2, zeros, zeros)
    shared_net = Sequential(layer, layer)

    inputs = np.zeros((1, 2))
    params = shared_net.init_parameters(PRNGKey(0), inputs)
    assert_params_equal(((np.zeros((2, 2)), np.zeros(2)),), params)

    out = shared_net.apply(params, inputs)
    assert np.array_equal(np.zeros((1, 2)), out)

    out = shared_net.apply(params, inputs, jit=True)
    assert np.array_equal(np.zeros((1, 2)), out)
Ejemplo n.º 4
0
def test_no_params():
    @parametrized
    def double(inputs):
        return 2 * inputs

    inputs = np.zeros((1, 3))
    params = double.init_parameters(PRNGKey(0), inputs)
    assert_params_equal((), params)

    out = double.apply(params, inputs)
    assert np.array_equal(np.zeros((1, 3)), out)

    out_ = double.apply(params, inputs, jit=True)
    assert np.array_equal(out, out_)
Ejemplo n.º 5
0
def test_internal_param_sharing():
    @parametrized
    def shared_net(inputs, layer=Dense(2, zeros, zeros)):
        return layer(layer(inputs))

    inputs = np.zeros((1, 2))
    params = shared_net.init_parameters(PRNGKey(0), inputs)
    assert_params_equal(((np.zeros((2, 2)), np.zeros(2),),), params)

    out = shared_net.apply(params, inputs)
    assert np.array_equal(np.zeros((1, 2)), out)

    out_ = shared_net.apply(params, inputs, jit=True)
    assert np.array_equal(out, out_)
Ejemplo n.º 6
0
def test_Dense_shape(Dense=Dense):
    net = Dense(2, kernel_init=zeros, bias_init=zeros)
    inputs = np.zeros((1, 3))

    params = net.init_parameters(PRNGKey(0), inputs)
    assert_params_equal((np.zeros((3, 2)), np.zeros(2)), params)

    out = net.apply(params, inputs)
    assert np.array_equal(np.zeros((1, 2)), out)

    out_ = jit(net.apply)(params, inputs)
    assert np.array_equal(out, out_)

    params_ = net.shaped(inputs).init_parameters(PRNGKey(0))
    assert_params_equal(params, params_)
Ejemplo n.º 7
0
def test_params_from():
    layer = Dense(2)
    net = Sequential(layer, relu)
    inputs = np.zeros((1, 3))
    layer_params = layer.init_parameters(PRNGKey(0), inputs)

    params_ = net.parameters_from({layer: layer_params}, inputs)
    assert_params_equal((layer_params,), params_)

    out = net.apply(params_, inputs)

    out_ = net.apply_from({layer: layer_params}, inputs)
    assert np.array_equal(out, out_)

    out_ = net.apply_from({layer: layer_params}, inputs, jit=True)
    assert np.array_equal(out, out_)
Ejemplo n.º 8
0
def test_external_submodule2():
    layer = Dense(2, zeros, zeros)

    @parametrized
    def net(inputs):
        return layer(inputs)

    inputs = np.zeros((1, 2))

    params = net.init_parameters(PRNGKey(0), inputs)
    assert_params_equal(((np.zeros((2, 2)), np.zeros(2)),), params)

    out = net.apply(params, inputs)
    assert np.array_equal(np.zeros((1, 2)), out)

    out_ = net.apply(params, inputs, jit=True)
    assert np.array_equal(out, out_)
Ejemplo n.º 9
0
def test_scan_parametrized_cell_without_params():
    @parametrized
    def cell(carry, x):
        return np.array([2]) * carry * x, np.array([2]) * carry * x

    @parametrized
    def rnn(inputs):
        _, outs = lax.scan(cell, np.zeros((2,)), inputs)
        return outs

    inputs = np.zeros((3,))

    params = rnn.init_parameters(PRNGKey(0), inputs)
    assert_params_equal(((),), params)

    outs = rnn.apply(params, inputs)

    assert (3, 2) == outs.shape
Ejemplo n.º 10
0
def test_params_from_shared_submodules():
    sublayer = Dense(2)
    a = Sequential(sublayer, relu)
    b = Sequential(sublayer, np.sum)

    @parametrized
    def net(inputs):
        return a(inputs) * b(inputs)

    inputs = np.zeros((1, 3))
    a_params = a.init_parameters(PRNGKey(0), inputs)
    out = a.apply(a_params, inputs)

    params = net.parameters_from({a: a_params}, inputs)
    assert_params_equal(a_params.dense.kernel, params.sequential0.dense.kernel)
    assert_params_equal(a_params.dense.kernel, params.sequential1.dense.kernel)
    out = net.apply(params, inputs)

    out_ = net.apply_from({a: a_params}, inputs)
    assert np.array_equal(out, out_)

    out_ = net.apply_from({a: a_params}, inputs, jit=True)
    assert np.array_equal(out, out_)

    out_ = net.apply_from({a.shaped(inputs): a_params}, inputs)
    assert np.array_equal(out, out_)

    out_ = net.apply_from({a.shaped(inputs): a_params}, inputs, jit=True)
    assert np.array_equal(out, out_)

    out_ = net.shaped(inputs).apply_from({a: a_params})
    assert np.array_equal(out, out_)

    out_ = net.shaped(inputs).apply_from({a: a_params}, jit=True)
    assert np.array_equal(out, out_)

    out_ = net.shaped(inputs).apply_from({a.shaped(inputs): a_params})
    assert np.array_equal(out, out_)

    out_ = net.shaped(inputs).apply_from({a.shaped(inputs): a_params}, jit=True)
    assert np.array_equal(out, out_)