Пример #1
0
def test_wavenet():
    filter_width = 2
    initial_filter_width = 3
    residual_channels = 4
    dilation_channels = 5
    skip_channels = 6
    dilations = [1, 2]
    nr_mix = 10
    receptive_field = calculate_receptive_field(filter_width, dilations,
                                                initial_filter_width)

    batch = random.normal(PRNGKey(0), (1, receptive_field + 1000, 1))
    output_width = batch.shape[1] - receptive_field + 1

    wavenet = Wavenet(dilations, filter_width, initial_filter_width,
                      output_width, residual_channels, dilation_channels,
                      skip_channels, nr_mix)

    @parametrized
    def loss(batch):
        theta = wavenet(batch)[:, :-1, :]
        # now slice the padding off the batch
        sliced_batch = batch[:, receptive_field:, :]
        return (np.mean(discretized_mix_logistic_loss(
            theta, sliced_batch, num_class=1 << 16), axis=0)
                * np.log2(np.e) / (output_width - 1))

    loss = L2Regularized(loss, .01)

    opt = optimizers.Adam(optimizers.exponential_decay(1e-3, decay_steps=1, decay_rate=0.999995))
    state = opt.init(loss.init_parameters(batch, key=PRNGKey(0)))
    state, train_loss = opt.update_and_get_loss(loss.apply, state, batch, jit=True)
    trained_params = opt.get_parameters(state)
    assert () == train_loss.shape
Пример #2
0
def main():
    filter_width = 2
    initial_filter_width = 32
    residual_channels = 32
    dilation_channels = 32
    skip_channels = 512
    dilations = [
        1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1, 2, 4, 8, 16, 32, 64, 128,
        256, 512, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1, 2, 4, 8, 16, 32,
        64, 128, 256, 512, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512
    ]
    nr_mix = 10
    receptive_field = calculate_receptive_field(filter_width, dilations,
                                                initial_filter_width)

    def get_batches(batches=100, sequence_length=1000, rng=PRNGKey(0)):
        for _ in range(batches):
            rng, rng_now = random.split(rng)
            yield random.normal(rng_now,
                                (1, receptive_field + sequence_length, 1))

    batches = get_batches()
    init_batch = next(batches)
    output_width = init_batch.shape[1] - receptive_field + 1

    wavenet = Wavenet(dilations, filter_width, initial_filter_width,
                      output_width, residual_channels, dilation_channels,
                      skip_channels, nr_mix)

    @parametrized
    def loss(batch):
        theta = wavenet(batch)[:, :-1, :]
        # now slice the padding off the batch
        sliced_batch = batch[:, receptive_field:, :]
        return (np.mean(discretized_mix_logistic_loss(
            theta, sliced_batch, num_class=1 << 16),
                        axis=0) * np.log2(np.e) / (output_width - 1))

    loss = L2Regularized(loss, .01)

    opt = optimizers.Adam(
        optimizers.exponential_decay(1e-3, decay_steps=1, decay_rate=0.999995))
    print(f'Initializing parameters.')
    state = opt.init(loss.init_parameters(PRNGKey(0), next(batches)))
    for batch in batches:
        print(f'Training on batch {opt.get_step(state)}.')
        state, train_loss = opt.update_and_get_loss(loss.apply,
                                                    state,
                                                    batch,
                                                    jit=True)

    trained_params = opt.get_parameters(state)
Пример #3
0
def main(batch_size=32,
         epochs=10,
         step_size=.001,
         decay_rate=.999995,
         nr_filters=1,
         nr_resnet=0,
         dropout_p=.5):
    unbatched_loss = PixelCNNPP(nr_filters=nr_filters,
                                nr_resnet=nr_resnet,
                                dropout_p=dropout_p)

    @parametrized
    def loss(rng, batch):
        batch_size = batch.shape[0]
        loss = vmap(unbatched_loss, (None, 0, 0))
        rngs = random.split(rng, batch_size)
        losses = loss(rngs, batch)
        assert losses.shape == (batch_size, )
        return np.mean(losses)

    get_train_batches, test_batches = dataset(batch_size)
    rng, rng_init_1, rng_init_2 = random.split(PRNGKey(0), 3)
    # TODO fix:
    params = unbatched_loss.init_parameters(rng_init_1, rng_init_2,
                                            next(test_batches)[0])
    # TODO fix batched version:
    # TODO rng_init_2 = random.split(rng_init_2, test_batch_size)
    # TODO vmap(loss, (0, 0))
    params = loss.init_parameters(rng_init_1, rng_init_2, next(test_batches))

    opt = optimizers.Adam(
        optimizers.exponential_decay(step_size, 1, decay_rate))
    state = opt.init(params)

    for epoch in range(epochs):
        for batch in get_train_batches():
            rng, rng_update = random.split(rng)
            i = opt.get_step(state)
            state, train_loss = opt.update_and_get_loss(
                loss.apply, state, rng_update, batch)

            if i % 100 == 0 or i < 10:
                rng, rng_test = random.split(rng)
                test_loss = loss(opt.get_parameters(state), rng_test,
                                 next(test_batches))
                print(f"Epoch {epoch}, iteration {i}, "
                      f"train loss {train_loss:.3f}, "
                      f"test loss {test_loss:.3f} ")
Пример #4
0
def test():
    @parametrized
    def loss(inputs, targets):
        return -np.mean(
            Sequential(Dense(4), relu, Dense(4), logsoftmax)(inputs) * targets)

    def next_batch():
        return np.zeros((3, 784)), np.zeros((3, 4))

    params = loss.init_parameters(PRNGKey(0), *next_batch())

    opt = optimizers.Adam()
    state = opt.init(params)
    for _ in range(3):
        state = opt.update(loss.apply, state, *next_batch(), jit=True)

    for _ in range(3):
        state, l = opt.update_and_get_loss(loss.apply,
                                           state,
                                           *next_batch(),
                                           jit=True)
        assert () == l.shape

    assert 6 == opt.get_step(state)
    assert 6 == state.step
    assert (4, 4) == opt.get_parameters(state).sequential.dense1.kernel.shape

    out = loss.apply(opt.get_parameters(state), *next_batch())
    assert () == out.shape

    # TODO waiting for https://github.com/google/jax/issues/1278
    # path = Path('/tmp') / 'test.params'
    # save_params(state, path)
    # state = load_params(path)

    assert 6 == opt.get_step(state)
    assert 6 == state.step
    assert (4, 4) == opt.get_parameters(state).sequential.dense1.kernel.shape

    out = loss.apply(opt.get_parameters(state), *next_batch())
    assert () == out.shape
Пример #5
0
def test_pixelcnn():
    loss, _ = PixelCNNPP(nr_filters=1, nr_resnet=1)
    images = np.zeros((2, 16, 16, 3), image_dtype)
    opt = optimizers.Adam()
    state = opt.init(loss.init_parameters(images, key=PRNGKey(0)))