コード例 #1
0
def run_evaluation(H):
    vocab = Vocabulary(os.path.join(H.ROOT_DIR, H.EXPERIMENT),
                       encoding=H.TARGET_ENCODING)

    audio_transform = transforms.Compose([
        AudioNormalizeDB(db=H.NORMALIZE_DB, max_gain_db=H.NORMALIZE_MAX_GAIN),
        AudioSpectrogram(sample_rate=H.AUDIO_SAMPLE_RATE,
                         window_size=H.SPECT_WINDOW_SIZE,
                         window_stride=H.SPECT_WINDOW_STRIDE,
                         window=H.SPECT_WINDOW),
        AudioNormalize(),
        FromNumpyToTensor(tensor_type=torch.FloatTensor)
    ])

    label_transform = transforms.Compose([
        TranscriptEncodeCTC(vocab),
        FromNumpyToTensor(tensor_type=torch.LongTensor)
    ])

    test_dataset = AudioDataset(os.path.join(H.ROOT_DIR, H.EXPERIMENT),
                                manifests_files=H.MANIFESTS,
                                datasets="test",
                                transform=audio_transform,
                                label_transform=label_transform,
                                max_data_size=None,
                                sorted_by='recording_duration')

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=H.BATCH_SIZE,
                                              num_workers=H.NUM_WORKERS,
                                              shuffle=False,
                                              collate_fn=collate_fn,
                                              pin_memory=True)

    logger.info(test_loader.dataset)

    model_pred = SpeechCNN(len(vocab),
                           input_size=256,
                           hidden_size=H.CNN_HIDDEN_SIZE,
                           dropout=H.CNN_DROPOUT,
                           initialize=torch_weight_init)
    if H.USE_CUDA:
        model_pred.cuda()

    state = torch.load(os.path.join(H.EXPERIMENT, H.MODEL_NAME + '.tar'))
    model_pred.load_state_dict(state)

    ctc_decoder = CTCGreedyDecoder(vocab)

    recognizer = Recognizer(model_pred, ctc_decoder, test_loader)

    hypotheses = recognizer()

    transcripts = []
    for _, labels, _, label_sizes, _ in test_loader:
        label_seq = CTCGreedyDecoder.decode_labels(labels, label_sizes, vocab)
        transcripts.extend(label_seq)

    bleu = Scorer.get_moses_multi_bleu(hypotheses,
                                       transcripts,
                                       lowercase=False)
    wer, cer = Scorer.get_wer_cer(hypotheses, transcripts)
    acc = Scorer.get_acc(hypotheses, transcripts)

    logger.info('Test Summary \n'
                'Bleu: {bleu:.3f}\n'
                'WER:  {wer:.3f}\n'
                'CER:  {cer:.3f}\n'
                'ACC:  {acc:.3f}'.format(bleu=bleu,
                                         wer=wer * 100,
                                         cer=cer * 100,
                                         acc=acc * 100))