def main(): args = get_arguments() try: directories = validate_directories(args) except ValueError as e: print("Some arguments are wrong:") print(str(e)) return logdir = directories['logdir'] print("logdir", logdir) restore_from = directories['restore_from'] # Create coordinator. coord = tf.train.Coordinator() # Load raw waveform from VCTK corpus. with tf.name_scope('create_inputs'): # Allow silence trimming to be skipped by specifying a threshold near # zero. silence_threshold = args.silence_threshold if args.silence_threshold > \ EPSILON else None gc_enabled = args.gc_channels is not None reader = AudioReader(args.data_dir, args.test_dir, coord, sample_rate=args.sample_rate, gc_enabled=gc_enabled, sample_size=args.sample_size, silence_threshold=silence_threshold) audio_batch = reader.dequeue(args.batch_size) net = create_model(args) ########Multi GPU########### #''' if args.l2_regularization_strength == 0: args.l2_regularization_strength = None global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False) # Create optimizer (default is Adam) optim = optimizer_factory[args.optimizer](learning_rate=args.learning_rate, momentum=args.momentum) tower_grads = [] losses = [] speech_inputs_mix = [] speech_inputs_1 = [] speech_inputs_2 = [] for i in xrange(args.num_gpus): speech_inputs_2.append( tf.Variable(tf.zeros( [net.batch_size, net.seq_len, args.num_of_frequency_points]), trainable=False, name="speech_batch_inputs", dtype=tf.float32)) speech_inputs_1.append( tf.Variable(tf.zeros( [net.batch_size, net.seq_len, args.num_of_frequency_points]), trainable=False, name="speech_batch_inputs", dtype=tf.float32)) speech_inputs_mix.append( tf.Variable(tf.zeros( [net.batch_size, net.seq_len, args.num_of_frequency_points]), trainable=False, name="speech_batch_inputs", dtype=tf.float32)) with tf.variable_scope(tf.get_variable_scope()): for i in xrange(args.num_gpus): with tf.device('/gpu:%d' % i): with tf.name_scope('TOWER_%d' % (i)) as scope: # Create model. print("Creating model On Gpu:%d." % (i)) loss, mask_state, output1, output2 = net.loss_SampleRnn( speech_inputs_1[i], speech_inputs_2[i], speech_inputs_mix[i], l2_regularization_strength=args. l2_regularization_strength) # Reuse variables for the nect tower. tf.get_variable_scope().reuse_variables() # UNKNOWN losses.append(loss) trainable = tf.trainable_variables() for name in trainable: print(name) gradients = optim.compute_gradients(loss, trainable) #aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) #aggregation_method=tf.AggregationMethod.EXPERIMENTAL_TREE) print("==========================") for name in gradients: print(name) # Keep track of the gradients across all towers. tower_grads.append(gradients) # We must calculate the mean of each gradient. Note that this is the # synchronization point across all towers. grad_vars = average_gradients(tower_grads) # UNKNOWN grads, vars = zip(*grad_vars) grads_clipped, _ = tf.clip_by_global_norm(grads, 5.0) grad_vars = zip(grads_clipped, vars) # Apply the gradients to adjust the shared variables. apply_gradient_op = optim.apply_gradients(grad_vars, global_step=global_step) ################### # Set up session #tf_config = tf.ConfigProto(allow_soft_placement=True,log_device_placement=False) tf_config = tf.ConfigProto(\ # allow_soft_placement is set to True to build towers on GPU allow_soft_placement=True,\ log_device_placement=False,\ inter_op_parallelism_threads = 1) tf_config.gpu_options.allow_growth = True sess = tf.Session(config=tf_config) saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=args.max_checkpoints) ckpt = tf.train.get_checkpoint_state(logdir) if ckpt: print(" Checkpoint found: {}".format(ckpt.model_checkpoint_path)) saver.restore(sess, ckpt.model_checkpoint_path) print(" Done.") else: print(" No checkpoint found.") sys.exit(" Your model seems to be invalid. ") try: inp_dict = {} global audio_test _, audio_test = scipy.io.wavfile.read(args.test_dir) X_test, X_hlf_test = stft(audio_test, fs, framesz, hop) amplitude_test = scipy.absolute(X_hlf_test) angle_test = np.angle(X_hlf_test) ne = np.reshape(amplitude_test, (1, amplitude_test.shape[0], amplitude_test.shape[1])) fd = ne[:, -256:, :] angle = angle_test[-256:, :] print(fd.shape) inp_dict[speech_inputs_mix[0]] = fd outp1, outp2 = sess.run([output1, output2], feed_dict=inp_dict) mk_audio(outp1, angle, "_1_test_") mk_audio(outp2, angle, "_2_test_") except KeyboardInterrupt: # Introduce a line break after ^C is displayed so save message # is on its own line. print()
def train(directories, args): # args:所有的参数解析 logdir = directories['logdir'] print("logdir:", logdir) restore_from = directories['restore_from'] # Even if we restored the model, we will treat it as new training # if the trained model is written into a location that's different from logdir. is_overwritten_training = logdir != restore_from # 先算后部分,返回bool值. 结果:False,意思是两个文件夹相同 coord = tf.train.Coordinator() # 创建一个协调器,管理线程 # create inputs gc_enabled = args.gc_channels is not None reader = AudioReader(args.data_dir, args.test_dir, coord, sample_rate=args.sample_rate, gc_enabled=gc_enabled) audio_batch = reader.dequeue(args.batch_size) # Initialize model net = SpeechSeparation( batch_size=args.batch_size, rnn_type=args.rnn_type, dim=args.dim, n_rnn=args.n_rnn, seq_len=args.seq_len, num_of_frequency_points=args.num_of_frequency_points) # need to modify net to include these #out = summary, output1, output2, losses, apply_gradient_op = net.initializer( net, args) # output1/2:(1,256,257) speech_inputs_1 = net.speech_inputs_1 # (1,256,257) speech_inputs_2 = net.speech_inputs_2 # (1,256,257) speech_inputs_mix = net.speech_inputs_mix # (1,256,257) # Set up session tf_config = tf.ConfigProto( # allow_soft_placement is set to True to build towers on GPU allow_soft_placement=True, log_device_placement=False, inter_op_parallelism_threads=1) tf_config.gpu_options.allow_growth = True sess = tf.Session(config=tf_config) sess.run(tf.global_variables_initializer()) # Create coordinator. # Set up logging for TensorBoard. writer = tf.summary.FileWriter(logdir) writer.add_graph(tf.get_default_graph()) run_metadata = tf.RunMetadata() # 定义TensorFlow运行元信息,记录训练运算时间和内存占用等信息 # Saver for storing checkpoints of the model. saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=args.max_checkpoints) threads = tf.train.start_queue_runners(sess=sess, coord=coord) reader.start_threads(sess) try: saved_global_step = load(saver, sess, restore_from) # 第一次结果:None if is_overwritten_training or saved_global_step is None: # The first training step will be saved_global_step + 1, # therefore we put -1 here for new or overwritten trainings. saved_global_step = -1 except: print( "Something went wrong while restoring checkpoint. We will terminate " "training to avoid accidentally overwriting the previous model.") raise #################################Start Training#################################### last_saved_step = saved_global_step # -1 try: for step in range(saved_global_step + 1, args.num_steps): # (0, 1000000) loss_sum = 0 start_time = time.time() inputslist = [sess.run(audio_batch) for i in range(args.num_gpus) ] # len:1, 里边还有四维,每一维(1,697,257),括号中1的个数指gpu个数 inp_dict = create_inputdict(inputslist, args, speech_inputs_1, speech_inputs_2, speech_inputs_mix) summ, loss_value, _ = sess.run( [summary, losses, apply_gradient_op], feed_dict=inp_dict) #feed_dict前一个数是占位符,后一个是真实值 for g in range(args.num_gpus): loss_sum += loss_value[g] / args.num_gpus writer.add_summary(summ, step) duration = time.time() - start_time if (step < 100): log_str = ('step {%d} - loss = {%0.3f}, ({%0.3f} sec/step') % ( step, loss_sum, duration) logging.warning(log_str) elif (0 == step % 100): log_str = ('step {%d} - loss = {%0.3f}, ({%0.3f} sec/step') % ( step, loss_sum / 100, duration) logging.warning(log_str) if (0 == step % 2000): angle_test, inp_dict = create_inputdict(inputslist, args, speech_inputs_1, speech_inputs_2, speech_inputs_mix, test=True) outp1, outp2 = sess.run([output1, output2], feed_dict=inp_dict) x_r = mk_audio(outp1, angle_test, args.sample_rate, "spk1_test_" + str(step) + ".wav") y_r = mk_audio(outp2, angle_test, args.sample_rate, "spk2_test_" + str(step) + ".wav") amplitude_test = inputslist[0][2] angle_test = inputslist[0][3] mk_audio(amplitude_test, angle_test, args.sample_rate, "raw_test_" + str(step) + ".wav") # audio summary on tensorboard merged = sess.run( tf.summary.merge([ tf.summary.audio('speaker1_' + str(step), x_r[None, :], args.sample_rate, max_outputs=1), tf.summary.audio('speaker2_' + str(step), y_r[None, :], args.sample_rate, max_outputs=1) ])) writer.add_summary(merged, step) if step % args.checkpoint_every == 0: save(saver, sess, logdir, step) last_saved_step = step except KeyboardInterrupt: # Introduce a line break after ^C is displayed so save message is on its own line. print() finally: #''' if step > last_saved_step: save(saver, sess, logdir, step) #''' coord.request_stop() coord.join(threads)