示例#1
0
def test_Reparametrized():
    @parametrized
    def net(inputs):
        return parameter((), lambda rng, shape: 2 * np.ones(shape), inputs)

    scaled_net = Reparametrized(net, reparametrization_factory=Scaled)

    inputs = np.zeros(())
    params = scaled_net.init_parameters(PRNGKey(0), inputs)

    reg_loss_out = scaled_net.apply(params, inputs)

    assert 4 == reg_loss_out
示例#2
0
def test_Reparametrized_unparametrized_transform():
    def doubled(params):
        return 2 * params

    @parametrized
    def net():
        return parameter((), lambda key, shape: 2 * np.ones(shape))

    scared_params = Reparametrized(net,
                                   reparametrization_factory=lambda: doubled)
    params = scared_params.init_parameters(key=PRNGKey(0))
    reg_loss_out = scared_params.apply(params)
    assert 4 == reg_loss_out
示例#3
0
def test_Reparametrized_unparametrized_transform():
    def doubled(params):
        return 2 * params

    @parametrized
    def net(inputs):
        return parameter((), lambda rng, shape: 2 * np.ones(shape), inputs)

    scared_params = Reparametrized(net,
                                   reparametrization_factory=lambda: doubled)

    inputs = np.zeros(())
    params = scared_params.init_parameters(PRNGKey(0), inputs)

    reg_loss_out = scared_params.apply(params, inputs)

    assert 4 == reg_loss_out
示例#4
0
def test_reparametrized_submodule():
    net = Sequential(Conv(2, (3, 3)), relu, Conv(2, (3, 3)), relu, flatten,
                     Reparametrized(Sequential(Dense(2), relu, Dense(2)), Scaled))

    input = np.ones((1, 3, 3, 1))
    params = net.init_parameters(input, key=PRNGKey(0))
    assert (2, 2) == params.reparametrized.model.dense1.kernel.shape

    out = net.apply(params, input)
    assert (1, 2) == out.shape