Beispiel #1
0
def model_fn(features,
             labels,
             mode=tf.estimator.ModeKeys.TRAIN,
             params=None,
             config=None):
    inputs = features['inputs']
    lengths = features['lengths']
    mel_targets = None
    linear_targets = None
    train_hooks = []
    global_step = tf.train.get_global_step()
    if mode == tf.estimator.ModeKeys.TRAIN:
        mel_targets = labels['mel_targets']
        linear_targets = labels['linear_targets']
    with tf.variable_scope('model'):
        model = Tacotron(params)
        model.initialize(inputs, lengths, mel_targets, linear_targets)
        if mode == tf.estimator.ModeKeys.TRAIN:
            model.add_loss()
            model.add_optimizer(global_step)
            # train_hooks.extend([
            #     LoggingTensorHook(
            #         [global_step, model.loss, tf.shape(model.linear_outputs)],
            #         every_n_secs=60,
            #     )
            # ])
        outputs = tf.map_fn(inv_spectrogram_tensorflow, model.linear_outputs)

    if mode == tf.estimator.ModeKeys.TRAIN:
        with tf.variable_scope('stats') as scope:
            tf.summary.histogram('linear_outputs', model.linear_outputs)
            tf.summary.histogram('linear_targets', model.linear_targets)
            tf.summary.histogram('mel_outputs', model.mel_outputs)
            tf.summary.histogram('mel_targets', model.mel_targets)
            tf.summary.scalar('loss_mel', model.mel_loss)
            tf.summary.scalar('loss_linear', model.linear_loss)
            tf.summary.scalar('learning_rate', model.learning_rate)
            tf.summary.scalar('loss', model.loss)
            gradient_norms = [tf.norm(grad) for grad in model.gradients]
            tf.summary.histogram('gradient_norm', gradient_norms)
            tf.summary.scalar('max_gradient_norm',
                              tf.reduce_max(gradient_norms))
            tf.summary.audio('outputs',
                             outputs,
                             hparams.sample_rate,
                             max_outputs=1)
            tf.summary.merge_all()

    return tf.estimator.EstimatorSpec(mode,
                                      predictions=outputs,
                                      loss=getattr(model, 'loss', None),
                                      train_op=getattr(model, 'optimize',
                                                       None),
                                      eval_metric_ops=None,
                                      export_outputs=None,
                                      training_chief_hooks=None,
                                      training_hooks=train_hooks,
                                      scaffold=None,
                                      evaluation_hooks=None,
                                      prediction_hooks=None)
Beispiel #2
0
def train(log_dir, args):
    checkpoint_path = os.path.join(log_dir, 'model.ckpt')
    input_path = os.path.join(args.base_dir, args.input)
    # 显示模型的路径信息
    log('Checkpoint path: %s' % checkpoint_path)
    log('Loading training data from: %s' % input_path)
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder') as scope:
        feeder = DataFeeder(coord, input_path, hparams)

    # 初始化模型
    global_step = tf.Variable(0, name='global_step', trainable=False)
    with tf.variable_scope('model') as scope:
        model = Tacotron(hparams)
        model.initialize(feeder.inputs, feeder.input_lengths,
                         feeder.mel_targets, feeder.linear_targets,
                         feeder.stop_token_targets, global_step)
        model.add_loss()
        model.add_optimizer(global_step)
        stats = add_stats(model)

    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=1)

    # 开始训练
    with tf.Session() as sess:
        try:
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
            sess.run(tf.global_variables_initializer())
            feeder.start_in_session(sess)

            while not coord.should_stop():
                start_time = time.time()
                step, loss, opt = sess.run(
                    [global_step, model.loss, model.optimize])
                time_window.append(time.time() - start_time)
                loss_window.append(loss)
                message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % (
                    step, time_window.average, loss, loss_window.average)
                log(message, slack=(step % args.checkpoint_interval == 0))

                if step % args.summary_interval == 0:
                    summary_writer.add_summary(sess.run(stats), step)

                # 每隔一定的训练步数生成检查点
                if step % args.checkpoint_interval == 0:
                    log('Saving checkpoint to: %s-%d' %
                        (checkpoint_path, step))
                    saver.save(sess, checkpoint_path, global_step=step)
                    log('Saving audio and alignment...')
                    input_seq, spectrogram, alignment = sess.run([
                        model.inputs[0], model.linear_outputs[0],
                        model.alignments[0]
                    ])
                    waveform = audio.inv_spectrogram(spectrogram.T)
                    # 合成样音
                    audio.save_wav(
                        waveform,
                        os.path.join(log_dir, 'step-%d-audio.wav' % step))
                    time_string = datetime.now().strftime('%Y-%m-%d %H:%M')
                    # 画Encoder-Decoder对齐图
                    infolog.plot_alignment(
                        alignment,
                        os.path.join(log_dir, 'step-%d-align.png' % step),
                        info='%s,  %s, step=%d, loss=%.5f' %
                        (args.model, time_string, step, loss))
                    # 显示合成样音的文本
                    log('Input: %s' % sequence_to_text(input_seq))

        except Exception as e:
            log('Exiting due to exception: %s' % e, slack=True)
            traceback.print_exc()
            coord.request_stop(e)