def transcribe_file(audio_path, tlog_path): from DeepSpeech import create_model # pylint: disable=cyclic-import,import-outside-toplevel from util.checkpoints import load_or_init_graph initialize_globals() scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet) try: num_processes = cpu_count() except NotImplementedError: num_processes = 1 with AudioFile(audio_path, as_path=True) as wav_path: data_set = split_audio_file( wav_path, batch_size=FLAGS.batch_size, aggressiveness=FLAGS.vad_aggressiveness, outlier_duration_ms=FLAGS.outlier_duration_ms, outlier_batch_size=FLAGS.outlier_batch_size) iterator = tf.data.Iterator.from_structure( data_set.output_types, data_set.output_shapes, output_classes=data_set.output_classes) batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next( ) no_dropout = [None] * 6 logits, _ = create_model(batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout) transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2])) tf.train.get_or_create_global_step() with tf.Session(config=Config.session_config) as session: if FLAGS.load == 'auto': method_order = ['best', 'last'] else: method_order = [FLAGS.load] load_or_init_graph(session, method_order) session.run(iterator.make_initializer(data_set)) transcripts = [] while True: try: starts, ends, batch_logits, batch_lengths = \ session.run([batch_time_start, batch_time_end, transposed, batch_x_len]) except tf.errors.OutOfRangeError: break decoded = ctc_beam_search_decoder_batch( batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width, num_processes=num_processes, scorer=scorer) decoded = list(d[0][1] for d in decoded) transcripts.extend(zip(starts, ends, decoded)) transcripts.sort(key=lambda t: t[0]) transcripts = [{ 'start': int(start), 'end': int(end), 'transcript': transcript } for start, end, transcript in transcripts] with open(tlog_path, 'w') as tlog_file: json.dump(transcripts, tlog_file, default=float)
def do_single_file_inference(input_file_path): with tfv1.Session(config=Config.session_config) as session: inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1) # Restore variables from training checkpoint if FLAGS.load == 'auto': method_order = ['best', 'last'] else: method_order = [FLAGS.load] load_or_init_graph(session, method_order) features, features_len = audiofile_to_features(input_file_path) previous_state_c = np.zeros([1, Config.n_cell_dim]) previous_state_h = np.zeros([1, Config.n_cell_dim]) # Add batch dimension features = tf.expand_dims(features, 0) features_len = tf.expand_dims(features_len, 0) # Evaluate features = create_overlapping_windows(features).eval(session=session) features_len = features_len.eval(session=session) logits = outputs['outputs'].eval(feed_dict={ inputs['input']: features, inputs['input_lengths']: features_len, inputs['previous_state_c']: previous_state_c, inputs['previous_state_h']: previous_state_h, }, session=session) logits = np.squeeze(logits) if FLAGS.scorer_path: scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet) else: scorer = None decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width, scorer=scorer, cutoff_prob=FLAGS.cutoff_prob, cutoff_top_n=FLAGS.cutoff_top_n) # Print highest probability result print(decoded[0][1])
def export(): r''' Restores the trained variables into a simpler graph that will be exported for serving. ''' log_info('Exporting the model...') from tensorflow.python.framework.ops import Tensor, Operation inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite) graph_version = int(file_relative_read('GRAPH_VERSION').strip()) assert graph_version > 0 outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version') outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate') outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len') outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step') outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width') outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet') if FLAGS.export_language: outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language') # Prevent further graph changes tfv1.get_default_graph().finalize() output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)] output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)] output_names = output_names_tensors + output_names_ops with tf.Session() as session: # Restore variables from checkpoint if FLAGS.load == 'auto': method_order = ['best', 'last'] else: method_order = [FLAGS.load] load_or_init_graph(session, method_order) output_filename = FLAGS.export_file_name + '.pb' if FLAGS.remove_export: if os.path.isdir(FLAGS.export_dir): log_info('Removing old export') shutil.rmtree(FLAGS.export_dir) output_graph_path = os.path.join(FLAGS.export_dir, output_filename) if not os.path.isdir(FLAGS.export_dir): os.makedirs(FLAGS.export_dir) frozen_graph = tfv1.graph_util.convert_variables_to_constants( sess=session, input_graph_def=tfv1.get_default_graph().as_graph_def(), output_node_names=output_names) frozen_graph = tfv1.graph_util.extract_sub_graph( graph_def=frozen_graph, dest_nodes=output_names) if not FLAGS.export_tflite: with open(output_graph_path, 'wb') as fout: fout.write(frozen_graph.SerializeToString()) else: output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite')) converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values()) converter.optimizations = [tf.lite.Optimize.DEFAULT] # AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite converter.allow_custom_ops = True tflite_model = converter.convert() with open(output_tflite_path, 'wb') as fout: fout.write(tflite_model) log_info('Models exported at %s' % (FLAGS.export_dir)) metadata_fname = os.path.join(FLAGS.export_dir, '{}_{}_{}.md'.format( FLAGS.export_author_id, FLAGS.export_model_name, FLAGS.export_model_version)) model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow' with open(metadata_fname, 'w') as f: f.write('---\n') f.write('author: {}\n'.format(FLAGS.export_author_id)) f.write('model_name: {}\n'.format(FLAGS.export_model_name)) f.write('model_version: {}\n'.format(FLAGS.export_model_version)) f.write('contact_info: {}\n'.format(FLAGS.export_contact_info)) f.write('license: {}\n'.format(FLAGS.export_license)) f.write('language: {}\n'.format(FLAGS.export_language)) f.write('runtime: {}\n'.format(model_runtime)) f.write('min_ds_version: {}\n'.format(FLAGS.export_min_ds_version)) f.write('max_ds_version: {}\n'.format(FLAGS.export_max_ds_version)) f.write('acoustic_model_url: <replace this with a publicly available URL of the acoustic model>\n') f.write('scorer_url: <replace this with a publicly available URL of the scorer, if present>\n') f.write('---\n') f.write('{}\n'.format(FLAGS.export_description)) log_info('Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.'.format(metadata_fname))
def train(): do_cache_dataset = True # pylint: disable=too-many-boolean-expressions if (FLAGS.data_aug_features_multiplicative > 0 or FLAGS.data_aug_features_additive > 0 or FLAGS.augmentation_spec_dropout_keeprate < 1 or FLAGS.augmentation_freq_and_time_masking or FLAGS.augmentation_pitch_and_tempo_scaling or FLAGS.augmentation_speed_up_std > 0 or FLAGS.augmentation_sparse_warp): do_cache_dataset = False exception_box = ExceptionBox() # Create training and validation datasets train_set = create_dataset(FLAGS.train_files.split(','), batch_size=FLAGS.train_batch_size, enable_cache=FLAGS.feature_cache and do_cache_dataset, cache_path=FLAGS.feature_cache, train_phase=True, exception_box=exception_box, process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2, buffering=FLAGS.read_buffer) iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set), tfv1.data.get_output_shapes(train_set), output_classes=tfv1.data.get_output_classes(train_set)) # Make initialization ops for switching between the two sets train_init_op = iterator.make_initializer(train_set) if FLAGS.dev_files: dev_sources = FLAGS.dev_files.split(',') dev_sets = [create_dataset([source], batch_size=FLAGS.dev_batch_size, train_phase=False, exception_box=exception_box, process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2, buffering=FLAGS.read_buffer) for source in dev_sources] dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets] # Dropout dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)] dropout_feed_dict = { dropout_rates[0]: FLAGS.dropout_rate, dropout_rates[1]: FLAGS.dropout_rate2, dropout_rates[2]: FLAGS.dropout_rate3, dropout_rates[3]: FLAGS.dropout_rate4, dropout_rates[4]: FLAGS.dropout_rate5, dropout_rates[5]: FLAGS.dropout_rate6, } no_dropout_feed_dict = { rate: 0. for rate in dropout_rates } # Building the graph learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False) reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction)) optimizer = create_optimizer(learning_rate_var) # Enable mixed precision training if FLAGS.automatic_mixed_precision: log_info('Enabling automatic mixed precision training.') optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates) # Average tower gradients across GPUs avg_tower_gradients = average_gradients(gradients) log_grads_and_vars(avg_tower_gradients) # global_step is automagically incremented by the optimizer global_step = tfv1.train.get_or_create_global_step() apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step) # Summaries step_summaries_op = tfv1.summary.merge_all('step_summaries') step_summary_writers = { 'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120), 'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120) } # Checkpointing checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep) checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train') best_dev_saver = tfv1.train.Saver(max_to_keep=1) best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev') # Save flags next to checkpoints os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True) flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt') with open(flags_file, 'w') as fout: fout.write(FLAGS.flags_into_string()) with tfv1.Session(config=Config.session_config) as session: log_debug('Session opened.') # Prevent further graph changes tfv1.get_default_graph().finalize() # Load checkpoint or initialize variables if FLAGS.load == 'auto': method_order = ['best', 'last', 'init'] else: method_order = [FLAGS.load] load_or_init_graph(session, method_order) def run_set(set_name, epoch, init_op, dataset=None): is_train = set_name == 'train' train_op = apply_gradient_op if is_train else [] feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict total_loss = 0.0 step_count = 0 step_summary_writer = step_summary_writers.get(set_name) checkpoint_time = time.time() # Setup progress bar class LossWidget(progressbar.widgets.FormatLabel): def __init__(self): progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f') def __call__(self, progress, data, **kwargs): data['mean_loss'] = total_loss / step_count if step_count else 0.0 return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs) prefix = 'Epoch {} | {:>10}'.format(epoch, 'Training' if is_train else 'Validation') widgets = [' | ', progressbar.widgets.Timer(), ' | Steps: ', progressbar.widgets.Counter(), ' | ', LossWidget()] suffix = ' | Dataset: {}'.format(dataset) if dataset else None pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start() # Initialize iterator to the appropriate dataset session.run(init_op) # Batch loop while True: try: _, current_step, batch_loss, problem_files, step_summary = \ session.run([train_op, global_step, loss, non_finite_files, step_summaries_op], feed_dict=feed_dict) exception_box.raise_if_set() except tf.errors.InvalidArgumentError as err: if FLAGS.augmentation_sparse_warp: log_info("Ignoring sparse warp error: {}".format(err)) continue else: raise except tf.errors.OutOfRangeError: exception_box.raise_if_set() break if problem_files.size > 0: problem_files = [f.decode('utf8') for f in problem_files[..., 0]] log_error('The following files caused an infinite (or NaN) ' 'loss: {}'.format(','.join(problem_files))) total_loss += batch_loss step_count += 1 pbar.update(step_count) step_summary_writer.add_summary(step_summary, current_step) if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs: checkpoint_saver.save(session, checkpoint_path, global_step=current_step) checkpoint_time = time.time() pbar.finish() mean_loss = total_loss / step_count if step_count > 0 else 0.0 return mean_loss, step_count log_info('STARTING Optimization') train_start_time = datetime.utcnow() best_dev_loss = float('inf') dev_losses = [] epochs_without_improvement = 0 try: for epoch in range(FLAGS.epochs): # Training log_progress('Training epoch %d...' % epoch) train_loss, _ = run_set('train', epoch, train_init_op) log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss)) checkpoint_saver.save(session, checkpoint_path, global_step=global_step) if FLAGS.dev_files: # Validation dev_loss = 0.0 total_steps = 0 for source, init_op in zip(dev_sources, dev_init_ops): log_progress('Validating epoch %d on %s...' % (epoch, source)) set_loss, steps = run_set('dev', epoch, init_op, dataset=source) dev_loss += set_loss * steps total_steps += steps log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss)) dev_loss = dev_loss / total_steps dev_losses.append(dev_loss) # Count epochs without an improvement for early stopping and reduction of learning rate on a plateau # the improvement has to be greater than FLAGS.es_min_delta if dev_loss > best_dev_loss - FLAGS.es_min_delta: epochs_without_improvement += 1 else: epochs_without_improvement = 0 # Save new best model if dev_loss < best_dev_loss: best_dev_loss = dev_loss save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint') log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path)) # Early stopping if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs: log_info('Early stop triggered as the loss did not improve the last {} epochs'.format( epochs_without_improvement)) break # Reduce learning rate on plateau if (FLAGS.reduce_lr_on_plateau and epochs_without_improvement % FLAGS.plateau_epochs == 0 and epochs_without_improvement > 0): # If the learning rate was reduced and there is still no improvement # wait FLAGS.plateau_epochs before the learning rate is reduced again session.run(reduce_learning_rate_op) current_learning_rate = learning_rate_var.eval() log_info('Encountered a plateau, reducing learning rate to {}'.format( current_learning_rate)) except KeyboardInterrupt: pass log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time)) log_debug('Session closed.')
def export(): r''' Restores the trained variables into a simpler graph that will be exported for serving. ''' log_info('Exporting the model...') from tensorflow.python.framework.ops import Tensor, Operation inputs, outputs, _ = create_inference_graph( batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite) graph_version = int(file_relative_read('GRAPH_VERSION').strip()) assert graph_version > 0 outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version') outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate') outputs['metadata_feature_win_len'] = tf.constant( [FLAGS.feature_win_len], name='metadata_feature_win_len') outputs['metadata_feature_win_step'] = tf.constant( [FLAGS.feature_win_step], name='metadata_feature_win_step') outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width') outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet') if FLAGS.export_language: outputs['metadata_language'] = tf.constant( [FLAGS.export_language.encode('utf-8')], name='metadata_language') # Prevent further graph changes tfv1.get_default_graph().finalize() output_names_tensors = [ tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor) ] output_names_ops = [ op.name for op in outputs.values() if isinstance(op, Operation) ] output_names = output_names_tensors + output_names_ops with tf.Session() as session: # Restore variables from checkpoint if FLAGS.load == 'auto': method_order = ['best', 'last'] else: method_order = [FLAGS.load] load_or_init_graph(session, method_order) output_filename = FLAGS.export_name + '.pb' if FLAGS.remove_export: if os.path.isdir(FLAGS.export_dir): log_info('Removing old export') shutil.rmtree(FLAGS.export_dir) output_graph_path = os.path.join(FLAGS.export_dir, output_filename) if not os.path.isdir(FLAGS.export_dir): os.makedirs(FLAGS.export_dir) frozen_graph = tfv1.graph_util.convert_variables_to_constants( sess=session, input_graph_def=tfv1.get_default_graph().as_graph_def(), output_node_names=output_names) frozen_graph = tfv1.graph_util.extract_sub_graph( graph_def=frozen_graph, dest_nodes=output_names) if not FLAGS.export_tflite: with open(output_graph_path, 'wb') as fout: fout.write(frozen_graph.SerializeToString()) else: output_tflite_path = os.path.join( FLAGS.export_dir, output_filename.replace('.pb', '.tflite')) converter = tf.lite.TFLiteConverter( frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values()) converter.optimizations = [tf.lite.Optimize.DEFAULT] # AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite converter.allow_custom_ops = True tflite_model = converter.convert() with open(output_tflite_path, 'wb') as fout: fout.write(tflite_model) log_info('Models exported at %s' % (FLAGS.export_dir))
def evaluate(test_csvs, create_model): if FLAGS.scorer_path: scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet) else: scorer = None test_csvs = FLAGS.test_files.split(',') test_sets = [ create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs ] iterator = tfv1.data.Iterator.from_structure( tfv1.data.get_output_types(test_sets[0]), tfv1.data.get_output_shapes(test_sets[0]), output_classes=tfv1.data.get_output_classes(test_sets[0])) test_init_ops = [ iterator.make_initializer(test_set) for test_set in test_sets ] batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next() # One rate per layer no_dropout = [None] * 6 logits, _ = create_model(batch_x=batch_x, batch_size=FLAGS.test_batch_size, seq_length=batch_x_len, dropout=no_dropout) # Transpose to batch major and apply softmax for decoder transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2])) loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_x_len) tfv1.train.get_or_create_global_step() # Get number of accessible CPU cores for this process try: num_processes = cpu_count() except NotImplementedError: num_processes = 1 with tfv1.Session(config=Config.session_config) as session: if FLAGS.load == 'auto': method_order = ['best', 'last'] else: method_order = [FLAGS.load] load_or_init_graph(session, method_order) def run_test(init_op, dataset): wav_filenames = [] losses = [] predictions = [] ground_truths = [] bar = create_progressbar(prefix='Test epoch | ', widgets=[ 'Steps: ', progressbar.Counter(), ' | ', progressbar.Timer() ]).start() log_progress('Test epoch...') step_count = 0 # Initialize iterator to the appropriate dataset session.run(init_op) # First pass, compute losses and transposed logits for decoding while True: try: batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \ session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y]) except tf.errors.OutOfRangeError: break decoded = ctc_beam_search_decoder_batch( batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width, num_processes=num_processes, scorer=scorer, cutoff_prob=FLAGS.cutoff_prob, cutoff_top_n=FLAGS.cutoff_top_n) predictions.extend(d[0][1] for d in decoded) ground_truths.extend( sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet)) wav_filenames.extend( wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames) losses.extend(batch_loss) step_count += 1 bar.update(step_count) bar.finish() # Print test summary test_samples = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, dataset) return test_samples samples = [] for csv, init_op in zip(test_csvs, test_init_ops): print('Testing model on {}'.format(csv)) samples.extend(run_test(init_op, dataset=csv)) return samples