def test_wavenet(): filter_width = 2 initial_filter_width = 3 residual_channels = 4 dilation_channels = 5 skip_channels = 6 dilations = [1, 2] nr_mix = 10 receptive_field = calculate_receptive_field(filter_width, dilations, initial_filter_width) batch = random.normal(PRNGKey(0), (1, receptive_field + 1000, 1)) output_width = batch.shape[1] - receptive_field + 1 wavenet = Wavenet(dilations, filter_width, initial_filter_width, output_width, residual_channels, dilation_channels, skip_channels, nr_mix) @parametrized def loss(batch): theta = wavenet(batch)[:, :-1, :] # now slice the padding off the batch sliced_batch = batch[:, receptive_field:, :] return (np.mean(discretized_mix_logistic_loss( theta, sliced_batch, num_class=1 << 16), axis=0) * np.log2(np.e) / (output_width - 1)) loss = L2Regularized(loss, .01) opt = optimizers.Adam(optimizers.exponential_decay(1e-3, decay_steps=1, decay_rate=0.999995)) state = opt.init(loss.init_parameters(batch, key=PRNGKey(0))) state, train_loss = opt.update_and_get_loss(loss.apply, state, batch, jit=True) trained_params = opt.get_parameters(state) assert () == train_loss.shape
def main(): filter_width = 2 initial_filter_width = 32 residual_channels = 32 dilation_channels = 32 skip_channels = 512 dilations = [ 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512 ] nr_mix = 10 receptive_field = calculate_receptive_field(filter_width, dilations, initial_filter_width) def get_batches(batches=100, sequence_length=1000, rng=PRNGKey(0)): for _ in range(batches): rng, rng_now = random.split(rng) yield random.normal(rng_now, (1, receptive_field + sequence_length, 1)) batches = get_batches() init_batch = next(batches) output_width = init_batch.shape[1] - receptive_field + 1 wavenet = Wavenet(dilations, filter_width, initial_filter_width, output_width, residual_channels, dilation_channels, skip_channels, nr_mix) @parametrized def loss(batch): theta = wavenet(batch)[:, :-1, :] # now slice the padding off the batch sliced_batch = batch[:, receptive_field:, :] return (np.mean(discretized_mix_logistic_loss( theta, sliced_batch, num_class=1 << 16), axis=0) * np.log2(np.e) / (output_width - 1)) loss = L2Regularized(loss, .01) opt = optimizers.Adam( optimizers.exponential_decay(1e-3, decay_steps=1, decay_rate=0.999995)) print(f'Initializing parameters.') state = opt.init(loss.init_parameters(PRNGKey(0), next(batches))) for batch in batches: print(f'Training on batch {opt.get_step(state)}.') state, train_loss = opt.update_and_get_loss(loss.apply, state, batch, jit=True) trained_params = opt.get_parameters(state)
def main(batch_size=32, epochs=10, step_size=.001, decay_rate=.999995, nr_filters=1, nr_resnet=0, dropout_p=.5): unbatched_loss = PixelCNNPP(nr_filters=nr_filters, nr_resnet=nr_resnet, dropout_p=dropout_p) @parametrized def loss(rng, batch): batch_size = batch.shape[0] loss = vmap(unbatched_loss, (None, 0, 0)) rngs = random.split(rng, batch_size) losses = loss(rngs, batch) assert losses.shape == (batch_size, ) return np.mean(losses) get_train_batches, test_batches = dataset(batch_size) rng, rng_init_1, rng_init_2 = random.split(PRNGKey(0), 3) # TODO fix: params = unbatched_loss.init_parameters(rng_init_1, rng_init_2, next(test_batches)[0]) # TODO fix batched version: # TODO rng_init_2 = random.split(rng_init_2, test_batch_size) # TODO vmap(loss, (0, 0)) params = loss.init_parameters(rng_init_1, rng_init_2, next(test_batches)) opt = optimizers.Adam( optimizers.exponential_decay(step_size, 1, decay_rate)) state = opt.init(params) for epoch in range(epochs): for batch in get_train_batches(): rng, rng_update = random.split(rng) i = opt.get_step(state) state, train_loss = opt.update_and_get_loss( loss.apply, state, rng_update, batch) if i % 100 == 0 or i < 10: rng, rng_test = random.split(rng) test_loss = loss(opt.get_parameters(state), rng_test, next(test_batches)) print(f"Epoch {epoch}, iteration {i}, " f"train loss {train_loss:.3f}, " f"test loss {test_loss:.3f} ")
def test(): @parametrized def loss(inputs, targets): return -np.mean( Sequential(Dense(4), relu, Dense(4), logsoftmax)(inputs) * targets) def next_batch(): return np.zeros((3, 784)), np.zeros((3, 4)) params = loss.init_parameters(PRNGKey(0), *next_batch()) opt = optimizers.Adam() state = opt.init(params) for _ in range(3): state = opt.update(loss.apply, state, *next_batch(), jit=True) for _ in range(3): state, l = opt.update_and_get_loss(loss.apply, state, *next_batch(), jit=True) assert () == l.shape assert 6 == opt.get_step(state) assert 6 == state.step assert (4, 4) == opt.get_parameters(state).sequential.dense1.kernel.shape out = loss.apply(opt.get_parameters(state), *next_batch()) assert () == out.shape # TODO waiting for https://github.com/google/jax/issues/1278 # path = Path('/tmp') / 'test.params' # save_params(state, path) # state = load_params(path) assert 6 == opt.get_step(state) assert 6 == state.step assert (4, 4) == opt.get_parameters(state).sequential.dense1.kernel.shape out = loss.apply(opt.get_parameters(state), *next_batch()) assert () == out.shape
def test_pixelcnn(): loss, _ = PixelCNNPP(nr_filters=1, nr_resnet=1) images = np.zeros((2, 16, 16, 3), image_dtype) opt = optimizers.Adam() state = opt.init(loss.init_parameters(images, key=PRNGKey(0)))