def test_Reparametrized(): @parametrized def net(inputs): return parameter((), lambda rng, shape: 2 * np.ones(shape), inputs) scaled_net = Reparametrized(net, reparametrization_factory=Scaled) inputs = np.zeros(()) params = scaled_net.init_parameters(PRNGKey(0), inputs) reg_loss_out = scaled_net.apply(params, inputs) assert 4 == reg_loss_out
def test_Reparametrized_unparametrized_transform(): def doubled(params): return 2 * params @parametrized def net(): return parameter((), lambda key, shape: 2 * np.ones(shape)) scared_params = Reparametrized(net, reparametrization_factory=lambda: doubled) params = scared_params.init_parameters(key=PRNGKey(0)) reg_loss_out = scared_params.apply(params) assert 4 == reg_loss_out
def test_Reparametrized_unparametrized_transform(): def doubled(params): return 2 * params @parametrized def net(inputs): return parameter((), lambda rng, shape: 2 * np.ones(shape), inputs) scared_params = Reparametrized(net, reparametrization_factory=lambda: doubled) inputs = np.zeros(()) params = scared_params.init_parameters(PRNGKey(0), inputs) reg_loss_out = scared_params.apply(params, inputs) assert 4 == reg_loss_out
def test_reparametrized_submodule(): net = Sequential(Conv(2, (3, 3)), relu, Conv(2, (3, 3)), relu, flatten, Reparametrized(Sequential(Dense(2), relu, Dense(2)), Scaled)) input = np.ones((1, 3, 3, 1)) params = net.init_parameters(input, key=PRNGKey(0)) assert (2, 2) == params.reparametrized.model.dense1.kernel.shape out = net.apply(params, input) assert (1, 2) == out.shape