def run(rank, world_size, args): ''' Args: rank: It is a value between 0 and `world_size-1`, which is passed automatically by `mp.spawn()` in :func:`main`. The node with rank 0 is responsible for saving checkpoint. world_size: Number of GPUs for DDP training. args: The return value of get_parser().parse_args() ''' model_type = args.model_type start_epoch = args.start_epoch num_epochs = args.num_epochs accum_grad = args.accum_grad den_scale = args.den_scale att_rate = args.att_rate use_pruned_intersect = args.use_pruned_intersect fix_random_seed(42) if world_size > 1: setup_dist(rank, world_size, args.master_port) exp_dir = Path('exp-' + model_type + '-mmi-att-sa-vgg-normlayer') setup_logger(f'{exp_dir}/log/log-train-{rank}') if args.tensorboard and rank == 0: tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') else: tb_writer = None # tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') if args.tensorboard and rank == 0 else None logging.info("Loading lexicon and symbol tables") lang_dir = Path('data/lang_nosp') lexicon = Lexicon(lang_dir) device_id = rank device = torch.device('cuda', device_id) if not Path(lang_dir / 'P.pt').is_file(): logging.debug(f'Loading P from {lang_dir}/P.fst.txt') with open(lang_dir / 'P.fst.txt') as f: # P is not an acceptor because there is # a back-off state, whose incoming arcs # have label #0 and aux_label eps. P = k2.Fsa.from_openfst(f.read(), acceptor=False) phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') first_phone_disambig_id = find_first_disambig_symbol( phone_symbol_table) # P.aux_labels is not needed in later computations, so # remove it here. del P.aux_labels # CAUTION(fangjun): The following line is crucial. # Arcs entering the back-off state have label equal to #0. # We have to change it to 0 here. P.labels[P.labels >= first_phone_disambig_id] = 0 P = k2.remove_epsilon(P) P = k2.arc_sort(P) torch.save(P.as_dict(), lang_dir / 'P.pt') else: logging.debug('Loading pre-compiled P') d = torch.load(lang_dir / 'P.pt') P = k2.Fsa.from_dict(d) graph_compiler = MmiTrainingGraphCompiler( lexicon=lexicon, P=P, device=device, ) phone_ids = lexicon.phone_symbols() # load dataset feature_dir = Path('exp/data') logging.info("About to get train cuts") cuts_train = CutSet.from_json(feature_dir / 'cuts_train.json.gz') logging.info("About to get dev cuts") cuts_dev = CutSet.from_json(feature_dir / 'cuts_dev.json.gz') logging.info("About to get Musan cuts") cuts_musan = CutSet.from_json(feature_dir / 'cuts_musan.json.gz') logging.info("About to create train dataset") train = K2SpeechRecognitionDataset(cuts_train, cut_transforms=[ CutConcatenate(), CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20)) ]) train_sampler = SingleCutSampler( cuts_train, max_frames=90000, shuffle=True, ) logging.info("About to create train dataloader") train_dl = torch.utils.data.DataLoader(train, sampler=train_sampler, batch_size=None, num_workers=4) logging.info("About to create dev dataset") validate = K2SpeechRecognitionDataset(cuts_dev) valid_sampler = SingleCutSampler(cuts_dev, max_frames=90000) logging.info("About to create dev dataloader") valid_dl = torch.utils.data.DataLoader(validate, sampler=valid_sampler, batch_size=None, num_workers=1) if not torch.cuda.is_available(): logging.error('No GPU detected!') sys.exit(-1) if use_pruned_intersect: logging.info('Use pruned intersect for den_lats') else: logging.info("Don't use pruned intersect for den_lats") logging.info("About to create model") if att_rate != 0.0: num_decoder_layers = 6 else: num_decoder_layers = 0 if model_type == "transformer": model = Transformer( num_features=40, nhead=args.nhead, d_model=args.attention_dim, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers, vgg_frontend=True) elif model_type == "conformer": model = Conformer( num_features=40, nhead=args.nhead, d_model=args.attention_dim, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers, vgg_frontend=True, is_espnet_structure=True) elif model_type == "contextnet": model = ContextNet(num_features=40, num_classes=len(phone_ids) + 1) # +1 for the blank symbol else: raise NotImplementedError("Model of type " + str(model_type) + " is not implemented") if args.torchscript: logging.info('Applying TorchScript to model...') model = torch.jit.script(model) model.to(device) describe(model) if world_size > 1: model = DDP(model, device_ids=[rank]) # Now for the alignment model, if any if args.use_ali_model: ali_model = TdnnLstm1b( num_features=40, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4) ali_model_fname = Path( f'exp-lstm-adam-ctc-musan/epoch-{args.ali_model_epoch}.pt') assert ali_model_fname.is_file(), \ f'ali model filename {ali_model_fname} does not exist!' ali_model.load_state_dict( torch.load(ali_model_fname, map_location='cpu')['state_dict']) ali_model.to(device) ali_model.eval() ali_model.requires_grad_(False) logging.info(f'Use ali_model: {ali_model_fname}') else: ali_model = None logging.info('No ali_model') optimizer = Noam(model.parameters(), model_size=args.attention_dim, factor=args.lr_factor, warm_step=args.warm_step, weight_decay=args.weight_decay) scaler = GradScaler(enabled=args.amp) best_objf = np.inf best_valid_objf = np.inf best_epoch = start_epoch best_model_path = os.path.join(exp_dir, 'best_model.pt') best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info') global_batch_idx_train = 0 # for logging only if start_epoch > 0: model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(start_epoch - 1)) ckpt = load_checkpoint(filename=model_path, model=model, optimizer=optimizer, scaler=scaler) best_objf = ckpt['objf'] best_valid_objf = ckpt['valid_objf'] global_batch_idx_train = ckpt['global_batch_idx_train'] logging.info( f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}" ) for epoch in range(start_epoch, num_epochs): #train_dl.sampler.set_epoch(epoch) curr_learning_rate = optimizer._rate if tb_writer is not None: tb_writer.add_scalar('train/learning_rate', curr_learning_rate, global_batch_idx_train) tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train) logging.info('epoch {}, learning rate {}'.format( epoch, curr_learning_rate)) objf, valid_objf, global_batch_idx_train = train_one_epoch( dataloader=train_dl, valid_dataloader=valid_dl, model=model, ali_model=ali_model, device=device, graph_compiler=graph_compiler, use_pruned_intersect=use_pruned_intersect, optimizer=optimizer, accum_grad=accum_grad, den_scale=den_scale, att_rate=att_rate, current_epoch=epoch, tb_writer=tb_writer, num_epochs=num_epochs, global_batch_idx_train=global_batch_idx_train, world_size=world_size, scaler=scaler) # the lower, the better if valid_objf < best_valid_objf: best_valid_objf = valid_objf best_objf = objf best_epoch = epoch save_checkpoint(filename=best_model_path, optimizer=None, scheduler=None, scaler=None, model=model, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train, local_rank=rank, torchscript=args.torchscript_epoch != -1 and epoch >= args.torchscript_epoch) save_training_info(filename=best_epoch_info_filename, model_path=best_model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch, local_rank=rank) # we always save the model for every epoch model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch)) save_checkpoint(filename=model_path, optimizer=optimizer, scheduler=None, scaler=scaler, model=model, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train, local_rank=rank, torchscript=args.torchscript_epoch != -1 and epoch >= args.torchscript_epoch) epoch_info_filename = os.path.join(exp_dir, 'epoch-{}-info'.format(epoch)) save_training_info(filename=epoch_info_filename, model_path=model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch, local_rank=rank) logging.warning('Done') if world_size > 1: torch.distributed.barrier() cleanup_dist()
def export_to_kaldi( recordings: RecordingSet, supervisions: SupervisionSet, output_dir: Pathlike, map_underscores_to: Optional[str] = None, ): """ Export a pair of ``RecordingSet`` and ``SupervisionSet`` to a Kaldi data directory. It even supports recordings that have multiple channels but the recordings will still have to have a single ``AudioSource``. The ``RecordingSet`` and ``SupervisionSet`` must be compatible, i.e. it must be possible to create a ``CutSet`` out of them. :param recordings: a ``RecordingSet`` manifest. :param supervisions: a ``SupervisionSet`` manifest. :param output_dir: path where the Kaldi-style data directory will be created. :param map_underscores_to: optional string with which we will replace all underscores. This helps avoid issues with Kaldi data dir sorting. """ output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) assert all(len(r.sources) == 1 for r in recordings), ( "Kaldi export of Recordings with multiple audio sources " "is currently not supported.") if map_underscores_to is not None: supervisions = supervisions.map(lambda s: fastcopy( s, id=s.id.replace("_", map_underscores_to), speaker=s.speaker.replace("_", map_underscores_to), )) # Create a simple CutSet that ties together # the recording <-> supervision information. cuts = CutSet.from_manifests( recordings=recordings, supervisions=supervisions).trim_to_supervisions() if all(r.num_channels == 1 for r in recordings): # if all the recordings are single channel, we won't add # the channel id affix to retain back compatibility # and the ability to receive back the same utterances after # importing the exported directory back # wav.scp save_kaldi_text_mapping( data={ recording.id: make_wavscp_channel_string_map( source, sampling_rate=recording.sampling_rate)[0] for recording in recordings for source in recording.sources }, path=output_dir / "wav.scp", ) # segments save_kaldi_text_mapping( data={ cut.supervisions[0].id: f"{cut.recording_id} {cut.start} {cut.end}" for cut in cuts }, path=output_dir / "segments", ) # reco2dur save_kaldi_text_mapping( data={ recording.id: recording.duration for recording in recordings }, path=output_dir / "reco2dur", ) else: # wav.scp save_kaldi_text_mapping( data={ f"{recording.id}_{channel}": make_wavscp_channel_string_map( source, sampling_rate=recording.sampling_rate)[channel] for recording in recordings for source in recording.sources for channel in source.channels }, path=output_dir / "wav.scp", ) # segments save_kaldi_text_mapping( data={ cut.supervisions[0].id: f"{cut.recording_id}_{cut.channel} {cut.start} {cut.end}" for cut in cuts }, path=output_dir / "segments", ) # reco2dur save_kaldi_text_mapping( data={ f"{recording.id}_{channel}": recording.duration for recording in recordings for channel in recording.sources[0].channels }, path=output_dir / "reco2dur", ) # text save_kaldi_text_mapping( data={ cut.supervisions[0].id: cut.supervisions[0].text for cut in cuts }, path=output_dir / "text", ) # utt2spk save_kaldi_text_mapping( data={ cut.supervisions[0].id: cut.supervisions[0].speaker for cut in cuts }, path=output_dir / "utt2spk", ) # utt2dur save_kaldi_text_mapping( data={cut.supervisions[0].id: cut.duration for cut in cuts}, path=output_dir / "utt2dur", ) # utt2lang [optional] if all(s.language is not None for s in supervisions): save_kaldi_text_mapping( data={ cut.supervisions[0].id: cut.supervisions[0].language for cut in cuts }, path=output_dir / "utt2lang", ) # utt2gender [optional] if all(s.gender is not None for s in supervisions): save_kaldi_text_mapping( data={ cut.supervisions[0].id: cut.supervisions[0].gender for cut in cuts }, path=output_dir / "utt2gender", )
def cut_set(cut): # The padding tests if feature extraction works correctly with a PaddingCut return CutSet.from_cuts([cut, cut.append(cut).pad(3.0)])
def prepare_gigaspeech( corpus_dir: Pathlike, output_dir: Optional[Pathlike], dataset_parts: Union[str, Sequence[str]] = "auto", num_jobs: int = 1, ) -> Dict[str, Dict[str, Union[RecordingSet, SupervisionSet]]]: if is_module_available("speechcolab"): from speechcolab.datasets.gigaspeech import GigaSpeech else: raise ImportError( "To process the GigaSpeech corpus, please install optional dependency: pip install speechcolab" ) subsets = ("XL", "DEV", "TEST") if dataset_parts == "auto" else dataset_parts if isinstance(subsets, str): subsets = [subsets] corpus_dir = Path(corpus_dir) gigaspeech = GigaSpeech(corpus_dir) output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) # Maybe some manifests already exist: we can read them and save a bit of preparation time. manifests = read_manifests_if_cached( dataset_parts=dataset_parts, output_dir=output_dir, prefix="gigaspeech", suffix="jsonl.gz", lazy=True, ) for part in subsets: logging.info(f"Processing GigaSpeech subset: {part}") if manifests_exist( part=part, output_dir=output_dir, prefix="gigaspeech", suffix="jsonl.gz" ): logging.info(f"GigaSpeech subset: {part} already prepared - skipping.") continue with RecordingSet.open_writer( output_dir / f"gigaspeech_recordings_{part}.jsonl.gz" ) as rec_writer, SupervisionSet.open_writer( output_dir / f"gigaspeech_supervisions_{part}.jsonl.gz" ) as sup_writer, CutSet.open_writer( output_dir / f"gigaspeech_cuts_{part}.jsonl.gz" ) as cut_writer: for recording, segments in tqdm( parallel_map( parse_utterance, gigaspeech.audios("{" + part + "}"), repeat(gigaspeech.gigaspeech_dataset_dir), num_jobs=num_jobs, ), desc="Processing GigaSpeech JSON entries", ): # Fix and validate the recording + supervisions recordings, segments = fix_manifests( recordings=RecordingSet.from_recordings([recording]), supervisions=SupervisionSet.from_segments(segments), ) validate_recordings_and_supervisions( recordings=recordings, supervisions=segments ) # Create the cut since most users will need it anyway. # There will be exactly one cut since there's exactly one recording. cuts = CutSet.from_manifests( recordings=recordings, supervisions=segments ) # Write the manifests rec_writer.write(recordings[0]) for s in segments: sup_writer.write(s) cut_writer.write(cuts[0]) manifests[part] = { "recordings": RecordingSet.from_jsonl_lazy(rec_writer.path), "supervisions": SupervisionSet.from_jsonl_lazy(sup_writer.path), "cuts": CutSet.from_jsonl_lazy(cut_writer.path), } return dict(manifests)
def test_collate_custom_temporal_array_ints(pad_direction): CODEBOOK_SIZE = 512 FRAME_SHIFT = 0.04 EXPECTED_PAD_VALUE = 0 cuts = CutSet.from_json("test/fixtures/ljspeech/cuts.json") max_num_frames = max( seconds_to_frames(cut.duration, FRAME_SHIFT) for cut in cuts) with NamedTemporaryFile(suffix=".h5") as f, NumpyHdf5Writer( f.name) as writer: expected_codebook_indices = [] for cut in cuts: expected_codebook_indices.append( np.random.randint(CODEBOOK_SIZE, size=(seconds_to_frames( cut.duration, FRAME_SHIFT), )).astype(np.int16)) cut.codebook_indices = writer.store_array( cut.id, expected_codebook_indices[-1], frame_shift=FRAME_SHIFT, temporal_dim=0, ) codebook_indices, codebook_indices_lens = collate_custom_field( cuts, "codebook_indices", pad_direction=pad_direction) assert isinstance(codebook_indices_lens, torch.Tensor) assert codebook_indices_lens.dtype == torch.int32 assert codebook_indices_lens.shape == (len(cuts), ) assert codebook_indices_lens.tolist() == [ seconds_to_frames(c.duration, FRAME_SHIFT) for c in cuts ] assert isinstance(codebook_indices, torch.Tensor) assert (codebook_indices.dtype == torch.int64 ) # the dtype got promoted by default assert codebook_indices.shape == (len(cuts), max_num_frames) for idx, cbidxs in enumerate(expected_codebook_indices): exp_len = cbidxs.shape[0] # PyTorch < 1.9.0 doesn't have an assert_equal function. if pad_direction == "right": np.testing.assert_equal( codebook_indices[idx, :exp_len].numpy(), cbidxs) np.testing.assert_equal( codebook_indices[idx, exp_len:].numpy(), EXPECTED_PAD_VALUE) if pad_direction == "left": np.testing.assert_equal( codebook_indices[idx, -exp_len:].numpy(), cbidxs) np.testing.assert_equal( codebook_indices[idx, :-exp_len].numpy(), EXPECTED_PAD_VALUE) if pad_direction == "both": half = (max_num_frames - exp_len) // 2 np.testing.assert_equal(codebook_indices[idx, :half].numpy(), EXPECTED_PAD_VALUE) np.testing.assert_equal( codebook_indices[idx, half:half + exp_len].numpy(), cbidxs) if half > 0: # indexing like [idx, -0:] would return the whole array rather # than an empty slice. np.testing.assert_equal( codebook_indices[idx, -half:].numpy(), EXPECTED_PAD_VALUE)
def __call__(self, cuts: CutSet) -> CutSet: return (cuts.sort_by_duration(ascending=False).mix( self.cuts, duration=cuts[0].duration, snr=self.snr, mix_prob=self.prob))
def main(): fix_random_seed(42) start_epoch = 0 num_epochs = 8 exp_dir = 'exp-lstm-adam-ctc-musan' setup_logger('{}/log/log-train'.format(exp_dir)) tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') # load L, G, symbol_table lang_dir = Path('data/lang_nosp') phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') word_symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt') logging.info("Loading L.fst") if (lang_dir / 'Linv.pt').exists(): L_inv = k2.Fsa.from_dict(torch.load(lang_dir / 'Linv.pt')) else: with open(lang_dir / 'L.fst.txt') as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) L_inv = k2.arc_sort(L.invert_()) torch.save(L_inv.as_dict(), lang_dir / 'Linv.pt') graph_compiler = CtcTrainingGraphCompiler( L_inv=L_inv, phones=phone_symbol_table, words=word_symbol_table ) phone_ids = get_phone_symbols(phone_symbol_table) # load dataset feature_dir = Path('exp/data') logging.info("About to get train cuts") cuts_train = CutSet.from_json(feature_dir / 'cuts_train-clean-100.json.gz') logging.info("About to get dev cuts") cuts_dev = CutSet.from_json(feature_dir / 'cuts_dev-clean.json.gz') logging.info("About to get Musan cuts") cuts_musan = CutSet.from_json(feature_dir / 'cuts_musan.json.gz') logging.info("About to create train dataset") train = K2SpeechRecognitionDataset( cuts_train, cut_transforms=[ CutConcatenate(), CutMix( cuts=cuts_musan, prob=0.5, snr=(10, 20) ) ] ) train_sampler = SingleCutSampler( cuts_train, max_frames=90000, shuffle=True, ) logging.info("About to create train dataloader") train_dl = torch.utils.data.DataLoader( train, sampler=train_sampler, batch_size=None, num_workers=4 ) logging.info("About to create dev dataset") validate = K2SpeechRecognitionDataset(cuts_dev) valid_sampler = SingleCutSampler(cuts_dev, max_frames=90000) logging.info("About to create dev dataloader") valid_dl = torch.utils.data.DataLoader( validate, sampler=valid_sampler, batch_size=None, num_workers=1 ) if not torch.cuda.is_available(): logging.error('No GPU detected!') sys.exit(-1) logging.info("About to create model") device_id = 0 device = torch.device('cuda', device_id) model = TdnnLstm1b( num_features=80, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4) model.to(device) describe(model) learning_rate = 1e-3 optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=5e-4) best_objf = np.inf best_valid_objf = np.inf best_epoch = start_epoch best_model_path = os.path.join(exp_dir, 'best_model.pt') best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info') global_batch_idx_train = 0 # for logging only if start_epoch > 0: model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(start_epoch - 1)) ckpt = load_checkpoint(filename=model_path, model=model, optimizer=optimizer) best_objf = ckpt['objf'] best_valid_objf = ckpt['valid_objf'] global_batch_idx_train = ckpt['global_batch_idx_train'] logging.info(f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}") for epoch in range(start_epoch, num_epochs): train_sampler.set_epoch(epoch) curr_learning_rate = 1e-3 # curr_learning_rate = learning_rate * pow(0.4, epoch) # for param_group in optimizer.param_groups: # param_group['lr'] = curr_learning_rate tb_writer.add_scalar('learning_rate', curr_learning_rate, epoch) logging.info('epoch {}, learning rate {}'.format( epoch, curr_learning_rate)) objf, valid_objf, global_batch_idx_train = train_one_epoch(dataloader=train_dl, valid_dataloader=valid_dl, model=model, device=device, graph_compiler=graph_compiler, optimizer=optimizer, current_epoch=epoch, tb_writer=tb_writer, num_epochs=num_epochs, global_batch_idx_train=global_batch_idx_train) # the lower, the better if valid_objf < best_valid_objf: best_valid_objf = valid_objf best_objf = objf best_epoch = epoch save_checkpoint(filename=best_model_path, model=model, epoch=epoch, optimizer=None, scheduler=None, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train) save_training_info(filename=best_epoch_info_filename, model_path=best_model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=best_objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch) # we always save the model for every epoch model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch)) save_checkpoint(filename=model_path, model=model, optimizer=optimizer, scheduler=None, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train) epoch_info_filename = os.path.join(exp_dir, 'epoch-{}-info'.format(epoch)) save_training_info(filename=epoch_info_filename, model_path=model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch) logging.warning('Done')
def main(): fix_random_seed(42) exp_dir = f'exp-lstm-adam-mmi-bigram-musan' setup_logger('{}/log/log-train'.format(exp_dir)) tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') # load L, G, symbol_table lang_dir = Path('data/lang_nosp') phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') word_symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt') logging.info("Loading L.fst") if (lang_dir / 'Linv.pt').exists(): L_inv = k2.Fsa.from_dict(torch.load(lang_dir / 'Linv.pt')) else: with open(lang_dir / 'L.fst.txt') as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) L_inv = k2.arc_sort(L.invert_()) torch.save(L_inv.as_dict(), lang_dir / 'Linv.pt') graph_compiler = MmiTrainingGraphCompiler(L_inv=L_inv, phones=phone_symbol_table, words=word_symbol_table) phone_ids = get_phone_symbols(phone_symbol_table) P = create_bigram_phone_lm(phone_ids) P.scores = torch.zeros_like(P.scores) # load dataset feature_dir = Path('exp/data') logging.info("About to get train cuts") cuts_train = CutSet.from_json(feature_dir / 'cuts_train.json.gz') logging.info("About to get dev cuts") cuts_dev = CutSet.from_json(feature_dir / 'cuts_dev.json.gz') logging.info("About to get Musan cuts") cuts_musan = CutSet.from_json(feature_dir / 'cuts_musan.json.gz') logging.info("About to create train dataset") train = K2SpeechRecognitionIterableDataset(cuts_train, max_frames=30000, shuffle=True, aug_cuts=cuts_musan, aug_prob=0.5, aug_snr=(10, 20)) logging.info("About to create dev dataset") validate = K2SpeechRecognitionIterableDataset(cuts_dev, max_frames=30000, shuffle=False, concat_cuts=False) logging.info("About to create train dataloader") train_dl = torch.utils.data.DataLoader(train, batch_size=None, num_workers=2) logging.info("About to create dev dataloader") valid_dl = torch.utils.data.DataLoader(validate, batch_size=None, num_workers=1) if not torch.cuda.is_available(): logging.error('No GPU detected!') sys.exit(-1) logging.info("About to create model") device_id = 0 device = torch.device('cuda', device_id) model = TdnnLstm1b( num_features=40, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=3) model.P_scores = nn.Parameter(P.scores.clone(), requires_grad=True) learning_rate = 1e-3 start_epoch = 0 num_epochs = 10 best_objf = np.inf best_epoch = start_epoch best_model_path = os.path.join(exp_dir, 'best_model.pt') best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info') global_batch_idx_train = 0 # for logging only global_batch_idx_valid = 0 # for logging only if start_epoch > 0: model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(start_epoch - 1)) (epoch, learning_rate, objf) = load_checkpoint(filename=model_path, model=model) best_objf = objf logging.info("epoch = {}, objf = {}".format(epoch, objf)) model.to(device) describe(model) # optimizer = optim.SGD(model.parameters(), # lr=learning_rate, # momentum=0.9, # weight_decay=5e-4) optimizer = optim.AdamW( model.parameters(), # lr=learning_rate, weight_decay=5e-4) curr_learning_rate = learning_rate for epoch in range(start_epoch, num_epochs): # curr_learning_rate = learning_rate * pow(0.4, epoch) if epoch > 6: curr_learning_rate *= 0.8 for param_group in optimizer.param_groups: param_group['lr'] = curr_learning_rate tb_writer.add_scalar('learning_rate', curr_learning_rate, epoch) logging.info('epoch {}, learning rate {}'.format( epoch, curr_learning_rate)) objf = train_one_epoch(dataloader=train_dl, valid_dataloader=valid_dl, model=model, P=P, device=device, graph_compiler=graph_compiler, optimizer=optimizer, current_epoch=epoch, tb_writer=tb_writer, num_epochs=num_epochs, global_batch_idx_train=global_batch_idx_train, global_batch_idx_valid=global_batch_idx_valid) # the lower, the better if objf < best_objf: best_objf = objf best_epoch = epoch save_checkpoint(filename=best_model_path, model=model, epoch=epoch, learning_rate=curr_learning_rate, objf=objf) save_training_info(filename=best_epoch_info_filename, model_path=best_model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=best_objf, best_objf=best_objf, best_epoch=best_epoch) # we always save the model for every epoch model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch)) save_checkpoint(filename=model_path, model=model, epoch=epoch, learning_rate=curr_learning_rate, objf=objf) epoch_info_filename = os.path.join(exp_dir, 'epoch-{}-info'.format(epoch)) save_training_info(filename=epoch_info_filename, model_path=model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, best_epoch=best_epoch) logging.warning('Done')
def main(): # load L, G, symbol_table lang_dir = 'data/lang_nosp' symbol_table = k2.SymbolTable.from_file(lang_dir + '/words.txt') ## This commented code created LG. We don't need that there. ## There were problems with disambiguation symbols; the G has ## disambiguation symbols which L.fst doesn't support. # if not os.path.exists(lang_dir + '/LG.pt'): # print("Loading L.fst.txt") # with open(lang_dir + '/L.fst.txt') as f: # L = k2.Fsa.from_openfst(f.read(), acceptor=False) # print("Loading G.fsa.txt") # with open(lang_dir + '/G.fsa.txt') as f: # G = k2.Fsa.from_openfst(f.read(), acceptor=True) # print("Arc-sorting L...") # L = k2.arc_sort(L.invert_()) # G = k2.arc_sort(G) # print(k2.is_arc_sorted(k2.get_properties(L))) # print(k2.is_arc_sorted(k2.get_properties(G))) # print("Intersecting L and G") # graph = k2.intersect(L, G) # graph = k2.arc_sort(graph) # print(k2.is_arc_sorted(k2.get_properties(graph))) # torch.save(graph.as_dict(), lang_dir + '/LG.pt') # else: # d = torch.load(lang_dir + '/LG.pt') # print("Loading pre-prepared LG") # graph = k2.Fsa.from_dict(d) print("Loading L.fst") if os.path.exists(lang_dir + '/Linv.pt'): L = k2.Fsa.from_dict(torch.load(lang_dir + '/Linv.pt')) else: with open(lang_dir + '/L.fst.txt') as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) L = k2.arc_sort(L.invert_()) torch.save(L.as_dict(), lang_dir + '/Linv.pt') # load dataset feature_dir = 'exp/data' print("About to get train cuts") cuts_train = CutSet.from_json(feature_dir + '/cuts_train-clean-100.json.gz') print("About to get dev cuts") cuts_dev = CutSet.from_json(feature_dir + '/cuts_dev-clean.json.gz') print("About to create train dataset") train = K2SpeechRecognitionIterableDataset(cuts_train, shuffle=True) print("About to create dev dataset") validate = K2SpeechRecognitionIterableDataset(cuts_dev, shuffle=False) print("About to create train dataloader") train_dl = torch.utils.data.DataLoader(train, batch_size=None, num_workers=1) print("About to create dev dataloader") valid_dl = torch.utils.data.DataLoader(validate, batch_size=None, num_workers=1) exp_dir = 'exp' setup_logger('{}/log/log-train'.format(exp_dir)) if not torch.cuda.is_available(): logging.error('No GPU detected!') sys.exit(-1) print("About to create model") device_id = 0 device = torch.device('cuda', device_id) model = Model(num_features=40, num_classes=364) model.to(device) learning_rate = 0.0001 start_epoch = 0 num_epochs = 3 best_objf = 100000 best_epoch = start_epoch best_model_path = os.path.join(exp_dir, 'best_model.pt') best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info') optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4) #optimizer = optim.SGD(model.parameters(), # lr=learning_rate, # momentum=0.9, # weight_decay=5e-4) for epoch in range(start_epoch, num_epochs): curr_learning_rate = learning_rate * pow(0.4, epoch) for param_group in optimizer.param_groups: param_group['lr'] = curr_learning_rate logging.info('epoch {}, learning rate {}'.format( epoch, curr_learning_rate)) objf = train_one_epoch(dataloader=train_dl, valid_dataloader=valid_dl, model=model, device=device, L=L, symbols=symbol_table, optimizer=optimizer, current_epoch=epoch, num_epochs=num_epochs) if objf < best_objf: best_objf = objf best_epoch = epoch save_checkpoint(filename=best_model_path, model=model, epoch=epoch, learning_rate=curr_learning_rate, objf=objf) save_training_info(filename=best_epoch_info_filename, model_path=best_model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=best_objf, best_objf=best_objf, best_epoch=best_epoch) # we always save the model for every epoch model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch)) save_checkpoint(filename=model_path, model=model, epoch=epoch, learning_rate=curr_learning_rate, objf=objf) epoch_info_filename = os.path.join(exp_dir, 'epoch-{}-info'.format(epoch)) save_training_info(filename=epoch_info_filename, model_path=model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, best_epoch=best_epoch) logging.warning('Done')
def main(): fix_random_seed(42) if not torch.cuda.is_available(): logging.error("No GPU detected!") sys.exit(-1) device_id = 0 device = torch.device("cuda", device_id) # Reserve the GPU with a dummy variable reserve_variable = torch.ones(1).to(device) start_epoch = 0 num_epochs = 100 exp_dir = "exp-tl1a-adam-xent" setup_logger("{}/log/log-train".format(exp_dir)) tb_writer = SummaryWriter(log_dir=f"{exp_dir}/tensorboard") # load dataset feature_dir = Path("exp/data") logging.info("About to get train cuts") cuts_train = CutSet.from_json(feature_dir / "cuts_train.json.gz") logging.info("About to get dev cuts") cuts_dev = CutSet.from_json(feature_dir / "cuts_dev.json.gz") logging.info("About to create train dataset") train = K2VadDataset(cuts_train) train_sampler = SingleCutSampler( cuts_train, max_frames=90000, shuffle=True, ) logging.info("About to create train dataloader") train_dl = torch.utils.data.DataLoader(train, sampler=train_sampler, batch_size=None, num_workers=4) logging.info("About to create dev dataset") validate = K2VadDataset(cuts_dev) valid_sampler = SingleCutSampler(cuts_dev, max_frames=90000) logging.info("About to create dev dataloader") valid_dl = torch.utils.data.DataLoader(validate, sampler=valid_sampler, batch_size=None, num_workers=1) logging.info("About to create model") model = TdnnLstm1a( num_features=80, num_classes=2, # speech/silence subsampling_factor=1, ) model.to(device) describe(model) learning_rate = 1e-4 optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=5e-4) best_objf = np.inf best_valid_objf = np.inf best_epoch = start_epoch best_model_path = os.path.join(exp_dir, "best_model.pt") best_epoch_info_filename = os.path.join(exp_dir, "best-epoch-info") global_batch_idx_train = 0 # for logging only if start_epoch > 0: model_path = os.path.join(exp_dir, "epoch-{}.pt".format(start_epoch - 1)) ckpt = load_checkpoint(filename=model_path, model=model, optimizer=optimizer) best_objf = ckpt["objf"] best_valid_objf = ckpt["valid_objf"] global_batch_idx_train = ckpt["global_batch_idx_train"] logging.info( f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}" ) for epoch in range(start_epoch, num_epochs): train_sampler.set_epoch(epoch) curr_learning_rate = learning_rate tb_writer.add_scalar("learning_rate", curr_learning_rate, epoch) logging.info("epoch {}, learning rate {}".format( epoch, curr_learning_rate)) objf, valid_objf, global_batch_idx_train = train_one_epoch( dataloader=train_dl, valid_dataloader=valid_dl, model=model, device=device, optimizer=optimizer, current_epoch=epoch, tb_writer=tb_writer, num_epochs=num_epochs, global_batch_idx_train=global_batch_idx_train, ) # the lower, the better if valid_objf < best_valid_objf: best_valid_objf = valid_objf best_objf = objf best_epoch = epoch save_checkpoint( filename=best_model_path, model=model, epoch=epoch, optimizer=None, scheduler=None, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train, ) save_training_info( filename=best_epoch_info_filename, model_path=best_model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=best_objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch, ) # we always save the model for every epoch model_path = os.path.join(exp_dir, "epoch-{}.pt".format(epoch)) save_checkpoint( filename=model_path, model=model, optimizer=optimizer, scheduler=None, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train, ) epoch_info_filename = os.path.join(exp_dir, "epoch-{}-info".format(epoch)) save_training_info( filename=epoch_info_filename, model_path=model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch, ) logging.warning("Done")
def __call__(self, cuts: CutSet) -> CutSet: return CutSet.from_cuts( cut.perturb_speed(factor=self.random.choice(self.factors) ) if self.random.random() >= self.p else cut for cut in cuts)
def main(): args = get_parser().parse_args() exp_dir = Path('exp-lstm-adam-mmi-bigram-musan-dist') setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug') # load L, G, symbol_table lang_dir = Path('data/lang_nosp') lexicon = Lexicon(lang_dir) phone_ids = lexicon.phone_symbols() phone_ids_with_blank = [0] + phone_ids ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank)) logging.debug("About to load model") # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N # device = torch.device('cuda', 1) device = torch.device('cuda') model = TdnnLstm1b( num_features=80, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=3) checkpoint = os.path.join(exp_dir, f'epoch-{args.epoch}.pt') load_checkpoint(checkpoint, model) model.to(device) model.eval() if not os.path.exists(lang_dir / 'HLG.pt'): logging.debug("Loading L_disambig.fst.txt") with open(lang_dir / 'L_disambig.fst.txt') as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) logging.debug("Loading G.fst.txt") with open(lang_dir / 'G.fst.txt') as f: G = k2.Fsa.from_openfst(f.read(), acceptor=False) first_phone_disambig_id = find_first_disambig_symbol(lexicon.phones) first_word_disambig_id = find_first_disambig_symbol(lexicon.words) HLG = compile_HLG(L=L, G=G, H=ctc_topo, labels_disambig_id_start=first_phone_disambig_id, aux_labels_disambig_id_start=first_word_disambig_id) torch.save(HLG.as_dict(), lang_dir / 'HLG.pt') else: logging.debug("Loading pre-compiled LG") d = torch.load(lang_dir / 'HLG.pt') HLG = k2.Fsa.from_dict(d) # load dataset feature_dir = Path('exp/data') logging.debug("About to get test cuts") cuts_test = CutSet.from_json(feature_dir / 'cuts_test-clean.json.gz') logging.info("About to create test dataset") test = K2SpeechRecognitionDataset(cuts_test) sampler = SingleCutSampler(cuts_test, max_frames=40000) logging.info("About to create test dataloader") test_dl = torch.utils.data.DataLoader(test, batch_size=None, sampler=sampler, num_workers=1) # if not torch.cuda.is_available(): # logging.error('No GPU detected!') # sys.exit(-1) logging.debug("convert HLG to device") HLG = HLG.to(device) HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0) HLG.requires_grad_(False) logging.debug("About to decode") results = decode(dataloader=test_dl, model=model, device=device, HLG=HLG, symbols=lexicon.words) test_set = 'test-clean' recog_path = exp_dir / f'recogs-{test_set}.txt' store_transcripts(path=recog_path, texts=results) logging.info(f'The transcripts are stored in {recog_path}') # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = exp_dir / f'errs-{test_set}.txt' with open(errs_filename, 'w') as f: wer = write_error_stats(f, f'{test_set}', results) logging.info(f'The error stats are stored in {errs_filename}')
def run_test( rank: Optional[int], n_shards: int, root: str, world_size: Optional[int], expected_cut_ids: List[str], num_workers: int, ) -> None: # Initialize DDP if needed if world_size is not None: os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12354" torch.distributed.init_process_group( "gloo", world_size=world_size, rank=rank, ) # adjust the expected cut IDs according to rank expected_cut_ids_orig = expected_cut_ids expected_cut_ids = expected_cut_ids[rank::world_size] else: rank = None # Open CutSet with options that de-duplicate the data across nodes and workers cuts_wds = CutSet.from_webdataset( "%s/shard-{000000..%06d}.tar" % (root, n_shards - 1), split_by_node=True, split_by_worker=True, shuffle_shards=True, ) # Iterate the data tot = 0 cut_ids = [] sampler = SimpleCutSampler(cuts_wds, max_duration=100, rank=0, world_size=1) dloader = DataLoader( IterableDatasetWrapper(dataset=DummyDataset(), sampler=sampler), batch_size=None, num_workers=num_workers, worker_init_fn=make_worker_init_fn( rank=rank, world_size=world_size, ), ) for batch in dloader: tot += len(batch) for c in batch: cut_ids.append(c.id) print(f"[Rank {rank}/{world_size}] Actual cuts: ", sorted(cut_ids)) print(f"[Rank {rank}/{world_size}] Expected cuts: ", sorted(expected_cut_ids)) try: assert tot == len( expected_cut_ids), f"{tot} != {len(expected_cut_ids)}" assert sorted(cut_ids) == sorted( expected_cut_ids ), f"{sorted(cut_ids)}\n!=\n{sorted(expected_cut_ids)}" except AssertionError: # Pytest doesn't work great with subprocesses print(traceback.print_exc()) raise
def main(): assert False, 'We are still working on this script as it has some issues, so please do NOT try to run it for now.' exp_dir = Path('exp') setup_logger('{}/log/log-decode'.format(exp_dir)) # load L, G, symbol_table lang_dir = Path('data/lang_nosp') symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt') if not os.path.exists(lang_dir / 'LG.pt'): print("Loading L_disambig.fst.txt") with open(lang_dir / 'L_disambig.fst.txt') as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) print("Loading G.fsa.txt") with open(lang_dir / 'G.fsa.txt') as f: G = k2.Fsa.from_openfst(f.read(), acceptor=True) LG = compile_LG(L=L, G=G, labels_disambig_id_start=347, aux_labels_disambig_id_start=200004) torch.save(LG.as_dict(), lang_dir / 'LG.pt') else: print("Loading pre-compiled LG") d = torch.load(lang_dir / 'LG.pt') LG = k2.Fsa.from_dict(d) # load dataset feature_dir = exp_dir / 'data' print("About to get test cuts") cuts_test = CutSet.from_json(feature_dir / 'cuts_test-clean.json.gz') print("About to create test dataset") test = K2SpeechRecognitionIterableDataset(cuts_test, max_frames=100000, shuffle=False) print("About to create test dataloader") test_dl = torch.utils.data.DataLoader(test, batch_size=None, num_workers=1) if not torch.cuda.is_available(): logging.error('No GPU detected!') sys.exit(-1) print("About to load model") # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N device = torch.device('cuda') model = Tdnn1a(num_features=40, num_classes=364) checkpoint = os.path.join(exp_dir, 'epoch-9.pt') load_checkpoint(checkpoint, model) model.to(device) model.eval() LG = LG.to(device) LG.requires_grad_(False) results = decode(dataloader=test_dl, model=model, device=device, LG=LG, symbols=symbol_table) for ref, hyp in results: print('ref=', ref, ', hyp=', hyp) # compute WER dists = [edit_distance(r, h) for r, h in results] errors = { key: sum(dist[key] for dist in dists) for key in ['sub', 'ins', 'del', 'total'] } total_words = sum(len(ref) for ref, _ in results) # Print Kaldi-like message: # %WER 8.20 [ 4459 / 54402, 695 ins, 427 del, 3337 sub ] print( f'%WER {errors["total"] / total_words:.2%} ' f'[{errors["total"]} / {total_words}, {errors["ins"]} ins, {errors["del"]} del, {errors["sub"]} sub ]' )
def libri_cut_set(): cs = CutSet.from_json("test/fixtures/libri/cuts.json") return CutSet.from_cuts([ cs[0], cs[0].with_id("copy-1"), cs[0].with_id("copy-2"), cs[0].append(cs[0]) ])
def main(): args = get_parser().parse_args() epoch = args.epoch max_frames = args.max_frames avg = args.avg att_rate = args.att_rate exp_dir = Path('exp-transformer-noam-ctc-att-musan') setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug') # load L, G, symbol_table lang_dir = Path('data/lang_nosp') symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt') phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') phone_ids = get_phone_symbols(phone_symbol_table) phone_ids_with_blank = [0] + phone_ids ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank)) if not os.path.exists(lang_dir / 'LG.pt'): print("Loading L_disambig.fst.txt") with open(lang_dir / 'L_disambig.fst.txt') as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) print("Loading G.fst.txt") with open(lang_dir / 'G.fst.txt') as f: G = k2.Fsa.from_openfst(f.read(), acceptor=False) first_phone_disambig_id = find_first_disambig_symbol( phone_symbol_table) first_word_disambig_id = find_first_disambig_symbol(symbol_table) LG = compile_LG(L=L, G=G, ctc_topo=ctc_topo, labels_disambig_id_start=first_phone_disambig_id, aux_labels_disambig_id_start=first_word_disambig_id) torch.save(LG.as_dict(), lang_dir / 'LG.pt') else: print("Loading pre-compiled LG") d = torch.load(lang_dir / 'LG.pt') LG = k2.Fsa.from_dict(d) # load dataset feature_dir = Path('exp/data') print("About to get test cuts") cuts_test = CutSet.from_json(feature_dir / 'cuts_test-clean.json.gz') print("About to create test dataset") test = K2SpeechRecognitionDataset(cuts_test) sampler = SingleCutSampler(cuts_test, max_frames=max_frames) print("About to create test dataloader") test_dl = torch.utils.data.DataLoader(test, batch_size=None, sampler=sampler, num_workers=1) # if not torch.cuda.is_available(): # logging.error('No GPU detected!') # sys.exit(-1) print("About to load model") # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N # device = torch.device('cuda', 1) device = torch.device('cuda') if att_rate != 0.0: num_decoder_layers = 6 else: num_decoder_layers = 0 model = Transformer( num_features=40, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers) if avg == 1: checkpoint = os.path.join(exp_dir, 'epoch-' + str(epoch - 1) + '.pt') load_checkpoint(checkpoint, model) else: checkpoints = [ os.path.join(exp_dir, 'epoch-' + str(avg_epoch) + '.pt') for avg_epoch in range(epoch - avg, epoch) ] average_checkpoint(checkpoints, model) model.to(device) model.eval() print("convert LG to device") LG = LG.to(device) LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0) LG.requires_grad_(False) print("About to decode") results = decode(dataloader=test_dl, model=model, device=device, LG=LG, symbols=symbol_table) s = '' for ref, hyp in results: s += f'ref={ref}\n' s += f'hyp={hyp}\n' logging.info(s) # compute WER dists = [edit_distance(r, h) for r, h in results] errors = { key: sum(dist[key] for dist in dists) for key in ['sub', 'ins', 'del', 'total'] } total_words = sum(len(ref) for ref, _ in results) # Print Kaldi-like message: # %WER 8.20 [ 4459 / 54402, 695 ins, 427 del, 3337 sub ] logging.info( f'%WER {errors["total"] / total_words:.2%} ' f'[{errors["total"]} / {total_words}, {errors["ins"]} ins, {errors["del"]} del, {errors["sub"]} sub ]' )
def resampling_cuts(cut_with_supervision, cut_with_supervision_start01): return CutSet.from_cuts( [cut_with_supervision, cut_with_supervision_start01])
def main(): exp_dir = Path('exp-lstm-adam-mmi-bigram-musan') setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug') # load L, G, symbol_table lang_dir = Path('data/lang_nosp') symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt') phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') phone_ids = get_phone_symbols(phone_symbol_table) P = create_bigram_phone_lm(phone_ids) phone_ids_with_blank = [0] + phone_ids ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank)) logging.debug("About to load model") # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N # device = torch.device('cuda', 1) device = torch.device('cuda') model = TdnnLstm1b( num_features=40, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=3) model.P_scores = torch.nn.Parameter(P.scores.clone(), requires_grad=False) checkpoint = os.path.join(exp_dir, 'epoch-9.pt') load_checkpoint(checkpoint, model) model.to(device) model.eval() assert P.requires_grad is False P.scores = model.P_scores.cpu() print_transition_probabilities(P, phone_symbol_table, phone_ids, filename='model_P_scores.txt') P.set_scores_stochastic_(model.P_scores) print_transition_probabilities(P, phone_symbol_table, phone_ids, filename='P_scores.txt') if not os.path.exists(lang_dir / 'LG.pt'): logging.debug("Loading L_disambig.fst.txt") with open(lang_dir / 'L_disambig.fst.txt') as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) logging.debug("Loading G.fst.txt") with open(lang_dir / 'G.fst.txt') as f: G = k2.Fsa.from_openfst(f.read(), acceptor=False) first_phone_disambig_id = find_first_disambig_symbol( phone_symbol_table) first_word_disambig_id = find_first_disambig_symbol(symbol_table) LG = compile_LG(L=L, G=G, ctc_topo=ctc_topo, labels_disambig_id_start=first_phone_disambig_id, aux_labels_disambig_id_start=first_word_disambig_id) torch.save(LG.as_dict(), lang_dir / 'LG.pt') else: logging.debug("Loading pre-compiled LG") d = torch.load(lang_dir / 'LG.pt') LG = k2.Fsa.from_dict(d) # load dataset feature_dir = Path('exp/data') logging.debug("About to get test cuts") cuts_test = CutSet.from_json(feature_dir / 'cuts_test.json.gz') logging.info("About to create test dataset") test = K2SpeechRecognitionDataset(cuts_test) sampler = SingleCutSampler(cuts_test, max_frames=100000) logging.info("About to create test dataloader") test_dl = torch.utils.data.DataLoader(test, batch_size=None, sampler=sampler, num_workers=1) # if not torch.cuda.is_available(): # logging.error('No GPU detected!') # sys.exit(-1) logging.debug("convert LG to device") LG = LG.to(device) LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0) LG.requires_grad_(False) logging.debug("About to decode") results = decode(dataloader=test_dl, model=model, device=device, LG=LG, symbols=symbol_table) s = '' results2 = [] for ref, hyp in results: s += f'ref={ref}\n' s += f'hyp={hyp}\n' results2.append((list(''.join(ref)), list(''.join(hyp)))) logging.info(s) # compute WER dists = [edit_distance(r, h) for r, h in results] dists2 = [edit_distance(r, h) for r, h in results2] errors = { key: sum(dist[key] for dist in dists) for key in ['sub', 'ins', 'del', 'total'] } errors2 = { key: sum(dist[key] for dist in dists2) for key in ['sub', 'ins', 'del', 'total'] } total_words = sum(len(ref) for ref, _ in results) total_chars = sum(len(ref) for ref, _ in results2) # Print Kaldi-like message: # %WER 8.20 [ 4459 / 54402, 695 ins, 427 del, 3337 sub ] logging.info( f'%WER {errors["total"] / total_words:.2%} ' f'[{errors["total"]} / {total_words}, {errors["ins"]} ins, {errors["del"]} del, {errors["sub"]} sub ]' ) logging.info( f'%WER {errors2["total"] / total_chars:.2%} ' f'[{errors2["total"]} / {total_chars}, {errors2["ins"]} ins, {errors2["del"]} del, {errors2["sub"]} sub ]' )
def libri_cut_set(): cs = CutSet.from_json('test/fixtures/libri/cuts.json') return CutSet.from_cuts([ cs[0], cs[0].with_id('copy-1'), cs[0].with_id('copy-2'), cs[0].append(cs[0]) ])
def main(): args = get_parser().parse_args() print('World size:', args.world_size, 'Rank:', args.local_rank) setup_dist(rank=args.local_rank, world_size=args.world_size) fix_random_seed(42) exp_dir = f'exp-lstm-adam-mmi-bigram-musan-dist' setup_logger('{}/log/log-train'.format(exp_dir), use_console=args.local_rank == 0) tb_writer = SummaryWriter( log_dir=f'{exp_dir}/tensorboard') if args.local_rank == 0 else None # load L, G, symbol_table lang_dir = Path('data/lang_nosp') phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') word_symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt') logging.info("Loading L.fst") if (lang_dir / 'Linv.pt').exists(): L_inv = k2.Fsa.from_dict(torch.load(lang_dir / 'Linv.pt')) else: with open(lang_dir / 'L.fst.txt') as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) L_inv = k2.arc_sort(L.invert_()) torch.save(L_inv.as_dict(), lang_dir / 'Linv.pt') graph_compiler = MmiTrainingGraphCompiler(L_inv=L_inv, phones=phone_symbol_table, words=word_symbol_table) phone_ids = get_phone_symbols(phone_symbol_table) P = create_bigram_phone_lm(phone_ids) P.scores = torch.zeros_like(P.scores) # load dataset feature_dir = Path('exp/data') logging.info("About to get train cuts") cuts_train = CutSet.from_json(feature_dir / 'cuts_train-clean-100.json.gz') logging.info("About to get dev cuts") cuts_dev = CutSet.from_json(feature_dir / 'cuts_dev-clean.json.gz') logging.info("About to get Musan cuts") cuts_musan = CutSet.from_json(feature_dir / 'cuts_musan.json.gz') logging.info("About to create train dataset") transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] if not args.bucketing_sampler: # We don't mix concatenating the cuts and bucketing # Here we insert concatenation before mixing so that the # noises from Musan are mixed onto almost-zero-energy # padding frames. transforms = [CutConcatenate()] + transforms train = K2SpeechRecognitionDataset(cuts_train, cut_transforms=transforms) if args.bucketing_sampler: logging.info('Using BucketingSampler.') train_sampler = BucketingSampler(cuts_train, max_frames=40000, shuffle=True, num_buckets=30) else: logging.info('Using regular sampler with cut concatenation.') train_sampler = SingleCutSampler( cuts_train, max_frames=30000, shuffle=True, ) logging.info("About to create train dataloader") train_dl = torch.utils.data.DataLoader(train, sampler=train_sampler, batch_size=None, num_workers=4) logging.info("About to create dev dataset") validate = K2SpeechRecognitionDataset(cuts_dev) # Note: we explicitly set world_size to 1 to disable the auto-detection of # distributed training inside the sampler. This way, every GPU will # perform the computation on the full dev set. It is a bit wasteful, # but unfortunately loss aggregation between multiple processes with # torch.distributed.all_reduce() tends to hang indefinitely inside # NCCL after ~3000 steps. With the current approach, we can still report # the loss on the full validation set. valid_sampler = SingleCutSampler(cuts_dev, max_frames=90000, world_size=1, rank=0) logging.info("About to create dev dataloader") valid_dl = torch.utils.data.DataLoader(validate, sampler=valid_sampler, batch_size=None, num_workers=1) if not torch.cuda.is_available(): logging.error('No GPU detected!') sys.exit(-1) logging.info("About to create model") device_id = args.local_rank device = torch.device('cuda', device_id) model = TdnnLstm1b( num_features=40, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=3) model.P_scores = nn.Parameter(P.scores.clone(), requires_grad=True) start_epoch = 0 num_epochs = 10 best_objf = np.inf best_valid_objf = np.inf best_epoch = start_epoch best_model_path = os.path.join(exp_dir, 'best_model.pt') best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info') global_batch_idx_train = 0 # for logging only use_adam = True if start_epoch > 0: model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(start_epoch - 1)) ckpt = load_checkpoint(filename=model_path, model=model) best_objf = ckpt['objf'] best_valid_objf = ckpt['valid_objf'] global_batch_idx_train = ckpt['global_batch_idx_train'] logging.info( f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}" ) model.to(device) if args.world_size > 1: logging.info( 'Using DistributedDataParallel in training. ' 'The reported loss, num_frames, etc. for training steps include ' 'only the batches seen in the master process (the actual loss ' 'includes batches from all GPUs, and the actual num_frames is ' f'approx. {args.world_size}x larger.') # For now do not sync BatchNorm across GPUs due to NCCL hanging in all_gather... # model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank) describe(model) if use_adam: learning_rate = 1e-3 weight_decay = 5e-4 optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) # Equivalent to the following in the epoch loop: # if epoch > 6: # curr_learning_rate *= 0.8 lr_scheduler = optim.lr_scheduler.LambdaLR( optimizer, lambda ep: 1.0 if ep < 7 else 0.8**(ep - 6)) else: learning_rate = 5e-5 weight_decay = 1e-5 momentum = 0.9 lr_schedule_gamma = 0.7 optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay) lr_scheduler = optim.lr_scheduler.ExponentialLR( optimizer=optimizer, gamma=lr_schedule_gamma, last_epoch=start_epoch - 1) for epoch in range(start_epoch, num_epochs): train_sampler.set_epoch(epoch) # LR scheduler can hold multiple learning rates for multiple parameter groups; # For now we report just the first LR which we assume concerns most of the parameters. curr_learning_rate = lr_scheduler.get_last_lr()[0] if tb_writer is not None: tb_writer.add_scalar('train/learning_rate', curr_learning_rate, global_batch_idx_train) tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train) logging.info('epoch {}, learning rate {}'.format( epoch, curr_learning_rate)) objf, valid_objf, global_batch_idx_train = train_one_epoch( dataloader=train_dl, valid_dataloader=valid_dl, model=model, P=P, device=device, graph_compiler=graph_compiler, optimizer=optimizer, current_epoch=epoch, tb_writer=tb_writer, num_epochs=num_epochs, global_batch_idx_train=global_batch_idx_train, ) # the lower, the better if valid_objf < best_valid_objf: best_valid_objf = valid_objf best_objf = objf best_epoch = epoch save_checkpoint(filename=best_model_path, model=model, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, local_rank=args.local_rank, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train) save_training_info(filename=best_epoch_info_filename, model_path=best_model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch) # we always save the model for every epoch model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch)) save_checkpoint(filename=model_path, model=model, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, local_rank=args.local_rank, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train) epoch_info_filename = os.path.join(exp_dir, 'epoch-{}-info'.format(epoch)) save_training_info(filename=epoch_info_filename, model_path=model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch) lr_scheduler.step() logging.warning('Done') cleanup_dist()
def main(): fix_random_seed(42) start_epoch = 0 num_epochs = 10 use_adam = True exp_dir = f'exp-lstm-adam-mmi-mbr-musan' setup_logger('{}/log/log-train'.format(exp_dir)) tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') if not torch.cuda.is_available(): logging.warn('No GPU detected!') logging.warn('USE CPU (very slow)!') device = torch.device('cpu') else: logging.info('Use GPU') device_id = 0 device = torch.device('cuda', device_id) # load L, G, symbol_table lang_dir = Path('data/lang_nosp') phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') word_symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt') logging.info("Loading L.fst") if (lang_dir / 'Linv.pt').exists(): logging.info('Loading precompiled L') L_inv = k2.Fsa.from_dict(torch.load(lang_dir / 'Linv.pt')) else: logging.info('Compiling L') with open(lang_dir / 'L.fst.txt') as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) L_inv = k2.arc_sort(L.invert_()) torch.save(L_inv.as_dict(), lang_dir / 'Linv.pt') logging.info("Loading L_disambig.fst") if (lang_dir / 'L_disambig.pt').exists(): logging.info('Loading precompiled L_disambig') L_disambig = k2.Fsa.from_dict(torch.load(lang_dir / 'L_disambig.pt')) else: logging.info('Compiling L_disambig') with open(lang_dir / 'L_disambig.fst.txt') as f: L_disambig = k2.Fsa.from_openfst(f.read(), acceptor=False) L_disambig = k2.arc_sort(L_disambig) torch.save(L_disambig.as_dict(), lang_dir / 'L_disambig.pt') logging.info("Loading G.fst") if (lang_dir / 'G_uni.pt').exists(): logging.info('Loading precompiled G') G = k2.Fsa.from_dict(torch.load(lang_dir / 'G_uni.pt')) else: logging.info('Compiling G') with open(lang_dir / 'G_uni.fst.txt') as f: G = k2.Fsa.from_openfst(f.read(), acceptor=False) G = k2.arc_sort(G) torch.save(G.as_dict(), lang_dir / 'G_uni.pt') graph_compiler = MmiMbrTrainingGraphCompiler( L_inv=L_inv, L_disambig=L_disambig, G=G, device=device, phones=phone_symbol_table, words=word_symbol_table ) phone_ids = get_phone_symbols(phone_symbol_table) P = create_bigram_phone_lm(phone_ids) P.scores = torch.zeros_like(P.scores) # load dataset feature_dir = Path('exp/data') logging.info("About to get train cuts") cuts_train = CutSet.from_json(feature_dir / 'cuts_train-clean-100.json.gz') logging.info("About to get dev cuts") cuts_dev = CutSet.from_json(feature_dir / 'cuts_dev-clean.json.gz') logging.info("About to get Musan cuts") cuts_musan = CutSet.from_json(feature_dir / 'cuts_musan.json.gz') logging.info("About to create train dataset") train = K2SpeechRecognitionDataset( cuts_train, cut_transforms=[ CutConcatenate(), CutMix( cuts=cuts_musan, prob=0.5, snr=(10, 20) ) ] ) train_sampler = SingleCutSampler( cuts_train, max_frames=30000, shuffle=True, ) logging.info("About to create train dataloader") train_dl = torch.utils.data.DataLoader( train, sampler=train_sampler, batch_size=None, num_workers=4 ) logging.info("About to create dev dataset") validate = K2SpeechRecognitionDataset(cuts_dev) valid_sampler = SingleCutSampler(cuts_dev, max_frames=60000) logging.info("About to create dev dataloader") valid_dl = torch.utils.data.DataLoader( validate, sampler=valid_sampler, batch_size=None, num_workers=1 ) logging.info("About to create model") model = TdnnLstm1b(num_features=40, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=3) model.P_scores = nn.Parameter(P.scores.clone(), requires_grad=True) model.to(device) describe(model) P = P.to(device) if use_adam: learning_rate = 1e-3 weight_decay = 5e-4 optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) # Equivalent to the following in the epoch loop: # if epoch > 6: # curr_learning_rate *= 0.8 lr_scheduler = optim.lr_scheduler.LambdaLR( optimizer, lambda ep: 1.0 if ep < 7 else 0.8 ** (ep - 6) ) else: learning_rate = 5e-5 weight_decay = 1e-5 momentum = 0.9 lr_schedule_gamma = 0.7 optimizer = optim.SGD( model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay ) lr_scheduler = optim.lr_scheduler.ExponentialLR( optimizer=optimizer, gamma=lr_schedule_gamma ) best_objf = np.inf best_valid_objf = np.inf best_epoch = start_epoch best_model_path = os.path.join(exp_dir, 'best_model.pt') best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info') global_batch_idx_train = 0 # for logging only if start_epoch > 0: model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(start_epoch - 1)) ckpt = load_checkpoint(filename=model_path, model=model, optimizer=optimizer, scheduler=lr_scheduler) best_objf = ckpt['objf'] best_valid_objf = ckpt['valid_objf'] global_batch_idx_train = ckpt['global_batch_idx_train'] logging.info(f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}") for epoch in range(start_epoch, num_epochs): # LR scheduler can hold multiple learning rates for multiple parameter groups; # For now we report just the first LR which we assume concerns most of the parameters. train_sampler.set_epoch(epoch) curr_learning_rate = lr_scheduler.get_last_lr()[0] tb_writer.add_scalar('train/learning_rate', curr_learning_rate, global_batch_idx_train) tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train) logging.info('epoch {}, learning rate {}'.format(epoch, curr_learning_rate)) objf, valid_objf, global_batch_idx_train = train_one_epoch( dataloader=train_dl, valid_dataloader=valid_dl, model=model, P=P, device=device, graph_compiler=graph_compiler, optimizer=optimizer, current_epoch=epoch, tb_writer=tb_writer, num_epochs=num_epochs, global_batch_idx_train=global_batch_idx_train, ) lr_scheduler.step() # the lower, the better if valid_objf < best_valid_objf: best_valid_objf = valid_objf best_objf = objf best_epoch = epoch save_checkpoint(filename=best_model_path, model=model, optimizer=None, scheduler=None, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train) save_training_info(filename=best_epoch_info_filename, model_path=best_model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch) # we always save the model for every epoch model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch)) save_checkpoint(filename=model_path, model=model, optimizer=optimizer, scheduler=lr_scheduler, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train) epoch_info_filename = os.path.join(exp_dir, 'epoch-{}-info'.format(epoch)) save_training_info(filename=epoch_info_filename, model_path=model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch) logging.warning('Done')
def main(): args = get_parser().parse_args() start_epoch = args.start_epoch num_epochs = args.num_epochs max_frames = args.max_frames accum_grad = args.accum_grad den_scale = args.den_scale att_rate = args.att_rate fix_random_seed(42) exp_dir = Path('exp-transformer-noam-mmi-att-musan') setup_logger('{}/log/log-train'.format(exp_dir)) tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') # load L, G, symbol_table lang_dir = Path('data/lang_nosp') phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') word_symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt') logging.info("Loading L.fst") if (lang_dir / 'Linv.pt').exists(): L_inv = k2.Fsa.from_dict(torch.load(lang_dir / 'Linv.pt')) else: with open(lang_dir / 'L.fst.txt') as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) L_inv = k2.arc_sort(L.invert_()) torch.save(L_inv.as_dict(), lang_dir / 'Linv.pt') graph_compiler = MmiTrainingGraphCompiler(L_inv=L_inv, phones=phone_symbol_table, words=word_symbol_table) phone_ids = get_phone_symbols(phone_symbol_table) P = create_bigram_phone_lm(phone_ids) P.scores = torch.zeros_like(P.scores) # load dataset feature_dir = Path('exp/data') logging.info("About to get train cuts") cuts_train = CutSet.from_json(feature_dir / 'cuts_train-clean-100.json.gz') logging.info("About to get dev cuts") cuts_dev = CutSet.from_json(feature_dir / 'cuts_dev-clean.json.gz') logging.info("About to get Musan cuts") cuts_musan = CutSet.from_json(feature_dir / 'cuts_musan.json.gz') logging.info("About to create train dataset") transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] if not args.bucketing_sampler: # We don't mix concatenating the cuts and bucketing # Here we insert concatenation before mixing so that the # noises from Musan are mixed onto almost-zero-energy # padding frames. transforms = [CutConcatenate()] + transforms train = K2SpeechRecognitionDataset(cuts_train, cut_transforms=transforms) if args.bucketing_sampler: logging.info('Using BucketingSampler.') train_sampler = BucketingSampler(cuts_train, max_frames=max_frames, shuffle=True, num_buckets=args.num_buckets) else: logging.info('Using regular sampler with cut concatenation.') train_sampler = SingleCutSampler( cuts_train, max_frames=max_frames, shuffle=True, ) logging.info("About to create train dataloader") train_dl = torch.utils.data.DataLoader(train, sampler=train_sampler, batch_size=None, num_workers=4) logging.info("About to create dev dataset") validate = K2SpeechRecognitionDataset(cuts_dev) valid_sampler = SingleCutSampler(cuts_dev, max_frames=max_frames) logging.info("About to create dev dataloader") valid_dl = torch.utils.data.DataLoader(validate, sampler=valid_sampler, batch_size=None, num_workers=1) if not torch.cuda.is_available(): logging.error('No GPU detected!') sys.exit(-1) logging.info("About to create model") device_id = 0 device = torch.device('cuda', device_id) if att_rate != 0.0: num_decoder_layers = 6 else: num_decoder_layers = 0 model = Transformer( num_features=40, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers) model.P_scores = nn.Parameter(P.scores.clone(), requires_grad=True) model.to(device) describe(model) optimizer = Noam(model.parameters(), model_size=256, factor=1.0, warm_step=args.warm_step) best_objf = np.inf best_valid_objf = np.inf best_epoch = start_epoch best_model_path = os.path.join(exp_dir, 'best_model.pt') best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info') global_batch_idx_train = 0 # for logging only if start_epoch > 0: model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(start_epoch - 1)) ckpt = load_checkpoint(filename=model_path, model=model, optimizer=optimizer) best_objf = ckpt['objf'] best_valid_objf = ckpt['valid_objf'] global_batch_idx_train = ckpt['global_batch_idx_train'] logging.info( f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}" ) for epoch in range(start_epoch, num_epochs): train_sampler.set_epoch(epoch) curr_learning_rate = optimizer._rate tb_writer.add_scalar('train/learning_rate', curr_learning_rate, global_batch_idx_train) tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train) logging.info('epoch {}, learning rate {}'.format( epoch, curr_learning_rate)) objf, valid_objf, global_batch_idx_train = train_one_epoch( dataloader=train_dl, valid_dataloader=valid_dl, model=model, P=P, device=device, graph_compiler=graph_compiler, optimizer=optimizer, accum_grad=accum_grad, den_scale=den_scale, att_rate=att_rate, current_epoch=epoch, tb_writer=tb_writer, num_epochs=num_epochs, global_batch_idx_train=global_batch_idx_train, ) # the lower, the better if valid_objf < best_valid_objf: best_valid_objf = valid_objf best_objf = objf best_epoch = epoch save_checkpoint(filename=best_model_path, optimizer=None, scheduler=None, model=model, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train) save_training_info(filename=best_epoch_info_filename, model_path=best_model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch) # we always save the model for every epoch model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch)) save_checkpoint(filename=model_path, optimizer=optimizer, scheduler=None, model=model, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train) epoch_info_filename = os.path.join(exp_dir, 'epoch-{}-info'.format(epoch)) save_training_info(filename=epoch_info_filename, model_path=model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch) logging.warning('Done')
def main(): if not torch.cuda.is_available(): logging.error("No GPU detected!") sys.exit(-1) device_id = 0 device = torch.device("cuda", device_id) # Reserve the GPU with a dummy variable reserve_variable = torch.ones(1).to(device) exp_dir = Path("exp-tl1a-adam-xent") setup_logger("{}/log/log-decode".format(exp_dir), log_level="debug") if not os.path.exists(exp_dir / "HCLG.pt"): logging.info("Preparing decoding graph") # sym_str = """ # <eps> 0 # silence 1 # speech 2 # """ # symbol_table = k2.SymbolTable.from_str(sym_str) HCLG = prepare_decoding_graph( min_silence_duration=0.03, min_speech_duration=0.3, max_speech_duration=10.0, ) # Arc sort the HCLG since it is needed for intersect logging.info("Sorting decoding graph by outgoing arcs") HCLG = k2.arc_sort(HCLG) # HCLG.symbols = symbol_table torch.save(HCLG.as_dict(), exp_dir / "HCLG.pt") else: logging.info("Loading pre-compiled decoding graph") d = torch.load(exp_dir / "HCLG.pt") HCLG = k2.Fsa.from_dict(d) # load dataset feature_dir = Path("exp/data") logging.info("About to get test cuts") cuts_test = CutSet.from_json(feature_dir / "cuts_test.json.gz") logging.info("About to create test dataset") test = K2VadDataset(cuts_test, return_cuts=True) sampler = SingleCutSampler(cuts_test, max_frames=100000) logging.info("About to create test dataloader") test_dl = torch.utils.data.DataLoader(test, batch_size=None, sampler=sampler, num_workers=1) logging.info("About to load model") model = TdnnLstm1a( num_features=80, num_classes=2, # speech/silence subsampling_factor=1, ) checkpoint = os.path.join(exp_dir, "best_model.pt") load_checkpoint(checkpoint, model) model.to(device) model.eval() logging.info("convert decoding graph to device") HCLG = HCLG.to(device) HCLG.requires_grad_(False) logging.info("About to decode") results = decode(dataloader=test_dl, model=model, device=device, HCLG=HCLG) # Compute frame-level accuracy and precision-recall metrics y_true = [] y_pred = [] for result in results: cut, ref, hyp = result y_true.append(ref) y_pred.append(hyp) y_true = torch.cat(y_true, dim=0).numpy() y_pred = torch.cat(y_pred, dim=0).numpy() logging.info("Results: \n{}".format( classification_report(y_true, y_pred, target_names=["silence", "speech"]))) # Create output segments per recording create_and_write_segments( [result[0] for result in results], # cuts [result[2] for result in results], # outputs exp_dir / "segments", # segments file )
def main(): exp_dir = Path('exp-lstm-adam') setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug') # load L, G, symbol_table lang_dir = Path('data/lang_nosp') symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt') phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') ctc_topo = build_ctc_topo(list(phone_symbol_table._id2sym.keys())) ctc_topo = k2.arc_sort(ctc_topo) if not os.path.exists(lang_dir / 'LG.pt'): print("Loading L_disambig.fst.txt") with open(lang_dir / 'L_disambig.fst.txt') as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) print("Loading G.fsa.txt") with open(lang_dir / 'G.fsa.txt') as f: G = k2.Fsa.from_openfst(f.read(), acceptor=True) first_phone_disambig_id = find_first_disambig_symbol( phone_symbol_table) first_word_disambig_id = find_first_disambig_symbol(symbol_table) LG = compile_LG(L=L, G=G, ctc_topo=ctc_topo, labels_disambig_id_start=first_phone_disambig_id, aux_labels_disambig_id_start=first_word_disambig_id) torch.save(LG.as_dict(), lang_dir / 'LG.pt') else: print("Loading pre-compiled LG") d = torch.load(lang_dir / 'LG.pt') LG = k2.Fsa.from_dict(d) # load dataset feature_dir = Path('exp/data') print("About to get test cuts") cuts_test = CutSet.from_json(feature_dir / 'cuts_test-clean.json.gz') print("About to create test dataset") test = K2SpeechRecognitionIterableDataset(cuts_test, max_frames=30000, shuffle=False, concat_cuts=False) print("About to create test dataloader") test_dl = torch.utils.data.DataLoader(test, batch_size=None, num_workers=1) # if not torch.cuda.is_available(): # logging.error('No GPU detected!') # sys.exit(-1) print("About to load model") # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N # device = torch.device('cuda', 1) device = torch.device('cuda') model = TdnnLstm1b(num_features=40, num_classes=len(phone_symbol_table)) checkpoint = os.path.join(exp_dir, 'epoch-7.pt') load_checkpoint(checkpoint, model) model.to(device) model.eval() print("convert LG to device") LG = LG.to(device) LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0) LG.requires_grad_(False) print("About to decode") results = decode(dataloader=test_dl, model=model, device=device, LG=LG, symbols=symbol_table) s = '' for ref, hyp in results: s += f'ref={ref}\n' s += f'hyp={hyp}\n' logging.info(s) # compute WER dists = [edit_distance(r, h) for r, h in results] errors = { key: sum(dist[key] for dist in dists) for key in ['sub', 'ins', 'del', 'total'] } total_words = sum(len(ref) for ref, _ in results) # Print Kaldi-like message: # %WER 8.20 [ 4459 / 54402, 695 ins, 427 del, 3337 sub ] logging.info( f'%WER {errors["total"] / total_words:.2%} ' f'[{errors["total"]} / {total_words}, {errors["ins"]} ins, {errors["del"]} del, {errors["sub"]} sub ]' )
def test_collate_custom_attribute_missing(): cuts = CutSet.from_json("test/fixtures/ljspeech/cuts.json") with pytest.raises(AttributeError): collate_custom_field(cuts, "nonexistent_attribute")
def main(): args = get_parser().parse_args() print('World size:', args.world_size, 'Rank:', args.local_rank) setup_dist(rank=args.local_rank, world_size=args.world_size, master_port=args.master_port) fix_random_seed(42) start_epoch = 0 num_epochs = 10 use_adam = True exp_dir = f'exp-lstm-adam-mmi-bigram-musan' setup_logger('{}/log/log-train'.format(exp_dir)) tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') # load L, G, symbol_table lang_dir = Path('data/lang_nosp') lexicon = Lexicon(lang_dir) device_id = args.local_rank device = torch.device('cuda', device_id) phone_ids = lexicon.phone_symbols() if not Path(lang_dir / 'P.pt').is_file(): logging.debug(f'Loading P from {lang_dir}/P.fst.txt') with open(lang_dir / 'P.fst.txt') as f: # P is not an acceptor because there is # a back-off state, whose incoming arcs # have label #0 and aux_label eps. P = k2.Fsa.from_openfst(f.read(), acceptor=False) phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') first_phone_disambig_id = find_first_disambig_symbol( phone_symbol_table) # P.aux_labels is not needed in later computations, so # remove it here. del P.aux_labels # CAUTION(fangjun): The following line is crucial. # Arcs entering the back-off state have label equal to #0. # We have to change it to 0 here. P.labels[P.labels >= first_phone_disambig_id] = 0 P = k2.remove_epsilon(P) P = k2.arc_sort(P) torch.save(P.as_dict(), lang_dir / 'P.pt') else: logging.debug('Loading pre-compiled P') d = torch.load(lang_dir / 'P.pt') P = k2.Fsa.from_dict(d) graph_compiler = MmiTrainingGraphCompiler( lexicon=lexicon, P=P, device=device, ) # load dataset feature_dir = Path('exp/data') logging.info("About to get train cuts") cuts_train = CutSet.from_json(feature_dir / 'cuts_train.json.gz') logging.info("About to get dev cuts") cuts_dev = CutSet.from_json(feature_dir / 'cuts_dev.json.gz') logging.info("About to get Musan cuts") cuts_musan = CutSet.from_json(feature_dir / 'cuts_musan.json.gz') logging.info("About to create train dataset") train = K2SpeechRecognitionDataset(cuts_train, cut_transforms=[ CutConcatenate(), CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20)) ]) train_sampler = SingleCutSampler( cuts_train, max_frames=40000, shuffle=True, ) logging.info("About to create train dataloader") train_dl = torch.utils.data.DataLoader(train, sampler=train_sampler, batch_size=None, num_workers=4) logging.info("About to create dev dataset") validate = K2SpeechRecognitionDataset(cuts_dev) valid_sampler = SingleCutSampler(cuts_dev, max_frames=12000) logging.info("About to create dev dataloader") valid_dl = torch.utils.data.DataLoader(validate, sampler=valid_sampler, batch_size=None, num_workers=1) if not torch.cuda.is_available(): logging.error('No GPU detected!') sys.exit(-1) logging.info("About to create model") device_id = 0 device = torch.device('cuda', device_id) model = TdnnLstm1b( num_features=40, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=3) model.P_scores = nn.Parameter(P.scores.clone(), requires_grad=True) model.to(device) describe(model) if use_adam: learning_rate = 1e-3 weight_decay = 5e-4 optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) # Equivalent to the following in the epoch loop: # if epoch > 6: # curr_learning_rate *= 0.8 lr_scheduler = optim.lr_scheduler.LambdaLR( optimizer, lambda ep: 1.0 if ep < 7 else 0.8**(ep - 6)) else: learning_rate = 5e-5 weight_decay = 1e-5 momentum = 0.9 lr_schedule_gamma = 0.7 optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay) lr_scheduler = optim.lr_scheduler.ExponentialLR( optimizer=optimizer, gamma=lr_schedule_gamma) best_objf = np.inf best_valid_objf = np.inf best_epoch = start_epoch best_model_path = os.path.join(exp_dir, 'best_model.pt') best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info') global_batch_idx_train = 0 # for logging only if start_epoch > 0: model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(start_epoch - 1)) ckpt = load_checkpoint(filename=model_path, model=model, optimizer=optimizer, scheduler=lr_scheduler) best_objf = ckpt['objf'] best_valid_objf = ckpt['valid_objf'] global_batch_idx_train = ckpt['global_batch_idx_train'] logging.info( f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}" ) for epoch in range(start_epoch, num_epochs): train_sampler.set_epoch(epoch) # LR scheduler can hold multiple learning rates for multiple parameter groups; # For now we report just the first LR which we assume concerns most of the parameters. curr_learning_rate = lr_scheduler.get_last_lr()[0] tb_writer.add_scalar('train/learning_rate', curr_learning_rate, global_batch_idx_train) tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train) logging.info('epoch {}, learning rate {}'.format( epoch, curr_learning_rate)) objf, valid_objf, global_batch_idx_train = train_one_epoch( dataloader=train_dl, valid_dataloader=valid_dl, model=model, device=device, graph_compiler=graph_compiler, optimizer=optimizer, current_epoch=epoch, tb_writer=tb_writer, num_epochs=num_epochs, global_batch_idx_train=global_batch_idx_train, ) lr_scheduler.step() # the lower, the better if valid_objf < best_valid_objf: best_valid_objf = valid_objf best_objf = objf best_epoch = epoch save_checkpoint(filename=best_model_path, model=model, optimizer=None, scheduler=None, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train) save_training_info(filename=best_epoch_info_filename, model_path=best_model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch) # we always save the model for every epoch model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch)) save_checkpoint(filename=model_path, model=model, optimizer=optimizer, scheduler=lr_scheduler, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train) epoch_info_filename = os.path.join(exp_dir, 'epoch-{}-info'.format(epoch)) save_training_info(filename=epoch_info_filename, model_path=model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch) logging.warning('Done')
def main(): parser = get_parser() args = parser.parse_args() model_type = args.model_type epoch = args.epoch avg = args.avg att_rate = args.att_rate num_paths = args.num_paths use_lm_rescoring = args.use_lm_rescoring use_whole_lattice = False if use_lm_rescoring and num_paths < 1: # It doesn't make sense to use n-best list for rescoring # when n is less than 1 use_whole_lattice = True output_beam_size = args.output_beam_size exp_dir = Path('exp-' + model_type + '-mmi-att-sa-vgg-normlayer') setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug') logging.info(f'output_beam_size: {output_beam_size}') # load L, G, symbol_table lang_dir = Path('data/lang_nosp') symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt') phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') phone_ids = get_phone_symbols(phone_symbol_table) phone_ids_with_blank = [0] + phone_ids ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank)) logging.debug("About to load model") # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N # device = torch.device('cuda', 1) device = torch.device('cuda') if att_rate != 0.0: num_decoder_layers = 6 else: num_decoder_layers = 0 if model_type == "transformer": model = Transformer( num_features=40, nhead=args.nhead, d_model=args.attention_dim, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers, vgg_frontend=args.vgg_fronted) elif model_type == "conformer": model = Conformer( num_features=40, nhead=args.nhead, d_model=args.attention_dim, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers, vgg_frontend=args.vgg_frontend, is_espnet_structure=args.is_espnet_structure) elif model_type == "contextnet": model = ContextNet(num_features=40, num_classes=len(phone_ids) + 1) # +1 for the blank symbol else: raise NotImplementedError("Model of type " + str(model_type) + " is not implemented") if avg == 1: checkpoint = os.path.join(exp_dir, 'epoch-' + str(epoch - 1) + '.pt') load_checkpoint(checkpoint, model) else: checkpoints = [ os.path.join(exp_dir, 'epoch-' + str(avg_epoch) + '.pt') for avg_epoch in range(epoch - avg, epoch) ] average_checkpoint(checkpoints, model) model.to(device) model.eval() if not os.path.exists(lang_dir / 'HLG.pt'): logging.debug("Loading L_disambig.fst.txt") with open(lang_dir / 'L_disambig.fst.txt') as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) logging.debug("Loading G.fst.txt") with open(lang_dir / 'G.fst.txt') as f: G = k2.Fsa.from_openfst(f.read(), acceptor=False) first_phone_disambig_id = find_first_disambig_symbol( phone_symbol_table) first_word_disambig_id = find_first_disambig_symbol(symbol_table) HLG = compile_HLG(L=L, G=G, H=ctc_topo, labels_disambig_id_start=first_phone_disambig_id, aux_labels_disambig_id_start=first_word_disambig_id) torch.save(HLG.as_dict(), lang_dir / 'HLG.pt') else: logging.debug("Loading pre-compiled HLG") d = torch.load(lang_dir / 'HLG.pt') HLG = k2.Fsa.from_dict(d) logging.debug('Decoding without LM rescoring') G = None if num_paths > 1: logging.debug(f'Use n-best list decoding, n is {num_paths}') else: logging.debug('Use 1-best decoding') logging.debug("convert HLG to device") HLG = HLG.to(device) HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0) HLG.requires_grad_(False) if not hasattr(HLG, 'lm_scores'): HLG.lm_scores = HLG.scores.clone() # load dataset feature_dir = Path('exp/data') logging.info("About to get test cuts") cuts_test = CutSet.from_json(feature_dir / 'cuts_test.json.gz') logging.info("About to create test dataset") test = K2SpeechRecognitionDataset(cuts_test) test_sampler = SingleCutSampler(cuts_test, max_frames=12000) logging.info("About to create test dataloader") test_dl = torch.utils.data.DataLoader(test, sampler=test_sampler, batch_size=None, num_workers=1) logging.info("About to decode") results = decode(dataloader=test_dl, model=model, HLG=HLG, symbols=symbol_table, num_paths=num_paths, G=G, use_whole_lattice=use_whole_lattice, output_beam_size=output_beam_size) s = '' results2 = [] for ref, hyp in results: s += f'ref={ref}\n' s += f'hyp={hyp}\n' results2.append((list(''.join(ref)), list(''.join(hyp)))) logging.info(s) # compute WER dists = [edit_distance(r, h) for r, h in results] dists2 = [edit_distance(r, h) for r, h in results2] errors = { key: sum(dist[key] for dist in dists) for key in ['sub', 'ins', 'del', 'total'] } errors2 = { key: sum(dist[key] for dist in dists2) for key in ['sub', 'ins', 'del', 'total'] } total_words = sum(len(ref) for ref, _ in results) total_chars = sum(len(ref) for ref, _ in results2) # Print Kaldi-like message: # %WER 8.20 [ 4459 / 54402, 695 ins, 427 del, 3337 sub ] logging.info( f'%WER {errors["total"] / total_words:.2%} ' f'[{errors["total"]} / {total_words}, {errors["ins"]} ins, {errors["del"]} del, {errors["sub"]} sub ]' ) logging.info( f'%CER {errors2["total"] / total_chars:.2%} ' f'[{errors2["total"]} / {total_chars}, {errors2["ins"]} ins, {errors2["del"]} del, {errors2["sub"]} sub ]' )
def test_global_mvn_from_cuts(): cuts = CutSet.from_json('test/fixtures/ljspeech/cuts.json') stats1 = GlobalMVN.from_cuts(cuts) stats2 = GlobalMVN.from_cuts(cuts, max_cuts=1) assert isinstance(stats1, GlobalMVN) assert isinstance(stats2, GlobalMVN)
def global_mvn(): cuts = CutSet.from_json('test/fixtures/ljspeech/cuts.json') return GlobalMVN.from_cuts(cuts)
def _next_batch(self) -> Tuple[CutSet, CutSet]: # Keep iterating the underlying CutSets as long as we hit or exceed the constraints # provided by user (the max number of source_feats or max number of cuts). # Note: no actual data is loaded into memory yet because the manifests contain all the metadata # required to do this operation. self.source_constraints.reset() self.target_constraints.reset() source_cuts = [] target_cuts = [] while True: # Check that we have not reached the end of the dataset. try: # We didn't - grab the next cut next_source_cut = next(self.source_cuts) next_target_cut = next(self.target_cuts) assert next_source_cut.id == next_target_cut.id, ( "Sampled source and target cuts with differing IDs. " "Ensure that your source and target cuts have the same length, " "the same IDs, and the same order.") except StopIteration: # No more cuts to sample from: if we have a partial batch, # we may output it, unless the user requested to drop it. # We also check if the batch is "almost there" to override drop_last. if source_cuts and ( not self.drop_last or self.source_constraints.close_to_exceeding() or self.target_constraints.close_to_exceeding()): # We have a partial batch and we can return it. assert len(source_cuts) == len( target_cuts ), "Unexpected state: some cuts in source / target are missing their counterparts..." self.diagnostics.keep(source_cuts) return CutSet.from_cuts(source_cuts), CutSet.from_cuts( target_cuts) else: # There is nothing more to return or it's discarded: # signal the iteration code to stop. self.diagnostics.discard(source_cuts) raise StopIteration() # Check whether the cuts we're about to sample satisfy optional user-requested predicate. if not self._filter_fn(next_source_cut) or not self._filter_fn( next_target_cut): # No - try another one. self.diagnostics.discard_single(next_source_cut) continue self.source_constraints.add(next_source_cut) self.target_constraints.add(next_target_cut) # Did we exceed the max_source_frames and max_cuts constraints? if (not self.source_constraints.exceeded() and not self.target_constraints.exceeded()): # No - add the next cut to the batch, and keep trying. source_cuts.append(next_source_cut) target_cuts.append(next_target_cut) else: # Yes. Do we have at least one cut in the batch? if source_cuts: # Yes. Return it. self.source_cuts.take_back(next_source_cut) self.target_cuts.take_back(next_target_cut) break else: # No. We'll warn the user that the constrains might be too tight, # and return the cut anyway. warnings.warn( "The first cut drawn in batch collection violates one of the max_... constraints" "we'll return it anyway. Consider increasing max_source_frames/max_cuts/etc." ) source_cuts.append(next_source_cut) target_cuts.append(next_target_cut) assert len(source_cuts) == len( target_cuts ), "Unexpected state: some cuts in source / target are missing their counterparts..." self.diagnostics.keep(source_cuts) return CutSet.from_cuts(source_cuts), CutSet.from_cuts(target_cuts)