예제 #1
0
def main(args):

    model_dir = os.path.dirname(os.path.realpath(args.checkpoint))

    hparams = model_utils.load_hparams(model_dir)

    encode_fn, tok_to_text, vocab_size = encoding.get_encoder(
        encoder_dir=model_dir, hparams=hparams)
    hparams[HP_VOCAB_SIZE.name] = vocab_size

    model = build_keras_model(hparams)
    model.load_weights(args.checkpoint)

    audio, sr = preprocessing.tf_load_audio(args.input)

    log_melspec = preprocessing.preprocess_audio(audio=audio,
                                                 sample_rate=sr,
                                                 hparams=hparams)
    log_melspec = tf.expand_dims(log_melspec, axis=0)

    decoder_fn = decoding.greedy_decode_fn(model, hparams)

    decoded = decoder_fn(log_melspec)[0]
    transcription = tok_to_text(decoded)

    print('Transcription:', transcription.numpy().decode('utf8'))
예제 #2
0
def main(args):

    model_dir = os.path.dirname(os.path.realpath(args.checkpoint))

    hparams = model_utils.load_hparams(model_dir)

    _, tok_to_text, vocab_size = encoding.get_encoder(
        encoder_dir=model_dir,
        hparams=hparams)
    hparams[HP_VOCAB_SIZE.name] = vocab_size

    model = build_keras_model(hparams, stateful=True)
    model.load_weights(args.checkpoint)

    decoder_fn = decoding.greedy_decode_fn(model, hparams)

    p = pyaudio.PyAudio()

    def listen_callback(in_data, frame_count, time_info, status):
        global LAST_OUTPUT

        audio = tf.io.decode_raw(in_data, out_type=tf.float32)

        log_melspec = preprocessing.preprocess_audio(
            audio=audio,
            sample_rate=SAMPLE_RATE,
            hparams=hparams)
        log_melspec = tf.expand_dims(log_melspec, axis=0)

        decoded = decoder_fn(log_melspec)[0]

        transcription = LAST_OUTPUT + tok_to_text(decoded)\
            .numpy().decode('utf8')

        if transcription != LAST_OUTPUT:
            LAST_OUTPUT = transcription
            print(transcription)

        return in_data, pyaudio.paContinue

    stream = p.open(
        format=pyaudio.paFloat32,
        channels=NUM_CHANNELS,
        rate=SAMPLE_RATE,
        input=True,
        frames_per_buffer=CHUNK_SIZE,
        stream_callback=listen_callback)

    print('Listening...')

    stream.start_stream()

    while stream.is_active():
        time.sleep(0.1)

    stream.stop_stream()
    stream.close()

    p.terminate()
def main(_):

    strategy, dtype = configure_environment(
        gpu_names=FLAGS.gpus,
        fp16_run=FLAGS.fp16_run)

    hparams, tb_hparams = setup_hparams(
        log_dir=FLAGS.tb_log_dir,
        checkpoint=FLAGS.checkpoint)

    os.makedirs(FLAGS.output_dir, exist_ok=True)

    if FLAGS.checkpoint is None:
        encoder_dir = FLAGS.data_dir
    else:
        encoder_dir = os.path.dirname(os.path.realpath(FLAGS.checkpoint))

    shutil.copy(
        os.path.join(encoder_dir, 'encoder.subwords'),
        os.path.join(FLAGS.output_dir, 'encoder.subwords'))

    encoder_fn, idx_to_text, vocab_size = encoding.get_encoder(
        encoder_dir=FLAGS.output_dir,
        hparams=hparams)

    if HP_VOCAB_SIZE.name not in hparams:
        hparams[HP_VOCAB_SIZE.name] = vocab_size

    with strategy.scope():

        model = build_keras_model(hparams,
            dtype=dtype)

        if FLAGS.checkpoint is not None:
            model.load_weights(FLAGS.checkpoint)
            logging.info('Restored weights from {}.'.format(FLAGS.checkpoint))

        model_utils.save_hparams(hparams, FLAGS.output_dir)

        optimizer = tf.keras.optimizers.SGD(hparams[HP_LEARNING_RATE.name],
            momentum=0.9)

        if FLAGS.fp16_run:
            optimizer = mixed_precision.LossScaleOptimizer(optimizer,
                loss_scale='dynamic')

    logging.info('Using {} encoder with vocab size: {}'.format(
        hparams[HP_TOKEN_TYPE.name], vocab_size))

    loss_fn = get_loss_fn(
        reduction_factor=hparams[HP_TIME_REDUCT_FACTOR.name])

    decode_fn = decoding.greedy_decode_fn(model, hparams)

    accuracy_fn = metrics.build_accuracy_fn(decode_fn)
    wer_fn = metrics.build_wer_fn(decode_fn, idx_to_text)

    encoder = model.layers[2]
    prediction_network = model.layers[3]

    encoder.summary()
    prediction_network.summary()

    model.summary()

    dev_dataset = None
    if FLAGS.eval_size != 0:
        dev_dataset = get_dataset(FLAGS.data_dir, 'dev',
            batch_size=FLAGS.batch_size, n_epochs=FLAGS.n_epochs,
            strategy=strategy, max_size=FLAGS.eval_size)

    log_dir = os.path.join(FLAGS.tb_log_dir,
        datetime.now().strftime('%Y%m%d-%H%M%S'))

    with tf.summary.create_file_writer(log_dir).as_default():

        hp.hparams(tb_hparams)

        if FLAGS.mode == 'train':

            train_dataset = get_dataset(FLAGS.data_dir, 'train',
                batch_size=FLAGS.batch_size, n_epochs=FLAGS.n_epochs,
                strategy=strategy)

            os.makedirs(FLAGS.output_dir, exist_ok=True)
            checkpoint_template = os.path.join(FLAGS.output_dir,
                'checkpoint_{step}_{val_loss:.4f}.hdf5')

            run_training(
                model=model,
                optimizer=optimizer,
                loss_fn=loss_fn,
                train_dataset=train_dataset,
                batch_size=FLAGS.batch_size,
                n_epochs=FLAGS.n_epochs,
                checkpoint_template=checkpoint_template,
                hparams=hparams,
                strategy=strategy,
                steps_per_log=FLAGS.steps_per_log,
                steps_per_checkpoint=FLAGS.steps_per_checkpoint,
                eval_dataset=dev_dataset,
                train_metrics=[],
                eval_metrics=[accuracy_fn, wer_fn])

        elif FLAGS.mode == 'eval' or FLAGS.mode == 'test':

            if FLAGS.checkpoint is None:
                raise Exception('You must provide a checkpoint to perform eval.')

            if FLAGS.mode == 'test':
                dataset = get_dataset(FLAGS.data_dir, 'test',
                    batch_size=FLAGS.batch_size, n_epochs=FLAGS.n_epochs)
            else:
                dataset = dev_dataset

            eval_start_time = time.time()

            eval_loss, eval_metrics_results = run_evaluate(
                model=model,
                optimizer=optimizer,
                loss_fn=loss_fn,
                eval_dataset=dataset,
                batch_size=FLAGS.batch_size,
                hparams=hparams,
                strategy=strategy,
                metrics=[accuracy_fn, wer_fn],
                gpus=gpus)

            validation_log_str = 'VALIDATION RESULTS: Time: {:.4f}, Loss: {:.4f}'.format(
                time.time() - eval_start_time, eval_loss)
            for metric_name, metric_result in eval_metrics_results.items():
                validation_log_str += ', {}: {:.4f}'.format(metric_name, metric_result)

            print(validation_log_str)
예제 #4
0
def main(_):

    hparams = {
        HP_TOKEN_TYPE: HP_TOKEN_TYPE.domain.values[1],
        HP_VOCAB_SIZE: HP_VOCAB_SIZE.domain.values[0],

        # Preprocessing
        HP_MEL_BINS: HP_MEL_BINS.domain.values[0],
        HP_FRAME_LENGTH: HP_FRAME_LENGTH.domain.values[0],
        HP_FRAME_STEP: HP_FRAME_STEP.domain.values[0],
        HP_HERTZ_LOW: HP_HERTZ_LOW.domain.values[0],
        HP_HERTZ_HIGH: HP_HERTZ_HIGH.domain.values[0],
        HP_DOWNSAMPLE_FACTOR: HP_DOWNSAMPLE_FACTOR.domain.values[0]
    }

    train_splits = ['dev-clean']

    dev_splits = ['dev-clean']

    test_splits = ['dev-clean']

    # train_splits = [
    #     'train-clean-100',
    #     'train-clean-360',
    #     'train-other-500'
    # ]

    # dev_splits = [
    #     'dev-clean',
    #     'dev-other'
    # ]

    # test_splits = [
    #     'test-clean',
    #     'test-other'
    # ]

    _hparams = {k.name: v for k, v in hparams.items()}

    texts_gen = librispeech.texts_generator(FLAGS.data_dir,
                                            split_names=train_splits)

    encoder_fn, decoder_fn, vocab_size = encoding.get_encoder(
        output_dir=FLAGS.output_dir,
        hparams=_hparams,
        texts_generator=texts_gen)
    _hparams[HP_VOCAB_SIZE.name] = vocab_size

    train_dataset = librispeech.load_dataset(FLAGS.data_dir, train_splits)
    dev_dataset = librispeech.load_dataset(FLAGS.data_dir, dev_splits)
    test_dataset = librispeech.load_dataset(FLAGS.data_dir, test_splits)

    train_dataset = preprocessing.preprocess_dataset(
        train_dataset,
        encoder_fn=encoder_fn,
        hparams=_hparams,
        max_length=FLAGS.max_length,
        save_plots=True)
    write_dataset(train_dataset, 'train')

    dev_dataset = preprocessing.preprocess_dataset(dev_dataset,
                                                   encoder_fn=encoder_fn,
                                                   hparams=_hparams,
                                                   max_length=FLAGS.max_length)
    write_dataset(dev_dataset, 'dev')

    test_dataset = preprocessing.preprocess_dataset(
        test_dataset,
        encoder_fn=encoder_fn,
        hparams=_hparams,
        max_length=FLAGS.max_length)
    write_dataset(test_dataset, 'test')