def test_ocr_rnn(): length = 5 carry_size = 3 class_count = 4 inputs = np.zeros((1, length, 4)) def rnn(): return Rnn(*GRUCell(carry_size, zeros)) net = Sequential( rnn(), rnn(), rnn(), lambda x: np.reshape(x, (-1, carry_size) ), # -> same weights for all time steps Dense(class_count, zeros, zeros), softmax, lambda x: np.reshape(x, (-1, length, class_count))) params = net.init_parameters(PRNGKey(0), inputs) assert len(params) == 4 cell = params.rnn0.gru_cell assert len(cell) == 3 assert np.array_equal(np.zeros((7, 3)), cell.update_kernel) assert np.array_equal(np.zeros((7, 3)), cell.reset_kernel) assert np.array_equal(np.zeros((7, 3)), cell.compute_kernel) out = net.apply(params, inputs) assert np.array_equal(.25 * np.ones((1, 5, 4)), out)
def conv_block(inputs): main = Sequential(Conv(filters1, (1, 1), strides), BatchNorm(), relu, Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), relu, Conv(filters3, (1, 1)), BatchNorm()) shortcut = Sequential(Conv(filters3, (1, 1), strides), BatchNorm()) return relu(sum((main(inputs), shortcut(inputs))))
def res_layer(inputs): """ From original doc string: The layer contains a gated filter that connects to dense output and to a skip connection: |-> [gate] -| |-> 1x1 conv -> skip output | |-> (*) -| input -|-> [filter] -| |-> 1x1 conv -| | |-> (+) -> dense output |------------------------------------| Where `[gate]` and `[filter]` are causal convolutions with a non-linear activation at the output """ gated = Sequential( Conv1D(dilation_channels, (filter_width, ), dilation=(dilation, )), sigmoid)(inputs) filtered = Sequential( Conv1D(dilation_channels, (filter_width, ), dilation=(dilation, )), np.tanh)(inputs) p = gated * filtered out = Conv1D(residual_channels, (1, ), padding='SAME')(p) # Add the transformed output of the resblock to the sliced input: sliced_inputs = lax.dynamic_slice( inputs, [0, inputs.shape[1] - out.shape[1], 0], [inputs.shape[0], out.shape[1], inputs.shape[2]]) new_out = sum(out, sliced_inputs) skip = Conv1D(residual_channels, (1, ), padding='SAME')(skip_slice(p, output_width)) return new_out, skip
def test_regularized_submodule(): net = Sequential(Conv(2, (1, 1)), relu, Conv(2, (1, 1)), relu, flatten, L2Regularized(Sequential(Dense(2), relu, Dense(2), np.sum), .1)) input = np.ones((1, 3, 3, 1)) params = net.init_parameters(input, key=PRNGKey(0)) assert (2, 2) == params.regularized.model.dense1.kernel.shape out = net.apply(params, input) assert () == out.shape
def test_submodule_reuse(): inputs = np.zeros((1, 2)) layer = Dense(5) net1 = Sequential(layer, Dense(2)) net2 = Sequential(layer, Dense(3)) layer_params = layer.init_parameters(inputs, key=PRNGKey(0)) net1_params = net1.init_parameters(inputs, key=PRNGKey(1), reuse={layer: layer_params}) net2_params = net2.init_parameters(inputs, key=PRNGKey(2), reuse={layer: layer_params}) out1 = net1.apply(net1_params, inputs) assert out1.shape == (1, 2) out2 = net2.apply(net2_params, inputs) assert out2.shape == (1, 3) assert_dense_parameters_equal(layer_params, net1_params[0]) assert_dense_parameters_equal(layer_params, net2_params[0]) new_layer_params = layer.init_parameters(inputs, key=PRNGKey(3)) combined_params = net1.parameters_from( { net1: net1_params, layer: new_layer_params }, inputs) assert_dense_parameters_equal(new_layer_params, combined_params.dense0) assert_dense_parameters_equal(net1_params.dense1, combined_params.dense1)
def test_reparametrized_submodule(): net = Sequential( Conv(2, (3, 3)), relu, Conv(2, (3, 3)), relu, flatten, Reparametrized(Sequential(Dense(2), relu, Dense(2)), Scaled)) input = np.ones((1, 3, 3, 1)) params = net.init_parameters(PRNGKey(0), input) assert (2, 2) == params.reparametrized.model.dense1.kernel.shape out = net.apply(params, input) assert (1, 2) == out.shape
def test_pool_shape(Pool): conv = Conv(2, filter_shape=(3, 3), padding='SAME', kernel_init=zeros, bias_init=zeros) inputs = np.zeros((1, 5, 5, 2)) pooled = Sequential(conv, Pool(window_shape=(1, 1), strides=(2, 2))) params = pooled.init_parameters(PRNGKey(0), inputs) out = pooled.apply(params, inputs) assert np.array_equal(np.zeros((1, 3, 3, 2)), out)
def test_input_dependent_nested_modules(): @parametrized def layer(inputs): return Dense(inputs.shape[0])(inputs) net = Sequential(Dense(3), layer) inputs = np.zeros((5, 3)) params = net.init_parameters(inputs, key=PRNGKey(0)) out = net.apply(params, inputs) assert (5, 5) == out.shape
def test_reuse_api(): inputs = np.zeros((1, 2)) net = Dense(5) net_params = net.init_parameters(inputs, key=PRNGKey(0)) # train net params... transfer_net = Sequential(net, relu, Dense(2)) transfer_net_params = transfer_net.init_parameters(inputs, key=PRNGKey(1), reuse={net: net_params}) assert net_params == transfer_net_params.dense0
def test_external_param_sharing(): layer = Dense(2, zeros, zeros) shared_net = Sequential(layer, layer) inputs = np.zeros((1, 2)) params = shared_net.init_parameters(inputs, key=PRNGKey(0)) assert_parameters_equal(((np.zeros((2, 2)), np.zeros(2)), ), params) out = shared_net.apply(params, inputs) assert np.array_equal(np.zeros((1, 2)), out) out = shared_net.apply(params, inputs, jit=True) assert np.array_equal(np.zeros((1, 2)), out)
def test_Sequential_graceful_update_message(): message = 'Call like Sequential(Dense(10), relu), without "[" and "]". ' \ '(Or pass iterables with Sequential(*layers).)' try: Sequential([Dense(2), relu]) assert False except ValueError as e: assert message == str(e) try: Sequential(Dense(2) for _ in range(3)) assert False except ValueError as e: assert message == str(e)
def test_flatten_shape(): conv = Conv(2, filter_shape=(3, 3), padding='SAME', kernel_init=zeros, bias_init=zeros) inputs = np.zeros((1, 5, 5, 2)) params = conv.init_parameters(PRNGKey(0), inputs) out = conv.apply(params, inputs) assert np.array_equal(np.zeros((1, 5, 5, 2)), out) flattened = Sequential(conv, flatten) out = flattened.apply_from({conv: params}, inputs) assert np.array_equal(np.zeros((1, 50)), out)
def test_diamond_shared_submodules(): p = Parameter(lambda rng: np.ones(())) a = Sequential(p) b = Sequential(p) @parametrized def net(inputs): return a(inputs), b(inputs) params = net.init_parameters(PRNGKey(0), np.zeros(())) assert 1 == len(params) assert np.array_equal(np.ones(()), params) a, b = net.apply(params, np.zeros(())) assert np.array_equal(np.ones(()), a) assert np.array_equal(np.ones(()), b)
def test_collection_input(type): @parametrized def net(inputs): assert isinstance(inputs, type) return inputs[0] * inputs[1] * parameter((), zeros) inputs = type((np.zeros(2), np.zeros(2))) params = net.init_parameters(inputs, key=PRNGKey(0)) out = net.apply(params, inputs) assert np.array_equal(np.zeros(2), out) net = Sequential(net) params = net.init_parameters(inputs, key=PRNGKey(0)) out = net.apply(params, inputs) assert np.array_equal(np.zeros(2), out)
def identity_block(inputs): main = Sequential(Conv(filters1, (1, 1)), BatchNorm(), relu, Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), relu, Conv(inputs.shape[3], (1, 1)), BatchNorm()) return relu(sum((main(inputs), inputs)))
def main(batch_size=256, env_name="CartPole-v1"): env = gym.make(env_name) policy = Sequential(Dense(64), relu, Dense(env.action_space.n)) @parametrized def loss(observations, actions, rewards_to_go): logprobs = log_softmax(policy(observations)) action_logprobs = logprobs[np.arange(logprobs.shape[0]), actions] return -np.mean(action_logprobs * rewards_to_go, axis=0) opt = Adam() shaped_loss = loss.shaped(np.zeros((1, ) + env.observation_space.shape), np.array([0]), np.array([0])) @jit def sample_action(state, key, observation): loss_params = opt.get_parameters(state) logits = policy.apply_from({shaped_loss: loss_params}, observation) return sample_categorical(key, logits) rng_init, rng = random.split(PRNGKey(0)) state = opt.init(shaped_loss.init_parameters(key=rng_init)) returns, observations, actions, rewards_to_go = [], [], [], [] for i in range(250): while len(observations) < batch_size: observation = env.reset() episode_done = False rewards = [] while not episode_done: rng_step, rng = random.split(rng) action = sample_action(state, rng_step, observation) observations.append(observation) actions.append(action) observation, reward, episode_done, info = env.step(int(action)) rewards.append(reward) returns.append(onp.sum(rewards)) rewards_to_go += list(onp.flip(onp.cumsum(onp.flip(rewards)))) print(f'Batch {i}, recent mean return: {onp.mean(returns[-100:]):.1f}') state = opt.update(loss.apply, state, np.array(observations[:batch_size]), np.array(actions[:batch_size]), np.array(rewards_to_go[:batch_size]), jit=True) observations = observations[batch_size:] actions = actions[batch_size:] rewards_to_go = rewards_to_go[batch_size:]
def test_params_from_shared_submodules2(): sublayer = Dense(2) a = Sequential(sublayer, relu) b = Sequential(sublayer, np.sum) @parametrized def net(inputs): return a(inputs), b(inputs) inputs = np.zeros((1, 3)) a_params = a.init_parameters(PRNGKey(0), inputs) out = a.apply(a_params, inputs) params = net.parameters_from({a: a_params}, inputs) assert_dense_params_equal(a_params.dense, params.sequential0.dense) assert_dense_params_equal(a_params.dense, params.sequential1.dense) # TODO parameters are duplicated, optimization with weight sharing is wrong: # TODO instead: assert 1 == len(params) out_, _ = net.apply(params, inputs) assert np.array_equal(out, out_)
def wavenet(inputs): hidden = Conv1D(residual_channels, (initial_filter_width, ))(inputs) out = np.zeros((hidden.shape[0], out_width, residual_channels), 'float32') for dilation in dilations: res = ResLayer(dilation_channels, residual_channels, filter_width, dilation, out_width)(hidden) hidden, out_partial = res out += out_partial return Sequential(relu, Conv1D(skip_channels, (1, )), relu, Conv1D(3 * nr_mix, (1, )))(out)
def test_ocr_rnn(): length = 5 carry_size = 3 class_count = 4 inputs = jnp.zeros((1, length, 4)) def rnn(): return Rnn(*GRUCell(carry_size, zeros)) net = Sequential( rnn(), rnn(), rnn(), lambda x: jnp.reshape(x, (-1, carry_size) ), # -> same weights for all time steps Dense(class_count, zeros, zeros), softmax, lambda x: jnp.reshape(x, (-1, length, class_count))) params = net.init_parameters(inputs, key=PRNGKey(0)) assert len(params) == 4 cell = params.rnn0.gru_cell assert len(cell) == 3 assert jnp.array_equal(jnp.zeros((7, 3)), cell.update_kernel) assert jnp.array_equal(jnp.zeros((7, 3)), cell.reset_kernel) assert jnp.array_equal(jnp.zeros((7, 3)), cell.compute_kernel) out = net.apply(params, inputs) @parametrized def cross_entropy(images, targets): prediction = net(images) return jnp.mean(-jnp.sum(targets * jnp.log(prediction), (1, 2))) opt = optimizers.RmsProp(0.003) state = opt.init(cross_entropy.init_parameters(inputs, out, key=PRNGKey(0))) state = opt.update(cross_entropy.apply, state, inputs, out) opt.update(cross_entropy.apply, state, inputs, out, jit=True)
def test_L2Regularized_sequential(): loss = Sequential(Dense(1, ones, ones), relu, Dense(1, ones, ones), sum) reg_loss = L2Regularized(loss, scale=2) inputs = np.ones(1) params = reg_loss.init_parameters(PRNGKey(0), inputs) assert np.array_equal(np.ones((1, 1)), params.model.dense0.kernel) assert np.array_equal(np.ones((1, 1)), params.model.dense1.kernel) reg_loss_out = reg_loss.apply(params, inputs) assert 7 == reg_loss_out
def test_parameters_from_shared_submodules(): sublayer = Dense(2) a = Sequential(sublayer, relu) b = Sequential(sublayer, np.sum) @parametrized def net(inputs): return a(inputs) * b(inputs) inputs = np.zeros((1, 3)) a_params = a.init_parameters(inputs, key=PRNGKey(0)) out = a.apply(a_params, inputs) params = net.parameters_from({a: a_params}, inputs) assert_parameters_equal(a_params.dense.kernel, params.sequential0.dense.kernel) assert_parameters_equal((), params.sequential1) out = net.apply(params, inputs) out_ = net.apply_from({a: a_params}, inputs) assert np.array_equal(out, out_) out_ = net.apply_from({a: a_params}, inputs, jit=True) assert np.array_equal(out, out_) out_ = net.apply_from({a.shaped(inputs): a_params}, inputs) assert np.array_equal(out, out_) out_ = net.apply_from({a.shaped(inputs): a_params}, inputs, jit=True) assert np.array_equal(out, out_) out_ = net.shaped(inputs).apply_from({a: a_params}) assert np.array_equal(out, out_) out_ = net.shaped(inputs).apply_from({a: a_params}, jit=True) assert np.array_equal(out, out_) out_ = net.shaped(inputs).apply_from({a.shaped(inputs): a_params}) assert np.array_equal(out, out_) out_ = net.shaped(inputs).apply_from({a.shaped(inputs): a_params}, jit=True) assert np.array_equal(out, out_)
def ResNet50(num_classes): return Sequential( GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2), 'SAME'), BatchNorm(), relu, MaxPool((3, 3), strides=(2, 2)), ConvBlock(3, [64, 64, 256], strides=(1, 1)), IdentityBlock(3, [64, 64]), IdentityBlock(3, [64, 64]), ConvBlock(3, [128, 128, 512]), IdentityBlock(3, [128, 128]), IdentityBlock(3, [128, 128]), IdentityBlock(3, [128, 128]), ConvBlock(3, [256, 256, 1024]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), ConvBlock(3, [512, 512, 2048]), IdentityBlock(3, [512, 512]), IdentityBlock(3, [512, 512]), AvgPool((7, 7)), flatten, Dense(num_classes), logsoftmax)
def test_parameters_from_subsubmodule(): subsublayer = Dense(2) sublayer = Sequential(subsublayer, relu) net = Sequential(sublayer, np.sum) inputs = np.zeros((1, 3)) params = net.init_parameters(inputs, key=PRNGKey(0)) out = net.apply(params, inputs) subsublayer_params = subsublayer.init_parameters(inputs, key=PRNGKey(0)) params_ = net.parameters_from({subsublayer: subsublayer_params}, inputs) assert_dense_parameters_equal(subsublayer_params, params_[0][0]) out_ = net.apply(params_, inputs) assert out.shape == out_.shape out_ = net.apply_from({subsublayer: subsublayer_params}, inputs) assert out.shape == out_.shape out_ = net.apply_from({subsublayer: subsublayer_params}, inputs, jit=True) assert out.shape == out_.shape
def main(): # TODO https://github.com/JuliusKunze/jaxnet/issues/4 print("Sorry, this example does not work yet work with the new jax version.") return train, test = read_dataset() _, length, x_size = train.data.shape class_count = train.target.shape[2] carry_size = 200 batch_size = 10 def rnn(): return Rnn(*GRUCell(carry_size=carry_size, param_init=lambda rng, shape: random.normal(rng, shape) * 0.01)) net = Sequential( rnn(), rnn(), rnn(), lambda x: np.reshape(x, (-1, carry_size)), # -> same weights for all time steps Dense(out_dim=class_count), softmax, lambda x: np.reshape(x, (-1, length, class_count))) @parametrized def cross_entropy(images, targets): prediction = net(images) return np.mean(-np.sum(targets * np.log(prediction), (1, 2))) @parametrized def error(inputs, targets): prediction = net(inputs) return np.mean(np.not_equal(np.argmax(targets, 2), np.argmax(prediction, 2))) opt = optimizers.RmsProp(0.003) batch = train.sample(batch_size) params = cross_entropy.init_parameters(random.PRNGKey(0), batch.data, batch.target) state = opt.init(params) for epoch in range(10): params = get_params(state) e = error.apply_from({cross_entropy: params}, test.data, test.target, jit=True) print(f'Epoch {epoch} error {e * 100:.1f}') break # TODO https://github.com/JuliusKunze/jaxnet/issues/2 for _ in range(100): batch = train.sample(batch_size) state = opt.update(cross_entropy.apply, state, batch.data, batch.target, jit=True)
def test_submodule_reuse(): inputs = np.zeros((1, 2)) layer = Dense(5) net1 = Sequential(layer, Dense(2)) net2 = Sequential(layer, Dense(3)) layer_params = layer.init_parameters(PRNGKey(0), inputs) net1_params = net1.init_parameters(PRNGKey(1), inputs, reuse={layer: layer_params}) net2_params = net2.init_parameters(PRNGKey(2), inputs, reuse={layer: layer_params}) out1 = net1.apply(net1_params, inputs) assert out1.shape == (1, 2) out2 = net2.apply(net2_params, inputs) assert out2.shape == (1, 3) assert_dense_params_equal(layer_params, net1_params[0]) assert_dense_params_equal(layer_params, net2_params[0])
def test_parameters_from_sharing_between_multiple_parents(): a = Dense(2) b = Sequential(a, np.sum) @parametrized def net(inputs): return a(inputs), b(inputs) inputs = np.zeros((1, 3)) a_params = a.init_parameters(inputs, key=PRNGKey(0)) out = a.apply(a_params, inputs) params = net.parameters_from({a: a_params}, inputs) assert_dense_parameters_equal(a_params, params.dense) assert_parameters_equal((), params.sequential) assert 2 == len(params) out_, _ = net.apply(params, inputs) assert np.array_equal(out, out_)
def test_mnist_vae(): @parametrized def encode(input): input = Sequential(Dense(5), relu, Dense(5), relu)(input) mean = Dense(10)(input) variance = Sequential(Dense(10), softplus)(input) return mean, variance decode = Sequential(Dense(5), relu, Dense(5), relu, Dense(5 * 5)) @parametrized def elbo(key, images): mu_z, sigmasq_z = encode(images) logits_x = decode(gaussian_sample(key, mu_z, sigmasq_z)) return bernoulli_logpdf(logits_x, images) - gaussian_kl(mu_z, sigmasq_z) params = elbo.init_parameters(PRNGKey(0), np.zeros((32, 5 * 5)), key=PRNGKey(0)) assert (5, 10) == params.encode.sequential1.dense.kernel.shape
def test_no_reuse(): inputs = np.zeros((1, 2)) layer = Dense(5) net1 = Sequential(layer, Dense(2)) p1 = net1.init_parameters(inputs, key=PRNGKey(0)) net2 = Sequential(layer, Dense(3)) p2 = net2.init_parameters(inputs, key=PRNGKey(1)) assert p1[0].kernel.shape == p2[0].kernel.shape assert p1[0].bias.shape == p2[0].bias.shape assert not np.array_equal(p1[0][0], p2[0][0]) assert not np.array_equal(p1[0][1], p2[0][1])
def main(): train, test = read_dataset() _, length, x_size = train.data.shape class_count = train.target.shape[2] carry_size = 200 batch_size = 10 def rnn(): return Rnn(*GRUCell(carry_size=carry_size, param_init=lambda key, shape: random.normal(key, shape) * 0.01)) net = Sequential( rnn(), rnn(), rnn(), lambda x: np.reshape(x, (-1, carry_size)), # -> same weights for all time steps Dense(out_dim=class_count), softmax, lambda x: np.reshape(x, (-1, length, class_count))) @parametrized def cross_entropy(images, targets): prediction = net(images) return np.mean(-np.sum(targets * np.log(prediction), (1, 2))) @parametrized def error(inputs, targets): prediction = net(inputs) return np.mean(np.not_equal(np.argmax(targets, 2), np.argmax(prediction, 2))) opt = optimizers.RmsProp(0.003) batch = train.sample(batch_size) state = opt.init(cross_entropy.init_parameters(batch.data, batch.target, key=PRNGKey(0))) for epoch in range(10): params = opt.get_parameters(state) e = error.apply_from({cross_entropy: params}, test.data, test.target, jit=True) print(f'Epoch {epoch} error {e * 100:.1f}') for _ in range(100): batch = train.sample(batch_size) state = opt.update(cross_entropy.apply, state, batch.data, batch.target, jit=True)
def test_readme(): net = Sequential(Dense(1024), relu, Dense(1024), relu, Dense(4), log_softmax) @parametrized def loss(inputs, targets): return -np.mean(net(inputs) * targets) def next_batch(): return np.zeros((3, 784)), np.zeros((3, 4)) params = loss.init_parameters(*next_batch(), key=PRNGKey(0)) print(params.sequential.dense2.bias) # [-0.01101029, -0.00749435, -0.00952365, 0.00493979] assert np.allclose([-0.01101029, -0.00749435, -0.00952365, 0.00493979], params.sequential.dense2.bias) out = loss.apply(params, *next_batch()) assert () == out.shape out_ = loss.apply(params, *next_batch(), jit=True) assert out.shape == out_.shape