def export(recordings: Pathlike, supervisions: Pathlike, output_dir: Pathlike): """ Convert a pair of ``RecordingSet`` and ``SupervisionSet`` manifests into a Kaldi-style data directory. """ export_to_kaldi(recordings=load_manifest(recordings), supervisions=load_manifest(supervisions), output_dir=output_dir)
def export( recordings: Pathlike, supervisions: Pathlike, output_dir: Pathlike, map_underscores_to: Optional[str], ): """ Convert a pair of ``RecordingSet`` and ``SupervisionSet`` manifests into a Kaldi-style data directory. """ from lhotse import load_manifest from lhotse.kaldi import export_to_kaldi output_dir = Path(output_dir) export_to_kaldi( recordings=load_manifest(recordings), supervisions=load_manifest(supervisions), output_dir=output_dir, map_underscores_to=map_underscores_to, ) click.secho( "Export completed! You likely need to run the following Kaldi commands:", bold=True, fg="yellow", ) click.secho( f" utils/utt2spk_to_spk2utt.pl {output_dir}/utt2spk > {output_dir}/spk2utt", fg="yellow", ) click.secho(f" utils/fix_data_dir.sh {output_dir}", fg="yellow")
def train_cuts(self) -> CutSet: logging.info("About to get train cuts") cuts_train = load_manifest(self.args.feature_dir / 'cuts_train-clean-100.json.gz') if self.args.full_libri: cuts_train = (cuts_train + load_manifest( self.args.feature_dir / 'cuts_train-clean-360.json.gz') + load_manifest(self.args.feature_dir / 'cuts_train-other-500.json.gz')) return cuts_train
def validate_(recordings: Pathlike, supervisions: Pathlike, read_data: bool): """ Validate a pair of Lhotse RECORDINGS and SUPERVISIONS manifest files. Checks whether the two manifests are consistent with each other. """ from lhotse import load_manifest, validate_recordings_and_supervisions recs = load_manifest(recordings) sups = load_manifest(supervisions) validate_recordings_and_supervisions( recordings=recs, supervisions=sups, read_data=read_data )
def test_cut_set_batch_feature_extraction_resume(cut_set, overwrite): # This test checks that we can keep writing to the same file # and the previously written results are not lost. # Since we don't have an easy way to interrupt the execution in a test, # we just write another CutSet to the same file. # The effect is the same. extractor = Fbank() cut_set = cut_set.resample(16000) subsets = cut_set.split(num_splits=2) processed = [] with NamedTemporaryFile() as feat_f, NamedTemporaryFile( suffix=".jsonl.gz") as manifest_f: for cuts in subsets: processed.append( cuts.compute_and_store_features_batch( extractor=extractor, storage_path=feat_f.name, manifest_path=manifest_f.name, num_workers=0, overwrite=overwrite, )) feat_f.flush() manifest_f.flush() merged = load_manifest(manifest_f.name) if overwrite: assert list(merged.ids) == list(subsets[-1].ids) else: assert list(merged.ids) == list(cut_set.ids) validate(merged, read_data=True)
def read_manifests_if_cached( dataset_parts: Optional[Sequence[str]], output_dir: Optional[Pathlike], prefix: str = '', suffix: Optional[str] = 'json', types: Iterable[str] = DEFAULT_DETECTED_MANIFEST_TYPES ) -> Dict[str, Dict[str, Union[RecordingSet, SupervisionSet]]]: """ Loads manifests from the disk, or a subset of them if only some exist. The manifests are searched for using the pattern ``output_dir / f'{prefix}_{manifest}_{part}.json'``, where `manifest` is one of ``["recordings", "supervisions"]`` and ``part`` is specified in ``dataset_parts``. This function is intended to speedup data preparation if it has already been done before. :param dataset_parts: Names of dataset pieces, e.g. in LibriSpeech: ``["test-clean", "dev-clean", ...]``. :param output_dir: Where to look for the files. :param prefix: Optional common prefix for the manifest files (underscore is automatically added). :param suffix: Optional common suffix for the manifest files ("json" by default). :param types: Which types of manifests are searched for (default: 'recordings' and 'supervisions'). :return: A dict with manifest (``d[dataset_part]['recording'|'manifest']``) or ``None``. """ if output_dir is None: return {} if prefix and not prefix.endswith('_'): prefix = f'{prefix}_' if suffix.startswith('.'): suffix = suffix[1:] manifests = defaultdict(dict) for part in dataset_parts: for manifest in types: path = output_dir / f'{prefix}{manifest}_{part}.{suffix}' if not path.is_file(): continue manifests[part][manifest] = load_manifest(path) return dict(manifests)
def subset( manifest: Pathlike, output_manifest: Pathlike, first: Optional[int], last: Optional[int], cutids: Optional[str], ): """Load MANIFEST, select the FIRST or LAST number of items and store it in OUTPUT_MANIFEST.""" from lhotse import load_manifest output_manifest = Path(output_manifest) manifest = Path(manifest) any_set = load_manifest(manifest) cids = None if cutids is not None: if os.path.exists(cutids): with open(cutids, "rt") as r: cids = json.load(r) else: cids = json.loads(cutids) if isinstance(any_set, CutSet): a_subset = any_set.subset(first=first, last=last, cut_ids=cids) else: if cutids is not None: raise ValueError( f"Expected a CutSet manifest with cut_ids argument; got {type(any_set)}" ) a_subset = any_set.subset(first=first, last=last) a_subset.to_file(output_manifest)
def test_generic_serialization(manifests, manifest_type, format, compressed): manifest = manifests[manifest_type] with NamedTemporaryFile(suffix='.' + format + ('.gz' if compressed else '')) as f: store_manifest(manifest, f.name) restored = load_manifest(f.name) assert manifest == restored
def combine(manifests: Pathlike, output_manifest: Pathlike): """Load MANIFESTS, combine them into a single one, and write it to OUTPUT_MANIFEST.""" from lhotse import load_manifest from lhotse.manipulation import combine as combine_manifests data_set = combine_manifests(*[load_manifest(m) for m in manifests]) data_set.to_file(output_manifest)
def read_manifests_if_cached( dataset_parts: Optional[Sequence[str]], output_dir: Optional[Pathlike], prefix: str = '', suffix: Optional[str] = 'json' ) -> Optional[Dict[str, Dict[str, Union[RecordingSet, SupervisionSet]]]]: """ Loads manifests from the disk if all of them exist in the specified paths. the manifests are searched for using the pattern `output_dir / f'{prefix}_{manifest}_{part}.json'`, where `manifest` is one of `["recordings", "supervisions"]` and `part` is specified in `dataset_parts`. This function is intended to speedup data preparation if it has already been done before. :param dataset_parts: Names of dataset pieces, e.g. in LibriSpeech: ``["test-clean", "dev-clean", ...]``. :param output_dir: Where to look for the files. :param prefix: Optional common prefix for the manifest files (underscore is automatically added). :param suffix: Optional common suffix for the manifest files ("json" by default). :return: A dict with manifest (``d[dataset_part]['recording'|'manifest']``) or ``None``. """ if output_dir is None: return None if prefix: prefix = f'{prefix}_' manifests = defaultdict(dict) for part in dataset_parts: for manifest in ('recordings', 'supervisions'): path = output_dir / f'{prefix}{manifest}_{part}.{suffix}' if not path.is_file(): # If one of the manifests is not available, assume we need to read and prepare everything # to simplify the rest of the code. return None manifests[part][manifest] = load_manifest(path) return dict(manifests)
def copy(input_manifest, output_manifest): """ Load INPUT_MANIFEST and store it to OUTPUT_MANIFEST. Useful for conversion between different serialization formats (e.g. JSON, JSONL, YAML). Automatically supports gzip compression when '.gz' suffix is detected. """ data = load_manifest(input_manifest) data.to_file(output_manifest)
def subset(manifest: Pathlike, output_manifest: Pathlike, first: Optional[int], last: Optional[int]): """Load MANIFEST, select the FIRST or LAST number of items and store it in OUTPUT_MANIFEST.""" output_manifest = Path(output_manifest) manifest = Path(manifest) any_set = load_manifest(manifest) a_subset = any_set.subset(first=first, last=last) a_subset.to_file(output_manifest)
def test_load_any_lhotse_manifest_lazy(path, exception_expectation): with exception_expectation: me = load_manifest(path) # some temporary files are needed to convert JSON to JSONL with NamedTemporaryFile(suffix=".jsonl.gz") as f: me.to_file(f.name) f.flush() ml = load_manifest_lazy(f.name) assert list(me) == list(ml) # equal under iteration
def test_cuts(self) -> List[CutSet]: test_sets = ['test-clean', 'test-other'] cuts = [] for test_set in test_sets: logging.debug("About to get test cuts") cuts.append( load_manifest(self.args.feature_dir / f'cuts_{test_set}.json.gz')) return cuts
def filter(predicate: str, manifest: Pathlike, output_manifest: Pathlike): """ Filter a MANIFEST according to the rule specified in PREDICATE, and save the result to OUTPUT_MANIFEST. It is intended to work generically with most manifest types - it supports RecordingSet, SupervisionSet and CutSet. \b The PREDICATE specifies which attribute is used for item selection. Some examples: lhotse filter 'duration>4.5' supervision.json output.json lhotse filter 'num_frames<600' cuts.json output.json lhotse filter 'start=0' cuts.json output.json lhotse filter 'channel!=0' audio.json output.json It currently only supports comparison of numerical manifest item attributes, such as: start, duration, end, channel, num_frames, num_features, etc. """ data_set = load_manifest(manifest) predicate_pattern = re.compile( r'(?P<key>\w+)(?P<op>=|==|!=|>|<|>=|<=)(?P<value>[0-9.]+)') match = predicate_pattern.match(predicate) if match is None: raise ValueError( "Invalid predicate! Run with --help option to learn what predicates are allowed." ) compare = { '<': operator.lt, '>': operator.gt, '>=': operator.ge, '<=': operator.le, '=': isclose, '==': isclose, '!=': complement(isclose) }[match.group('op')] try: value = int(match.group('value')) except ValueError: value = float(match.group('value')) retained_items = [] try: for item in data_set: attr = getattr(item, match.group('key')) if compare(attr, value): retained_items.append(item) except AttributeError: click.echo( f'Invalid predicate! Items in "{manifest}" do not have the attribute "{match.group("key")}"', err=True) exit(1) filtered_data_set = to_manifest(retained_items) if filtered_data_set is None: click.echo('No items satisfying the predicate.', err=True) exit(0) filtered_data_set.to_file(output_manifest)
def test_cut_set_decompose_output_dir(): c = dummy_cut( 0, start=5.0, duration=10.0, supervisions=[ dummy_supervision(0, start=0.0), dummy_supervision(1, start=6.5) ], ) assert c.start == 5.0 assert c.end == 15.0 cuts = CutSet.from_cuts([c]) with TemporaryDirectory() as td: td = Path(td) recs, sups, feats = cuts.decompose(output_dir=td) assert list(recs) == list(load_manifest(td / "recordings.jsonl.gz")) assert list(sups) == list(load_manifest(td / "supervisions.jsonl.gz")) assert list(feats) == list(load_manifest(td / "features.jsonl.gz"))
def test_cuts(self) -> CutSet: if self.args.use_context_for_test: path = ( self.args.feature_dir / f"gigaspeech_cuts_TEST{get_context_suffix(self.args)}.jsonl.gz" ) else: path = self.args.feature_dir / f"gigaspeech_cuts_TEST.jsonl.gz" logging.info(f"About to get test cuts from {path}") cuts_test = load_manifest(path) return cuts_test
def valid_cuts(self) -> CutSet: if self.args.use_context_for_test: path = ( self.args.feature_dir / f"gigaspeech_cuts_DEV{get_context_suffix(self.args)}.jsonl.gz") else: path = self.args.feature_dir / f"gigaspeech_cuts_DEV.jsonl.gz" logging.info(f"About to get valid cuts from {path}") cuts_valid = load_manifest(path) if self.args.small_dev: return cuts_valid.subset(first=1000) else: return cuts_valid
def split(num_splits: int, manifest: Pathlike, output_dir: Pathlike, shuffle: bool): """ Load MANIFEST, split it into NUM_SPLITS equal parts and save as separate manifests in OUTPUT_DIR. """ output_dir = Path(output_dir) manifest = Path(manifest) suffix = ''.join(manifest.suffixes) any_set = load_manifest(manifest) parts = any_set.split(num_splits=num_splits, shuffle=shuffle) output_dir.mkdir(parents=True, exist_ok=True) for idx, part in enumerate(parts): part.to_file( (output_dir / manifest).with_suffix(f'.{idx + 1}{suffix}'))
def read_if_cached( dataset_parts: Optional[Tuple[str]], output_dir: Optional[Pathlike] ) -> Optional[Dict[str, Dict[str, Union[RecordingSet, SupervisionSet]]]]: if output_dir is None: return None manifests = defaultdict(dict) for part in dataset_parts: for manifest in ('recordings', 'supervisions'): path = output_dir / f'{manifest}_{part}.json' if not path.is_file(): # If one of the manifests is not available, assume we need to read and prepare everything # to simplify the rest of the code. return None manifests[part][manifest] = load_manifest(path) return dict(manifests)
def split(num_splits: int, manifest: Pathlike, output_dir: Pathlike, shuffle: bool): """ Load MANIFEST, split it into NUM_SPLITS equal parts and save as separate manifests in OUTPUT_DIR. """ from lhotse import load_manifest output_dir = Path(output_dir) manifest = Path(manifest) suffix = "".join(manifest.suffixes) any_set = load_manifest(manifest) parts = any_set.split(num_splits=num_splits, shuffle=shuffle) output_dir.mkdir(parents=True, exist_ok=True) num_digits = len(str(num_splits)) for idx, part in enumerate(parts): idx = f"{idx + 1}".zfill(num_digits) part.to_file( (output_dir / manifest.stem).with_suffix(f".{idx}{suffix}"))
def read_cv_manifests_if_cached( output_dir: Optional[Pathlike], language: str, ) -> Dict[str, Dict[str, Union[RecordingSet, SupervisionSet]]]: """ Returns: {'train': {'recordings': ..., 'supervisions': ...}, 'dev': ..., 'test': ...} """ if output_dir is None: return {} manifests = defaultdict(dict) for part in ["train", "dev", "test"]: for manifest in ["recordings", "supervisions"]: path = output_dir / f"cv_{manifest}_{language}_{part}.jsonl.gz" if not path.is_file(): continue manifests[part][manifest] = load_manifest(path) return manifests
def test_cut_set_decompose_output_dir_doesnt_duplicate_recording(): c = dummy_cut(0) c2 = dummy_cut(0) c2.id = "dummy-cut-0001" # override cut ID, retain identical recording ID as `c` cuts = CutSet.from_cuts([c, c2]) with TemporaryDirectory() as td: td = Path(td) cuts.decompose(output_dir=td) text = load_jsonl(td / "recordings.jsonl.gz") print(list(text)) recs = load_manifest(td / "recordings.jsonl.gz") assert isinstance(recs, RecordingSet) # deduplicated recording assert len(recs) == 1 assert recs[0].id == "dummy-recording-0000"
def read_manifests_if_cached( dataset_parts: Optional[Sequence[str]], output_dir: Optional[Pathlike], prefix: str = "", suffix: Optional[str] = "jsonl.gz", types: Iterable[str] = DEFAULT_DETECTED_MANIFEST_TYPES, lazy: bool = False, ) -> Optional[Dict[str, Dict[str, Union[RecordingSet, SupervisionSet]]]]: """ Loads manifests from the disk, or a subset of them if only some exist. The manifests are searched for using the pattern ``output_dir / f'{prefix}_{manifest}_{part}.json'``, where `manifest` is one of ``["recordings", "supervisions"]`` and ``part`` is specified in ``dataset_parts``. This function is intended to speedup data preparation if it has already been done before. :param dataset_parts: Names of dataset pieces, e.g. in LibriSpeech: ``["test-clean", "dev-clean", ...]``. :param output_dir: Where to look for the files. :param prefix: Optional common prefix for the manifest files (underscore is automatically added). :param suffix: Optional common suffix for the manifest files ("json" by default). :param types: Which types of manifests are searched for (default: 'recordings' and 'supervisions'). :return: A dict with manifest (``d[dataset_part]['recording'|'manifest']``) or ``None``. """ if output_dir is None: return None if prefix and not prefix.endswith("_"): prefix = f"{prefix}_" if suffix.startswith("."): suffix = suffix[1:] if lazy and not suffix.startswith("jsonl"): raise ValueError( f"Only JSONL manifests can be opened lazily (got suffix: '{suffix}')" ) manifests = defaultdict(dict) output_dir = Path(output_dir) for part in dataset_parts: for manifest in types: path = output_dir / f"{prefix}{manifest}_{part}.{suffix}" if not path.is_file(): continue if lazy: manifests[part][manifest] = TYPES_TO_CLASSES[ manifest].from_jsonl_lazy(path) else: manifests[part][manifest] = load_manifest(path) return dict(manifests)
def simple( output_cut_manifest: Pathlike, recording_manifest: Optional[Pathlike], feature_manifest: Optional[Pathlike], supervision_manifest: Optional[Pathlike], ): """ Create a CutSet stored in OUTPUT_CUT_MANIFEST. Depending on the provided options, it may contain any combination of recording, feature and supervision manifests. Either RECORDING_MANIFEST or FEATURE_MANIFEST has to be provided. When SUPERVISION_MANIFEST is provided, the cuts time span will correspond to that of the supervision segments. Otherwise, that time span corresponds to the one found in features, if available, otherwise recordings. """ supervision_set, feature_set, recording_set = [ load_manifest(p) if p is not None else None for p in (supervision_manifest, feature_manifest, recording_manifest) ] cut_set = CutSet.from_manifests(recordings=recording_set, supervisions=supervision_set, features=feature_set) cut_set.to_file(output_cut_manifest)
def read_manifests_if_cached( dataset_parts: Optional[Sequence[str]], output_dir: Optional[Pathlike] ) -> Optional[Dict[str, Dict[str, Union[RecordingSet, SupervisionSet]]]]: """Loads manifests from the disk if all of them exist in the specified paths. the manifests are searched for using the pattern `output_dir / f'{manifest}_{part}.json'`, where `manifest` is one of `["recordings", "supervisions"]` and `part` is specified in `dataset_parts`. This function is intended to speedup data preparation if it has already been done before. """ if output_dir is None: return None manifests = defaultdict(dict) for part in dataset_parts: for manifest in ('recordings', 'supervisions'): path = output_dir / f'{manifest}_{part}.json' if not path.is_file(): # If one of the manifests is not available, assume we need to read and prepare everything # to simplify the rest of the code. return None manifests[part][manifest] = load_manifest(path) return dict(manifests)
def main(): args = get_parser().parse_args() model_type = args.model_type start_epoch = args.start_epoch num_epochs = args.num_epochs max_duration = args.max_duration accum_grad = args.accum_grad att_rate = args.att_rate fix_random_seed(42) exp_dir = Path('exp-' + model_type + '-noam-ctc-att-musan-sa') setup_logger('{}/log/log-train'.format(exp_dir)) tb_writer = SummaryWriter( log_dir=f'{exp_dir}/tensorboard') if args.tensorboard else None # load L, G, symbol_table lang_dir = Path('data/lang_nosp') phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') word_symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt') logging.info("Loading L.fst") if (lang_dir / 'Linv.pt').exists(): L_inv = k2.Fsa.from_dict(torch.load(lang_dir / 'Linv.pt')) else: with open(lang_dir / 'L.fst.txt') as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) L_inv = k2.arc_sort(L.invert_()) torch.save(L_inv.as_dict(), lang_dir / 'Linv.pt') graph_compiler = CtcTrainingGraphCompiler(L_inv=L_inv, phones=phone_symbol_table, words=word_symbol_table) phone_ids = get_phone_symbols(phone_symbol_table) # load dataset feature_dir = Path('exp/data') logging.info("About to get train cuts") cuts_train = load_manifest(feature_dir / 'cuts_train-clean-100.json.gz') if args.full_libri: cuts_train = ( cuts_train + load_manifest(feature_dir / 'cuts_train-clean-360.json.gz') + load_manifest(feature_dir / 'cuts_train-other-500.json.gz')) logging.info("About to get dev cuts") cuts_dev = (load_manifest(feature_dir / 'cuts_dev-clean.json.gz') + load_manifest(feature_dir / 'cuts_dev-other.json.gz')) logging.info("About to get Musan cuts") cuts_musan = load_manifest(feature_dir / 'cuts_musan.json.gz') logging.info("About to create train dataset") transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] if args.concatenate_cuts: logging.info( f'Using cut concatenation with duration factor {args.duration_factor} and gap {args.gap}.' ) # Cut concatenation should be the first transform in the list, # so that if we e.g. mix noise in, it will fill the gaps between different utterances. transforms = [ CutConcatenate(duration_factor=args.duration_factor, gap=args.gap) ] + transforms train = K2SpeechRecognitionDataset(cuts_train, cut_transforms=transforms, input_transforms=[ SpecAugment(num_frame_masks=2, features_mask_size=27, num_feature_masks=2, frames_mask_size=100) ]) if args.on_the_fly_feats: # NOTE: the PerturbSpeed transform should be added only if we remove it from data prep stage. # # Add on-the-fly speed perturbation; since originally it would have increased epoch # # size by 3, we will apply prob 2/3 and use 3x more epochs. # # Speed perturbation probably should come first before concatenation, # # but in principle the transforms order doesn't have to be strict (e.g. could be randomized) # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2 / 3)] + transforms # Drop feats to be on the safe side. cuts_train = cuts_train.drop_features() from lhotse.features.fbank import FbankConfig train = K2SpeechRecognitionDataset( cuts=cuts_train, cut_transforms=transforms, input_strategy=OnTheFlyFeatures(Fbank( FbankConfig(num_mel_bins=80))), input_transforms=[ SpecAugment(num_frame_masks=2, features_mask_size=27, num_feature_masks=2, frames_mask_size=100) ]) if args.bucketing_sampler: logging.info('Using BucketingSampler.') train_sampler = BucketingSampler(cuts_train, max_duration=max_duration, shuffle=True, num_buckets=args.num_buckets) else: logging.info('Using SingleCutSampler.') train_sampler = SingleCutSampler( cuts_train, max_duration=max_duration, shuffle=True, ) logging.info("About to create train dataloader") train_dl = torch.utils.data.DataLoader( train, sampler=train_sampler, batch_size=None, num_workers=4, ) logging.info("About to create dev dataset") if args.on_the_fly_feats: cuts_dev = cuts_dev.drop_features() validate = K2SpeechRecognitionDataset( cuts_dev.drop_features(), input_strategy=OnTheFlyFeatures(Fbank( FbankConfig(num_mel_bins=80)))) else: validate = K2SpeechRecognitionDataset(cuts_dev) valid_sampler = SingleCutSampler( cuts_dev, max_duration=max_duration, ) logging.info("About to create dev dataloader") valid_dl = torch.utils.data.DataLoader(validate, sampler=valid_sampler, batch_size=None, num_workers=1) if not torch.cuda.is_available(): logging.error('No GPU detected!') sys.exit(-1) logging.info("About to create model") device_id = 0 device = torch.device('cuda', device_id) if att_rate != 0.0: num_decoder_layers = 6 else: num_decoder_layers = 0 if model_type == "transformer": model = Transformer( num_features=80, nhead=args.nhead, d_model=args.attention_dim, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers) else: model = Conformer( num_features=80, nhead=args.nhead, d_model=args.attention_dim, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers) model.to(device) describe(model) optimizer = Noam(model.parameters(), model_size=args.attention_dim, factor=1.0, warm_step=args.warm_step) best_objf = np.inf best_valid_objf = np.inf best_epoch = start_epoch best_model_path = os.path.join(exp_dir, 'best_model.pt') best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info') global_batch_idx_train = 0 # for logging only if start_epoch > 0: model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(start_epoch - 1)) ckpt = load_checkpoint(filename=model_path, model=model, optimizer=optimizer) best_objf = ckpt['objf'] best_valid_objf = ckpt['valid_objf'] global_batch_idx_train = ckpt['global_batch_idx_train'] logging.info( f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}" ) for epoch in range(start_epoch, num_epochs): train_sampler.set_epoch(epoch) curr_learning_rate = optimizer._rate if tb_writer is not None: tb_writer.add_scalar('train/learning_rate', curr_learning_rate, global_batch_idx_train) tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train) logging.info('epoch {}, learning rate {}'.format( epoch, curr_learning_rate)) objf, valid_objf, global_batch_idx_train = train_one_epoch( dataloader=train_dl, valid_dataloader=valid_dl, model=model, device=device, graph_compiler=graph_compiler, optimizer=optimizer, accum_grad=accum_grad, att_rate=att_rate, current_epoch=epoch, tb_writer=tb_writer, num_epochs=num_epochs, global_batch_idx_train=global_batch_idx_train, ) # the lower, the better if valid_objf < best_valid_objf: best_valid_objf = valid_objf best_objf = objf best_epoch = epoch save_checkpoint(filename=best_model_path, optimizer=None, scheduler=None, model=model, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train) save_training_info(filename=best_epoch_info_filename, model_path=best_model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch) # we always save the model for every epoch model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch)) save_checkpoint(filename=model_path, optimizer=optimizer, scheduler=None, model=model, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train) epoch_info_filename = os.path.join(exp_dir, 'epoch-{}-info'.format(epoch)) save_training_info(filename=epoch_info_filename, model_path=model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch) logging.warning('Done')
def train_dataloaders(self) -> DataLoader: logging.info("About to get train cuts") cuts_train = self.train_cuts() logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.feature_dir / 'cuts_musan.json.gz') logging.info("About to create train dataset") transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] if self.args.concatenate_cuts: logging.info( f'Using cut concatenation with duration factor ' f'{self.args.duration_factor} and gap {self.args.gap}.') # Cut concatenation should be the first transform in the list, # so that if we e.g. mix noise in, it will fill the gaps between different utterances. transforms = [ CutConcatenate(duration_factor=self.args.duration_factor, gap=self.args.gap) ] + transforms input_transforms = [ SpecAugment(num_frame_masks=2, features_mask_size=27, num_feature_masks=2, frames_mask_size=100) ] train = K2SpeechRecognitionDataset( cut_transforms=transforms, input_transforms=input_transforms, return_cuts=True, ) if self.args.on_the_fly_feats: # NOTE: the PerturbSpeed transform should be added only if we remove it from data prep stage. # # Add on-the-fly speed perturbation; since originally it would have increased epoch # # size by 3, we will apply prob 2/3 and use 3x more epochs. # # Speed perturbation probably should come first before concatenation, # # but in principle the transforms order doesn't have to be strict (e.g. could be randomized) # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2 / 3)] + transforms # Drop feats to be on the safe side. cuts_train = cuts_train.drop_features() train = K2SpeechRecognitionDataset( cut_transforms=transforms, input_strategy=OnTheFlyFeatures( Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=True, ) if self.args.bucketing_sampler: logging.info('Using BucketingSampler.') train_sampler = BucketingSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets) else: logging.info('Using SingleCutSampler.') train_sampler = SingleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, ) logging.info("About to create train dataloader") train_dl = DataLoader( train, sampler=train_sampler, batch_size=None, num_workers=4, persistent_workers=True, ) return train_dl
def validate_(manifest: Pathlike, read_data: bool): """Validate a Lhotse manifest file.""" from lhotse import load_manifest, validate data = load_manifest(manifest) validate(data, read_data=read_data)
def valid_cuts(self) -> CutSet: logging.info("About to get dev cuts") cuts_valid = ( load_manifest(self.args.feature_dir / 'cuts_dev-clean.json.gz') + load_manifest(self.args.feature_dir / 'cuts_dev-other.json.gz')) return cuts_valid