def test_wavenet(): filter_width = 2 initial_filter_width = 3 residual_channels = 4 dilation_channels = 5 skip_channels = 6 dilations = [1, 2] nr_mix = 10 receptive_field = calculate_receptive_field(filter_width, dilations, initial_filter_width) batch = random.normal(PRNGKey(0), (1, receptive_field + 1000, 1)) output_width = batch.shape[1] - receptive_field + 1 wavenet = Wavenet(dilations, filter_width, initial_filter_width, output_width, residual_channels, dilation_channels, skip_channels, nr_mix) @parametrized def loss(batch): theta = wavenet(batch)[:, :-1, :] # now slice the padding off the batch sliced_batch = batch[:, receptive_field:, :] return (np.mean(discretized_mix_logistic_loss( theta, sliced_batch, num_class=1 << 16), axis=0) * np.log2(np.e) / (output_width - 1)) loss = L2Regularized(loss, .01) opt = optimizers.Adam(optimizers.exponential_decay(1e-3, decay_steps=1, decay_rate=0.999995)) state = opt.init(loss.init_parameters(batch, key=PRNGKey(0))) state, train_loss = opt.update_and_get_loss(loss.apply, state, batch, jit=True) trained_params = opt.get_parameters(state) assert () == train_loss.shape
def test_mnist_classifier(): from examples.mnist_classifier import predict, loss, accuracy next_batch = lambda: (np.zeros((3, 784)), np.zeros((3, 10))) opt = optimizers.Momentum(0.001, mass=0.9) state = opt.init(loss.init_parameters(*next_batch(), key=PRNGKey(0))) t = time.time() for _ in range(10): state = opt.update(loss.apply, state, *next_batch(), jit=True) elapsed = time.time() - t assert 5 > elapsed params = opt.get_parameters(state) train_acc = accuracy.apply_from({loss: params}, *next_batch(), jit=True) assert () == train_acc.shape predict_params = predict.parameters_from({loss.shaped(*next_batch()): params}, next_batch()[0]) predictions = predict.apply(predict_params, next_batch()[0], jit=True) assert (3, 10) == predictions.shape
def test_pixelcnn(): loss, _ = PixelCNNPP(nr_filters=1, nr_resnet=1) images = np.zeros((2, 16, 16, 3), image_dtype) opt = optimizers.Adam() state = opt.init(loss.init_parameters(images, key=PRNGKey(0)))