def main(): args = get_arguments() logdir = os.path.join(args.logdir, 'train', str(datetime.now())) with open(args.wavenet_params, 'r') as config_file: wavenet_params = json.load(config_file) sess = tf.Session() net = WaveNet( batch_size=1, dilations=wavenet_params['dilations'], filter_width=wavenet_params['filter_width'], residual_channels=wavenet_params['residual_channels'], dilation_channels=wavenet_params['dilation_channels'], quantization_channels=wavenet_params['quantization_channels'], use_biases=wavenet_params['use_biases']) samples = tf.placeholder(tf.int32) next_sample = net.predict_proba(samples) saver = tf.train.Saver() print('Restoring model from {}'.format(args.checkpoint)) saver.restore(sess, args.checkpoint) decode = net.decode(samples) quantization_steps = wavenet_params['quantization_steps'] waveform = np.random.randint(quantization_steps, size=(1, )).tolist() for step in range(args.samples): if len(waveform) > args.window: window = waveform[-args.window:] else: window = waveform prediction = sess.run(next_sample, feed_dict={samples: window}) sample = np.random.choice(np.arange(quantization_steps), p=prediction) waveform.append(sample) print('Sample {:3<d}/{:3<d}: {}'.format(step + 1, args.samples, sample)) if (args.wav_out_path and args.save_every and (step + 1) % args.save_every == 0): out = sess.run(decode, feed_dict={samples: waveform}) write_wav(out, wavenet_params['sample_rate'], args.wav_out_path) datestring = str(datetime.now()).replace(' ', 'T') writer = tf.train.SummaryWriter( os.path.join(logdir, 'generation', datestring)) tf.audio_summary('generated', decode, wavenet_params['sample_rate']) summaries = tf.merge_all_summaries() summary_out = sess.run(summaries, feed_dict={samples: np.reshape(waveform, [-1, 1])}) writer.add_summary(summary_out) if args.wav_out_path: out = sess.run(decode, feed_dict={samples: waveform}) write_wav(out, wavenet_params['sample_rate'], args.wav_out_path) print('Finished generating. The result can be viewed in TensorBoard.')
def main(): args = get_arguments() logdir = os.path.join(args.logdir, 'train', str(datetime.now())) with open(args.wavenet_params, 'r') as config_file: wavenet_params = json.load(config_file) sess = tf.Session() net = WaveNet( batch_size=1, dilations=wavenet_params['dilations'], filter_width=wavenet_params['filter_width'], residual_channels=wavenet_params['residual_channels'], dilation_channels=wavenet_params['dilation_channels'], quantization_channels=wavenet_params['quantization_channels'], skip_channels=wavenet_params['skip_channels'], use_biases=wavenet_params['use_biases'], fast_generation=args.fast_generation) samples = tf.placeholder(tf.int32) next_sample = net.predict_proba(samples) if args.fast_generation: sess.run(tf.initialize_all_variables()) sess.run(net.init_ops) variables_to_restore = { var.name[:-2]: var for var in tf.all_variables() if not ('state_buffer' in var.name or 'pointer' in var.name) } saver = tf.train.Saver(variables_to_restore) print('Restoring model from {}'.format(args.checkpoint)) saver.restore(sess, args.checkpoint) decode = mu_law_decode(samples, wavenet_params['quantization_channels']) quantization_channels = wavenet_params['quantization_channels'] if args.wav_seed: seed = create_seed(args.wav_seed, wavenet_params['sample_rate'], quantization_channels) waveform = sess.run(seed).tolist() else: waveform = np.random.randint(quantization_channels, size=(1, )).tolist() for step in range(args.samples): if args.fast_generation: window = waveform[-1] outputs = [next_sample] outputs.extend(net.push_ops) else: if len(waveform) > args.window: window = waveform[-args.window:] else: window = waveform outputs = [next_sample] prediction = sess.run(outputs, feed_dict={samples: window})[0] sample = np.random.choice(np.arange(quantization_channels), p=prediction) waveform.append(sample) print('Sample {:3<d}/{:3<d}: {}'.format(step + 1, args.samples, sample)) if (args.wav_out_path and args.save_every and (step + 1) % args.save_every == 0): out = sess.run(decode, feed_dict={samples: waveform}) write_wav(out, wavenet_params['sample_rate'], args.wav_out_path) datestring = str(datetime.now()).replace(' ', 'T') writer = tf.train.SummaryWriter( os.path.join(logdir, 'generation', datestring)) tf.audio_summary('generated', decode, wavenet_params['sample_rate']) summaries = tf.merge_all_summaries() summary_out = sess.run(summaries, feed_dict={samples: np.reshape(waveform, [-1, 1])}) writer.add_summary(summary_out) if args.wav_out_path: out = sess.run(decode, feed_dict={samples: waveform}) write_wav(out, wavenet_params['sample_rate'], args.wav_out_path) print('Finished generating. The result can be viewed in TensorBoard.')