def main(args):

    model_dir = os.path.dirname(os.path.realpath(args.checkpoint))

    hparams = model_utils.load_hparams(model_dir)

    encode_fn, tok_to_text, vocab_size = encoding.load_encoder(model_dir, 
        hparams=hparams)
    hparams[HP_VOCAB_SIZE.name] = vocab_size

    start_token = encode_fn('')[0]

    model = build_keras_model(hparams)
    model.load_weights(args.checkpoint)
    
    audio, sr = preprocessing.tf_load_audio(args.input)

    mel_specs = preprocessing.compute_mel_spectrograms(
        audio_arr=audio,
        sample_rate=sr,
        n_mel_bins=hparams[HP_MEL_BINS.name],
        frame_length=hparams[HP_FRAME_LENGTH.name],
        frame_step=hparams[HP_FRAME_STEP.name],
        hertz_low=hparams[HP_HERTZ_LOW.name],
        hertz_high=hparams[HP_HERTZ_HIGH.name])

    mel_specs = tf.expand_dims(mel_specs, axis=0)

    decoder_fn = decoding.greedy_decode_fn(model, 
        start_token=start_token)

    decoded = decoder_fn(mel_specs)[0]
    transcription = tok_to_text(decoded)

    print('Transcription:', transcription.numpy().decode('utf8'))
Example #2
0
def transcribe_file():
    if os.path.exists(os.path.join(FLAGS.model_dir, 'hparams.json')):
        _hparams = model_utils.load_hparams(FLAGS.model_dir)
        encoder_fn, vocab_size = encoding.load_encoder(FLAGS.model_dir,
                                                       hparams=_hparams)
        model, loss_fn = model_utils.load_model(FLAGS.model_dir,
                                                vocab_size=vocab_size,
                                                hparams=_hparams)
        optimizer = tf.keras.optimizers.Adam(_hparams[HP_LEARNING_RATE.name])
        model.compile(loss=loss_fn,
                      optimizer=optimizer,
                      experimental_run_tf_function=False)
    else:
        print('need afford model_dir ')
        return
    transcription = model.predict(FLAGS.test_file)
    print('Input file: {}'.format(FLAGS.input))
    print('Transcription: {}'.format(transcription))
Example #3
0
def train():

    hparams = {
        HP_TOKEN_TYPE: HP_TOKEN_TYPE.domain.values[0],
        HP_VOCAB_SIZE: HP_VOCAB_SIZE.domain.values[0],

        # Preprocessing
        HP_MEL_BINS: HP_MEL_BINS.domain.values[0],
        HP_FRAME_LENGTH: HP_FRAME_LENGTH.domain.values[0],
        HP_FRAME_STEP: HP_FRAME_STEP.domain.values[0],
        HP_HERTZ_LOW: HP_HERTZ_LOW.domain.values[0],
        HP_HERTZ_HIGH: HP_HERTZ_HIGH.domain.values[0],

        # Model
        HP_EMBEDDING_SIZE: HP_EMBEDDING_SIZE.domain.values[0],
        HP_ENCODER_LAYERS: HP_ENCODER_LAYERS.domain.values[0],
        HP_ENCODER_SIZE: HP_ENCODER_SIZE.domain.values[0],
        HP_TIME_REDUCT_INDEX: HP_TIME_REDUCT_INDEX.domain.values[0],
        HP_TIME_REDUCT_FACTOR: HP_TIME_REDUCT_FACTOR.domain.values[0],
        HP_PRED_NET_LAYERS: HP_PRED_NET_LAYERS.domain.values[0],
        HP_JOINT_NET_SIZE: HP_JOINT_NET_SIZE.domain.values[0],
        HP_SOFTMAX_SIZE: HP_SOFTMAX_SIZE.domain.values[0],
        HP_LEARNING_RATE: HP_LEARNING_RATE.domain.values[0]
    }

    if os.path.exists(os.path.join(FLAGS.model_dir, 'hparams.json')):

        _hparams = model_utils.load_hparams(FLAGS.model_dir)

        encoder_fn, vocab_size = encoding.load_encoder(FLAGS.model_dir,
                                                       hparams=_hparams)

        model, loss_fn = model_utils.load_model(FLAGS.model_dir,
                                                vocab_size=vocab_size,
                                                hparams=_hparams)

    else:

        _hparams = {k.name: v for k, v in hparams.items()}

        texts_gen = common_voice.texts_generator(FLAGS.data_dir)

        encoder_fn, vocab_size = encoding.build_encoder(
            texts_gen, model_dir=FLAGS.model_dir, hparams=_hparams)

        model, loss_fn = build_keras_model(vocab_size, _hparams)

    logging.info('Using {} encoder with vocab size: {}'.format(
        _hparams[HP_TOKEN_TYPE.name], vocab_size))

    dataset_fn = get_dataset_fn(FLAGS.data_dir,
                                encoder_fn=encoder_fn,
                                batch_size=FLAGS.batch_size,
                                hparams=hparams)

    train_dataset, train_steps = dataset_fn('train')
    dev_dataset, dev_steps = dataset_fn('dev')

    optimizer = tf.keras.optimizers.Adam(_hparams[HP_LEARNING_RATE.name])

    model.compile(loss=loss_fn,
                  optimizer=optimizer,
                  experimental_run_tf_function=False)

    os.makedirs(FLAGS.model_dir, exist_ok=True)
    checkpoint_fp = os.path.join(FLAGS.model_dir,
                                 'model.{epoch:03d}-{val_loss:.4f}.hdf5')

    model_utils.save_hparams(_hparams, FLAGS.model_dir)

    model.fit(train_dataset,
              epochs=FLAGS.n_epochs,
              steps_per_epoch=train_steps,
              validation_data=dev_dataset,
              validation_steps=dev_steps,
              callbacks=[
                  tf.keras.callbacks.TensorBoard(FLAGS.tb_log_dir),
                  tf.keras.callbacks.ModelCheckpoint(checkpoint_fp,
                                                     save_weights_only=True)
              ])
Example #4
0
def main(_):

    configure_environment(FLAGS.fp16_run)

    hparams = {
        HP_TOKEN_TYPE: HP_TOKEN_TYPE.domain.values[1],

        # Preprocessing
        HP_MEL_BINS: HP_MEL_BINS.domain.values[0],
        HP_FRAME_LENGTH: HP_FRAME_LENGTH.domain.values[0],
        HP_FRAME_STEP: HP_FRAME_STEP.domain.values[0],
        HP_HERTZ_LOW: HP_HERTZ_LOW.domain.values[0],
        HP_HERTZ_HIGH: HP_HERTZ_HIGH.domain.values[0],

        # Model
        HP_EMBEDDING_SIZE: HP_EMBEDDING_SIZE.domain.values[0],
        HP_ENCODER_LAYERS: HP_ENCODER_LAYERS.domain.values[0],
        HP_ENCODER_SIZE: HP_ENCODER_SIZE.domain.values[0],
        HP_PROJECTION_SIZE: HP_PROJECTION_SIZE.domain.values[0],
        HP_TIME_REDUCT_INDEX: HP_TIME_REDUCT_INDEX.domain.values[0],
        HP_TIME_REDUCT_FACTOR: HP_TIME_REDUCT_FACTOR.domain.values[0],
        HP_PRED_NET_LAYERS: HP_PRED_NET_LAYERS.domain.values[0],
        HP_PRED_NET_SIZE: HP_PRED_NET_SIZE.domain.values[0],
        HP_JOINT_NET_SIZE: HP_JOINT_NET_SIZE.domain.values[0],
        HP_LEARNING_RATE: HP_LEARNING_RATE.domain.values[0]
    }

    with tf.summary.create_file_writer(
            os.path.join(FLAGS.tb_log_dir, 'hparams_tuning')).as_default():
        hp.hparams_config(
            hparams=[
                HP_TOKEN_TYPE, HP_VOCAB_SIZE, HP_EMBEDDING_SIZE,
                HP_ENCODER_LAYERS, HP_ENCODER_SIZE, HP_PROJECTION_SIZE,
                HP_TIME_REDUCT_INDEX, HP_TIME_REDUCT_FACTOR,
                HP_PRED_NET_LAYERS, HP_PRED_NET_SIZE, HP_JOINT_NET_SIZE
            ],
            metrics=[
                hp.Metric(METRIC_ACCURACY, display_name='Accuracy'),
                hp.Metric(METRIC_CER, display_name='CER'),
                hp.Metric(METRIC_WER, display_name='WER'),
            ],
        )

    _hparams = {k.name: v for k, v in hparams.items()}

    if len(FLAGS) == 0:
        gpus = [
            x.name.strip('/physical_device:')
            for x in tf.config.experimental.list_physical_devices('GPU')
        ]
    else:
        gpus = ['GPU:' + str(gpu_id) for gpu_id in FLAGS.gpus]

    strategy = tf.distribute.MirroredStrategy(devices=gpus)
    # strategy = None

    dtype = tf.float32
    if FLAGS.fp16_run:
        dtype = tf.float16

    # initializer = tf.keras.initializers.RandomUniform(
    #     minval=-0.1, maxval=0.1)
    initializer = None

    if FLAGS.checkpoint is not None:

        checkpoint_dir = os.path.dirname(os.path.realpath(FLAGS.checkpoint))

        _hparams = model_utils.load_hparams(checkpoint_dir)
        encoder_fn, idx_to_text, vocab_size = encoding.load_encoder(
            checkpoint_dir, hparams=_hparams)

        if strategy is not None:
            with strategy.scope():
                model = build_keras_model(_hparams,
                                          initializer=initializer,
                                          dtype=dtype)
                model.load_weights(FLAGS.checkpoint)
        else:
            model = build_keras_model(_hparams,
                                      initializer=initializer,
                                      dtype=dtype)
            model.load_weights(FLAGS.checkpoint)

        logging.info('Restored weights from {}.'.format(FLAGS.checkpoint))

    else:

        os.makedirs(FLAGS.output_dir, exist_ok=True)

        shutil.copy(os.path.join(FLAGS.data_dir, 'encoder.subwords'),
                    os.path.join(FLAGS.output_dir, 'encoder.subwords'))

        encoder_fn, idx_to_text, vocab_size = encoding.load_encoder(
            FLAGS.output_dir, hparams=_hparams)
        _hparams[HP_VOCAB_SIZE.name] = vocab_size

        if strategy is not None:
            with strategy.scope():
                model = build_keras_model(_hparams,
                                          initializer=initializer,
                                          dtype=dtype)
        else:
            model = build_keras_model(_hparams,
                                      initializer=initializer,
                                      dtype=dtype)
        model_utils.save_hparams(_hparams, FLAGS.output_dir)

    logging.info('Using {} encoder with vocab size: {}'.format(
        _hparams[HP_TOKEN_TYPE.name], vocab_size))

    loss_fn = get_loss_fn(
        reduction_factor=_hparams[HP_TIME_REDUCT_FACTOR.name])

    start_token = encoder_fn('')[0]
    decode_fn = decoding.greedy_decode_fn(model, start_token=start_token)

    accuracy_fn = metrics.build_accuracy_fn(decode_fn)
    cer_fn = metrics.build_cer_fn(decode_fn, idx_to_text)
    wer_fn = metrics.build_wer_fn(decode_fn, idx_to_text)

    optimizer = tf.keras.optimizers.SGD(_hparams[HP_LEARNING_RATE.name],
                                        momentum=0.9)

    if FLAGS.fp16_run:
        optimizer = mixed_precision.LossScaleOptimizer(optimizer,
                                                       loss_scale='dynamic')

    encoder = model.layers[2]
    prediction_network = model.layers[3]

    encoder.summary()
    prediction_network.summary()

    model.summary()

    dev_dataset, _ = get_dataset(FLAGS.data_dir,
                                 'dev',
                                 batch_size=FLAGS.batch_size,
                                 n_epochs=FLAGS.n_epochs,
                                 strategy=strategy,
                                 max_size=FLAGS.eval_size)
    # dev_steps = dev_specs['size'] // FLAGS.batch_size

    log_dir = os.path.join(FLAGS.tb_log_dir,
                           datetime.now().strftime('%Y%m%d-%H%M%S'))

    with tf.summary.create_file_writer(log_dir).as_default():

        hp.hparams(hparams)

        if FLAGS.mode == 'train':

            train_dataset, _ = get_dataset(FLAGS.data_dir,
                                           'train',
                                           batch_size=FLAGS.batch_size,
                                           n_epochs=FLAGS.n_epochs,
                                           strategy=strategy)
            # train_steps = train_specs['size'] // FLAGS.batch_size

            os.makedirs(FLAGS.output_dir, exist_ok=True)
            checkpoint_template = os.path.join(
                FLAGS.output_dir, 'checkpoint_{step}_{val_loss:.4f}.hdf5')

            run_training(model=model,
                         optimizer=optimizer,
                         loss_fn=loss_fn,
                         train_dataset=train_dataset,
                         batch_size=FLAGS.batch_size,
                         n_epochs=FLAGS.n_epochs,
                         checkpoint_template=checkpoint_template,
                         strategy=strategy,
                         steps_per_log=FLAGS.steps_per_log,
                         steps_per_checkpoint=FLAGS.steps_per_checkpoint,
                         eval_dataset=dev_dataset,
                         train_metrics=[],
                         eval_metrics=[accuracy_fn, cer_fn, wer_fn],
                         gpus=gpus)

        elif FLAGS.mode == 'eval' or FLAGS.mode == 'test':

            if FLAGS.checkpoint is None:
                raise Exception(
                    'You must provide a checkpoint to perform eval.')

            if FLAGS.mode == 'test':
                dataset, test_specs = get_dataset(FLAGS.data_dir,
                                                  'test',
                                                  batch_size=FLAGS.batch_size,
                                                  n_epochs=FLAGS.n_epochs)
            else:
                dataset = dev_dataset

            eval_start_time = time.time()

            eval_loss, eval_metrics_results = run_evaluate(
                model=model,
                optimizer=optimizer,
                loss_fn=loss_fn,
                eval_dataset=dataset,
                batch_size=FLAGS.batch_size,
                strategy=strategy,
                metrics=[accuracy_fn, cer_fn, wer_fn],
                gpus=gpus)

            validation_log_str = 'VALIDATION RESULTS: Time: {:.4f}, Loss: {:.4f}'.format(
                time.time() - eval_start_time, eval_loss)
            for metric_name, metric_result in eval_metrics_results.items():
                validation_log_str += ', {}: {:.4f}'.format(
                    metric_name, metric_result)

            print(validation_log_str)
def main(args):

    model_dir = os.path.dirname(os.path.realpath(args.checkpoint))

    hparams = model_utils.load_hparams(model_dir)

    _, tok_to_text, vocab_size = encoding.load_encoder(model_dir,
                                                       hparams=hparams)
    hparams[HP_VOCAB_SIZE.name] = vocab_size

    model = build_keras_model(hparams, stateful=True)
    model.load_weights(args.checkpoint)

    decoder_fn = decoding.greedy_decode_fn(model)

    p = pyaudio.PyAudio()

    def listen_callback(in_data, frame_count, time_info, status):
        global LAST_OUTPUT

        audio = tf.io.decode_raw(in_data, out_type=tf.float32)

        mel_specs = preprocessing.compute_mel_spectrograms(
            audio_arr=audio,
            sample_rate=SAMPLE_RATE,
            n_mel_bins=hparams[HP_MEL_BINS.name],
            frame_length=hparams[HP_FRAME_LENGTH.name],
            frame_step=hparams[HP_FRAME_STEP.name],
            hertz_low=hparams[HP_HERTZ_LOW.name],
            hertz_high=hparams[HP_HERTZ_HIGH.name])

        mel_specs = tf.expand_dims(mel_specs, axis=0)

        decoded = decoder_fn(mel_specs, max_length=5)[0]

        transcription = LAST_OUTPUT + tok_to_text(decoded)\
            .numpy().decode('utf8')

        if transcription != LAST_OUTPUT:
            LAST_OUTPUT = transcription
            print(transcription)

        return in_data, pyaudio.paContinue

    stream = p.open(format=pyaudio.paFloat32,
                    channels=NUM_CHANNELS,
                    rate=SAMPLE_RATE,
                    input=True,
                    frames_per_buffer=CHUNK_SIZE,
                    stream_callback=listen_callback)

    print('Listening...')

    stream.start_stream()

    while stream.is_active():
        time.sleep(0.1)

    stream.stop_stream()
    stream.close()

    p.terminate()