Beispiel #1
0
def test_wavenet():
    filter_width = 2
    initial_filter_width = 3
    residual_channels = 4
    dilation_channels = 5
    skip_channels = 6
    dilations = [1, 2]
    nr_mix = 10
    receptive_field = calculate_receptive_field(filter_width, dilations,
                                                initial_filter_width)

    batch = random.normal(PRNGKey(0), (1, receptive_field + 1000, 1))
    output_width = batch.shape[1] - receptive_field + 1

    wavenet = Wavenet(dilations, filter_width, initial_filter_width,
                      output_width, residual_channels, dilation_channels,
                      skip_channels, nr_mix)

    @parametrized
    def loss(batch):
        theta = wavenet(batch)[:, :-1, :]
        # now slice the padding off the batch
        sliced_batch = batch[:, receptive_field:, :]
        return (np.mean(discretized_mix_logistic_loss(
            theta, sliced_batch, num_class=1 << 16), axis=0)
                * np.log2(np.e) / (output_width - 1))

    loss = L2Regularized(loss, .01)

    opt = optimizers.Adam(optimizers.exponential_decay(1e-3, decay_steps=1, decay_rate=0.999995))
    state = opt.init(loss.init_parameters(batch, key=PRNGKey(0)))
    state, train_loss = opt.update_and_get_loss(loss.apply, state, batch, jit=True)
    trained_params = opt.get_parameters(state)
    assert () == train_loss.shape
Beispiel #2
0
def test_mnist_classifier():
    from examples.mnist_classifier import predict, loss, accuracy

    next_batch = lambda: (np.zeros((3, 784)), np.zeros((3, 10)))
    opt = optimizers.Momentum(0.001, mass=0.9)
    state = opt.init(loss.init_parameters(*next_batch(), key=PRNGKey(0)))

    t = time.time()
    for _ in range(10):
        state = opt.update(loss.apply, state, *next_batch(), jit=True)

    elapsed = time.time() - t
    assert 5 > elapsed

    params = opt.get_parameters(state)
    train_acc = accuracy.apply_from({loss: params}, *next_batch(), jit=True)
    assert () == train_acc.shape

    predict_params = predict.parameters_from({loss.shaped(*next_batch()): params}, next_batch()[0])
    predictions = predict.apply(predict_params, next_batch()[0], jit=True)
    assert (3, 10) == predictions.shape
Beispiel #3
0
def test_pixelcnn():
    loss, _ = PixelCNNPP(nr_filters=1, nr_resnet=1)
    images = np.zeros((2, 16, 16, 3), image_dtype)
    opt = optimizers.Adam()
    state = opt.init(loss.init_parameters(images, key=PRNGKey(0)))