Exemplo n.º 1
0
def test_ocr_rnn():
    length = 5
    carry_size = 3
    class_count = 4
    inputs = np.zeros((1, length, 4))

    def rnn():
        return Rnn(*GRUCell(carry_size, zeros))

    net = Sequential(
        rnn(),
        rnn(),
        rnn(),
        lambda x: np.reshape(x, (-1, carry_size)
                             ),  # -> same weights for all time steps
        Dense(class_count, zeros, zeros),
        softmax,
        lambda x: np.reshape(x, (-1, length, class_count)))

    params = net.init_parameters(PRNGKey(0), inputs)

    assert len(params) == 4
    cell = params.rnn0.gru_cell
    assert len(cell) == 3
    assert np.array_equal(np.zeros((7, 3)), cell.update_kernel)
    assert np.array_equal(np.zeros((7, 3)), cell.reset_kernel)
    assert np.array_equal(np.zeros((7, 3)), cell.compute_kernel)

    out = net.apply(params, inputs)
    assert np.array_equal(.25 * np.ones((1, 5, 4)), out)
Exemplo n.º 2
0
 def conv_block(inputs):
     main = Sequential(Conv(filters1, (1, 1), strides), BatchNorm(), relu,
                       Conv(filters2, (ks, ks),
                            padding='SAME'), BatchNorm(), relu,
                       Conv(filters3, (1, 1)), BatchNorm())
     shortcut = Sequential(Conv(filters3, (1, 1), strides), BatchNorm())
     return relu(sum((main(inputs), shortcut(inputs))))
Exemplo n.º 3
0
    def res_layer(inputs):
        """
        From original doc string:

        The layer contains a gated filter that connects to dense output
        and to a skip connection:

               |-> [gate]   -|        |-> 1x1 conv -> skip output
               |             |-> (*) -|
        input -|-> [filter] -|        |-> 1x1 conv -|
               |                                    |-> (+) -> dense output
               |------------------------------------|

        Where `[gate]` and `[filter]` are causal convolutions with a
        non-linear activation at the output
        """
        gated = Sequential(
            Conv1D(dilation_channels, (filter_width, ), dilation=(dilation, )),
            sigmoid)(inputs)
        filtered = Sequential(
            Conv1D(dilation_channels, (filter_width, ), dilation=(dilation, )),
            np.tanh)(inputs)
        p = gated * filtered
        out = Conv1D(residual_channels, (1, ), padding='SAME')(p)
        # Add the transformed output of the resblock to the sliced input:
        sliced_inputs = lax.dynamic_slice(
            inputs, [0, inputs.shape[1] - out.shape[1], 0],
            [inputs.shape[0], out.shape[1], inputs.shape[2]])
        new_out = sum(out, sliced_inputs)
        skip = Conv1D(residual_channels, (1, ),
                      padding='SAME')(skip_slice(p, output_width))
        return new_out, skip
Exemplo n.º 4
0
def test_regularized_submodule():
    net = Sequential(Conv(2, (1, 1)), relu, Conv(2, (1, 1)), relu, flatten,
                     L2Regularized(Sequential(Dense(2), relu, Dense(2), np.sum), .1))

    input = np.ones((1, 3, 3, 1))
    params = net.init_parameters(input, key=PRNGKey(0))
    assert (2, 2) == params.regularized.model.dense1.kernel.shape

    out = net.apply(params, input)
    assert () == out.shape
Exemplo n.º 5
0
def test_submodule_reuse():
    inputs = np.zeros((1, 2))

    layer = Dense(5)
    net1 = Sequential(layer, Dense(2))
    net2 = Sequential(layer, Dense(3))

    layer_params = layer.init_parameters(inputs, key=PRNGKey(0))
    net1_params = net1.init_parameters(inputs,
                                       key=PRNGKey(1),
                                       reuse={layer: layer_params})
    net2_params = net2.init_parameters(inputs,
                                       key=PRNGKey(2),
                                       reuse={layer: layer_params})

    out1 = net1.apply(net1_params, inputs)
    assert out1.shape == (1, 2)

    out2 = net2.apply(net2_params, inputs)
    assert out2.shape == (1, 3)

    assert_dense_parameters_equal(layer_params, net1_params[0])
    assert_dense_parameters_equal(layer_params, net2_params[0])

    new_layer_params = layer.init_parameters(inputs, key=PRNGKey(3))
    combined_params = net1.parameters_from(
        {
            net1: net1_params,
            layer: new_layer_params
        }, inputs)
    assert_dense_parameters_equal(new_layer_params, combined_params.dense0)
    assert_dense_parameters_equal(net1_params.dense1, combined_params.dense1)
Exemplo n.º 6
0
def test_reparametrized_submodule():
    net = Sequential(
        Conv(2, (3, 3)), relu, Conv(2, (3, 3)), relu, flatten,
        Reparametrized(Sequential(Dense(2), relu, Dense(2)), Scaled))

    input = np.ones((1, 3, 3, 1))
    params = net.init_parameters(PRNGKey(0), input)
    assert (2, 2) == params.reparametrized.model.dense1.kernel.shape

    out = net.apply(params, input)
    assert (1, 2) == out.shape
Exemplo n.º 7
0
def test_pool_shape(Pool):
    conv = Conv(2,
                filter_shape=(3, 3),
                padding='SAME',
                kernel_init=zeros,
                bias_init=zeros)
    inputs = np.zeros((1, 5, 5, 2))

    pooled = Sequential(conv, Pool(window_shape=(1, 1), strides=(2, 2)))
    params = pooled.init_parameters(PRNGKey(0), inputs)
    out = pooled.apply(params, inputs)
    assert np.array_equal(np.zeros((1, 3, 3, 2)), out)
Exemplo n.º 8
0
def test_input_dependent_nested_modules():
    @parametrized
    def layer(inputs):
        return Dense(inputs.shape[0])(inputs)

    net = Sequential(Dense(3), layer)

    inputs = np.zeros((5, 3))
    params = net.init_parameters(inputs, key=PRNGKey(0))

    out = net.apply(params, inputs)
    assert (5, 5) == out.shape
Exemplo n.º 9
0
def test_reuse_api():
    inputs = np.zeros((1, 2))
    net = Dense(5)
    net_params = net.init_parameters(inputs, key=PRNGKey(0))

    # train net params...

    transfer_net = Sequential(net, relu, Dense(2))
    transfer_net_params = transfer_net.init_parameters(inputs, key=PRNGKey(1),
                                                       reuse={net: net_params})

    assert net_params == transfer_net_params.dense0
Exemplo n.º 10
0
def test_external_param_sharing():
    layer = Dense(2, zeros, zeros)
    shared_net = Sequential(layer, layer)

    inputs = np.zeros((1, 2))
    params = shared_net.init_parameters(inputs, key=PRNGKey(0))
    assert_parameters_equal(((np.zeros((2, 2)), np.zeros(2)), ), params)

    out = shared_net.apply(params, inputs)
    assert np.array_equal(np.zeros((1, 2)), out)

    out = shared_net.apply(params, inputs, jit=True)
    assert np.array_equal(np.zeros((1, 2)), out)
Exemplo n.º 11
0
def test_Sequential_graceful_update_message():
    message = 'Call like Sequential(Dense(10), relu), without "[" and "]". ' \
              '(Or pass iterables with Sequential(*layers).)'
    try:
        Sequential([Dense(2), relu])
        assert False
    except ValueError as e:
        assert message == str(e)

    try:
        Sequential(Dense(2) for _ in range(3))
        assert False
    except ValueError as e:
        assert message == str(e)
Exemplo n.º 12
0
def test_flatten_shape():
    conv = Conv(2,
                filter_shape=(3, 3),
                padding='SAME',
                kernel_init=zeros,
                bias_init=zeros)
    inputs = np.zeros((1, 5, 5, 2))

    params = conv.init_parameters(PRNGKey(0), inputs)
    out = conv.apply(params, inputs)
    assert np.array_equal(np.zeros((1, 5, 5, 2)), out)

    flattened = Sequential(conv, flatten)
    out = flattened.apply_from({conv: params}, inputs)
    assert np.array_equal(np.zeros((1, 50)), out)
Exemplo n.º 13
0
def test_diamond_shared_submodules():
    p = Parameter(lambda rng: np.ones(()))
    a = Sequential(p)
    b = Sequential(p)

    @parametrized
    def net(inputs):
        return a(inputs), b(inputs)

    params = net.init_parameters(PRNGKey(0), np.zeros(()))
    assert 1 == len(params)
    assert np.array_equal(np.ones(()), params)
    a, b = net.apply(params, np.zeros(()))
    assert np.array_equal(np.ones(()), a)
    assert np.array_equal(np.ones(()), b)
Exemplo n.º 14
0
def test_collection_input(type):
    @parametrized
    def net(inputs):
        assert isinstance(inputs, type)
        return inputs[0] * inputs[1] * parameter((), zeros)

    inputs = type((np.zeros(2), np.zeros(2)))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    out = net.apply(params, inputs)
    assert np.array_equal(np.zeros(2), out)

    net = Sequential(net)
    params = net.init_parameters(inputs, key=PRNGKey(0))
    out = net.apply(params, inputs)
    assert np.array_equal(np.zeros(2), out)
Exemplo n.º 15
0
    def identity_block(inputs):
        main = Sequential(Conv(filters1, (1, 1)), BatchNorm(), relu,
                          Conv(filters2, (ks, ks), padding='SAME'),
                          BatchNorm(), relu, Conv(inputs.shape[3], (1, 1)),
                          BatchNorm())

        return relu(sum((main(inputs), inputs)))
Exemplo n.º 16
0
def main(batch_size=256, env_name="CartPole-v1"):
    env = gym.make(env_name)

    policy = Sequential(Dense(64), relu, Dense(env.action_space.n))

    @parametrized
    def loss(observations, actions, rewards_to_go):
        logprobs = log_softmax(policy(observations))
        action_logprobs = logprobs[np.arange(logprobs.shape[0]), actions]
        return -np.mean(action_logprobs * rewards_to_go, axis=0)

    opt = Adam()

    shaped_loss = loss.shaped(np.zeros((1, ) + env.observation_space.shape),
                              np.array([0]), np.array([0]))

    @jit
    def sample_action(state, key, observation):
        loss_params = opt.get_parameters(state)
        logits = policy.apply_from({shaped_loss: loss_params}, observation)
        return sample_categorical(key, logits)

    rng_init, rng = random.split(PRNGKey(0))
    state = opt.init(shaped_loss.init_parameters(key=rng_init))
    returns, observations, actions, rewards_to_go = [], [], [], []

    for i in range(250):
        while len(observations) < batch_size:
            observation = env.reset()
            episode_done = False
            rewards = []

            while not episode_done:
                rng_step, rng = random.split(rng)
                action = sample_action(state, rng_step, observation)
                observations.append(observation)
                actions.append(action)

                observation, reward, episode_done, info = env.step(int(action))
                rewards.append(reward)

            returns.append(onp.sum(rewards))
            rewards_to_go += list(onp.flip(onp.cumsum(onp.flip(rewards))))

        print(f'Batch {i}, recent mean return: {onp.mean(returns[-100:]):.1f}')

        state = opt.update(loss.apply,
                           state,
                           np.array(observations[:batch_size]),
                           np.array(actions[:batch_size]),
                           np.array(rewards_to_go[:batch_size]),
                           jit=True)

        observations = observations[batch_size:]
        actions = actions[batch_size:]
        rewards_to_go = rewards_to_go[batch_size:]
Exemplo n.º 17
0
def test_params_from_shared_submodules2():
    sublayer = Dense(2)
    a = Sequential(sublayer, relu)
    b = Sequential(sublayer, np.sum)

    @parametrized
    def net(inputs):
        return a(inputs), b(inputs)

    inputs = np.zeros((1, 3))
    a_params = a.init_parameters(PRNGKey(0), inputs)
    out = a.apply(a_params, inputs)

    params = net.parameters_from({a: a_params}, inputs)
    assert_dense_params_equal(a_params.dense, params.sequential0.dense)
    assert_dense_params_equal(a_params.dense, params.sequential1.dense)
    # TODO parameters are duplicated, optimization with weight sharing is wrong:
    # TODO instead: assert 1 == len(params)
    out_, _ = net.apply(params, inputs)
    assert np.array_equal(out, out_)
Exemplo n.º 18
0
 def wavenet(inputs):
     hidden = Conv1D(residual_channels, (initial_filter_width, ))(inputs)
     out = np.zeros((hidden.shape[0], out_width, residual_channels),
                    'float32')
     for dilation in dilations:
         res = ResLayer(dilation_channels, residual_channels, filter_width,
                        dilation, out_width)(hidden)
         hidden, out_partial = res
         out += out_partial
     return Sequential(relu, Conv1D(skip_channels, (1, )), relu,
                       Conv1D(3 * nr_mix, (1, )))(out)
Exemplo n.º 19
0
def test_ocr_rnn():
    length = 5
    carry_size = 3
    class_count = 4
    inputs = jnp.zeros((1, length, 4))

    def rnn():
        return Rnn(*GRUCell(carry_size, zeros))

    net = Sequential(
        rnn(),
        rnn(),
        rnn(),
        lambda x: jnp.reshape(x, (-1, carry_size)
                              ),  # -> same weights for all time steps
        Dense(class_count, zeros, zeros),
        softmax,
        lambda x: jnp.reshape(x, (-1, length, class_count)))

    params = net.init_parameters(inputs, key=PRNGKey(0))

    assert len(params) == 4
    cell = params.rnn0.gru_cell
    assert len(cell) == 3
    assert jnp.array_equal(jnp.zeros((7, 3)), cell.update_kernel)
    assert jnp.array_equal(jnp.zeros((7, 3)), cell.reset_kernel)
    assert jnp.array_equal(jnp.zeros((7, 3)), cell.compute_kernel)

    out = net.apply(params, inputs)

    @parametrized
    def cross_entropy(images, targets):
        prediction = net(images)
        return jnp.mean(-jnp.sum(targets * jnp.log(prediction), (1, 2)))

    opt = optimizers.RmsProp(0.003)
    state = opt.init(cross_entropy.init_parameters(inputs, out,
                                                   key=PRNGKey(0)))
    state = opt.update(cross_entropy.apply, state, inputs, out)
    opt.update(cross_entropy.apply, state, inputs, out, jit=True)
Exemplo n.º 20
0
def test_L2Regularized_sequential():
    loss = Sequential(Dense(1, ones, ones), relu, Dense(1, ones, ones), sum)

    reg_loss = L2Regularized(loss, scale=2)

    inputs = np.ones(1)
    params = reg_loss.init_parameters(PRNGKey(0), inputs)
    assert np.array_equal(np.ones((1, 1)), params.model.dense0.kernel)
    assert np.array_equal(np.ones((1, 1)), params.model.dense1.kernel)

    reg_loss_out = reg_loss.apply(params, inputs)

    assert 7 == reg_loss_out
Exemplo n.º 21
0
def test_parameters_from_shared_submodules():
    sublayer = Dense(2)
    a = Sequential(sublayer, relu)
    b = Sequential(sublayer, np.sum)

    @parametrized
    def net(inputs):
        return a(inputs) * b(inputs)

    inputs = np.zeros((1, 3))
    a_params = a.init_parameters(inputs, key=PRNGKey(0))
    out = a.apply(a_params, inputs)

    params = net.parameters_from({a: a_params}, inputs)
    assert_parameters_equal(a_params.dense.kernel,
                            params.sequential0.dense.kernel)
    assert_parameters_equal((), params.sequential1)
    out = net.apply(params, inputs)

    out_ = net.apply_from({a: a_params}, inputs)
    assert np.array_equal(out, out_)

    out_ = net.apply_from({a: a_params}, inputs, jit=True)
    assert np.array_equal(out, out_)

    out_ = net.apply_from({a.shaped(inputs): a_params}, inputs)
    assert np.array_equal(out, out_)

    out_ = net.apply_from({a.shaped(inputs): a_params}, inputs, jit=True)
    assert np.array_equal(out, out_)

    out_ = net.shaped(inputs).apply_from({a: a_params})
    assert np.array_equal(out, out_)

    out_ = net.shaped(inputs).apply_from({a: a_params}, jit=True)
    assert np.array_equal(out, out_)

    out_ = net.shaped(inputs).apply_from({a.shaped(inputs): a_params})
    assert np.array_equal(out, out_)

    out_ = net.shaped(inputs).apply_from({a.shaped(inputs): a_params},
                                         jit=True)
    assert np.array_equal(out, out_)
Exemplo n.º 22
0
def ResNet50(num_classes):
    return Sequential(
        GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2), 'SAME'),
        BatchNorm(), relu, MaxPool((3, 3), strides=(2, 2)),
        ConvBlock(3, [64, 64, 256], strides=(1, 1)),
        IdentityBlock(3, [64, 64]), IdentityBlock(3, [64, 64]),
        ConvBlock(3, [128, 128, 512]), IdentityBlock(3, [128, 128]),
        IdentityBlock(3, [128, 128]), IdentityBlock(3, [128, 128]),
        ConvBlock(3, [256, 256, 1024]), IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]),
        ConvBlock(3, [512, 512, 2048]), IdentityBlock(3, [512, 512]),
        IdentityBlock(3, [512, 512]), AvgPool((7, 7)), flatten,
        Dense(num_classes), logsoftmax)
Exemplo n.º 23
0
def test_parameters_from_subsubmodule():
    subsublayer = Dense(2)
    sublayer = Sequential(subsublayer, relu)
    net = Sequential(sublayer, np.sum)
    inputs = np.zeros((1, 3))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    out = net.apply(params, inputs)

    subsublayer_params = subsublayer.init_parameters(inputs, key=PRNGKey(0))

    params_ = net.parameters_from({subsublayer: subsublayer_params}, inputs)
    assert_dense_parameters_equal(subsublayer_params, params_[0][0])
    out_ = net.apply(params_, inputs)
    assert out.shape == out_.shape

    out_ = net.apply_from({subsublayer: subsublayer_params}, inputs)
    assert out.shape == out_.shape

    out_ = net.apply_from({subsublayer: subsublayer_params}, inputs, jit=True)
    assert out.shape == out_.shape
Exemplo n.º 24
0
def main():
    # TODO https://github.com/JuliusKunze/jaxnet/issues/4
    print("Sorry, this example does not work yet work with the new jax version.")
    return

    train, test = read_dataset()
    _, length, x_size = train.data.shape
    class_count = train.target.shape[2]
    carry_size = 200
    batch_size = 10

    def rnn():
        return Rnn(*GRUCell(carry_size=carry_size,
                            param_init=lambda rng, shape: random.normal(rng, shape) * 0.01))

    net = Sequential(
        rnn(),
        rnn(),
        rnn(),
        lambda x: np.reshape(x, (-1, carry_size)),  # -> same weights for all time steps
        Dense(out_dim=class_count),
        softmax,
        lambda x: np.reshape(x, (-1, length, class_count)))

    @parametrized
    def cross_entropy(images, targets):
        prediction = net(images)
        return np.mean(-np.sum(targets * np.log(prediction), (1, 2)))

    @parametrized
    def error(inputs, targets):
        prediction = net(inputs)
        return np.mean(np.not_equal(np.argmax(targets, 2), np.argmax(prediction, 2)))

    opt = optimizers.RmsProp(0.003)

    batch = train.sample(batch_size)
    params = cross_entropy.init_parameters(random.PRNGKey(0), batch.data, batch.target)
    state = opt.init(params)
    for epoch in range(10):
        params = get_params(state)
        e = error.apply_from({cross_entropy: params}, test.data, test.target, jit=True)
        print(f'Epoch {epoch} error {e * 100:.1f}')

        break  # TODO https://github.com/JuliusKunze/jaxnet/issues/2
        for _ in range(100):
            batch = train.sample(batch_size)
            state = opt.update(cross_entropy.apply, state, batch.data, batch.target, jit=True)
Exemplo n.º 25
0
def test_submodule_reuse():
    inputs = np.zeros((1, 2))

    layer = Dense(5)
    net1 = Sequential(layer, Dense(2))
    net2 = Sequential(layer, Dense(3))

    layer_params = layer.init_parameters(PRNGKey(0), inputs)
    net1_params = net1.init_parameters(PRNGKey(1), inputs, reuse={layer: layer_params})
    net2_params = net2.init_parameters(PRNGKey(2), inputs, reuse={layer: layer_params})

    out1 = net1.apply(net1_params, inputs)
    assert out1.shape == (1, 2)

    out2 = net2.apply(net2_params, inputs)
    assert out2.shape == (1, 3)

    assert_dense_params_equal(layer_params, net1_params[0])
    assert_dense_params_equal(layer_params, net2_params[0])
Exemplo n.º 26
0
def test_parameters_from_sharing_between_multiple_parents():
    a = Dense(2)
    b = Sequential(a, np.sum)

    @parametrized
    def net(inputs):
        return a(inputs), b(inputs)

    inputs = np.zeros((1, 3))
    a_params = a.init_parameters(inputs, key=PRNGKey(0))
    out = a.apply(a_params, inputs)

    params = net.parameters_from({a: a_params}, inputs)
    assert_dense_parameters_equal(a_params, params.dense)
    assert_parameters_equal((), params.sequential)
    assert 2 == len(params)
    out_, _ = net.apply(params, inputs)
    assert np.array_equal(out, out_)
Exemplo n.º 27
0
def test_mnist_vae():
    @parametrized
    def encode(input):
        input = Sequential(Dense(5), relu, Dense(5), relu)(input)
        mean = Dense(10)(input)
        variance = Sequential(Dense(10), softplus)(input)
        return mean, variance

    decode = Sequential(Dense(5), relu, Dense(5), relu, Dense(5 * 5))

    @parametrized
    def elbo(key, images):
        mu_z, sigmasq_z = encode(images)
        logits_x = decode(gaussian_sample(key, mu_z, sigmasq_z))
        return bernoulli_logpdf(logits_x, images) - gaussian_kl(mu_z, sigmasq_z)

    params = elbo.init_parameters(PRNGKey(0), np.zeros((32, 5 * 5)), key=PRNGKey(0))
    assert (5, 10) == params.encode.sequential1.dense.kernel.shape
Exemplo n.º 28
0
def test_no_reuse():
    inputs = np.zeros((1, 2))

    layer = Dense(5)
    net1 = Sequential(layer, Dense(2))
    p1 = net1.init_parameters(inputs, key=PRNGKey(0))

    net2 = Sequential(layer, Dense(3))
    p2 = net2.init_parameters(inputs, key=PRNGKey(1))

    assert p1[0].kernel.shape == p2[0].kernel.shape
    assert p1[0].bias.shape == p2[0].bias.shape
    assert not np.array_equal(p1[0][0], p2[0][0])
    assert not np.array_equal(p1[0][1], p2[0][1])
Exemplo n.º 29
0
def main():
    train, test = read_dataset()
    _, length, x_size = train.data.shape
    class_count = train.target.shape[2]
    carry_size = 200
    batch_size = 10

    def rnn():
        return Rnn(*GRUCell(carry_size=carry_size,
                            param_init=lambda key, shape: random.normal(key, shape) * 0.01))

    net = Sequential(
        rnn(),
        rnn(),
        rnn(),
        lambda x: np.reshape(x, (-1, carry_size)),  # -> same weights for all time steps
        Dense(out_dim=class_count),
        softmax,
        lambda x: np.reshape(x, (-1, length, class_count)))

    @parametrized
    def cross_entropy(images, targets):
        prediction = net(images)
        return np.mean(-np.sum(targets * np.log(prediction), (1, 2)))

    @parametrized
    def error(inputs, targets):
        prediction = net(inputs)
        return np.mean(np.not_equal(np.argmax(targets, 2), np.argmax(prediction, 2)))

    opt = optimizers.RmsProp(0.003)

    batch = train.sample(batch_size)
    state = opt.init(cross_entropy.init_parameters(batch.data, batch.target, key=PRNGKey(0)))
    for epoch in range(10):
        params = opt.get_parameters(state)
        e = error.apply_from({cross_entropy: params}, test.data, test.target, jit=True)
        print(f'Epoch {epoch} error {e * 100:.1f}')

        for _ in range(100):
            batch = train.sample(batch_size)
            state = opt.update(cross_entropy.apply, state, batch.data, batch.target, jit=True)
Exemplo n.º 30
0
def test_readme():
    net = Sequential(Dense(1024), relu, Dense(1024), relu, Dense(4), log_softmax)

    @parametrized
    def loss(inputs, targets):
        return -np.mean(net(inputs) * targets)

    def next_batch(): return np.zeros((3, 784)), np.zeros((3, 4))

    params = loss.init_parameters(*next_batch(), key=PRNGKey(0))

    print(params.sequential.dense2.bias)  # [-0.01101029, -0.00749435, -0.00952365,  0.00493979]

    assert np.allclose([-0.01101029, -0.00749435, -0.00952365, 0.00493979],
                       params.sequential.dense2.bias)

    out = loss.apply(params, *next_batch())
    assert () == out.shape

    out_ = loss.apply(params, *next_batch(), jit=True)
    assert out.shape == out_.shape