def test_specaugment_load_state_dict(): # all values non-default config = dict( time_warp_factor=85, num_feature_masks=2, features_mask_size=12, num_frame_masks=2, frames_mask_size=71, max_frames_mask_fraction=0.25, p=0.6, ) specaug = SpecAugment() specaug.load_state_dict(config) for key, value in config.items(): assert getattr(specaug, key) == value
def test_specaugment_state_dict(): # all values default config = dict( time_warp_factor=80, num_feature_masks=1, features_mask_size=13, num_frame_masks=1, frames_mask_size=70, max_frames_mask_fraction=0.2, p=0.5, ) specaug = SpecAugment(**config) state_dict = specaug.state_dict() for key, value in config.items(): assert state_dict[key] == value
def test_specaugment_2d_input_raises_error(): cuts = CutSet.from_json("test/fixtures/ljspeech/cuts.json") feats = torch.from_numpy(cuts[0].load_features()) tfnm = SpecAugment(p=1.0, time_warp_factor=10) with pytest.raises(AssertionError): augmented = tfnm(feats) assert (feats != augmented).any()
def test_specaugment_batch(num_feature_masks, num_frame_masks): cuts = CutSet.from_json('test/fixtures/ljspeech/cuts.json') feats, feat_lens = collate_features(cuts) tfnm = SpecAugment(p=1.0, time_warp_factor=10, features_mask_size=5, frames_mask_size=20, num_feature_masks=num_feature_masks, num_frame_masks=num_frame_masks) augmented = tfnm(feats) assert (feats != augmented).any()
def main(): args = get_parser().parse_args() model_type = args.model_type start_epoch = args.start_epoch num_epochs = args.num_epochs max_duration = args.max_duration accum_grad = args.accum_grad att_rate = args.att_rate fix_random_seed(42) exp_dir = Path('exp-' + model_type + '-noam-ctc-att-musan-sa') setup_logger('{}/log/log-train'.format(exp_dir)) tb_writer = SummaryWriter( log_dir=f'{exp_dir}/tensorboard') if args.tensorboard else None # load L, G, symbol_table lang_dir = Path('data/lang_nosp') phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') word_symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt') logging.info("Loading L.fst") if (lang_dir / 'Linv.pt').exists(): L_inv = k2.Fsa.from_dict(torch.load(lang_dir / 'Linv.pt')) else: with open(lang_dir / 'L.fst.txt') as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) L_inv = k2.arc_sort(L.invert_()) torch.save(L_inv.as_dict(), lang_dir / 'Linv.pt') graph_compiler = CtcTrainingGraphCompiler(L_inv=L_inv, phones=phone_symbol_table, words=word_symbol_table) phone_ids = get_phone_symbols(phone_symbol_table) # load dataset feature_dir = Path('exp/data') logging.info("About to get train cuts") cuts_train = load_manifest(feature_dir / 'cuts_train-clean-100.json.gz') if args.full_libri: cuts_train = ( cuts_train + load_manifest(feature_dir / 'cuts_train-clean-360.json.gz') + load_manifest(feature_dir / 'cuts_train-other-500.json.gz')) logging.info("About to get dev cuts") cuts_dev = (load_manifest(feature_dir / 'cuts_dev-clean.json.gz') + load_manifest(feature_dir / 'cuts_dev-other.json.gz')) logging.info("About to get Musan cuts") cuts_musan = load_manifest(feature_dir / 'cuts_musan.json.gz') logging.info("About to create train dataset") transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] if args.concatenate_cuts: logging.info( f'Using cut concatenation with duration factor {args.duration_factor} and gap {args.gap}.' ) # Cut concatenation should be the first transform in the list, # so that if we e.g. mix noise in, it will fill the gaps between different utterances. transforms = [ CutConcatenate(duration_factor=args.duration_factor, gap=args.gap) ] + transforms train = K2SpeechRecognitionDataset(cuts_train, cut_transforms=transforms, input_transforms=[ SpecAugment(num_frame_masks=2, features_mask_size=27, num_feature_masks=2, frames_mask_size=100) ]) if args.on_the_fly_feats: # NOTE: the PerturbSpeed transform should be added only if we remove it from data prep stage. # # Add on-the-fly speed perturbation; since originally it would have increased epoch # # size by 3, we will apply prob 2/3 and use 3x more epochs. # # Speed perturbation probably should come first before concatenation, # # but in principle the transforms order doesn't have to be strict (e.g. could be randomized) # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2 / 3)] + transforms # Drop feats to be on the safe side. cuts_train = cuts_train.drop_features() from lhotse.features.fbank import FbankConfig train = K2SpeechRecognitionDataset( cuts=cuts_train, cut_transforms=transforms, input_strategy=OnTheFlyFeatures(Fbank( FbankConfig(num_mel_bins=80))), input_transforms=[ SpecAugment(num_frame_masks=2, features_mask_size=27, num_feature_masks=2, frames_mask_size=100) ]) if args.bucketing_sampler: logging.info('Using BucketingSampler.') train_sampler = BucketingSampler(cuts_train, max_duration=max_duration, shuffle=True, num_buckets=args.num_buckets) else: logging.info('Using SingleCutSampler.') train_sampler = SingleCutSampler( cuts_train, max_duration=max_duration, shuffle=True, ) logging.info("About to create train dataloader") train_dl = torch.utils.data.DataLoader( train, sampler=train_sampler, batch_size=None, num_workers=4, ) logging.info("About to create dev dataset") if args.on_the_fly_feats: cuts_dev = cuts_dev.drop_features() validate = K2SpeechRecognitionDataset( cuts_dev.drop_features(), input_strategy=OnTheFlyFeatures(Fbank( FbankConfig(num_mel_bins=80)))) else: validate = K2SpeechRecognitionDataset(cuts_dev) valid_sampler = SingleCutSampler( cuts_dev, max_duration=max_duration, ) logging.info("About to create dev dataloader") valid_dl = torch.utils.data.DataLoader(validate, sampler=valid_sampler, batch_size=None, num_workers=1) if not torch.cuda.is_available(): logging.error('No GPU detected!') sys.exit(-1) logging.info("About to create model") device_id = 0 device = torch.device('cuda', device_id) if att_rate != 0.0: num_decoder_layers = 6 else: num_decoder_layers = 0 if model_type == "transformer": model = Transformer( num_features=80, nhead=args.nhead, d_model=args.attention_dim, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers) else: model = Conformer( num_features=80, nhead=args.nhead, d_model=args.attention_dim, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers) model.to(device) describe(model) optimizer = Noam(model.parameters(), model_size=args.attention_dim, factor=1.0, warm_step=args.warm_step) best_objf = np.inf best_valid_objf = np.inf best_epoch = start_epoch best_model_path = os.path.join(exp_dir, 'best_model.pt') best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info') global_batch_idx_train = 0 # for logging only if start_epoch > 0: model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(start_epoch - 1)) ckpt = load_checkpoint(filename=model_path, model=model, optimizer=optimizer) best_objf = ckpt['objf'] best_valid_objf = ckpt['valid_objf'] global_batch_idx_train = ckpt['global_batch_idx_train'] logging.info( f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}" ) for epoch in range(start_epoch, num_epochs): train_sampler.set_epoch(epoch) curr_learning_rate = optimizer._rate if tb_writer is not None: tb_writer.add_scalar('train/learning_rate', curr_learning_rate, global_batch_idx_train) tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train) logging.info('epoch {}, learning rate {}'.format( epoch, curr_learning_rate)) objf, valid_objf, global_batch_idx_train = train_one_epoch( dataloader=train_dl, valid_dataloader=valid_dl, model=model, device=device, graph_compiler=graph_compiler, optimizer=optimizer, accum_grad=accum_grad, att_rate=att_rate, current_epoch=epoch, tb_writer=tb_writer, num_epochs=num_epochs, global_batch_idx_train=global_batch_idx_train, ) # the lower, the better if valid_objf < best_valid_objf: best_valid_objf = valid_objf best_objf = objf best_epoch = epoch save_checkpoint(filename=best_model_path, optimizer=None, scheduler=None, model=model, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train) save_training_info(filename=best_epoch_info_filename, model_path=best_model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch) # we always save the model for every epoch model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch)) save_checkpoint(filename=model_path, optimizer=optimizer, scheduler=None, model=model, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train) epoch_info_filename = os.path.join(exp_dir, 'epoch-{}-info'.format(epoch)) save_training_info(filename=epoch_info_filename, model_path=model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch) logging.warning('Done')
def train_dataloaders(self) -> DataLoader: logging.info("About to get train cuts") cuts_train = self.train_cuts() logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.feature_dir / 'cuts_musan.json.gz') logging.info("About to create train dataset") transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] if self.args.concatenate_cuts: logging.info( f'Using cut concatenation with duration factor ' f'{self.args.duration_factor} and gap {self.args.gap}.') # Cut concatenation should be the first transform in the list, # so that if we e.g. mix noise in, it will fill the gaps between different utterances. transforms = [ CutConcatenate(duration_factor=self.args.duration_factor, gap=self.args.gap) ] + transforms input_transforms = [ SpecAugment(num_frame_masks=2, features_mask_size=27, num_feature_masks=2, frames_mask_size=100) ] train = K2SpeechRecognitionDataset( cut_transforms=transforms, input_transforms=input_transforms, return_cuts=True, ) if self.args.on_the_fly_feats: # NOTE: the PerturbSpeed transform should be added only if we remove it from data prep stage. # # Add on-the-fly speed perturbation; since originally it would have increased epoch # # size by 3, we will apply prob 2/3 and use 3x more epochs. # # Speed perturbation probably should come first before concatenation, # # but in principle the transforms order doesn't have to be strict (e.g. could be randomized) # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2 / 3)] + transforms # Drop feats to be on the safe side. cuts_train = cuts_train.drop_features() train = K2SpeechRecognitionDataset( cut_transforms=transforms, input_strategy=OnTheFlyFeatures( Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=True, ) if self.args.bucketing_sampler: logging.info('Using BucketingSampler.') train_sampler = BucketingSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets) else: logging.info('Using SingleCutSampler.') train_sampler = SingleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, ) logging.info("About to create train dataloader") train_dl = DataLoader( train, sampler=train_sampler, batch_size=None, num_workers=4, persistent_workers=True, ) return train_dl
def test_specaugment_single(): cuts = CutSet.from_json('test/fixtures/ljspeech/cuts.json') feats = torch.from_numpy(cuts[0].load_features()) tfnm = SpecAugment(p=1.0, time_warp_factor=10) augmented = tfnm(feats) assert (feats != augmented).any()