示例#1
0
def test_Reparametrized():
    @parametrized
    def net(inputs):
        return parameter((), lambda rng, shape: 2 * np.ones(shape), inputs)

    scaled_net = Reparametrized(net, reparametrization_factory=Scaled)

    inputs = np.zeros(())
    params = scaled_net.init_parameters(PRNGKey(0), inputs)

    reg_loss_out = scaled_net.apply(params, inputs)

    assert 4 == reg_loss_out
示例#2
0
def test_Reparametrized_unparametrized_transform():
    def doubled(params):
        return 2 * params

    @parametrized
    def net():
        return parameter((), lambda key, shape: 2 * np.ones(shape))

    scared_params = Reparametrized(net,
                                   reparametrization_factory=lambda: doubled)
    params = scared_params.init_parameters(key=PRNGKey(0))
    reg_loss_out = scared_params.apply(params)
    assert 4 == reg_loss_out
示例#3
0
def test_Reparametrized_unparametrized_transform():
    def doubled(params):
        return 2 * params

    @parametrized
    def net(inputs):
        return parameter((), lambda rng, shape: 2 * np.ones(shape), inputs)

    scared_params = Reparametrized(net,
                                   reparametrization_factory=lambda: doubled)

    inputs = np.zeros(())
    params = scared_params.init_parameters(PRNGKey(0), inputs)

    reg_loss_out = scared_params.apply(params, inputs)

    assert 4 == reg_loss_out