def train(): # Create training and validation datasets train_set = create_dataset(FLAGS.train_files.split(','), batch_size=FLAGS.train_batch_size, cache_path=FLAGS.feature_cache, train_phase=True) iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set), tfv1.data.get_output_shapes(train_set), output_classes=tfv1.data.get_output_classes(train_set)) # Make initialization ops for switching between the two sets train_init_op = iterator.make_initializer(train_set) if FLAGS.dev_files: dev_csvs = FLAGS.dev_files.split(',') dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size, train_phase=False) for csv in dev_csvs] dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets] # Dropout dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)] dropout_feed_dict = { dropout_rates[0]: FLAGS.dropout_rate, dropout_rates[1]: FLAGS.dropout_rate2, dropout_rates[2]: FLAGS.dropout_rate3, dropout_rates[3]: FLAGS.dropout_rate4, dropout_rates[4]: FLAGS.dropout_rate5, dropout_rates[5]: FLAGS.dropout_rate6, } no_dropout_feed_dict = { rate: 0. for rate in dropout_rates } # Building the graph optimizer = create_optimizer()
def evaluate(test_csvs, create_model, try_loading): scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.lm_binary_path, FLAGS.lm_trie_path, Config.alphabet) test_csvs = FLAGS.test_files.split(',') test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size) for csv in test_csvs] iterator = tf.data.Iterator.from_structure(test_sets[0].output_types, test_sets[0].output_shapes, output_classes=test_sets[0].output_classes) test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets] (batch_x, batch_x_len), batch_y = iterator.get_next() # One rate per layer no_dropout = [None] * 6 logits, _ = create_model(batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout) # Transpose to batch major and apply softmax for decoder transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2])) loss = tf.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_x_len) tf.train.get_or_create_global_step() # Get number of accessible CPU cores for this process try: num_processes = cpu_count() except NotImplementedError: num_processes = 1 # Create a saver using variables from the above newly created graph saver = tf.train.Saver() with tf.Session(config=Config.session_config) as session: # Restore variables from training checkpoint loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation') if not loaded: loaded = try_loading(session, saver, 'checkpoint', 'most recent') if not loaded: log_error('Checkpoint directory ({}) does not contain a valid checkpoint state.'.format(FLAGS.checkpoint_dir)) exit(1) def run_test(init_op, dataset): logitses = [] losses = [] seq_lengths = [] ground_truths = [] bar = create_progressbar(prefix='Computing acoustic model predictions | ', widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start() log_progress('Computing acoustic model predictions...') step_count = 0 # Initialize iterator to the appropriate dataset session.run(init_op) # First pass, compute losses and transposed logits for decoding while True: try: logits, loss_, lengths, transcripts = session.run([transposed, loss, batch_x_len, batch_y]) except tf.errors.OutOfRangeError: break step_count += 1 bar.update(step_count) logitses.append(logits) losses.extend(loss_) seq_lengths.append(lengths) ground_truths.extend(sparse_tensor_value_to_texts(transcripts, Config.alphabet)) bar.finish() predictions = [] bar = create_progressbar(max_value=step_count, prefix='Decoding predictions | ').start() log_progress('Decoding predictions...') # Second pass, decode logits and compute WER and edit distance metrics for logits, seq_length in bar(zip(logitses, seq_lengths)): decoded = ctc_beam_search_decoder_batch(logits, seq_length, Config.alphabet, FLAGS.beam_width, num_processes=num_processes, scorer=scorer) predictions.extend(d[0][1] for d in decoded) distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)] wer, cer, samples = calculate_report(ground_truths, predictions, distances, losses) mean_loss = np.mean(losses) # Take only the first report_count items report_samples = itertools.islice(samples, FLAGS.report_count) print('Test on %s - WER: %f, CER: %f, loss: %f' % (dataset, wer, cer, mean_loss)) print('-' * 80) for sample in report_samples: print('WER: %f, CER: %f, loss: %f' % (sample.wer, sample.distance, sample.loss)) print(' - src: "%s"' % sample.src) print(' - res: "%s"' % sample.res) print('-' * 80) return samples samples = [] for csv, init_op in zip(test_csvs, test_init_ops): print('Testing model on {}'.format(csv)) samples.extend(run_test(init_op, dataset=csv)) return samples
def evaluate(test_csvs, create_model, try_loading): if FLAGS.lm_binary_path: scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.lm_binary_path, FLAGS.lm_trie_path, Config.alphabet) else: scorer = None test_csvs = FLAGS.test_files.split(',') test_sets = [ create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs ] iterator = tfv1.data.Iterator.from_structure( tfv1.data.get_output_types(test_sets[0]), tfv1.data.get_output_shapes(test_sets[0]), output_classes=tfv1.data.get_output_classes(test_sets[0])) test_init_ops = [ iterator.make_initializer(test_set) for test_set in test_sets ] batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next() # One rate per layer no_dropout = [None] * 6 logits, _ = create_model(batch_x=batch_x, batch_size=FLAGS.test_batch_size, seq_length=batch_x_len, dropout=no_dropout) # Transpose to batch major and apply softmax for decoder transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2])) loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_x_len) tfv1.train.get_or_create_global_step() # Get number of accessible CPU cores for this process try: num_processes = cpu_count() except NotImplementedError: num_processes = 1 # Create a saver using variables from the above newly created graph saver = tfv1.train.Saver() with tfv1.Session(config=Config.session_config) as session: # Restore variables from training checkpoint loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation') if not loaded: loaded = try_loading(session, saver, 'checkpoint', 'most recent') if not loaded: log_error( 'Checkpoint directory ({}) does not contain a valid checkpoint state.' .format(FLAGS.checkpoint_dir)) exit(1) def run_test(init_op, dataset): wav_filenames = [] losses = [] predictions = [] ground_truths = [] bar = create_progressbar(prefix='Test epoch | ', widgets=[ 'Steps: ', progressbar.Counter(), ' | ', progressbar.Timer() ]).start() log_progress('Test epoch...') step_count = 0 # Initialize iterator to the appropriate dataset session.run(init_op) # First pass, compute losses and transposed logits for decoding while True: try: batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \ session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y]) except tf.errors.OutOfRangeError: break decoded = ctc_beam_search_decoder_batch( batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width, num_processes=num_processes, scorer=scorer, cutoff_prob=FLAGS.cutoff_prob, cutoff_top_n=FLAGS.cutoff_top_n) predictions.extend(d[0][1] for d in decoded) ground_truths.extend( sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet)) wav_filenames.extend( wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames) losses.extend(batch_loss) step_count += 1 bar.update(step_count) bar.finish() wer, cer, samples = calculate_report(wav_filenames, ground_truths, predictions, losses) mean_loss = np.mean(losses) # Take only the first report_count items report_samples = itertools.islice(samples, FLAGS.report_count) print('Test on %s - WER: %f, CER: %f, loss: %f' % (dataset, wer, cer, mean_loss)) print('-' * 80) for sample in report_samples: print('WER: %f, CER: %f, loss: %f' % (sample.wer, sample.cer, sample.loss)) print(' - wav: file://%s' % sample.wav_filename) print(' - src: "%s"' % sample.src) print(' - res: "%s"' % sample.res) print('-' * 80) return samples samples = [] for csv, init_op in zip(test_csvs, test_init_ops): print('Testing model on {}'.format(csv)) samples.extend(run_test(init_op, dataset=csv)) return samples
def train(): do_cache_dataset = True # pylint: disable=too-many-boolean-expressions if (FLAGS.data_aug_features_multiplicative > 0 or FLAGS.data_aug_features_additive > 0 or FLAGS.augmentation_spec_dropout_keeprate < 1 or FLAGS.augmentation_freq_and_time_masking or FLAGS.augmentation_pitch_and_tempo_scaling or FLAGS.augmentation_speed_up_std > 0 or FLAGS.augmentation_sparse_warp): do_cache_dataset = False # Create training and validation datasets train_set = create_dataset(FLAGS.train_files.split(','), batch_size=FLAGS.train_batch_size, enable_cache=FLAGS.feature_cache and do_cache_dataset, cache_path=FLAGS.feature_cache, train_phase=True) iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set), tfv1.data.get_output_shapes(train_set), output_classes=tfv1.data.get_output_classes(train_set)) # Make initialization ops for switching between the two sets train_init_op = iterator.make_initializer(train_set) if FLAGS.dev_files: dev_csvs = FLAGS.dev_files.split(',') dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size, train_phase=False) for csv in dev_csvs] dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets] # The transfer learning approach here need us to supply the layers which we # want to exclude from the source model. # Say we want to exclude all layers except for the first one, we can use this: # # drop_source_layers=['2', '3', 'lstm', '5', '6'] # # If we want to use all layers from the source model except the last one, we use this: # # drop_source_layers=['6'] # if FLAGS.load == "transfer": drop_source_layers = ['2', '3', 'lstm', '5', '6'][-int(FLAGS.drop_source_layers):] else: drop_source_layers=None # Dropout dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)] dropout_feed_dict = { dropout_rates[0]: FLAGS.dropout_rate, dropout_rates[1]: FLAGS.dropout_rate2, dropout_rates[2]: FLAGS.dropout_rate3, dropout_rates[3]: FLAGS.dropout_rate4, dropout_rates[4]: FLAGS.dropout_rate5, dropout_rates[5]: FLAGS.dropout_rate6, } no_dropout_feed_dict = { rate: 0. for rate in dropout_rates } # Building the graph optimizer = create_optimizer() # Enable mixed precision training if FLAGS.automatic_mixed_precision: log_info('Enabling automatic mixed precision training.') optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates, drop_source_layers) # Average tower gradients across GPUs avg_tower_gradients = average_gradients(gradients) log_grads_and_vars(avg_tower_gradients) # global_step is automagically incremented by the optimizer global_step = tfv1.train.get_or_create_global_step() apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step) # Summaries step_summaries_op = tfv1.summary.merge_all('step_summaries') step_summary_writers = { 'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120), 'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120) } # Checkpointing checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep) checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'train') best_dev_saver = tfv1.train.Saver(max_to_keep=1) best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev') # Save flags next to checkpoints os.makedirs(FLAGS.checkpoint_dir, exist_ok=True) flags_file = os.path.join(FLAGS.checkpoint_dir, 'flags.txt') with open(flags_file, 'w') as fout: fout.write(FLAGS.flags_into_string()) initializer = tfv1.global_variables_initializer() with tfv1.Session(config=Config.session_config) as session: log_debug('Session opened.') # Loading or initializing loaded = False # Initialize training from a CuDNN RNN checkpoint if FLAGS.cudnn_checkpoint: if FLAGS.use_cudnn_rnn: log_error('Trying to use --cudnn_checkpoint but --use_cudnn_rnn ' 'was specified. The --cudnn_checkpoint flag is only ' 'needed when converting a CuDNN RNN checkpoint to ' 'a CPU-capable graph. If your system is capable of ' 'using CuDNN RNN, you can just specify the CuDNN RNN ' 'checkpoint normally with --checkpoint_dir.') sys.exit(1) log_info('Converting CuDNN RNN checkpoint from {}'.format(FLAGS.cudnn_checkpoint)) ckpt = tfv1.train.load_checkpoint(FLAGS.cudnn_checkpoint) missing_variables = [] # Load compatible variables from checkpoint for v in tfv1.global_variables(): try: v.load(ckpt.get_tensor(v.op.name), session=session) except tf.errors.NotFoundError: missing_variables.append(v) # Check that the only missing variables are the Adam moment tensors if any('Adam' not in v.op.name for v in missing_variables): log_error('Tried to load a CuDNN RNN checkpoint but there were ' 'more missing variables than just the Adam moment ' 'tensors.') sys.exit(1) # Initialize Adam moment tensors from scratch to allow use of CuDNN # RNN checkpoints. log_info('Initializing missing Adam moment tensors.') init_op = tfv1.variables_initializer(missing_variables) session.run(init_op) loaded = True if not loaded and FLAGS.load in ['auto', 'last']: #tf.initialize_all_variables().run() tfv1.get_default_graph().finalize() loaded = try_loading(session, checkpoint_saver, 'checkpoint', 'most recent') if not loaded and FLAGS.load in ['auto', 'best']: #tf.initialize_all_variables().run() tfv1.get_default_graph().finalize() loaded = try_loading(session, best_dev_saver, 'best_dev_checkpoint', 'best validation') if not loaded : if FLAGS.load == "transfer": if FLAGS.source_model_checkpoint_dir: print('Initializing model from', FLAGS.source_model_checkpoint_dir) ckpt = tfv1.train.load_checkpoint(FLAGS.source_model_checkpoint_dir) variables = list(ckpt.get_variable_to_shape_map().keys()) print('variable', variables) print('global', tf.global_variables()) # Load desired source variables missing_variables2 = [] for v in tf.global_variables(): if not any(layer in v.op.name for layer in drop_source_layers): print('Loading', v.op.name) try: v.load(ckpt.get_tensor(v.op.name), session=session) print('OK') except tf.errors.NotFoundError: missing_variables2.append(v) print('KO') except ValueError: #missing_variables2.append(v) print('KO for valueError') print('missing_variables =', missing_variables2) # Initialize all variables needed for DS, but not loaded from ckpt init_op = tfv1.variables_initializer( [v for v in tf.global_variables() if any(layer in v.op.name for layer in drop_source_layers) ] + missing_variables2) tfv1.get_default_graph().finalize() session.run(init_op) elif FLAGS.load in ['auto', 'init']: log_info('Initializing variables...') tfv1.get_default_graph().finalize() session.run(initializer) else: log_error('Unable to load %s model from specified checkpoint dir' ' - consider using load option "auto" or "init".' % FLAGS.load) sys.exit(1) def run_set(set_name, epoch, init_op, dataset=None): is_train = set_name == 'train' train_op = apply_gradient_op if is_train else [] feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict total_loss = 0.0 step_count = 0 step_summary_writer = step_summary_writers.get(set_name) checkpoint_time = time.time() # Setup progress bar class LossWidget(progressbar.widgets.FormatLabel): def __init__(self): progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f') def __call__(self, progress, data, **kwargs): data['mean_loss'] = total_loss / step_count if step_count else 0.0 return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs) prefix = 'Epoch {} | {:>10}'.format(epoch, 'Training' if is_train else 'Validation') widgets = [' | ', progressbar.widgets.Timer(), ' | Steps: ', progressbar.widgets.Counter(), ' | ', LossWidget()] suffix = ' | Dataset: {}'.format(dataset) if dataset else None pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start() # Initialize iterator to the appropriate dataset session.run(init_op) # Batch loop while True: try: _, current_step, batch_loss, problem_files, step_summary = \ session.run([train_op, global_step, loss, non_finite_files, step_summaries_op], feed_dict=feed_dict) except tf.errors.InvalidArgumentError as err: if FLAGS.augmentation_sparse_warp: log_info("Ignoring sparse warp error: {}".format(err)) continue else: raise except tf.errors.OutOfRangeError: break if problem_files.size > 0: problem_files = [f.decode('utf8') for f in problem_files[..., 0]] log_error('The following files caused an infinite (or NaN) ' 'loss: {}'.format(','.join(problem_files))) total_loss += batch_loss step_count += 1 pbar.update(step_count) step_summary_writer.add_summary(step_summary, current_step) if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs: checkpoint_saver.save(session, checkpoint_path, global_step=current_step) checkpoint_time = time.time() pbar.finish() mean_loss = total_loss / step_count if step_count > 0 else 0.0 return mean_loss, step_count log_info('STARTING Optimization') train_start_time = datetime.utcnow() best_dev_loss = float('inf') dev_losses = [] try: for epoch in range(FLAGS.epochs): # Training log_progress('Training epoch %d...' % epoch) train_loss, _ = run_set('train', epoch, train_init_op) log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss)) checkpoint_saver.save(session, checkpoint_path, global_step=global_step) if FLAGS.dev_files: # Validation dev_loss = 0.0 total_steps = 0 for csv, init_op in zip(dev_csvs, dev_init_ops): log_progress('Validating epoch %d on %s...' % (epoch, csv)) set_loss, steps = run_set('dev', epoch, init_op, dataset=csv) dev_loss += set_loss * steps total_steps += steps log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, csv, set_loss)) dev_loss = dev_loss / total_steps dev_losses.append(dev_loss) if dev_loss < best_dev_loss: best_dev_loss = dev_loss save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint') log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path)) # Early stopping if FLAGS.early_stop and len(dev_losses) >= FLAGS.es_steps: mean_loss = np.mean(dev_losses[-FLAGS.es_steps:-1]) std_loss = np.std(dev_losses[-FLAGS.es_steps:-1]) dev_losses = dev_losses[-FLAGS.es_steps:] log_debug('Checking for early stopping (last %d steps) validation loss: ' '%f, with standard deviation: %f and mean: %f' % (FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss)) if dev_losses[-1] > np.max(dev_losses[:-1]) or \ (abs(dev_losses[-1] - mean_loss) < FLAGS.es_mean_th and std_loss < FLAGS.es_std_th): log_info('Early stop triggered as (for last %d steps) validation loss:' ' %f with standard deviation: %f and mean: %f' % (FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss)) break except KeyboardInterrupt: pass log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time)) log_debug('Session closed.')
def train(): # Create training and validation datasets train_set = create_dataset(FLAGS.train_files.split(','), batch_size=FLAGS.train_batch_size, cache_path=FLAGS.feature_cache) iterator = tfv1.data.Iterator.from_structure( tfv1.data.get_output_types(train_set), tfv1.data.get_output_shapes(train_set), output_classes=tfv1.data.get_output_classes(train_set)) # Make initialization ops for switching between the two sets train_init_op = iterator.make_initializer(train_set) if FLAGS.dev_files: dev_csvs = FLAGS.dev_files.split(',') dev_sets = [ create_dataset([csv], batch_size=FLAGS.dev_batch_size) for csv in dev_csvs ] dev_init_ops = [ iterator.make_initializer(dev_set) for dev_set in dev_sets ] # Dropout dropout_rates = [ tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6) ] dropout_feed_dict = { dropout_rates[0]: FLAGS.dropout_rate, dropout_rates[1]: FLAGS.dropout_rate2, dropout_rates[2]: FLAGS.dropout_rate3, dropout_rates[3]: FLAGS.dropout_rate4, dropout_rates[4]: FLAGS.dropout_rate5, dropout_rates[5]: FLAGS.dropout_rate6, } no_dropout_feed_dict = {rate: 0. for rate in dropout_rates} # Building the graph optimizer = create_optimizer() gradients, loss = get_tower_results(iterator, optimizer, dropout_rates) # Average tower gradients across GPUs avg_tower_gradients = average_gradients(gradients) log_grads_and_vars(avg_tower_gradients) # global_step is automagically incremented by the optimizer global_step = tfv1.train.get_or_create_global_step() apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step) # Summaries step_summaries_op = tfv1.summary.merge_all('step_summaries') step_summary_writers = { 'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120), 'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120) } # Checkpointing checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep) checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'train') checkpoint_filename = 'checkpoint' best_dev_saver = tfv1.train.Saver(max_to_keep=1) best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev') best_dev_filename = 'best_dev_checkpoint' initializer = tfv1.global_variables_initializer() with tfv1.Session(config=Config.session_config) as session: log_debug('Session opened.') # Loading or initializing loaded = False # Initialize training from a CuDNN RNN checkpoint if FLAGS.cudnn_checkpoint: if FLAGS.use_cudnn_rnn: log_error( 'Trying to use --cudnn_checkpoint but --use_cudnn_rnn ' 'was specified. The --cudnn_checkpoint flag is only ' 'needed when converting a CuDNN RNN checkpoint to ' 'a CPU-capable graph. If your system is capable of ' 'using CuDNN RNN, you can just specify the CuDNN RNN ' 'checkpoint normally with --checkpoint_dir.') exit(1) log_info('Converting CuDNN RNN checkpoint from {}'.format( FLAGS.cudnn_checkpoint)) ckpt = tfv1.train.load_checkpoint(FLAGS.cudnn_checkpoint) missing_variables = [] # Load compatible variables from checkpoint for v in tfv1.global_variables(): try: v.load(ckpt.get_tensor(v.op.name), session=session) except tf.errors.NotFoundError: missing_variables.append(v) # Check that the only missing variables are the Adam moment tensors if any('Adam' not in v.op.name for v in missing_variables): log_error( 'Tried to load a CuDNN RNN checkpoint but there were ' 'more missing variables than just the Adam moment ' 'tensors.') exit(1) # Initialize Adam moment tensors from scratch to allow use of CuDNN # RNN checkpoints. log_info('Initializing missing Adam moment tensors.') init_op = tfv1.variables_initializer(missing_variables) session.run(init_op) loaded = True tfv1.get_default_graph().finalize() if not loaded and FLAGS.load in ['auto', 'last']: loaded = try_loading(session, checkpoint_saver, checkpoint_filename, 'most recent') if not loaded and FLAGS.load in ['auto', 'best']: loaded = try_loading(session, best_dev_saver, best_dev_filename, 'best validation') if not loaded: if FLAGS.load in ['auto', 'init']: log_info('Initializing variables...') session.run(initializer) else: log_error( 'Unable to load %s model from specified checkpoint dir' ' - consider using load option "auto" or "init".' % FLAGS.load) sys.exit(1) def run_set(set_name, epoch, init_op, dataset=None): is_train = set_name == 'train' train_op = apply_gradient_op if is_train else [] feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict total_loss = 0.0 step_count = 0 step_summary_writer = step_summary_writers.get(set_name) checkpoint_time = time.time() # Setup progress bar class LossWidget(progressbar.widgets.FormatLabel): def __init__(self): progressbar.widgets.FormatLabel.__init__( self, format='Loss: %(mean_loss)f') def __call__(self, progress, data, **kwargs): data[ 'mean_loss'] = total_loss / step_count if step_count else 0.0 return progressbar.widgets.FormatLabel.__call__( self, progress, data, **kwargs) prefix = 'Epoch {} | {:>10}'.format( epoch, 'Training' if is_train else 'Validation') widgets = [ ' | ', progressbar.widgets.Timer(), ' | Steps: ', progressbar.widgets.Counter(), ' | ', LossWidget() ] suffix = ' | Dataset: {}'.format(dataset) if dataset else None pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start() # Initialize iterator to the appropriate dataset session.run(init_op) # Batch loop while True: try: _, current_step, batch_loss, step_summary = \ session.run([train_op, global_step, loss, step_summaries_op], feed_dict=feed_dict) except tf.errors.OutOfRangeError: break total_loss += batch_loss step_count += 1 pbar.update(step_count) step_summary_writer.add_summary(step_summary, current_step) if is_train and FLAGS.checkpoint_secs > 0 and time.time( ) - checkpoint_time > FLAGS.checkpoint_secs: checkpoint_saver.save(session, checkpoint_path, global_step=current_step) checkpoint_time = time.time() pbar.finish() mean_loss = total_loss / step_count if step_count > 0 else 0.0 return mean_loss, step_count log_info('STARTING Optimization') train_start_time = datetime.utcnow() best_dev_loss = float('inf') dev_losses = [] try: for epoch in range(FLAGS.epochs): # Training log_progress('Training epoch %d...' % epoch) train_loss, _ = run_set('train', epoch, train_init_op) log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss)) checkpoint_saver.save(session, checkpoint_path, global_step=global_step) if FLAGS.dev_files: # Validation dev_loss = 0.0 total_steps = 0 for csv, init_op in zip(dev_csvs, dev_init_ops): log_progress('Validating epoch %d on %s...' % (epoch, csv)) set_loss, steps = run_set('dev', epoch, init_op, dataset=csv) dev_loss += set_loss * steps total_steps += steps log_progress( 'Finished validating epoch %d on %s - loss: %f' % (epoch, csv, set_loss)) dev_loss = dev_loss / total_steps dev_losses.append(dev_loss) if dev_loss < best_dev_loss: best_dev_loss = dev_loss save_path = best_dev_saver.save( session, best_dev_path, global_step=global_step, latest_filename=best_dev_filename) log_info( "Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path)) # Early stopping if FLAGS.early_stop and len(dev_losses) >= FLAGS.es_steps: mean_loss = np.mean(dev_losses[-FLAGS.es_steps:-1]) std_loss = np.std(dev_losses[-FLAGS.es_steps:-1]) dev_losses = dev_losses[-FLAGS.es_steps:] log_debug( 'Checking for early stopping (last %d steps) validation loss: ' '%f, with standard deviation: %f and mean: %f' % (FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss)) if dev_losses[-1] > np.max(dev_losses[:-1]) or \ (abs(dev_losses[-1] - mean_loss) < FLAGS.es_mean_th and std_loss < FLAGS.es_std_th): log_info( 'Early stop triggered as (for last %d steps) validation loss:' ' %f with standard deviation: %f and mean: %f' % (FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss)) break except KeyboardInterrupt: pass log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time)) log_debug('Session closed.')
def evaluate_with_pruning(test_csvs, prune_percentage, random, scores_file, result_file, verbose=True, skip_lstm=False): '''Code originaly comes from the DeepSpeech repository (./DeepSpeech/evaluate.py). The code is adapted for evaluation on pruned versions of the DeepSpeech model. ''' tfv1.reset_default_graph() if FLAGS.lm_binary_path: scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.lm_binary_path, FLAGS.lm_trie_path, Config.alphabet) else: scorer = None test_csvs = test_csvs.split(',') test_sets = [ create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs ] iterator = tfv1.data.Iterator.from_structure( tfv1.data.get_output_types(test_sets[0]), tfv1.data.get_output_shapes(test_sets[0]), output_classes=tfv1.data.get_output_classes(test_sets[0])) test_init_ops = [ iterator.make_initializer(test_set) for test_set in test_sets ] batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next() # One rate per layer no_dropout = [None] * 6 logits, _ = create_model(batch_x=batch_x, batch_size=FLAGS.test_batch_size, seq_length=batch_x_len, dropout=no_dropout) # Transpose to batch major and apply softmax for decoder transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2])) loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_x_len) tfv1.train.get_or_create_global_step() # Get number of accessible CPU cores for this process try: num_processes = cpu_count() except NotImplementedError: num_processes = 1 # Create a saver using variables from the above newly created graph saver = tfv1.train.Saver() with tfv1.Session(config=Config.session_config) as session: # Create a saver using variables from the above newly created graph saver = tfv1.train.Saver() # Restore variables from training checkpoint loaded = False if not loaded and FLAGS.load in ['auto', 'last']: loaded = try_loading(session, saver, 'checkpoint', 'most recent', load_step=False) if not loaded and FLAGS.load in ['auto', 'best']: loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation', load_step=False) if not loaded: print('Could not load checkpoint from {}'.format( FLAGS.checkpoint_dir)) sys.exit(1) ###### PRUNING PART ###### if verbose: if not prune_percentage: print('No pruning done.') else: if verbose: print('-' * 80) if verbose: print('pruning with {}%...'.format(prune_percentage)) scores_per_layer = np.load(scores_file) layer_masks = prune_matrices(scores_per_layer, prune_percentage=prune_percentage, random=random, verbose=verbose, skip_lstm=skip_lstm) n_layers_to_prune = len(layer_masks) i = 0 for index, v in enumerate(tf.trainable_variables()): lstm_layer_name = 'cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/kernel:0' if 'weights' not in v.name and v.name != lstm_layer_name: continue if (i >= n_layers_to_prune): break # if i < total_ops, it is not yet the last layer # make mask into the shape of the weights if v.name == lstm_layer_name: if skip_lstm: continue # Shape of LSTM weights: [(2*neurons), (4*neurons)] cell_template = np.ones((2, 4)) mask = np.repeat(layer_masks[i], v.shape[0] // 2, axis=0) mask = mask.reshape( [layer_masks[i].shape[0], v.shape[0] // 2]) mask = np.swapaxes(mask, 0, 1) mask = np.kron(mask, cell_template) else: idx = layer_masks[i] == 1 mask = np.repeat(layer_masks[i], v.shape[0], axis=0) mask = mask.reshape([layer_masks[i].shape[0], v.shape[0]]) mask = np.swapaxes(mask, 0, 1) # apply mask to weights session.run(v.assign(tf.multiply(v, mask))) i += 1 ###### END PRUNING PART ###### def run_test(init_op, dataset): wav_filenames = [] losses = [] predictions = [] ground_truths = [] bar = create_progressbar(prefix='Test epoch | ', widgets=[ 'Steps: ', progressbar.Counter(), ' | ', progressbar.Timer() ]).start() log_progress('Test epoch...') step_count = 0 # Initialize iterator to the appropriate dataset session.run(init_op) # First pass, compute losses and transposed logits for decoding while True: try: batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \ session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y]) except tf.errors.OutOfRangeError: break decoded = ctc_beam_search_decoder_batch( batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width, num_processes=num_processes, scorer=scorer, cutoff_prob=FLAGS.cutoff_prob, cutoff_top_n=FLAGS.cutoff_top_n) predictions.extend(d[0][1] for d in decoded) ground_truths.extend( sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet)) wav_filenames.extend( wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames) losses.extend(batch_loss) step_count += 1 bar.update(step_count) bar.finish() wer, cer, samples = calculate_report(wav_filenames, ground_truths, predictions, losses) mean_loss = np.mean(losses) # Take only the first report_count items report_samples = itertools.islice(samples, FLAGS.report_count) if verbose: print('Test on %s - WER: %f, CER: %f, loss: %f' % (dataset, wer, cer, mean_loss)) if verbose: print('-' * 80) if result_file: pruning_type = 'score-based' if not random else 'random' result_string = '''Results for evaluating model with pruning percentage of {}% and {} pruning: Test on {} - WER: {}, CER: {}, loss: {} '''.format(prune_percentage * 100, pruning_type, dataset, wer, cer, mean_loss) write_to_file(result_file, result_string, 'a+') return wer, cer, mean_loss results = [] for csv, init_op in zip(test_csvs, test_init_ops): if verbose: print('Testing model on {}'.format(csv)) results.extend(run_test(init_op, dataset=csv)) return results
def train(): # Create training and validation datasets train_set = create_dataset(FLAGS.train_files.split(','), batch_size=FLAGS.train_batch_size, cache_path=FLAGS.train_cached_features_path) iterator = tf.data.Iterator.from_structure(train_set.output_types, train_set.output_shapes, output_classes=train_set.output_classes) # Make initialization ops for switching between the two sets train_init_op = iterator.make_initializer(train_set) if FLAGS.dev_files: dev_set = create_dataset(FLAGS.dev_files.split(','), batch_size=FLAGS.dev_batch_size, cache_path=FLAGS.dev_cached_features_path) dev_init_op = iterator.make_initializer(dev_set) # Dropout dropout_rates = [tf.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)] dropout_feed_dict = { dropout_rates[0]: FLAGS.dropout_rate, dropout_rates[1]: FLAGS.dropout_rate2, dropout_rates[2]: FLAGS.dropout_rate3, dropout_rates[3]: FLAGS.dropout_rate4, dropout_rates[4]: FLAGS.dropout_rate5, dropout_rates[5]: FLAGS.dropout_rate6, } no_dropout_feed_dict = { rate: 0. for rate in dropout_rates } # Building the graph optimizer = create_optimizer() gradients, loss = get_tower_results(iterator, optimizer, dropout_rates) # Average tower gradients across GPUs avg_tower_gradients = average_gradients(gradients) log_grads_and_vars(avg_tower_gradients) # global_step is automagically incremented by the optimizer global_step = tf.train.get_or_create_global_step() apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step) # Summaries step_summaries_op = tf.summary.merge_all('step_summaries') step_summary_writers = { 'train': tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120), 'dev': tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120) } # Checkpointing checkpoint_saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep) checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'train') checkpoint_filename = 'checkpoint' best_dev_saver = tf.train.Saver(max_to_keep=1) best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev') best_dev_filename = 'best_dev_checkpoint' initializer = tf.global_variables_initializer() with tf.Session(config=Config.session_config) as session: log_debug('Session opened.') tf.get_default_graph().finalize() # Loading or initializing loaded = False if FLAGS.load in ['auto', 'last']: loaded = try_loading(session, checkpoint_saver, checkpoint_filename, 'most recent') if not loaded and FLAGS.load in ['auto', 'best']: loaded = try_loading(session, best_dev_saver, best_dev_filename, 'best validation') if not loaded: if FLAGS.load in ['auto', 'init']: log_info('Initializing variables...') session.run(initializer) else: log_error('Unable to load %s model from specified checkpoint dir' ' - consider using load option "auto" or "init".' % FLAGS.load) sys.exit(1) def run_set(set_name, init_op): is_train = set_name == 'train' train_op = apply_gradient_op if is_train else [] feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict total_loss = 0.0 step_count = 0 step_summary_writer = step_summary_writers.get(set_name) checkpoint_time = time.time() class LossWidget(progressbar.widgets.FormatLabel): def __init__(self): progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f') def __call__(self, progress, data, **kwargs): data['mean_loss'] = total_loss / step_count if step_count else 0.0 return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs) if FLAGS.show_progressbar: pbar = progressbar.ProgressBar(widgets=['Epoch {}'.format(epoch), ' | ', progressbar.widgets.Timer(), ' | Steps: ', progressbar.widgets.Counter(), ' | ', LossWidget()]) pbar.start() # Initialize iterator to the appropriate dataset session.run(init_op) # Batch loop while True: try: _, current_step, batch_loss, step_summary = \ session.run([train_op, global_step, loss, step_summaries_op], feed_dict=feed_dict) except tf.errors.OutOfRangeError: break total_loss += batch_loss step_count += 1 if FLAGS.show_progressbar: pbar.update(step_count) step_summary_writer.add_summary(step_summary, current_step) if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs: checkpoint_saver.save(session, checkpoint_path, global_step=current_step) checkpoint_time = time.time() if FLAGS.show_progressbar: pbar.finish() return total_loss / step_count log_info('STARTING Optimization') best_dev_loss = float('inf') dev_losses = [] try: for epoch in range(FLAGS.epochs): # Training if not FLAGS.show_progressbar: log_info('Training epoch %d...' % epoch) train_loss = run_set('train', train_init_op) if not FLAGS.show_progressbar: log_info('Finished training epoch %d - loss: %f' % (epoch, train_loss)) checkpoint_saver.save(session, checkpoint_path, global_step=global_step) if FLAGS.dev_files: # Validation if not FLAGS.show_progressbar: log_info('Validating epoch %d...' % epoch) dev_loss = run_set('dev', dev_init_op) if not FLAGS.show_progressbar: log_info('Finished validating epoch %d - loss: %f' % (epoch, dev_loss)) dev_losses.append(dev_loss) if dev_loss < best_dev_loss: best_dev_loss = dev_loss save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename=best_dev_filename) log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path)) # Early stopping if FLAGS.early_stop and len(dev_losses) >= FLAGS.es_steps: mean_loss = np.mean(dev_losses[-FLAGS.es_steps:-1]) std_loss = np.std(dev_losses[-FLAGS.es_steps:-1]) dev_losses = dev_losses[-FLAGS.es_steps:] log_debug('Checking for early stopping (last %d steps) validation loss: ' '%f, with standard deviation: %f and mean: %f' % (FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss)) if dev_losses[-1] > np.max(dev_losses[:-1]) or \ (abs(dev_losses[-1] - mean_loss) < FLAGS.es_mean_th and std_loss < FLAGS.es_std_th): log_info('Early stop triggered as (for last %d steps) validation loss:' ' %f with standard deviation: %f and mean: %f' % (FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss)) break except KeyboardInterrupt: pass log_debug('Session closed.')
def train(): do_cache_dataset = True # pylint: disable=too-many-boolean-expressions if (FLAGS.data_aug_features_multiplicative > 0 or FLAGS.data_aug_features_additive > 0 or FLAGS.augmentation_spec_dropout_keeprate < 1 or FLAGS.augmentation_freq_and_time_masking or FLAGS.augmentation_pitch_and_tempo_scaling or FLAGS.augmentation_speed_up_std > 0 or FLAGS.augmentation_sparse_warp): do_cache_dataset = False exception_box = ExceptionBox() # Create training and validation datasets train_set = create_dataset(FLAGS.train_files.split(','), batch_size=FLAGS.train_batch_size, enable_cache=FLAGS.feature_cache and do_cache_dataset, cache_path=FLAGS.feature_cache, train_phase=True, exception_box=exception_box, process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2, buffering=FLAGS.read_buffer) iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set), tfv1.data.get_output_shapes(train_set), output_classes=tfv1.data.get_output_classes(train_set)) # Make initialization ops for switching between the two sets train_init_op = iterator.make_initializer(train_set) if FLAGS.dev_files: dev_sources = FLAGS.dev_files.split(',') dev_sets = [create_dataset([source], batch_size=FLAGS.dev_batch_size, train_phase=False, exception_box=exception_box, process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2, buffering=FLAGS.read_buffer) for source in dev_sources] dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets] # Dropout dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)] dropout_feed_dict = { dropout_rates[0]: FLAGS.dropout_rate, dropout_rates[1]: FLAGS.dropout_rate2, dropout_rates[2]: FLAGS.dropout_rate3, dropout_rates[3]: FLAGS.dropout_rate4, dropout_rates[4]: FLAGS.dropout_rate5, dropout_rates[5]: FLAGS.dropout_rate6, } no_dropout_feed_dict = { rate: 0. for rate in dropout_rates } # Building the graph learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False) reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction)) optimizer = create_optimizer(learning_rate_var) # Enable mixed precision training if FLAGS.automatic_mixed_precision: log_info('Enabling automatic mixed precision training.') optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates) # Average tower gradients across GPUs avg_tower_gradients = average_gradients(gradients) log_grads_and_vars(avg_tower_gradients) # global_step is automagically incremented by the optimizer global_step = tfv1.train.get_or_create_global_step() apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step) # Summaries step_summaries_op = tfv1.summary.merge_all('step_summaries') step_summary_writers = { 'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120), 'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120) } # Checkpointing checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep) checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train') best_dev_saver = tfv1.train.Saver(max_to_keep=1) best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev') # Save flags next to checkpoints os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True) flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt') with open(flags_file, 'w') as fout: fout.write(FLAGS.flags_into_string()) with tfv1.Session(config=Config.session_config) as session: log_debug('Session opened.') # Prevent further graph changes tfv1.get_default_graph().finalize() # Load checkpoint or initialize variables if FLAGS.load == 'auto': method_order = ['best', 'last', 'init'] else: method_order = [FLAGS.load] load_or_init_graph(session, method_order) def run_set(set_name, epoch, init_op, dataset=None): is_train = set_name == 'train' train_op = apply_gradient_op if is_train else [] feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict total_loss = 0.0 step_count = 0 step_summary_writer = step_summary_writers.get(set_name) checkpoint_time = time.time() # Setup progress bar class LossWidget(progressbar.widgets.FormatLabel): def __init__(self): progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f') def __call__(self, progress, data, **kwargs): data['mean_loss'] = total_loss / step_count if step_count else 0.0 return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs) prefix = 'Epoch {} | {:>10}'.format(epoch, 'Training' if is_train else 'Validation') widgets = [' | ', progressbar.widgets.Timer(), ' | Steps: ', progressbar.widgets.Counter(), ' | ', LossWidget()] suffix = ' | Dataset: {}'.format(dataset) if dataset else None pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start() # Initialize iterator to the appropriate dataset session.run(init_op) # Batch loop while True: try: _, current_step, batch_loss, problem_files, step_summary = \ session.run([train_op, global_step, loss, non_finite_files, step_summaries_op], feed_dict=feed_dict) exception_box.raise_if_set() except tf.errors.InvalidArgumentError as err: if FLAGS.augmentation_sparse_warp: log_info("Ignoring sparse warp error: {}".format(err)) continue else: raise except tf.errors.OutOfRangeError: exception_box.raise_if_set() break if problem_files.size > 0: problem_files = [f.decode('utf8') for f in problem_files[..., 0]] log_error('The following files caused an infinite (or NaN) ' 'loss: {}'.format(','.join(problem_files))) total_loss += batch_loss step_count += 1 pbar.update(step_count) step_summary_writer.add_summary(step_summary, current_step) if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs: checkpoint_saver.save(session, checkpoint_path, global_step=current_step) checkpoint_time = time.time() pbar.finish() mean_loss = total_loss / step_count if step_count > 0 else 0.0 return mean_loss, step_count log_info('STARTING Optimization') train_start_time = datetime.utcnow() best_dev_loss = float('inf') dev_losses = [] epochs_without_improvement = 0 try: for epoch in range(FLAGS.epochs): # Training log_progress('Training epoch %d...' % epoch) train_loss, _ = run_set('train', epoch, train_init_op) log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss)) checkpoint_saver.save(session, checkpoint_path, global_step=global_step) if FLAGS.dev_files: # Validation dev_loss = 0.0 total_steps = 0 for source, init_op in zip(dev_sources, dev_init_ops): log_progress('Validating epoch %d on %s...' % (epoch, source)) set_loss, steps = run_set('dev', epoch, init_op, dataset=source) dev_loss += set_loss * steps total_steps += steps log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss)) dev_loss = dev_loss / total_steps dev_losses.append(dev_loss) # Count epochs without an improvement for early stopping and reduction of learning rate on a plateau # the improvement has to be greater than FLAGS.es_min_delta if dev_loss > best_dev_loss - FLAGS.es_min_delta: epochs_without_improvement += 1 else: epochs_without_improvement = 0 # Save new best model if dev_loss < best_dev_loss: best_dev_loss = dev_loss save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint') log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path)) # Early stopping if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs: log_info('Early stop triggered as the loss did not improve the last {} epochs'.format( epochs_without_improvement)) break # Reduce learning rate on plateau if (FLAGS.reduce_lr_on_plateau and epochs_without_improvement % FLAGS.plateau_epochs == 0 and epochs_without_improvement > 0): # If the learning rate was reduced and there is still no improvement # wait FLAGS.plateau_epochs before the learning rate is reduced again session.run(reduce_learning_rate_op) current_learning_rate = learning_rate_var.eval() log_info('Encountered a plateau, reducing learning rate to {}'.format( current_learning_rate)) except KeyboardInterrupt: pass log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time)) log_debug('Session closed.')
def evaluate(test_csvs, create_model, try_loading): scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.lm_binary_path, FLAGS.lm_trie_path, Config.alphabet) test_set = create_dataset(test_csvs, batch_size=FLAGS.test_batch_size, cache_path=FLAGS.test_cached_features_path) it = test_set.make_one_shot_iterator() (batch_x, batch_x_len), batch_y = it.get_next() # One rate per layer no_dropout = [None] * 6 logits, _ = create_model(batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout) # Transpose to batch major and apply softmax for decoder transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2])) loss = tf.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_x_len) global_step = tf.train.get_or_create_global_step() with tf.Session(config=Config.session_config) as session: # Create a saver using variables from the above newly created graph saver = tf.train.Saver() # Restore variables from training checkpoint loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation') if not loaded: loaded = try_loading(session, saver, 'checkpoint', 'most recent') if not loaded: log_error( 'Checkpoint directory ({}) does not contain a valid checkpoint state.' .format(FLAGS.checkpoint_dir)) exit(1) logitses = [] losses = [] seq_lengths = [] ground_truths = [] print('Computing acoustic model predictions...') bar = progressbar.ProgressBar(widgets=[ 'Steps: ', progressbar.Counter(), ' | ', progressbar.Timer() ]) step_count = 0 # First pass, compute losses and transposed logits for decoding while True: try: logits, loss_, lengths, transcripts = session.run( [transposed, loss, batch_x_len, batch_y]) except tf.errors.OutOfRangeError: break step_count += 1 bar.update(step_count) logitses.append(logits) losses.extend(loss_) seq_lengths.append(lengths) ground_truths.extend( sparse_tensor_value_to_texts(transcripts, Config.alphabet)) bar.finish() predictions = [] # Get number of accessible CPU cores for this process try: num_processes = cpu_count() except: num_processes = 1 print('Decoding predictions...') bar = progressbar.ProgressBar(max_value=step_count, widget=progressbar.AdaptiveETA) # Second pass, decode logits and compute WER and edit distance metrics for logits, seq_length in bar(zip(logitses, seq_lengths)): decoded = ctc_beam_search_decoder_batch(logits, seq_length, Config.alphabet, FLAGS.beam_width, num_processes=num_processes, scorer=scorer) predictions.extend(d[0][1] for d in decoded) distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)] wer, cer, samples = calculate_report(ground_truths, predictions, distances, losses) mean_loss = np.mean(losses) # Take only the first report_count items report_samples = itertools.islice(samples, FLAGS.report_count) print('Test - WER: %f, CER: %f, loss: %f' % (wer, cer, mean_loss)) print('-' * 80) for sample in report_samples: print('WER: %f, CER: %f, loss: %f' % (sample.wer, sample.distance, sample.loss)) print(' - src: "%s"' % sample.src) print(' - res: "%s"' % sample.res) print('-' * 80) return samples
def evaluate(test_csvs, create_model): if FLAGS.scorer_path: scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet) else: scorer = None test_csvs = FLAGS.test_files.split(',') test_sets = [ create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs ] iterator = tfv1.data.Iterator.from_structure( tfv1.data.get_output_types(test_sets[0]), tfv1.data.get_output_shapes(test_sets[0]), output_classes=tfv1.data.get_output_classes(test_sets[0])) test_init_ops = [ iterator.make_initializer(test_set) for test_set in test_sets ] batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next() # One rate per layer no_dropout = [None] * 6 logits, _ = create_model(batch_x=batch_x, batch_size=FLAGS.test_batch_size, seq_length=batch_x_len, dropout=no_dropout) # Transpose to batch major and apply softmax for decoder transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2])) loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_x_len) tfv1.train.get_or_create_global_step() # Get number of accessible CPU cores for this process try: num_processes = cpu_count() except NotImplementedError: num_processes = 1 with tfv1.Session(config=Config.session_config) as session: if FLAGS.load == 'auto': method_order = ['best', 'last'] else: method_order = [FLAGS.load] load_or_init_graph(session, method_order) def run_test(init_op, dataset): wav_filenames = [] losses = [] predictions = [] ground_truths = [] bar = create_progressbar(prefix='Test epoch | ', widgets=[ 'Steps: ', progressbar.Counter(), ' | ', progressbar.Timer() ]).start() log_progress('Test epoch...') step_count = 0 # Initialize iterator to the appropriate dataset session.run(init_op) # First pass, compute losses and transposed logits for decoding while True: try: batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \ session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y]) except tf.errors.OutOfRangeError: break decoded = ctc_beam_search_decoder_batch( batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width, num_processes=num_processes, scorer=scorer, cutoff_prob=FLAGS.cutoff_prob, cutoff_top_n=FLAGS.cutoff_top_n) predictions.extend(d[0][1] for d in decoded) ground_truths.extend( sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet)) wav_filenames.extend( wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames) losses.extend(batch_loss) step_count += 1 bar.update(step_count) bar.finish() # Print test summary test_samples = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, dataset) return test_samples samples = [] for csv, init_op in zip(test_csvs, test_init_ops): print('Testing model on {}'.format(csv)) samples.extend(run_test(init_op, dataset=csv)) return samples
def train(): # Create training and validation datasets train_set, train_batches = create_dataset( FLAGS.train_files.split(','), batch_size=FLAGS.train_batch_size, cache_path=FLAGS.train_cached_features_path) iterator = tf.data.Iterator.from_structure( train_set.output_types, train_set.output_shapes, output_classes=train_set.output_classes) # Make initialization ops for switching between the two sets train_init_op = iterator.make_initializer(train_set) if FLAGS.dev_files: dev_set, dev_batches = create_dataset( FLAGS.dev_files.split(','), batch_size=FLAGS.dev_batch_size, cache_path=FLAGS.dev_cached_features_path) dev_init_op = iterator.make_initializer(dev_set) # Dropout dropout_rates = [ tf.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6) ] dropout_feed_dict = { dropout_rates[0]: FLAGS.dropout_rate, dropout_rates[1]: FLAGS.dropout_rate2, dropout_rates[2]: FLAGS.dropout_rate3, dropout_rates[3]: FLAGS.dropout_rate4, dropout_rates[4]: FLAGS.dropout_rate5, dropout_rates[5]: FLAGS.dropout_rate6, } no_dropout_feed_dict = {rate: 0. for rate in dropout_rates} # Building the graph optimizer = create_optimizer() gradients, loss = get_tower_results(iterator, optimizer, dropout_rates) # Average tower gradients across GPUs avg_tower_gradients = average_gradients(gradients) log_grads_and_vars(avg_tower_gradients) # global_step is automagically incremented by the optimizer global_step = tf.Variable(0, trainable=False, name='global_step') apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step) # Summaries step_summaries_op = tf.summary.merge_all('step_summaries') step_summary_writers = { 'train': tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120), 'dev': tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120) } # Checkpointing checkpoint_saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep) checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'train') checkpoint_filename = 'checkpoint' best_dev_saver = tf.train.Saver(max_to_keep=1) best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev') best_dev_filename = 'best_dev_checkpoint' initializer = tf.global_variables_initializer() with tf.Session(config=Config.session_config) as session: log_debug('Session opened.') tf.get_default_graph().finalize() # Loading or initializing loaded = False if FLAGS.load in ['auto', 'last']: loaded = try_loading(session, checkpoint_saver, checkpoint_filename, 'most recent epoch') if not loaded and FLAGS.load in ['auto', 'best']: loaded = try_loading(session, best_dev_saver, best_dev_filename, 'best validation') if not loaded: if FLAGS.load in ['auto', 'init']: log_info('Initializing...') session.run(initializer) else: log_error( 'Unable to load %s model from specified checkpoint dir' ' - consider using load option "auto" or "init".' % FLAGS.load) sys.exit(1) # Retrieving global_step from restored model and setting training parameters accordingly step = session.run(global_step) num_gpus = len(Config.available_devices) steps_per_epoch = max(1, train_batches // num_gpus) current_epoch = step // steps_per_epoch target_epoch = current_epoch + abs( FLAGS.epoch) if FLAGS.epoch < 0 else FLAGS.epoch log_debug('step: %d' % step) log_debug('epoch: %d' % current_epoch) log_debug('target epoch: %d' % target_epoch) log_debug('steps per epoch: %d' % steps_per_epoch) log_debug('batches per step (GPUs): %d' % num_gpus) log_debug('number of batches in train set: %d' % train_batches) def run_set(set_name, init_op, num_batches): is_train = set_name == 'train' train_op = apply_gradient_op if is_train else [] feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict total_loss = 0.0 step_summary_writer = step_summary_writers.get(set_name) num_steps = max(1, num_batches // num_gpus) checkpoint_time = time.time() if FLAGS.show_progressbar: pbar = progressbar.ProgressBar(max_value=num_steps, redirect_stdout=True).start() else: pbar = lambda i: i # Initialize iterator to the appropriate dataset session.run(init_op) # Batch loop for step_index in pbar(range(num_steps)): if coord.should_stop(): break _, current_step, batch_loss, step_summary = \ session.run([train_op, global_step, loss, step_summaries_op], feed_dict=feed_dict) total_loss += batch_loss step_summary_writer.add_summary(step_summary, current_step) if is_train and FLAGS.checkpoint_secs > 0 and time.time( ) - checkpoint_time > FLAGS.checkpoint_secs: checkpoint_saver.save(session, checkpoint_path, global_step=current_step) checkpoint_time = time.time() return total_loss / num_steps if target_epoch > current_epoch: log_info('STARTING Optimization') best_dev_loss = float('inf') dev_losses = [] coord = tf.train.Coordinator() with coord.stop_on_exception(): for current_epoch in range(current_epoch, target_epoch): if coord.should_stop(): break # Training log_info('Training epoch %d ...' % current_epoch) train_loss = run_set('train', train_init_op, train_batches) log_info('Finished training epoch %d - loss: %f' % (current_epoch, train_loss)) checkpoint_saver.save(session, checkpoint_path, global_step=global_step) if FLAGS.dev_files: # Validation log_info('Validating epoch %d ...' % current_epoch) dev_loss = run_set('dev', dev_init_op, dev_batches) dev_losses.append(dev_loss) log_info('Finished validating epoch %d - loss: %f' % (current_epoch, dev_loss)) if dev_loss < best_dev_loss: best_dev_loss = dev_loss save_path = best_dev_saver.save( session, best_dev_path, latest_filename=best_dev_filename) log_info( "Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path)) # Early stopping if FLAGS.early_stop and len( dev_losses) >= FLAGS.es_steps: mean_loss = np.mean(dev_losses[-FLAGS.es_steps:-1]) std_loss = np.std(dev_losses[-FLAGS.es_steps:-1]) dev_losses = dev_losses[-FLAGS.es_steps:] log_debug( 'Checking for early stopping (last %d steps) validation loss: ' '%f, with standard deviation: %f and mean: %f' % (FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss)) if dev_losses[-1] > np.max(dev_losses[:-1]) or \ (abs(dev_losses[-1] - mean_loss) < FLAGS.es_mean_th and std_loss < FLAGS.es_std_th): log_info( 'Early stop triggered as (for last %d steps) validation loss:' ' %f with standard deviation: %f and mean: %f' % (FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss)) break coord.request_stop() else: log_info('Target epoch already reached - skipped training.') log_debug('Session closed.')