def train(log_dir, args, hparams, use_hvd=False): if use_hvd: import horovod.tensorflow as hvd # Initialize Horovod. hvd.init() else: hvd = None eval_dir, eval_plot_dir, eval_wav_dir, meta_folder, plot_dir, save_dir, tensorboard_dir, wav_dir = init_dir( log_dir) checkpoint_path = os.path.join(save_dir, 'centaur_model.ckpt') input_path = os.path.join(args.base_dir, args.input_dir) log('Checkpoint path: {}'.format(checkpoint_path)) log('Loading training data from: {}'.format(input_path)) log('Using model: {}'.format(args.model)) log(hparams_debug_string()) # Start by setting a seed for repeatability tf.set_random_seed(hparams.random_seed) # Set up data feeder coord = tf.train.Coordinator() with tf.variable_scope('datafeeder'): feeder = Feeder(coord, input_path, hparams) # Set up model: global_step = tf.Variable(0, name='global_step', trainable=False) model, stats = model_train_mode(feeder, hparams, global_step, hvd=hvd) eval_model = model_test_mode(feeder, hparams) # Embeddings metadata char_embedding_meta = os.path.join(meta_folder, 'CharacterEmbeddings.tsv') if not os.path.isfile(char_embedding_meta): with open(char_embedding_meta, 'w', encoding='utf-8') as f: for symbol in symbols: if symbol == ' ': symbol = '\\s' # For visual purposes, swap space with \s f.write('{}\n'.format(symbol)) char_embedding_meta = char_embedding_meta.replace(log_dir, '..') # Book keeping step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=2) log('Centaur training set to a maximum of {} steps'.format( args.train_steps)) # Memory allocation on the GPU as needed config = tf.ConfigProto() config.allow_soft_placement = True config.gpu_options.allow_growth = True if use_hvd: config.gpu_options.visible_device_list = str(hvd.local_rank()) # Train with tf.Session(config=config) as sess: try: # Init model and load weights sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) # saved model restoring if args.restore: # Restore saved model if the user requested it, default = True restore_model(saver, sess, global_step, save_dir, checkpoint_path, args.reset_global_step) else: log('Starting new training!', slack=True) saver.save(sess, checkpoint_path, global_step=global_step) # initializing feeder start_step = sess.run(global_step) feeder.start_threads(sess, start_step=start_step) # Horovod bcast vars across workers if use_hvd: # Horovod: broadcast initial variable states from rank 0 to all other processes. # This is necessary to ensure consistent initialization of all workers when # training is started with random weights or restored from a checkpoint. bcast = hvd.broadcast_global_variables(0) bcast.run() log('Worker{}: Initialized'.format(hvd.rank())) # Training loop summary_writer = tf.summary.FileWriter(tensorboard_dir, sess.graph) while not coord.should_stop() and step < args.train_steps: start_time = time.time() step, loss, opt = sess.run( [global_step, model.loss, model.train_op]) if use_hvd: main_process = hvd.rank() == 0 if main_process: time_window.append(time.time() - start_time) loss_window.append(loss) message = 'Step {:7d} [{:.3f} sec/step, loss={:.5f}, avg_loss={:.5f}]'.format( step, time_window.average, loss, loss_window.average) log(message, end='\r', slack=(step % args.checkpoint_interval == 0)) if np.isnan(loss) or loss > 100.: log('Loss exploded to {:.5f} at step {}'.format( loss, step)) raise Exception('Loss exploded') if step % args.summary_interval == 0: log('\nWriting summary at step {}'.format(step)) summary_writer.add_summary(sess.run(stats), step) if step % args.eval_interval == 0: run_eval(args, eval_dir, eval_model, eval_plot_dir, eval_wav_dir, feeder, hparams, sess, step, summary_writer) if step % args.checkpoint_interval == 0 or step == args.train_steps or step == 300: save_current_model(args, checkpoint_path, global_step, hparams, loss, model, plot_dir, saver, sess, step, wav_dir) if step % args.embedding_interval == 0 or step == args.train_steps or step == 1: update_character_embedding(char_embedding_meta, save_dir, summary_writer) log('Centaur training complete after {} global steps!'.format( args.train_steps), slack=True) return save_dir except Exception as e: log('Exiting due to exception: {}'.format(e), slack=True) traceback.print_exc() coord.request_stop(e)
def main(args): eval_fn = os.path.join(args.model_dir, 'eval-detailed.txt') assert os.path.exists(args.model_dir), 'Model dir does not exist.' assert args.overwrite or not os.path.exists( eval_fn), 'Evaluation file already exists.' os.environ["CUDA_VISIBLE_DEVICES"] = "%d" % args.gpu print '\n' + '=' * 30 + ' ARGUMENTS ' + '=' * 30 params = myutils.load_params(args.model_dir) for k, v in params.__dict__.iteritems(): print 'TRAIN | {}: {}'.format(k, v) for k, v in args.__dict__.iteritems(): print 'EVAL | {}: {}'.format(k, v) sys.stdout.flush() DURATION = 0.1 BATCH_SIZE = 16 with tf.device('/cpu:0'), tf.variable_scope('feeder'): feeder = Feeder(params.db_dir, subset_fn=args.subset_fn, ambi_order=params.ambi_order, audio_rate=params.audio_rate, video_rate=params.video_rate, context=params.context, duration=DURATION, return_video=VIDEO in params.encoders, img_prep=myutils.img_prep_fcn(), return_flow=FLOW in params.encoders, frame_size=(224, 448), queue_size=BATCH_SIZE * 5, n_threads=4, for_eval=True) batches = feeder.dequeue(BATCH_SIZE) ambix_batch = batches['ambix'] video_batch = batches['video'] if VIDEO in params.encoders else None flow_batch = batches['flow'] if FLOW in params.encoders else None audio_mask_batch = batches['audio_mask'] ss = int(params.audio_rate * params.context) / 2 t = int(params.audio_rate * DURATION) audio_input = ambix_batch[:, :, :params.ambi_order**2] audio_target = ambix_batch[:, ss:ss + t, params.ambi_order**2:] print '\n' + '=' * 20 + ' MODEL ' + '=' * 20 sys.stdout.flush() with tf.device('/gpu:0'): # Model num_sep = params.num_sep_tracks if params.separation != NO_SEPARATION else 1 net_params = SptAudioGenParams( sep_num_tracks=num_sep, ctx_feats_fc_units=params.context_units, loc_fc_units=params.loc_units, sep_freq_mask_fc_units=params.freq_mask_units, sep_fft_window=params.fft_window) model = SptAudioGen(ambi_order=params.ambi_order, audio_rate=params.audio_rate, video_rate=params.video_rate, context=params.context, sample_duration=DURATION, encoders=params.encoders, separation=params.separation, params=net_params) # Inference pred_t = model.inference_ops(audio=audio_input, video=video_batch, flow=flow_batch, is_training=False) # Losses and evaluation metrics with tf.variable_scope('metrics'): w_t = audio_input[:, ss:ss + t] _, stft_dist_ps, lsd_ps, mse_ps, snr_ps = model.evaluation_ops( pred_t, audio_target, w_t, mask_channels=audio_mask_batch[:, params.ambi_order**2:]) # Loader vars2save = [ v for v in tf.global_variables() if not v.op.name.startswith('metrics') ] saver = tf.train.Saver(vars2save) print '\n' + '=' * 30 + ' VARIABLES ' + '=' * 30 model_vars = tf.global_variables() import numpy as np for v in model_vars: if 'Adam' in v.op.name.split('/')[-1]: continue print ' * {:50s} | {:20s} | {:7s} | {:10s}'.format( v.op.name, str(v.get_shape()), str(np.prod(v.get_shape())), str(v.dtype)) print '\n' + '=' * 30 + ' EVALUATION ' + '=' * 30 sys.stdout.flush() config = tf.ConfigProto(allow_soft_placement=True, gpu_options=tf.GPUOptions(allow_growth=True)) with tf.Session(config=config) as sess: print 'Loading model...' sess.run(model.init_ops) saver.restore(sess, tf.train.latest_checkpoint(args.model_dir)) print 'Initializing data feeders...' coord = tf.train.Coordinator() tf.train.start_queue_runners(sess, coord) feeder.start_threads(sess) all_metrics = [ 'amplitude/predicted', 'amplitude/gt', 'mse/avg', 'mse/X', 'mse/Y', 'mse/Z', 'stft/avg', 'stft/X', 'stft/Y', 'stft/Z', 'lsd/avg', 'lsd/X', 'lsd/Y', 'lsd/Z', 'mel_lsd/avg', 'mel_lsd/X', 'mel_lsd/Y', 'mel_lsd/Z', 'snr/avg', 'snr/X', 'snr/Y', 'snr/Z', 'env_mse/avg', 'env_mse/X', 'env_mse/Y', 'env_mse/Z', 'emd/dir', 'emd/dir2' ] metrics = OrderedDict([(key, []) for key in all_metrics]) sample_ids = [] telapsed = deque(maxlen=20) print 'Start evaluation...' it = -1 # run_options = tf.RunOptions(timeout_in_ms=60*1000) while True: it += 1 if feeder.done(sess): break start_time = time.time() outs = sess.run([ batches['id'], audio_mask_batch, w_t, audio_target, pred_t, stft_dist_ps, lsd_ps, mse_ps, snr_ps ]) video_id, layout, mono, gt, pred = outs[:5] gt_m = np.concatenate( (mono, gt), axis=2) * layout[:, np.newaxis, :] pred_m = np.concatenate( (mono, pred), axis=2) * layout[:, np.newaxis, :] stft_dist, lsd, mse, snr = outs[5:] _env_time = 0. _emd_time = 0. _pow_time = 0. _lsd_time = 0. for smp in range(BATCH_SIZE): metrics['stft/avg'].append(np.mean(stft_dist[smp])) for i, ch in zip(range(3), 'YZX'): metrics['stft/' + ch].append(stft_dist[smp, i]) metrics['lsd/avg'].append(np.mean(lsd[smp])) for i, ch in zip(range(3), 'YZX'): metrics['lsd/' + ch].append(lsd[smp, i]) metrics['mse/avg'].append(np.mean(mse[smp])) for i, ch in zip(range(3), 'YZX'): metrics['mse/' + ch].append(mse[smp, i]) metrics['snr/avg'].append(np.nanmean(snr[smp])) for i, ch in zip(range(3), 'YZX'): metrics['snr/' + ch].append(snr[smp, i]) # Compute Mel LSD distance _t = time.time() mel_lsd = myutils.compute_lsd_dist(pred[smp], gt[smp], params.audio_rate) metrics['mel_lsd/avg'].append(np.mean(mel_lsd)) for i, ch in zip(range(3), 'YZX'): metrics['mel_lsd/' + ch].append(mel_lsd[i]) _lsd_time += (time.time() - _t) # Compute envelope distances _t = time.time() env_mse = myutils.compute_envelope_dist(pred[smp], gt[smp]) metrics['env_mse/avg'].append(np.mean(env_mse)) for i, ch in zip(range(3), 'YZX'): metrics['env_mse/' + ch].append(env_mse[i]) _env_time += (time.time() - _t) # Compute EMD (for speed, only compute emd over first 0.1s of every 1sec) _t = time.time() emd_dir, emd_dir2 = ambix_emd(pred_m[smp], gt_m[smp], model.snd_rate, ang_res=30) metrics['emd/dir'].append(emd_dir) metrics['emd/dir2'].append(emd_dir2) _emd_time += (time.time() - _t) # Compute chunk power _t = time.time() metrics['amplitude/gt'].append(np.abs(gt[smp]).max()) metrics['amplitude/predicted'].append(np.abs(pred[smp]).max()) _pow_time += (time.time() - _t) sample_ids.append(video_id[smp]) telapsed.append(time.time() - start_time) #print '\nTotal:', telapsed[-1] #print 'Env:', _env_time #print 'LSD:', _lsd_time #print 'EMD:', _emd_time #print 'POW:', _pow_time if it % 100 == 0: # Store evaluation metrics with open(eval_fn, 'w') as f: f.write('SampleID | {}\n'.format(' '.join(metrics.keys()))) for smp in range(len(sample_ids)): f.write('{} | {}\n'.format( sample_ids[smp], ' '.join( [str(metrics[key][smp]) for key in metrics]))) if it % 5 == 0: stats = OrderedDict([(m, np.mean(metrics[m])) for m in all_metrics]) myutils.print_stats(stats.values(), stats.keys(), BATCH_SIZE, telapsed, it, tag='EVAL') sys.stdout.flush() # Print progress stats = OrderedDict([(m, np.mean(metrics[m])) for m in all_metrics]) myutils.print_stats(stats.values(), stats.keys(), BATCH_SIZE, telapsed, it, tag='EVAL') sys.stdout.flush() with open(eval_fn, 'w') as f: f.write('SampleID | {}\n'.format(' '.join(metrics.keys()))) for smp in range(len(sample_ids)): f.write('{} | {}\n'.format( sample_ids[smp], ' '.join([str(metrics[key][smp]) for key in metrics]))) print('\n' + '#' * 60) print('End of evaluation.')
def main(args): os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) if not os.path.exists(args.model_dir): os.makedirs(args.model_dir) print('\n' + '=' * 30 + ' ARGUMENTS ' + '=' * 30) sys.stdout.flush() if args.resume: params = myutils.load_params(args.model_dir) args.encoders = params.encoders args.separation = params.separation args.ambi_order = params.ambi_order args.audio_rate = params.audio_rate args.video_rate = params.video_rate args.context = params.context args.sample_dur = params.sample_dur else: myutils.save_params(args) myutils.print_params(args) # Feeder min_t = min([args.context, args.sample_dur, 1. / args.video_rate]) args.video_rate = int(1. / min_t) with tf.device('/cpu:0'), tf.variable_scope('feeder'): feeder = Feeder(args.db_dir, subset_fn=args.subset_fn, ambi_order=args.ambi_order, audio_rate=args.audio_rate, video_rate=args.video_rate, context=args.context, duration=args.sample_dur, return_video=VIDEO in args.encoders, img_prep=myutils.img_prep_fcn(), return_flow=FLOW in args.encoders, frame_size=(224, 448), queue_size=args.batch_size * 5, n_threads=4, for_eval=False) batches = feeder.dequeue(args.batch_size) ambix_batch = batches['ambix'] video_batch = batches['video'] if 'video' in args.encoders else None flow_batch = batches['flow'] if 'flow' in args.encoders else None audio_mask_batch = batches['audio_mask'] t = int(args.audio_rate * args.sample_dur) ss = int(args.audio_rate * args.context) / 2 n_chann_in = args.ambi_order**2 audio_input = ambix_batch[:, :, :n_chann_in] audio_target = ambix_batch[:, ss:ss + t, n_chann_in:] print('\n' + '=' * 20 + ' MODEL ' + '=' * 20) sys.stdout.flush() with tf.device('/gpu:0'): # Model num_sep = args.num_sep_tracks if args.separation != NO_SEPARATION else 1 params = SptAudioGenParams(sep_num_tracks=num_sep, ctx_feats_fc_units=args.context_units, loc_fc_units=args.loc_units, sep_freq_mask_fc_units=args.freq_mask_units, sep_fft_window=args.fft_window) model = SptAudioGen(ambi_order=args.ambi_order, audio_rate=args.audio_rate, video_rate=args.video_rate, context=args.context, sample_duration=args.sample_dur, encoders=args.encoders, separation=args.separation, params=params) ambix_pred = model.inference_ops(audio=audio_input, video=video_batch, flow=flow_batch, is_training=True) # Losses and evaluation metrics print(audio_mask_batch) with tf.variable_scope('metrics'): metrics_t, _, _, _, _ = model.evaluation_ops( ambix_pred, audio_target, audio_input[:, ss:ss + t], mask_channels=audio_mask_batch[:, args.ambi_order**2:]) step_t = tf.Variable(0, trainable=False, name='step') with tf.variable_scope('loss'): loss_t = model.loss_ops(metrics_t, step_t) losses_t = {l: loss_t[l] for l in loss_t} regularizers = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) if regularizers and 'regularization' in losses_t: losses_t['regularization'] = tf.add_n(regularizers) losses_t['total_loss'] = tf.add_n(losses_t.values()) # Optimizer update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.variable_scope('optimization') and tf.control_dependencies( update_ops): train_op, lr_t = myutils.optimize(losses_t['total_loss'], step_t, args) # Initialization rest_ops = model.init_ops init_op = [ tf.global_variables_initializer(), tf.local_variables_initializer() ] saver = tf.train.Saver(max_to_keep=1) # Tensorboard metrics_t['training_loss'] = losses_t['total_loss'] metrics_t['queue'] = feeder.queue_state metrics_t['lr'] = lr_t myutils.add_scalar_summaries(metrics_t.values(), metrics_t.keys()) summary_ops = tf.summary.merge( tf.get_collection(tf.GraphKeys.SUMMARIES)) summary_writer = tf.summary.FileWriter(args.model_dir, flush_secs=30) #summary_writer.add_graph(tf.get_default_graph()) print('\n' + '=' * 30 + ' VARIABLES ' + '=' * 30) model_vars = tf.global_variables() import numpy as np for v in model_vars: if 'Adam' in v.op.name.split('/')[-1]: continue print(' * {:50s} | {:20s} | {:7s} | {:10s}'.format( v.op.name, str(v.get_shape()), str(np.prod(v.get_shape())), str(v.dtype))) print('\n' + '=' * 30 + ' TRAINING ' + '=' * 30) sys.stdout.flush() config = tf.ConfigProto(allow_soft_placement=True, gpu_options=tf.GPUOptions(allow_growth=True)) with tf.Session(config=config) as sess: print('Initializing network...') sess.run(init_op) if rest_ops: sess.run(rest_ops) print('Initializing data feeders...') coord = tf.train.Coordinator() tf.train.start_queue_runners(sess, coord) feeder.start_threads(sess) tf.get_default_graph().finalize() # Restore model init_step = 0 if args.resume: print('Restoring previously saved model...') ckpt = tf.train.latest_checkpoint(args.model_dir) if ckpt: saver.restore(sess, ckpt) init_step = sess.run(step_t) try: print('Start training...') duration = deque(maxlen=20) for step in range(init_step, args.n_iters): start_time = time.time() if step % 20 != 0: sess.run(train_op) else: outs = sess.run( [train_op, summary_ops, losses_t['total_loss']] + losses_t.values() + metrics_t.values()) if math.isnan(outs[2]): raise ValueError( 'Training produced a NaN metric or loss.') duration.append(time.time() - start_time) if step % 20 == 0: # Print progress to terminal and tensorboard myutils.print_stats(outs[3:], losses_t.keys() + metrics_t.keys(), args.batch_size, duration, step, tag='TRAIN') summary_writer.add_summary(outs[1], step) sys.stdout.flush() if step % 5000 == 0 and step != 0: # Save checkpoint saver.save(sess, args.model_dir + '/model.ckpt', global_step=step_t) print('=' * 60 + '\nCheckpoint saved\n' + '=' * 60) except Exception, e: print(str(e)) finally:
def main(): # Instantiate Configs config = hparams # Directory Setting input_path = os.path.join(config.wavenet_input, 'map.txt') post_input_path = os.path.join( config.post_train_input, 'map.txt') #post #yk: post train의 경우 input path설정 log_dir = config.log_dir save_dir = os.path.join(log_dir, 'wave_pretrained') plot_dir = os.path.join(log_dir, 'plots') wav_dir = os.path.join(log_dir, 'wavs') eval_dir = os.path.join(log_dir, 'eval-dir') eval_plot_dir = os.path.join(eval_dir, 'plots') eval_wav_dir = os.path.join(eval_dir, 'wavs') tensorboard_dir = os.path.join(log_dir, 'wavenet_events') meta_folder = os.path.join(log_dir, 'metas') os.makedirs(log_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True) os.makedirs(plot_dir, exist_ok=True) os.makedirs(wav_dir, exist_ok=True) os.makedirs(eval_dir, exist_ok=True) os.makedirs(eval_plot_dir, exist_ok=True) os.makedirs(eval_wav_dir, exist_ok=True) os.makedirs(tensorboard_dir, exist_ok=True) os.makedirs(meta_folder, exist_ok=True) checkpoint_path = os.path.join(save_dir, 'wavenet_model.ckpt') input_path = os.path.join(config.base_dir, input_path) post_input_path = os.path.join( config.base_dir, post_input_path) #post #yk: post train의 경우 input path설정 log.info('Checkpoint_path: {}'.format(checkpoint_path)) if config.spk_train_mode == True: # spk-train mode #post log.info( 'Loading spk-training data from: {}'.format(input_path)) #post elif config.post_train_mode == True: # post-train mode #post log.info('Loading post-training data from: {}'.format( post_input_path)) #post else: # train all vars #post log.info('Loading training data from: {}'.format(input_path)) #post log.info('Using model: {}'.format('WaveNet')) log.info(hparams_debug_string()) #Start by setting a seed for repeatability tf.set_random_seed(config.wavenet_random_seed) #Set up data feeder coord = tf.train.Coordinator() with tf.variable_scope('datafeeder') as scope: if config.post_train_mode == True: # post-train mode #post feeder = Feeder(coord, post_input_path, config.base_dir, config) #post else: # train all vars #post #yk : feeder는 post-train이냐 아니냐에따라 다른 data load feeder = Feeder(coord, input_path, config.base_dir, config) #post #Instantiate Model Class (Graphing) global_step = tf.Variable(0, name='global_step', trainable=False) model, stats = model_train_mode(feeder, config, global_step) eval_model = model_test_mode(feeder, config, global_step) ##EVAL #Speaker Embeddings metadata if config.speakers_path is not None: speaker_embedding_meta = config.speakers_path else: speaker_embedding_meta = os.path.join(meta_folder, 'SpeakerEmbeddings.tsv') if not os.path.isfile(speaker_embedding_meta): with open(speaker_embedding_meta, 'w', encoding='utf-8') as f: for speaker in config.speakers: f.write('{}\n'.format(speaker)) speaker_embedding_meta = speaker_embedding_meta.replace(log_dir, '..') #book keeping step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) sh_saver = create_shadow_saver(model, global_step) if config.post_train_mode == True: log.info('Wavenet post training set to a maximum of {} steps'.format( config.post_train_steps)) train_steps = config.post_train_steps elif config.pre_train_mode == True: log.info('Wavenet pre training set to a maximum of {} steps'.format( config.pre_train_steps)) train_steps = config.pre_train_steps elif config.spk_train_mode == True: log.info('Wavenet spk training set to a maximum of {} steps'.format( config.spk_train_steps)) train_steps = config.spk_train_steps else: # train all vars #post #yk : post train max step따로 받도록 >> train_step으로 통일 (뒤에 train loop에서 config.wavenet_train_steps이던 것을 train_steps로 바) log.info('Wavenet training set to a maximum of {} steps'.format( config.wavenet_train_steps)) train_steps = config.wavenet_train_steps #Memory allocation on the memory conf = tf.ConfigProto() conf.gpu_options.allow_growth = True conf.allow_soft_placement = True run_init = False #Train with tf.Session(config=conf) as sess: try: summary_writer = tf.summary.FileWriter(tensorboard_dir, sess.graph) sess.run(tf.global_variables_initializer()) #saved model restoring if config.restore == True: # Restore saved model if the user requested it, default = True try: checkpoint_state = tf.train.get_checkpoint_state(save_dir) if (checkpoint_state and checkpoint_state.model_checkpoint_path): log.info('Loading checkpoint {}'.format( checkpoint_state.model_checkpoint_path)) load_averaged_model( sess, sh_saver, checkpoint_state.model_checkpoint_path) else: log.info('No model to load at {}'.format(save_dir)) if config.wavenet_weight_normalization: run_init = True except tf.errors.OutOfRangeError as e: log.info('Cannot restore checkpoint: {}'.format(e)) else: log.info('Starting new training!') if config.wavenet_weight_normalization: run_init = True if run_init: log.info( '\nApplying Weight normalization in fresh training. Applying data dependent initialization forward pass..' ) #Create init_model init_model, _ = model_train_mode(feeder, config, global_step, init=True) #initializing feeder feeder.start_threads(sess) if run_init: #Run one forward pass for model parameters initialization (make prediction on init_batch) _ = sess.run(init_model.tower_y_hat) log.info( 'Data dependent initialization done. Starting training!') #Training loop while not coord.should_stop() and step < train_steps: start_time = time.time() step, loss, vq_loss, vq_perplexity, reconst_loss, spk_loss, opt = sess.run( [ global_step, model.loss, model.vq_loss, model.vq_perplexity, model.reconst_loss, model.spk_loss, model.optimize ]) time_window.append(time.time() - start_time) loss_window.append(loss) message = 'Step {:7d} [{:.3f} sec/step, loss={:.5f}, avg_loss={:.5f}, vq_loss={:.5f}, vq_perplexity={:.5f}, reconst_loss={:.5f}, spk_loss={:.5f}]'.format( step, time_window.average, loss, loss_window.average, vq_loss, vq_perplexity, reconst_loss, spk_loss) log.info(message) if np.isnan(loss) or loss > 10000: log.info('Loss exploded to {:.5f} at step {}'.format( loss, step)) raise Exception('Loss exploded') if step % config.summary_interval == 0: log.info('\nWriting summary at step {}'.format(step)) summary_writer.add_summary(sess.run(stats), step) if step % config.checkpoint_interval == 0 or step == train_steps: save_log(sess, step, model, plot_dir, wav_dir, config=config) save_checkpoint(sess, sh_saver, checkpoint_path, global_step) if step % config.eval_interval == 0: log.info('\nEvaluating at step {}'.format(step)) eval_step(sess, step, eval_model, eval_plot_dir, eval_wav_dir, summary_writer=summary_writer, config=model._config) ##EVAL if config.gin_channels > 0 and ( step % config.embedding_interval == 0 or step == train_steps): #or step == 1): #Get current checkpoint state checkpoint_state = tf.train.get_checkpoint_state(save_dir) print("checkpoint_state : {}".format(checkpoint_state)) #Update Projector log.info( '\nSaving Model Speaker Embeddings visualization..') add_embedding_stats(summary_writer, [model.embedding_table.name], [speaker_embedding_meta], checkpoint_state.model_checkpoint_path) log.info( 'WaveNet Speaker embeddings have been updated on tensorboard!' ) log.info('Wavenet training complete after {} global steps'.format( train_steps)) return save_dir except Exception as e: log.info('Exiting due to exception: {}'.format(e)) traceback.print_exc() coord.request_stop(e)