class Synthesizer(): """ Synthesizer """ def init(self, checkpoint_path): """ Initialize Synthesizer @type checkpoint_path str @param checkpoint_path path to checkpoint to be restored """ print('Constructing Tacotron Model ...') inputs = tf.compat.v1.placeholder(tf.int32, [1, None], 'inputs') input_lengths = tf.compat.v1.placeholder(tf.int32, [1], 'input_lengths') with tf.compat.v1.variable_scope('model'): self.model = Tacotron() self.model.init(inputs, input_lengths) self.wav_output = audio.spectrogram_to_wav_tf( self.model.linear_outputs[0]) print('Loading checkpoint: %s' % checkpoint_path) self.session = tf.compat.v1.Session() self.session.run(tf.compat.v1.global_variables_initializer()) saver = tf.compat.v1.train.Saver() saver.restore(self.session, checkpoint_path) def synthesize(self, text): """ Convert the text into synthesized speech @type text str @param text text to be synthesized @rtype object @return synthesized speech """ seq = text_to_sequence(text) feed_dict = { self.model.inputs: [np.asarray(seq, dtype=np.int32)], self.model.input_lengths: np.asarray([len(seq)], dtype=np.int32) } wav = self.session.run(self.wav_output, feed_dict=feed_dict) wav = audio.inv_preemphasis(wav) wav = wav[:audio.find_endpoint(wav)] out = io.BytesIO() audio.save_audio(wav, out) return out.getvalue()
def train(log_dir, args): checkpoint_path = os.path.join(log_dir, 'model.ckpt') input_path = os.path.join(args.base_dir, 'training/train.txt') logger.log('Checkpoint path: %s' % checkpoint_path) logger.log('Loading training data from: %s' % input_path) # set up DataFeeder coordi = tf.train.Coordinator() with tf.compat.v1.variable_scope('data_feeder'): feeder = DataFeeder(coordi, input_path) # set up Model global_step = tf.Variable(0, name='global_step', trainable=False) with tf.compat.v1.variable_scope('model'): model = Tacotron() model.init(feeder.inputs, feeder.input_lengths, mel_targets=feeder.mel_targets, linear_targets=feeder.linear_targets) model.add_loss() model.add_optimizer(global_step) stats = add_stats(model) # book keeping step = 0 loss_window = ValueWindow(100) time_window = ValueWindow(100) saver = tf.compat.v1.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2) # start training already! with tf.compat.v1.Session() as sess: try: summary_writer = tf.summary.FileWriter(log_dir, sess.graph) # initialize parameters sess.run(tf.compat.v1.global_variables_initializer()) # if requested, restore from step if (args.restore_step): restore_path = '%s-%d' % (checkpoint_path, args.restore_step) saver.restore(sess, restore_path) logger.log('Resuming from checkpoint: %s' % restore_path) else: logger.log('Starting a new training!') feeder.start_in_session(sess) while not coordi.should_stop(): start_time = time.time() step, loss, opt = sess.run( [global_step, model.loss, model.optimize]) time_window.append(time.time() - start_time) loss_window.append(loss) msg = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % ( step, time_window.average, loss, loss_window.average) logger.log(msg) if loss > 100 or math.isnan(loss): # bad situation logger.log('Loss exploded to %.05f at step %d!' % (loss, step)) raise Exception('Loss Exploded') if step % args.summary_interval == 0: # it's time to write summary logger.log('Writing summary at step: %d' % step) summary_writer.add_summary(sess.run(stats), step) if step % args.checkpoint_interval == 0: # it's time to save a checkpoint logger.log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) logger.log('Saving audio and alignment...') input_seq, spectrogram, alignment = sess.run([ model.inputs[0], model.linear_outputs[0], model.alignments[0] ]) # convert spectrogram to waveform waveform = audio.spectrogram_to_wav(spectrogram.T) # save it audio.save_audio( waveform, os.path.join(log_dir, 'step-%d-audio.wav' % step)) plotter.plot_alignment( alignment, os.path.join(log_dir, 'step-%d-align.png' % step), info='%s, %s, step=%d, loss=%.5f' % ('tacotron', time_string(), step, loss)) logger.log('Input: %s' % sequence_to_text(input_seq)) except Exception as e: logger.log('Exiting due to exception %s' % e) traceback.print_exc() coordi.request_stop(e)