예제 #1
0
def main():
    args = get_arguments()
    args.logdir = os.path.join(hparams.logdir_root, args.run_name)
    if not os.path.exists(args.logdir):
        os.makedirs(args.logdir)

    # Create coordinator.
    coord = tf.train.Coordinator()

    with tf.device('/cpu:0'):
        with tf.name_scope('inputs'):
            reader = DataReader(coord, args.wave_dir, args.lc_dir,
                                hparams.sample_size)

    global_step = tf.get_variable("global_step", [],
                                  initializer=tf.constant_initializer(0),
                                  trainable=False)

    if hparams.learning_rate_decay_way == "cosine":
        learning_rate = consine_learning_rate_decay(global_step)
    else:
        learning_rate = linear_learning_rate_decay(global_step)

    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)

    tf.summary.scalar('learning_rate', learning_rate)

    gpu_ids = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
    tower_grads, tower_losses = [], []

    x_placeholder = []
    lc_placeholder = []
    for _ in range(len(gpu_ids)):
        x_placeholder.append(
            tf.placeholder(dtype=tf.float32, shape=[None, None, 1]))
        lc_placeholder.append(
            tf.placeholder(dtype=tf.float32,
                           shape=[None, None, hparams.num_mels]))

    with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
        for i in range(len(gpu_ids)):
            with tf.device(
                    assign_to_device('/gpu:%d' % int(gpu_ids[i]),
                                     ps_device='cpu:0')), tf.name_scope(
                                         'tower_%d' % int(i)):
                model = WaveGlow(lc_channels=hparams.num_mels)

                model.create_network(x_placeholder[i], lc_placeholder[i])

                #with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
                #    tf.contrib.quantize.create_training_graph(quant_delay=0)
                model.add_loss()
                tower_losses.append(model.loss)
                grads = optimizer.compute_gradients(model.loss)
                tower_grads.append(grads)
                add_stats(model)

    with tf.name_scope('average_grad'):
        averaged_loss = tf.add_n(tower_losses) / len(tower_losses)
        tf.summary.scalar('average_loss', averaged_loss)
        averaged_gradients = average_gradients(tower_grads)

        # gradient clipping
        gradients = [grad for grad, var in averaged_gradients]
        params = [var for grad, var in averaged_gradients]
        clipped_gradients, norm = tf.clip_by_global_norm(
            gradients, hparams.clip_norm)

        # Add dependency on UPDATE_OPS; otherwise batchnorm won't work correctly. See:
        # https://github.com/tensorflow/tensorflow/issues/1122
        with tf.control_dependencies(tf.get_collection(
                tf.GraphKeys.UPDATE_OPS)):
            train_op = optimizer.apply_gradients(zip(clipped_gradients,
                                                     params),
                                                 global_step=global_step)
            update_op = tf.get_collection(
                tf.GraphKeys.UPDATE_OPS)  # batch norm update

            train_ops = tf.group([train_op, update_op])

    stats = tf.summary.merge_all()

    saver = tf.train.Saver(max_to_keep=10)
    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=True)

    #tf.contrib.quantize.create_training_graph(
    #        input_graph=tf.get_default_graph(),
    #        quant_delay=0)
    with tf.Session(config=config) as sess:
        try:
            reader.start_threads()
            summary_writer = tf.summary.FileWriter(args.logdir, sess.graph)
            sess.run(tf.global_variables_initializer())
            saved_global_step = 0
            last_saved_step = 0
            step = 0

            if args.restore_from is not None:
                try:
                    saved_global_step = load(saver, sess, args.restore_from)
                except Exception:
                    print(
                        "Something went wrong while restoring checkpoint. "
                        "We will terminate training to avoid accidentally overwriting "
                        "the previous model.")
                    raise
                print("Restore model successfully!")
            else:
                print("Start new training.")
            last_saved_step = saved_global_step

            for step in range(saved_global_step + 1, hparams.train_steps):
                start_time = time.time()
                x, lc = reader.dequeue(num_elements=hparams.batch_size *
                                       len(gpu_ids))
                dicts = dict()
                for i in range(len(gpu_ids)):
                    dicts[x_placeholder[i]] = x[i *
                                                hparams.batch_size:(i + 1) *
                                                hparams.batch_size]
                    dicts[lc_placeholder[i]] = lc[i *
                                                  hparams.batch_size:(i + 1) *
                                                  hparams.batch_size]

                _, loss, lr, _, summary = sess.run([
                    global_step, averaged_loss, learning_rate, train_ops, stats
                ],
                                                   feed_dict=dicts)
                duration = time.time() - start_time
                step_log = 'step {:d} loss={:.3f} lr={:.8f} time={:4f}'\
                        .format(step, loss, lr, duration)
                print(step_log)

                if step % hparams.save_model_every == 0:
                    save(saver, sess, args.logdir, step)
                    last_saved_step = step

                if step % hparams.summary_interval == 0:
                    summary_writer.add_summary(summary, step)

        except KeyboardInterrupt:
            # Introduce a line break after ^C is displayed so save message
            # is on its own line.
            print()
        finally:
            if step > last_saved_step:
                save(saver, sess, args.logdir, step)
                coord.request_stop()
                coord.join()
예제 #2
0
파일: train.py 프로젝트: mcf330/WaveGlow
def main():
    args = get_arguments()
    args.logdir = os.path.join(hparams.logdir_root, args.run_name)
    if not os.path.exists(args.logdir):
        os.makedirs(args.logdir)

    assert hparams.upsampling_rate == hparams.hop_length, 'upsamling rate should be same as hop_length'

    # Create coordinator.
    coord = tf.train.Coordinator()
    global_step = tf.get_variable("global_step", [], initializer=tf.constant_initializer(0), trainable=False)
    learning_rate = tf.train.exponential_decay(hparams.lr, global_step, hparams.decay_steps, 0.95, staircase=True)
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)

    with tf.device('/cpu:0'):
        with tf.name_scope('inputs'):
            reader = DataReader(coord, args.filelist, args.wave_dir, args.lc_dir)

    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False, allow_soft_placement=True))
    reader.start_threads()

    audio_placeholder = tf.placeholder(tf.float32, shape=[None, None, 1], name='audio')
    lc_placeholder = tf.placeholder(tf.float32, shape=[None, None, hparams.num_mels], name='lc')

    tower_losses = []
    tower_grads = []
    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
        for i in range(args.ngpu):
            with tf.device('/gpu:%d' % i), tf.name_scope('tower_%d' % i):
                glow = WaveGlow(lc_dim=hparams.num_mels,
                                n_flows=hparams.n_flows,
                                n_group=hparams.n_group,
                                n_early_every=hparams.n_early_every,
                                n_early_size=hparams.n_early_size)
                print('create network %i' % i)

                local_audio_placeholder = audio_placeholder[i * hparams.batch_size:(i + 1) * hparams.batch_size, :, :]
                local_lc_placeholder = lc_placeholder[i * hparams.batch_size:(i + 1) * hparams.batch_size, :, :]

                output_audio, log_s_list, log_det_W_list = glow.create_forward_network(local_audio_placeholder,
                                                                                       local_lc_placeholder)
                loss = compute_waveglow_loss(output_audio, log_s_list, log_det_W_list, sigma=hparams.sigma)
                grads = optimizer.compute_gradients(loss, var_list=tf.trainable_variables())

                tower_losses.append(loss)
                tower_grads.append(grads)

                tf.summary.scalar('loss_tower_%d' % i, loss)

    # # gradient clipping
    # gradients = [grad for grad, var in averaged_gradients]
    # params = [var for grad, var in averaged_gradients]
    # clipped_gradients, norm = tf.clip_by_global_norm(gradients, 1.0)
    #
    # with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
    #     train_ops = optimizer.apply_gradients(zip(clipped_gradients, params), global_step=global_step)

    print("create network finished")
    loss = tf.reduce_mean(tower_losses)
    averaged_gradients = average_gradients(tower_grads)

    train_ops = optimizer.apply_gradients(averaged_gradients, global_step=global_step)

    tf.summary.scalar('loss', loss)

    # Set up logging for TensorBoard.
    writer = tf.summary.FileWriter(args.logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.summary.merge_all()

    # Set up session
    init = tf.global_variables_initializer()
    sess.run(init)
    print('parameters initialization finished')

    saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=30)

    saved_global_step = 0
    if args.restore_from is not None:
        try:
            saved_global_step = load(saver, sess, args.restore_from)
            if saved_global_step is None:
                # The first training step will be saved_global_step + 1,
                # therefore we put -1 here for new or overwritten trainings.
                saved_global_step = 0
        except Exception:
            print("Something went wrong while restoring checkpoint. "
                  "We will terminate training to avoid accidentally overwriting "
                  "the previous model.")
            raise

        print("restore model successfully!")

    print('start training.')
    last_saved_step = saved_global_step
    try:
        for step in range(saved_global_step + 1, hparams.train_steps):
            audio, lc = reader.dequeue(num_elements=hparams.batch_size * args.ngpu)

            if hparams.lc_encode or hparams.transposed_upsampling:
                # if using local condition bi-lstm encoding or tranposed conv upsampling, no need to upsample
                # bi-lstm, upsamle will be done in the tf code
                lc = np.reshape(lc, [hparams.batch_size * args.ngpu, -1, hparams.num_mels])
            else:
                # upsampling by directly repeat
                lc = np.tile(lc, [1, 1, hparams.upsampling_rate])
                lc = np.reshape(lc, [hparams.batch_size * args.ngpu, -1, hparams.num_mels])

            start_time = time.time()
            if step % 50 == 0 and args.store_metadata:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary, loss_value, _, lr = sess.run(
                    [summaries, loss, train_ops, learning_rate],
                    feed_dict={audio_placeholder: audio, lc_placeholder: lc},
                    options=run_options,
                    run_metadata=run_metadata)
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(args.logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                summary, loss_value, _, lr = sess.run([summaries, loss, train_ops, learning_rate],
                                                      feed_dict={audio_placeholder: audio, lc_placeholder: lc})
                writer.add_summary(summary, step)

            duration = time.time() - start_time
            step_log = 'step {:d} - loss = {:.3f}, lr={:.8f}, time cost={:4f}'\
                .format(step, loss_value, lr, duration)
            print(step_log)

            if step % hparams.save_model_every == 0:
                save(saver, sess, args.logdir, step)
                last_saved_step = step

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step > last_saved_step:
            save(saver, sess, args.logdir, step)
        coord.request_stop()
        coord.join()
예제 #3
0
def main():
    args = get_arguments()
    args.logdir = os.path.join(hparams.logdir_root, args.run_name)
    if not os.path.exists(args.logdir):
        os.makedirs(args.logdir)

    assert hparams.upsampling_rate == hparams.hop_length, 'upsamling rate should be same as hop_length'

    # Create coordinator.
    coord = tf.train.Coordinator()

    with tf.device('/cpu:0'):
        with tf.name_scope('inputs'):
            reader = DataReader(coord, args.filelist, args.wave_dir, args.lc_dir)

    with tf.Graph().as_default():
        global_step = tf.get_variable("global_step", [], initializer=tf.constant_initializer(0), trainable=False)
        learning_rate = tf.train.exponential_decay(hparams.lr, global_step, hparams.decay_steps, 0.95, staircase=True)
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)

        sess = tf.Session(config=tf.ConfigProto(log_device_placement=False, allow_soft_placement=True))
        reader.start_threads()

        seq_len = hparams.seqlen
        batch_size = hparams.batch_size
        lc_dim = hparams.num_mels

        assert hparams.upsampling_rate == hparams.hop_length
        assert hparams.seqlen % hparams.upsampling_rate == 0
        assert np.cumprod(hparams.upsample_factors)[-1] == hparams.upsampling_rate

        lc_frames = hparams.seqlen // hparams.upsampling_rate + 2 * hparams.lc_pad

        x_placeholder = tf.placeholder(dtype=tf.float32, shape=[batch_size, seq_len, 1])          # B*1800*1
        y_placeholder = tf.placeholder(dtype=tf.float32, shape=[batch_size, seq_len])             # B*1800
        lc_placeholder = tf.placeholder(dtype=tf.float32, shape=[batch_size, lc_frames, lc_dim])  # B*9*80

        wave_rnn = WaveRNN_Alternative()
        loss = wave_rnn.build_network(x_placeholder, y_placeholder, lc_placeholder)

        grads = optimizer.compute_gradients(loss)
        grads = [(tf.clip_by_norm(grad, 2), var) for grad, var in grads]
        train_ops = optimizer.apply_gradients(grads, global_step=global_step)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)     # batch norm update
        train_ops = tf.group([train_ops, update_ops])

        tf.summary.scalar('loss', loss)

        # Set up logging for TensorBoard.
        writer = tf.summary.FileWriter(args.logdir)
        writer.add_graph(tf.get_default_graph())
        summaries = tf.summary.merge_all()

        # create test model
        test_lcnet_placeholder = tf.placeholder(dtype=tf.float32, shape=[1, None, hparams.num_mels])
        # build lc network for transposed upsampling
        test_lc_upsample_net = wave_rnn.upsample_network(test_lcnet_placeholder)

        test_x_placeholder = tf.placeholder(dtype=tf.float32, shape=[1, 1], name='x')
        test_mel_placeholder = tf.placeholder(dtype=tf.float32, shape=[1, hparams.lc_dims], name='mel')
        test_aux_placeholder = tf.placeholder(dtype=tf.float32, shape=[1, hparams.lc_dims], name='aux')

        infer_ops = wave_rnn.inference(test_x_placeholder, test_mel_placeholder, test_aux_placeholder)

        # Set up session
        init = tf.global_variables_initializer()
        sess.run(init)
        print('parameters initialization finished')

        saver = tf.train.Saver(max_to_keep=30)

        wave_dir = os.path.join(args.logdir, 'test_wave')
        if not os.path.exists(wave_dir):
            os.makedirs(wave_dir)

        # load local condition
        test_lc = read_binary_lc(reader.test_lc, hparams.num_mels)

        saved_global_step = 0
        if args.restore_from is not None:
            try:
                saved_global_step = load(saver, sess, args.restore_from)
                if saved_global_step is None:
                    # The first training step will be saved_global_step + 1,
                    # therefore we put -1 here for new or overwritten trainings.
                    saved_global_step = 0
            except Exception:
                print("Something went wrong while restoring checkpoint. "
                      "We will terminate training to avoid accidentally overwriting "
                      "the previous model.")
                raise

            print("restore model successfully!")

        print('start training.')
        last_saved_step = saved_global_step
        try:
            for step in range(saved_global_step + 1, hparams.train_steps):
                x, y, lc = reader.dequeue(num_elements=hparams.batch_size * args.ngpu)

                y = np.reshape(y, [hparams.batch_size, seq_len])

                start_time = time.time()
                summary, loss_value, _, lr = sess.run([summaries, loss, train_ops, learning_rate],
                                                      feed_dict={
                                                          x_placeholder: x,
                                                          y_placeholder: y,
                                                          lc_placeholder: lc
                                                      })
                writer.add_summary(summary, step)

                duration = time.time() - start_time
                step_log = 'step {:d} - loss = {:.3f}, lr={:.8f}, time cost={:4f}'\
                    .format(step, loss_value, lr, duration)
                print(step_log)

                if step % hparams.save_model_every == 0:
                    save(saver, sess, args.logdir, step)
                    last_saved_step = step

                    # pad local condition before and after to make sure same length after the conv1d opertiona with filter_width=3
                    lc = np.pad(test_lc, ((hparams.lc_pad, hparams.lc_pad), (0, 0)), mode='constant')  # pad 2 frames in both start & end
                    lc = np.reshape(lc, [1, -1, hparams.num_mels])

                    # lc go through the lc_net
                    aux, mel = sess.run(test_lc_upsample_net, feed_dict={
                        test_lcnet_placeholder: lc
                    })

                    mel = np.reshape(mel, [-1, hparams.lc_dims])
                    aux = np.reshape(aux, [-1, hparams.lc_dims])

                    # create generation model
                    print('generating samples')
                    wave_save_name = os.path.join(wave_dir, 'wavernn_test_model_{}.wav'.format(str(step).zfill(7)))
                    generate_seq(wave_save_name, wave_rnn, sess, mel, aux, infer_ops, test_x_placeholder,
                                 test_mel_placeholder, test_aux_placeholder)

        except KeyboardInterrupt:
            # Introduce a line break after ^C is displayed so save message
            # is on its own line.
            print()
        finally:
            if step > last_saved_step:
                save(saver, sess, args.logdir, step)
            coord.request_stop()
            coord.join()