def main(args): model_dir = os.path.dirname(os.path.realpath(args.checkpoint)) hparams = model_utils.load_hparams(model_dir) encode_fn, tok_to_text, vocab_size = encoding.get_encoder( encoder_dir=model_dir, hparams=hparams) hparams[HP_VOCAB_SIZE.name] = vocab_size model = build_keras_model(hparams) model.load_weights(args.checkpoint) audio, sr = preprocessing.tf_load_audio(args.input) log_melspec = preprocessing.preprocess_audio(audio=audio, sample_rate=sr, hparams=hparams) log_melspec = tf.expand_dims(log_melspec, axis=0) decoder_fn = decoding.greedy_decode_fn(model, hparams) decoded = decoder_fn(log_melspec)[0] transcription = tok_to_text(decoded) print('Transcription:', transcription.numpy().decode('utf8'))
def main(args): model_dir = os.path.dirname(os.path.realpath(args.checkpoint)) hparams = model_utils.load_hparams(model_dir) _, tok_to_text, vocab_size = encoding.get_encoder( encoder_dir=model_dir, hparams=hparams) hparams[HP_VOCAB_SIZE.name] = vocab_size model = build_keras_model(hparams, stateful=True) model.load_weights(args.checkpoint) decoder_fn = decoding.greedy_decode_fn(model, hparams) p = pyaudio.PyAudio() def listen_callback(in_data, frame_count, time_info, status): global LAST_OUTPUT audio = tf.io.decode_raw(in_data, out_type=tf.float32) log_melspec = preprocessing.preprocess_audio( audio=audio, sample_rate=SAMPLE_RATE, hparams=hparams) log_melspec = tf.expand_dims(log_melspec, axis=0) decoded = decoder_fn(log_melspec)[0] transcription = LAST_OUTPUT + tok_to_text(decoded)\ .numpy().decode('utf8') if transcription != LAST_OUTPUT: LAST_OUTPUT = transcription print(transcription) return in_data, pyaudio.paContinue stream = p.open( format=pyaudio.paFloat32, channels=NUM_CHANNELS, rate=SAMPLE_RATE, input=True, frames_per_buffer=CHUNK_SIZE, stream_callback=listen_callback) print('Listening...') stream.start_stream() while stream.is_active(): time.sleep(0.1) stream.stop_stream() stream.close() p.terminate()
def main(_): strategy, dtype = configure_environment( gpu_names=FLAGS.gpus, fp16_run=FLAGS.fp16_run) hparams, tb_hparams = setup_hparams( log_dir=FLAGS.tb_log_dir, checkpoint=FLAGS.checkpoint) os.makedirs(FLAGS.output_dir, exist_ok=True) if FLAGS.checkpoint is None: encoder_dir = FLAGS.data_dir else: encoder_dir = os.path.dirname(os.path.realpath(FLAGS.checkpoint)) shutil.copy( os.path.join(encoder_dir, 'encoder.subwords'), os.path.join(FLAGS.output_dir, 'encoder.subwords')) encoder_fn, idx_to_text, vocab_size = encoding.get_encoder( encoder_dir=FLAGS.output_dir, hparams=hparams) if HP_VOCAB_SIZE.name not in hparams: hparams[HP_VOCAB_SIZE.name] = vocab_size with strategy.scope(): model = build_keras_model(hparams, dtype=dtype) if FLAGS.checkpoint is not None: model.load_weights(FLAGS.checkpoint) logging.info('Restored weights from {}.'.format(FLAGS.checkpoint)) model_utils.save_hparams(hparams, FLAGS.output_dir) optimizer = tf.keras.optimizers.SGD(hparams[HP_LEARNING_RATE.name], momentum=0.9) if FLAGS.fp16_run: optimizer = mixed_precision.LossScaleOptimizer(optimizer, loss_scale='dynamic') logging.info('Using {} encoder with vocab size: {}'.format( hparams[HP_TOKEN_TYPE.name], vocab_size)) loss_fn = get_loss_fn( reduction_factor=hparams[HP_TIME_REDUCT_FACTOR.name]) decode_fn = decoding.greedy_decode_fn(model, hparams) accuracy_fn = metrics.build_accuracy_fn(decode_fn) wer_fn = metrics.build_wer_fn(decode_fn, idx_to_text) encoder = model.layers[2] prediction_network = model.layers[3] encoder.summary() prediction_network.summary() model.summary() dev_dataset = None if FLAGS.eval_size != 0: dev_dataset = get_dataset(FLAGS.data_dir, 'dev', batch_size=FLAGS.batch_size, n_epochs=FLAGS.n_epochs, strategy=strategy, max_size=FLAGS.eval_size) log_dir = os.path.join(FLAGS.tb_log_dir, datetime.now().strftime('%Y%m%d-%H%M%S')) with tf.summary.create_file_writer(log_dir).as_default(): hp.hparams(tb_hparams) if FLAGS.mode == 'train': train_dataset = get_dataset(FLAGS.data_dir, 'train', batch_size=FLAGS.batch_size, n_epochs=FLAGS.n_epochs, strategy=strategy) os.makedirs(FLAGS.output_dir, exist_ok=True) checkpoint_template = os.path.join(FLAGS.output_dir, 'checkpoint_{step}_{val_loss:.4f}.hdf5') run_training( model=model, optimizer=optimizer, loss_fn=loss_fn, train_dataset=train_dataset, batch_size=FLAGS.batch_size, n_epochs=FLAGS.n_epochs, checkpoint_template=checkpoint_template, hparams=hparams, strategy=strategy, steps_per_log=FLAGS.steps_per_log, steps_per_checkpoint=FLAGS.steps_per_checkpoint, eval_dataset=dev_dataset, train_metrics=[], eval_metrics=[accuracy_fn, wer_fn]) elif FLAGS.mode == 'eval' or FLAGS.mode == 'test': if FLAGS.checkpoint is None: raise Exception('You must provide a checkpoint to perform eval.') if FLAGS.mode == 'test': dataset = get_dataset(FLAGS.data_dir, 'test', batch_size=FLAGS.batch_size, n_epochs=FLAGS.n_epochs) else: dataset = dev_dataset eval_start_time = time.time() eval_loss, eval_metrics_results = run_evaluate( model=model, optimizer=optimizer, loss_fn=loss_fn, eval_dataset=dataset, batch_size=FLAGS.batch_size, hparams=hparams, strategy=strategy, metrics=[accuracy_fn, wer_fn], gpus=gpus) validation_log_str = 'VALIDATION RESULTS: Time: {:.4f}, Loss: {:.4f}'.format( time.time() - eval_start_time, eval_loss) for metric_name, metric_result in eval_metrics_results.items(): validation_log_str += ', {}: {:.4f}'.format(metric_name, metric_result) print(validation_log_str)
def main(_): hparams = { HP_TOKEN_TYPE: HP_TOKEN_TYPE.domain.values[1], HP_VOCAB_SIZE: HP_VOCAB_SIZE.domain.values[0], # Preprocessing HP_MEL_BINS: HP_MEL_BINS.domain.values[0], HP_FRAME_LENGTH: HP_FRAME_LENGTH.domain.values[0], HP_FRAME_STEP: HP_FRAME_STEP.domain.values[0], HP_HERTZ_LOW: HP_HERTZ_LOW.domain.values[0], HP_HERTZ_HIGH: HP_HERTZ_HIGH.domain.values[0], HP_DOWNSAMPLE_FACTOR: HP_DOWNSAMPLE_FACTOR.domain.values[0] } train_splits = ['dev-clean'] dev_splits = ['dev-clean'] test_splits = ['dev-clean'] # train_splits = [ # 'train-clean-100', # 'train-clean-360', # 'train-other-500' # ] # dev_splits = [ # 'dev-clean', # 'dev-other' # ] # test_splits = [ # 'test-clean', # 'test-other' # ] _hparams = {k.name: v for k, v in hparams.items()} texts_gen = librispeech.texts_generator(FLAGS.data_dir, split_names=train_splits) encoder_fn, decoder_fn, vocab_size = encoding.get_encoder( output_dir=FLAGS.output_dir, hparams=_hparams, texts_generator=texts_gen) _hparams[HP_VOCAB_SIZE.name] = vocab_size train_dataset = librispeech.load_dataset(FLAGS.data_dir, train_splits) dev_dataset = librispeech.load_dataset(FLAGS.data_dir, dev_splits) test_dataset = librispeech.load_dataset(FLAGS.data_dir, test_splits) train_dataset = preprocessing.preprocess_dataset( train_dataset, encoder_fn=encoder_fn, hparams=_hparams, max_length=FLAGS.max_length, save_plots=True) write_dataset(train_dataset, 'train') dev_dataset = preprocessing.preprocess_dataset(dev_dataset, encoder_fn=encoder_fn, hparams=_hparams, max_length=FLAGS.max_length) write_dataset(dev_dataset, 'dev') test_dataset = preprocessing.preprocess_dataset( test_dataset, encoder_fn=encoder_fn, hparams=_hparams, max_length=FLAGS.max_length) write_dataset(test_dataset, 'test')