Пример #1
0
parser.add_argument('--finetune',
                    dest='finetune',
                    action='store_true',
                    help='Finetune the model from checkpoint "continue_from"')
parser.add_argument('--augment',
                    dest='augment',
                    action='store_true',
                    help='Use random tempo and gain perturbations.')

torch.manual_seed(123456)
torch.cuda.manual_seed_all(123456)

if __name__ == '__main__':
    args = parser.parse_args()

    train_dataset = SpectrogramDataset(args.data_path, 'train')
    valid_dataset = SpectrogramDataset(args.data_path, 'valid')
    train_sampler = BucketingSampler(train_dataset, batch_size=args.batch_size)
    train_loader = AudioDataLoader(train_dataset,
                                   num_workers=args.num_workers,
                                   batch_sampler=train_sampler)
    valid_loader = AudioDataLoader(valid_dataset,
                                   batch_size=args.batch_size,
                                   num_workers=args.num_workers)

    dtype = torch.FloatTensor
    ltype = torch.LongTensor

    if torch.cuda.is_available():
        print('use gpu')
        dtype = torch.cuda.FloatTensor
Пример #2
0
    with open(args.labels_path, encoding="utf-8") as label_file:
        labels = json.load(label_file)

    vocab = Vocab()
    for label in labels:
        vocab.add_token(label)
        vocab.add_label(label)

    train_data_list = []
    for i in range(len(args.train_manifest_list)):
        if args.feat == "spectrogram":
            train_data = SpectrogramDataset(
                vocab,
                args,
                audio_conf,
                manifest_filepath_list=args.train_manifest_list,
                normalize=True,
                augment=args.augment,
                input_type=args.input_type,
                is_train=True,
                partitions=args.train_partition_list)
        elif args.feat == "logfbank":
            train_data = LogFBankDataset(
                vocab,
                args,
                audio_conf,
                manifest_filepath_list=args.train_manifest_list,
                normalize=True,
                augment=args.augment,
                input_type=args.input_type,
                is_train=True)
        train_data_list.append(train_data)
Пример #3
0
    # add PAD_CHAR, SOS_CHAR, EOS_CHAR
    labels = constant.PAD_CHAR + constant.SOS_CHAR + constant.EOS_CHAR + labels
    label2id, id2label = {}, {}
    count = 0
    for i in range(len(labels)):
        if labels[i] not in label2id:
            label2id[labels[i]] = count
            id2label[count] = labels[i]
            count += 1
        else:
            print("multiple label: ", labels[i])

    # label2id = dict([(labels[i], i) for i in range(len(labels))])
    # id2label = dict([(i, labels[i]) for i in range(len(labels))])

    train_data = SpectrogramDataset(audio_conf, manifest_filepath_list=args.train_manifest_list, label2id=label2id, normalize=True, augment=args.augment)
    train_sampler = BucketingSampler(train_data, batch_size=args.batch_size)
    train_loader = AudioDataLoader(
        train_data, num_workers=args.num_workers, batch_sampler=train_sampler)

    valid_loader_list, test_loader_list = [], []
    for i in range(len(args.valid_manifest_list)):
        valid_data = SpectrogramDataset(audio_conf, manifest_filepath_list=[args.valid_manifest_list[i]], label2id=label2id,
                                        normalize=True, augment=False)
        valid_loader = AudioDataLoader(valid_data, num_workers=args.num_workers, batch_size=args.batch_size)
        valid_loader_list.append(valid_loader)

    for i in range(len(args.test_manifest_list)):
        test_data = SpectrogramDataset(audio_conf, manifest_filepath_list=[args.test_manifest_list[i]], label2id=label2id,
                                    normalize=True, augment=False)
        test_loader = AudioDataLoader(test_data, num_workers=args.num_workers)
Пример #4
0
    with open(args.labels_path, encoding="utf-8") as label_file:
        labels = json.load(label_file)

    vocab = Vocab()
    for label in labels:
        vocab.add_token(label)
        vocab.add_label(label)

    train_loader_list, valid_loader_list, test_loader_list = [], [], []
    for i in range(len(args.train_manifest_list)):
        if args.feat == "spectrogram":
            train_data = SpectrogramDataset(
                vocab,
                args,
                audio_conf,
                manifest_filepath_list=[args.train_manifest_list[i]],
                normalize=True,
                augment=args.augment,
                input_type=args.input_type)
        elif args.feat == "logfbank":
            train_data = LogFBankDataset(
                vocab,
                args,
                audio_conf,
                manifest_filepath_list=[args.train_manifest_list[i]],
                normalize=True,
                augment=False,
                input_type=args.input_type)
        train_loader = AudioDataLoader(pad_token_id=0,
                                       dataset=train_data,
                                       num_workers=args.num_workers,
Пример #5
0
    train_manifest_list = args.train_manifest_list
    val_manifest_list = args.val_manifest_list
    test_manifest_list = args.test_manifest_list

    train_data = MultiSpectrogramDataset(audio_conf=audio_conf, manifest_filepath_list=train_manifest_list, label2id=label2id,
                                        normalize=True, augment=args.augment)
    train_sampler = MultiBucketingSampler(train_data, batch_size=args.batch_size)
    train_loader = MultiAudioDataLoader(
        train_data, num_workers=args.num_workers, batch_sampler=train_sampler)

    valid_loaders = []

    for i in range(len(val_manifest_list)):
        val_manifest = val_manifest_list[i]

        valid_data = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=val_manifest, label2id=label2id,
                                        normalize=True, augment=False)
        valid_sampler = BucketingSampler(valid_data, batch_size=args.batch_size)
        valid_loader = AudioDataLoader(valid_data, num_workers=args.num_workers, batch_sampler=valid_sampler)
        valid_loaders.append(valid_loader)

    start_epoch = 0
    metrics = None
    if constant.args.continue_from != "":
        print("Continue from checkpoint:", constant.args.continue_from)
        model, opt, epoch, metrics, loaded_args, label2id, id2label = load_model(
            constant.args.continue_from)
        start_epoch = (epoch-1)  # index starts from zero
        verbose = constant.args.verbose
    else:
        if constant.args.model == "TRFS":
            model = init_transformer_model(constant.args, label2id, id2label)
Пример #6
0
    if loaded_args.parallel:
        print("unwrap data parallel")
        model = model.module

    audio_conf = dict(sample_rate=loaded_args.sample_rate,
                      window_size=loaded_args.window_size,
                      window_stride=loaded_args.window_stride,
                      window=loaded_args.window,
                      noise_dir=loaded_args.noise_dir,
                      noise_prob=loaded_args.noise_prob,
                      noise_levels=(loaded_args.noise_min,
                                    loaded_args.noise_max))

    test_data = SpectrogramDataset(
        audio_conf=audio_conf,
        manifest_filepath_list=constant.args.test_manifest_list,
        label2id=label2id,
        normalize=True,
        augment=False)
    test_sampler = BucketingSampler(test_data,
                                    batch_size=constant.args.batch_size)
    test_loader = AudioDataLoader(test_data,
                                  num_workers=args.num_workers,
                                  batch_sampler=test_sampler)

    lm = None
    if constant.args.lm_rescoring:
        lm = LM(constant.args.lm_path)

    print(model)

    evaluate(model, test_loader, lm=lm)
Пример #7
0
def load_data(train_manifest_list,
              valid_manifest_list,
              test_manifest_list,
              batch_size=12):

    audio_conf = dict(sample_rate=16000,
                      window_size=0.02,
                      window_stride=0.01,
                      window='hamming')
    PAD_CHAR = "¶"
    SOS_CHAR = "§"
    EOS_CHAR = "¤"

    labels_path = './labels.json'
    with open(labels_path) as label_file:
        labels = str(''.join(json.load(label_file)))

    # add PAD_CHAR, SOS_CHAR, EOS_CHAR
    labels = PAD_CHAR + SOS_CHAR + EOS_CHAR + labels
    label2id, id2label = {}, {}
    count = 0
    for i in range(len(labels)):
        if labels[i] not in label2id:
            label2id[labels[i]] = count
            id2label[count] = labels[i]
            count += 1

    train_data = SpectrogramDataset(audio_conf,
                                    manifest_filepath_list=train_manifest_list,
                                    label2id=label2id,
                                    normalize=True,
                                    augment=False)
    # print('train_data ', train_data)
    train_sampler = BucketingSampler(train_data, batch_size=batch_size)
    # print('train_sampler: ', train_sampler)
    train_loader = AudioDataLoader(train_data,
                                   num_workers=4,
                                   batch_sampler=train_sampler)

    valid_loader_list, test_loader_list = [], []
    for i in range(len(valid_manifest_list)):
        valid_data = SpectrogramDataset(
            audio_conf,
            manifest_filepath_list=[valid_manifest_list[i]],
            label2id=label2id,
            normalize=True,
            augment=False)
        valid_loader = AudioDataLoader(valid_data,
                                       num_workers=4,
                                       batch_size=batch_size)
        valid_loader_list.append(valid_loader)

    for i in range(len(test_manifest_list)):
        test_data = SpectrogramDataset(
            audio_conf,
            manifest_filepath_list=[test_manifest_list[i]],
            label2id=label2id,
            normalize=True,
            augment=False)
        test_loader = AudioDataLoader(test_data, num_workers=4)
        test_loader_list.append(test_loader)
    print('done !')
    return train_loader, valid_loader_list, test_loader_list
Пример #8
0
    
    print("EPOCH:", epoch)

    audio_conf = dict(sample_rate=loaded_args.sample_rate,
                      window_size=loaded_args.window_size,
                      window_stride=loaded_args.window_stride,
                      window=loaded_args.window,
                      noise_dir=loaded_args.noise_dir,
                      noise_prob=loaded_args.noise_prob,
                      noise_levels=(loaded_args.noise_min, loaded_args.noise_max))

    test_manifest_list = args.test_manifest_list

    print("INPUT TYPE: ", args.input_type)
    if loaded_args.feat == "spectrogram":
        test_data = SpectrogramDataset(vocab, args, audio_conf=audio_conf, manifest_filepath_list=[test_manifest_list[0]], normalize=True, augment=False, input_type=args.input_type)
    elif loaded_args.feat == "logfbank":
        test_data = LogFBankDataset(vocab, args, audio_conf=audio_conf, manifest_filepath_list=[test_manifest_list[0]], normalize=True, augment=False, input_type=args.input_type)
    test_sampler = BucketingSampler(test_data, batch_size=args.k_test)
    test_loader = AudioDataLoader(vocab.PAD_ID, dataset=test_data, num_workers=args.num_workers, batch_sampler=test_sampler)

    print("Parameters: {}(trainable), {}(non-trainable)".format(compute_num_params(model)[0], compute_num_params(model)[1]))

    if not args.cuda:
        model = model.cpu()

    lm = None
    if args.lm_rescoring:
        lm = LM(args.lm_path, args)

    print(">>>>>>>>>", args.tgt_max_len)