def train(log_dir, args): commit = get_git_commit() if args.git else 'None' checkpoint_path = os.path.join(log_dir, 'model.ckpt') input_path = os.path.join(args.base_dir, args.input) log('Checkpoint path: %s' % checkpoint_path) log('Loading training data from: %s' % input_path) log('Using model: %s' % args.model) log(hparams_debug_string()) # Set up DataFeeder: coord = tf.train.Coordinator() with tf.variable_scope('datafeeder') as scope: feeder = DataFeeder(coord, input_path, hparams) # Set up model: global_step = tf.Variable(0, name='global_step', trainable=False) with tf.variable_scope('model') as scope: model = create_model(args.model, hparams) model.initialize(feeder.inputs, feeder.input_lengths, feeder.mel_targets, feeder.linear_targets) model.add_loss() model.add_optimizer(global_step) stats = add_stats(model) # Bookkeeping: step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2) # Train! with tf.Session() as sess: try: summary_writer = tf.summary.FileWriter(log_dir, sess.graph) sess.run(tf.global_variables_initializer()) if args.restore_step: # Restore from a checkpoint if the user requested it. restore_path = '%s-%d' % (checkpoint_path, args.restore_step) saver.restore(sess, restore_path) log('Resuming from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True) else: log('Starting new training run at commit: %s' % commit, slack=True) feeder.start_in_session(sess) while not coord.should_stop(): start_time = time.time() step, loss, opt = sess.run([global_step, model.loss, model.optimize]) time_window.append(time.time() - start_time) loss_window.append(loss) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % ( step, time_window.average, loss, loss_window.average) log(message, slack=(step % args.checkpoint_interval == 0)) if loss > 100 or math.isnan(loss): log('Loss exploded to %.05f at step %d!' % (loss, step), slack=True) raise Exception('Loss Exploded') if step % args.summary_interval == 0: log('Writing summary at step: %d' % step) summary_writer.add_summary(sess.run(stats), step) if step % args.checkpoint_interval == 0: log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) log('Saving audio and alignment...') input_seq, spectrogram, alignment = sess.run([ model.inputs[0], model.linear_outputs[0], model.alignments[0]]) waveform = audio.inv_spectrogram(spectrogram.T) audio.save_wav(waveform, os.path.join(log_dir, 'step-%d-audio.wav' % step)) plot.plot_alignment(alignment, os.path.join(log_dir, 'step-%d-align.png' % step), info='%s, %s, %s, step=%d, loss=%.5f' % (args.model, commit, time_string(), step, loss)) log('Input: %s' % sequence_to_text(input_seq)) except Exception as e: log('Exiting due to exception: %s' % e, slack=True) traceback.print_exc() coord.request_stop(e)
def train(log_dir, args): commit = get_git_commit() if args.git else 'None' checkpoint_path = os.path.join(log_dir, 'model.ckpt') input_path = os.path.join(args.base_dir, args.input) log('Checkpoint path: %s' % checkpoint_path) log('Loading training data from: %s' % input_path) log('Using model: %s' % args.model) log(hparams_debug_string()) # Set up DataFeeder: coord = tf.train.Coordinator() with tf.variable_scope('datafeeder') as scope: feeder = DataFeeder(coord, input_path, hparams) # Set up model: global_step = tf.Variable(0, name='global_step', trainable=False) with tf.variable_scope('model') as scope: model = create_model(args.model, hparams) model.initialize(feeder.inputs, feeder.input_lengths, feeder.mel_targets, feeder.linear_targets) model.add_loss() model.add_optimizer(global_step) stats = add_stats(model) # Bookkeeping: step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2) # Train! config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: try: summary_writer = tf.summary.FileWriter(log_dir, sess.graph) sess.run(tf.global_variables_initializer()) if args.restore_step: # Restore from a checkpoint if the user requested it. restore_path = '%s-%d' % (checkpoint_path, args.restore_step) saver.restore(sess, restore_path) log('Resuming from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True) else: log('Starting new training run at commit: %s' % commit, slack=True) feeder.start_in_session(sess) while not coord.should_stop(): start_time = time.time() step, loss, opt = sess.run([global_step, model.loss, model.optimize]) time_window.append(time.time() - start_time) loss_window.append(loss) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % ( step, time_window.average, loss, loss_window.average) log(message, slack=(step % args.checkpoint_interval == 0)) if loss > 100 or math.isnan(loss): log('Loss exploded to %.05f at step %d!' % (loss, step), slack=True) raise Exception('Loss Exploded') if step % args.summary_interval == 0: log('Writing summary at step: %d' % step) summary_writer.add_summary(sess.run(stats), step) if step % args.checkpoint_interval == 0: log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) log('Saving audio and alignment...') input_seq, spectrogram, alignment = sess.run([ model.inputs[0], model.linear_outputs[0], model.alignments[0]]) waveform = audio.inv_spectrogram(spectrogram.T) audio.save_wav(waveform, os.path.join(log_dir, 'step-%d-audio.wav' % step)) plot.plot_alignment(alignment, os.path.join(log_dir, 'step-%d-align.png' % step), info='%s, %s, %s, step=%d, loss=%.5f' % (args.model, commit, time_string(), step, loss)) except Exception as e: log('Exiting due to exception: %s' % e, slack=True) traceback.print_exc() coord.request_stop(e)
def train(log_dir, config): config.data_paths = config.data_paths data_dirs = [os.path.join(data_path, "data") \ for data_path in config.data_paths] num_speakers = len(data_dirs) config.num_test = config.num_test_per_speaker * num_speakers if num_speakers > 1 and hparams.model_type not in ["deepvoice", "simple"]: raise Exception("[!] Unkown model_type for multi-speaker: {}".format( config.model_type)) commit = get_git_commit() if config.git else 'None' checkpoint_path = os.path.join(log_dir, 'model.ckpt') log(' [*] git recv-parse HEAD:\n%s' % get_git_revision_hash()) log('=' * 50) #log(' [*] dit diff:\n%s' % get_git_diff()) log('=' * 50) log(' [*] Checkpoint path: %s' % checkpoint_path) log(' [*] Loading training data from: %s' % data_dirs) log(' [*] Using model: %s' % config.model_dir) log(hparams_debug_string()) # Set up DataFeeder: coord = tf.train.Coordinator() with tf.variable_scope('datafeeder') as scope: train_feeder = DataFeeder(coord, data_dirs, hparams, config, 32, data_type='train', batch_size=hparams.batch_size) test_feeder = DataFeeder(coord, data_dirs, hparams, config, 8, data_type='test', batch_size=config.num_test) # Set up model: is_randomly_initialized = config.initialize_path is None global_step = tf.Variable(0, name='global_step', trainable=False) # 학습횟수 카운팅을 위해. trainable=False는 학습에 관여하지 않음을 명시 # 타코트론 모델 생성 with tf.variable_scope('model') as scope: model = create_model(hparams) model.initialize(train_feeder.inputs, train_feeder.input_lengths, num_speakers, train_feeder.speaker_id, train_feeder.mel_targets, train_feeder.linear_targets, train_feeder.loss_coeff, is_randomly_initialized=is_randomly_initialized) model.add_loss() model.add_optimizer(global_step) train_stats = add_stats(model, scope_name='stats') # legacy with tf.variable_scope('model', reuse=True) as scope: test_model = create_model(hparams) test_model.initialize(test_feeder.inputs, test_feeder.input_lengths, num_speakers, test_feeder.speaker_id, test_feeder.mel_targets, test_feeder.linear_targets, test_feeder.loss_coeff, rnn_decoder_test_mode=True, is_randomly_initialized=is_randomly_initialized) test_model.add_loss() test_stats = add_stats(test_model, model, scope_name='test') test_stats = tf.summary.merge([test_stats, train_stats]) # Bookkeeping: step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=None, keep_checkpoint_every_n_hours=2) sess_config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True) sess_config.gpu_options.allow_growth = True # Train! #with tf.Session(config=sess_config) as sess: with tf.Session() as sess: try: summary_writer = tf.summary.FileWriter(log_dir, sess.graph) sess.run(tf.global_variables_initializer()) # argument로 기존 체크포인트가 주어지는지 아닌지 부터 확인 if config.load_path: # Restore from a checkpoint if the user requested it. restore_path = get_most_recent_checkpoint(config.model_dir) saver.restore(sess, restore_path) log('Resuming from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True) elif config.initialize_path: restore_path = get_most_recent_checkpoint( config.initialize_path) saver.restore(sess, restore_path) log('Initialized from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True) # Added code for recognizing whole global steps by soundbrew reg_step = re.compile("model.ckpt-(\d+)") # reg_initial_target = re.compile("logs/(\w+)_") # reg_new_target = re.compile("datasets/(\w+)") m1 = reg_step.search(restore_path) # m2 = reg_initial_target.search(restore_path) # m3 = reg_new_target.search(config.data_paths) prev_global_step = m1.group(1) # prev_target = m2.group(1) # new_target = m3.group(1) #if new_target == prev_target: # zero_step_assign = tf.assign(global_step, int(prev_global_step)) #else: # zero_step_assign = tf.assign(global_step, 0) zero_step_assign = tf.assign(global_step, int(prev_global_step)) sess.run(zero_step_assign) start_step = sess.run(global_step) log('=' * 50) log(' [*] Global step is reset to {}'. \ format(start_step)) log('=' * 50) else: log('Starting new training run at commit: %s' % commit, slack=True) start_step = sess.run(global_step) train_feeder.start_in_session(sess, start_step) test_feeder.start_in_session(sess, start_step) while not coord.should_stop(): start_time = time.time() step, loss, opt = sess.run( [global_step, model.loss_without_coeff, model.optimize], feed_dict=model.get_dummy_feed_dict()) time_window.append(time.time() - start_time) loss_window.append(loss) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % ( step, time_window.average, loss, loss_window.average) log(message, slack=(step % config.checkpoint_interval == 0)) if loss > 100 or math.isnan(loss): log('Loss exploded to %.05f at step %d!' % (loss, step), slack=True) raise Exception('Loss Exploded') if step % config.summary_interval == 0: log('Writing summary at step: %d' % step) feed_dict = { **model.get_dummy_feed_dict(), **test_model.get_dummy_feed_dict() } summary_writer.add_summary( sess.run(test_stats, feed_dict=feed_dict), step) if step % config.checkpoint_interval == 0: log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) if step % config.test_interval == 0: log('Saving audio and alignment...') num_test = config.num_test fetches = [ model.inputs[:num_test], model.linear_outputs[:num_test], model.alignments[:num_test], test_model.inputs[:num_test], test_model.linear_outputs[:num_test], test_model.alignments[:num_test], ] feed_dict = { **model.get_dummy_feed_dict(), **test_model.get_dummy_feed_dict() } sequences, spectrograms, alignments, \ test_sequences, test_spectrograms, test_alignments = \ sess.run(fetches, feed_dict=feed_dict) save_and_plot(sequences[:1], spectrograms[:1], alignments[:1], log_dir, step, loss, "train") save_and_plot(test_sequences, test_spectrograms, test_alignments, log_dir, step, loss, "test") except Exception as e: log('Exiting due to exception: %s' % e, slack=True) traceback.print_exc() coord.request_stop(e)
def train(log_dir, args): commit = get_git_commit() if args.git else 'None' checkpoint_path = os.path.join(log_dir, 'model.ckpt') ## input path is lists of both postive path and negtiva path input_path_pos = os.path.join(args.base_dir, args.input_pos) input_path_neg = os.path.join(args.base_dir, args.input_neg) log('Checkpoint path: %s' % checkpoint_path) log('Loading positive training data from: %s' % input_path_pos) log('Loading negative training data from: %s' % input_path_neg) log('Using model: %s' % args.model) log(hparams_debug_string()) # Set up DataFeeder: coord = tf.train.Coordinator() with tf.variable_scope('datafeeder') as scope: feeder = DataFeeder(coord, input_path_pos, input_path_neg, hparams) # Set up model: global_step = tf.Variable(0, name='global_step', trainable=False) with tf.variable_scope('model') as scope: model = create_model(args.model, hparams) model.initialize(feeder.inputs_pos, feeder.input_lengths_pos, feeder.mel_targets_pos, feeder.linear_targets_pos, feeder.mel_targets_neg, feeder.linear_targets_neg, feeder.labels_pos, feeder.labels_neg) model.add_loss() model.add_optimizer(global_step) # Bookkeeping: step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2) # Train! with tf.Session() as sess: try: #summary_writer = tf.summary.FileWriter(log_dir, sess.graph) sess.run(tf.global_variables_initializer()) if args.restore_step: # Restore from a checkpoint if the user requested it. restore_path = '%s-%d' % (checkpoint_path, args.restore_step) saver.restore(sess, restore_path) log('Resuming from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True) else: log('Starting new training run at commit: %s' % commit, slack=True) feeder.start_in_session(sess) while not coord.should_stop(): start_time = time.time() # train d sess.run(model.d_optimize) # train g step, rec_loss, style_loss, d_loss, g_loss, _ = sess.run([ global_step, model.rec_loss, model.style_loss, model.d_loss, model.g_loss, model.g_optimize ]) time_window.append(time.time() - start_time) message = 'Step %-7d [%.03f sec/step, rec_loss=%.05f, style_loss=%.05f, d_loss=%.05f, g_loss=%.05f]' % ( step, time_window.average, rec_loss, style_loss, d_loss, g_loss) log(message, slack=(step % args.checkpoint_interval == 0)) if step % args.checkpoint_interval == 0: log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) log('Saving audio and alignment...') input_seq, spectrogram_pos, spectrogram_neg, alignment_pos, alignment_neg = sess.run( [ model.inputs[0], model.linear_outputs_pos[0], model.linear_outputs_neg[0], model.alignments_pos[0], model.alignments_neg[0] ]) waveform_pos = audio.inv_spectrogram(spectrogram_pos.T) waveform_neg = audio.inv_spectrogram(spectrogram_neg.T) audio.save_wav( waveform_pos, os.path.join(log_dir, 'step-%d-audio_pos.wav' % step)) audio.save_wav( waveform_neg, os.path.join(log_dir, 'step-%d-audio_neg.wav' % step)) plot.plot_alignment( alignment_pos, os.path.join(log_dir, 'step-%d-align_pos.png' % step), info='%s, %s, %s, step=%d, loss=%.5f' % (args.model, commit, time_string(), step, rec_loss)) plot.plot_alignment( alignment_neg, os.path.join(log_dir, 'step-%d-align_neg.png' % step), info='%s, %s, %s, step=%d, loss=%.5f' % (args.model, commit, time_string(), step, rec_loss)) log('Input: %s' % sequence_to_text(input_seq)) except Exception as e: log('Exiting due to exception: %s' % e, slack=True) traceback.print_exc() coord.request_stop(e)
def train(log_dir, config): config.data_paths = config.data_paths sub_dirs =['A11', 'A12', 'A13', 'A14', 'A19', 'A2', 'A22', 'A23', 'A32', 'A33', 'A34', 'A36', 'A4', 'A6', 'A7', 'A8', 'B11', 'B12', 'B15','B2', 'B22', 'B31', 'B32', 'B33', 'B4', 'B6', 'B7', 'B8', 'C12', 'C13', 'C14', 'C17', 'C18', 'C19', 'C2', 'C20', 'C21', 'C22', 'C23', 'C31', 'C32', 'C4', 'C6', 'C7', 'C8', 'D11', 'D12', 'D13', 'D21', 'D31', 'D32', 'D4', 'D6', 'D7', 'D8',] data_dirs = [os.path.join(config.data_paths[0], sub_dir) for sub_dir in sub_dirs] num_speakers = len(data_dirs) setattr(hparams, "num_speakers", len(data_dirs)) config.num_test = config.num_test_per_speaker * num_speakers if num_speakers > 1 and hparams.model_type not in ["deepvoice", "simple"]: raise Exception("[!] Unkown model_type for multi-speaker: {}".format(config.model_type)) commit = get_git_commit() if config.git else 'None' checkpoint_path = os.path.join(log_dir, 'model.ckpt') log(' [*] git recv-parse HEAD:\n%s' % get_git_revision_hash()) log('='*50) #log(' [*] dit diff:\n%s' % get_git_diff()) log('='*50) log(' [*] Checkpoint path: %s' % checkpoint_path) log(' [*] Loading training data from: %s' % data_dirs) log(' [*] Using model: %s' % config.model_dir) log(hparams_debug_string()) # Set up DataFeeder: coord = tf.train.Coordinator() with tf.variable_scope('datafeeder') as scope: train_feeder = DataFeeder( coord, data_dirs, hparams, config, 32, data_type='train', batch_size=hparams.batch_size) test_feeder = DataFeeder( coord, data_dirs, hparams, config, 8, data_type='test', batch_size=config.num_test) # Set up model: is_randomly_initialized = config.initialize_path is None global_step = tf.Variable(0, name='global_step', trainable=False) with tf.variable_scope('model') as scope: model = create_model(hparams) model.initialize( train_feeder.inputs, train_feeder.input_lengths, num_speakers, train_feeder.speaker_id, train_feeder.mel_targets, train_feeder.linear_targets, train_feeder.loss_coeff, is_randomly_initialized=is_randomly_initialized) model.add_loss() model.add_optimizer(global_step) train_stats = add_stats(model, scope_name='stats') # legacy with tf.variable_scope('model', reuse=True) as scope: test_model = create_model(hparams) test_model.initialize( test_feeder.inputs, test_feeder.input_lengths, num_speakers, test_feeder.speaker_id, test_feeder.mel_targets, test_feeder.linear_targets, test_feeder.loss_coeff, rnn_decoder_test_mode=True, is_randomly_initialized=is_randomly_initialized) test_model.add_loss() test_stats = add_stats(test_model, model, scope_name='test') test_stats = tf.summary.merge([test_stats, train_stats]) # Bookkeeping: step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=None, keep_checkpoint_every_n_hours=2) sess_config = tf.ConfigProto( log_device_placement=False, allow_soft_placement=True) sess_config.gpu_options.allow_growth=True # Train! #with tf.Session(config=sess_config) as sess: with tf.Session() as sess: try: summary_writer = tf.summary.FileWriter(log_dir, sess.graph) sess.run(tf.global_variables_initializer()) if config.load_path: # Restore from a checkpoint if the user requested it. restore_path = get_most_recent_checkpoint(config.model_dir) saver.restore(sess, restore_path) log('Resuming from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True) elif config.initialize_path: restore_path = get_most_recent_checkpoint(config.initialize_path) saver.restore(sess, restore_path) log('Initialized from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True) zero_step_assign = tf.assign(global_step, 0) sess.run(zero_step_assign) start_step = sess.run(global_step) log('='*50) log(' [*] Global step is reset to {}'. \ format(start_step)) log('='*50) else: log('Starting new training run at commit: %s' % commit, slack=True) start_step = sess.run(global_step) train_feeder.start_in_session(sess, start_step) test_feeder.start_in_session(sess, start_step) while not coord.should_stop(): start_time = time.time() step, loss, opt = sess.run( [global_step, model.loss_without_coeff, model.optimize], feed_dict=model.get_dummy_feed_dict()) time_window.append(time.time() - start_time) loss_window.append(loss) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % ( step, time_window.average, loss, loss_window.average) log(message, slack=(step % config.checkpoint_interval == 0)) if loss > 100 or math.isnan(loss): log('Loss exploded to %.05f at step %d!' % (loss, step), slack=True) raise Exception('Loss Exploded') if step % config.summary_interval == 0: log('Writing summary at step: %d' % step) feed_dict = { **model.get_dummy_feed_dict(), **test_model.get_dummy_feed_dict() } summary_writer.add_summary(sess.run( test_stats, feed_dict=feed_dict), step) if step % config.checkpoint_interval == 0: log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) if step % config.test_interval == 0: log('Saving audio and alignment...') num_test = config.num_test fetches = [ model.inputs[:num_test], model.linear_outputs[:num_test], model.alignments[:num_test], test_model.inputs[:num_test], test_model.linear_outputs[:num_test], test_model.alignments[:num_test], ] feed_dict = { **model.get_dummy_feed_dict(), **test_model.get_dummy_feed_dict() } sequences, spectrograms, alignments, \ test_sequences, test_spectrograms, test_alignments = \ sess.run(fetches, feed_dict=feed_dict) save_and_plot(sequences[:1], spectrograms[:1], alignments[:1], log_dir, step, loss, "train") save_and_plot(test_sequences, test_spectrograms, test_alignments, log_dir, step, loss, "test") except Exception as e: log('Exiting due to exception: %s' % e, slack=True) traceback.print_exc() coord.request_stop(e)
def main(): parser = argparse.ArgumentParser() parser.add_argument('-config', type=str, default='config/hparams.yaml') parser.add_argument('-load_model', type=str, default=None) parser.add_argument('-model_name', type=str, default='P_S_Transformer_debug', help='model name') # parser.add_argument('-batches_per_allreduce', type=int, default=1, # help='number of batches processed locally before ' # 'executing allreduce across workers; it multiplies ' # 'total batch size.') parser.add_argument('-num_wokers', type=int, default=0, help='how many subprocesses to use for data loading. ' '0 means that the data will be loaded in the main process') parser.add_argument('-log', type=str, default='train.log') opt = parser.parse_args() configfile = open(opt.config) config = AttrDict(yaml.load(configfile,Loader=yaml.FullLoader)) log_name = opt.model_name or config.model.name log_folder = os.path.join(os.getcwd(),'logdir/logging',log_name) if not os.path.isdir(log_folder): os.mkdir(log_folder) logger = init_logger(log_folder+'/'+opt.log) # TODO: build dataloader train_datafeeder = DataFeeder(config,'debug') # TODO: build model or load pre-trained model global global_step global_step = 0 learning_rate = CustomSchedule(config.model.d_model) # learning_rate = 0.00002 optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=config.optimizer.beta1, beta_2=config.optimizer.beta2, epsilon=config.optimizer.epsilon) logger.info('config.optimizer.beta1:' + str(config.optimizer.beta1)) logger.info('config.optimizer.beta2:' + str(config.optimizer.beta2)) logger.info('config.optimizer.epsilon:' + str(config.optimizer.epsilon)) # print(str(config)) model = Speech_transformer(config=config,logger=logger) #Create the checkpoint path and the checkpoint manager. This will be used to save checkpoints every n epochs. checkpoint_path = log_folder ckpt = tf.train.Checkpoint(transformer=model, optimizer=optimizer) ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5) # if a checkpoint exists, restore the latest checkpoint. if ckpt_manager.latest_checkpoint: ckpt.restore(ckpt_manager.latest_checkpoint) logger.info('Latest checkpoint restored!!') else: logger.info('Start new run') # define metrics and summary writer train_loss = tf.keras.metrics.Mean(name='train_loss') train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy') # summary_writer = tf.keras.callbacks.TensorBoard(log_dir=log_folder) summary_writer = summary_ops_v2.create_file_writer_v2(log_folder+'/train') # @tf.function def train_step(batch_data): inp = batch_data['the_inputs'] # batch*time*feature tar = batch_data['the_labels'] # batch*time # inp_len = batch_data['input_length'] # tar_len = batch_data['label_length'] gtruth = batch_data['ground_truth'] tar_inp = tar tar_real = gtruth # enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp[:,:,0], tar_inp) combined_mask = create_combined_mask(tar=tar_inp) with tf.GradientTape() as tape: predictions, _ = model(inp, tar_inp, True, None, combined_mask, None) # logger.info('config.train.label_smoothing_epsilon:' + str(config.train.label_smoothing_epsilon)) loss = LableSmoothingLoss(tar_real, predictions,config.model.vocab_size,config.train.label_smoothing_epsilon) gradients = tape.gradient(loss, model.trainable_variables) clipped_gradients, _ = tf.clip_by_global_norm(gradients, 1.0) optimizer.apply_gradients(zip(gradients, model.trainable_variables)) train_loss(loss) train_accuracy(tar_real, predictions) time_window = ValueWindow(100) loss_window = ValueWindow(100) acc_window = ValueWindow(100) logger.info('config.train.epoches:' + str(config.train.epoches)) first_time = True for epoch in range(config.train.epoches): logger.info('start epoch '+ str(epoch)) logger.info('total wavs: '+ str(len(train_datafeeder))) logger.info('batch size: ' + str(train_datafeeder.batch_size)) logger.info('batch per epoch: ' + str(len(train_datafeeder)//train_datafeeder.batch_size)) train_data = train_datafeeder.get_batch() start_time = time.time() train_loss.reset_states() train_accuracy.reset_states() for step in range(len(train_datafeeder)//train_datafeeder.batch_size): batch_data = next(train_data) step_time = time.time() train_step(batch_data) if first_time: model.summary() first_time=False time_window.append(time.time()-step_time) loss_window.append(train_loss.result()) acc_window.append(train_accuracy.result()) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f, acc=%.05f, avg_acc=%.05f]' % ( global_step, time_window.average, train_loss.result(), loss_window.average, train_accuracy.result(),acc_window.average) logger.info(message) if global_step % 10 == 0: with summary_ops_v2.always_record_summaries(): with summary_writer.as_default(): summary_ops_v2.scalar('train_loss', train_loss.result(), step=global_step) summary_ops_v2.scalar('train_acc', train_accuracy.result(), step=global_step) global_step += 1 ckpt_save_path = ckpt_manager.save() logger.info('Saving checkpoint for epoch {} at {}'.format(epoch+1, ckpt_save_path)) logger.info('Time taken for 1 epoch: {} secs\n'.format(time.time() - start_time))