def test_wavenet(): filter_width = 2 initial_filter_width = 3 residual_channels = 4 dilation_channels = 5 skip_channels = 6 dilations = [1, 2] nr_mix = 10 receptive_field = calculate_receptive_field(filter_width, dilations, initial_filter_width) batch = random.normal(PRNGKey(0), (1, receptive_field + 1000, 1)) output_width = batch.shape[1] - receptive_field + 1 wavenet = Wavenet(dilations, filter_width, initial_filter_width, output_width, residual_channels, dilation_channels, skip_channels, nr_mix) @parametrized def loss(batch): theta = wavenet(batch)[:, :-1, :] # now slice the padding off the batch sliced_batch = batch[:, receptive_field:, :] return (np.mean(discretized_mix_logistic_loss( theta, sliced_batch, num_class=1 << 16), axis=0) * np.log2(np.e) / (output_width - 1)) loss = L2Regularized(loss, .01) opt = optimizers.Adam(optimizers.exponential_decay(1e-3, decay_steps=1, decay_rate=0.999995)) state = opt.init(loss.init_parameters(batch, key=PRNGKey(0))) state, train_loss = opt.update_and_get_loss(loss.apply, state, batch, jit=True) trained_params = opt.get_parameters(state) assert () == train_loss.shape
def test_regularized_submodule(): net = Sequential(Conv(2, (1, 1)), relu, Conv(2, (1, 1)), relu, flatten, L2Regularized(Sequential(Dense(2), relu, Dense(2), np.sum), .1)) input = np.ones((1, 3, 3, 1)) params = net.init_parameters(input, key=PRNGKey(0)) assert (2, 2) == params.regularized.model.dense1.kernel.shape out = net.apply(params, input) assert () == out.shape
def main(): filter_width = 2 initial_filter_width = 32 residual_channels = 32 dilation_channels = 32 skip_channels = 512 dilations = [ 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512 ] nr_mix = 10 receptive_field = calculate_receptive_field(filter_width, dilations, initial_filter_width) def get_batches(batches=100, sequence_length=1000, rng=PRNGKey(0)): for _ in range(batches): rng, rng_now = random.split(rng) yield random.normal(rng_now, (1, receptive_field + sequence_length, 1)) batches = get_batches() init_batch = next(batches) output_width = init_batch.shape[1] - receptive_field + 1 wavenet = Wavenet(dilations, filter_width, initial_filter_width, output_width, residual_channels, dilation_channels, skip_channels, nr_mix) @parametrized def loss(batch): theta = wavenet(batch)[:, :-1, :] # now slice the padding off the batch sliced_batch = batch[:, receptive_field:, :] return (np.mean(discretized_mix_logistic_loss( theta, sliced_batch, num_class=1 << 16), axis=0) * np.log2(np.e) / (output_width - 1)) loss = L2Regularized(loss, .01) opt = optimizers.Adam( optimizers.exponential_decay(1e-3, decay_steps=1, decay_rate=0.999995)) print(f'Initializing parameters.') state = opt.init(loss.init_parameters(PRNGKey(0), next(batches))) for batch in batches: print(f'Training on batch {opt.get_step(state)}.') state, train_loss = opt.update_and_get_loss(loss.apply, state, batch, jit=True) trained_params = opt.get_parameters(state)
def test_L2Regularized_sequential(): loss = Sequential(Dense(1, ones, ones), relu, Dense(1, ones, ones), sum) reg_loss = L2Regularized(loss, scale=2) inputs = np.ones(1) params = reg_loss.init_parameters(PRNGKey(0), inputs) assert np.array_equal(np.ones((1, 1)), params.model.dense0.kernel) assert np.array_equal(np.ones((1, 1)), params.model.dense1.kernel) reg_loss_out = reg_loss.apply(params, inputs) assert 7 == reg_loss_out
def test_L2Regularized(): @parametrized def loss(inputs): a = parameter((), ones, inputs, 'a') b = parameter((), lambda rng, shape: 2 * np.ones(shape), inputs, 'b') return a + b reg_loss = L2Regularized(loss, scale=2) inputs = np.zeros(()) params = reg_loss.init_parameters(PRNGKey(0), inputs) assert np.array_equal(np.ones(()), params.model.a) assert np.array_equal(2 * np.ones(()), params.model.b) reg_loss_out = reg_loss.apply(params, inputs) assert 1 + 2 + 1 + 4 == reg_loss_out