Example #1
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    print("logdir", logdir)
    restore_from = directories['restore_from']

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        # Allow silence trimming to be skipped by specifying a threshold near
        # zero.
        silence_threshold = args.silence_threshold if args.silence_threshold > \
                                                      EPSILON else None
        gc_enabled = args.gc_channels is not None
        reader = AudioReader(args.data_dir,
                             args.test_dir,
                             coord,
                             sample_rate=args.sample_rate,
                             gc_enabled=gc_enabled,
                             sample_size=args.sample_size,
                             silence_threshold=silence_threshold)
        audio_batch = reader.dequeue(args.batch_size)

    net = create_model(args)
    ########Multi GPU###########
    #'''
    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None
    global_step = tf.get_variable('global_step', [],
                                  initializer=tf.constant_initializer(0),
                                  trainable=False)

    # Create optimizer (default is Adam)
    optim = optimizer_factory[args.optimizer](learning_rate=args.learning_rate,
                                              momentum=args.momentum)
    tower_grads = []
    losses = []
    speech_inputs_mix = []
    speech_inputs_1 = []
    speech_inputs_2 = []

    for i in xrange(args.num_gpus):
        speech_inputs_2.append(
            tf.Variable(tf.zeros(
                [net.batch_size, net.seq_len, args.num_of_frequency_points]),
                        trainable=False,
                        name="speech_batch_inputs",
                        dtype=tf.float32))
        speech_inputs_1.append(
            tf.Variable(tf.zeros(
                [net.batch_size, net.seq_len, args.num_of_frequency_points]),
                        trainable=False,
                        name="speech_batch_inputs",
                        dtype=tf.float32))
        speech_inputs_mix.append(
            tf.Variable(tf.zeros(
                [net.batch_size, net.seq_len, args.num_of_frequency_points]),
                        trainable=False,
                        name="speech_batch_inputs",
                        dtype=tf.float32))

    with tf.variable_scope(tf.get_variable_scope()):
        for i in xrange(args.num_gpus):
            with tf.device('/gpu:%d' % i):
                with tf.name_scope('TOWER_%d' % (i)) as scope:
                    # Create model.
                    print("Creating model On Gpu:%d." % (i))
                    loss, mask_state, output1, output2 = net.loss_SampleRnn(
                        speech_inputs_1[i],
                        speech_inputs_2[i],
                        speech_inputs_mix[i],
                        l2_regularization_strength=args.
                        l2_regularization_strength)

                    # Reuse variables for the nect tower.
                    tf.get_variable_scope().reuse_variables()

                    # UNKNOWN
                    losses.append(loss)
                    trainable = tf.trainable_variables()
                    for name in trainable:
                        print(name)

                    gradients = optim.compute_gradients(loss, trainable)
                    #aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
                    #aggregation_method=tf.AggregationMethod.EXPERIMENTAL_TREE)
                    print("==========================")
                    for name in gradients:
                        print(name)
                    # Keep track of the gradients across all towers.
                    tower_grads.append(gradients)

    # We must calculate the mean of each gradient. Note that this is the
    # synchronization point across all towers.
    grad_vars = average_gradients(tower_grads)

    # UNKNOWN
    grads, vars = zip(*grad_vars)
    grads_clipped, _ = tf.clip_by_global_norm(grads, 5.0)
    grad_vars = zip(grads_clipped, vars)

    # Apply the gradients to adjust the shared variables.
    apply_gradient_op = optim.apply_gradients(grad_vars,
                                              global_step=global_step)

    ###################

    # Set up session
    #tf_config = tf.ConfigProto(allow_soft_placement=True,log_device_placement=False)
    tf_config = tf.ConfigProto(\
        # allow_soft_placement is set to True to build towers on GPU

        allow_soft_placement=True,\
        log_device_placement=False,\
                inter_op_parallelism_threads = 1)
    tf_config.gpu_options.allow_growth = True
    sess = tf.Session(config=tf_config)

    saver = tf.train.Saver(var_list=tf.trainable_variables(),
                           max_to_keep=args.max_checkpoints)

    ckpt = tf.train.get_checkpoint_state(logdir)
    if ckpt:
        print("  Checkpoint found: {}".format(ckpt.model_checkpoint_path))
        saver.restore(sess, ckpt.model_checkpoint_path)
        print(" Done.")
    else:
        print(" No checkpoint found.")
        sys.exit(" Your model seems to be invalid. ")

    try:
        inp_dict = {}
        global audio_test
        _, audio_test = scipy.io.wavfile.read(args.test_dir)

        X_test, X_hlf_test = stft(audio_test, fs, framesz, hop)
        amplitude_test = scipy.absolute(X_hlf_test)
        angle_test = np.angle(X_hlf_test)

        ne = np.reshape(amplitude_test,
                        (1, amplitude_test.shape[0], amplitude_test.shape[1]))
        fd = ne[:, -256:, :]
        angle = angle_test[-256:, :]
        print(fd.shape)

        inp_dict[speech_inputs_mix[0]] = fd

        outp1, outp2 = sess.run([output1, output2], feed_dict=inp_dict)

        mk_audio(outp1, angle, "_1_test_")
        mk_audio(outp2, angle, "_2_test_")

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
def train(directories, args):  # args:所有的参数解析
    logdir = directories['logdir']
    print("logdir:", logdir)
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into a location that's different from logdir.
    is_overwritten_training = logdir != restore_from  # 先算后部分,返回bool值. 结果:False,意思是两个文件夹相同

    coord = tf.train.Coordinator()  # 创建一个协调器,管理线程
    # create inputs
    gc_enabled = args.gc_channels is not None
    reader = AudioReader(args.data_dir,
                         args.test_dir,
                         coord,
                         sample_rate=args.sample_rate,
                         gc_enabled=gc_enabled)
    audio_batch = reader.dequeue(args.batch_size)

    # Initialize model
    net = SpeechSeparation(
        batch_size=args.batch_size,
        rnn_type=args.rnn_type,
        dim=args.dim,
        n_rnn=args.n_rnn,
        seq_len=args.seq_len,
        num_of_frequency_points=args.num_of_frequency_points)

    # need to modify net to include these
    #out =
    summary, output1, output2, losses, apply_gradient_op = net.initializer(
        net, args)  # output1/2:(1,256,257)
    speech_inputs_1 = net.speech_inputs_1  # (1,256,257)
    speech_inputs_2 = net.speech_inputs_2  # (1,256,257)
    speech_inputs_mix = net.speech_inputs_mix  # (1,256,257)

    # Set up session
    tf_config = tf.ConfigProto(
        # allow_soft_placement is set to True to build towers on GPU
        allow_soft_placement=True,
        log_device_placement=False,
        inter_op_parallelism_threads=1)
    tf_config.gpu_options.allow_growth = True
    sess = tf.Session(config=tf_config)

    sess.run(tf.global_variables_initializer())
    # Create coordinator.

    # Set up logging for TensorBoard.
    writer = tf.summary.FileWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()  # 定义TensorFlow运行元信息,记录训练运算时间和内存占用等信息

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.trainable_variables(),
                           max_to_keep=args.max_checkpoints)

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    reader.start_threads(sess)
    try:
        saved_global_step = load(saver, sess, restore_from)  # 第一次结果:None
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1
    except:
        print(
            "Something went wrong while restoring checkpoint. We will terminate "
            "training to avoid accidentally overwriting the previous model.")
        raise


#################################Start Training####################################
    last_saved_step = saved_global_step  # -1
    try:
        for step in range(saved_global_step + 1,
                          args.num_steps):  # (0, 1000000)
            loss_sum = 0
            start_time = time.time()

            inputslist = [sess.run(audio_batch) for i in range(args.num_gpus)
                          ]  # len:1, 里边还有四维,每一维(1,697,257),括号中1的个数指gpu个数
            inp_dict = create_inputdict(inputslist, args, speech_inputs_1,
                                        speech_inputs_2, speech_inputs_mix)
            summ, loss_value, _ = sess.run(
                [summary, losses, apply_gradient_op],
                feed_dict=inp_dict)  #feed_dict前一个数是占位符,后一个是真实值

            for g in range(args.num_gpus):
                loss_sum += loss_value[g] / args.num_gpus

            writer.add_summary(summ, step)
            duration = time.time() - start_time

            if (step < 100):
                log_str = ('step {%d} - loss = {%0.3f}, ({%0.3f} sec/step') % (
                    step, loss_sum, duration)
                logging.warning(log_str)

            elif (0 == step % 100):
                log_str = ('step {%d} - loss = {%0.3f}, ({%0.3f} sec/step') % (
                    step, loss_sum / 100, duration)
                logging.warning(log_str)

            if (0 == step % 2000):
                angle_test, inp_dict = create_inputdict(inputslist,
                                                        args,
                                                        speech_inputs_1,
                                                        speech_inputs_2,
                                                        speech_inputs_mix,
                                                        test=True)

                outp1, outp2 = sess.run([output1, output2], feed_dict=inp_dict)
                x_r = mk_audio(outp1, angle_test, args.sample_rate,
                               "spk1_test_" + str(step) + ".wav")
                y_r = mk_audio(outp2, angle_test, args.sample_rate,
                               "spk2_test_" + str(step) + ".wav")

                amplitude_test = inputslist[0][2]
                angle_test = inputslist[0][3]
                mk_audio(amplitude_test, angle_test, args.sample_rate,
                         "raw_test_" + str(step) + ".wav")

                # audio summary on tensorboard
                merged = sess.run(
                    tf.summary.merge([
                        tf.summary.audio('speaker1_' + str(step),
                                         x_r[None, :],
                                         args.sample_rate,
                                         max_outputs=1),
                        tf.summary.audio('speaker2_' + str(step),
                                         y_r[None, :],
                                         args.sample_rate,
                                         max_outputs=1)
                    ]))
                writer.add_summary(merged, step)

            if step % args.checkpoint_every == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message is on its own line.
        print()
    finally:
        #'''
        if step > last_saved_step:
            save(saver, sess, logdir, step)
        #'''
        coord.request_stop()
        coord.join(threads)