Exemple #1
0
def main(batch_size=256, env_name="CartPole-v1"):
    env = gym.make(env_name)

    policy = Sequential(Dense(64), relu, Dense(env.action_space.n))

    @parametrized
    def loss(observations, actions, rewards_to_go):
        logprobs = log_softmax(policy(observations))
        action_logprobs = logprobs[np.arange(logprobs.shape[0]), actions]
        return -np.mean(action_logprobs * rewards_to_go, axis=0)

    opt = Adam()

    shaped_loss = loss.shaped(np.zeros((1, ) + env.observation_space.shape),
                              np.array([0]), np.array([0]))

    @jit
    def sample_action(state, key, observation):
        loss_params = opt.get_parameters(state)
        logits = policy.apply_from({shaped_loss: loss_params}, observation)
        return sample_categorical(key, logits)

    rng_init, rng = random.split(PRNGKey(0))
    state = opt.init(shaped_loss.init_parameters(key=rng_init))
    returns, observations, actions, rewards_to_go = [], [], [], []

    for i in range(250):
        while len(observations) < batch_size:
            observation = env.reset()
            episode_done = False
            rewards = []

            while not episode_done:
                rng_step, rng = random.split(rng)
                action = sample_action(state, rng_step, observation)
                observations.append(observation)
                actions.append(action)

                observation, reward, episode_done, info = env.step(int(action))
                rewards.append(reward)

            returns.append(onp.sum(rewards))
            rewards_to_go += list(onp.flip(onp.cumsum(onp.flip(rewards))))

        print(f'Batch {i}, recent mean return: {onp.mean(returns[-100:]):.1f}')

        state = opt.update(loss.apply,
                           state,
                           np.array(observations[:batch_size]),
                           np.array(actions[:batch_size]),
                           np.array(rewards_to_go[:batch_size]),
                           jit=True)

        observations = observations[batch_size:]
        actions = actions[batch_size:]
        rewards_to_go = rewards_to_go[batch_size:]
Exemple #2
0
def test_regularized_submodule():
    net = Sequential(Conv(2, (1, 1)), relu, Conv(2, (1, 1)), relu, flatten,
                     L2Regularized(Sequential(Dense(2), relu, Dense(2), np.sum), .1))

    input = np.ones((1, 3, 3, 1))
    params = net.init_parameters(input, key=PRNGKey(0))
    assert (2, 2) == params.regularized.model.dense1.kernel.shape

    out = net.apply(params, input)
    assert () == out.shape
Exemple #3
0
def test_reparametrized_submodule():
    net = Sequential(
        Conv(2, (3, 3)), relu, Conv(2, (3, 3)), relu, flatten,
        Reparametrized(Sequential(Dense(2), relu, Dense(2)), Scaled))

    input = np.ones((1, 3, 3, 1))
    params = net.init_parameters(PRNGKey(0), input)
    assert (2, 2) == params.reparametrized.model.dense1.kernel.shape

    out = net.apply(params, input)
    assert (1, 2) == out.shape
Exemple #4
0
def test_reuse_api():
    inputs = np.zeros((1, 2))
    net = Dense(5)
    net_params = net.init_parameters(inputs, key=PRNGKey(0))

    # train net params...

    transfer_net = Sequential(net, relu, Dense(2))
    transfer_net_params = transfer_net.init_parameters(inputs, key=PRNGKey(1),
                                                       reuse={net: net_params})

    assert net_params == transfer_net_params.dense0
Exemple #5
0
def test_L2Regularized_sequential():
    loss = Sequential(Dense(1, ones, ones), relu, Dense(1, ones, ones), sum)

    reg_loss = L2Regularized(loss, scale=2)

    inputs = np.ones(1)
    params = reg_loss.init_parameters(PRNGKey(0), inputs)
    assert np.array_equal(np.ones((1, 1)), params.model.dense0.kernel)
    assert np.array_equal(np.ones((1, 1)), params.model.dense1.kernel)

    reg_loss_out = reg_loss.apply(params, inputs)

    assert 7 == reg_loss_out
Exemple #6
0
def test_nested_module_without_inputs():
    dense = Dense(2)
    inputs = np.zeros((1, 3))
    params = dense.init_parameters(inputs, key=PRNGKey(0))
    assert (3, 2) == params.kernel.shape
    assert (2, ) == params.bias.shape
    assert str(dense).startswith('dense')

    out = dense.apply(params, inputs)
    assert (1, 2) == out.shape

    out_ = dense.apply(params, inputs, jit=True)
    assert np.allclose(out, out_)
Exemple #7
0
def test_Sequential_graceful_update_message():
    message = 'Call like Sequential(Dense(10), relu), without "[" and "]". ' \
              '(Or pass iterables with Sequential(*layers).)'
    try:
        Sequential([Dense(2), relu])
        assert False
    except ValueError as e:
        assert message == str(e)

    try:
        Sequential(Dense(2) for _ in range(3))
        assert False
    except ValueError as e:
        assert message == str(e)
Exemple #8
0
def test_no_reuse():
    inputs = np.zeros((1, 2))

    layer = Dense(5)
    net1 = Sequential(layer, Dense(2))
    p1 = net1.init_parameters(inputs, key=PRNGKey(0))

    net2 = Sequential(layer, Dense(3))
    p2 = net2.init_parameters(inputs, key=PRNGKey(1))

    assert p1[0].kernel.shape == p2[0].kernel.shape
    assert p1[0].bias.shape == p2[0].bias.shape
    assert not np.array_equal(p1[0][0], p2[0][0])
    assert not np.array_equal(p1[0][1], p2[0][1])
Exemple #9
0
def test_Dense_shape(Dense=Dense):
    net = Dense(2, kernel_init=zeros, bias_init=zeros)
    inputs = np.zeros((1, 3))

    params = net.init_parameters(PRNGKey(0), inputs)
    assert_parameters_equal((np.zeros((3, 2)), np.zeros(2)), params)

    out = net.apply(params, inputs)
    assert np.array_equal(np.zeros((1, 2)), out)

    out_ = jit(net.apply)(params, inputs)
    assert np.array_equal(out, out_)

    params_ = net.shaped(inputs).init_parameters(PRNGKey(0))
    assert_parameters_equal(params, params_)
Exemple #10
0
def test_ocr_rnn():
    length = 5
    carry_size = 3
    class_count = 4
    inputs = np.zeros((1, length, 4))

    def rnn():
        return Rnn(*GRUCell(carry_size, zeros))

    net = Sequential(
        rnn(),
        rnn(),
        rnn(),
        lambda x: np.reshape(x, (-1, carry_size)
                             ),  # -> same weights for all time steps
        Dense(class_count, zeros, zeros),
        softmax,
        lambda x: np.reshape(x, (-1, length, class_count)))

    params = net.init_parameters(PRNGKey(0), inputs)

    assert len(params) == 4
    cell = params.rnn0.gru_cell
    assert len(cell) == 3
    assert np.array_equal(np.zeros((7, 3)), cell.update_kernel)
    assert np.array_equal(np.zeros((7, 3)), cell.reset_kernel)
    assert np.array_equal(np.zeros((7, 3)), cell.compute_kernel)

    out = net.apply(params, inputs)
    assert np.array_equal(.25 * np.ones((1, 5, 4)), out)
Exemple #11
0
def test_parameters_from():
    layer = Dense(2)
    net = Sequential(layer, relu)
    inputs = np.zeros((1, 3))
    layer_params = layer.init_parameters(inputs, key=PRNGKey(0))

    params_ = net.parameters_from({layer: layer_params}, inputs)
    assert_parameters_equal((layer_params, ), params_)

    out = net.apply(params_, inputs)

    out_ = net.apply_from({layer: layer_params}, inputs)
    assert np.array_equal(out, out_)

    out_ = net.apply_from({layer: layer_params}, inputs, jit=True)
    assert np.array_equal(out, out_)
Exemple #12
0
def test_external_sequential_submodule():
    layer = Sequential(Conv(4, (2, 2)), flatten, relu, Dense(3), relu,
                       Dense(2), Sequential(Dense(2), relu))
    inputs = np.zeros((1, 5, 5, 2))

    params = layer.init_parameters(inputs, key=PRNGKey(0))
    assert (4, ) == params.conv.bias.shape
    assert (3, ) == params.dense0.bias.shape
    assert (3, 2) == params.dense1.kernel.shape
    assert (2, ) == params.dense1.bias.shape
    assert (2, ) == params.sequential.dense.bias.shape

    out = layer.apply(params, inputs)
    assert (1, 2) == out.shape

    out_ = layer.apply(params, inputs, jit=True)
    assert np.allclose(out, out_)
Exemple #13
0
def test_Parameter_dense():
    def Dense(out_dim, kernel_init=glorot_normal(), bias_init=normal()):
        @parametrized
        def dense(inputs):
            kernel = parameter((inputs.shape[-1], out_dim), kernel_init)
            bias = parameter((out_dim,), bias_init)
            return np.dot(inputs, kernel) + bias

        return dense

    net = Dense(2)
    inputs = np.zeros((1, 3))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    assert (3, 2) == params.parameter0.shape
    assert (2,) == params.parameter1.shape

    out = net.apply(params, inputs, jit=True)
    assert (1, 2) == out.shape
Exemple #14
0
def test_save_and_load_params():
    params = Dense(2).init_parameters(np.zeros((1, 2)), key=PRNGKey(0))

    from pathlib import Path
    path = Path('/') / 'tmp' / 'net.params'
    save(params, path)
    params_ = load(path)

    assert_dense_parameters_equal(params, params_)
Exemple #15
0
def test_mnist_vae():
    @parametrized
    def encode(input):
        input = Sequential(Dense(5), relu, Dense(5), relu)(input)
        mean = Dense(10)(input)
        variance = Sequential(Dense(10), softplus)(input)
        return mean, variance

    decode = Sequential(Dense(5), relu, Dense(5), relu, Dense(5 * 5))

    @parametrized
    def elbo(key, images):
        mu_z, sigmasq_z = encode(images)
        logits_x = decode(gaussian_sample(key, mu_z, sigmasq_z))
        return bernoulli_logpdf(logits_x, images) - gaussian_kl(mu_z, sigmasq_z)

    params = elbo.init_parameters(PRNGKey(0), np.zeros((32, 5 * 5)), key=PRNGKey(0))
    assert (5, 10) == params.encode.sequential1.dense.kernel.shape
Exemple #16
0
def test_parameters_from_sharing_between_multiple_parents():
    a = Dense(2)
    b = Sequential(a, np.sum)

    @parametrized
    def net(inputs):
        return a(inputs), b(inputs)

    inputs = np.zeros((1, 3))
    a_params = a.init_parameters(inputs, key=PRNGKey(0))
    out = a.apply(a_params, inputs)

    params = net.parameters_from({a: a_params}, inputs)
    assert_dense_parameters_equal(a_params, params.dense)
    assert_parameters_equal((), params.sequential)
    assert 2 == len(params)
    out_, _ = net.apply(params, inputs)
    assert np.array_equal(out, out_)
Exemple #17
0
def test_submodule_reuse():
    inputs = np.zeros((1, 2))

    layer = Dense(5)
    net1 = Sequential(layer, Dense(2))
    net2 = Sequential(layer, Dense(3))

    layer_params = layer.init_parameters(PRNGKey(0), inputs)
    net1_params = net1.init_parameters(PRNGKey(1), inputs, reuse={layer: layer_params})
    net2_params = net2.init_parameters(PRNGKey(2), inputs, reuse={layer: layer_params})

    out1 = net1.apply(net1_params, inputs)
    assert out1.shape == (1, 2)

    out2 = net2.apply(net2_params, inputs)
    assert out2.shape == (1, 3)

    assert_dense_params_equal(layer_params, net1_params[0])
    assert_dense_params_equal(layer_params, net2_params[0])
Exemple #18
0
def test_parameters_from_subsubmodule():
    subsublayer = Dense(2)
    sublayer = Sequential(subsublayer, relu)
    net = Sequential(sublayer, np.sum)
    inputs = np.zeros((1, 3))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    out = net.apply(params, inputs)

    subsublayer_params = subsublayer.init_parameters(inputs, key=PRNGKey(0))

    params_ = net.parameters_from({subsublayer: subsublayer_params}, inputs)
    assert_dense_parameters_equal(subsublayer_params, params_[0][0])
    out_ = net.apply(params_, inputs)
    assert out.shape == out_.shape

    out_ = net.apply_from({subsublayer: subsublayer_params}, inputs)
    assert out.shape == out_.shape

    out_ = net.apply_from({subsublayer: subsublayer_params}, inputs, jit=True)
    assert out.shape == out_.shape
Exemple #19
0
def test_external_submodule_partial_jit():
    layer = Dense(3)

    @parametrized
    def net_fun(inputs):
        return jit(lambda x: 2 * x)(layer(inputs))

    inputs = random_inputs((2,))
    params = net_fun.init_parameters(PRNGKey(0), inputs)
    out = net_fun.apply(params, inputs)
    assert out.shape == (3,)
Exemple #20
0
def test_readme():
    net = Sequential(Dense(1024), relu, Dense(1024), relu, Dense(4), log_softmax)

    @parametrized
    def loss(inputs, targets):
        return -np.mean(net(inputs) * targets)

    def next_batch(): return np.zeros((3, 784)), np.zeros((3, 4))

    params = loss.init_parameters(*next_batch(), key=PRNGKey(0))

    print(params.sequential.dense2.bias)  # [-0.01101029, -0.00749435, -0.00952365,  0.00493979]

    assert np.allclose([-0.01101029, -0.00749435, -0.00952365, 0.00493979],
                       params.sequential.dense2.bias)

    out = loss.apply(params, *next_batch())
    assert () == out.shape

    out_ = loss.apply(params, *next_batch(), jit=True)
    assert out.shape == out_.shape
Exemple #21
0
def test_input_dependent_nested_modules():
    @parametrized
    def layer(inputs):
        return Dense(inputs.shape[0])(inputs)

    net = Sequential(Dense(3), layer)

    inputs = np.zeros((5, 3))
    params = net.init_parameters(inputs, key=PRNGKey(0))

    out = net.apply(params, inputs)
    assert (5, 5) == out.shape
Exemple #22
0
def test_submodule_reuse():
    inputs = np.zeros((1, 2))

    layer = Dense(5)
    net1 = Sequential(layer, Dense(2))
    net2 = Sequential(layer, Dense(3))

    layer_params = layer.init_parameters(inputs, key=PRNGKey(0))
    net1_params = net1.init_parameters(inputs,
                                       key=PRNGKey(1),
                                       reuse={layer: layer_params})
    net2_params = net2.init_parameters(inputs,
                                       key=PRNGKey(2),
                                       reuse={layer: layer_params})

    out1 = net1.apply(net1_params, inputs)
    assert out1.shape == (1, 2)

    out2 = net2.apply(net2_params, inputs)
    assert out2.shape == (1, 3)

    assert_dense_parameters_equal(layer_params, net1_params[0])
    assert_dense_parameters_equal(layer_params, net2_params[0])

    new_layer_params = layer.init_parameters(inputs, key=PRNGKey(3))
    combined_params = net1.parameters_from(
        {
            net1: net1_params,
            layer: new_layer_params
        }, inputs)
    assert_dense_parameters_equal(new_layer_params, combined_params.dense0)
    assert_dense_parameters_equal(net1_params.dense1, combined_params.dense1)
Exemple #23
0
def test_external_param_sharing():
    layer = Dense(2, zeros, zeros)
    shared_net = Sequential(layer, layer)

    inputs = np.zeros((1, 2))
    params = shared_net.init_parameters(inputs, key=PRNGKey(0))
    assert_parameters_equal(((np.zeros((2, 2)), np.zeros(2)), ), params)

    out = shared_net.apply(params, inputs)
    assert np.array_equal(np.zeros((1, 2)), out)

    out = shared_net.apply(params, inputs, jit=True)
    assert np.array_equal(np.zeros((1, 2)), out)
Exemple #24
0
def ResNet50(num_classes):
    return Sequential(
        GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2), 'SAME'),
        BatchNorm(), relu, MaxPool((3, 3), strides=(2, 2)),
        ConvBlock(3, [64, 64, 256], strides=(1, 1)),
        IdentityBlock(3, [64, 64]), IdentityBlock(3, [64, 64]),
        ConvBlock(3, [128, 128, 512]), IdentityBlock(3, [128, 128]),
        IdentityBlock(3, [128, 128]), IdentityBlock(3, [128, 128]),
        ConvBlock(3, [256, 256, 1024]), IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]),
        ConvBlock(3, [512, 512, 2048]), IdentityBlock(3, [512, 512]),
        IdentityBlock(3, [512, 512]), AvgPool((7, 7)), flatten,
        Dense(num_classes), logsoftmax)
Exemple #25
0
def test_parameters_from_top_level():
    net = Dense(2)
    inputs = np.zeros((1, 3))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    out = net.apply(params, inputs)

    params_ = net.parameters_from({net: params}, inputs)
    assert_dense_parameters_equal(params, params_)
    out_ = net.apply(params_, inputs)
    assert np.array_equal(out, out_)

    out_ = net.apply_from({net: params}, inputs)
    assert np.array_equal(out, out_)

    out_ = net.apply_from({net: params}, inputs, jit=True)
    assert np.array_equal(out, out_)
Exemple #26
0
def test_readme():
    net = Sequential(Dense(1024), relu, Dense(1024), relu, Dense(4),
                     logsoftmax)

    @parametrized
    def loss(inputs, targets):
        return -np.mean(net(inputs) * targets)

    def next_batch():
        return np.zeros((3, 784)), np.zeros((3, 4))

    params = loss.init_parameters(PRNGKey(0), *next_batch())

    print(params.sequential.dense2.bias
          )  # [0.00376661 0.01038619 0.00920947 0.00792002]

    assert np.allclose([0.00376661, 0.01038619, 0.00920947, 0.00792002],
                       params.sequential.dense2.bias)

    out = loss.apply(params, *next_batch())
    assert () == out.shape

    out_ = loss.apply(params, *next_batch(), jit=True)
    assert out.shape == out_.shape
Exemple #27
0
def main():
    # TODO https://github.com/JuliusKunze/jaxnet/issues/4
    print("Sorry, this example does not work yet work with the new jax version.")
    return

    train, test = read_dataset()
    _, length, x_size = train.data.shape
    class_count = train.target.shape[2]
    carry_size = 200
    batch_size = 10

    def rnn():
        return Rnn(*GRUCell(carry_size=carry_size,
                            param_init=lambda rng, shape: random.normal(rng, shape) * 0.01))

    net = Sequential(
        rnn(),
        rnn(),
        rnn(),
        lambda x: np.reshape(x, (-1, carry_size)),  # -> same weights for all time steps
        Dense(out_dim=class_count),
        softmax,
        lambda x: np.reshape(x, (-1, length, class_count)))

    @parametrized
    def cross_entropy(images, targets):
        prediction = net(images)
        return np.mean(-np.sum(targets * np.log(prediction), (1, 2)))

    @parametrized
    def error(inputs, targets):
        prediction = net(inputs)
        return np.mean(np.not_equal(np.argmax(targets, 2), np.argmax(prediction, 2)))

    opt = optimizers.RmsProp(0.003)

    batch = train.sample(batch_size)
    params = cross_entropy.init_parameters(random.PRNGKey(0), batch.data, batch.target)
    state = opt.init(params)
    for epoch in range(10):
        params = get_params(state)
        e = error.apply_from({cross_entropy: params}, test.data, test.target, jit=True)
        print(f'Epoch {epoch} error {e * 100:.1f}')

        break  # TODO https://github.com/JuliusKunze/jaxnet/issues/2
        for _ in range(100):
            batch = train.sample(batch_size)
            state = opt.update(cross_entropy.apply, state, batch.data, batch.target, jit=True)
Exemple #28
0
def test_submodule_reuse_top_level():
    net = Dense(2)
    inputs = np.zeros((1, 3))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    out = net.apply(params, inputs)

    params_ = net.init_parameters(inputs, key=PRNGKey(1), reuse={net: params})
    assert_dense_parameters_equal(params, params_)

    out_ = net.apply(params_, inputs)
    assert np.array_equal(out, out_)
Exemple #29
0
def test_external_submodule2():
    layer = Dense(2, zeros, zeros)

    @parametrized
    def net(inputs):
        return layer(inputs)

    inputs = np.zeros((1, 2))

    params = net.init_parameters(inputs, key=PRNGKey(0))
    assert_parameters_equal(((np.zeros((2, 2)), np.zeros(2)), ), params)

    out = net.apply(params, inputs)
    assert np.array_equal(np.zeros((1, 2)), out)

    out_ = net.apply(params, inputs, jit=True)
    assert np.array_equal(out, out_)
Exemple #30
0
def test_external_submodule():
    layer = Dense(3)

    @parametrized
    def net(inputs):
        return 2 * layer(inputs)

    inputs = random_inputs((2, ))
    params = net.init_parameters(inputs, key=PRNGKey(0))
    out = net.apply(params, inputs)
    assert out.shape == (3, )

    out_ = net.apply(params, inputs)
    assert np.array_equal(out, out_)

    out_ = net.apply(params, inputs, jit=True)
    assert np.allclose(out, out_)