def main(): args = get_arguments() try: directories = validate_directories(args) except ValueError as e: print("Some arguments are wrong:") print(str(e)) return logdir = directories['logdir'] restore_from = directories['restore_from'] # Even if we restored the model, we will treat it as new training # if the trained model is written into an arbitrary location. is_overwritten_training = logdir != restore_from with open(args.wavenet_params, 'r') as f: wavenet_params = json.load(f) # Create coordinator. coord = tf.train.Coordinator() # Load raw waveform from VCTK corpus. with tf.name_scope('create_inputs'): # Allow silence trimming to be skipped by specifying a threshold near # zero. silence_threshold = args.silence_threshold if args.silence_threshold > EPSILON else None gc_enabled = args.gc_channels is not None reader = AudioReader( args.data_dir, coord, sample_rate=wavenet_params['sample_rate'], gc_enabled=gc_enabled, receptive_field=WaveNetModel.calculate_receptive_field( wavenet_params['filter_width'], wavenet_params['dilations'], wavenet_params['scalar_input'], wavenet_params['initial_filter_width']), sample_size=args.sample_size, silence_threshold=silence_threshold) audio_batch = reader.dequeue(args.batch_size) if gc_enabled: gc_id_batch = reader.dequeue_gc(args.batch_size) else: gc_id_batch = None
def main(): args = get_arguments() try: directories = validate_directories(args) except ValueError as e: print("Some arguments are wrong:") print(str(e)) return logdir = directories['logdir'] logdir_root = directories['logdir_root'] restore_from = directories['restore_from'] # Even if we restored the model, we will treat it as new training # if the trained model is written into an arbitrary location. is_overwritten_training = logdir != restore_from with open(args.wavenet_params, 'r') as f: wavenet_params = json.load(f) # Create coordinator. coord = tf.train.Coordinator() # Load raw waveform from VCTK corpus. with tf.name_scope('create_inputs'): reader = AudioReader(args.data_dir, coord, sample_rate=wavenet_params['sample_rate'], sample_size=args.sample_size) audio_batch = reader.dequeue(args.batch_size) # Create network. net = WaveNet( batch_size=args.batch_size, dilations=wavenet_params["dilations"], filter_width=wavenet_params["filter_width"], residual_channels=wavenet_params["residual_channels"], dilation_channels=wavenet_params["dilation_channels"], skip_channels=wavenet_params["skip_channels"], quantization_channels=wavenet_params["quantization_channels"], use_biases=wavenet_params["use_biases"]) loss = net.loss(audio_batch) optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate) trainable = tf.trainable_variables() optim = optimizer.minimize(loss, var_list=trainable) # Set up logging for TensorBoard. writer = tf.train.SummaryWriter(logdir) writer.add_graph(tf.get_default_graph()) run_metadata = tf.RunMetadata() summaries = tf.merge_all_summaries() # Set up session sess = tf.Session(config=tf.ConfigProto(log_device_placement=False)) init = tf.initialize_all_variables() sess.run(init) # Saver for storing checkpoints of the model. saver = tf.train.Saver() try: saved_global_step = load(saver, sess, restore_from) if is_overwritten_training or saved_global_step is None: # The first training step will be saved_global_step + 1, # therefore we put -1 here for new or overwritten trainings. saved_global_step = -1 except: print("Something went wrong while restoring checkpoint. " "We will terminate training to avoid accidentally overwriting " "the previous model.") raise threads = tf.train.start_queue_runners(sess=sess, coord=coord) reader.start_threads(sess) try: last_saved_step = saved_global_step for step in range(saved_global_step + 1, args.num_steps): start_time = time.time() if args.store_metadata and step % 50 == 0: # Slow run that stores extra information for debugging. print('Storing metadata') run_options = tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE) summary, loss_value, _ = sess.run([summaries, loss, optim], options=run_options, run_metadata=run_metadata) writer.add_summary(summary, step) writer.add_run_metadata(run_metadata, 'step_{:04d}'.format(step)) tl = timeline.Timeline(run_metadata.step_stats) timeline_path = os.path.join(logdir, 'timeline.trace') with open(timeline_path, 'w') as f: f.write(tl.generate_chrome_trace_format(show_memory=True)) else: summary, loss_value, _ = sess.run([summaries, loss, optim]) writer.add_summary(summary, step) duration = time.time() - start_time print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'.format( step, loss_value, duration)) if step % 50 == 0: save(saver, sess, logdir, step) last_saved_step = step except KeyboardInterrupt: # Introduce a line break after ^C is displayed so save message # is on its own line. print() finally: if step > last_saved_step: save(saver, sess, logdir, step) coord.request_stop() coord.join(threads)
#AUDIO_FILE_PATH = '/Users/andrewszot/Downloads/VCTK-Corpus' gc_enabled = False reader = AudioReader( AUDIO_FILE_PATH, coord, sample_rate=wavenet_params['sample_rate'], gc_enabled=gc_enabled, receptive_field=calculate_receptive_field(wavenet_params["filter_width"], wavenet_params["dilations"], wavenet_params["scalar_input"], wavenet_params["initial_filter_width"]), sample_size=39939, silence_threshold=silence_threshold) audio_batch = reader.dequeue(1) if gc_enabled: gc_id_batch = reader.dequeue_gc(1) else: gc_id_batch = None global_step = tf.Variable(0, trainable=False) sess = tf.Session(config=tf.ConfigProto(log_device_placement=False)) threads = tf.train.start_queue_runners(sess=sess, coord=coord) reader.start_threads(sess) saver = tf.train.Saver() try: # 100K