def test_k2_speech_recognition_on_the_fly_feature_extraction_with_randomized_smoothing( k2_cut_set, ): dataset = K2SpeechRecognitionDataset( input_strategy=OnTheFlyFeatures(extractor=Fbank(), )) rs_dataset = K2SpeechRecognitionDataset(input_strategy=OnTheFlyFeatures( extractor=Fbank(), # Use p=1.0 to ensure that smoothing is applied in this test. wave_transforms=[RandomizedSmoothing(sigma=0.5, p=1.0)], )) sampler = SingleCutSampler(k2_cut_set, shuffle=False, max_cuts=1) for cut_ids in sampler: batch = dataset[cut_ids] rs_batch = rs_dataset[cut_ids] # Additive noise should cause the energies to go up assert (rs_batch["inputs"] - batch["inputs"]).sum() > 0
def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]: cuts = self.test_cuts() is_list = isinstance(cuts, list) test_loaders = [] if not is_list: cuts = [cuts] for cuts_test in cuts: logging.debug("About to create test dataset") test = K2SpeechRecognitionDataset( input_strategy=( OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) if self.args.on_the_fly_feats else PrecomputedFeatures()), return_cuts=True, ) sampler = SingleCutSampler(cuts_test, max_duration=self.args.max_duration) logging.debug("About to create test dataloader") test_dl = DataLoader(test, batch_size=None, sampler=sampler, num_workers=1) test_loaders.append(test_dl) if is_list: return test_loaders else: return test_loaders[0]
def test_k2_speech_recognition_on_the_fly_feature_extraction( k2_cut_set, use_batch_extract, fault_tolerant): precomputed_dataset = K2SpeechRecognitionDataset() on_the_fly_dataset = K2SpeechRecognitionDataset( input_strategy=OnTheFlyFeatures( Fbank(FbankConfig(num_mel_bins=40)), use_batch_extract=use_batch_extract, fault_tolerant=fault_tolerant, )) sampler = SimpleCutSampler(k2_cut_set, shuffle=False, max_cuts=1) for cut_ids in sampler: batch_pc = precomputed_dataset[cut_ids] batch_otf = on_the_fly_dataset[cut_ids] # Check that the features do not differ too much. norm_pc = torch.linalg.norm(batch_pc["inputs"]) norm_diff = torch.linalg.norm(batch_pc["inputs"] - batch_otf["inputs"]) # The precomputed and on-the-fly features are different due to mixing in time/fbank domains # and lilcom compression. assert norm_diff < 0.01 * norm_pc # Check that the supervision boundaries are the same. assert (batch_pc["supervisions"]["start_frame"] == batch_otf["supervisions"]["start_frame"]).all() assert (batch_pc["supervisions"]["num_frames"] == batch_otf["supervisions"]["num_frames"]).all()
def valid_dataloaders(self) -> DataLoader: logging.info("About to get dev cuts") cuts_valid = self.valid_cuts() logging.info("About to create dev dataset") if self.args.on_the_fly_feats: cuts_valid = cuts_valid.drop_features() validate = K2SpeechRecognitionDataset( cuts_valid.drop_features(), input_strategy=OnTheFlyFeatures( Fbank(FbankConfig(num_mel_bins=80)))) else: validate = K2SpeechRecognitionDataset(cuts_valid) valid_sampler = SingleCutSampler( cuts_valid, max_duration=self.args.max_duration, ) logging.info("About to create dev dataloader") valid_dl = DataLoader( validate, sampler=valid_sampler, batch_size=None, num_workers=2, persistent_workers=True, ) return valid_dl
def valid_dataloaders(self) -> DataLoader: self.validate_args() logging.info("About to get dev cuts") cuts_valid = self.valid_cuts() transforms = [] if self.args.concatenate_cuts: transforms = [ CutConcatenate(duration_factor=self.args.duration_factor, gap=self.args.gap) ] + transforms logging.info("About to create dev dataset") if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, input_strategy=OnTheFlyFeatures(Fbank( FbankConfig(num_mel_bins=80)), num_workers=8), return_cuts=self.args.return_cuts, ) else: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, return_cuts=self.args.return_cuts, ) valid_sampler = SingleCutSampler( cuts_valid, max_duration=self.args.max_duration, shuffle=False, ) logging.info("About to create dev dataloader") # valid_dl = DataLoader( # validate, # sampler=valid_sampler, # batch_size=None, # num_workers=8, # persistent_workers=True, # ) valid_dl = LhotseDataLoader( validate, sampler=valid_sampler, num_workers=2, ) return valid_dl
def test_k2_speech_recognition_on_the_fly_feature_extraction(k2_cut_set): precomputed_dataset = K2SpeechRecognitionDataset(k2_cut_set) on_the_fly_dataset = K2SpeechRecognitionDataset( k2_cut_set.drop_features(), input_strategy=OnTheFlyFeatures(Fbank()) ) sampler = SingleCutSampler(k2_cut_set, shuffle=False, max_cuts=1) for cut_ids in sampler: batch_pc = precomputed_dataset[cut_ids] batch_otf = on_the_fly_dataset[cut_ids] # Check that the features do not differ too much. norm_pc = torch.linalg.norm(batch_pc['inputs']) norm_diff = torch.linalg.norm(batch_pc['inputs'] - batch_otf['inputs']) # The precomputed and on-the-fly features are different due to mixing in time/fbank domains # and lilcom compression. assert norm_diff < 0.01 * norm_pc # Check that the supervision boundaries are the same. assert (batch_pc['supervisions']['start_frame'] == batch_otf['supervisions']['start_frame']).all() assert (batch_pc['supervisions']['num_frames'] == batch_otf['supervisions']['num_frames']).all()
def main(): args = get_parser().parse_args() model_type = args.model_type start_epoch = args.start_epoch num_epochs = args.num_epochs max_duration = args.max_duration accum_grad = args.accum_grad att_rate = args.att_rate fix_random_seed(42) exp_dir = Path('exp-' + model_type + '-noam-ctc-att-musan-sa') setup_logger('{}/log/log-train'.format(exp_dir)) tb_writer = SummaryWriter( log_dir=f'{exp_dir}/tensorboard') if args.tensorboard else None # load L, G, symbol_table lang_dir = Path('data/lang_nosp') phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') word_symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt') logging.info("Loading L.fst") if (lang_dir / 'Linv.pt').exists(): L_inv = k2.Fsa.from_dict(torch.load(lang_dir / 'Linv.pt')) else: with open(lang_dir / 'L.fst.txt') as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) L_inv = k2.arc_sort(L.invert_()) torch.save(L_inv.as_dict(), lang_dir / 'Linv.pt') graph_compiler = CtcTrainingGraphCompiler(L_inv=L_inv, phones=phone_symbol_table, words=word_symbol_table) phone_ids = get_phone_symbols(phone_symbol_table) # load dataset feature_dir = Path('exp/data') logging.info("About to get train cuts") cuts_train = load_manifest(feature_dir / 'cuts_train-clean-100.json.gz') if args.full_libri: cuts_train = ( cuts_train + load_manifest(feature_dir / 'cuts_train-clean-360.json.gz') + load_manifest(feature_dir / 'cuts_train-other-500.json.gz')) logging.info("About to get dev cuts") cuts_dev = (load_manifest(feature_dir / 'cuts_dev-clean.json.gz') + load_manifest(feature_dir / 'cuts_dev-other.json.gz')) logging.info("About to get Musan cuts") cuts_musan = load_manifest(feature_dir / 'cuts_musan.json.gz') logging.info("About to create train dataset") transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] if args.concatenate_cuts: logging.info( f'Using cut concatenation with duration factor {args.duration_factor} and gap {args.gap}.' ) # Cut concatenation should be the first transform in the list, # so that if we e.g. mix noise in, it will fill the gaps between different utterances. transforms = [ CutConcatenate(duration_factor=args.duration_factor, gap=args.gap) ] + transforms train = K2SpeechRecognitionDataset(cuts_train, cut_transforms=transforms, input_transforms=[ SpecAugment(num_frame_masks=2, features_mask_size=27, num_feature_masks=2, frames_mask_size=100) ]) if args.on_the_fly_feats: # NOTE: the PerturbSpeed transform should be added only if we remove it from data prep stage. # # Add on-the-fly speed perturbation; since originally it would have increased epoch # # size by 3, we will apply prob 2/3 and use 3x more epochs. # # Speed perturbation probably should come first before concatenation, # # but in principle the transforms order doesn't have to be strict (e.g. could be randomized) # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2 / 3)] + transforms # Drop feats to be on the safe side. cuts_train = cuts_train.drop_features() from lhotse.features.fbank import FbankConfig train = K2SpeechRecognitionDataset( cuts=cuts_train, cut_transforms=transforms, input_strategy=OnTheFlyFeatures(Fbank( FbankConfig(num_mel_bins=80))), input_transforms=[ SpecAugment(num_frame_masks=2, features_mask_size=27, num_feature_masks=2, frames_mask_size=100) ]) if args.bucketing_sampler: logging.info('Using BucketingSampler.') train_sampler = BucketingSampler(cuts_train, max_duration=max_duration, shuffle=True, num_buckets=args.num_buckets) else: logging.info('Using SingleCutSampler.') train_sampler = SingleCutSampler( cuts_train, max_duration=max_duration, shuffle=True, ) logging.info("About to create train dataloader") train_dl = torch.utils.data.DataLoader( train, sampler=train_sampler, batch_size=None, num_workers=4, ) logging.info("About to create dev dataset") if args.on_the_fly_feats: cuts_dev = cuts_dev.drop_features() validate = K2SpeechRecognitionDataset( cuts_dev.drop_features(), input_strategy=OnTheFlyFeatures(Fbank( FbankConfig(num_mel_bins=80)))) else: validate = K2SpeechRecognitionDataset(cuts_dev) valid_sampler = SingleCutSampler( cuts_dev, max_duration=max_duration, ) logging.info("About to create dev dataloader") valid_dl = torch.utils.data.DataLoader(validate, sampler=valid_sampler, batch_size=None, num_workers=1) if not torch.cuda.is_available(): logging.error('No GPU detected!') sys.exit(-1) logging.info("About to create model") device_id = 0 device = torch.device('cuda', device_id) if att_rate != 0.0: num_decoder_layers = 6 else: num_decoder_layers = 0 if model_type == "transformer": model = Transformer( num_features=80, nhead=args.nhead, d_model=args.attention_dim, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers) else: model = Conformer( num_features=80, nhead=args.nhead, d_model=args.attention_dim, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers) model.to(device) describe(model) optimizer = Noam(model.parameters(), model_size=args.attention_dim, factor=1.0, warm_step=args.warm_step) best_objf = np.inf best_valid_objf = np.inf best_epoch = start_epoch best_model_path = os.path.join(exp_dir, 'best_model.pt') best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info') global_batch_idx_train = 0 # for logging only if start_epoch > 0: model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(start_epoch - 1)) ckpt = load_checkpoint(filename=model_path, model=model, optimizer=optimizer) best_objf = ckpt['objf'] best_valid_objf = ckpt['valid_objf'] global_batch_idx_train = ckpt['global_batch_idx_train'] logging.info( f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}" ) for epoch in range(start_epoch, num_epochs): train_sampler.set_epoch(epoch) curr_learning_rate = optimizer._rate if tb_writer is not None: tb_writer.add_scalar('train/learning_rate', curr_learning_rate, global_batch_idx_train) tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train) logging.info('epoch {}, learning rate {}'.format( epoch, curr_learning_rate)) objf, valid_objf, global_batch_idx_train = train_one_epoch( dataloader=train_dl, valid_dataloader=valid_dl, model=model, device=device, graph_compiler=graph_compiler, optimizer=optimizer, accum_grad=accum_grad, att_rate=att_rate, current_epoch=epoch, tb_writer=tb_writer, num_epochs=num_epochs, global_batch_idx_train=global_batch_idx_train, ) # the lower, the better if valid_objf < best_valid_objf: best_valid_objf = valid_objf best_objf = objf best_epoch = epoch save_checkpoint(filename=best_model_path, optimizer=None, scheduler=None, model=model, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train) save_training_info(filename=best_epoch_info_filename, model_path=best_model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch) # we always save the model for every epoch model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch)) save_checkpoint(filename=model_path, optimizer=optimizer, scheduler=None, model=model, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train) epoch_info_filename = os.path.join(exp_dir, 'epoch-{}-info'.format(epoch)) save_training_info(filename=epoch_info_filename, model_path=model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch) logging.warning('Done')
def train_dataloaders(self) -> DataLoader: logging.info("About to get train cuts") cuts_train = self.train_cuts() logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.feature_dir / 'cuts_musan.json.gz') logging.info("About to create train dataset") transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] if self.args.concatenate_cuts: logging.info( f'Using cut concatenation with duration factor ' f'{self.args.duration_factor} and gap {self.args.gap}.') # Cut concatenation should be the first transform in the list, # so that if we e.g. mix noise in, it will fill the gaps between different utterances. transforms = [ CutConcatenate(duration_factor=self.args.duration_factor, gap=self.args.gap) ] + transforms input_transforms = [ SpecAugment(num_frame_masks=2, features_mask_size=27, num_feature_masks=2, frames_mask_size=100) ] train = K2SpeechRecognitionDataset( cut_transforms=transforms, input_transforms=input_transforms, return_cuts=True, ) if self.args.on_the_fly_feats: # NOTE: the PerturbSpeed transform should be added only if we remove it from data prep stage. # # Add on-the-fly speed perturbation; since originally it would have increased epoch # # size by 3, we will apply prob 2/3 and use 3x more epochs. # # Speed perturbation probably should come first before concatenation, # # but in principle the transforms order doesn't have to be strict (e.g. could be randomized) # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2 / 3)] + transforms # Drop feats to be on the safe side. cuts_train = cuts_train.drop_features() train = K2SpeechRecognitionDataset( cut_transforms=transforms, input_strategy=OnTheFlyFeatures( Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=True, ) if self.args.bucketing_sampler: logging.info('Using BucketingSampler.') train_sampler = BucketingSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets) else: logging.info('Using SingleCutSampler.') train_sampler = SingleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, ) logging.info("About to create train dataloader") train_dl = DataLoader( train, sampler=train_sampler, batch_size=None, num_workers=4, persistent_workers=True, ) return train_dl
def main(): args = get_parser().parse_args() model_type = args.model_type epoch = args.epoch max_duration = args.max_duration avg = args.avg att_rate = args.att_rate exp_dir = Path('exp-' + model_type + '-noam-ctc-att-musan-sa') setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug') # load L, G, symbol_table lang_dir = Path('data/lang_nosp') symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt') phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') phone_ids = get_phone_symbols(phone_symbol_table) phone_ids_with_blank = [0] + phone_ids ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank)) logging.debug("About to load model") # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N # device = torch.device('cuda', 1) device = torch.device('cuda') if att_rate != 0.0: num_decoder_layers = 6 else: num_decoder_layers = 0 if model_type == "transformer": model = Transformer( num_features=80, nhead=args.nhead, d_model=args.attention_dim, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers) else: model = Conformer( num_features=80, nhead=args.nhead, d_model=args.attention_dim, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers) if avg == 1: checkpoint = os.path.join(exp_dir, 'epoch-' + str(epoch - 1) + '.pt') load_checkpoint(checkpoint, model) else: checkpoints = [ os.path.join(exp_dir, 'epoch-' + str(avg_epoch) + '.pt') for avg_epoch in range(epoch - avg, epoch) ] average_checkpoint(checkpoints, model) model.to(device) model.eval() if not os.path.exists(lang_dir / 'HLG.pt'): logging.debug("Loading L_disambig.fst.txt") with open(lang_dir / 'L_disambig.fst.txt') as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) logging.debug("Loading G.fst.txt") with open(lang_dir / 'G.fst.txt') as f: G = k2.Fsa.from_openfst(f.read(), acceptor=False) first_phone_disambig_id = find_first_disambig_symbol( phone_symbol_table) first_word_disambig_id = find_first_disambig_symbol(symbol_table) HLG = compile_HLG(L=L, G=G, H=ctc_topo, labels_disambig_id_start=first_phone_disambig_id, aux_labels_disambig_id_start=first_word_disambig_id) torch.save(HLG.as_dict(), lang_dir / 'HLG.pt') else: logging.debug("Loading pre-compiled HLG") d = torch.load(lang_dir / 'HLG.pt') HLG = k2.Fsa.from_dict(d) logging.debug("convert HLG to device") HLG = HLG.to(device) HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0) HLG.requires_grad_(False) # load dataset feature_dir = Path('exp/data') test_sets = ['test-clean', 'test-other'] for test_set in test_sets: logging.info(f'* DECODING: {test_set}') logging.debug("About to get test cuts") cuts_test = load_manifest(feature_dir / f'cuts_{test_set}.json.gz') logging.debug("About to create test dataset") from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse import Fbank, FbankConfig test = K2SpeechRecognitionDataset( cuts_test, input_strategy=OnTheFlyFeatures(Fbank( FbankConfig(num_mel_bins=80)))) sampler = SingleCutSampler(cuts_test, max_duration=max_duration) logging.debug("About to create test dataloader") test_dl = torch.utils.data.DataLoader(test, batch_size=None, sampler=sampler, num_workers=1) logging.debug("About to decode") results = decode(dataloader=test_dl, model=model, device=device, HLG=HLG, symbols=symbol_table) recog_path = exp_dir / f'recogs-{test_set}.txt' store_transcripts(path=recog_path, texts=results) logging.info(f'The transcripts are stored in {recog_path}') # compute WER dists = [edit_distance(r, h) for r, h in results] errors = { key: sum(dist[key] for dist in dists) for key in ['sub', 'ins', 'del', 'total'] } total_words = sum(len(ref) for ref, _ in results) # Print Kaldi-like message: # %WER 8.20 [ 4459 / 54402, 695 ins, 427 del, 3337 sub ] logging.info( f'[{test_set}] %WER {errors["total"] / total_words:.2%} ' f'[{errors["total"]} / {total_words}, {errors["ins"]} ins, {errors["del"]} del, {errors["sub"]} sub ]' )