def main(): parser = argparse.ArgumentParser() parser.add_argument('--model_dir', default='.', help="") parser.add_argument('--data_dir', default='./data/casp7', help="") parser.add_argument('--restore_dir', default=None, help="") parser.add_argument('--param_path', default='./hparams.json', help="Hyperparameter file") args = parser.parse_args() physical_devices = tf.config.experimental.list_physical_devices('GPU') assert len( physical_devices) > 0, "Not enough GPU hardware devices available" tf.config.experimental.set_memory_growth(physical_devices[0], True) tf.random.set_seed(117) hparams = Params(args.param_path) model = Pronet(hparams) feeder = Feeder(args.data_dir, hparams) #if args.restore: # pass #else: # pass train(model, feeder, hparams)
def train(log_dir, args): save_dir = os.path.join(log_dir, 'pretrained/') checkpoint_path = os.path.join(save_dir, 'model.ckpt') input_path = os.path.join(args.base_dir, args.input) plot_dir = os.path.join(log_dir, 'plots') os.makedirs(plot_dir, exist_ok=True) log('Checkpoint path: {}'.format(checkpoint_path)) log('Loading training data from: {}'.format(input_path)) log('Using model: {}'.format(args.model)) log(hparams_debug_string()) #Set up data feeder coord = tf.train.Coordinator() with tf.variable_scope('datafeeder') as scope: feeder = Feeder(coord, input_path, hparams) #Set up model: step_count = 0 try: #simple text file to keep count of global step with open(os.path.join(log_dir, 'step_counter.txt'), 'r') as file: step_count = int(file.read()) except: print('no step_counter file found, assuming there is no saved checkpoint') global_step = tf.Variable(step_count, name='global_step', trainable=False) with tf.variable_scope('model') as scope: model = create_model(args.model, hparams) model.initialize(feeder.inputs, feeder.input_lengths, feeder.mel_targets, feeder.token_targets) model.add_loss() model.add_optimizer(global_step) stats = add_stats(model) #Book keeping step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=5) #Memory allocation on the GPU as needed config = tf.ConfigProto() config.gpu_options.allow_growth = True #Train with tf.Session(config=config) as sess: try: summary_writer = tf.summary.FileWriter(log_dir, sess.graph) sess.run(tf.global_variables_initializer()) #saved model restoring if args.restore: #Restore saved model if the user requested it, Default = True. try: checkpoint_state = tf.train.get_checkpoint_state(save_dir) except tf.errors.OutOfRangeError as e: log('Cannot restore checkpoint: {}'.format(e)) if (checkpoint_state and checkpoint_state.model_checkpoint_path): log('Loading checkpoint {}'.format(checkpoint_state.model_checkpoint_path)) saver.restore(sess, checkpoint_state.model_checkpoint_path) else: if not args.restore: log('Starting new training!') else: log('No model to load at {}'.format(save_dir)) #initiating feeder feeder.start_in_session(sess) #Training loop while not coord.should_stop(): start_time = time.time() step, loss, opt = sess.run([global_step, model.loss, model.optimize]) time_window.append(time.time() - start_time) loss_window.append(loss) message = 'Step {:7d} [{:.3f} sec/step, loss={:.5f}, avg_loss={:.5f}]'.format( step, time_window.average, loss, loss_window.average) log(message, end='\r') if loss > 100 or np.isnan(loss): log('Loss exploded to {:.5f} at step {}'.format(loss, step)) raise Exception('Loss exploded') if step % args.summary_interval == 0: log('\nWriting summary at step: {}'.format(step)) summary_writer.add_summary(sess.run(stats), step) if step % args.checkpoint_interval == 0: with open(os.path.join(log_dir,'step_counter.txt'), 'w') as file: file.write(str(step)) log('Saving checkpoint to: {}-{}'.format(checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) # Unlike the original tacotron, we won't save audio # because we yet have to use wavenet as vocoder log('Saving alignement and Mel-Spectrograms..') input_seq, prediction, alignment, target = sess.run([model.inputs[0], model.mel_outputs[0], model.alignments[0], model.mel_targets[0], ]) #save predicted spectrogram to disk (for plot and manual evaluation purposes) mel_filename = 'ljspeech-mel-prediction-step-{}.npy'.format(step) np.save(os.path.join(log_dir, mel_filename), prediction, allow_pickle=False) #save alignment plot to disk (control purposes) plot.plot_alignment(alignment, os.path.join(plot_dir, 'step-{}-align.png'.format(step)), info='{}, {}, step={}, loss={:.5f}'.format(args.model, time_string(), step, loss)) #save real mel-spectrogram plot to disk (control purposes) plot.plot_spectrogram(target, os.path.join(plot_dir, 'step-{}-real-mel-spectrogram.png'.format(step)), info='{}, {}, step={}, Real'.format(args.model, time_string(), step, loss)) #save predicted mel-spectrogram plot to disk (control purposes) plot.plot_spectrogram(prediction, os.path.join(plot_dir, 'step-{}-pred-mel-spectrogram.png'.format(step)), info='{}, {}, step={}, loss={:.5}'.format(args.model, time_string(), step, loss)) log('Input at step {}: {}'.format(step, sequence_to_text(input_seq))) except Exception as e: log('Exiting due to exception: {}'.format(e), slack=True) traceback.print_exc() coord.request_stop(e)