def model_fn(features, labels, mode=tf.estimator.ModeKeys.TRAIN, params=None, config=None): inputs = features['inputs'] lengths = features['lengths'] mel_targets = None linear_targets = None train_hooks = [] global_step = tf.train.get_global_step() if mode == tf.estimator.ModeKeys.TRAIN: mel_targets = labels['mel_targets'] linear_targets = labels['linear_targets'] with tf.variable_scope('model'): model = Tacotron(params) model.initialize(inputs, lengths, mel_targets, linear_targets) if mode == tf.estimator.ModeKeys.TRAIN: model.add_loss() model.add_optimizer(global_step) # train_hooks.extend([ # LoggingTensorHook( # [global_step, model.loss, tf.shape(model.linear_outputs)], # every_n_secs=60, # ) # ]) outputs = tf.map_fn(inv_spectrogram_tensorflow, model.linear_outputs) if mode == tf.estimator.ModeKeys.TRAIN: with tf.variable_scope('stats') as scope: tf.summary.histogram('linear_outputs', model.linear_outputs) tf.summary.histogram('linear_targets', model.linear_targets) tf.summary.histogram('mel_outputs', model.mel_outputs) tf.summary.histogram('mel_targets', model.mel_targets) tf.summary.scalar('loss_mel', model.mel_loss) tf.summary.scalar('loss_linear', model.linear_loss) tf.summary.scalar('learning_rate', model.learning_rate) tf.summary.scalar('loss', model.loss) gradient_norms = [tf.norm(grad) for grad in model.gradients] tf.summary.histogram('gradient_norm', gradient_norms) tf.summary.scalar('max_gradient_norm', tf.reduce_max(gradient_norms)) tf.summary.audio('outputs', outputs, hparams.sample_rate, max_outputs=1) tf.summary.merge_all() return tf.estimator.EstimatorSpec(mode, predictions=outputs, loss=getattr(model, 'loss', None), train_op=getattr(model, 'optimize', None), eval_metric_ops=None, export_outputs=None, training_chief_hooks=None, training_hooks=train_hooks, scaffold=None, evaluation_hooks=None, prediction_hooks=None)
def train(log_dir, args): checkpoint_path = os.path.join(log_dir, 'model.ckpt') input_path = os.path.join(args.base_dir, args.input) # 显示模型的路径信息 log('Checkpoint path: %s' % checkpoint_path) log('Loading training data from: %s' % input_path) coord = tf.train.Coordinator() with tf.variable_scope('datafeeder') as scope: feeder = DataFeeder(coord, input_path, hparams) # 初始化模型 global_step = tf.Variable(0, name='global_step', trainable=False) with tf.variable_scope('model') as scope: model = Tacotron(hparams) model.initialize(feeder.inputs, feeder.input_lengths, feeder.mel_targets, feeder.linear_targets, feeder.stop_token_targets, global_step) model.add_loss() model.add_optimizer(global_step) stats = add_stats(model) step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=1) # 开始训练 with tf.Session() as sess: try: summary_writer = tf.summary.FileWriter(log_dir, sess.graph) sess.run(tf.global_variables_initializer()) feeder.start_in_session(sess) while not coord.should_stop(): start_time = time.time() step, loss, opt = sess.run( [global_step, model.loss, model.optimize]) time_window.append(time.time() - start_time) loss_window.append(loss) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % ( step, time_window.average, loss, loss_window.average) log(message, slack=(step % args.checkpoint_interval == 0)) if step % args.summary_interval == 0: summary_writer.add_summary(sess.run(stats), step) # 每隔一定的训练步数生成检查点 if step % args.checkpoint_interval == 0: log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) log('Saving audio and alignment...') input_seq, spectrogram, alignment = sess.run([ model.inputs[0], model.linear_outputs[0], model.alignments[0] ]) waveform = audio.inv_spectrogram(spectrogram.T) # 合成样音 audio.save_wav( waveform, os.path.join(log_dir, 'step-%d-audio.wav' % step)) time_string = datetime.now().strftime('%Y-%m-%d %H:%M') # 画Encoder-Decoder对齐图 infolog.plot_alignment( alignment, os.path.join(log_dir, 'step-%d-align.png' % step), info='%s, %s, step=%d, loss=%.5f' % (args.model, time_string, step, loss)) # 显示合成样音的文本 log('Input: %s' % sequence_to_text(input_seq)) except Exception as e: log('Exiting due to exception: %s' % e, slack=True) traceback.print_exc() coord.request_stop(e)