def character_based(): is_character_based = False if FLAGS.scorer_path: scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet) is_character_based = scorer.is_utf8_mode() return is_character_based
def transcribe_file(audio_path, tlog_path): from mozilla_voice_stt_training.train import create_model # pylint: disable=cyclic-import,import-outside-toplevel from mozilla_voice_stt_training.util.checkpoints import load_graph_for_evaluation initialize_globals() scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet) try: num_processes = cpu_count() except NotImplementedError: num_processes = 1 with AudioFile(audio_path, as_path=True) as wav_path: data_set = split_audio_file( wav_path, batch_size=FLAGS.batch_size, aggressiveness=FLAGS.vad_aggressiveness, outlier_duration_ms=FLAGS.outlier_duration_ms, outlier_batch_size=FLAGS.outlier_batch_size) iterator = tf.data.Iterator.from_structure( data_set.output_types, data_set.output_shapes, output_classes=data_set.output_classes) batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next( ) no_dropout = [None] * 6 logits, _ = create_model(batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout) transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2])) tf.train.get_or_create_global_step() with tf.Session(config=Config.session_config) as session: load_graph_for_evaluation(session) session.run(iterator.make_initializer(data_set)) transcripts = [] while True: try: starts, ends, batch_logits, batch_lengths = \ session.run([batch_time_start, batch_time_end, transposed, batch_x_len]) except tf.errors.OutOfRangeError: break decoded = ctc_beam_search_decoder_batch( batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width, num_processes=num_processes, scorer=scorer) decoded = list(d[0][1] for d in decoded) transcripts.extend(zip(starts, ends, decoded)) transcripts.sort(key=lambda t: t[0]) transcripts = [{ 'start': int(start), 'end': int(end), 'transcript': transcript } for start, end, transcript in transcripts] with open(tlog_path, 'w') as tlog_file: json.dump(transcripts, tlog_file, default=float)
def early_training_checks(): # Check for proper scorer early if FLAGS.scorer_path: scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet) del scorer if FLAGS.train_files and FLAGS.test_files and FLAGS.load_checkpoint_dir != FLAGS.save_checkpoint_dir: log_warn('WARNING: You specified different values for --load_checkpoint_dir ' 'and --save_checkpoint_dir, but you are running training and testing ' 'in a single invocation. The testing step will respect --load_checkpoint_dir, ' 'and thus WILL NOT TEST THE CHECKPOINT CREATED BY THE TRAINING STEP. ' 'Train and test in two separate invocations, specifying the correct ' '--load_checkpoint_dir in both cases, or use the same location ' 'for loading and saving.')
def do_single_file_inference(input_file_path): with tfv1.Session(config=Config.session_config) as session: inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1) # Restore variables from training checkpoint load_graph_for_evaluation(session) features, features_len = audiofile_to_features(input_file_path) previous_state_c = np.zeros([1, Config.n_cell_dim]) previous_state_h = np.zeros([1, Config.n_cell_dim]) # Add batch dimension features = tf.expand_dims(features, 0) features_len = tf.expand_dims(features_len, 0) # Evaluate features = create_overlapping_windows(features).eval(session=session) features_len = features_len.eval(session=session) logits = outputs['outputs'].eval(feed_dict={ inputs['input']: features, inputs['input_lengths']: features_len, inputs['previous_state_c']: previous_state_c, inputs['previous_state_h']: previous_state_h, }, session=session) logits = np.squeeze(logits) if FLAGS.scorer_path: scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet) else: scorer = None decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width, scorer=scorer, cutoff_prob=FLAGS.cutoff_prob, cutoff_top_n=FLAGS.cutoff_top_n) # Print highest probability result print(decoded[0][1])
def evaluate(test_csvs, create_model): if FLAGS.scorer_path: scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet) else: scorer = None test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False, reverse=FLAGS.reverse_test, limit=FLAGS.limit_test) for csv in test_csvs] iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]), tfv1.data.get_output_shapes(test_sets[0]), output_classes=tfv1.data.get_output_classes(test_sets[0])) test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets] batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next() # One rate per layer no_dropout = [None] * 6 logits, _ = create_model(batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout) # Transpose to batch major and apply softmax for decoder transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2])) loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_x_len) tfv1.train.get_or_create_global_step() # Get number of accessible CPU cores for this process try: num_processes = cpu_count() except NotImplementedError: num_processes = 1 with tfv1.Session(config=Config.session_config) as session: load_graph_for_evaluation(session) def run_test(init_op, dataset): wav_filenames = [] losses = [] predictions = [] ground_truths = [] bar = create_progressbar(prefix='Test epoch | ', widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start() log_progress('Test epoch...') step_count = 0 # Initialize iterator to the appropriate dataset session.run(init_op) # First pass, compute losses and transposed logits for decoding while True: try: batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \ session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y]) except tf.errors.OutOfRangeError: break decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width, num_processes=num_processes, scorer=scorer, cutoff_prob=FLAGS.cutoff_prob, cutoff_top_n=FLAGS.cutoff_top_n) predictions.extend(d[0][1] for d in decoded) ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet)) wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames) losses.extend(batch_loss) step_count += 1 bar.update(step_count) bar.finish() # Print test summary test_samples = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, dataset) return test_samples samples = [] for csv, init_op in zip(test_csvs, test_init_ops): print('Testing model on {}'.format(csv)) samples.extend(run_test(init_op, dataset=csv)) return samples