def test_parameters_from(): layer = Dense(2) net = Sequential(layer, relu) inputs = np.zeros((1, 3)) layer_params = layer.init_parameters(inputs, key=PRNGKey(0)) params_ = net.parameters_from({layer: layer_params}, inputs) assert_parameters_equal((layer_params, ), params_) out = net.apply(params_, inputs) out_ = net.apply_from({layer: layer_params}, inputs) assert np.array_equal(out, out_) out_ = net.apply_from({layer: layer_params}, inputs, jit=True) assert np.array_equal(out, out_)
def test_parameters_from_subsubmodule(): subsublayer = Dense(2) sublayer = Sequential(subsublayer, relu) net = Sequential(sublayer, np.sum) inputs = np.zeros((1, 3)) params = net.init_parameters(inputs, key=PRNGKey(0)) out = net.apply(params, inputs) subsublayer_params = subsublayer.init_parameters(inputs, key=PRNGKey(0)) params_ = net.parameters_from({subsublayer: subsublayer_params}, inputs) assert_dense_parameters_equal(subsublayer_params, params_[0][0]) out_ = net.apply(params_, inputs) assert out.shape == out_.shape out_ = net.apply_from({subsublayer: subsublayer_params}, inputs) assert out.shape == out_.shape out_ = net.apply_from({subsublayer: subsublayer_params}, inputs, jit=True) assert out.shape == out_.shape
def test_flatten_shape(): conv = Conv(2, filter_shape=(3, 3), padding='SAME', kernel_init=zeros, bias_init=zeros) inputs = np.zeros((1, 5, 5, 2)) params = conv.init_parameters(PRNGKey(0), inputs) out = conv.apply(params, inputs) assert np.array_equal(np.zeros((1, 5, 5, 2)), out) flattened = Sequential(conv, flatten) out = flattened.apply_from({conv: params}, inputs) assert np.array_equal(np.zeros((1, 50)), out)