def test():
    np.random.seed(42)
    audio, speaker_ids = make_sine_waves(None)
    dilations = [2**i for i in range(7)] * 2
    receptive_field = WaveNet.calculate_receptive_field(2, dilations)
    audio = np.pad(audio, (receptive_field - 1, 0),
                   'constant').astype(np.float32)

    encoded = mu_law_encode(audio, 2**8)
    encoded = encoded[np.newaxis, :]
    encoded_one_hot = one_hot(encoded, 2**8)

    signal_length = int(tf.shape(encoded_one_hot)[1] - 1)
    input_one_hot = tf.slice(encoded_one_hot, [0, 0, 0],
                             [-1, signal_length, -1])
    target_one_hot = tf.slice(encoded_one_hot, [0, receptive_field, 0],
                              [-1, -1, -1])
    print('input shape: ', tf.shape(input_one_hot))
    print('output shape: ', tf.shape(target_one_hot))

    net = WaveNet(1, dilations, 2, signal_length, 32, 32, 32, 2**8, True, 0.01)
    net.build(input_shape=(None, signal_length, 2**8))
    optimizer = Adam(lr=1e-3)

    for epoch in range(301):
        with tf.GradientTape() as tape:
            # [b, 1254, 256] => [b, 999, 256]
            logits = net(input_one_hot, training=True)
            # [b, 999, 256] => [b * 999, 256]
            logits = tf.reshape(logits, [-1, 2**8])
            target_one_hot = tf.reshape(target_one_hot, [-1, 2**8])
            # comput loss
            loss = tf.losses.categorical_crossentropy(target_one_hot,
                                                      logits,
                                                      from_logits=True)
            loss = tf.reduce_mean(loss)

        grads = tape.gradient(loss, net.trainable_variables)
        optimizer.apply_gradients(zip(grads, net.trainable_variables))
        if epoch % 100 == 0:
            print(epoch, 'loss: ', float(loss))
Beispiel #2
0
		test_list = f.readlines()
		num_test = len(f.readlines())

	# Enqueue jobs
	for i in range(num_train):
		tasks.put(Task(train_list[i], DATA_SAVE_DIR, 'train', i))
	
	for i in range(num_test):
		tasks.put(Task(test_list[i], DATA_SAVE_DIR, 'test', i))

	# Add a poison pill for each consumer
	for i in range(num_consumers):
		tasks.put(None)
	
	wvn = WaveNet(input_dim=256+406+2, dilations=[1,2,4,8,16,32,64,128,256,512], filter_width=2)
	wvn.build()
	wvn.compile()
	wvn.plot()
	wvn.add_callbacks(os.path.join(CKPT_PATH,'weights.epoch001.{epoch:02d}.hdf5'), None)

	# Start 1st epoch training
	num_jobs = num_train
	train_files = []
	train_times = []
	while num_jobs:
		f = results_tr.get()
		train_files.append(f)
		# Model training
		start = T()
		wvn.fit_on_file(f)
		end = T()