def test_submodule_reuse(): inputs = np.zeros((1, 2)) layer = Dense(5) net1 = Sequential(layer, Dense(2)) net2 = Sequential(layer, Dense(3)) layer_params = layer.init_parameters(inputs, key=PRNGKey(0)) net1_params = net1.init_parameters(inputs, key=PRNGKey(1), reuse={layer: layer_params}) net2_params = net2.init_parameters(inputs, key=PRNGKey(2), reuse={layer: layer_params}) out1 = net1.apply(net1_params, inputs) assert out1.shape == (1, 2) out2 = net2.apply(net2_params, inputs) assert out2.shape == (1, 3) assert_dense_parameters_equal(layer_params, net1_params[0]) assert_dense_parameters_equal(layer_params, net2_params[0]) new_layer_params = layer.init_parameters(inputs, key=PRNGKey(3)) combined_params = net1.parameters_from( { net1: net1_params, layer: new_layer_params }, inputs) assert_dense_parameters_equal(new_layer_params, combined_params.dense0) assert_dense_parameters_equal(net1_params.dense1, combined_params.dense1)
def test_submodule_reuse_top_level(): net = Dense(2) inputs = np.zeros((1, 3)) params = net.init_parameters(inputs, key=PRNGKey(0)) out = net.apply(params, inputs) params_ = net.init_parameters(inputs, key=PRNGKey(1), reuse={net: params}) assert_dense_parameters_equal(params, params_) out_ = net.apply(params_, inputs) assert np.array_equal(out, out_)
def test_reuse_api(): inputs = np.zeros((1, 2)) net = Dense(5) net_params = net.init_parameters(inputs, key=PRNGKey(0)) # train net params... transfer_net = Sequential(net, relu, Dense(2)) transfer_net_params = transfer_net.init_parameters(inputs, key=PRNGKey(1), reuse={net: net_params}) assert net_params == transfer_net_params.dense0
def test_nested_module_without_inputs(): dense = Dense(2) inputs = np.zeros((1, 3)) params = dense.init_parameters(inputs, key=PRNGKey(0)) assert (3, 2) == params.kernel.shape assert (2, ) == params.bias.shape assert str(dense).startswith('dense') out = dense.apply(params, inputs) assert (1, 2) == out.shape out_ = dense.apply(params, inputs, jit=True) assert np.allclose(out, out_)
def test_Dense_shape(Dense=Dense): net = Dense(2, kernel_init=zeros, bias_init=zeros) inputs = np.zeros((1, 3)) params = net.init_parameters(PRNGKey(0), inputs) assert_parameters_equal((np.zeros((3, 2)), np.zeros(2)), params) out = net.apply(params, inputs) assert np.array_equal(np.zeros((1, 2)), out) out_ = jit(net.apply)(params, inputs) assert np.array_equal(out, out_) params_ = net.shaped(inputs).init_parameters(PRNGKey(0)) assert_parameters_equal(params, params_)
def test_parameters_from_top_level(): net = Dense(2) inputs = np.zeros((1, 3)) params = net.init_parameters(inputs, key=PRNGKey(0)) out = net.apply(params, inputs) params_ = net.parameters_from({net: params}, inputs) assert_dense_parameters_equal(params, params_) out_ = net.apply(params_, inputs) assert np.array_equal(out, out_) out_ = net.apply_from({net: params}, inputs) assert np.array_equal(out, out_) out_ = net.apply_from({net: params}, inputs, jit=True) assert np.array_equal(out, out_)
def test_parameters_from(): layer = Dense(2) net = Sequential(layer, relu) inputs = np.zeros((1, 3)) layer_params = layer.init_parameters(inputs, key=PRNGKey(0)) params_ = net.parameters_from({layer: layer_params}, inputs) assert_parameters_equal((layer_params, ), params_) out = net.apply(params_, inputs) out_ = net.apply_from({layer: layer_params}, inputs) assert np.array_equal(out, out_) out_ = net.apply_from({layer: layer_params}, inputs, jit=True) assert np.array_equal(out, out_)
def test_Parameter_dense(): def Dense(out_dim, kernel_init=glorot_normal(), bias_init=normal()): @parametrized def dense(inputs): kernel = parameter((inputs.shape[-1], out_dim), kernel_init) bias = parameter((out_dim,), bias_init) return np.dot(inputs, kernel) + bias return dense net = Dense(2) inputs = np.zeros((1, 3)) params = net.init_parameters(inputs, key=PRNGKey(0)) assert (3, 2) == params.parameter0.shape assert (2,) == params.parameter1.shape out = net.apply(params, inputs, jit=True) assert (1, 2) == out.shape
def test_parameters_from_sharing_between_multiple_parents(): a = Dense(2) b = Sequential(a, np.sum) @parametrized def net(inputs): return a(inputs), b(inputs) inputs = np.zeros((1, 3)) a_params = a.init_parameters(inputs, key=PRNGKey(0)) out = a.apply(a_params, inputs) params = net.parameters_from({a: a_params}, inputs) assert_dense_parameters_equal(a_params, params.dense) assert_parameters_equal((), params.sequential) assert 2 == len(params) out_, _ = net.apply(params, inputs) assert np.array_equal(out, out_)
def test_submodule_reuse(): inputs = np.zeros((1, 2)) layer = Dense(5) net1 = Sequential(layer, Dense(2)) net2 = Sequential(layer, Dense(3)) layer_params = layer.init_parameters(PRNGKey(0), inputs) net1_params = net1.init_parameters(PRNGKey(1), inputs, reuse={layer: layer_params}) net2_params = net2.init_parameters(PRNGKey(2), inputs, reuse={layer: layer_params}) out1 = net1.apply(net1_params, inputs) assert out1.shape == (1, 2) out2 = net2.apply(net2_params, inputs) assert out2.shape == (1, 3) assert_dense_params_equal(layer_params, net1_params[0]) assert_dense_params_equal(layer_params, net2_params[0])
def test_parameters_from_subsubmodule(): subsublayer = Dense(2) sublayer = Sequential(subsublayer, relu) net = Sequential(sublayer, np.sum) inputs = np.zeros((1, 3)) params = net.init_parameters(inputs, key=PRNGKey(0)) out = net.apply(params, inputs) subsublayer_params = subsublayer.init_parameters(inputs, key=PRNGKey(0)) params_ = net.parameters_from({subsublayer: subsublayer_params}, inputs) assert_dense_parameters_equal(subsublayer_params, params_[0][0]) out_ = net.apply(params_, inputs) assert out.shape == out_.shape out_ = net.apply_from({subsublayer: subsublayer_params}, inputs) assert out.shape == out_.shape out_ = net.apply_from({subsublayer: subsublayer_params}, inputs, jit=True) assert out.shape == out_.shape