def train(logdir,args): # TODO:parse ckpt,arguments,hparams checkpoint_path = os.path.join(logdir,'model.ckpt') input_path = args.data_dir log('Checkpoint path: %s' % checkpoint_path) log('Loading training data from : %s ' % input_path) log('Using model : %s' %args.model) # TODO:set up datafeeder with tf.variable_scope('datafeeder') as scope: hp.data_length = None hp.initial_learning_rate = 0.0001 hp.batch_size = 256 hp.prime = True hp.stcmd = True feeder = DataFeeder(args=hp) log('num_sentences:'+str(len(feeder.wav_lst))) # 283600 hp.input_vocab_size = len(feeder.pny_vocab) hp.final_output_dim = len(feeder.pny_vocab) hp.steps_per_epoch = len(feeder.wav_lst)//hp.batch_size log('steps_per_epoch:' + str(hp.steps_per_epoch)) # 17725 log('pinyin_vocab_size:'+str(hp.input_vocab_size)) # 1292 hp.label_vocab_size = len(feeder.han_vocab) log('label_vocab_size :' + str(hp.label_vocab_size)) # 6291 # TODO:set up model global_step = tf.Variable(initial_value=0,name='global_step',trainable=False) valid_step = 0 # valid_global_step = tf.Variable(initial_value=0,name='valid_global_step',trainable=False) with tf.variable_scope('model') as scope: model = create_model(args.model,hp) model.build_graph() model.add_loss() model.add_optimizer(global_step=global_step,loss=model.mean_loss) # TODO: summary stats = add_stats(model=model) valid_stats = add_dev_stats(model) # TODO:Set up saver and Bookkeeping time_window = ValueWindow(100) loss_window = ValueWindow(100) acc_window = ValueWindow(100) valid_time_window = ValueWindow(100) valid_loss_window = ValueWindow(100) valid_acc_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=20) first_serving = True # TODO: train with tf.Session() as sess: log(hparams_debug_string(hp)) try: # TODO: Set writer and initializer summary_writer = tf.summary.FileWriter(logdir + '/train', sess.graph) summary_writer_dev = tf.summary.FileWriter(logdir + '/dev') sess.run(tf.global_variables_initializer()) # TODO: Restore if args.restore_step: # Restore from a checkpoint if the user requested it. restore_path = '%s-%d' % (checkpoint_path, args.restore_step) saver.restore(sess, restore_path) log('Resuming from checkpoint: %s ' % restore_path) else: log('Starting new training run ') step = 0 # TODO: epochs steps batch for i in range(args.epochs): batch_data = feeder.get_lm_batch() log('Traning epoch '+ str(i)+':') for j in range(hp.steps_per_epoch): input_batch, label_batch = next(batch_data) feed_dict = { model.x:input_batch, model.y:label_batch, } # TODO: Run one step ~~~ start_time = time.time() total_step,batch_loss,batch_acc,opt = sess.run([global_step, model.mean_loss,model.acc,model.optimize],feed_dict=feed_dict) time_window.append(time.time() - start_time) step = total_step # TODO: Append loss loss_window.append(batch_loss) acc_window.append(batch_acc) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f,acc=%.05f, avg_acc=%.05f, lr=%.07f]' % ( step, time_window.average, batch_loss, loss_window.average,batch_acc,acc_window.average,K.get_value(model.learning_rate)) log(message) # TODO: Check loss if math.isnan(batch_loss): log('Loss exploded to %.05f at step %d!' % (batch_loss, step)) raise Exception('Loss Exploded') # TODO: Check sumamry if step % args.summary_interval == 0: log('Writing summary at step: %d' % step) summary_writer.add_summary(sess.run(stats,feed_dict=feed_dict), step) # TODO: Check checkpoint if step % args.checkpoint_interval == 0: log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) log('test acc...') label,final_pred_label = sess.run([ model.y, model.preds],feed_dict=feed_dict) log('label.shape :'+str(label.shape)) # (batch_size , label_length) log('final_pred_label.shape:'+str(np.asarray(final_pred_label).shape)) # (1, batch_size, decode_length<=label_length) log('label : '+str(label[0])) log('final_pred_label: '+str( np.asarray(final_pred_label)[0])) # TODO: serving if args.serving :#and total_step // hp.steps_per_epoch > 5: np.save('logdir/lm_pinyin_dict.npy',feeder.pny_vocab) np.save('logdir/lm_hanzi_dict.npy',feeder.han_vocab) print(total_step, 'hhhhhhhh') # TODO: Set up serving builder and signature map serve_dir = args.serving_dir + '0001' if os.path.exists(serve_dir): os.removedirs(serve_dir) builder = tf.saved_model.builder.SavedModelBuilder(export_dir=serve_dir) input = tf.saved_model.utils.build_tensor_info(model.x) output_labels = tf.saved_model.utils.build_tensor_info(model.preds) prediction_signature = ( tf.saved_model.signature_def_utils.build_signature_def( inputs={'pinyin': input}, outputs={'hanzi': output_labels}, method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME) ) if first_serving: first_serving = False builder.add_meta_graph_and_variables( sess=sess, tags=[tf.saved_model.tag_constants.SERVING], signature_def_map={ 'predict_Pinyin2Hanzi': prediction_signature, }, main_op=tf.tables_initializer(), strip_default_attrs=True ) builder.save() log('Done store serving-model') raise Exception('Done store serving-model') # TODO: Validation # if total_step % hp.steps_per_epoch == 0 and i >= 10: if total_step % hp.steps_per_epoch == 0: log('validation...') valid_start = time.time() # TODO: validation valid_hp = copy.deepcopy(hp) print('feature_type: ',hp.feature_type) valid_hp.data_type = 'dev' valid_hp.thchs30 = True valid_hp.aishell = True valid_hp.prime = True valid_hp.stcmd = True valid_hp.shuffle = True valid_hp.data_length = None valid_feeder = DataFeeder(args=valid_hp) valid_feeder.pny_vocab = feeder.pny_vocab valid_feeder.han_vocab = feeder.han_vocab # valid_feeder.am_vocab = feeder.am_vocab valid_batch_data = valid_feeder.get_lm_batch() log('valid_num_sentences:' + str(len(valid_feeder.wav_lst))) # 15219 valid_hp.input_vocab_size = len(valid_feeder.pny_vocab) valid_hp.final_output_dim = len(valid_feeder.pny_vocab) valid_hp.steps_per_epoch = len(valid_feeder.wav_lst) // valid_hp.batch_size log('valid_steps_per_epoch:' + str(valid_hp.steps_per_epoch)) # 951 log('valid_pinyin_vocab_size:' + str(valid_hp.input_vocab_size)) # 1124 valid_hp.label_vocab_size = len(valid_feeder.han_vocab) log('valid_label_vocab_size :' + str(valid_hp.label_vocab_size)) # 3327 # dev 只跑一个epoch就行 with tf.variable_scope('validation') as scope: for k in range(len(valid_feeder.wav_lst) // valid_hp.batch_size): valid_input_batch,valid_label_batch = next(valid_batch_data) valid_feed_dict = { model.x: valid_input_batch, model.y: valid_label_batch, } # TODO: Run one step valid_start_time = time.time() valid_batch_loss,valid_batch_acc = sess.run([model.mean_loss,model.acc], feed_dict=valid_feed_dict) valid_time_window.append(time.time() - valid_start_time) valid_loss_window.append(valid_batch_loss) valid_acc_window.append(valid_batch_acc) # print('loss',loss,'batch_loss',batch_loss) message = 'Valid-Step %-7d [%.03f sec/step, valid_loss=%.05f, avg_loss=%.05f, valid_acc=%.05f, avg_acc=%.05f]' % ( valid_step, valid_time_window.average, valid_batch_loss, valid_loss_window.average,valid_batch_acc,valid_acc_window.average) log(message) summary_writer_dev.add_summary(sess.run(valid_stats,feed_dict=valid_feed_dict), valid_step) valid_step += 1 log('Done Validation!Total Time Cost(sec):' + str(time.time()-valid_start)) except Exception as e: log('Exiting due to exception: %s' % e) traceback.print_exc()
def test(): parser = argparse.ArgumentParser() # TODO: add arguments parser.add_argument('--log_dir', default=os.path.expanduser('~/my_asr2/logdir/logging')) parser.add_argument( '--serving_dir', default=os.path.expanduser('~/my_asr2/logdir/serving_am/')) parser.add_argument('--data_dir', default=os.path.expanduser('~/corpus_zn')) parser.add_argument('--model', default='ASR_wavnet') # parser.add_argument('--epochs', type=int, help='Max epochs to run.', default=100) parser.add_argument('--restore_step', type=int, help='Global step to restore from checkpoint.', default=2100) parser.add_argument('--serving', type=bool, help='', default=False) # parser.add_argument('--validation_interval', type=int, help='一个epoch验证5次,每次200步共3200条数据', default=7090) # 35450//5 parser.add_argument('--summary_interval', type=int, default=1, help='Steps between running summary ops.') # parser.add_argument('--checkpoint_interval', type=int, default=100, help='Steps between writing checkpoints.') parser.add_argument( '--hparams', default='', help= 'Hyperparameter overrides as a comma-separated list of name=value pairs' ) args = parser.parse_args() run_name = args.model logdir = os.path.join(args.log_dir, 'logs-%s' % run_name) init(os.path.join(logdir, 'test.log'), run_name) hp.parse(args.hparams) # TODO:parse ckpt,arguments,hparams checkpoint_path = os.path.join(logdir, 'model.ckpt') input_path = args.data_dir log('Checkpoint path: %s' % checkpoint_path) log('Loading training data from : %s ' % input_path) log('Using model : %s' % args.model) # TODO:set up datafeeder with tf.variable_scope('datafeeder') as scope: hp.data_type = 'test' hp.feature_type = 'mfcc' hp.data_length = None hp.initial_learning_rate = 0.0005 hp.batch_size = 1 hp.aishell = False hp.prime = False hp.stcmd = False hp.AM = True hp.LM = False hp.shuffle = False hp.is_training = False # TODO: 在infer的时候一定要设置为False否则bn会扰乱所有的值! feeder = DataFeeder_wavnet(args=hp) log('num_wavs:' + str(len(feeder.wav_lst))) feeder.am_vocab = np.load('logdir/am_pinyin_dict.npy').tolist() hp.input_vocab_size = len(feeder.am_vocab) hp.final_output_dim = len(feeder.am_vocab) hp.steps_per_epoch = len(feeder.wav_lst) // hp.batch_size log('steps_per_epoch:' + str(hp.steps_per_epoch)) log('pinyin_vocab_size:' + str(hp.input_vocab_size)) # TODO: set up model with tf.variable_scope('model') as scope: model = create_model(args.model, hp) model.build_graph() model.add_loss() model.add_decoder() # model.add_optimizer(global_step=global_step) # TODO: summary stats = add_stats(model) # TODO:Set up saver and Bookkeeping time_window = ValueWindow(100) loss_window = ValueWindow(100) wer_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=20) # TODO: test with tf.Session(graph=tf.get_default_graph()) as sess: log(hparams_debug_string(hp)) try: # TODO: Set writer and initializer summary_writer = tf.summary.FileWriter(logdir + '/test', sess.graph) sess.run(tf.global_variables_initializer()) # TODO: Restore if args.restore_step: # Restore from a checkpoint if the user requested it. restore_path = '%s-%d' % (checkpoint_path, args.restore_step) saver.restore(sess, restore_path) log('Resuming from checkpoint: %s ' % restore_path) else: log('Starting new training run ') # TODO: epochs steps batch step = 0 batch_data = feeder.get_am_batch() for j in range(hp.steps_per_epoch): input_batch = next(batch_data) feed_dict = { model.inputs: input_batch['the_inputs'], model.labels: input_batch['the_labels'], model.input_lengths: input_batch['input_length'], model.label_lengths: input_batch['label_length'] } # TODO: Run one step start_time = time.time() array_loss, batch_loss, wer, label, final_pred_label = sess.run( [ model.ctc_loss, model.batch_loss, model.WER, model.labels, model.decoded1 ], feed_dict=feed_dict) time_window.append(time.time() - start_time) step = step + 1 # TODO: Append loss loss_window.append(batch_loss) wer_window.append(wer) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f, wer=%.05f, avg_wer=%.05f]' % ( step, time_window.average, batch_loss, loss_window.average, wer, wer_window.average) log(message) # TODO: show pred and write summary log('label.shape :' + str(label.shape)) # (batch_size , label_length) log('final_pred_label.shape:' + str(np.asarray(final_pred_label).shape)) log('label : ' + str(label[0])) log('final_pred_label: ' + str(np.asarray(final_pred_label)[0][0])) log('Writing summary at step: %d' % step) summary_writer.add_summary( sess.run(stats, feed_dict=feed_dict), step) # TODO: Check loss if math.isnan(batch_loss): log('Loss exploded to %.05f at step %d!' % (batch_loss, step)) raise Exception('Loss Exploded') log('serving step: ' + str(step)) # TODO: Set up serving builder and signature map serve_dir = args.serving_dir + '0001' if os.path.exists(serve_dir): shutil.rmtree(serve_dir) log('delete exists dirs:' + serve_dir) builder = tf.saved_model.builder.SavedModelBuilder( export_dir=serve_dir) input_spec = tf.saved_model.utils.build_tensor_info(model.inputs) input_len = tf.saved_model.utils.build_tensor_info( model.input_lengths) output_labels = tf.saved_model.utils.build_tensor_info( model.decoded1) output_logits = tf.saved_model.utils.build_tensor_info( model.pred_softmax) prediction_signature = ( tf.saved_model.signature_def_utils.build_signature_def( inputs={ 'mfcc': input_spec, 'len': input_len }, outputs={ 'label': output_labels, 'logits': output_logits }, method_name=tf.saved_model.signature_constants. PREDICT_METHOD_NAME)) builder.add_meta_graph_and_variables( sess=sess, tags=[tf.saved_model.tag_constants.SERVING], signature_def_map={ 'predict_AudioSpec2Pinyin': prediction_signature, }, main_op=tf.tables_initializer(), strip_default_attrs=False) builder.save() log('Done store serving-model') except Exception as e: log('Exiting due to exception: %s' % e) traceback.print_exc()
def train(self, dataset_train, dataset_test, dataset_train_lengths, dataset_test_lengths): # Setup data loaders with tf.variable_scope('train_iterator'): self.iterator_data_train = dataset_train.make_initializable_iterator( ) self.iterator_length_train = dataset_train_lengths.make_initializable_iterator( ) next_train_data = self.iterator_data_train.get_next() next_train_length = self.iterator_length_train.get_next() with tf.variable_scope('test_iterator'): self.iterator_data_test = dataset_test.make_initializable_iterator( ) self.iterator_length_test = dataset_test_lengths.make_initializable_iterator( ) next_test_data = self.iterator_data_test.get_next() next_test_length = self.iterator_length_test.get_next() # Set up model self.initializers = [ i.initializer for i in [ self.iterator_data_test, self.iterator_data_train, self.iterator_length_test, self.iterator_length_train ] ] self.model, self.stats = self.model_train_mode(next_train_data, next_train_length, self.global_step) self.eval_model = self.model_eval_mode(next_test_data, next_test_length) if self.all_params.use_ema: self.saver = create_shadow_saver(self.model, self.global_step) else: self.saver = tf.train.Saver(max_to_keep=100) # Book keeping step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) print('Training set to a maximum of {} steps'.format(self.all_params.train_steps)) \ # Memory allocation on the GPU as needed config = tf.ConfigProto() config.gpu_options.allow_growth = True # Train print('Starting training') with tf.Session(config=config) as sess: summary_writer = tf.summary.FileWriter(self.tensorboard_dir, sess.graph) # Allow the full trace to be stored at run time. options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) # Create a fresh metadata object: run_metadata = tf.RunMetadata() sess.run(tf.global_variables_initializer()) for init in self.initializers: sess.run(init) # saved model restoring if self.all_params.restore: # Restore saved model if the user requested it, default = True try: checkpoint_state = tf.train.get_checkpoint_state( self.save_dir) if (checkpoint_state and checkpoint_state.model_checkpoint_path): print('Loading checkpoint {}'.format( checkpoint_state.model_checkpoint_path)) self.saver.restore( sess, checkpoint_state.model_checkpoint_path) else: print('No model to load at {}'.format(self.save_dir)) self.saver.save(sess, self.checkpoint_path, global_step=self.global_step) except tf.errors.OutOfRangeError as e: print('Cannot restore checkpoint: {}'.format(e)) else: print('Starting new training!') self.saver.save(sess, self.checkpoint_path, global_step=self.global_step) # Training loop while not self.coord.should_stop( ) and step < self.all_params.train_steps: start_time = time.time() step, loss, opt = sess.run( [self.global_step, self.model.loss, self.model.optimize], options=options, run_metadata=run_metadata) time_window.append(time.time() - start_time) loss_window.append(loss) message = 'Step {:7d} [{:.3f} sec/step, loss={:.5f}, avg_loss={:.5f}]'.format( step, time_window.average, loss, loss_window.average) print(message) if np.isnan(loss) or loss > 100.: print('Loss exploded to {:.5f} at step {}'.format( loss, step)) raise Exception('Loss exploded') if step % self.all_params.summary_interval == 0: print('\nWriting summary at step {}'.format(step)) summary_writer.add_summary(sess.run(self.stats), step) if step % self.all_params.checkpoint_interval == 0 or step == self.all_params.train_steps: print('Saving model!') # Save model and current global step self.saver.save(sess, self.checkpoint_path, global_step=self.global_step) if step % self.all_params.eval_interval == 0: # Run eval and save eval stats print('\nRunning evaluation at step {}'.format(step)) all_logits = [] all_outputs = [] all_targets = [] all_lengths = [] val_losses = [] for i in tqdm(range(4)): val_loss, logits, outputs, targets, lengths = sess.run( [ self.eval_model.loss, self.eval_model.logits, self.eval_model.outputs, self.eval_model.targets, self.eval_model.input_lengths ]) all_logits.append(logits) all_outputs.append(outputs) all_targets.append(targets) all_lengths.append(lengths) val_losses.append(val_loss) logits = [l for logits in all_logits for l in logits] outputs = [o for output in all_outputs for o in output] targets = [t for target in all_targets for t in target] lengths = [l for length in all_lengths for l in length] logits = np.array([ e for o, l in zip(logits, lengths) for e in o[:l] ]).reshape(-1) outputs = np.array([ e for o, l in zip(outputs, lengths) for e in o[:l] ]).reshape(-1) targets = np.array([ e for t, l in zip(targets, lengths) for e in t[:l] ]).reshape(-1) val_loss = sum(val_losses) / len(val_losses) assert len(targets) == len(outputs) capture_rate, fig_path = evaluate_and_plot( outputs, targets, index=np.arange(0, len(targets)), model_name=self.all_params.name or self.all_params.self.model, weight=self.all_params.capture_weight, out_dir=self.eval_dir, use_tf=False, sess=sess, step=step) add_eval_stats(summary_writer, step, val_loss, capture_rate) tensorboard_file = os.path.join( self.tensorboard_dir, os.listdir(self.tensorboard_dir)[0]) ###### Replace these lines ########################### print(f'train_loss: {float(loss)}') print(f'validation_loss{float(val_loss)}') print(f'validation_capture_rate{float(capture_rate)}') ###################################################### print('Training complete after {} global steps!'.format( self.all_params.train_steps))
def train(train_loader, model, device, mels_criterion, stop_criterion, optimizer, scheduler, writer, train_dir): batch_time = ValueWindow() data_time = ValueWindow() losses = ValueWindow() # switch to train mode model.train() end = time.time() global global_epoch global global_step for i, (txts, mels, stop_tokens, txt_lengths, mels_lengths) in enumerate(train_loader): scheduler.adjust_learning_rate(optimizer, global_step) # measure data loading time data_time.update(time.time() - end) if device > -1: txts = txts.cuda(device) mels = mels.cuda(device) stop_tokens = stop_tokens.cuda(device) txt_lengths = txt_lengths.cuda(device) mels_lengths = mels_lengths.cuda(device) # compute output frames, decoder_frames, stop_tokens_predict, alignment = model( txts, txt_lengths, mels) decoder_frames_loss = mels_criterion(decoder_frames, mels, lengths=mels_lengths) frames_loss = mels_criterion(frames, mels, lengths=mels_lengths) stop_token_loss = stop_criterion(stop_tokens_predict, stop_tokens, lengths=mels_lengths) loss = decoder_frames_loss + frames_loss + stop_token_loss #print(frames_loss, decoder_frames_loss) losses.update(loss.item()) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() if hparams.clip_thresh > 0: grad_norm = torch.nn.utils.clip_grad_norm_( model.get_trainable_parameters(), hparams.clip_thresh) optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % hparams.print_freq == 0: log('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})'.format( global_epoch, i, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses)) # Logs writer.add_scalar("loss", float(loss.item()), global_step) writer.add_scalar( "avg_loss in {} window".format(losses.get_dinwow_size), float(losses.avg), global_step) writer.add_scalar("stop_token_loss", float(stop_token_loss.item()), global_step) writer.add_scalar("decoder_frames_loss", float(decoder_frames_loss.item()), global_step) writer.add_scalar("output_frames_loss", float(frames_loss.item()), global_step) if hparams.clip_thresh > 0: writer.add_scalar("gradient norm", grad_norm, global_step) writer.add_scalar("learning rate", optimizer.param_groups[0]['lr'], global_step) global_step += 1 dst_alignment_path = join(train_dir, "{}_alignment.png".format(global_step)) alignment = alignment.cpu().detach().numpy() plot_alignment(alignment[0, :txt_lengths[0], :mels_lengths[0]], dst_alignment_path, info="{}, {}".format(hparams.builder, global_step))
def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, rank, group_name, hparams, refine_from): """Training and validation logging results to tensorboard and stdout Params ------ output_directory (string): directory to save checkpoints log_directory (string) directory to save tensorboard logs checkpoint_path(string): checkpoint path n_gpus (int): number of gpus rank (int): rank of current gpu hparams (object): comma separated list of "name=value" pairs. """ if hparams.distributed_run: init_distributed(hparams, n_gpus, rank, group_name) torch.manual_seed(hparams.seed) torch.cuda.manual_seed(hparams.seed) model = load_model(hparams) learning_rate = hparams.initial_learning_rate optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=hparams.weight_decay) if hparams.use_GAN and hparams.GAN_type=='lsgan': from discriminator import Lsgan_Loss, Calculate_Discrim model_D = Calculate_Discrim(hparams).cuda() if torch.cuda.is_available() else Calculate_Discrim(hparams) lsgan_loss = Lsgan_Loss(hparams) optimizer_D = torch.optim.Adam(model_D.parameters(), lr=learning_rate, weight_decay=hparams.weight_decay) if hparams.use_GAN and hparams.GAN_type=='wgan-gp': from discriminator import Wgan_GP, GP model_D = Wgan_GP(hparams).cuda() if torch.cuda.is_available() else Wgan_GP(hparams) calc_gradient_penalty = GP(hparams).cuda() if torch.cuda.is_available() else GP(hparams) optimizer_D = torch.optim.Adam(model_D.parameters(), lr=learning_rate, weight_decay=hparams.weight_decay) if hparams.is_partial_refine: refine_list=['speaker_embedding.weight', 'spkemb_projection.weight', 'spkemb_projection.bias', 'projection.weight', 'projection.bias', 'encoder.encoders.0.norm1.w.weight', 'encoder.encoders.0.norm1.w.bias', 'encoder.encoders.0.norm1.b.weight', 'encoder.encoders.0.norm1.b.bias', 'encoder.encoders.0.norm2.w.weight', 'encoder.encoders.0.norm2.w.bias', 'encoder.encoders.0.norm2.b.weight', 'encoder.encoders.0.norm2.b.bias', 'encoder.encoders.1.norm1.w.weight', 'encoder.encoders.1.norm1.w.bias', 'encoder.encoders.1.norm1.b.weight', 'encoder.encoders.1.norm1.b.bias', 'encoder.encoders.1.norm2.w.weight', 'encoder.encoders.1.norm2.w.bias', 'encoder.encoders.1.norm2.b.weight', 'encoder.encoders.1.norm2.b.bias', 'encoder.encoders.2.norm1.w.weight', 'encoder.encoders.2.norm1.w.bias', 'encoder.encoders.2.norm1.b.weight', 'encoder.encoders.2.norm1.b.bias', 'encoder.encoders.2.norm2.w.weight', 'encoder.encoders.2.norm2.w.bias', 'encoder.encoders.2.norm2.b.weight', 'encoder.encoders.2.norm2.b.bias', 'encoder.encoders.3.norm1.w.weight', 'encoder.encoders.3.norm1.w.bias', 'encoder.encoders.3.norm1.b.weight', 'encoder.encoders.3.norm1.b.bias', 'encoder.encoders.3.norm2.w.weight', 'encoder.encoders.3.norm2.w.bias', 'encoder.encoders.3.norm2.b.weight', 'encoder.encoders.3.norm2.b.bias', 'encoder.encoders.4.norm1.w.weight', 'encoder.encoders.4.norm1.w.bias', 'encoder.encoders.4.norm1.b.weight', 'encoder.encoders.4.norm1.b.bias', 'encoder.encoders.4.norm2.w.weight', 'encoder.encoders.4.norm2.w.bias', 'encoder.encoders.4.norm2.b.weight', 'encoder.encoders.4.norm2.b.bias', 'encoder.encoders.5.norm1.w.weight', 'encoder.encoders.5.norm1.w.bias', 'encoder.encoders.5.norm1.b.weight', 'encoder.encoders.5.norm1.b.bias', 'encoder.encoders.5.norm2.w.weight', 'encoder.encoders.5.norm2.w.bias', 'encoder.encoders.5.norm2.b.weight', 'encoder.encoders.5.norm2.b.bias', 'encoder.after_norm.w.weight', 'encoder.after_norm.w.bias', 'encoder.after_norm.b.weight', 'encoder.after_norm.b.bias', 'duration_predictor.norm.0.w.weight', 'duration_predictor.norm.0.w.bias', 'duration_predictor.norm.0.b.weight', 'duration_predictor.norm.0.b.bias', 'duration_predictor.norm.1.w.weight', 'duration_predictor.norm.1.w.bias', 'duration_predictor.norm.1.b.weight', 'duration_predictor.norm.1.b.bias', 'decoder.encoders.0.norm1.w.weight', 'decoder.encoders.0.norm1.w.bias', 'decoder.encoders.0.norm1.b.weight', 'decoder.encoders.0.norm1.b.bias', 'decoder.encoders.0.norm2.w.weight', 'decoder.encoders.0.norm2.w.bias', 'decoder.encoders.0.norm2.b.weight', 'decoder.encoders.0.norm2.b.bias', 'decoder.encoders.1.norm1.w.weight', 'decoder.encoders.1.norm1.w.bias', 'decoder.encoders.1.norm1.b.weight', 'decoder.encoders.1.norm1.b.bias', 'decoder.encoders.1.norm2.w.weight', 'decoder.encoders.1.norm2.w.bias', 'decoder.encoders.1.norm2.b.weight', 'decoder.encoders.1.norm2.b.bias', 'decoder.encoders.2.norm1.w.weight', 'decoder.encoders.2.norm1.w.bias', 'decoder.encoders.2.norm1.b.weight', 'decoder.encoders.2.norm1.b.bias', 'decoder.encoders.2.norm2.w.weight', 'decoder.encoders.2.norm2.w.bias', 'decoder.encoders.2.norm2.b.weight', 'decoder.encoders.2.norm2.b.bias', 'decoder.encoders.3.norm1.w.weight', 'decoder.encoders.3.norm1.w.bias', 'decoder.encoders.3.norm1.b.weight', 'decoder.encoders.3.norm1.b.bias', 'decoder.encoders.3.norm2.w.weight', 'decoder.encoders.3.norm2.w.bias', 'decoder.encoders.3.norm2.b.weight', 'decoder.encoders.3.norm2.b.bias', 'decoder.encoders.4.norm1.w.weight', 'decoder.encoders.4.norm1.w.bias', 'decoder.encoders.4.norm1.b.weight', 'decoder.encoders.4.norm1.b.bias', 'decoder.encoders.4.norm2.w.weight', 'decoder.encoders.4.norm2.w.bias', 'decoder.encoders.4.norm2.b.weight', 'decoder.encoders.4.norm2.b.bias', 'decoder.encoders.5.norm1.w.weight', 'decoder.encoders.5.norm1.w.bias', 'decoder.encoders.5.norm1.b.weight', 'decoder.encoders.5.norm1.b.bias', 'decoder.encoders.5.norm2.w.weight', 'decoder.encoders.5.norm2.w.bias', 'decoder.encoders.5.norm2.b.weight', 'decoder.encoders.5.norm2.b.bias', 'decoder.after_norm.w.weight', 'decoder.after_norm.w.bias', 'decoder.after_norm.b.weight', 'decoder.after_norm.b.bias'] if hparams.is_refine_style: style_list= ['gst.ref_enc.convs.0.weight', 'gst.ref_enc.convs.1.weight', 'gst.ref_enc.convs.1.bias', 'gst.ref_enc.convs.3.weight', 'gst.ref_enc.convs.4.weight', 'gst.ref_enc.convs.4.bias', 'gst.ref_enc.convs.6.weight', 'gst.ref_enc.convs.7.weight', 'gst.ref_enc.convs.7.bias', 'gst.ref_enc.convs.9.weight', 'gst.ref_enc.convs.10.weight', 'gst.ref_enc.convs.10.bias', 'gst.ref_enc.convs.12.weight', 'gst.ref_enc.convs.13.weight', 'gst.ref_enc.convs.13.bias', 'gst.ref_enc.convs.15.weight', 'gst.ref_enc.convs.16.weight', 'gst.ref_enc.convs.16.bias', 'gst.ref_enc.gru.weight_ih_l0,' 'gst.ref_enc.gru.weight_hh_l0', 'gst.ref_enc.gru.bias_ih_l0', 'gst.ref_enc.gru.bias_hh_l0', 'gst.stl.gst_embs', 'gst.stl.mha.linear_q.weight', 'gst.stl.mha.linear_q.bias', 'gst.stl.mha.linear_k.weight', 'gst.stl.mha.linear_k.bias', 'gst.stl.mha.linear_v.weight', 'gst.stl.mha.linear_v.bias', 'gst.stl.mha.linear_out.weight', 'gst.stl.mha.linear_out.bias', 'gst.choosestl.choose_mha.linear_q.weight', 'gst.choosestl.choose_mha.linear_q.bias', 'gst.choosestl.choose_mha.linear_k.weight', 'gst.choosestl.choose_mha.linear_k.bias', 'gst.choosestl.choose_mha.linear_v.weight', 'gst.choosestl.choose_mha.linear_v.bias', 'gst.choosestl.choose_mha.linear_out.weight', 'gst.choosestl.choose_mha.linear_out.bias', 'gst_projection.weight', 'gst_projection.bias' ] refine_list += style_list for name, param in model.named_parameters(): if hparams.is_partial_refine: if name in refine_list: param.requires_grad = True else: param.requires_grad = False print(name, param.requires_grad, param.shape) if hparams.distributed_run: model = apply_gradient_allreduce(model) if hparams.use_GAN: model_D = apply_gradient_allreduce(model_D) logger = prepare_directories_and_logger(output_directory, log_directory, rank) train_loader, valset, collate_fn, trainset = prepare_dataloaders(hparams) # Load checkpoint if one exists iteration = 0 epoch_offset = 0 if not checkpoint_path: checkpoint_path = get_checkpoint_path(output_directory) if not hparams.is_partial_refine else get_checkpoint_path(refine_from) if checkpoint_path is not None: if warm_start: model = warm_start_model( checkpoint_path, model, hparams.ignore_layers) else: model, optimizer, _learning_rate, iteration = load_checkpoint( checkpoint_path, model, optimizer, hparams, style_list=style_list if hparams.is_refine_style else None) if hparams.use_saved_learning_rate: learning_rate = _learning_rate iteration = (iteration + 1) if not hparams.is_partial_refine else 0# next iteration is iteration + 1 epoch_offset = max(0, int(iteration / len(train_loader))) if not hparams.is_partial_refine else 0 model.train() if hparams.use_GAN: model_D.train() else: hparams.use_GAN = True hparams.Generator_pretrain_step = hparams.iters is_overflow = False epoch = epoch_offset time_window = ValueWindow(100) loss_window = ValueWindow(100) # ================ MAIN TRAINNIG LOOP! =================== while iteration <= hparams.iters: # print("Epoch: {}".format(epoch)) if hparams.distributed_run and hparams.batch_criterion == 'utterance': train_loader.sampler.set_epoch(epoch) for i, batch in enumerate(train_loader): start = time.perf_counter() learning_rate = learning_rate_decay(iteration, hparams) if hparams.use_GAN: # Discriminator turn if iteration > hparams.Generator_pretrain_step: for param_group in optimizer_D.param_groups: param_group['lr'] = learning_rate optimizer.zero_grad() optimizer_D.zero_grad() for name, param in model.named_parameters(): param.requires_grad = False for name, param in model_D.named_parameters(): param.requires_grad = True loss, loss_dict, weight, pred_outs, ys, olens = model(*model._parse_batch(batch,hparams,utt_mels=trainset.utt_mels if hparams.is_refine_style else None)) if hparams.GAN_type=='lsgan': discrim_gen_output, discrim_target_output = model_D(pred_outs + (torch.randn(pred_outs.size()).cuda() if hparams.add_noise else 0), ys + (torch.randn(pred_outs.size()).cuda() if hparams.add_noise else 0), olens) loss_D = lsgan_loss(discrim_gen_output, discrim_target_output, train_object='D') loss_G = lsgan_loss(discrim_gen_output, discrim_target_output, train_object='G') loss_D.backward(retain_graph=True) if hparams.GAN_type=='wgan-gp': D_real = model_D(ys, olens) D_real = -D_real.mean() D_real.backward(retain_graph=True) D_fake = model_D(pred_outs, olens) D_fake = D_fake.mean() D_fake.backward() gradient_penalty = calc_gradient_penalty(model_D, ys.data, pred_outs.data, olens.data) gradient_penalty.backward() D_cost = D_real + D_fake + gradient_penalty Wasserstein_D = -D_real - D_fake grad_norm_D = torch.nn.utils.clip_grad_norm_(model_D.parameters(), hparams.grad_clip_thresh) optimizer_D.step() print('\n') if hparams.GAN_type=='lsgan': print("Epoch:{} step:{} loss_D: {:>9.6f}, loss_G: {:>9.6f}, Grad Norm: {:>9.6f}".format(epoch, iteration, loss_D, loss_G, grad_norm_D)) if hparams.GAN_type=='wgan-gp': print("Epoch:{} step:{} D_cost: {:>9.6f}, Wasserstein_D: {:>9.6f}, GP: {:>9.6f}, Grad Norm: {:>9.6f}".format(epoch, iteration, D_cost, Wasserstein_D, gradient_penalty, grad_norm_D)) # Generator turn for param_group in optimizer.param_groups: param_group['lr'] = learning_rate optimizer.zero_grad() if iteration > hparams.Generator_pretrain_step: for name, param in model.named_parameters(): if hparams.is_partial_refine: if name in refine_list: param.requires_grad = True else: param.requires_grad = True for name, param in model_D.named_parameters(): param.requires_grad = False optimizer_D.zero_grad() loss, loss_dict, weight, pred_outs, ys, olens = model(*model._parse_batch(batch,hparams,utt_mels=trainset.utt_mels if hparams.is_refine_style else None)) if iteration > hparams.Generator_pretrain_step: if hparams.GAN_type=='lsgan': discrim_gen_output, discrim_target_output = model_D(pred_outs, ys, olens) loss_D = lsgan_loss(discrim_gen_output, discrim_target_output, train_object='D') loss_G = lsgan_loss(discrim_gen_output, discrim_target_output, train_object='G') if hparams.GAN_type=='wgan-gp': loss_G = model_D(pred_outs, olens) loss_G = -loss_G.mean() loss = loss + loss_G*hparams.GAN_alpha*abs(loss.item()/loss_G.item()) if hparams.distributed_run: reduced_loss = reduce_tensor(loss.data, n_gpus).item() if loss_dict: for key in loss_dict: loss_dict[key] = reduce_tensor(loss_dict[key].data, n_gpus).item() else: reduced_loss = loss.item() if loss_dict: for key in loss_dict: loss_dict[key] = loss_dict[key].item() loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hparams.grad_clip_thresh) optimizer.step() duration = time.perf_counter() - start time_window.append(duration) loss_window.append(reduced_loss) if not is_overflow and (rank == 0): if iteration % hparams.log_per_checkpoint == 0: if hparams.GAN_type=='lsgan': print("Epoch:{} step:{} Train loss: {:>9.6f}, avg loss: {:>9.6f}, Grad Norm: {:>9.6f}, {:>5.2f}s/it, {:s} loss: {:>9.6f}, D_loss: {:>9.6f}, G_loss: {:>9.6f}, duration loss: {:>9.6f}, ssim loss: {:>9.6f}, lr: {:>4}".format( epoch, iteration, reduced_loss, loss_window.average, grad_norm, time_window.average, hparams.loss_type, loss_dict[hparams.loss_type], loss_D.item() if iteration > hparams.Generator_pretrain_step else 0, loss_G.item() if iteration > hparams.Generator_pretrain_step else 0, loss_dict["duration_loss"], loss_dict["ssim_loss"], learning_rate)) if hparams.GAN_type=='wgan-gp': print("Epoch:{} step:{} Train loss: {:>9.6f}, avg loss: {:>9.6f}, Grad Norm: {:>9.6f}, {:>5.2f}s/it, {:s} loss: {:>9.6f}, G_loss: {:>9.6f}, duration loss: {:>9.6f}, ssim loss: {:>9.6f}, lr: {:>4}".format( epoch, iteration, reduced_loss, loss_window.average, grad_norm, time_window.average, hparams.loss_type, loss_dict[hparams.loss_type], loss_G.item() if iteration > hparams.Generator_pretrain_step else 0, loss_dict["duration_loss"], loss_dict["ssim_loss"], learning_rate)) logger.log_training( reduced_loss, grad_norm, learning_rate, duration, iteration, loss_dict) if not is_overflow and (iteration % hparams.iters_per_checkpoint == 0): if valset is not None: validate(model, valset, iteration, hparams.batch_criterion, hparams.batch_size, n_gpus, collate_fn, logger, hparams.distributed_run, rank) if rank == 0: checkpoint_path = os.path.join( output_directory, "checkpoint_{}_refine_{}".format(iteration, hparams.training_files.split('/')[-2].split('_')[-1]) if hparams.is_partial_refine else "checkpoint_{}".format(iteration)) save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path) iteration += 1 torch.cuda.empty_cache() epoch += 1
sess.run(dev_iterator.initializer, {num_frames: max_frames}) try: checkpoint_state = tf.train.get_checkpoint_state(train_hparams.checkpoint_path) if (checkpoint_state and checkpoint_state.model_checkpoint_path): log('Loading checkpoint {}'.format(checkpoint_state.model_checkpoint_path), slack=True) saver.restore(sess, checkpoint_state.model_checkpoint_path) else: log('No model to load at {}'.format(train_hparams.checkpoint_path), slack=True) saver.save(sess, checkpoint_path, global_step=train_resnet.global_step) except OutOfRangeError as e: log('Cannot restore checkpoint: {}'.format(e), slack=True) step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) acc_window = ValueWindow(100) while step < train_hparams.train_steps: start_time = time.time() fetches = [train_resnet.global_step, train_resnet.train_op, train_resnet.cost, train_resnet.accuracy] feed_dict = {handle: train_handle} try: step, _, loss, acc = sess.run(fetches=fetches, feed_dict=feed_dict) except OutOfRangeError as e: sess.run(train_iterator.initializer, {num_frames: max_frames}) continue time_window.append(time.time() - start_time)
def train_epoch(model, train_loader, loss_fn, optimizer, scheduler, batch_size, epoch, start_stpe): model.train() count = 0 total_loss = 0 n = batch_size step = start_stpe examples = [] total_loss_window = ValueWindow(100) post_loss_window = ValueWindow(100) post_acc_window = ValueWindow(100) for x, y in train_loader: count += 1 examples.append([x[0], y[0]]) if count % 8 == 0: examples.sort(key=lambda x: len(x[-1])) examples = (np.vstack([ex[0] for ex in examples]), np.vstack([ex[1] for ex in examples])) batches = [(examples[0][i:i + n], examples[1][i:i + n]) for i in range(0, len(examples[-1]) + 1 - n, n)] if len(examples[-1]) % n != 0: batches.append( (np.vstack((examples[0][-(len(examples[-1]) % n):], examples[0][:n - (len(examples[0]) % n)])), np.vstack((examples[1][-(len(examples[-1]) % n):], examples[1][:n - (len(examples[-1]) % n)])))) for batch in batches: # mini batch # train_data(?, 7, 80), train_label(?, 7) step += 1 train_data = torch.as_tensor(batch[0], dtype=torch.float32).to(DEVICE) train_label = torch.as_tensor(batch[1], dtype=torch.float32).to(DEVICE) optimizer.zero_grad(True) midnet_output, postnet_output, alpha = model(train_data) postnet_accuracy, pipenet_accuracy = prediction( train_label, midnet_output, postnet_output) loss, postnet_loss, pipenet_loss, attention_loss = loss_fn( model, train_label, postnet_output, midnet_output, alpha) total_loss += loss.detach().item() loss.backward() nn.utils.clip_grad_norm_(model.parameters(), 5, norm_type=2) optimizer.step() scheduler.step() lr = scheduler._rate total_loss_window.append(loss.detach().item()) post_loss_window.append(postnet_loss.detach().item()) post_acc_window.append(postnet_accuracy) if step % 10 == 0: print( '{} Epoch: {}, Step: {}, overall loss: {:.5f}, postnet loss: {:.5f}, ' 'postnet acc: {:.4f}, lr :{:.5f}'.format( datetime.now().strftime(_format)[:-3], epoch, step, total_loss_window.average, post_loss_window.average, post_acc_window.average, lr)) if step % 50_000 == 0: print('{} save checkpoint.'.format( datetime.now().strftime(_format)[:-3])) checkpoint = { "model": model.state_dict(), 'optimizer': optimizer.state_dict(), "epoch": epoch, 'step': step, 'scheduler_lr': scheduler._rate, 'scheduler_step': scheduler._step } if not os.path.isdir("./checkpoint"): os.mkdir("./checkpoint") torch.save( checkpoint, './checkpoint/STAM_weights_%s_%s.pth' % (str(epoch), str(step / 1_000_000))) gc.collect() torch.cuda.empty_cache() del batches, examples examples = []
def train(logdir,args): # TODO:parse ckpt,arguments,hparams checkpoint_path = os.path.join(logdir,'model.ckpt') input_path = args.data_dir log('Checkpoint path: %s' % checkpoint_path) log('Loading training data from : %s ' % input_path) log('Using model : %s' %args.model) # TODO:set up datafeeder with tf.variable_scope('datafeeder') as scope: hp.aishell=True hp.prime=False hp.stcmd=False hp.data_path = 'D:/pycharm_proj/corpus_zn/' hp.initial_learning_rate = 0.001 hp.decay_learning_rate=False hp.data_length = 512 hp.batch_size = 64 feeder = DataFeeder(args=hp) log('num_wavs:'+str(len(feeder.wav_lst))) # 283600 hp.input_vocab_size = len(feeder.pny_vocab) hp.final_output_dim = len(feeder.pny_vocab) hp.steps_per_epoch = len(feeder.wav_lst)//hp.batch_size log('steps_per_epoch:' + str(hp.steps_per_epoch)) # 17725 log('pinyin_vocab_size:'+str(hp.input_vocab_size)) # 1292 hp.label_vocab_size = len(feeder.han_vocab) log('label_vocab_size :' + str(hp.label_vocab_size)) # 6291 # TODO:set up model global_step = tf.Variable(initial_value=0,name='global_step',trainable=False) valid_step = 0 # valid_global_step = tf.Variable(initial_value=0,name='valid_global_step',trainable=False) with tf.variable_scope('model') as scope: model = create_model(args.model,hp) model.build_graph() model.add_loss() model.add_decoder() model.add_optimizer(global_step=global_step) # TODO: summary stats = add_stats(model=model) valid_stats = add_dev_stats(model) # TODO:Set up saver and Bookkeeping time_window = ValueWindow(100) loss_window = ValueWindow(100) wer_window = ValueWindow(100) valid_time_window = ValueWindow(100) valid_loss_window = ValueWindow(100) valid_wer_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=20) first_serving = True # TODO: train with tf.Session() as sess: try: # TODO: Set writer and initializer summary_writer = tf.summary.FileWriter(logdir, sess.graph) sess.run(tf.global_variables_initializer()) # TODO: Restore if args.restore_step: # Restore from a checkpoint if the user requested it. restore_path = '%s-%d' % (checkpoint_path, args.restore_step) saver.restore(sess, restore_path) log('Resuming from checkpoint: %s ' % restore_path) else: log('Starting new training run ') step = 0 # TODO: epochs steps batch for i in range(args.epochs): batch_data = feeder.get_am_batch() log('Traning epoch '+ str(i)+':') for j in range(hp.steps_per_epoch): input_batch = next(batch_data) feed_dict = {model.inputs:input_batch['the_inputs'], model.labels:input_batch['the_labels'], model.input_lengths:input_batch['input_length'], model.label_lengths:input_batch['label_length']} # TODO: Run one step start_time = time.time() total_step, array_loss, batch_loss,opt = sess.run([global_step, model.ctc_loss, model.batch_loss,model.optimize],feed_dict=feed_dict) time_window.append(time.time() - start_time) step = total_step # TODO: Append loss # loss = np.sum(array_loss).item()/hp.batch_size loss_window.append(batch_loss) # print('loss',loss,'batch_loss',batch_loss) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f, lr=%.07f]' % ( step, time_window.average, batch_loss, loss_window.average,K.get_value(model.learning_rate)) log(message) # ctcloss返回值是[batch_size,1]形式的所以刚开始没sum报错only size-1 arrays can be converted to Python scalars # TODO: Check loss if math.isnan(batch_loss): log('Loss exploded to %.05f at step %d!' % (batch_loss, step)) raise Exception('Loss Exploded') # TODO: Check sumamry if step % args.summary_interval == 0: log('Writing summary at step: %d' % step) summary_writer.add_summary(sess.run(stats,feed_dict=feed_dict), step) # TODO: Check checkpoint if step % args.checkpoint_interval == 0: log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) log('test acc...') # eval_start_time = time.time() # with tf.name_scope('eval') as scope: # with open(os.path.expanduser('~/my_asr2/datasets/resource/preprocessedData/dev-meta.txt'), encoding='utf-8') as f: # metadata = [line.strip().split('|') for line in f] # random.shuffle(metadata) # eval_loss = [] # batch_size = args.hp.batch_size # batchs = len(metadata)//batch_size # for i in range(batchs): # batch = metadata[i*batch_size : i*batch_size+batch_size] # batch = list(map(eval_get_example,batch)) # batch = eval_prepare_batch(batch) # feed_dict = {'labels':batch[0],} # label,final_pred_label ,log_probabilities = sess.run([ # model.labels[0], model.decoded[0], model.log_probabilities[0]]) # # 刚开始没有加[]会报错 https://github.com/tensorflow/tensorflow/issues/11840 # print('label: ' ,label) # print('final_pred_label: ', final_pred_label[0]) # log('eval time: %.03f, avg_eval_loss: %.05f' % (time.time()-eval_start_time,np.mean(eval_loss))) label,final_pred_label ,log_probabilities,y_pred2 = sess.run([ model.labels, model.decoded, model.log_probabilities,model.y_pred2],feed_dict=feed_dict) # 刚开始没有加[]会报错 https://github.com/tensorflow/tensorflow/issues/11840 print('label.shape :',label.shape) # (batch_size , label_length) print('final_pred_label.shape:',np.asarray(final_pred_label).shape) # (1, batch_size, decode_length<=label_length) print('y_pred2.shape : ', y_pred2.shape) print('label: ' ,label[0]) print('y_pred2 : ', y_pred2[0]) print('final_pred_label: ', np.asarray(final_pred_label)[0][0]) # 刚开始打不出来,因为使用的tf.nn.ctc_beam_decoder,这个返回的是sparse tensor所以打不出来 # 后来用keras封装的decoder自动将sparse转成dense tensor才能打出来 # waveform = audio.inv_spectrogram(spectrogram.T) # audio.save_wav(waveform, os.path.join(logdir, 'step-%d-audio.wav' % step)) # plot.plot_alignment(alignment, os.path.join(logdir, 'step-%d-align.png' % step), # info='%s, %s, %s, step=%d, loss=%.5f' % ( # args.model, commit, time_string(), step, loss)) # log('Input: %s' % sequence_to_text(input_seq)) # TODO: Check stop if step % hp.steps_per_epoch ==0: # TODO: Set up serving builder and signature map serve_dir = args.serving_dir + '_' + str(total_step//hp.steps_per_epoch -1) if os.path.exists(serve_dir): os.removedirs(serve_dir) builder = tf.saved_model.builder.SavedModelBuilder(export_dir=serve_dir) input_spec = tf.saved_model.utils.build_tensor_info(model.inputs) input_len = tf.saved_model.utils.build_tensor_info(model.input_lengths) output_labels = tf.saved_model.utils.build_tensor_info(model.decoded2) output_logits = tf.saved_model.utils.build_tensor_info(model.pred_logits) prediction_signature = ( tf.saved_model.signature_def_utils.build_signature_def( inputs={'spec': input_spec, 'len': input_len}, outputs={'label': output_labels, 'logits': output_logits}, method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME) ) if first_serving: first_serving = False builder.add_meta_graph_and_variables( sess=sess, tags=[tf.saved_model.tag_constants.SERVING, 'ASR'], signature_def_map={ 'predict_AudioSpec2Pinyin': prediction_signature, }, main_op=tf.tables_initializer(), strip_default_attrs=True ) else: builder.add_meta_graph_and_variables( sess=sess, tags=[tf.saved_model.tag_constants.SERVING, 'ASR'], signature_def_map={ 'predict_AudioSpec2Pinyin': prediction_signature, }, strip_default_attrs=True ) builder.save() log('Done store serving-model') # TODO: Validation if step % hp.steps_per_epoch ==0 and i >= 10: log('validation...') valid_start = time.time() # TODO: validation valid_hp = hp valid_hp.data_type = 'dev' valid_hp.thchs30 = True valid_hp.aishell = False valid_hp.prime = False valid_hp.stcmd = False valid_hp.shuffle = True valid_hp.data_length = None valid_feeder = DataFeeder(args=valid_hp) valid_batch_data = valid_feeder.get_am_batch() log('valid_num_wavs:' + str(len(valid_feeder.wav_lst))) # 15219 valid_hp.input_vocab_size = len(valid_feeder.pny_vocab) valid_hp.final_output_dim = len(valid_feeder.pny_vocab) valid_hp.steps_per_epoch = len(valid_feeder.wav_lst) // valid_hp.batch_size log('valid_steps_per_epoch:' + str(valid_hp.steps_per_epoch)) # 951 log('valid_pinyin_vocab_size:' + str(valid_hp.input_vocab_size)) # 1124 valid_hp.label_vocab_size = len(valid_feeder.han_vocab) log('valid_label_vocab_size :' + str(valid_hp.label_vocab_size)) # 3327 words_num = 0 word_error_num = 0 # dev 只跑一个epoch就行 with tf.variable_scope('validation') as scope: for k in range(len(valid_feeder.wav_lst) // valid_hp.batch_size): valid_input_batch = next(valid_batch_data) valid_feed_dict = {model.inputs: valid_input_batch['the_inputs'], model.labels: valid_input_batch['the_labels'], model.input_lengths: valid_input_batch['input_length'], model.label_lengths: valid_input_batch['label_length']} # TODO: Run one step valid_start_time = time.time() valid_batch_loss,valid_WER = sess.run([model.batch_loss,model.WER], feed_dict=valid_feed_dict) valid_time_window.append(time.time() - valid_start_time) valid_loss_window.append(valid_batch_loss) valid_wer_window.append(valid_WER) # print('loss',loss,'batch_loss',batch_loss) message = 'Valid-Step %-7d [%.03f sec/step, valid_loss=%.05f, avg_loss=%.05f, WER=%.05f, avg_WER=%.05f, lr=%.07f]' % ( valid_step, valid_time_window.average, valid_batch_loss, valid_loss_window.average,valid_WER,valid_wer_window.average,K.get_value(model.learning_rate)) log(message) summary_writer.add_summary(sess.run(valid_stats,feed_dict=valid_feed_dict), valid_step) valid_step += 1 log('Done Validation!Total Time Cost(sec):' + str(time.time()-valid_start)) except Exception as e: log('Exiting due to exception: %s' % e) traceback.print_exc()
def train(log_dir, args, hparams, use_hvd=False): if use_hvd: import horovod.tensorflow as hvd # Initialize Horovod. hvd.init() else: hvd = None eval_dir, eval_plot_dir, eval_wav_dir, meta_folder, plot_dir, save_dir, tensorboard_dir, wav_dir = init_dir( log_dir) checkpoint_path = os.path.join(save_dir, 'centaur_model.ckpt') input_path = os.path.join(args.base_dir, args.input_dir) log('Checkpoint path: {}'.format(checkpoint_path)) log('Loading training data from: {}'.format(input_path)) log('Using model: {}'.format(args.model)) log(hparams_debug_string()) # Start by setting a seed for repeatability tf.set_random_seed(hparams.random_seed) # Set up data feeder coord = tf.train.Coordinator() with tf.variable_scope('datafeeder'): feeder = Feeder(coord, input_path, hparams) # Set up model: global_step = tf.Variable(0, name='global_step', trainable=False) model, stats = model_train_mode(feeder, hparams, global_step, hvd=hvd) eval_model = model_test_mode(feeder, hparams) # Embeddings metadata char_embedding_meta = os.path.join(meta_folder, 'CharacterEmbeddings.tsv') if not os.path.isfile(char_embedding_meta): with open(char_embedding_meta, 'w', encoding='utf-8') as f: for symbol in symbols: if symbol == ' ': symbol = '\\s' # For visual purposes, swap space with \s f.write('{}\n'.format(symbol)) char_embedding_meta = char_embedding_meta.replace(log_dir, '..') # Book keeping step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=2) log('Centaur training set to a maximum of {} steps'.format( args.train_steps)) # Memory allocation on the GPU as needed config = tf.ConfigProto() config.allow_soft_placement = True config.gpu_options.allow_growth = True if use_hvd: config.gpu_options.visible_device_list = str(hvd.local_rank()) # Train with tf.Session(config=config) as sess: try: # Init model and load weights sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) # saved model restoring if args.restore: # Restore saved model if the user requested it, default = True restore_model(saver, sess, global_step, save_dir, checkpoint_path, args.reset_global_step) else: log('Starting new training!', slack=True) saver.save(sess, checkpoint_path, global_step=global_step) # initializing feeder start_step = sess.run(global_step) feeder.start_threads(sess, start_step=start_step) # Horovod bcast vars across workers if use_hvd: # Horovod: broadcast initial variable states from rank 0 to all other processes. # This is necessary to ensure consistent initialization of all workers when # training is started with random weights or restored from a checkpoint. bcast = hvd.broadcast_global_variables(0) bcast.run() log('Worker{}: Initialized'.format(hvd.rank())) # Training loop summary_writer = tf.summary.FileWriter(tensorboard_dir, sess.graph) while not coord.should_stop() and step < args.train_steps: start_time = time.time() step, loss, opt = sess.run( [global_step, model.loss, model.train_op]) if use_hvd: main_process = hvd.rank() == 0 if main_process: time_window.append(time.time() - start_time) loss_window.append(loss) message = 'Step {:7d} [{:.3f} sec/step, loss={:.5f}, avg_loss={:.5f}]'.format( step, time_window.average, loss, loss_window.average) log(message, end='\r', slack=(step % args.checkpoint_interval == 0)) if np.isnan(loss) or loss > 100.: log('Loss exploded to {:.5f} at step {}'.format( loss, step)) raise Exception('Loss exploded') if step % args.summary_interval == 0: log('\nWriting summary at step {}'.format(step)) summary_writer.add_summary(sess.run(stats), step) if step % args.eval_interval == 0: run_eval(args, eval_dir, eval_model, eval_plot_dir, eval_wav_dir, feeder, hparams, sess, step, summary_writer) if step % args.checkpoint_interval == 0 or step == args.train_steps or step == 300: save_current_model(args, checkpoint_path, global_step, hparams, loss, model, plot_dir, saver, sess, step, wav_dir) if step % args.embedding_interval == 0 or step == args.train_steps or step == 1: update_character_embedding(char_embedding_meta, save_dir, summary_writer) log('Centaur training complete after {} global steps!'.format( args.train_steps), slack=True) return save_dir except Exception as e: log('Exiting due to exception: {}'.format(e), slack=True) traceback.print_exc() coord.request_stop(e)
def train(log_dir, args): checkpoint_path = os.path.join(hdfs_ckpts, log_dir, 'model.ckpt') log(hp.to_string(), is_print=False) log('Loading training data from: %s' % args.tfr_dir) log('Checkpoint path: %s' % checkpoint_path) log('Using model: sygst tacotron2') tf_dset = TFDataSet(hp, args.tfr_dir) feats = tf_dset.get_train_next() # Set up model: global_step = tf.Variable(0, name='global_step', trainable=False) training = tf.placeholder_with_default(True, shape=(), name='training') with tf.name_scope('model'): model = Tacotron2SYGST(hp) model(feats['inputs'], mel_inputs=feats['mel_targets'], spec_inputs=feats['linear_targets'], spec_lengths=feats['spec_lengths'], ref_inputs=feats['mel_targets'], ref_lengths=feats['spec_lengths'], arousal_labels=feats['soft_arousal_labels'], valence_labels=feats['soft_valance_labels'], training=training) """ text_x, mel_x, spec_x, spec_len, aro, val = debug_data(2, 5, 10) model(text_x, mel_x, spec_x, spec_len, mel_x, spec_len, aro, val, training=training) """ model.add_loss() model.add_optimizer(global_step) stats = model.add_stats() # Bookkeeping: step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=50, keep_checkpoint_every_n_hours=2) # Train! config = tf.ConfigProto(allow_soft_placement=True, gpu_options=tf.GPUOptions(allow_growth=True)) with tf.Session(config=config) as sess: try: summary_writer = tf.summary.FileWriter(log_dir, sess.graph) sess.run(tf.global_variables_initializer()) if args.restore_step: # Restore from a checkpoint if the user requested it. restore_path = '%s-%s' % (checkpoint_path, args.restore_step) saver.restore(sess, restore_path) log('Resuming from checkpoint: %s' % restore_path, slack=True) else: log('Starting a new training run ...', slack=True) """ fetches = [global_step, model.optimize, model.loss, model.mel_loss, model.spec_loss, model.stop_loss, model.arousal_loss, model.valence_loss, model.mel_grad_norms_max, model.spec_grad_norms_max, model.stop_grad_norms_max, model.aro_grad_norms_max, model.val_grad_norms_max] """ fetches = [ global_step, model.optimize, model.loss, model.mel_loss, model.spec_loss, model.stop_loss, model.arousal_loss, model.valence_loss ] for _ in range(_max_step): start_time = time.time() sess.run(debug.get_ops()) # step, _, loss, mel_loss, spec_loss, stop_loss, aro_loss, val_loss, mel_g, spec_g, stop_g, aro_g, val_g = sess.run(fetches) step, _, loss, mel_loss, spec_loss, stop_loss, aro_loss, val_loss = sess.run( fetches) time_window.append(time.time() - start_time) loss_window.append(loss) """ message = 'Step %-7d [%.3f sec/step,ml=%.3f,spl=%.3f,sl=%.3f,al=%.3f,vl=%.3f,mg=%.4f,spg=%.4f,sg=%.4f,ag=%.4f,vg=%.4f]' % ( step, time_window.average, mel_loss, spec_loss, stop_loss, aro_loss, val_loss, mel_g, spec_g, stop_g, aro_g, val_g) """ message = 'Step %-7d [%.3f sec/step,ml=%.3f,spl=%.3f,sl=%.3f,al=%.3f,vl=%.3f]' % ( step, time_window.average, mel_loss, spec_loss, stop_loss, aro_loss, val_loss) log(message, slack=(step % args.checkpoint_interval == 0)) if loss > 100 or math.isnan(loss): log('Loss exploded to %.5f at step %d!' % (loss, step), slack=True) raise Exception('Loss Exploded') if step % args.summary_interval == 0: log('Writing summary at step: %d' % step) try: summary_writer.add_summary(sess.run(stats), step) except Exception as e: log(f'summary failed and ignored: {str(e)}') if step % args.checkpoint_interval == 0: log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) log('Saving audio and alignment...') gt_mel, gt_spec, seq, mel, spec, align = sess.run([ model.mel_targets[0], model.spec_targets[0], model.text_targets[0], model.mel_outputs[0], model.spec_outputs[0], model.alignment_outputs[0] ]) text = sequence_to_text(seq) wav = audio.inv_spectrogram(hp, spec.T) wav_path = os.path.join(log_dir, 'step-%d-audio.wav' % step) mel_path = os.path.join(log_dir, 'step-%d-mel.png' % step) spec_path = os.path.join(log_dir, 'step-%d-spec.png' % step) align_path = os.path.join(log_dir, 'step-%d-align.png' % step) info = '%s, %s, step=%d, loss=%.5f\n %s' % ( args.model, time_string(), step, loss, text) plot.plot_alignment(align, align_path, info=info) plot.plot_mel(mel, mel_path, info=info, gt_mel=gt_mel) plot.plot_mel(spec, spec_path, info=info, gt_mel=gt_spec) audio.save_wav(hp, wav, wav_path) log('Input: %s' % text) except Exception as e: log('Exiting due to exception: %s' % e, slack=True) traceback.print_exc()
def train(args): if args.model_path is None: msg = 'Prepare for new run ...' output_dir = os.path.join( args.log_dir, args.run_name + '_' + datetime.datetime.now().strftime('%m%d_%H%M')) if not os.path.exists(output_dir): os.makedirs(output_dir) ckpt_dir = os.path.join( args.ckpt_dir, args.run_name + '_' + datetime.datetime.now().strftime('%m%d_%H%M')) if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) else: msg = 'Restart previous run ...\nlogs to save to %s, ckpt to save to %s, model to load from %s' % \ (args.log_dir, args.ckpt_dir, args.model_path) output_dir = args.log_dir ckpt_dir = args.ckpt_dir if not os.path.isdir(output_dir): print('Invalid log dir: %s' % output_dir) return if not os.path.isdir(ckpt_dir): print('Invalid ckpt dir: %s' % ckpt_dir) return set_logger(os.path.join(output_dir, 'outputs.log')) logging.info(msg) global device if args.device is not None: logging.info('Setting device to ' + args.device) device = torch.device(args.device) else: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logging.info('Setting up...') hparams.parse(args.hparams) logging.info(hparams_debug_string()) model = EdgeClassification if hparams.use_roberta: logging.info('Using Roberta...') model = RobertaEdgeClassification global_step = 0 if args.model_path is None: if hparams.load_pretrained: logging.info('Load online pretrained model...' + ( ('cached at ' + args.cache_path) if args.cache_path is not None else '')) if hparams.use_roberta: model = model.from_pretrained('roberta-base', cache_dir=args.cache_path, hparams=hparams) else: model = model.from_pretrained('bert-base-uncased', cache_dir=args.cache_path, hparams=hparams) else: logging.info('Build model from scratch...') if hparams.use_roberta: config = BertConfig.from_pretrained('bert-base-uncased') else: config = RobertaConfig.from_pretrained('roberta-base') model = model(config=config, hparams=hparams) else: if not os.path.isdir(args.model_path): raise OSError(str(args.model_path) + ' not found') logging.info('Load saved model from %s ...' % (args.model_path)) model = model.from_pretrained(args.model_path, hparams=hparams) step = args.model_path.split('_')[-1] if step.isnumeric(): global_step = int(step) logging.info('Initial step=%d' % global_step) if hparams.use_roberta: tokenizer = RobertaTokenizer.from_pretrained('roberta-base') else: tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') hparams.parse(args.hparams) logging.info(hparams_debug_string()) if hparams.text_sample_eval: if args.eval_text_path is None: raise ValueError('eval_text_path not given') if ':' not in args.eval_text_path: eval_data_paths = [args.eval_text_path] else: eval_data_paths = args.eval_text_path.split(':') eval_feeder = [] for p in eval_data_paths: name = os.path.split(p)[-1] if name.endswith('.tsv'): name = name[:-4] eval_feeder.append( (name, ExternalTextFeeder(p, hparams, tokenizer, 'dev'))) else: eval_feeder = [('', DataFeeder(args.data_dir, hparams, tokenizer, 'dev'))] tb_writer = SummaryWriter(output_dir) no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], 'weight_decay': hparams.weight_decay }, { 'params': [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }] optimizer = AdamW(optimizer_grouped_parameters, lr=hparams.learning_rate, eps=hparams.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=hparams.warmup_steps, lr_decay_step=hparams.lr_decay_step, max_lr_decay_rate=hparams.max_lr_decay_rate) acc_step = global_step * hparams.gradient_accumulation_steps time_window = ValueWindow() loss_window = ValueWindow() acc_window = ValueWindow() model.to(device) model.zero_grad() tr_loss = tr_acc = 0.0 start_time = time.time() if args.model_path is not None: logging.info('Load saved model from %s ...' % (args.model_path)) if os.path.exists(os.path.join(args.model_path, 'optimizer.pt')) \ and os.path.exists(os.path.join(args.model_path, 'scheduler.pt')): optimizer.load_state_dict( torch.load(os.path.join(args.model_path, 'optimizer.pt'))) optimizer.load_state_dict(optimizer.state_dict()) scheduler.load_state_dict( torch.load(os.path.join(args.model_path, 'scheduler.pt'))) scheduler.load_state_dict(scheduler.state_dict()) else: logging.warning('Could not find saved optimizer/scheduler') if global_step > 0: logs = run_eval(args, model, eval_feeder) for key, value in logs.items(): tb_writer.add_scalar(key, value, global_step) logging.info('Start training...') if hparams.text_sample_train: train_feeder = PrebuiltTrainFeeder(args.train_text_path, hparams, tokenizer, 'train') else: train_feeder = DataFeeder(args.data_dir, hparams, tokenizer, 'train') while True: batch = train_feeder.next_batch() model.train() outputs = model(input_ids=batch.input_ids.to(device), attention_mask=batch.input_mask.to(device), token_type_ids=None if batch.token_type_ids is None else batch.token_type_ids.to(device), labels=batch.labels.to(device)) loss = outputs['loss'] preds = outputs['preds'] acc = torch.mean((preds.cpu() == batch.labels).float()) preds = preds.cpu().detach().numpy() labels = batch.labels.detach().numpy() t_acc = np.sum(np.logical_and(preds == 1, labels == 1)) / np.sum(labels == 1) f_acc = np.sum(np.logical_and(preds == 0, labels == 0)) / np.sum(labels == 0) if hparams.gradient_accumulation_steps > 1: loss = loss / hparams.gradient_accumulation_steps acc = acc / hparams.gradient_accumulation_steps tr_loss += loss.item() tr_acc += acc.item() loss.backward() acc_step += 1 if acc_step % hparams.gradient_accumulation_steps != 0: continue torch.nn.utils.clip_grad_norm_(model.parameters(), hparams.max_grad_norm) optimizer.step() scheduler.step(None) model.zero_grad() global_step += 1 step_time = time.time() - start_time time_window.append(step_time) loss_window.append(tr_loss) acc_window.append(tr_acc) if global_step % args.save_steps == 0: # Save model checkpoint model_to_save = model.module if hasattr(model, 'module') else model cur_ckpt_dir = os.path.join(ckpt_dir, 'checkpoint_%d' % (global_step)) if not os.path.exists(cur_ckpt_dir): os.makedirs(cur_ckpt_dir) model_to_save.save_pretrained(cur_ckpt_dir) torch.save(args, os.path.join(cur_ckpt_dir, 'training_args.bin')) torch.save(optimizer.state_dict(), os.path.join(cur_ckpt_dir, 'optimizer.pt')) torch.save(scheduler.state_dict(), os.path.join(cur_ckpt_dir, 'scheduler.pt')) logging.info("Saving model checkpoint to %s", cur_ckpt_dir) if global_step % args.logging_steps == 0: logs = run_eval(args, model, eval_feeder) learning_rate_scalar = scheduler.get_lr()[0] logs['learning_rate'] = learning_rate_scalar logs['loss'] = loss_window.average logs['acc'] = acc_window.average for key, value in logs.items(): tb_writer.add_scalar(key, value, global_step) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f, acc=%.05f, avg_acc=%.05f, t_acc=%.05f, f_acc=%.05f]' % ( global_step, step_time, tr_loss, loss_window.average, tr_acc, acc_window.average, t_acc, f_acc) logging.info(message) tr_loss = tr_acc = 0.0 start_time = time.time()
def train(log_dir, args): save_dir = os.path.join(log_dir, 'pretrained/') checkpoint_path = os.path.join(save_dir, 'model.ckpt') input_path = os.path.join(args.base_dir, args.input) log('Checkpoint path: {}'.format(checkpoint_path)) log('Loading training data from: {}'.format(input_path)) log('Using model: {}'.format(args.model)) log(hparams_debug_string()) #Set up feeder coord = tf.train.Coordinator() with tf.variable_scope('datafeeder') as scope: feeder = Feeder(coord, input_path, hparams) #Set up model step_count = 0 try: #simple text file to keep count of global step with open(os.path.join(log_dir, 'step_counter.txt'), 'r') as file: step_count = int(file.read()) except: print( 'no step_counter file found, assuming there is no saved checkpoint' ) global_step = tf.Variable(step_count, name='global_step', trainable=False) with tf.variable_scope('model') as scope: model = create_model(args.model, hparams) model.initialize(feeder.inputs, feeder.input_lengths, feeder.mel_targets) model.add_loss() model.add_optimizer(global_step) stats = add_stats(model) #Book keeping step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2) #Memory allocation on the GPU as needed config = tf.ConfigProto() config.gpu_options.allow_growth = True #Train with tf.Session(config=config) as sess: try: summary_writer = tf.summary.FileWriter(log_dir, sess.graph) sess.run(tf.global_variables_initializer()) #saved model restoring if args.restore: #Restore saved model if the user requested it, Default = True. try: checkpoint_state = tf.train.get_checkpoint_state(log_dir) except tf.errors.OutOfRangeError as e: log('Cannot restore checkpoint: {}'.format(e)) if (checkpoint_state and checkpoint_state.model_checkpoint_path): log('Loading checkpoint {}'.format( checkpoint_state.model_checkpoint_path)) saver.restore(sess, checkpoint_state.model_checkpoint_path) else: if not args.restore: log('Starting new training!') else: log('No model to load at {}'.format(save_dir)) #initiating feeder feeder.start_in_session(sess) #Training loop while not coord.should_stop(): start_time = time.time() step, loss, opt = sess.run( [global_step, model.loss, model.optimize]) time_window.append(time.time() - start_time) loss_window.append(loss) message = 'Step {:7d} [{:.3f} sec/step, loss={:.5f}, avg_loss={:.5f}]'.format( step, time_window.average, loss, loss_window.average) log(message, end='\r') if loss > 100 or np.isnan(loss): log('Loss exploded to {:.5f} at step {}'.format( loss, step)) raise Exception('Loss exploded') if step % args.summary_interval == 0: log('\nWriting summary at step: {}'.format(step)) summary_writer.add_summary(sess.run(stats), step) if step % args.checkpoint_interval == 0: with open(os.path.join(log_dir, 'step_counter.txt'), 'w') as file: file.write(str(step)) log('Saving checkpoint to: {}-{}'.format( checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) input_seq, prediction = sess.run( [model.inputs[0], model.output[0]]) #Save an example prediction at this step log('Input at step {}: {}'.format( step, sequence_to_text(input_seq))) log('Model prediction: {}'.format( class_to_str(prediction))) except Exception as e: log('Exiting due to exception: {}'.format(e), slack=True) traceback.print_exc() coord.request_stop(e)
def train(): checkpoint_path = os.path.join(Config.LogDir, 'model.ckpt') save_dir = os.path.join(Config.LogDir, 'pretrained/') input_path = Config.DataDir #Set up data feeder coord = tf.train.Coordinator() with tf.variable_scope('datafeeder') as scope: feeder = Feeder(coord, input_path) #Set up model: step_count = 0 try: #simple text file to keep count of global step with open('step_counter.txt', 'r') as file: step_count = int(file.read()) except: print('no step_counter file found, assuming there is no saved checkpoint') global_step = tf.Variable(step_count, name='global_step', trainable=False) model = Tacotron2.Tacotron2(global_step, feeder.inputs, feeder.input_lengths, feeder.mel_targets, feeder.target_lengths) model.buildTacotron2() model.addLoss(feeder.masks) #Book keeping step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2) #Train with tf.Session() as sess: try: sess.run(tf.global_variables_initializer()) #saver.restore(sess, checkpoint_state.model_checkpoint_path) #initiating feeder feeder.start_in_session(sess) #Training loop while not coord.should_stop(): start_time = time.time() step, loss, _ = sess.run([model.global_step, model.loss, model.optim]) time_window.append(time.time() - start_time) loss_window.append(loss) if step % 1 == 0: message = 'Step {:7d} [{:.3f} sec/step, loss={:.5f}, avg_loss={:.5f}]'.format( step, time_window.average, loss, loss_window.average) print (message) ''' if loss > 100 or np.isnan(loss): log('Loss exploded to {:.5f} at step {}'.format(loss, step)) raise Exception('Loss exploded') if step % Config.CheckpointInterval == 0: with open('step_counter.txt', 'w') as file: file.write(str(step)) log('Saving checkpoint to: {}-{}'.format(checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) # Unlike the original tacotron, we won't save audio # because we yet have to use wavenet as vocoder log('Saving alignement..') input_seq, prediction, alignment = sess.run([model.inputs[0], model.mel_outputs[0], model.alignments[0], ]) #save predicted spectrogram to disk (for plot and manual evaluation purposes) mel_filename = 'ljspeech-mel-prediction-step-{}.npy'.format(step) np.save(os.path.join(log_dir, mel_filename), prediction.T, allow_pickle=False) #save alignment plot to disk (evaluation purposes) plot.plot_alignment(alignment, os.path.join(log_dir, 'step-{}-align.png'.format(step)), info='{}, {}, step={}, loss={:.5f}'.format(args.model, time_string(), step, loss)) log('Input at step {}: {}'.format(step, sequence_to_text(input_seq))) ''' except Exception as e: #log('Exiting due to exception: {}'.format(e), slack=True) traceback.print_exc() coord.request_stop(e)
def train(log_dir, args): checkpoint_path = os.path.join(log_dir, 'model.ckpt') input_path = os.path.join(args.base_dir, 'training/train.txt') logger.log('Checkpoint path: %s' % checkpoint_path) logger.log('Loading training data from: %s' % input_path) # set up DataFeeder coordi = tf.train.Coordinator() with tf.compat.v1.variable_scope('data_feeder'): feeder = DataFeeder(coordi, input_path) # set up Model global_step = tf.Variable(0, name='global_step', trainable=False) with tf.compat.v1.variable_scope('model'): model = Tacotron() model.init(feeder.inputs, feeder.input_lengths, mel_targets=feeder.mel_targets, linear_targets=feeder.linear_targets) model.add_loss() model.add_optimizer(global_step) stats = add_stats(model) # book keeping step = 0 loss_window = ValueWindow(100) time_window = ValueWindow(100) saver = tf.compat.v1.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2) # start training already! with tf.compat.v1.Session() as sess: try: summary_writer = tf.summary.FileWriter(log_dir, sess.graph) # initialize parameters sess.run(tf.compat.v1.global_variables_initializer()) # if requested, restore from step if (args.restore_step): restore_path = '%s-%d' % (checkpoint_path, args.restore_step) saver.restore(sess, restore_path) logger.log('Resuming from checkpoint: %s' % restore_path) else: logger.log('Starting a new training!') feeder.start_in_session(sess) while not coordi.should_stop(): start_time = time.time() step, loss, opt = sess.run( [global_step, model.loss, model.optimize]) time_window.append(time.time() - start_time) loss_window.append(loss) msg = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % ( step, time_window.average, loss, loss_window.average) logger.log(msg) if loss > 100 or math.isnan(loss): # bad situation logger.log('Loss exploded to %.05f at step %d!' % (loss, step)) raise Exception('Loss Exploded') if step % args.summary_interval == 0: # it's time to write summary logger.log('Writing summary at step: %d' % step) summary_writer.add_summary(sess.run(stats), step) if step % args.checkpoint_interval == 0: # it's time to save a checkpoint logger.log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) logger.log('Saving audio and alignment...') input_seq, spectrogram, alignment = sess.run([ model.inputs[0], model.linear_outputs[0], model.alignments[0] ]) # convert spectrogram to waveform waveform = audio.spectrogram_to_wav(spectrogram.T) # save it audio.save_audio( waveform, os.path.join(log_dir, 'step-%d-audio.wav' % step)) plotter.plot_alignment( alignment, os.path.join(log_dir, 'step-%d-align.png' % step), info='%s, %s, step=%d, loss=%.5f' % ('tacotron', time_string(), step, loss)) logger.log('Input: %s' % sequence_to_text(input_seq)) except Exception as e: logger.log('Exiting due to exception %s' % e) traceback.print_exc() coordi.request_stop(e)
def train(log_dir, config): config.data_paths = config.data_paths # ['datasets/moon'] data_dirs = config.data_paths # ['datasets/moon\\data'] num_speakers = len(data_dirs) config.num_test = config.num_test_per_speaker * num_speakers # 2*1 if num_speakers > 1 and hparams.model_type not in [ "multi-speaker", "simple" ]: raise Exception("[!] Unkown model_type for multi-speaker: {}".format( config.model_type)) commit = get_git_commit() if config.git else 'None' checkpoint_path = os.path.join( log_dir, 'model.ckpt' ) # 'logdir-tacotron\\moon_2018-08-28_13-06-42\\model.ckpt' #log(' [*] git recv-parse HEAD:\n%s' % get_git_revision_hash()) # hccho: 주석 처리 log('=' * 50) #log(' [*] dit diff:\n%s' % get_git_diff()) log('=' * 50) log(' [*] Checkpoint path: %s' % checkpoint_path) log(' [*] Loading training data from: %s' % data_dirs) log(' [*] Using model: %s' % config.model_dir) # 'logdir-tacotron\\moon_2018-08-28_13-06-42' log(hparams_debug_string()) # Set up DataFeeder: coord = tf.train.Coordinator() with tf.variable_scope('datafeeder') as scope: # DataFeeder의 6개 placeholder: train_feeder.inputs, train_feeder.input_lengths, train_feeder.loss_coeff, train_feeder.mel_targets, train_feeder.linear_targets, train_feeder.speaker_id train_feeder = DataFeederTacotron2(coord, data_dirs, hparams, config, 32, data_type='train', batch_size=config.batch_size) test_feeder = DataFeederTacotron2(coord, data_dirs, hparams, config, 8, data_type='test', batch_size=config.num_test) # Set up model: global_step = tf.Variable(0, name='global_step', trainable=False) with tf.variable_scope('model') as scope: model = create_model(hparams) model.initialize(inputs=train_feeder.inputs, input_lengths=train_feeder.input_lengths, num_speakers=num_speakers, speaker_id=train_feeder.speaker_id, mel_targets=train_feeder.mel_targets, linear_targets=train_feeder.linear_targets, is_training=True, loss_coeff=train_feeder.loss_coeff, stop_token_targets=train_feeder.stop_token_targets) model.add_loss() model.add_optimizer(global_step) train_stats = add_stats(model, scope_name='train') # legacy with tf.variable_scope('model', reuse=True) as scope: test_model = create_model(hparams) test_model.initialize( inputs=test_feeder.inputs, input_lengths=test_feeder.input_lengths, num_speakers=num_speakers, speaker_id=test_feeder.speaker_id, mel_targets=test_feeder.mel_targets, linear_targets=test_feeder.linear_targets, is_training=False, loss_coeff=test_feeder.loss_coeff, stop_token_targets=test_feeder.stop_token_targets) test_model.add_loss() # Bookkeeping: step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=None, keep_checkpoint_every_n_hours=2) sess_config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True) sess_config.gpu_options.allow_growth = True # Train! #with tf.Session(config=sess_config) as sess: with tf.Session(config=sess_config) as sess: try: summary_writer = tf.summary.FileWriter(log_dir, sess.graph) sess.run(tf.global_variables_initializer()) if config.load_path: # Restore from a checkpoint if the user requested it. restore_path = get_most_recent_checkpoint(config.model_dir) saver.restore(sess, restore_path) log('Resuming from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True) elif config.initialize_path: restore_path = get_most_recent_checkpoint( config.initialize_path) saver.restore(sess, restore_path) log('Initialized from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True) zero_step_assign = tf.assign(global_step, 0) sess.run(zero_step_assign) start_step = sess.run(global_step) log('=' * 50) log(' [*] Global step is reset to {}'.format(start_step)) log('=' * 50) else: log('Starting new training run at commit: %s' % commit, slack=True) start_step = sess.run(global_step) train_feeder.start_in_session(sess, start_step) test_feeder.start_in_session(sess, start_step) while not coord.should_stop(): start_time = time.time() step, loss, opt = sess.run( [global_step, model.loss_without_coeff, model.optimize]) time_window.append(time.time() - start_time) loss_window.append(loss) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % ( step, time_window.average, loss, loss_window.average) log(message, slack=(step % config.checkpoint_interval == 0)) if loss > 100 or math.isnan(loss): log('Loss exploded to %.05f at step %d!' % (loss, step), slack=True) raise Exception('Loss Exploded') if step % config.summary_interval == 0: log('Writing summary at step: %d' % step) summary_writer.add_summary(sess.run(train_stats), step) if step % config.checkpoint_interval == 0: log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) if step % config.test_interval == 0: log('Saving audio and alignment...') num_test = config.num_test fetches = [ model.inputs[:num_test], model.linear_outputs[:num_test], model.alignments[:num_test], test_model.inputs[:num_test], test_model.linear_outputs[:num_test], test_model.alignments[:num_test], ] sequences, spectrograms, alignments, test_sequences, test_spectrograms, test_alignments = sess.run( fetches) #librosa는 ffmpeg가 있어야 한다. save_and_plot( sequences[:1], spectrograms[:1], alignments[:1], log_dir, step, loss, "train" ) # spectrograms: (num_test,200,1025), alignments: (num_test,encoder_length,decoder_length) save_and_plot(test_sequences, test_spectrograms, test_alignments, log_dir, step, loss, "test") if step == 50: log("Stop at 300000's step (last loss= %.05f)" % loss_window.average) coord.request_stop(e) except Exception as e: log('Exiting due to exception: %s' % e, slack=True) traceback.print_exc() coord.request_stop(e)
def train(log_dir, config): config.data_paths = config.data_paths # 파싱된 명령행 인자값 중 데이터 경로 : default='datasets/kr_example' data_dirs = [os.path.join(data_path, "data") \ for data_path in config.data_paths] num_speakers = len(data_dirs) # 학습하는 화자 수 측정 : 단일화자 모델-1, 다중화자 모델-2 config.num_test = config.num_test_per_speaker * num_speakers if num_speakers > 1 and hparams.model_type not in ["deepvoice", "simple"]: # 다중화자 모델 학습일 때 모델 타입이 "deepvoice"나 "simple"이 아니라면 raise Exception("[!] Unkown model_type for multi-speaker: {}".format(config.model_type)) # hparams.modle_type을 config.model_type으로 오타남. commit = get_git_commit() if config.git else 'None' # git 관련된거여서 무시 checkpoint_path = os.path.join(log_dir, 'model.ckpt') # checkpoint_path 경로 지정-model.skpt 파일 경로 log(' [*] git recv-parse HEAD:\n%s' % get_git_revision_hash()) # git log log('='*50) # 줄 구분용 ===== #log(' [*] dit diff:\n%s' % get_git_diff()) log('='*50) # 줄 구분용 ===== log(' [*] Checkpoint path: %s' % checkpoint_path) # check_point 경로 출력 log(' [*] Loading training data from: %s' % data_dirs) log(' [*] Using model: %s' % config.model_dir) log(hparams_debug_string()) # Set up DataFeeder: coord = tf.train.Coordinator() # 쓰레드 사용 선언 with tf.variable_scope('datafeeder') as scope: train_feeder = DataFeeder( coord, data_dirs, hparams, config, 32, data_type='train', batch_size=hparams.batch_size) # def __init__(self, coordinator, data_dirs, hparams, config, batches_per_group, data_type, batch_size): test_feeder = DataFeeder( coord, data_dirs, hparams, config, 8, data_type='test', batch_size=config.num_test) # Set up model: is_randomly_initialized = config.initialize_path is None global_step = tf.Variable(0, name='global_step', trainable=False) with tf.variable_scope('model') as scope: model = create_model(hparams) # Tacotron 모델 생성 model.initialize( train_feeder.inputs, train_feeder.input_lengths, num_speakers, train_feeder.speaker_id, train_feeder.mel_targets, train_feeder.linear_targets, train_feeder.loss_coeff, is_randomly_initialized=is_randomly_initialized) model.add_loss() model.add_optimizer(global_step) train_stats = add_stats(model, scope_name='stats') # legacy with tf.variable_scope('model', reuse=True) as scope: test_model = create_model(hparams) # Tacotron test모델 생성 test_model.initialize( test_feeder.inputs, test_feeder.input_lengths, num_speakers, test_feeder.speaker_id, test_feeder.mel_targets, test_feeder.linear_targets, test_feeder.loss_coeff, rnn_decoder_test_mode=True, is_randomly_initialized=is_randomly_initialized) test_model.add_loss() test_stats = add_stats(test_model, model, scope_name='test') # model의 loss값같은것들을 tensorboard에 기록 / model에 test_model, model2에 model test_stats = tf.summary.merge([test_stats, train_stats]) # Bookkeeping: step = 0 time_window = ValueWindow(100) # ValueWindow 클래스 window_size = 100 loss_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=None, keep_checkpoint_every_n_hours=2) # 2시간에 한번씩 자동저장, checkpoint 삭제 안됨 sess_config = tf.ConfigProto( log_device_placement=False, # log_device_placement 작성하는동안 할당장치 알려줌. allow_soft_placement=True) # allow_soft_placement False면 GPU없을때 오류남 sess_config.gpu_options.allow_growth=True # 탄력적으로 GPU메모리 사용 # Train! #with tf.Session(config=sess_config) as sess: with tf.Session() as sess: # with문 내의 모든 명령들은 CPU 혹은 GPU 사용 선언 try: summary_writer = tf.summary.FileWriter(log_dir, sess.graph) # summary 오퍼레이션이 평가된 결과 및 텐서보드 그래프를 파라미터 형식으로 log_dir 에 저장 sess.run(tf.global_variables_initializer()) # 데이터셋이 로드되고 그래프가 모두 정의되면 변수를 초기화하여 훈련 시작 if config.load_path: # log의 설정 값들 경로를 지정하였다면 # Restore from a checkpoint if the user requested it. restore_path = get_most_recent_checkpoint(config.model_dir) # 가장 마지막에 저장된 파일경로 저장 saver.restore(sess, restore_path) # restore_path 값 가져오기 log('Resuming from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True) # git과 slack을 이용한 log 출력 elif config.initialize_path: # log의 설정 값들로 초기화하여 사용하기로 지정하였다면 restore_path = get_most_recent_checkpoint(config.initialize_path) # 지정된 경로에서 가장 마지막에 저장된 파일경로 저장 saver.restore(sess, restore_path) # restore_path 값 가져오기 log('Initialized from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True) # git과 slack을 이용한 log 출력 zero_step_assign = tf.assign(global_step, 0) # global_step의 텐서 객체 참조 변수 값을 0으로 바꿔주는 명령어 지정 sess.run(zero_step_assign) # 변수들을 모두 0으로 바꾸는 명령어 실행 start_step = sess.run(global_step) # global_step 값 부분을 시작지점으로 하여 연산 시작 log('='*50) log(' [*] Global step is reset to {}'. \ format(start_step)) # 즉, 연산 시작 부분이 0으로 초기화 되었다고 알려줌. log('='*50) else: log('Starting new training run at commit: %s' % commit, slack=True) # 과거의 데이터를 사용하지 않을 경우 새로운 학습이라고 log 출력 start_step = sess.run(global_step) # 연산 시작지점 가져오기 train_feeder.start_in_session(sess, start_step) test_feeder.start_in_session(sess, start_step) while not coord.should_stop(): # 쓰레드가 멈춰야하는 상황이 아니라면 start_time = time.time() # 시작시간 지정(1970년 1월 1일 이후 경과된 시간을 UTC 기준으로 초로 반환) step, loss, opt = sess.run( [global_step, model.loss_without_coeff, model.optimize], feed_dict=model.get_dummy_feed_dict()) # step 값은 global_step 값으로 지정, loss 값은 time_window.append(time.time() - start_time) loss_window.append(loss) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % ( step, time_window.average, loss, loss_window.average) log(message, slack=(step % config.checkpoint_interval == 0)) if loss > 100 or math.isnan(loss): log('Loss exploded to %.05f at step %d!' % (loss, step), slack=True) raise Exception('Loss Exploded') if step % config.summary_interval == 0: log('Writing summary at step: %d' % step) feed_dict = { **model.get_dummy_feed_dict(), **test_model.get_dummy_feed_dict() } summary_writer.add_summary(sess.run( test_stats, feed_dict=feed_dict), step) if step % config.checkpoint_interval == 0: log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) if step % config.test_interval == 0: log('Saving audio and alignment...') num_test = config.num_test fetches = [ model.inputs[:num_test], model.linear_outputs[:num_test], model.alignments[:num_test], test_model.inputs[:num_test], test_model.linear_outputs[:num_test], test_model.alignments[:num_test], ] feed_dict = { **model.get_dummy_feed_dict(), **test_model.get_dummy_feed_dict() } sequences, spectrograms, alignments, \ test_sequences, test_spectrograms, test_alignments = \ sess.run(fetches, feed_dict=feed_dict) save_and_plot(sequences[:1], spectrograms[:1], alignments[:1], log_dir, step, loss, "train") save_and_plot(test_sequences, test_spectrograms, test_alignments, log_dir, step, loss, "test") except Exception as e: log('Exiting due to exception: %s' % e, slack=True) traceback.print_exc() coord.request_stop(e)
def train(log_dir, config): config.data_paths = config.data_paths data_dirs = [os.path.join(data_path, "data") \ for data_path in config.data_paths] num_speakers = len(data_dirs) config.num_test = config.num_test_per_speaker * num_speakers if num_speakers > 1 and hparams.model_type not in ["deepvoice", "simple"]: raise Exception("[!] Unkown model_type for multi-speaker: {}".format( config.model_type)) commit = get_git_commit() if config.git else 'None' checkpoint_path = os.path.join(log_dir, 'model.ckpt') #log(' [*] git recv-parse HEAD:\n%s' % get_git_revision_hash()) log('=' * 50) #log(' [*] dit diff:\n%s' % get_git_diff()) log('=' * 50) log(' [*] Checkpoint path: %s' % checkpoint_path) log(' [*] Loading training data from: %s' % data_dirs) log(' [*] Using model: %s' % config.model_dir) log(hparams_debug_string()) # Set up DataFeeder: coord = tf.train.Coordinator() with tf.variable_scope('datafeeder') as scope: train_feeder = DataFeeder(coord, data_dirs, hparams, config, 32, data_type='train', batch_size=hparams.batch_size) test_feeder = DataFeeder(coord, data_dirs, hparams, config, 8, data_type='test', batch_size=config.num_test) # Set up model: is_randomly_initialized = config.initialize_path is None global_step = tf.Variable(0, name='global_step', trainable=False) with tf.variable_scope('model') as scope: model = create_model(hparams) model.initialize(train_feeder.inputs, train_feeder.input_lengths, num_speakers, train_feeder.speaker_id, train_feeder.mel_targets, train_feeder.linear_targets, train_feeder.loss_coeff, is_randomly_initialized=is_randomly_initialized) model.add_loss() model.add_optimizer(global_step) train_stats = add_stats(model, scope_name='stats') # legacy with tf.variable_scope('model', reuse=True) as scope: test_model = create_model(hparams) test_model.initialize(test_feeder.inputs, test_feeder.input_lengths, num_speakers, test_feeder.speaker_id, test_feeder.mel_targets, test_feeder.linear_targets, test_feeder.loss_coeff, rnn_decoder_test_mode=True, is_randomly_initialized=is_randomly_initialized) test_model.add_loss() test_stats = add_stats(test_model, model, scope_name='test') test_stats = tf.summary.merge([test_stats, train_stats]) # Bookkeeping: step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=None, keep_checkpoint_every_n_hours=2) sess_config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True) sess_config.gpu_options.allow_growth = True # Train! #with tf.Session(config=sess_config) as sess: with tf.Session() as sess: try: summary_writer = tf.summary.FileWriter(log_dir, sess.graph) sess.run(tf.global_variables_initializer()) if config.load_path: # Restore from a checkpoint if the user requested it. restore_path = get_most_recent_checkpoint(config.model_dir) saver.restore(sess, restore_path) log('Resuming from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True) elif config.initialize_path: restore_path = get_most_recent_checkpoint( config.initialize_path) saver.restore(sess, restore_path) log('Initialized from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True) zero_step_assign = tf.assign(global_step, 0) sess.run(zero_step_assign) start_step = sess.run(global_step) log('=' * 50) log(' [*] Global step is reset to {}'. \ format(start_step)) log('=' * 50) else: log('Starting new training run at commit: %s' % commit, slack=True) start_step = sess.run(global_step) train_feeder.start_in_session(sess, start_step) test_feeder.start_in_session(sess, start_step) while not coord.should_stop(): start_time = time.time() step, loss, opt = sess.run( [global_step, model.loss_without_coeff, model.optimize], feed_dict=model.get_dummy_feed_dict()) time_window.append(time.time() - start_time) loss_window.append(loss) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % ( step, time_window.average, loss, loss_window.average) log(message, slack=(step % config.checkpoint_interval == 0)) if loss > 100 or math.isnan(loss): log('Loss exploded to %.05f at step %d!' % (loss, step), slack=True) raise Exception('Loss Exploded') if step % config.summary_interval == 0: log('Writing summary at step: %d' % step) feed_dict = { **model.get_dummy_feed_dict(), **test_model.get_dummy_feed_dict() } summary_writer.add_summary( sess.run(test_stats, feed_dict=feed_dict), step) if step % config.checkpoint_interval == 0: log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) if step % config.test_interval == 0: log('Saving audio and alignment...') num_test = config.num_test fetches = [ model.inputs[:num_test], model.linear_outputs[:num_test], model.alignments[:num_test], test_model.inputs[:num_test], test_model.linear_outputs[:num_test], test_model.alignments[:num_test], ] feed_dict = { **model.get_dummy_feed_dict(), **test_model.get_dummy_feed_dict() } sequences, spectrograms, alignments, \ test_sequences, test_spectrograms, test_alignments = \ sess.run(fetches, feed_dict=feed_dict) save_and_plot(sequences[:1], spectrograms[:1], alignments[:1], log_dir, step, loss, "train") save_and_plot(test_sequences, test_spectrograms, test_alignments, log_dir, step, loss, "test") except Exception as e: log('Exiting due to exception: %s' % e, slack=True) traceback.print_exc() coord.request_stop(e)
def validate(val_loader, model, device, mels_criterion, stop_criterion, writer, val_dir): batch_time = ValueWindow() losses = ValueWindow() # switch to evaluate mode model.eval() global global_epoch global global_step with torch.no_grad(): end = time.time() for i, (txts, mels, stop_tokens, txt_lengths, mels_lengths) in enumerate(val_loader): # measure data loading time batch_time.update(time.time() - end) if device > -1: txts = txts.cuda(device) mels = mels.cuda(device) stop_tokens = stop_tokens.cuda(device) txt_lengths = txt_lengths.cuda(device) mels_lengths = mels_lengths.cuda(device) # compute output frames, decoder_frames, stop_tokens_predict, alignment = model( txts, txt_lengths, mels) decoder_frames_loss = mels_criterion(decoder_frames, mels, lengths=mels_lengths) frames_loss = mels_criterion(frames, mels, lengths=mels_lengths) stop_token_loss = stop_criterion(stop_tokens_predict, stop_tokens, lengths=mels_lengths) loss = decoder_frames_loss + frames_loss + stop_token_loss losses.update(loss.item()) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % hparams.print_freq == 0: log('Epoch: [{0}]\t' 'Test: [{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})'.format( global_epoch, i, len(val_loader), batch_time=batch_time, loss=losses)) # Logs writer.add_scalar("loss", float(loss.item()), global_step) writer.add_scalar( "avg_loss in {} window".format(losses.get_dinwow_size), float(losses.avg), global_step) writer.add_scalar("stop_token_loss", float(stop_token_loss.item()), global_step) writer.add_scalar("decoder_frames_loss", float(decoder_frames_loss.item()), global_step) writer.add_scalar("output_frames_loss", float(frames_loss.item()), global_step) dst_alignment_path = join(val_dir, "{}_alignment.png".format(global_step)) alignment = alignment.cpu().detach().numpy() plot_alignment(alignment[0, :txt_lengths[0], :mels_lengths[0]], dst_alignment_path, info="{}, {}".format(hparams.builder, global_step)) return losses.avg
def train(log_dir, args): save_dir = os.path.join(log_dir, 'pretrained/') checkpoint_path = os.path.join(save_dir, 'model.ckpt') input_path = os.path.join(args.base_dir, args.input) plot_dir = os.path.join(log_dir, 'plots') os.makedirs(plot_dir, exist_ok=True) log('Checkpoint path: {}'.format(checkpoint_path)) log('Loading training data from: {}'.format(input_path)) log('Using model: {}'.format(args.model)) log(hparams_debug_string()) #Set up data feeder coord = tf.train.Coordinator() with tf.variable_scope('datafeeder') as scope: feeder = Feeder(coord, input_path, hparams) #Set up model: step_count = 0 try: #simple text file to keep count of global step with open(os.path.join(log_dir, 'step_counter.txt'), 'r') as file: step_count = int(file.read()) except: print('no step_counter file found, assuming there is no saved checkpoint') global_step = tf.Variable(step_count, name='global_step', trainable=False) with tf.variable_scope('model') as scope: model = create_model(args.model, hparams) model.initialize(feeder.inputs, feeder.input_lengths, feeder.mel_targets, feeder.token_targets) model.add_loss() model.add_optimizer(global_step) stats = add_stats(model) #Book keeping step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=5) #Memory allocation on the GPU as needed config = tf.ConfigProto() config.gpu_options.allow_growth = True #Train with tf.Session(config=config) as sess: try: summary_writer = tf.summary.FileWriter(log_dir, sess.graph) sess.run(tf.global_variables_initializer()) #saved model restoring if args.restore: #Restore saved model if the user requested it, Default = True. try: checkpoint_state = tf.train.get_checkpoint_state(save_dir) except tf.errors.OutOfRangeError as e: log('Cannot restore checkpoint: {}'.format(e)) if (checkpoint_state and checkpoint_state.model_checkpoint_path): log('Loading checkpoint {}'.format(checkpoint_state.model_checkpoint_path)) saver.restore(sess, checkpoint_state.model_checkpoint_path) else: if not args.restore: log('Starting new training!') else: log('No model to load at {}'.format(save_dir)) #initiating feeder feeder.start_in_session(sess) #Training loop while not coord.should_stop(): start_time = time.time() step, loss, opt = sess.run([global_step, model.loss, model.optimize]) time_window.append(time.time() - start_time) loss_window.append(loss) message = 'Step {:7d} [{:.3f} sec/step, loss={:.5f}, avg_loss={:.5f}]'.format( step, time_window.average, loss, loss_window.average) log(message, end='\r') if loss > 100 or np.isnan(loss): log('Loss exploded to {:.5f} at step {}'.format(loss, step)) raise Exception('Loss exploded') if step % args.summary_interval == 0: log('\nWriting summary at step: {}'.format(step)) summary_writer.add_summary(sess.run(stats), step) if step % args.checkpoint_interval == 0: with open(os.path.join(log_dir,'step_counter.txt'), 'w') as file: file.write(str(step)) log('Saving checkpoint to: {}-{}'.format(checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) # Unlike the original tacotron, we won't save audio # because we yet have to use wavenet as vocoder log('Saving alignement and Mel-Spectrograms..') input_seq, prediction, alignment, target = sess.run([model.inputs[0], model.mel_outputs[0], model.alignments[0], model.mel_targets[0], ]) #save predicted spectrogram to disk (for plot and manual evaluation purposes) mel_filename = 'ljspeech-mel-prediction-step-{}.npy'.format(step) np.save(os.path.join(log_dir, mel_filename), prediction, allow_pickle=False) #save alignment plot to disk (control purposes) plot.plot_alignment(alignment, os.path.join(plot_dir, 'step-{}-align.png'.format(step)), info='{}, {}, step={}, loss={:.5f}'.format(args.model, time_string(), step, loss)) #save real mel-spectrogram plot to disk (control purposes) plot.plot_spectrogram(target, os.path.join(plot_dir, 'step-{}-real-mel-spectrogram.png'.format(step)), info='{}, {}, step={}, Real'.format(args.model, time_string(), step, loss)) #save predicted mel-spectrogram plot to disk (control purposes) plot.plot_spectrogram(prediction, os.path.join(plot_dir, 'step-{}-pred-mel-spectrogram.png'.format(step)), info='{}, {}, step={}, loss={:.5}'.format(args.model, time_string(), step, loss)) log('Input at step {}: {}'.format(step, sequence_to_text(input_seq))) except Exception as e: log('Exiting due to exception: {}'.format(e), slack=True) traceback.print_exc() coord.request_stop(e)
def main(): parser = argparse.ArgumentParser() parser.add_argument('-config', type=str, default='config/hparams.yaml') parser.add_argument('-load_model', type=str, default=None) parser.add_argument('-model_name', type=str, default='P_S_Transformer_debug', help='model name') # parser.add_argument('-batches_per_allreduce', type=int, default=1, # help='number of batches processed locally before ' # 'executing allreduce across workers; it multiplies ' # 'total batch size.') parser.add_argument('-num_wokers', type=int, default=0, help='how many subprocesses to use for data loading. ' '0 means that the data will be loaded in the main process') parser.add_argument('-log', type=str, default='train.log') opt = parser.parse_args() configfile = open(opt.config) config = AttrDict(yaml.load(configfile,Loader=yaml.FullLoader)) log_name = opt.model_name or config.model.name log_folder = os.path.join(os.getcwd(),'logdir/logging',log_name) if not os.path.isdir(log_folder): os.mkdir(log_folder) logger = init_logger(log_folder+'/'+opt.log) # TODO: build dataloader train_datafeeder = DataFeeder(config,'debug') # TODO: build model or load pre-trained model global global_step global_step = 0 learning_rate = CustomSchedule(config.model.d_model) # learning_rate = 0.00002 optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=config.optimizer.beta1, beta_2=config.optimizer.beta2, epsilon=config.optimizer.epsilon) logger.info('config.optimizer.beta1:' + str(config.optimizer.beta1)) logger.info('config.optimizer.beta2:' + str(config.optimizer.beta2)) logger.info('config.optimizer.epsilon:' + str(config.optimizer.epsilon)) # print(str(config)) model = Speech_transformer(config=config,logger=logger) #Create the checkpoint path and the checkpoint manager. This will be used to save checkpoints every n epochs. checkpoint_path = log_folder ckpt = tf.train.Checkpoint(transformer=model, optimizer=optimizer) ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5) # if a checkpoint exists, restore the latest checkpoint. if ckpt_manager.latest_checkpoint: ckpt.restore(ckpt_manager.latest_checkpoint) logger.info('Latest checkpoint restored!!') else: logger.info('Start new run') # define metrics and summary writer train_loss = tf.keras.metrics.Mean(name='train_loss') train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy') # summary_writer = tf.keras.callbacks.TensorBoard(log_dir=log_folder) summary_writer = summary_ops_v2.create_file_writer_v2(log_folder+'/train') # @tf.function def train_step(batch_data): inp = batch_data['the_inputs'] # batch*time*feature tar = batch_data['the_labels'] # batch*time # inp_len = batch_data['input_length'] # tar_len = batch_data['label_length'] gtruth = batch_data['ground_truth'] tar_inp = tar tar_real = gtruth # enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp[:,:,0], tar_inp) combined_mask = create_combined_mask(tar=tar_inp) with tf.GradientTape() as tape: predictions, _ = model(inp, tar_inp, True, None, combined_mask, None) # logger.info('config.train.label_smoothing_epsilon:' + str(config.train.label_smoothing_epsilon)) loss = LableSmoothingLoss(tar_real, predictions,config.model.vocab_size,config.train.label_smoothing_epsilon) gradients = tape.gradient(loss, model.trainable_variables) clipped_gradients, _ = tf.clip_by_global_norm(gradients, 1.0) optimizer.apply_gradients(zip(gradients, model.trainable_variables)) train_loss(loss) train_accuracy(tar_real, predictions) time_window = ValueWindow(100) loss_window = ValueWindow(100) acc_window = ValueWindow(100) logger.info('config.train.epoches:' + str(config.train.epoches)) first_time = True for epoch in range(config.train.epoches): logger.info('start epoch '+ str(epoch)) logger.info('total wavs: '+ str(len(train_datafeeder))) logger.info('batch size: ' + str(train_datafeeder.batch_size)) logger.info('batch per epoch: ' + str(len(train_datafeeder)//train_datafeeder.batch_size)) train_data = train_datafeeder.get_batch() start_time = time.time() train_loss.reset_states() train_accuracy.reset_states() for step in range(len(train_datafeeder)//train_datafeeder.batch_size): batch_data = next(train_data) step_time = time.time() train_step(batch_data) if first_time: model.summary() first_time=False time_window.append(time.time()-step_time) loss_window.append(train_loss.result()) acc_window.append(train_accuracy.result()) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f, acc=%.05f, avg_acc=%.05f]' % ( global_step, time_window.average, train_loss.result(), loss_window.average, train_accuracy.result(),acc_window.average) logger.info(message) if global_step % 10 == 0: with summary_ops_v2.always_record_summaries(): with summary_writer.as_default(): summary_ops_v2.scalar('train_loss', train_loss.result(), step=global_step) summary_ops_v2.scalar('train_acc', train_accuracy.result(), step=global_step) global_step += 1 ckpt_save_path = ckpt_manager.save() logger.info('Saving checkpoint for epoch {} at {}'.format(epoch+1, ckpt_save_path)) logger.info('Time taken for 1 epoch: {} secs\n'.format(time.time() - start_time))