Exemplo n.º 1
0
def test_submodule_reuse():
    inputs = np.zeros((1, 2))

    layer = Dense(5)
    net1 = Sequential(layer, Dense(2))
    net2 = Sequential(layer, Dense(3))

    layer_params = layer.init_parameters(inputs, key=PRNGKey(0))
    net1_params = net1.init_parameters(inputs,
                                       key=PRNGKey(1),
                                       reuse={layer: layer_params})
    net2_params = net2.init_parameters(inputs,
                                       key=PRNGKey(2),
                                       reuse={layer: layer_params})

    out1 = net1.apply(net1_params, inputs)
    assert out1.shape == (1, 2)

    out2 = net2.apply(net2_params, inputs)
    assert out2.shape == (1, 3)

    assert_dense_parameters_equal(layer_params, net1_params[0])
    assert_dense_parameters_equal(layer_params, net2_params[0])

    new_layer_params = layer.init_parameters(inputs, key=PRNGKey(3))
    combined_params = net1.parameters_from(
        {
            net1: net1_params,
            layer: new_layer_params
        }, inputs)
    assert_dense_parameters_equal(new_layer_params, combined_params.dense0)
    assert_dense_parameters_equal(net1_params.dense1, combined_params.dense1)
Exemplo n.º 2
0
def test_parameters_from():
    layer = Dense(2)
    net = Sequential(layer, relu)
    inputs = np.zeros((1, 3))
    layer_params = layer.init_parameters(inputs, key=PRNGKey(0))

    params_ = net.parameters_from({layer: layer_params}, inputs)
    assert_parameters_equal((layer_params, ), params_)

    out = net.apply(params_, inputs)

    out_ = net.apply_from({layer: layer_params}, inputs)
    assert np.array_equal(out, out_)

    out_ = net.apply_from({layer: layer_params}, inputs, jit=True)
    assert np.array_equal(out, out_)
Exemplo n.º 3
0
def test_parameters_from_subsubmodule():
    subsublayer = Dense(2)
    sublayer = Sequential(subsublayer, relu)
    net = Sequential(sublayer, np.sum)
    inputs = np.zeros((1, 3))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    out = net.apply(params, inputs)

    subsublayer_params = subsublayer.init_parameters(inputs, key=PRNGKey(0))

    params_ = net.parameters_from({subsublayer: subsublayer_params}, inputs)
    assert_dense_parameters_equal(subsublayer_params, params_[0][0])
    out_ = net.apply(params_, inputs)
    assert out.shape == out_.shape

    out_ = net.apply_from({subsublayer: subsublayer_params}, inputs)
    assert out.shape == out_.shape

    out_ = net.apply_from({subsublayer: subsublayer_params}, inputs, jit=True)
    assert out.shape == out_.shape