Exemplo n.º 1
0
def test_parametrized_jit(jitted_fun):
    net = parametrized(jitted_fun)
    inputs = random_inputs((2, ))
    params = net.init_parameters(inputs, key=PRNGKey(0))

    assert 'fun' == type(params).__name__
    assert (3, ) == params.dense.bias.shape

    params_ = net.init_parameters(inputs, key=PRNGKey(0))
    assert_parameters_equal(params, params_)

    out = net.apply(params, inputs)
    assert out.shape == (3, )
    assert np.allclose([0.84194356, -1.5927866, -1.7411114], out)

    # run twice to cover cached jit call
    out_ = net.apply(params, inputs)
    assert np.allclose(out, out_)

    out = net.apply(params, inputs, jit=True)
    assert np.allclose(out, out_)

    out_ = net.apply(params, inputs, jit=True)
    assert np.allclose(out, out_)

    out_ = net.apply_from({net: params}, inputs, jit=True)
    assert np.allclose(out, out_)
Exemplo n.º 2
0
def test_Conv1DTranspose_runs(channels, filter_shape, padding, strides,
                              input_shape):
    convt = Conv1DTranspose(channels,
                            filter_shape,
                            strides=strides,
                            padding=padding)
    inputs = random_inputs(input_shape)
    params = convt.init_parameters(PRNGKey(0), inputs)
    convt.apply(params, inputs)
Exemplo n.º 3
0
def test_Conv_runs(channels, filter_shape, padding, strides, input_shape,
                   dilation):
    conv = Conv(channels,
                filter_shape,
                strides=strides,
                padding=padding,
                dilation=dilation)
    inputs = random_inputs(input_shape)
    params = conv.init_parameters(PRNGKey(0), inputs)
    conv.apply(params, inputs)
Exemplo n.º 4
0
def test_external_submodule_partial_jit():
    layer = Dense(3)

    @parametrized
    def net_fun(inputs):
        return jit(lambda x: 2 * x)(layer(inputs))

    inputs = random_inputs((2,))
    params = net_fun.init_parameters(PRNGKey(0), inputs)
    out = net_fun.apply(params, inputs)
    assert out.shape == (3,)
Exemplo n.º 5
0
def test_BatchNorm_shape_NCHW(center, scale):
    input_shape = (4, 5, 6, 7)
    batch_norm = BatchNorm(axis=(0, 2, 3), center=center, scale=scale)

    inputs = random_inputs(input_shape)
    params = batch_norm.init_parameters(PRNGKey(0), inputs)
    out = batch_norm.apply(params, inputs)

    assert out.shape == input_shape
    if center:
        assert params.beta.shape == (5, )
    if scale:
        assert params.gamma.shape == (5, )
Exemplo n.º 6
0
def test_default_argument_submodule():
    @parametrized
    def net(inputs, layer=Dense(3)):
        return 2 * layer(inputs)

    inputs = random_inputs((2, ))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    out = net.apply(params, inputs)
    assert out.shape == (3, )

    out_ = net.apply(params, inputs)
    assert np.array_equal(out, out_)

    out_ = net.apply(params, inputs, jit=True)
    assert np.allclose(out, out_)
Exemplo n.º 7
0
def test_inline_submodule():
    @parametrized
    def net_fun(inputs):
        layer = Dense(3)
        return 2 * layer(inputs)

    inputs = random_inputs((2,))
    params = net_fun.init_parameters(PRNGKey(0), inputs)
    out = net_fun.apply(params, inputs)
    assert out.shape == (3,)

    out_ = net_fun.apply(params, inputs)
    assert np.array_equal(out, out_)

    out_ = net_fun.apply(params, inputs, jit=True)
    assert np.allclose(out, out_)