def test_submodule_reuse(): inputs = np.zeros((1, 2)) layer = Dense(5) net1 = Sequential(layer, Dense(2)) net2 = Sequential(layer, Dense(3)) layer_params = layer.init_parameters(inputs, key=PRNGKey(0)) net1_params = net1.init_parameters(inputs, key=PRNGKey(1), reuse={layer: layer_params}) net2_params = net2.init_parameters(inputs, key=PRNGKey(2), reuse={layer: layer_params}) out1 = net1.apply(net1_params, inputs) assert out1.shape == (1, 2) out2 = net2.apply(net2_params, inputs) assert out2.shape == (1, 3) assert_dense_parameters_equal(layer_params, net1_params[0]) assert_dense_parameters_equal(layer_params, net2_params[0]) new_layer_params = layer.init_parameters(inputs, key=PRNGKey(3)) combined_params = net1.parameters_from( { net1: net1_params, layer: new_layer_params }, inputs) assert_dense_parameters_equal(new_layer_params, combined_params.dense0) assert_dense_parameters_equal(net1_params.dense1, combined_params.dense1)
def main(batch_size=256, env_name="CartPole-v1"): env = gym.make(env_name) policy = Sequential(Dense(64), relu, Dense(env.action_space.n)) @parametrized def loss(observations, actions, rewards_to_go): logprobs = log_softmax(policy(observations)) action_logprobs = logprobs[np.arange(logprobs.shape[0]), actions] return -np.mean(action_logprobs * rewards_to_go, axis=0) opt = Adam() shaped_loss = loss.shaped(np.zeros((1, ) + env.observation_space.shape), np.array([0]), np.array([0])) @jit def sample_action(state, key, observation): loss_params = opt.get_parameters(state) logits = policy.apply_from({shaped_loss: loss_params}, observation) return sample_categorical(key, logits) rng_init, rng = random.split(PRNGKey(0)) state = opt.init(shaped_loss.init_parameters(key=rng_init)) returns, observations, actions, rewards_to_go = [], [], [], [] for i in range(250): while len(observations) < batch_size: observation = env.reset() episode_done = False rewards = [] while not episode_done: rng_step, rng = random.split(rng) action = sample_action(state, rng_step, observation) observations.append(observation) actions.append(action) observation, reward, episode_done, info = env.step(int(action)) rewards.append(reward) returns.append(onp.sum(rewards)) rewards_to_go += list(onp.flip(onp.cumsum(onp.flip(rewards)))) print(f'Batch {i}, recent mean return: {onp.mean(returns[-100:]):.1f}') state = opt.update(loss.apply, state, np.array(observations[:batch_size]), np.array(actions[:batch_size]), np.array(rewards_to_go[:batch_size]), jit=True) observations = observations[batch_size:] actions = actions[batch_size:] rewards_to_go = rewards_to_go[batch_size:]
def test_regularized_submodule(): net = Sequential(Conv(2, (1, 1)), relu, Conv(2, (1, 1)), relu, flatten, L2Regularized(Sequential(Dense(2), relu, Dense(2), np.sum), .1)) input = np.ones((1, 3, 3, 1)) params = net.init_parameters(input, key=PRNGKey(0)) assert (2, 2) == params.regularized.model.dense1.kernel.shape out = net.apply(params, input) assert () == out.shape
def test_reparametrized_submodule(): net = Sequential( Conv(2, (3, 3)), relu, Conv(2, (3, 3)), relu, flatten, Reparametrized(Sequential(Dense(2), relu, Dense(2)), Scaled)) input = np.ones((1, 3, 3, 1)) params = net.init_parameters(PRNGKey(0), input) assert (2, 2) == params.reparametrized.model.dense1.kernel.shape out = net.apply(params, input) assert (1, 2) == out.shape
def test_reuse_api(): inputs = np.zeros((1, 2)) net = Dense(5) net_params = net.init_parameters(inputs, key=PRNGKey(0)) # train net params... transfer_net = Sequential(net, relu, Dense(2)) transfer_net_params = transfer_net.init_parameters(inputs, key=PRNGKey(1), reuse={net: net_params}) assert net_params == transfer_net_params.dense0
def test_L2Regularized_sequential(): loss = Sequential(Dense(1, ones, ones), relu, Dense(1, ones, ones), sum) reg_loss = L2Regularized(loss, scale=2) inputs = np.ones(1) params = reg_loss.init_parameters(PRNGKey(0), inputs) assert np.array_equal(np.ones((1, 1)), params.model.dense0.kernel) assert np.array_equal(np.ones((1, 1)), params.model.dense1.kernel) reg_loss_out = reg_loss.apply(params, inputs) assert 7 == reg_loss_out
def test_no_reuse(): inputs = np.zeros((1, 2)) layer = Dense(5) net1 = Sequential(layer, Dense(2)) p1 = net1.init_parameters(inputs, key=PRNGKey(0)) net2 = Sequential(layer, Dense(3)) p2 = net2.init_parameters(inputs, key=PRNGKey(1)) assert p1[0].kernel.shape == p2[0].kernel.shape assert p1[0].bias.shape == p2[0].bias.shape assert not np.array_equal(p1[0][0], p2[0][0]) assert not np.array_equal(p1[0][1], p2[0][1])
def test_Sequential_graceful_update_message(): message = 'Call like Sequential(Dense(10), relu), without "[" and "]". ' \ '(Or pass iterables with Sequential(*layers).)' try: Sequential([Dense(2), relu]) assert False except ValueError as e: assert message == str(e) try: Sequential(Dense(2) for _ in range(3)) assert False except ValueError as e: assert message == str(e)
def test_ocr_rnn(): length = 5 carry_size = 3 class_count = 4 inputs = np.zeros((1, length, 4)) def rnn(): return Rnn(*GRUCell(carry_size, zeros)) net = Sequential( rnn(), rnn(), rnn(), lambda x: np.reshape(x, (-1, carry_size) ), # -> same weights for all time steps Dense(class_count, zeros, zeros), softmax, lambda x: np.reshape(x, (-1, length, class_count))) params = net.init_parameters(PRNGKey(0), inputs) assert len(params) == 4 cell = params.rnn0.gru_cell assert len(cell) == 3 assert np.array_equal(np.zeros((7, 3)), cell.update_kernel) assert np.array_equal(np.zeros((7, 3)), cell.reset_kernel) assert np.array_equal(np.zeros((7, 3)), cell.compute_kernel) out = net.apply(params, inputs) assert np.array_equal(.25 * np.ones((1, 5, 4)), out)
def test_external_sequential_submodule(): layer = Sequential(Conv(4, (2, 2)), flatten, relu, Dense(3), relu, Dense(2), Sequential(Dense(2), relu)) inputs = np.zeros((1, 5, 5, 2)) params = layer.init_parameters(inputs, key=PRNGKey(0)) assert (4, ) == params.conv.bias.shape assert (3, ) == params.dense0.bias.shape assert (3, 2) == params.dense1.kernel.shape assert (2, ) == params.dense1.bias.shape assert (2, ) == params.sequential.dense.bias.shape out = layer.apply(params, inputs) assert (1, 2) == out.shape out_ = layer.apply(params, inputs, jit=True) assert np.allclose(out, out_)
def test_save_and_load_params(): params = Dense(2).init_parameters(np.zeros((1, 2)), key=PRNGKey(0)) from pathlib import Path path = Path('/') / 'tmp' / 'net.params' save(params, path) params_ = load(path) assert_dense_parameters_equal(params, params_)
def test_mnist_vae(): @parametrized def encode(input): input = Sequential(Dense(5), relu, Dense(5), relu)(input) mean = Dense(10)(input) variance = Sequential(Dense(10), softplus)(input) return mean, variance decode = Sequential(Dense(5), relu, Dense(5), relu, Dense(5 * 5)) @parametrized def elbo(key, images): mu_z, sigmasq_z = encode(images) logits_x = decode(gaussian_sample(key, mu_z, sigmasq_z)) return bernoulli_logpdf(logits_x, images) - gaussian_kl(mu_z, sigmasq_z) params = elbo.init_parameters(PRNGKey(0), np.zeros((32, 5 * 5)), key=PRNGKey(0)) assert (5, 10) == params.encode.sequential1.dense.kernel.shape
def test_submodule_reuse(): inputs = np.zeros((1, 2)) layer = Dense(5) net1 = Sequential(layer, Dense(2)) net2 = Sequential(layer, Dense(3)) layer_params = layer.init_parameters(PRNGKey(0), inputs) net1_params = net1.init_parameters(PRNGKey(1), inputs, reuse={layer: layer_params}) net2_params = net2.init_parameters(PRNGKey(2), inputs, reuse={layer: layer_params}) out1 = net1.apply(net1_params, inputs) assert out1.shape == (1, 2) out2 = net2.apply(net2_params, inputs) assert out2.shape == (1, 3) assert_dense_params_equal(layer_params, net1_params[0]) assert_dense_params_equal(layer_params, net2_params[0])
def test_submodule_reuse_top_level(): net = Dense(2) inputs = np.zeros((1, 3)) params = net.init_parameters(inputs, key=PRNGKey(0)) out = net.apply(params, inputs) params_ = net.init_parameters(inputs, key=PRNGKey(1), reuse={net: params}) assert_dense_parameters_equal(params, params_) out_ = net.apply(params_, inputs) assert np.array_equal(out, out_)
def test_external_submodule_partial_jit(): layer = Dense(3) @parametrized def net_fun(inputs): return jit(lambda x: 2 * x)(layer(inputs)) inputs = random_inputs((2,)) params = net_fun.init_parameters(PRNGKey(0), inputs) out = net_fun.apply(params, inputs) assert out.shape == (3,)
def test_readme(): net = Sequential(Dense(1024), relu, Dense(1024), relu, Dense(4), log_softmax) @parametrized def loss(inputs, targets): return -np.mean(net(inputs) * targets) def next_batch(): return np.zeros((3, 784)), np.zeros((3, 4)) params = loss.init_parameters(*next_batch(), key=PRNGKey(0)) print(params.sequential.dense2.bias) # [-0.01101029, -0.00749435, -0.00952365, 0.00493979] assert np.allclose([-0.01101029, -0.00749435, -0.00952365, 0.00493979], params.sequential.dense2.bias) out = loss.apply(params, *next_batch()) assert () == out.shape out_ = loss.apply(params, *next_batch(), jit=True) assert out.shape == out_.shape
def test_input_dependent_nested_modules(): @parametrized def layer(inputs): return Dense(inputs.shape[0])(inputs) net = Sequential(Dense(3), layer) inputs = np.zeros((5, 3)) params = net.init_parameters(inputs, key=PRNGKey(0)) out = net.apply(params, inputs) assert (5, 5) == out.shape
def test_external_param_sharing(): layer = Dense(2, zeros, zeros) shared_net = Sequential(layer, layer) inputs = np.zeros((1, 2)) params = shared_net.init_parameters(inputs, key=PRNGKey(0)) assert_parameters_equal(((np.zeros((2, 2)), np.zeros(2)), ), params) out = shared_net.apply(params, inputs) assert np.array_equal(np.zeros((1, 2)), out) out = shared_net.apply(params, inputs, jit=True) assert np.array_equal(np.zeros((1, 2)), out)
def test_nested_module_without_inputs(): dense = Dense(2) inputs = np.zeros((1, 3)) params = dense.init_parameters(inputs, key=PRNGKey(0)) assert (3, 2) == params.kernel.shape assert (2, ) == params.bias.shape assert str(dense).startswith('dense') out = dense.apply(params, inputs) assert (1, 2) == out.shape out_ = dense.apply(params, inputs, jit=True) assert np.allclose(out, out_)
def ResNet50(num_classes): return Sequential( GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2), 'SAME'), BatchNorm(), relu, MaxPool((3, 3), strides=(2, 2)), ConvBlock(3, [64, 64, 256], strides=(1, 1)), IdentityBlock(3, [64, 64]), IdentityBlock(3, [64, 64]), ConvBlock(3, [128, 128, 512]), IdentityBlock(3, [128, 128]), IdentityBlock(3, [128, 128]), IdentityBlock(3, [128, 128]), ConvBlock(3, [256, 256, 1024]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), ConvBlock(3, [512, 512, 2048]), IdentityBlock(3, [512, 512]), IdentityBlock(3, [512, 512]), AvgPool((7, 7)), flatten, Dense(num_classes), logsoftmax)
def test_readme(): net = Sequential(Dense(1024), relu, Dense(1024), relu, Dense(4), logsoftmax) @parametrized def loss(inputs, targets): return -np.mean(net(inputs) * targets) def next_batch(): return np.zeros((3, 784)), np.zeros((3, 4)) params = loss.init_parameters(PRNGKey(0), *next_batch()) print(params.sequential.dense2.bias ) # [0.00376661 0.01038619 0.00920947 0.00792002] assert np.allclose([0.00376661, 0.01038619, 0.00920947, 0.00792002], params.sequential.dense2.bias) out = loss.apply(params, *next_batch()) assert () == out.shape out_ = loss.apply(params, *next_batch(), jit=True) assert out.shape == out_.shape
def test_Dense_shape(Dense=Dense): net = Dense(2, kernel_init=zeros, bias_init=zeros) inputs = np.zeros((1, 3)) params = net.init_parameters(PRNGKey(0), inputs) assert_parameters_equal((np.zeros((3, 2)), np.zeros(2)), params) out = net.apply(params, inputs) assert np.array_equal(np.zeros((1, 2)), out) out_ = jit(net.apply)(params, inputs) assert np.array_equal(out, out_) params_ = net.shaped(inputs).init_parameters(PRNGKey(0)) assert_parameters_equal(params, params_)
def main(): # TODO https://github.com/JuliusKunze/jaxnet/issues/4 print("Sorry, this example does not work yet work with the new jax version.") return train, test = read_dataset() _, length, x_size = train.data.shape class_count = train.target.shape[2] carry_size = 200 batch_size = 10 def rnn(): return Rnn(*GRUCell(carry_size=carry_size, param_init=lambda rng, shape: random.normal(rng, shape) * 0.01)) net = Sequential( rnn(), rnn(), rnn(), lambda x: np.reshape(x, (-1, carry_size)), # -> same weights for all time steps Dense(out_dim=class_count), softmax, lambda x: np.reshape(x, (-1, length, class_count))) @parametrized def cross_entropy(images, targets): prediction = net(images) return np.mean(-np.sum(targets * np.log(prediction), (1, 2))) @parametrized def error(inputs, targets): prediction = net(inputs) return np.mean(np.not_equal(np.argmax(targets, 2), np.argmax(prediction, 2))) opt = optimizers.RmsProp(0.003) batch = train.sample(batch_size) params = cross_entropy.init_parameters(random.PRNGKey(0), batch.data, batch.target) state = opt.init(params) for epoch in range(10): params = get_params(state) e = error.apply_from({cross_entropy: params}, test.data, test.target, jit=True) print(f'Epoch {epoch} error {e * 100:.1f}') break # TODO https://github.com/JuliusKunze/jaxnet/issues/2 for _ in range(100): batch = train.sample(batch_size) state = opt.update(cross_entropy.apply, state, batch.data, batch.target, jit=True)
def test_parameters_from_top_level(): net = Dense(2) inputs = np.zeros((1, 3)) params = net.init_parameters(inputs, key=PRNGKey(0)) out = net.apply(params, inputs) params_ = net.parameters_from({net: params}, inputs) assert_dense_parameters_equal(params, params_) out_ = net.apply(params_, inputs) assert np.array_equal(out, out_) out_ = net.apply_from({net: params}, inputs) assert np.array_equal(out, out_) out_ = net.apply_from({net: params}, inputs, jit=True) assert np.array_equal(out, out_)
def test_parameters_from(): layer = Dense(2) net = Sequential(layer, relu) inputs = np.zeros((1, 3)) layer_params = layer.init_parameters(inputs, key=PRNGKey(0)) params_ = net.parameters_from({layer: layer_params}, inputs) assert_parameters_equal((layer_params, ), params_) out = net.apply(params_, inputs) out_ = net.apply_from({layer: layer_params}, inputs) assert np.array_equal(out, out_) out_ = net.apply_from({layer: layer_params}, inputs, jit=True) assert np.array_equal(out, out_)
def test_external_submodule(): layer = Dense(3) @parametrized def net(inputs): return 2 * layer(inputs) inputs = random_inputs((2, )) params = net.init_parameters(inputs, key=PRNGKey(0)) out = net.apply(params, inputs) assert out.shape == (3, ) out_ = net.apply(params, inputs) assert np.array_equal(out, out_) out_ = net.apply(params, inputs, jit=True) assert np.allclose(out, out_)
def test_external_submodule2(): layer = Dense(2, zeros, zeros) @parametrized def net(inputs): return layer(inputs) inputs = np.zeros((1, 2)) params = net.init_parameters(inputs, key=PRNGKey(0)) assert_parameters_equal(((np.zeros((2, 2)), np.zeros(2)), ), params) out = net.apply(params, inputs) assert np.array_equal(np.zeros((1, 2)), out) out_ = net.apply(params, inputs, jit=True) assert np.array_equal(out, out_)
def test_parameters_from_sharing_between_multiple_parents(): a = Dense(2) b = Sequential(a, np.sum) @parametrized def net(inputs): return a(inputs), b(inputs) inputs = np.zeros((1, 3)) a_params = a.init_parameters(inputs, key=PRNGKey(0)) out = a.apply(a_params, inputs) params = net.parameters_from({a: a_params}, inputs) assert_dense_parameters_equal(a_params, params.dense) assert_parameters_equal((), params.sequential) assert 2 == len(params) out_, _ = net.apply(params, inputs) assert np.array_equal(out, out_)
def test_Parameter_dense(): def Dense(out_dim, kernel_init=glorot_normal(), bias_init=normal()): @parametrized def dense(inputs): kernel = parameter((inputs.shape[-1], out_dim), kernel_init) bias = parameter((out_dim,), bias_init) return np.dot(inputs, kernel) + bias return dense net = Dense(2) inputs = np.zeros((1, 3)) params = net.init_parameters(inputs, key=PRNGKey(0)) assert (3, 2) == params.parameter0.shape assert (2,) == params.parameter1.shape out = net.apply(params, inputs, jit=True) assert (1, 2) == out.shape
def test_parameters_from_shared_submodules(): sublayer = Dense(2) a = Sequential(sublayer, relu) b = Sequential(sublayer, np.sum) @parametrized def net(inputs): return a(inputs) * b(inputs) inputs = np.zeros((1, 3)) a_params = a.init_parameters(inputs, key=PRNGKey(0)) out = a.apply(a_params, inputs) params = net.parameters_from({a: a_params}, inputs) assert_parameters_equal(a_params.dense.kernel, params.sequential0.dense.kernel) assert_parameters_equal((), params.sequential1) out = net.apply(params, inputs) out_ = net.apply_from({a: a_params}, inputs) assert np.array_equal(out, out_) out_ = net.apply_from({a: a_params}, inputs, jit=True) assert np.array_equal(out, out_) out_ = net.apply_from({a.shaped(inputs): a_params}, inputs) assert np.array_equal(out, out_) out_ = net.apply_from({a.shaped(inputs): a_params}, inputs, jit=True) assert np.array_equal(out, out_) out_ = net.shaped(inputs).apply_from({a: a_params}) assert np.array_equal(out, out_) out_ = net.shaped(inputs).apply_from({a: a_params}, jit=True) assert np.array_equal(out, out_) out_ = net.shaped(inputs).apply_from({a.shaped(inputs): a_params}) assert np.array_equal(out, out_) out_ = net.shaped(inputs).apply_from({a.shaped(inputs): a_params}, jit=True) assert np.array_equal(out, out_)