def train(log_dir, config): config.data_paths = config.data_paths data_dirs = [os.path.join(data_path, "data") \ for data_path in config.data_paths] num_speakers = len(data_dirs) config.num_test = config.num_test_per_speaker * num_speakers if num_speakers > 1 and hparams.model_type not in ["deepvoice", "simple"]: raise Exception("[!] Unkown model_type for multi-speaker: {}".format( config.model_type)) commit = get_git_commit() if config.git else 'None' checkpoint_path = os.path.join(log_dir, 'model.ckpt') #log(' [*] git recv-parse HEAD:\n%s' % get_git_revision_hash()) log('=' * 50) #log(' [*] dit diff:\n%s' % get_git_diff()) log('=' * 50) log(' [*] Checkpoint path: %s' % checkpoint_path) log(' [*] Loading training data from: %s' % data_dirs) log(' [*] Using model: %s' % config.model_dir) log(hparams_debug_string()) # Set up DataFeeder: coord = tf.train.Coordinator() with tf.variable_scope('datafeeder') as scope: train_feeder = DataFeeder(coord, data_dirs, hparams, config, 32, data_type='train', batch_size=hparams.batch_size) test_feeder = DataFeeder(coord, data_dirs, hparams, config, 8, data_type='test', batch_size=config.num_test) # Set up model: is_randomly_initialized = config.initialize_path is None global_step = tf.Variable(0, name='global_step', trainable=False) with tf.variable_scope('model') as scope: model = create_model(hparams) model.initialize(train_feeder.inputs, train_feeder.input_lengths, num_speakers, train_feeder.speaker_id, train_feeder.mel_targets, train_feeder.linear_targets, train_feeder.loss_coeff, is_randomly_initialized=is_randomly_initialized) model.add_loss() model.add_optimizer(global_step) train_stats = add_stats(model, scope_name='stats') # legacy with tf.variable_scope('model', reuse=True) as scope: test_model = create_model(hparams) test_model.initialize(test_feeder.inputs, test_feeder.input_lengths, num_speakers, test_feeder.speaker_id, test_feeder.mel_targets, test_feeder.linear_targets, test_feeder.loss_coeff, rnn_decoder_test_mode=True, is_randomly_initialized=is_randomly_initialized) test_model.add_loss() test_stats = add_stats(test_model, model, scope_name='test') test_stats = tf.summary.merge([test_stats, train_stats]) # Bookkeeping: step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=None, keep_checkpoint_every_n_hours=2) sess_config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True) sess_config.gpu_options.allow_growth = True # Train! #with tf.Session(config=sess_config) as sess: with tf.Session() as sess: try: summary_writer = tf.summary.FileWriter(log_dir, sess.graph) sess.run(tf.global_variables_initializer()) if config.load_path: # Restore from a checkpoint if the user requested it. restore_path = get_most_recent_checkpoint(config.model_dir) saver.restore(sess, restore_path) log('Resuming from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True) elif config.initialize_path: restore_path = get_most_recent_checkpoint( config.initialize_path) saver.restore(sess, restore_path) log('Initialized from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True) zero_step_assign = tf.assign(global_step, 0) sess.run(zero_step_assign) start_step = sess.run(global_step) log('=' * 50) log(' [*] Global step is reset to {}'. \ format(start_step)) log('=' * 50) else: log('Starting new training run at commit: %s' % commit, slack=True) start_step = sess.run(global_step) train_feeder.start_in_session(sess, start_step) test_feeder.start_in_session(sess, start_step) while not coord.should_stop(): start_time = time.time() step, loss, opt = sess.run( [global_step, model.loss_without_coeff, model.optimize], feed_dict=model.get_dummy_feed_dict()) time_window.append(time.time() - start_time) loss_window.append(loss) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % ( step, time_window.average, loss, loss_window.average) log(message, slack=(step % config.checkpoint_interval == 0)) if loss > 100 or math.isnan(loss): log('Loss exploded to %.05f at step %d!' % (loss, step), slack=True) raise Exception('Loss Exploded') if step % config.summary_interval == 0: log('Writing summary at step: %d' % step) feed_dict = { **model.get_dummy_feed_dict(), **test_model.get_dummy_feed_dict() } summary_writer.add_summary( sess.run(test_stats, feed_dict=feed_dict), step) if step % config.checkpoint_interval == 0: log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) if step % config.test_interval == 0: log('Saving audio and alignment...') num_test = config.num_test fetches = [ model.inputs[:num_test], model.linear_outputs[:num_test], model.alignments[:num_test], test_model.inputs[:num_test], test_model.linear_outputs[:num_test], test_model.alignments[:num_test], ] feed_dict = { **model.get_dummy_feed_dict(), **test_model.get_dummy_feed_dict() } sequences, spectrograms, alignments, \ test_sequences, test_spectrograms, test_alignments = \ sess.run(fetches, feed_dict=feed_dict) save_and_plot(sequences[:1], spectrograms[:1], alignments[:1], log_dir, step, loss, "train") save_and_plot(test_sequences, test_spectrograms, test_alignments, log_dir, step, loss, "test") except Exception as e: log('Exiting due to exception: %s' % e, slack=True) traceback.print_exc() coord.request_stop(e)
def train(log_dir, args): commit = get_git_commit() if args.git else 'None' checkpoint_path = os.path.join(log_dir, 'model.ckpt') input_path = os.path.join(args.base_dir, args.input) log('Checkpoint path: %s' % checkpoint_path) log('Loading training data from: %s' % input_path) log('Using model: %s' % args.model) log(hparams_debug_string()) # Set up DataFeeder: coord = tf.train.Coordinator() with tf.variable_scope('datafeeder') as scope: feeder = DataFeeder(coord, input_path, hparams) # Set up model: global_step = tf.Variable(0, name='global_step', trainable=False) with tf.variable_scope('model') as scope: model = create_model(args.model, hparams) model.initialize(feeder.inputs, feeder.input_lengths, feeder.mel_targets, feeder.linear_targets) model.add_loss() model.add_optimizer(global_step) stats = add_stats(model) # Bookkeeping: step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2) # Train! with tf.Session() as sess: try: summary_writer = tf.summary.FileWriter(log_dir, sess.graph) sess.run(tf.global_variables_initializer()) if args.restore_step: # Restore from a checkpoint if the user requested it. restore_path = '%s-%d' % (checkpoint_path, args.restore_step) saver.restore(sess, restore_path) log('Resuming from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True) else: log('Starting new training run at commit: %s' % commit, slack=True) feeder.start_in_session(sess) while not coord.should_stop(): start_time = time.time() step, loss, opt = sess.run([global_step, model.loss, model.optimize]) time_window.append(time.time() - start_time) loss_window.append(loss) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % ( step, time_window.average, loss, loss_window.average) log(message, slack=(step % args.checkpoint_interval == 0)) if loss > 100 or math.isnan(loss): log('Loss exploded to %.05f at step %d!' % (loss, step), slack=True) raise Exception('Loss Exploded') if step % args.summary_interval == 0: log('Writing summary at step: %d' % step) summary_writer.add_summary(sess.run(stats), step) if step % args.checkpoint_interval == 0: log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) log('Saving audio and alignment...') input_seq, spectrogram, alignment = sess.run([ model.inputs[0], model.linear_outputs[0], model.alignments[0]]) waveform = audio.inv_spectrogram(spectrogram.T) audio.save_wav(waveform, os.path.join(log_dir, 'step-%d-audio.wav' % step)) plot.plot_alignment(alignment, os.path.join(log_dir, 'step-%d-align.png' % step), info='%s, %s, %s, step=%d, loss=%.5f' % (args.model, commit, time_string(), step, loss)) log('Input: %s' % sequence_to_text(input_seq)) except Exception as e: log('Exiting due to exception: %s' % e, slack=True) traceback.print_exc() coord.request_stop(e)
def train(log_dir, args): commit = get_git_commit() if args.git else 'None' checkpoint_path = os.path.join(log_dir, 'model.ckpt') input_path = os.path.join(args.base_dir, args.input) log('Checkpoint path: %s' % checkpoint_path) log('Loading training data from: %s' % input_path) log('Using model: %s' % args.model) log(hparams_debug_string()) # Set up DataFeeder: coord = tf.train.Coordinator() with tf.variable_scope('datafeeder') as scope: feeder = DataFeeder(coord, input_path, hparams) # Set up model: global_step = tf.Variable(0, name='global_step', trainable=False) with tf.variable_scope('model') as scope: model = SkipNet(hparams) model.initialize(feeder.txt_A, feeder.txt_A_lenth, feeder.txt_B, feeder.txt_B_lenth, \ feeder.mel_targets, feeder.image_targets) model.add_loss() model.add_optimizer(global_step) stats = add_stats(model) # Bookkeeping: step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2) # Train! with tf.Session() as sess: try: sess.run(tf.global_variables_initializer()) if args.restore_step: # Restore from a checkpoint if the user requested it. restore_path = '%s-%d' % (checkpoint_path, args.restore_step) saver.restore(sess, restore_path) log('Resuming from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True) else: log('Starting new training run at commit: %s' % commit, slack=True) feeder.start_in_session(sess) while not coord.should_stop(): start_time = time.time() feed_dict = { self.txt_targets_A: feeder.txt_A, self.txt_lenth_A: feeder.txt_A_lenth, self.txt_targets_B: feeder.txt_B, self.txt_lenth_B: feeder.txt_B_lenth, self.mel_targets: feeder.mel_targets, self.image_targets: feeder.image_targets } # iter 1: dataset A : image - text pairs step, img_loss, txt_B_loss, d_loss, g_loss, opt1, opt2 =\ sess.run(feed_dict, [global_step, model.recon_img_loss, model.recon_txt_loss_A,\ model.domain_d_loss, model.domain_g_loss, model.optimize_recon, model.optimize_domain]) # iter 2: dataset B: speech-text pairs step, speech_loss, txt_A_loss, opt1, opt2 =\ sess.run(feed_dict, [global_step, model.recon_speech_loss, model.recon_txt_loss_B,\ model.domain_d_loss, model.domain_g_loss, model.optimize_recon, model.optimize_domain]) time_window.append(time.time() - start_time) loss_window.append(loss) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % ( step, time_window.average, loss, loss_window.average) log(message, slack=(step % args.checkpoint_interval == 0)) if loss > 100 or math.isnan(loss): log('Loss exploded to %.05f at step %d!' % (loss, step), slack=True) raise Exception('Loss Exploded') if step % args.checkpoint_interval == 0: log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) except Exception as e: log('Exiting due to exception: %s' % e, slack=True) traceback.print_exc() coord.request_stop(e)
def train(log_dir, args, trans_ckpt_dir=None): commit = get_git_commit() if args.git else 'None' checkpoint_path = os.path.join(log_dir, 'model.ckpt') if trans_ckpt_dir != None: trans_checkpoint_path = os.path.join(trans_ckpt_dir, 'model.ckpt') input_path = os.path.join(args.base_dir, args.input) log('Checkpoint path: %s' % trans_checkpoint_path) log('Loading training data from: %s' % input_path) log('Using model: %s' % args.model) log(hparams_debug_string()) # Set up DataFeeder: coord = tf.train.Coordinator() with tf.variable_scope('datafeeder') as scope: feeder = DataFeeder(coord, input_path, hparams) # Set up model: global_step = tf.Variable(0, name='global_step', trainable=False) with tf.variable_scope('model') as scope: model = create_model(args.model, hparams) model.initialize(feeder.inputs, feeder.input_lengths, feeder.mel_targets, feeder.linear_targets) model.add_loss() model.add_optimizer(global_step) stats = add_stats(model) # Bookkeeping: step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2) # Train! with tf.Session() as sess: try: summary_writer = tf.summary.FileWriter(log_dir, sess.graph) sess.run(tf.global_variables_initializer()) if args.restore_step: # Restore from a checkpoint if the user requested it. restore_path = '%s-%d' % (trans_checkpoint_path, args.restore_step) saver.restore(sess, restore_path) log('Resuming from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True) else: log('Starting new training run at commit: %s' % commit, slack=True) feeder.start_in_session(sess) while not coord.should_stop(): start_time = time.time() step, loss, opt = sess.run( [global_step, model.loss, model.optimize]) time_window.append(time.time() - start_time) loss_window.append(loss) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % ( step, time_window.average, loss, loss_window.average) log(message, slack=(step % args.checkpoint_interval == 0)) if loss > 100 or math.isnan(loss): log('Loss exploded to %.05f at step %d!' % (loss, step), slack=True) raise Exception('Loss Exploded') if step % args.summary_interval == 0: log('Writing summary at step: %d' % step) summary_writer.add_summary(sess.run(stats), step) if step % args.checkpoint_interval == 0: log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) log('Saving audio and alignment...') input_seq, spectrogram, alignment = sess.run([ model.inputs[0], model.linear_outputs[0], model.alignments[0] ]) waveform = audio.inv_spectrogram(spectrogram.T) audio.save_wav( waveform, os.path.join(log_dir, 'step-%d-audio.wav' % step)) plot.plot_alignment( alignment, os.path.join(log_dir, 'step-%d-align.png' % step), info='%s, %s, %s, step=%d, loss=%.5f' % (args.model, commit, time_string(), step, loss)) log('Input: %s' % sequence_to_text(input_seq)) except Exception as e: log('Exiting due to exception: %s' % e, slack=True) traceback.print_exc() coord.request_stop(e)
def train(log_dir, pretrain_log_dir, args): checkpoint_path = os.path.join(log_dir, 'model.ckpt') input_path = os.path.join(args.base_dir, args.input) log('Checkpoint path: %s' % checkpoint_path) log('Loading training data from: %s' % input_path) log('Using model: %s' % args.model) log(hparams_debug_string()) # Set up DataFeeder: ### input_path: linear, mel, frame_num, ppgs coord = tf.train.Coordinator() with tf.variable_scope('datafeeder') as scope: feeder = DataFeeder(coord, input_path, hparams) # Set up model: global_step = tf.Variable(0, name='global_step', trainable=False) with tf.variable_scope('model') as scope: model = create_model(args.model, hparams) model.initialize(input_lengths=feeder.input_lengths, mel_targets=feeder.mel_targets, linear_targets=feeder.linear_targets, ppgs=feeder.ppgs, speakers=feeder.speakers) model.add_loss() model.add_optimizer(global_step) stats = add_stats(model) # Bookkeeping: step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) acc_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2) # Train! with tf.Session() as sess: try: summary_writer = tf.summary.FileWriter(log_dir, sess.graph) sess.run(tf.global_variables_initializer()) try: if pretrain_log_dir != None: checkpoint_state = tf.train.get_checkpoint_state( pretrain_log_dir) else: checkpoint_state = tf.train.get_checkpoint_state(log_dir) if (checkpoint_state and checkpoint_state.model_checkpoint_path): log('Loading checkpoint {}'.format( checkpoint_state.model_checkpoint_path), slack=True) saver.restore(sess, checkpoint_state.model_checkpoint_path) else: log('No model to load at {}'.format(log_dir), slack=True) saver.save(sess, checkpoint_path, global_step=global_step) except tf.errors.OutOfRangeError as e: log('Cannot restore checkpoint: {}'.format(e), slack=True) feeder.start_in_session(sess) while not coord.should_stop(): start_time = time.time() ### how to run training if args.model == 'tacotron': step, loss, opt = sess.run( [global_step, model.loss, model.optimize]) time_window.append(time.time() - start_time) loss_window.append(loss) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % ( step, time_window.average, loss, loss_window.average) log(message, slack=(step % args.checkpoint_interval == 0)) elif args.model == 'nnet1': step, loss, opt, ppgs, logits = sess.run([ global_step, model.loss, model.optimize, model.ppgs, model.logits ]) ## cal acc ppgs = np.argmax(ppgs, axis=-1) # (N, 201, ) logits = np.argmax(logits, axis=-1) # (N, 201, ) num_hits = np.sum(np.equal(ppgs, logits)) num_targets = np.shape(ppgs)[0] * np.shape(ppgs)[1] acc = num_hits / num_targets ## summerize time_window.append(time.time() - start_time) loss_window.append(loss) acc_window.append(acc) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f, acc=%.05f, avg_acc=%.05f]' % ( step, time_window.average, loss, loss_window.average, acc, acc_window.average) log(message, slack=(step % args.checkpoint_interval == 0)) else: print('input error!!') assert 1 == 0 ### save model and logs if loss > 100 or math.isnan(loss): log('Loss exploded to %.05f at step %d!' % (loss, step), slack=True) raise Exception('Loss Exploded') if step % args.summary_interval == 0: log('Writing summary at step: %d' % step) summary_writer.add_summary(sess.run(stats), step) if step % args.checkpoint_interval == 0: log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) except Exception as e: log('Exiting due to exception: %s' % e, slack=True) traceback.print_exc() coord.request_stop(e)
def train(log_dir, args): commit = get_git_commit() if args.git else 'None' checkpoint_path = os.path.join(log_dir, 'model.ckpt') input_path = os.path.join(args.base_dir, args.input) log('Checkpoint path: %s' % checkpoint_path) log('Loading training data from: %s' % input_path) log('Using model: %s' % args.model) log(hparams_debug_string()) ps_hosts = args.ps_hosts.split(",") worker_hosts = args.worker_hosts.split(",") cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts}) server = tf.train.Server(cluster, job_name=args.job_name, task_index=args.task_index) # Block further graph execution if current node is parameter server if args.job_name == "ps": server.join() with tf.device( tf.train.replica_device_setter( worker_device="/job:worker/task:%d" % args.task_index, cluster=cluster)): # Set up DataFeeder: coord = tf.train.Coordinator() with tf.variable_scope('datafeeder') as scope: feeder = DataFeeder(coord, input_path, hparams) # Set up model: global_step = tf.Variable(0, name='global_step', trainable=False) with tf.variable_scope('model') as scope: model = create_model(args.model, hparams) model.initialize(feeder.inputs, feeder.input_lengths, feeder.mel_targets, feeder.linear_targets) model.add_loss() model.add_optimizer(global_step) stats = add_stats(model) # Bookkeeping: step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2, sharded=True) hooks = [tf.train.StopAtStepHook(last_step=1000000)] # Train! # Monitored... automatycznie wznawia z checkpointu. is_chief = (args.task_index == 0) init_op = tf.global_variables_initializer() sv = tf.train.Supervisor(is_chief=(args.task_index == 0), logdir="train_logs", init_op=init_op, summary_op=stats, saver=saver, save_model_secs=600) with sv.managed_session(server.target) as sess: try: summary_writer = tf.summary.FileWriter(log_dir, sess.graph) sess.run(init_op) if args.restore_step and is_chief: # Restore from a checkpoint if the user requested it. restore_path = '%s-%d' % (checkpoint_path, args.restore_step) saver.restore(sess, restore_path) log('Resuming from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True) else: log('Starting new training run at commit: %s' % commit, slack=True) feeder.start_in_session(sess) while not coord.should_stop(): start_time = time.time() step, loss, opt = sess.run( [global_step, model.loss, model.optimize]) time_window.append(time.time() - start_time) loss_window.append(loss) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % ( step, time_window.average, loss, loss_window.average) log(message, slack=(step % args.checkpoint_interval == 0)) if loss > 100 or math.isnan(loss): log('Loss exploded to %.05f at step %d!' % (loss, step), slack=True) raise Exception('Loss Exploded') if step % args.summary_interval == 0: log('Writing summary at step: %d' % step) summary_writer.add_summary(sess.run(stats), step) if step % args.checkpoint_interval == 0 and is_chief: log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) log('Saving audio and alignment...') input_seq, spectrogram, alignment = sess.run([ model.inputs[0], model.linear_outputs[0], model.alignments[0] ]) waveform = audio.inv_spectrogram(spectrogram.T) audio.save_wav( waveform, os.path.join(log_dir, 'step-%d-audio.wav' % step)) plot.plot_alignment( alignment, os.path.join(log_dir, 'step-%d-align.png' % step), info='%s, %s, %s, step=%d, loss=%.5f' % (args.model, commit, time_string(), step, loss)) log('Input: %s' % sequence_to_text(input_seq)) except Exception as e: log('Exiting due to exception: %s' % e, slack=True) traceback.print_exc() coord.request_stop(e)
def train(log_dir, args): commit = get_git_commit() if args.git else 'None' checkpoint_path = os.path.join(log_dir, 'model.ckpt') ## input path is lists of both postive path and negtiva path input_path_pos = os.path.join(args.base_dir, args.input_pos) input_path_neg = os.path.join(args.base_dir, args.input_neg) log('Checkpoint path: %s' % checkpoint_path) log('Loading positive training data from: %s' % input_path_pos) log('Loading negative training data from: %s' % input_path_neg) log('Using model: %s' % args.model) log(hparams_debug_string()) # Set up DataFeeder: coord = tf.train.Coordinator() with tf.variable_scope('datafeeder') as scope: feeder = DataFeeder(coord, input_path_pos, input_path_neg, hparams) # Set up model: global_step = tf.Variable(0, name='global_step', trainable=False) with tf.variable_scope('model') as scope: model = create_model(args.model, hparams) model.initialize(feeder.inputs_pos, feeder.input_lengths_pos, feeder.mel_targets_pos, feeder.linear_targets_pos, feeder.mel_targets_neg, feeder.linear_targets_neg, feeder.labels_pos, feeder.labels_neg) model.add_loss() model.add_optimizer(global_step) # Bookkeeping: step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2) # Train! with tf.Session() as sess: try: #summary_writer = tf.summary.FileWriter(log_dir, sess.graph) sess.run(tf.global_variables_initializer()) if args.restore_step: # Restore from a checkpoint if the user requested it. restore_path = '%s-%d' % (checkpoint_path, args.restore_step) saver.restore(sess, restore_path) log('Resuming from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True) else: log('Starting new training run at commit: %s' % commit, slack=True) feeder.start_in_session(sess) while not coord.should_stop(): start_time = time.time() # train d sess.run(model.d_optimize) # train g step, rec_loss, style_loss, d_loss, g_loss, _ = sess.run([ global_step, model.rec_loss, model.style_loss, model.d_loss, model.g_loss, model.g_optimize ]) time_window.append(time.time() - start_time) message = 'Step %-7d [%.03f sec/step, rec_loss=%.05f, style_loss=%.05f, d_loss=%.05f, g_loss=%.05f]' % ( step, time_window.average, rec_loss, style_loss, d_loss, g_loss) log(message, slack=(step % args.checkpoint_interval == 0)) if step % args.checkpoint_interval == 0: log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) log('Saving audio and alignment...') input_seq, spectrogram_pos, spectrogram_neg, alignment_pos, alignment_neg = sess.run( [ model.inputs[0], model.linear_outputs_pos[0], model.linear_outputs_neg[0], model.alignments_pos[0], model.alignments_neg[0] ]) waveform_pos = audio.inv_spectrogram(spectrogram_pos.T) waveform_neg = audio.inv_spectrogram(spectrogram_neg.T) audio.save_wav( waveform_pos, os.path.join(log_dir, 'step-%d-audio_pos.wav' % step)) audio.save_wav( waveform_neg, os.path.join(log_dir, 'step-%d-audio_neg.wav' % step)) plot.plot_alignment( alignment_pos, os.path.join(log_dir, 'step-%d-align_pos.png' % step), info='%s, %s, %s, step=%d, loss=%.5f' % (args.model, commit, time_string(), step, rec_loss)) plot.plot_alignment( alignment_neg, os.path.join(log_dir, 'step-%d-align_neg.png' % step), info='%s, %s, %s, step=%d, loss=%.5f' % (args.model, commit, time_string(), step, rec_loss)) log('Input: %s' % sequence_to_text(input_seq)) except Exception as e: log('Exiting due to exception: %s' % e, slack=True) traceback.print_exc() coord.request_stop(e)
def train(log_dir, config): config.data_paths = config.data_paths # 파싱된 명령행 인자값 중 데이터 경로 : default='datasets/kr_example' data_dirs = [os.path.join(data_path, "data") \ for data_path in config.data_paths] num_speakers = len(data_dirs) # 학습하는 화자 수 측정 : 단일화자 모델-1, 다중화자 모델-2 config.num_test = config.num_test_per_speaker * num_speakers if num_speakers > 1 and hparams.model_type not in ["deepvoice", "simple"]: # 다중화자 모델 학습일 때 모델 타입이 "deepvoice"나 "simple"이 아니라면 raise Exception("[!] Unkown model_type for multi-speaker: {}".format(config.model_type)) # hparams.modle_type을 config.model_type으로 오타남. commit = get_git_commit() if config.git else 'None' # git 관련된거여서 무시 checkpoint_path = os.path.join(log_dir, 'model.ckpt') # checkpoint_path 경로 지정-model.skpt 파일 경로 log(' [*] git recv-parse HEAD:\n%s' % get_git_revision_hash()) # git log log('='*50) # 줄 구분용 ===== #log(' [*] dit diff:\n%s' % get_git_diff()) log('='*50) # 줄 구분용 ===== log(' [*] Checkpoint path: %s' % checkpoint_path) # check_point 경로 출력 log(' [*] Loading training data from: %s' % data_dirs) log(' [*] Using model: %s' % config.model_dir) log(hparams_debug_string()) # Set up DataFeeder: coord = tf.train.Coordinator() # 쓰레드 사용 선언 with tf.variable_scope('datafeeder') as scope: train_feeder = DataFeeder( coord, data_dirs, hparams, config, 32, data_type='train', batch_size=hparams.batch_size) # def __init__(self, coordinator, data_dirs, hparams, config, batches_per_group, data_type, batch_size): test_feeder = DataFeeder( coord, data_dirs, hparams, config, 8, data_type='test', batch_size=config.num_test) # Set up model: is_randomly_initialized = config.initialize_path is None global_step = tf.Variable(0, name='global_step', trainable=False) with tf.variable_scope('model') as scope: model = create_model(hparams) # Tacotron 모델 생성 model.initialize( train_feeder.inputs, train_feeder.input_lengths, num_speakers, train_feeder.speaker_id, train_feeder.mel_targets, train_feeder.linear_targets, train_feeder.loss_coeff, is_randomly_initialized=is_randomly_initialized) model.add_loss() model.add_optimizer(global_step) train_stats = add_stats(model, scope_name='stats') # legacy with tf.variable_scope('model', reuse=True) as scope: test_model = create_model(hparams) # Tacotron test모델 생성 test_model.initialize( test_feeder.inputs, test_feeder.input_lengths, num_speakers, test_feeder.speaker_id, test_feeder.mel_targets, test_feeder.linear_targets, test_feeder.loss_coeff, rnn_decoder_test_mode=True, is_randomly_initialized=is_randomly_initialized) test_model.add_loss() test_stats = add_stats(test_model, model, scope_name='test') # model의 loss값같은것들을 tensorboard에 기록 / model에 test_model, model2에 model test_stats = tf.summary.merge([test_stats, train_stats]) # Bookkeeping: step = 0 time_window = ValueWindow(100) # ValueWindow 클래스 window_size = 100 loss_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=None, keep_checkpoint_every_n_hours=2) # 2시간에 한번씩 자동저장, checkpoint 삭제 안됨 sess_config = tf.ConfigProto( log_device_placement=False, # log_device_placement 작성하는동안 할당장치 알려줌. allow_soft_placement=True) # allow_soft_placement False면 GPU없을때 오류남 sess_config.gpu_options.allow_growth=True # 탄력적으로 GPU메모리 사용 # Train! #with tf.Session(config=sess_config) as sess: with tf.Session() as sess: # with문 내의 모든 명령들은 CPU 혹은 GPU 사용 선언 try: summary_writer = tf.summary.FileWriter(log_dir, sess.graph) # summary 오퍼레이션이 평가된 결과 및 텐서보드 그래프를 파라미터 형식으로 log_dir 에 저장 sess.run(tf.global_variables_initializer()) # 데이터셋이 로드되고 그래프가 모두 정의되면 변수를 초기화하여 훈련 시작 if config.load_path: # log의 설정 값들 경로를 지정하였다면 # Restore from a checkpoint if the user requested it. restore_path = get_most_recent_checkpoint(config.model_dir) # 가장 마지막에 저장된 파일경로 저장 saver.restore(sess, restore_path) # restore_path 값 가져오기 log('Resuming from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True) # git과 slack을 이용한 log 출력 elif config.initialize_path: # log의 설정 값들로 초기화하여 사용하기로 지정하였다면 restore_path = get_most_recent_checkpoint(config.initialize_path) # 지정된 경로에서 가장 마지막에 저장된 파일경로 저장 saver.restore(sess, restore_path) # restore_path 값 가져오기 log('Initialized from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True) # git과 slack을 이용한 log 출력 zero_step_assign = tf.assign(global_step, 0) # global_step의 텐서 객체 참조 변수 값을 0으로 바꿔주는 명령어 지정 sess.run(zero_step_assign) # 변수들을 모두 0으로 바꾸는 명령어 실행 start_step = sess.run(global_step) # global_step 값 부분을 시작지점으로 하여 연산 시작 log('='*50) log(' [*] Global step is reset to {}'. \ format(start_step)) # 즉, 연산 시작 부분이 0으로 초기화 되었다고 알려줌. log('='*50) else: log('Starting new training run at commit: %s' % commit, slack=True) # 과거의 데이터를 사용하지 않을 경우 새로운 학습이라고 log 출력 start_step = sess.run(global_step) # 연산 시작지점 가져오기 train_feeder.start_in_session(sess, start_step) test_feeder.start_in_session(sess, start_step) while not coord.should_stop(): # 쓰레드가 멈춰야하는 상황이 아니라면 start_time = time.time() # 시작시간 지정(1970년 1월 1일 이후 경과된 시간을 UTC 기준으로 초로 반환) step, loss, opt = sess.run( [global_step, model.loss_without_coeff, model.optimize], feed_dict=model.get_dummy_feed_dict()) # step 값은 global_step 값으로 지정, loss 값은 time_window.append(time.time() - start_time) loss_window.append(loss) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % ( step, time_window.average, loss, loss_window.average) log(message, slack=(step % config.checkpoint_interval == 0)) if loss > 100 or math.isnan(loss): log('Loss exploded to %.05f at step %d!' % (loss, step), slack=True) raise Exception('Loss Exploded') if step % config.summary_interval == 0: log('Writing summary at step: %d' % step) feed_dict = { **model.get_dummy_feed_dict(), **test_model.get_dummy_feed_dict() } summary_writer.add_summary(sess.run( test_stats, feed_dict=feed_dict), step) if step % config.checkpoint_interval == 0: log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) if step % config.test_interval == 0: log('Saving audio and alignment...') num_test = config.num_test fetches = [ model.inputs[:num_test], model.linear_outputs[:num_test], model.alignments[:num_test], test_model.inputs[:num_test], test_model.linear_outputs[:num_test], test_model.alignments[:num_test], ] feed_dict = { **model.get_dummy_feed_dict(), **test_model.get_dummy_feed_dict() } sequences, spectrograms, alignments, \ test_sequences, test_spectrograms, test_alignments = \ sess.run(fetches, feed_dict=feed_dict) save_and_plot(sequences[:1], spectrograms[:1], alignments[:1], log_dir, step, loss, "train") save_and_plot(test_sequences, test_spectrograms, test_alignments, log_dir, step, loss, "test") except Exception as e: log('Exiting due to exception: %s' % e, slack=True) traceback.print_exc() coord.request_stop(e)
def gst_train(log_dir, args): commit = get_git_commit() if args.git else 'None' save_dir = os.path.join(log_dir, 'gst_pretrained/') checkpoint_path = os.path.join(save_dir, 'gst_model.ckpt') input_path = os.path.join(args.base_dir, args.gst_input) plot_dir = os.path.join(log_dir, 'plots') wav_dir = os.path.join(log_dir, 'wavs') mel_dir = os.path.join(log_dir, 'mel-spectrograms') eval_dir = os.path.join(log_dir, 'eval-dir') eval_plot_dir = os.path.join(eval_dir, 'plots') eval_wav_dir = os.path.join(eval_dir, 'wavs') os.makedirs(eval_dir, exist_ok=True) os.makedirs(plot_dir, exist_ok=True) os.makedirs(wav_dir, exist_ok=True) os.makedirs(mel_dir, exist_ok=True) os.makedirs(eval_plot_dir, exist_ok=True) os.makedirs(eval_wav_dir, exist_ok=True) log('Checkpoint path: %s' % checkpoint_path) log('Loading training data from: %s' % input_path) log(hparams_debug_string()) #Start by setting a seed for repeatability tf.set_random_seed(hparams.random_seed) # Set up DataFeeder: coord = tf.train.Coordinator() with tf.variable_scope('datafeeder') as scope: feeder = DataFeeder(coord, input_path, hparams) # Set up model: global_step = tf.Variable(0, name='global_step', trainable=False) model, stats = model_train_mode(args, feeder, hparams, global_step) eval_model = model_test_mode(args, feeder, hparams, global_step) # Bookkeeping: step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2) #Memory allocation on the GPU as needed config = tf.ConfigProto() config.gpu_options.allow_growth = True # Train! with tf.Session(config=config) as sess: try: summary_writer = tf.summary.FileWriter(log_dir, sess.graph) sess.run(tf.global_variables_initializer()) checkpoint_state = False #saved model restoring if args.restore_step: #Restore saved model if the user requested it, Default = True. try: checkpoint_state = tf.train.get_checkpoint_state(save_dir) except tf.errors.OutOfRangeError as e: log('Cannot restore checkpoint: {}'.format(e)) if (checkpoint_state and checkpoint_state.model_checkpoint_path): log('Loading checkpoint {}'.format(checkpoint_state.model_checkpoint_path)) saver.restore(sess, checkpoint_state.model_checkpoint_path) else: if not args.restore_step: log('Starting new training!') else: log('No model to load at {}'.format(save_dir)) feeder.start_in_session(sess) while not coord.should_stop() and step < args.gst_train_steps: start_time = time.time() step, loss, opt = sess.run([global_step, model.loss, model.optimize]) time_window.append(time.time() - start_time) loss_window.append(loss) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % ( step, time_window.average, loss, loss_window.average) log(message, slack=(step % args.checkpoint_interval == 0)) if loss > 100 or math.isnan(loss): log('Loss exploded to %.05f at step %d!' % (loss, step), slack=True) raise Exception('Loss Exploded') if step % args.summary_interval == 0: log('Writing summary at step: %d' % step) summary_writer.add_summary(sess.run(stats), step) if step % args.eval_interval == 0: #Run eval and save eval stats log('\nRunning evaluation at step {}'.format(step)) eval_losses = [] linear_losses = [] #TODO: FIX TO ENCOMPASS MORE LOSS for i in tqdm(range(feeder.test_steps)): eloss, linear_loss, mel_p, mel_t, t_len, align, lin_p = sess.run([eval_model.loss, eval_model.linear_loss, eval_model.mel_outputs[0], eval_model.mel_targets[0], eval_model.targets_lengths[0], eval_model.alignments[0], eval_model.linear_outputs[0]]) eval_losses.append(eloss) linear_losses.append(linear_loss) eval_loss = sum(eval_losses) / len(eval_losses) linear_loss = sum(linear_losses) / len(linear_losses) wav = audio.inv_linear_spectrogram(lin_p.T) audio.save_wav(wav, os.path.join(eval_wav_dir, 'step-{}-eval-waveform-linear.wav'.format(step))) log('Saving eval log to {}..'.format(eval_dir)) #Save some log to monitor model improvement on same unseen sequence wav = audio.inv_mel_spectrogram(mel_p.T) audio.save_wav(wav, os.path.join(eval_wav_dir, 'step-{}-eval-waveform-mel.wav'.format(step))) plot.plot_alignment(align, os.path.join(eval_plot_dir, 'step-{}-eval-align.png'.format(step)), info='{}, {}, step={}, loss={:.5f}'.format(args.model, time_string(), step, eval_loss), max_len=t_len // hparams.outputs_per_step) plot.plot_spectrogram(mel_p, os.path.join(eval_plot_dir, 'step-{}-eval-mel-spectrogram.png'.format(step)), info='{}, {}, step={}, loss={:.5}'.format(args.model, time_string(), step, eval_loss), target_spectrogram=mel_t, ) log('Eval loss for global step {}: {:.3f}'.format(step, eval_loss)) log('Writing eval summary!') add_eval_stats(summary_writer, step, linear_loss, eval_loss) if step % args.checkpoint_interval == 0 or step == args.gst_train_steps: log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) log('Saving audio and alignment...') input_seq, mel_pred, alignment, target, target_len = sess.run([model.inputs[0], model.mel_outputs[0], model.alignments[0], model.mel_targets[0], model.targets_lengths[0], ]) #save predicted mel spectrogram to disk (debug) mel_filename = 'mel-prediction-step-{}.npy'.format(step) np.save(os.path.join(mel_dir, mel_filename), mel_pred.T, allow_pickle=False) #save griffin lim inverted wav for debug (mel -> wav) wav = audio.inv_mel_spectrogram(mel_pred.T) audio.save_wav(wav, os.path.join(wav_dir, 'step-{}-wave-from-mel.wav'.format(step))) #save alignment plot to disk (control purposes) plot.plot_alignment(alignment, os.path.join(plot_dir, 'step-{}-align.png'.format(step)), info='{}, {}, step={}, loss={:.5f}'.format(args.model, time_string(), step, loss), max_len=target_len // hparams.outputs_per_step) #save real and predicted mel-spectrogram plot to disk (control purposes) plot.plot_spectrogram(mel_pred, os.path.join(plot_dir, 'step-{}-mel-spectrogram.png'.format(step)), info='{}, {}, step={}, loss={:.5}'.format(args.model, time_string(), step, loss), target_spectrogram=target, max_len=target_len) log('Input at step {}: {}'.format(step, sequence_to_text(input_seq))) log('GST Taco training complete after {} global steps!'.format(args.gst_train_steps)) return save_dir except Exception as e: log('Exiting due to exception: %s' % e, slack=True) traceback.print_exc() coord.request_stop(e)
def train(log_dir, args): commit = get_git_commit() if args.git else 'None' checkpoint_path = os.path.join(log_dir, 'model.ckpt') input_path = os.path.join(args.base_dir, args.input) log('Checkpoint path: %s' % checkpoint_path) log('Loading training data from: %s' % input_path) log('Using model: %s' % args.model) log(hparams_debug_string()) # Set up DataFeeder: coord = tf.train.Coordinator() with tf.variable_scope('datafeeder') as _: feeder = DataFeeder(coord, input_path, hparams) # Set up model: global_step = tf.Variable(0, name='global_step', trainable=False) with tf.variable_scope('model') as _: model = create_model(args.model, hparams) model.initialize(feeder.inputs, args.vgg19_pretrained_model, feeder.mel_targets, feeder.linear_targets) model.add_loss() model.add_optimizer(global_step) stats = add_stats(model) # Bookkeeping: time_window = ValueWindow() loss_window = ValueWindow() saver = tf.train.Saver(keep_checkpoint_every_n_hours=2) # Train! with tf.Session() as sess: try: train_start_time = time.time() summary_writer = tf.summary.FileWriter(log_dir, sess.graph) sess.run(tf.global_variables_initializer()) if args.restore_step: # Restore from a checkpoint if the user requested it. restore_path = '%s-%d' % (checkpoint_path, args.restore_step) checkpoint_saver = tf.train.import_meta_graph( '%s.%s' % (restore_path, 'meta')) checkpoint_saver.restore(sess, restore_path) log('Resuming from checkpoint: %s at commit: %s' % (restore_path, commit)) else: log('Starting new training run at commit: %s' % commit) feeder.start_in_session(sess) while not coord.should_stop(): start_time = time.time() step, loss, opt = sess.run( [global_step, model.loss, model.optimize]) time_window.append(time.time() - start_time) loss_window.append(loss) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % ( step, time_window.average, loss, loss_window.average) log(message, slack=(step % args.summary_interval == 0)) if loss > 100 or math.isnan(loss): log('Loss exploded to %.05f at step %d!' % (loss, step)) raise Exception('Loss Exploded') if step % args.summary_interval == 0: log('Writing summary at step: %d' % step) summary_writer.add_summary(sess.run(stats), step) if step % args.checkpoint_interval == 0: log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) log('Saving audio...') _, spectrogram, _ = sess.run([ model.inputs[0], model.linear_outputs[0], model.alignments[0] ]) waveform = audio.inv_spectrogram(spectrogram.T) audio_path = os.path.join(log_dir, 'step-%d-audio.wav' % step) audio.save_wav(waveform, audio_path) infolog.upload_to_slack(audio_path, step) time_so_far = time.time() - train_start_time hrs, rest = divmod(time_so_far, 3600) min, secs = divmod(rest, 60) log('{:.0f} hrs, {:.0f}mins and {:.1f}sec since the training process began' .format(hrs, min, secs)) if asked_to_stop(step): coord.request_stop() except Exception as e: log('@channel: Exiting due to exception: %s' % e) traceback.print_exc() coord.request_stop(e)
def train(log_dir, args): checkpoint_path = os.path.join(log_dir, 'model.ckpt') input_path = os.path.join(args.base_dir, args.input) # 显示模型的路径信息 log('Checkpoint path: %s' % checkpoint_path) log('Loading training data from: %s' % input_path) coord = tf.train.Coordinator() with tf.variable_scope('datafeeder') as scope: feeder = DataFeeder(coord, input_path, hparams) # 初始化模型 global_step = tf.Variable(0, name='global_step', trainable=False) with tf.variable_scope('model') as scope: model = Tacotron(hparams) model.initialize(feeder.inputs, feeder.input_lengths, feeder.mel_targets, feeder.linear_targets, feeder.stop_token_targets, global_step) model.add_loss() model.add_optimizer(global_step) stats = add_stats(model) step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=1) # 开始训练 with tf.Session() as sess: try: summary_writer = tf.summary.FileWriter(log_dir, sess.graph) sess.run(tf.global_variables_initializer()) feeder.start_in_session(sess) while not coord.should_stop(): start_time = time.time() step, loss, opt = sess.run( [global_step, model.loss, model.optimize]) time_window.append(time.time() - start_time) loss_window.append(loss) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % ( step, time_window.average, loss, loss_window.average) log(message, slack=(step % args.checkpoint_interval == 0)) if step % args.summary_interval == 0: summary_writer.add_summary(sess.run(stats), step) # 每隔一定的训练步数生成检查点 if step % args.checkpoint_interval == 0: log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) log('Saving audio and alignment...') input_seq, spectrogram, alignment = sess.run([ model.inputs[0], model.linear_outputs[0], model.alignments[0] ]) waveform = audio.inv_spectrogram(spectrogram.T) # 合成样音 audio.save_wav( waveform, os.path.join(log_dir, 'step-%d-audio.wav' % step)) time_string = datetime.now().strftime('%Y-%m-%d %H:%M') # 画Encoder-Decoder对齐图 infolog.plot_alignment( alignment, os.path.join(log_dir, 'step-%d-align.png' % step), info='%s, %s, step=%d, loss=%.5f' % (args.model, time_string, step, loss)) # 显示合成样音的文本 log('Input: %s' % sequence_to_text(input_seq)) except Exception as e: log('Exiting due to exception: %s' % e, slack=True) traceback.print_exc() coord.request_stop(e)
def train(log_dir, args): commit = get_git_commit() if args.git else 'None' checkpoint_path = os.path.join(log_dir, 'model.ckpt') input_path = os.path.join(args.base_dir, args.input) log('Checkpoint path: %s' % checkpoint_path) log('Loading training data from: %s' % input_path) log('Using model: %s' % args.model) log(hparams_debug_string()) # graph with tf.Graph().as_default(), tf.device('/cpu:0'): #new attributes of hparams #hparams.num_GPU = len(GPUs_id) #hparams.datasets = eval(args.datasets) hparams.datasets = eval(args.datasets) hparams.prenet_layer1 = args.prenet_layer1 hparams.prenet_layer2 = args.prenet_layer2 hparams.gru_size = args.gru_size hparams.attention_size = args.attention_size hparams.rnn_size = args.rnn_size hparams.enable_fv1 = args.enable_fv1 hparams.enable_fv2 = args.enable_fv2 if args.batch_size: hparams.batch_size = args.batch_size # Multi-GPU settings GPUs_id = eval(args.GPUs_id) hparams.num_GPU = len(GPUs_id) tower_grads = [] tower_loss = [] models = [] global_step = tf.Variable(-1, name='global_step', trainable=False) if hparams.decay_learning_rate: learning_rate = _learning_rate_decay(hparams.initial_learning_rate, global_step, hparams.num_GPU) else: learning_rate = tf.convert_to_tensor(hparams.initial_learning_rate) # Set up DataFeeder: coord = tf.train.Coordinator() with tf.variable_scope('datafeeder') as scope: input_path = os.path.join(args.base_dir, args.input) feeder = DataFeeder(coord, input_path, hparams) inputs = feeder.inputs inputs = tf.split(inputs, hparams.num_GPU, 0) input_lengths = feeder.input_lengths input_lengths = tf.split(input_lengths, hparams.num_GPU, 0) mel_targets = feeder.mel_targets mel_targets = tf.split(mel_targets, hparams.num_GPU, 0) linear_targets = feeder.linear_targets linear_targets = tf.split(linear_targets, hparams.num_GPU, 0) # Set up model: with tf.variable_scope('model') as scope: optimizer = tf.train.AdamOptimizer(learning_rate, hparams.adam_beta1, hparams.adam_beta2) for i, GPU_id in enumerate(GPUs_id): with tf.device('/gpu:%d' % GPU_id): with tf.name_scope('GPU_%d' % GPU_id): if hparams.enable_fv1 or hparams.enable_fv2: net = ResCNN(data=mel_targets[i], batch_size=hparams.batch_size, hyparam=hparams) net.inference() voice_print_feature = tf.reduce_mean( net.features, 0) else: voice_print_feature = None models.append(None) models[i] = create_model(args.model, hparams) models[i].initialize( inputs=inputs[i], input_lengths=input_lengths[i], mel_targets=mel_targets[i], linear_targets=linear_targets[i], voice_print_feature=voice_print_feature) models[i].add_loss() """L2 weight decay loss.""" if args.weight_decay > 0: costs = [] for var in tf.trainable_variables(): #if var.op.name.find(r'DW') > 0: costs.append(tf.nn.l2_loss(var)) # tf.summary.histogram(var.op.name, var) weight_decay = tf.cast(args.weight_decay, tf.float32) cost = models[i].loss models[i].loss += tf.multiply( weight_decay, tf.add_n(costs)) cost_pure_wd = tf.multiply(weight_decay, tf.add_n(costs)) else: cost = models[i].loss cost_pure_wd = tf.constant([0]) tower_loss.append(models[i].loss) tf.get_variable_scope().reuse_variables() models[i].add_optimizer(global_step, optimizer) tower_grads.append(models[i].gradients) # calculate average gradient gradients = average_gradients(tower_grads) stats = add_stats(models[0], gradients, learning_rate) time.sleep(10) # apply average gradient with tf.control_dependencies(tf.get_collection( tf.GraphKeys.UPDATE_OPS)): apply_gradient_op = optimizer.apply_gradients( gradients, global_step=global_step) # Bookkeeping: step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2) # Train! config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True) config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: try: summary_writer = tf.summary.FileWriter(log_dir, sess.graph) sess.run(tf.global_variables_initializer()) if args.restore_step: # Restore from a checkpoint if the user requested it. restore_path = '%s-%d' % (checkpoint_path, args.restore_step) saver.restore(sess, restore_path) log('Resuming from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True) else: log('Starting new training run at commit: %s' % commit, slack=True) feeder.start_in_session(sess) while not coord.should_stop(): start_time = time.time() model = models[0] step, loss, opt, loss_wd, loss_pure_wd = sess.run([ global_step, cost, apply_gradient_op, model.loss, cost_pure_wd ]) feeder._batch_in_queue -= 1 log('feed._batch_in_queue: %s' % str(feeder._batch_in_queue), slack=True) time_window.append(time.time() - start_time) loss_window.append(loss) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f, loss_wd=%.05f, loss_pure_wd=%.05f]' % ( step, time_window.average, loss, loss_window.average, loss_wd, loss_pure_wd) log(message, slack=(step % args.checkpoint_interval == 0)) #if the gradient seems to explode, then restore to the previous step if loss > 2 * loss_window.average or math.isnan(loss): log('recover to the previous checkpoint') #tf.reset_default_graph() restore_step = int( (step - 10) / args.checkpoint_interval ) * args.checkpoint_interval restore_path = '%s-%d' % (checkpoint_path, restore_step) saver.restore(sess, restore_path) continue if loss > 100 or math.isnan(loss): log('Loss exploded to %.05f at step %d!' % (loss, step), slack=True) raise Exception('Loss Exploded') try: if step % args.summary_interval == 0: log('Writing summary at step: %d' % step) summary_writer.add_summary(sess.run(stats), step) except: pass if step % args.checkpoint_interval == 0: log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) log('Saving audio and alignment...') input_seq, spectrogram, alignment = sess.run([ model.inputs[0], model.linear_outputs[0], model.alignments[0] ]) waveform = audio.inv_spectrogram(spectrogram.T) audio.save_wav( waveform, os.path.join(log_dir, 'step-%d-audio.wav' % step)) plot.plot_alignment( alignment, os.path.join(log_dir, 'step-%d-align.png' % step), info='%s, %s, %s, step=%d, loss=%.5f' % (args.model, commit, time_string(), step, loss)) log('Input: %s' % sequence_to_text(input_seq)) except Exception as e: log('Exiting due to exception: %s' % e, slack=True) traceback.print_exc() coord.request_stop(e)
def train(log_dir, args): commit = get_git_commit() if args.git else 'None' checkpoint_path = os.path.join(log_dir, 'model.ckpt') DATA_PATH = {'bznsyp': "BZNSYP", 'ljspeech': "LJSpeech-1.1"}[args.dataset] input_path = os.path.join(args.base_dir, 'DATA', DATA_PATH, 'training', 'train.txt') log('Checkpoint path: %s' % checkpoint_path) log('Loading training data from: %s' % input_path) log('Using model: %s' % args.model) log(hparams_debug_string()) # Set up DataFeeder: coord = tf.train.Coordinator() with tf.variable_scope('datafeeder') as scope: feeder = DataFeeder(coord, input_path, hparams) # Set up model: global_step = tf.Variable(0, name='global_step', trainable=False) with tf.variable_scope('model') as scope: model = create_model(args.model, hparams) model.initialize(feeder.inputs, feeder.input_lengths, feeder.lpc_targets, feeder.stop_token_targets) model.add_loss() model.add_optimizer(global_step) stats = add_stats(model) # Bookkeeping: step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=999, keep_checkpoint_every_n_hours=2) # Train! config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: try: summary_writer = tf.summary.FileWriter(log_dir, sess.graph) sess.run(tf.global_variables_initializer()) if args.restore_step: # Restore from a checkpoint if the user requested it. checkpoint_state = tf.train.get_checkpoint_state(log_dir) # restore_path = '%s-%d' % (checkpoint_path, args.restore_step) if checkpoint_state is not None: saver.restore(sess, checkpoint_state.model_checkpoint_path) log('Resuming from checkpoint: %s at commit: %s' % (checkpoint_state.model_checkpoint_path, commit), slack=True) else: log('Starting new training run at commit: %s' % commit, slack=True) if args.restore_decoder: models = [ f for f in os.listdir('pretrain') if f.find('.meta') != -1 ] decoder_ckpt_path = os.path.join( 'pretrain', models[0].replace('.meta', '')) global_vars = tf.global_variables() var_list = [] valid_scope = [ 'model/inference/decoder', 'model/inference/post_cbhg', 'model/inference/dense', 'model/inference/memory_layer' ] for v in global_vars: if v.name.find('attention') != -1: continue if v.name.find('Attention') != -1: continue for scope in valid_scope: if v.name.startswith(scope): var_list.append(v) decoder_saver = tf.train.Saver(var_list) decoder_saver.restore(sess, decoder_ckpt_path) print('restore pretrained decoder ...') feeder.start_in_session(sess) while not coord.should_stop(): start_time = time.time() step, loss, opt = sess.run( [global_step, model.loss, model.optimize]) time_window.append(time.time() - start_time) loss_window.append(loss) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % ( step, time_window.average, loss, loss_window.average) log(message, slack=(step % args.checkpoint_interval == 0)) if loss > 100 or math.isnan(loss): log('Loss exploded to %.05f at step %d!' % (loss, step), slack=True) raise Exception('Loss Exploded') if step % args.summary_interval == 0: log('Writing summary at step: %d' % step) summary_writer.add_summary(sess.run(stats), step) if step % args.checkpoint_interval == 0: log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) log('Saving audio and alignment...') input_seq, lpc_targets, alignment = sess.run([ model.inputs[0], model.lpc_outputs[0], model.alignments[0] ]) plot.plot_alignment( alignment, os.path.join(log_dir, 'step-%d-align.png' % step), info='%s, %s, %s, step=%d, loss=%.5f' % (args.model, commit, time_string(), step, loss)) np.save(os.path.join(log_dir, 'step-%d-lpc.npy' % step), lpc_targets) log('Input: %s' % sequence_to_text(input_seq)) except Exception as e: log('Exiting due to exception: %s' % e, slack=True) traceback.print_exc() coord.request_stop(e)