def test_Reparametrized(): @parametrized def net(inputs): return parameter((), lambda rng, shape: 2 * np.ones(shape), inputs) scaled_net = Reparametrized(net, reparametrization_factory=Scaled) inputs = np.zeros(()) params = scaled_net.init_parameters(PRNGKey(0), inputs) reg_loss_out = scaled_net.apply(params, inputs) assert 4 == reg_loss_out
def test_Reparametrized_unparametrized_transform(): def doubled(params): return 2 * params @parametrized def net(): return parameter((), lambda key, shape: 2 * np.ones(shape)) scared_params = Reparametrized(net, reparametrization_factory=lambda: doubled) params = scared_params.init_parameters(key=PRNGKey(0)) reg_loss_out = scared_params.apply(params) assert 4 == reg_loss_out
def test_Reparametrized_unparametrized_transform(): def doubled(params): return 2 * params @parametrized def net(inputs): return parameter((), lambda rng, shape: 2 * np.ones(shape), inputs) scared_params = Reparametrized(net, reparametrization_factory=lambda: doubled) inputs = np.zeros(()) params = scared_params.init_parameters(PRNGKey(0), inputs) reg_loss_out = scared_params.apply(params, inputs) assert 4 == reg_loss_out