def test_parametrized_jit(jitted_fun): net = parametrized(jitted_fun) inputs = random_inputs((2, )) params = net.init_parameters(inputs, key=PRNGKey(0)) assert 'fun' == type(params).__name__ assert (3, ) == params.dense.bias.shape params_ = net.init_parameters(inputs, key=PRNGKey(0)) assert_parameters_equal(params, params_) out = net.apply(params, inputs) assert out.shape == (3, ) assert np.allclose([0.84194356, -1.5927866, -1.7411114], out) # run twice to cover cached jit call out_ = net.apply(params, inputs) assert np.allclose(out, out_) out = net.apply(params, inputs, jit=True) assert np.allclose(out, out_) out_ = net.apply(params, inputs, jit=True) assert np.allclose(out, out_) out_ = net.apply_from({net: params}, inputs, jit=True) assert np.allclose(out, out_)
def test_Conv1DTranspose_runs(channels, filter_shape, padding, strides, input_shape): convt = Conv1DTranspose(channels, filter_shape, strides=strides, padding=padding) inputs = random_inputs(input_shape) params = convt.init_parameters(PRNGKey(0), inputs) convt.apply(params, inputs)
def test_Conv_runs(channels, filter_shape, padding, strides, input_shape, dilation): conv = Conv(channels, filter_shape, strides=strides, padding=padding, dilation=dilation) inputs = random_inputs(input_shape) params = conv.init_parameters(PRNGKey(0), inputs) conv.apply(params, inputs)
def test_external_submodule_partial_jit(): layer = Dense(3) @parametrized def net_fun(inputs): return jit(lambda x: 2 * x)(layer(inputs)) inputs = random_inputs((2,)) params = net_fun.init_parameters(PRNGKey(0), inputs) out = net_fun.apply(params, inputs) assert out.shape == (3,)
def test_BatchNorm_shape_NCHW(center, scale): input_shape = (4, 5, 6, 7) batch_norm = BatchNorm(axis=(0, 2, 3), center=center, scale=scale) inputs = random_inputs(input_shape) params = batch_norm.init_parameters(PRNGKey(0), inputs) out = batch_norm.apply(params, inputs) assert out.shape == input_shape if center: assert params.beta.shape == (5, ) if scale: assert params.gamma.shape == (5, )
def test_default_argument_submodule(): @parametrized def net(inputs, layer=Dense(3)): return 2 * layer(inputs) inputs = random_inputs((2, )) params = net.init_parameters(inputs, key=PRNGKey(0)) out = net.apply(params, inputs) assert out.shape == (3, ) out_ = net.apply(params, inputs) assert np.array_equal(out, out_) out_ = net.apply(params, inputs, jit=True) assert np.allclose(out, out_)
def test_inline_submodule(): @parametrized def net_fun(inputs): layer = Dense(3) return 2 * layer(inputs) inputs = random_inputs((2,)) params = net_fun.init_parameters(PRNGKey(0), inputs) out = net_fun.apply(params, inputs) assert out.shape == (3,) out_ = net_fun.apply(params, inputs) assert np.array_equal(out, out_) out_ = net_fun.apply(params, inputs, jit=True) assert np.allclose(out, out_)