def testEncodeDecode(self): x = np.linspace(-1, 1, 1000).astype(np.float32) channels = 256 # Test whether decoded signal is roughly equal to # what was encoded before with self.test_session() as sess: encoded = mu_law_encode(x, channels) x1 = sess.run(mu_law_decode(encoded, channels)) self.assertAllClose(x, x1, rtol=1e-1, atol=0.05) # Make sure that re-encoding leaves the waveform invariant with self.test_session() as sess: encoded = mu_law_encode(x1, channels) x2 = sess.run(mu_law_decode(encoded, channels)) self.assertAllClose(x1, x2)
def testDecodeZeros(self): np.random.seed(40) channels = 128 number_of_samples = 100 x = np.zeros(number_of_samples) y = manual_mu_law_encode(x, channels) decoded_manual = manual_mu_law_decode(y, channels) with self.test_session() as sess: decode = sess.run(mu_law_decode(y, channels)) self.assertAllEqual(decoded_manual, decode)
def testDecodeUniformRandomNoise(self): np.random.seed(40) channels = 128 number_of_samples = 512 x = np.random.uniform(-1, 1, number_of_samples) y = manual_mu_law_encode(x, channels) decoded_manual = manual_mu_law_decode(y, channels) with self.test_session() as sess: decode = sess.run(mu_law_decode(y, channels)) self.assertAllEqual(decoded_manual, decode)
def testDecodeUniformRandomNoise(self): np.random.seed(1944) # For repeatability of test. channels = 256 number_of_samples = 10 x = np.random.uniform(-1, 1, number_of_samples).astype(np.float32) y = manual_mu_law_encode(x, channels) manual_decode = manual_mu_law_decode(y, channels) with self.test_session() as sess: decode = sess.run(mu_law_decode(y, channels)) self.assertAllEqual(manual_decode, decode)
def testDecodeRamp(self): np.random.seed(40) channels = 128 number_of_samples = 512 number_of_steps = 2.0 / number_of_samples x = np.arange(-1.0, 1.0, number_of_steps) y = manual_mu_law_encode(x, channels) decoded_manual = manual_mu_law_decode(y, channels) with self.test_session() as sess: decode = sess.run(mu_law_decode(y, channels)) self.assertAllEqual(decoded_manual, decode)
def testDecodeRandomConstant(self): np.random.seed(40) channels = 128 number_of_samples = 512 x = np.zeros(number_of_samples) x.fill(np.random.uniform(-1, 1)) y = manual_mu_law_encode(x, channels) decoded_manual = manual_mu_law_decode(y, channels) with self.test_session() as sess: decode = sess.run(mu_law_decode(y, channels)) self.assertAllEqual(decoded_manual, decode)
def main(): args = get_arguments() logdir = os.path.join(args.logdir, 'train', str(datetime.now())) with open(args.wavenet_params, 'r') as config_file: wavenet_params = json.load(config_file) sess = tf.Session() net = WaveNet( batch_size=1, dilations=wavenet_params['dilations'], filter_width=wavenet_params['filter_width'], residual_channels=wavenet_params['residual_channels'], dilation_channels=wavenet_params['dilation_channels'], quantization_channels=wavenet_params['quantization_channels'], skip_channels=wavenet_params['skip_channels'], use_biases=wavenet_params['use_biases'], fast_generation=args.fast_generation) samples = tf.placeholder(tf.int32) next_sample = net.predict_proba(samples) if args.fast_generation: sess.run(tf.initialize_all_variables()) sess.run(net.init_ops) variables_to_restore = { var.name[:-2]: var for var in tf.all_variables() if not ('state_buffer' in var.name or 'pointer' in var.name) } saver = tf.train.Saver(variables_to_restore) print('Restoring model from {}'.format(args.checkpoint)) saver.restore(sess, args.checkpoint) decode = mu_law_decode(samples, wavenet_params['quantization_channels']) quantization_channels = wavenet_params['quantization_channels'] if args.wav_seed: seed = create_seed(args.wav_seed, wavenet_params['sample_rate'], quantization_channels) waveform = sess.run(seed).tolist() else: waveform = np.random.randint(quantization_channels, size=(1, )).tolist() for step in range(args.samples): if args.fast_generation: window = waveform[-1] outputs = [next_sample] outputs.extend(net.push_ops) else: if len(waveform) > args.window: window = waveform[-args.window:] else: window = waveform outputs = [next_sample] prediction = sess.run(outputs, feed_dict={samples: window})[0] sample = np.random.choice(np.arange(quantization_channels), p=prediction) waveform.append(sample) print('Sample {:3<d}/{:3<d}: {}'.format(step + 1, args.samples, sample)) if (args.wav_out_path and args.save_every and (step + 1) % args.save_every == 0): out = sess.run(decode, feed_dict={samples: waveform}) write_wav(out, wavenet_params['sample_rate'], args.wav_out_path) datestring = str(datetime.now()).replace(' ', 'T') writer = tf.train.SummaryWriter( os.path.join(logdir, 'generation', datestring)) tf.audio_summary('generated', decode, wavenet_params['sample_rate']) summaries = tf.merge_all_summaries() summary_out = sess.run(summaries, feed_dict={samples: np.reshape(waveform, [-1, 1])}) writer.add_summary(summary_out) if args.wav_out_path: out = sess.run(decode, feed_dict={samples: waveform}) write_wav(out, wavenet_params['sample_rate'], args.wav_out_path) print('Finished generating. The result can be viewed in TensorBoard.')
def testDecodeNegativeDilation(self): channels = 10 y = [0, 255, 243, 31, 156, 229, 0, 235, 202, 18] with self.test_session() as sess: self.assertRaises(TypeError, sess.run(mu_law_decode(y, channels)))