labels = constant.PAD_CHAR + constant.SOS_CHAR + constant.EOS_CHAR + labels label2id, id2label = {}, {} count = 0 for i in range(len(labels)): if labels[i] not in label2id: label2id[labels[i]] = count id2label[count] = labels[i] count += 1 else: print("multiple label: ", labels[i]) # label2id = dict([(labels[i], i) for i in range(len(labels))]) # id2label = dict([(i, labels[i]) for i in range(len(labels))]) train_data = SpectrogramDataset(audio_conf, manifest_filepath_list=args.train_manifest_list, label2id=label2id, normalize=True, augment=args.augment) train_sampler = BucketingSampler(train_data, batch_size=args.batch_size) train_loader = AudioDataLoader( train_data, num_workers=args.num_workers, batch_sampler=train_sampler) valid_loader_list, test_loader_list = [], [] for i in range(len(args.valid_manifest_list)): valid_data = SpectrogramDataset(audio_conf, manifest_filepath_list=[args.valid_manifest_list[i]], label2id=label2id, normalize=True, augment=False) valid_loader = AudioDataLoader(valid_data, num_workers=args.num_workers, batch_size=args.batch_size) valid_loader_list.append(valid_loader) for i in range(len(args.test_manifest_list)): test_data = SpectrogramDataset(audio_conf, manifest_filepath_list=[args.test_manifest_list[i]], label2id=label2id, normalize=True, augment=False) test_loader = AudioDataLoader(test_data, num_workers=args.num_workers) test_loader_list.append(test_loader)
args, audio_conf, manifest_filepath_list=[args.valid_manifest_list[i]], normalize=True, augment=args.augment, input_type=args.input_type) elif args.feat == "logfbank": valid_data = LogFBankDataset( vocab, args, audio_conf, manifest_filepath_list=[args.valid_manifest_list[i]], normalize=True, augment=False, input_type=args.input_type) valid_sampler = BucketingSampler(valid_data, batch_size=args.k_train) valid_loader = AudioDataLoader(pad_token_id=vocab.PAD_ID, dataset=valid_data, num_workers=args.num_workers) valid_loader_list.append(valid_loader) start_epoch = 0 metrics = None loaded_args = None if args.continue_from != "": logging.info("Continue from checkpoint:" + args.continue_from) model, vocab, opt, epoch, metrics, loaded_args = load_joint_model( args.continue_from) start_epoch = (epoch) # index starts from zero verbose = args.verbose else:
def load_data(train_manifest_list, valid_manifest_list, test_manifest_list, batch_size=12): audio_conf = dict(sample_rate=16000, window_size=0.02, window_stride=0.01, window='hamming') PAD_CHAR = "¶" SOS_CHAR = "§" EOS_CHAR = "¤" labels_path = './labels.json' with open(labels_path) as label_file: labels = str(''.join(json.load(label_file))) # add PAD_CHAR, SOS_CHAR, EOS_CHAR labels = PAD_CHAR + SOS_CHAR + EOS_CHAR + labels label2id, id2label = {}, {} count = 0 for i in range(len(labels)): if labels[i] not in label2id: label2id[labels[i]] = count id2label[count] = labels[i] count += 1 train_data = SpectrogramDataset(audio_conf, manifest_filepath_list=train_manifest_list, label2id=label2id, normalize=True, augment=False) # print('train_data ', train_data) train_sampler = BucketingSampler(train_data, batch_size=batch_size) # print('train_sampler: ', train_sampler) train_loader = AudioDataLoader(train_data, num_workers=4, batch_sampler=train_sampler) valid_loader_list, test_loader_list = [], [] for i in range(len(valid_manifest_list)): valid_data = SpectrogramDataset( audio_conf, manifest_filepath_list=[valid_manifest_list[i]], label2id=label2id, normalize=True, augment=False) valid_loader = AudioDataLoader(valid_data, num_workers=4, batch_size=batch_size) valid_loader_list.append(valid_loader) for i in range(len(test_manifest_list)): test_data = SpectrogramDataset( audio_conf, manifest_filepath_list=[test_manifest_list[i]], label2id=label2id, normalize=True, augment=False) test_loader = AudioDataLoader(test_data, num_workers=4) test_loader_list.append(test_loader) print('done !') return train_loader, valid_loader_list, test_loader_list
model = model.module audio_conf = dict(sample_rate=loaded_args.sample_rate, window_size=loaded_args.window_size, window_stride=loaded_args.window_stride, window=loaded_args.window, noise_dir=loaded_args.noise_dir, noise_prob=loaded_args.noise_prob, noise_levels=(loaded_args.noise_min, loaded_args.noise_max)) test_data = SpectrogramDataset( audio_conf=audio_conf, manifest_filepath_list=constant.args.test_manifest_list, label2id=label2id, normalize=True, augment=False) test_sampler = BucketingSampler(test_data, batch_size=constant.args.batch_size) test_loader = AudioDataLoader(test_data, num_workers=args.num_workers, batch_sampler=test_sampler) lm = None if constant.args.lm_rescoring: lm = LM(constant.args.lm_path) print(model) evaluate(model, test_loader, lm=lm)
trainer = Trainer() trainer.train(model, vocab, train_loader, valid_loader_list, loss_type, start_epoch, num_epochs, args, last_metrics=metrics, evaluate_every=args.evaluate_every, early_stop=args.early_stop, opt_name=args.opt_name) # test logging.info("Test") print("Test") test_manifest_list = args.test_manifest_list args.tgt_max_len = 150 cer_list, wer_list = [], [] for i in range(len(test_manifest_list)): if loaded_args.feat == "spectrogram": test_data = SpectrogramDataset(vocab, args, audio_conf=audio_conf, manifest_filepath_list=[test_manifest_list[i]], normalize=True, augment=False, input_type=args.input_type) elif loaded_args.feat == "logfbank": test_data = LogFBankDataset(vocab, args, audio_conf=audio_conf, manifest_filepath_list=[test_manifest_list[i]], normalize=True, augment=False, input_type=args.input_type) test_sampler = BucketingSampler(test_data, batch_size=args.k_test) test_loader = AudioDataLoader(vocab.PAD_ID, dataset=test_data, num_workers=args.num_workers, batch_sampler=test_sampler) lm = None if args.lm_rescoring: lm = LM(args.lm_path, args) cer, wer = evaluate(model, vocab, test_loader, args, lm=lm, start_token=vocab.SOS_ID) cer_list.append(cer) wer_list.append(wer) print("="*50) logging.info("="*50) for index, (cer, wer) in enumerate(zip(cer_list, wer_list)): print("TEST DATASET [{}] CER:{:.4f} WER:{:.4f}".format(index, cer, wer)) logging.info("TEST DATASET [{}] CER:{:.4f} WER:{:.4f}".format(index, cer, wer))