예제 #1
0
    labels = constant.PAD_CHAR + constant.SOS_CHAR + constant.EOS_CHAR + labels
    label2id, id2label = {}, {}
    count = 0
    for i in range(len(labels)):
        if labels[i] not in label2id:
            label2id[labels[i]] = count
            id2label[count] = labels[i]
            count += 1
        else:
            print("multiple label: ", labels[i])

    # label2id = dict([(labels[i], i) for i in range(len(labels))])
    # id2label = dict([(i, labels[i]) for i in range(len(labels))])

    train_data = SpectrogramDataset(audio_conf, manifest_filepath_list=args.train_manifest_list, label2id=label2id, normalize=True, augment=args.augment)
    train_sampler = BucketingSampler(train_data, batch_size=args.batch_size)
    train_loader = AudioDataLoader(
        train_data, num_workers=args.num_workers, batch_sampler=train_sampler)

    valid_loader_list, test_loader_list = [], []
    for i in range(len(args.valid_manifest_list)):
        valid_data = SpectrogramDataset(audio_conf, manifest_filepath_list=[args.valid_manifest_list[i]], label2id=label2id,
                                        normalize=True, augment=False)
        valid_loader = AudioDataLoader(valid_data, num_workers=args.num_workers, batch_size=args.batch_size)
        valid_loader_list.append(valid_loader)

    for i in range(len(args.test_manifest_list)):
        test_data = SpectrogramDataset(audio_conf, manifest_filepath_list=[args.test_manifest_list[i]], label2id=label2id,
                                    normalize=True, augment=False)
        test_loader = AudioDataLoader(test_data, num_workers=args.num_workers)
        test_loader_list.append(test_loader)
예제 #2
0
                args,
                audio_conf,
                manifest_filepath_list=[args.valid_manifest_list[i]],
                normalize=True,
                augment=args.augment,
                input_type=args.input_type)
        elif args.feat == "logfbank":
            valid_data = LogFBankDataset(
                vocab,
                args,
                audio_conf,
                manifest_filepath_list=[args.valid_manifest_list[i]],
                normalize=True,
                augment=False,
                input_type=args.input_type)
        valid_sampler = BucketingSampler(valid_data, batch_size=args.k_train)
        valid_loader = AudioDataLoader(pad_token_id=vocab.PAD_ID,
                                       dataset=valid_data,
                                       num_workers=args.num_workers)
        valid_loader_list.append(valid_loader)

    start_epoch = 0
    metrics = None
    loaded_args = None
    if args.continue_from != "":
        logging.info("Continue from checkpoint:" + args.continue_from)
        model, vocab, opt, epoch, metrics, loaded_args = load_joint_model(
            args.continue_from)
        start_epoch = (epoch)  # index starts from zero
        verbose = args.verbose
    else:
예제 #3
0
def load_data(train_manifest_list,
              valid_manifest_list,
              test_manifest_list,
              batch_size=12):

    audio_conf = dict(sample_rate=16000,
                      window_size=0.02,
                      window_stride=0.01,
                      window='hamming')
    PAD_CHAR = "¶"
    SOS_CHAR = "§"
    EOS_CHAR = "¤"

    labels_path = './labels.json'
    with open(labels_path) as label_file:
        labels = str(''.join(json.load(label_file)))

    # add PAD_CHAR, SOS_CHAR, EOS_CHAR
    labels = PAD_CHAR + SOS_CHAR + EOS_CHAR + labels
    label2id, id2label = {}, {}
    count = 0
    for i in range(len(labels)):
        if labels[i] not in label2id:
            label2id[labels[i]] = count
            id2label[count] = labels[i]
            count += 1

    train_data = SpectrogramDataset(audio_conf,
                                    manifest_filepath_list=train_manifest_list,
                                    label2id=label2id,
                                    normalize=True,
                                    augment=False)
    # print('train_data ', train_data)
    train_sampler = BucketingSampler(train_data, batch_size=batch_size)
    # print('train_sampler: ', train_sampler)
    train_loader = AudioDataLoader(train_data,
                                   num_workers=4,
                                   batch_sampler=train_sampler)

    valid_loader_list, test_loader_list = [], []
    for i in range(len(valid_manifest_list)):
        valid_data = SpectrogramDataset(
            audio_conf,
            manifest_filepath_list=[valid_manifest_list[i]],
            label2id=label2id,
            normalize=True,
            augment=False)
        valid_loader = AudioDataLoader(valid_data,
                                       num_workers=4,
                                       batch_size=batch_size)
        valid_loader_list.append(valid_loader)

    for i in range(len(test_manifest_list)):
        test_data = SpectrogramDataset(
            audio_conf,
            manifest_filepath_list=[test_manifest_list[i]],
            label2id=label2id,
            normalize=True,
            augment=False)
        test_loader = AudioDataLoader(test_data, num_workers=4)
        test_loader_list.append(test_loader)
    print('done !')
    return train_loader, valid_loader_list, test_loader_list
예제 #4
0
파일: test.py 프로젝트: yf1291/nlp4
        model = model.module

    audio_conf = dict(sample_rate=loaded_args.sample_rate,
                      window_size=loaded_args.window_size,
                      window_stride=loaded_args.window_stride,
                      window=loaded_args.window,
                      noise_dir=loaded_args.noise_dir,
                      noise_prob=loaded_args.noise_prob,
                      noise_levels=(loaded_args.noise_min,
                                    loaded_args.noise_max))

    test_data = SpectrogramDataset(
        audio_conf=audio_conf,
        manifest_filepath_list=constant.args.test_manifest_list,
        label2id=label2id,
        normalize=True,
        augment=False)
    test_sampler = BucketingSampler(test_data,
                                    batch_size=constant.args.batch_size)
    test_loader = AudioDataLoader(test_data,
                                  num_workers=args.num_workers,
                                  batch_sampler=test_sampler)

    lm = None
    if constant.args.lm_rescoring:
        lm = LM(constant.args.lm_path)

    print(model)

    evaluate(model, test_loader, lm=lm)
예제 #5
0
    trainer = Trainer()
    trainer.train(model, vocab, train_loader, valid_loader_list, loss_type, start_epoch, num_epochs, args, last_metrics=metrics, evaluate_every=args.evaluate_every, early_stop=args.early_stop, opt_name=args.opt_name)

    # test
    logging.info("Test")
    print("Test")
    test_manifest_list = args.test_manifest_list
    args.tgt_max_len = 150

    cer_list, wer_list = [], []
    for i in range(len(test_manifest_list)):
        if loaded_args.feat == "spectrogram":
            test_data = SpectrogramDataset(vocab, args, audio_conf=audio_conf, manifest_filepath_list=[test_manifest_list[i]], normalize=True, augment=False, input_type=args.input_type)
        elif loaded_args.feat == "logfbank":
            test_data = LogFBankDataset(vocab, args, audio_conf=audio_conf, manifest_filepath_list=[test_manifest_list[i]], normalize=True, augment=False, input_type=args.input_type)
        test_sampler = BucketingSampler(test_data, batch_size=args.k_test)
        test_loader = AudioDataLoader(vocab.PAD_ID, dataset=test_data, num_workers=args.num_workers, batch_sampler=test_sampler)

        lm = None
        if args.lm_rescoring:
            lm = LM(args.lm_path, args)

        cer, wer = evaluate(model, vocab, test_loader, args, lm=lm, start_token=vocab.SOS_ID)
        cer_list.append(cer)
        wer_list.append(wer)

    print("="*50)
    logging.info("="*50)
    for index, (cer, wer) in enumerate(zip(cer_list, wer_list)):
        print("TEST DATASET [{}] CER:{:.4f} WER:{:.4f}".format(index, cer, wer))
        logging.info("TEST DATASET [{}] CER:{:.4f} WER:{:.4f}".format(index, cer, wer))