Ejemplo n.º 1
0
def test_wavenet():
    filter_width = 2
    initial_filter_width = 3
    residual_channels = 4
    dilation_channels = 5
    skip_channels = 6
    dilations = [1, 2]
    nr_mix = 10
    receptive_field = calculate_receptive_field(filter_width, dilations,
                                                initial_filter_width)

    batch = random.normal(PRNGKey(0), (1, receptive_field + 1000, 1))
    output_width = batch.shape[1] - receptive_field + 1

    wavenet = Wavenet(dilations, filter_width, initial_filter_width,
                      output_width, residual_channels, dilation_channels,
                      skip_channels, nr_mix)

    @parametrized
    def loss(batch):
        theta = wavenet(batch)[:, :-1, :]
        # now slice the padding off the batch
        sliced_batch = batch[:, receptive_field:, :]
        return (np.mean(discretized_mix_logistic_loss(
            theta, sliced_batch, num_class=1 << 16), axis=0)
                * np.log2(np.e) / (output_width - 1))

    loss = L2Regularized(loss, .01)

    opt = optimizers.Adam(optimizers.exponential_decay(1e-3, decay_steps=1, decay_rate=0.999995))
    state = opt.init(loss.init_parameters(batch, key=PRNGKey(0)))
    state, train_loss = opt.update_and_get_loss(loss.apply, state, batch, jit=True)
    trained_params = opt.get_parameters(state)
    assert () == train_loss.shape
Ejemplo n.º 2
0
def test_regularized_submodule():
    net = Sequential(Conv(2, (1, 1)), relu, Conv(2, (1, 1)), relu, flatten,
                     L2Regularized(Sequential(Dense(2), relu, Dense(2), np.sum), .1))

    input = np.ones((1, 3, 3, 1))
    params = net.init_parameters(input, key=PRNGKey(0))
    assert (2, 2) == params.regularized.model.dense1.kernel.shape

    out = net.apply(params, input)
    assert () == out.shape
Ejemplo n.º 3
0
def main():
    filter_width = 2
    initial_filter_width = 32
    residual_channels = 32
    dilation_channels = 32
    skip_channels = 512
    dilations = [
        1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1, 2, 4, 8, 16, 32, 64, 128,
        256, 512, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1, 2, 4, 8, 16, 32,
        64, 128, 256, 512, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512
    ]
    nr_mix = 10
    receptive_field = calculate_receptive_field(filter_width, dilations,
                                                initial_filter_width)

    def get_batches(batches=100, sequence_length=1000, rng=PRNGKey(0)):
        for _ in range(batches):
            rng, rng_now = random.split(rng)
            yield random.normal(rng_now,
                                (1, receptive_field + sequence_length, 1))

    batches = get_batches()
    init_batch = next(batches)
    output_width = init_batch.shape[1] - receptive_field + 1

    wavenet = Wavenet(dilations, filter_width, initial_filter_width,
                      output_width, residual_channels, dilation_channels,
                      skip_channels, nr_mix)

    @parametrized
    def loss(batch):
        theta = wavenet(batch)[:, :-1, :]
        # now slice the padding off the batch
        sliced_batch = batch[:, receptive_field:, :]
        return (np.mean(discretized_mix_logistic_loss(
            theta, sliced_batch, num_class=1 << 16),
                        axis=0) * np.log2(np.e) / (output_width - 1))

    loss = L2Regularized(loss, .01)

    opt = optimizers.Adam(
        optimizers.exponential_decay(1e-3, decay_steps=1, decay_rate=0.999995))
    print(f'Initializing parameters.')
    state = opt.init(loss.init_parameters(PRNGKey(0), next(batches)))
    for batch in batches:
        print(f'Training on batch {opt.get_step(state)}.')
        state, train_loss = opt.update_and_get_loss(loss.apply,
                                                    state,
                                                    batch,
                                                    jit=True)

    trained_params = opt.get_parameters(state)
Ejemplo n.º 4
0
def test_L2Regularized_sequential():
    loss = Sequential(Dense(1, ones, ones), relu, Dense(1, ones, ones), sum)

    reg_loss = L2Regularized(loss, scale=2)

    inputs = np.ones(1)
    params = reg_loss.init_parameters(PRNGKey(0), inputs)
    assert np.array_equal(np.ones((1, 1)), params.model.dense0.kernel)
    assert np.array_equal(np.ones((1, 1)), params.model.dense1.kernel)

    reg_loss_out = reg_loss.apply(params, inputs)

    assert 7 == reg_loss_out
Ejemplo n.º 5
0
def test_L2Regularized():
    @parametrized
    def loss(inputs):
        a = parameter((), ones, inputs, 'a')
        b = parameter((), lambda rng, shape: 2 * np.ones(shape), inputs, 'b')

        return a + b

    reg_loss = L2Regularized(loss, scale=2)

    inputs = np.zeros(())
    params = reg_loss.init_parameters(PRNGKey(0), inputs)
    assert np.array_equal(np.ones(()), params.model.a)
    assert np.array_equal(2 * np.ones(()), params.model.b)

    reg_loss_out = reg_loss.apply(params, inputs)

    assert 1 + 2 + 1 + 4 == reg_loss_out