def test_k2_speech_recognition_on_the_fly_feature_extraction_with_randomized_smoothing(
    k2_cut_set, ):
    dataset = K2SpeechRecognitionDataset(
        input_strategy=OnTheFlyFeatures(extractor=Fbank(), ))
    rs_dataset = K2SpeechRecognitionDataset(input_strategy=OnTheFlyFeatures(
        extractor=Fbank(),
        # Use p=1.0 to ensure that smoothing is applied in this test.
        wave_transforms=[RandomizedSmoothing(sigma=0.5, p=1.0)],
    ))
    sampler = SingleCutSampler(k2_cut_set, shuffle=False, max_cuts=1)
    for cut_ids in sampler:
        batch = dataset[cut_ids]
        rs_batch = rs_dataset[cut_ids]
        # Additive noise should cause the energies to go up
        assert (rs_batch["inputs"] - batch["inputs"]).sum() > 0
Exemplo n.º 2
0
    def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
        cuts = self.test_cuts()
        is_list = isinstance(cuts, list)
        test_loaders = []
        if not is_list:
            cuts = [cuts]

        for cuts_test in cuts:
            logging.debug("About to create test dataset")
            test = K2SpeechRecognitionDataset(
                input_strategy=(
                    OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
                    if self.args.on_the_fly_feats else PrecomputedFeatures()),
                return_cuts=True,
            )
            sampler = SingleCutSampler(cuts_test,
                                       max_duration=self.args.max_duration)
            logging.debug("About to create test dataloader")
            test_dl = DataLoader(test,
                                 batch_size=None,
                                 sampler=sampler,
                                 num_workers=1)
            test_loaders.append(test_dl)

        if is_list:
            return test_loaders
        else:
            return test_loaders[0]
Exemplo n.º 3
0
def test_k2_speech_recognition_on_the_fly_feature_extraction(
        k2_cut_set, use_batch_extract, fault_tolerant):
    precomputed_dataset = K2SpeechRecognitionDataset()
    on_the_fly_dataset = K2SpeechRecognitionDataset(
        input_strategy=OnTheFlyFeatures(
            Fbank(FbankConfig(num_mel_bins=40)),
            use_batch_extract=use_batch_extract,
            fault_tolerant=fault_tolerant,
        ))
    sampler = SimpleCutSampler(k2_cut_set, shuffle=False, max_cuts=1)
    for cut_ids in sampler:
        batch_pc = precomputed_dataset[cut_ids]
        batch_otf = on_the_fly_dataset[cut_ids]

        # Check that the features do not differ too much.
        norm_pc = torch.linalg.norm(batch_pc["inputs"])
        norm_diff = torch.linalg.norm(batch_pc["inputs"] - batch_otf["inputs"])
        # The precomputed and on-the-fly features are different due to mixing in time/fbank domains
        # and lilcom compression.
        assert norm_diff < 0.01 * norm_pc

        # Check that the supervision boundaries are the same.
        assert (batch_pc["supervisions"]["start_frame"] ==
                batch_otf["supervisions"]["start_frame"]).all()
        assert (batch_pc["supervisions"]["num_frames"] ==
                batch_otf["supervisions"]["num_frames"]).all()
Exemplo n.º 4
0
    def valid_dataloaders(self) -> DataLoader:
        logging.info("About to get dev cuts")
        cuts_valid = self.valid_cuts()

        logging.info("About to create dev dataset")
        if self.args.on_the_fly_feats:
            cuts_valid = cuts_valid.drop_features()
            validate = K2SpeechRecognitionDataset(
                cuts_valid.drop_features(),
                input_strategy=OnTheFlyFeatures(
                    Fbank(FbankConfig(num_mel_bins=80))))
        else:
            validate = K2SpeechRecognitionDataset(cuts_valid)
        valid_sampler = SingleCutSampler(
            cuts_valid,
            max_duration=self.args.max_duration,
        )
        logging.info("About to create dev dataloader")
        valid_dl = DataLoader(
            validate,
            sampler=valid_sampler,
            batch_size=None,
            num_workers=2,
            persistent_workers=True,
        )
        return valid_dl
Exemplo n.º 5
0
    def valid_dataloaders(self) -> DataLoader:
        self.validate_args()
        logging.info("About to get dev cuts")
        cuts_valid = self.valid_cuts()

        transforms = []
        if self.args.concatenate_cuts:
            transforms = [
                CutConcatenate(duration_factor=self.args.duration_factor,
                               gap=self.args.gap)
            ] + transforms

        logging.info("About to create dev dataset")
        if self.args.on_the_fly_feats:
            validate = K2SpeechRecognitionDataset(
                cut_transforms=transforms,
                input_strategy=OnTheFlyFeatures(Fbank(
                    FbankConfig(num_mel_bins=80)),
                                                num_workers=8),
                return_cuts=self.args.return_cuts,
            )
        else:
            validate = K2SpeechRecognitionDataset(
                cut_transforms=transforms,
                return_cuts=self.args.return_cuts,
            )
        valid_sampler = SingleCutSampler(
            cuts_valid,
            max_duration=self.args.max_duration,
            shuffle=False,
        )
        logging.info("About to create dev dataloader")
        # valid_dl = DataLoader(
        #    validate,
        #    sampler=valid_sampler,
        #    batch_size=None,
        #    num_workers=8,
        #    persistent_workers=True,
        # )
        valid_dl = LhotseDataLoader(
            validate,
            sampler=valid_sampler,
            num_workers=2,
        )
        return valid_dl
def test_k2_speech_recognition_on_the_fly_feature_extraction(k2_cut_set):
    precomputed_dataset = K2SpeechRecognitionDataset(k2_cut_set)
    on_the_fly_dataset = K2SpeechRecognitionDataset(
        k2_cut_set.drop_features(),
        input_strategy=OnTheFlyFeatures(Fbank())
    )
    sampler = SingleCutSampler(k2_cut_set, shuffle=False, max_cuts=1)
    for cut_ids in sampler:
        batch_pc = precomputed_dataset[cut_ids]
        batch_otf = on_the_fly_dataset[cut_ids]

        # Check that the features do not differ too much.
        norm_pc = torch.linalg.norm(batch_pc['inputs'])
        norm_diff = torch.linalg.norm(batch_pc['inputs'] - batch_otf['inputs'])
        # The precomputed and on-the-fly features are different due to mixing in time/fbank domains
        # and lilcom compression.
        assert norm_diff < 0.01 * norm_pc

        # Check that the supervision boundaries are the same.
        assert (batch_pc['supervisions']['start_frame'] == batch_otf['supervisions']['start_frame']).all()
        assert (batch_pc['supervisions']['num_frames'] == batch_otf['supervisions']['num_frames']).all()
Exemplo n.º 7
0
def main():
    args = get_parser().parse_args()

    model_type = args.model_type
    start_epoch = args.start_epoch
    num_epochs = args.num_epochs
    max_duration = args.max_duration
    accum_grad = args.accum_grad
    att_rate = args.att_rate

    fix_random_seed(42)

    exp_dir = Path('exp-' + model_type + '-noam-ctc-att-musan-sa')
    setup_logger('{}/log/log-train'.format(exp_dir))
    tb_writer = SummaryWriter(
        log_dir=f'{exp_dir}/tensorboard') if args.tensorboard else None

    # load L, G, symbol_table
    lang_dir = Path('data/lang_nosp')
    phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt')
    word_symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt')

    logging.info("Loading L.fst")
    if (lang_dir / 'Linv.pt').exists():
        L_inv = k2.Fsa.from_dict(torch.load(lang_dir / 'Linv.pt'))
    else:
        with open(lang_dir / 'L.fst.txt') as f:
            L = k2.Fsa.from_openfst(f.read(), acceptor=False)
            L_inv = k2.arc_sort(L.invert_())
            torch.save(L_inv.as_dict(), lang_dir / 'Linv.pt')

    graph_compiler = CtcTrainingGraphCompiler(L_inv=L_inv,
                                              phones=phone_symbol_table,
                                              words=word_symbol_table)
    phone_ids = get_phone_symbols(phone_symbol_table)

    # load dataset
    feature_dir = Path('exp/data')
    logging.info("About to get train cuts")
    cuts_train = load_manifest(feature_dir / 'cuts_train-clean-100.json.gz')
    if args.full_libri:
        cuts_train = (
            cuts_train +
            load_manifest(feature_dir / 'cuts_train-clean-360.json.gz') +
            load_manifest(feature_dir / 'cuts_train-other-500.json.gz'))
    logging.info("About to get dev cuts")
    cuts_dev = (load_manifest(feature_dir / 'cuts_dev-clean.json.gz') +
                load_manifest(feature_dir / 'cuts_dev-other.json.gz'))
    logging.info("About to get Musan cuts")
    cuts_musan = load_manifest(feature_dir / 'cuts_musan.json.gz')

    logging.info("About to create train dataset")
    transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
    if args.concatenate_cuts:
        logging.info(
            f'Using cut concatenation with duration factor {args.duration_factor} and gap {args.gap}.'
        )
        # Cut concatenation should be the first transform in the list,
        # so that if we e.g. mix noise in, it will fill the gaps between different utterances.
        transforms = [
            CutConcatenate(duration_factor=args.duration_factor, gap=args.gap)
        ] + transforms
    train = K2SpeechRecognitionDataset(cuts_train,
                                       cut_transforms=transforms,
                                       input_transforms=[
                                           SpecAugment(num_frame_masks=2,
                                                       features_mask_size=27,
                                                       num_feature_masks=2,
                                                       frames_mask_size=100)
                                       ])

    if args.on_the_fly_feats:
        # NOTE: the PerturbSpeed transform should be added only if we remove it from data prep stage.
        # # Add on-the-fly speed perturbation; since originally it would have increased epoch
        # # size by 3, we will apply prob 2/3 and use 3x more epochs.
        # # Speed perturbation probably should come first before concatenation,
        # # but in principle the transforms order doesn't have to be strict (e.g. could be randomized)
        # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2 / 3)] + transforms
        # Drop feats to be on the safe side.
        cuts_train = cuts_train.drop_features()
        from lhotse.features.fbank import FbankConfig
        train = K2SpeechRecognitionDataset(
            cuts=cuts_train,
            cut_transforms=transforms,
            input_strategy=OnTheFlyFeatures(Fbank(
                FbankConfig(num_mel_bins=80))),
            input_transforms=[
                SpecAugment(num_frame_masks=2,
                            features_mask_size=27,
                            num_feature_masks=2,
                            frames_mask_size=100)
            ])

    if args.bucketing_sampler:
        logging.info('Using BucketingSampler.')
        train_sampler = BucketingSampler(cuts_train,
                                         max_duration=max_duration,
                                         shuffle=True,
                                         num_buckets=args.num_buckets)
    else:
        logging.info('Using SingleCutSampler.')
        train_sampler = SingleCutSampler(
            cuts_train,
            max_duration=max_duration,
            shuffle=True,
        )
    logging.info("About to create train dataloader")
    train_dl = torch.utils.data.DataLoader(
        train,
        sampler=train_sampler,
        batch_size=None,
        num_workers=4,
    )

    logging.info("About to create dev dataset")
    if args.on_the_fly_feats:
        cuts_dev = cuts_dev.drop_features()
        validate = K2SpeechRecognitionDataset(
            cuts_dev.drop_features(),
            input_strategy=OnTheFlyFeatures(Fbank(
                FbankConfig(num_mel_bins=80))))
    else:
        validate = K2SpeechRecognitionDataset(cuts_dev)
    valid_sampler = SingleCutSampler(
        cuts_dev,
        max_duration=max_duration,
    )
    logging.info("About to create dev dataloader")
    valid_dl = torch.utils.data.DataLoader(validate,
                                           sampler=valid_sampler,
                                           batch_size=None,
                                           num_workers=1)

    if not torch.cuda.is_available():
        logging.error('No GPU detected!')
        sys.exit(-1)

    logging.info("About to create model")
    device_id = 0
    device = torch.device('cuda', device_id)

    if att_rate != 0.0:
        num_decoder_layers = 6
    else:
        num_decoder_layers = 0

    if model_type == "transformer":
        model = Transformer(
            num_features=80,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers)
    else:
        model = Conformer(
            num_features=80,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers)

    model.to(device)
    describe(model)

    optimizer = Noam(model.parameters(),
                     model_size=args.attention_dim,
                     factor=1.0,
                     warm_step=args.warm_step)

    best_objf = np.inf
    best_valid_objf = np.inf
    best_epoch = start_epoch
    best_model_path = os.path.join(exp_dir, 'best_model.pt')
    best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info')
    global_batch_idx_train = 0  # for logging only

    if start_epoch > 0:
        model_path = os.path.join(exp_dir,
                                  'epoch-{}.pt'.format(start_epoch - 1))
        ckpt = load_checkpoint(filename=model_path,
                               model=model,
                               optimizer=optimizer)
        best_objf = ckpt['objf']
        best_valid_objf = ckpt['valid_objf']
        global_batch_idx_train = ckpt['global_batch_idx_train']
        logging.info(
            f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}"
        )

    for epoch in range(start_epoch, num_epochs):
        train_sampler.set_epoch(epoch)
        curr_learning_rate = optimizer._rate
        if tb_writer is not None:
            tb_writer.add_scalar('train/learning_rate', curr_learning_rate,
                                 global_batch_idx_train)
            tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train)

        logging.info('epoch {}, learning rate {}'.format(
            epoch, curr_learning_rate))
        objf, valid_objf, global_batch_idx_train = train_one_epoch(
            dataloader=train_dl,
            valid_dataloader=valid_dl,
            model=model,
            device=device,
            graph_compiler=graph_compiler,
            optimizer=optimizer,
            accum_grad=accum_grad,
            att_rate=att_rate,
            current_epoch=epoch,
            tb_writer=tb_writer,
            num_epochs=num_epochs,
            global_batch_idx_train=global_batch_idx_train,
        )
        # the lower, the better
        if valid_objf < best_valid_objf:
            best_valid_objf = valid_objf
            best_objf = objf
            best_epoch = epoch
            save_checkpoint(filename=best_model_path,
                            optimizer=None,
                            scheduler=None,
                            model=model,
                            epoch=epoch,
                            learning_rate=curr_learning_rate,
                            objf=objf,
                            valid_objf=valid_objf,
                            global_batch_idx_train=global_batch_idx_train)
            save_training_info(filename=best_epoch_info_filename,
                               model_path=best_model_path,
                               current_epoch=epoch,
                               learning_rate=curr_learning_rate,
                               objf=objf,
                               best_objf=best_objf,
                               valid_objf=valid_objf,
                               best_valid_objf=best_valid_objf,
                               best_epoch=best_epoch)

        # we always save the model for every epoch
        model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch))
        save_checkpoint(filename=model_path,
                        optimizer=optimizer,
                        scheduler=None,
                        model=model,
                        epoch=epoch,
                        learning_rate=curr_learning_rate,
                        objf=objf,
                        valid_objf=valid_objf,
                        global_batch_idx_train=global_batch_idx_train)
        epoch_info_filename = os.path.join(exp_dir,
                                           'epoch-{}-info'.format(epoch))
        save_training_info(filename=epoch_info_filename,
                           model_path=model_path,
                           current_epoch=epoch,
                           learning_rate=curr_learning_rate,
                           objf=objf,
                           best_objf=best_objf,
                           valid_objf=valid_objf,
                           best_valid_objf=best_valid_objf,
                           best_epoch=best_epoch)

    logging.warning('Done')
Exemplo n.º 8
0
    def train_dataloaders(self) -> DataLoader:
        logging.info("About to get train cuts")
        cuts_train = self.train_cuts()

        logging.info("About to get Musan cuts")
        cuts_musan = load_manifest(self.args.feature_dir /
                                   'cuts_musan.json.gz')

        logging.info("About to create train dataset")
        transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
        if self.args.concatenate_cuts:
            logging.info(
                f'Using cut concatenation with duration factor '
                f'{self.args.duration_factor} and gap {self.args.gap}.')
            # Cut concatenation should be the first transform in the list,
            # so that if we e.g. mix noise in, it will fill the gaps between different utterances.
            transforms = [
                CutConcatenate(duration_factor=self.args.duration_factor,
                               gap=self.args.gap)
            ] + transforms

        input_transforms = [
            SpecAugment(num_frame_masks=2,
                        features_mask_size=27,
                        num_feature_masks=2,
                        frames_mask_size=100)
        ]

        train = K2SpeechRecognitionDataset(
            cut_transforms=transforms,
            input_transforms=input_transforms,
            return_cuts=True,
        )

        if self.args.on_the_fly_feats:
            # NOTE: the PerturbSpeed transform should be added only if we remove it from data prep stage.
            # # Add on-the-fly speed perturbation; since originally it would have increased epoch
            # # size by 3, we will apply prob 2/3 and use 3x more epochs.
            # # Speed perturbation probably should come first before concatenation,
            # # but in principle the transforms order doesn't have to be strict (e.g. could be randomized)
            # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2 / 3)] + transforms
            # Drop feats to be on the safe side.
            cuts_train = cuts_train.drop_features()
            train = K2SpeechRecognitionDataset(
                cut_transforms=transforms,
                input_strategy=OnTheFlyFeatures(
                    Fbank(FbankConfig(num_mel_bins=80))),
                input_transforms=input_transforms,
                return_cuts=True,
            )

        if self.args.bucketing_sampler:
            logging.info('Using BucketingSampler.')
            train_sampler = BucketingSampler(
                cuts_train,
                max_duration=self.args.max_duration,
                shuffle=self.args.shuffle,
                num_buckets=self.args.num_buckets)
        else:
            logging.info('Using SingleCutSampler.')
            train_sampler = SingleCutSampler(
                cuts_train,
                max_duration=self.args.max_duration,
                shuffle=self.args.shuffle,
            )
        logging.info("About to create train dataloader")
        train_dl = DataLoader(
            train,
            sampler=train_sampler,
            batch_size=None,
            num_workers=4,
            persistent_workers=True,
        )
        return train_dl
Exemplo n.º 9
0
def main():
    args = get_parser().parse_args()

    model_type = args.model_type
    epoch = args.epoch
    max_duration = args.max_duration
    avg = args.avg
    att_rate = args.att_rate

    exp_dir = Path('exp-' + model_type + '-noam-ctc-att-musan-sa')
    setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug')

    # load L, G, symbol_table
    lang_dir = Path('data/lang_nosp')
    symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt')
    phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt')

    phone_ids = get_phone_symbols(phone_symbol_table)
    phone_ids_with_blank = [0] + phone_ids
    ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank))

    logging.debug("About to load model")
    # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N
    # device = torch.device('cuda', 1)
    device = torch.device('cuda')

    if att_rate != 0.0:
        num_decoder_layers = 6
    else:
        num_decoder_layers = 0

    if model_type == "transformer":
        model = Transformer(
            num_features=80,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers)
    else:
        model = Conformer(
            num_features=80,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers)

    if avg == 1:
        checkpoint = os.path.join(exp_dir, 'epoch-' + str(epoch - 1) + '.pt')
        load_checkpoint(checkpoint, model)
    else:
        checkpoints = [
            os.path.join(exp_dir, 'epoch-' + str(avg_epoch) + '.pt')
            for avg_epoch in range(epoch - avg, epoch)
        ]
        average_checkpoint(checkpoints, model)

    model.to(device)
    model.eval()

    if not os.path.exists(lang_dir / 'HLG.pt'):
        logging.debug("Loading L_disambig.fst.txt")
        with open(lang_dir / 'L_disambig.fst.txt') as f:
            L = k2.Fsa.from_openfst(f.read(), acceptor=False)
        logging.debug("Loading G.fst.txt")
        with open(lang_dir / 'G.fst.txt') as f:
            G = k2.Fsa.from_openfst(f.read(), acceptor=False)
        first_phone_disambig_id = find_first_disambig_symbol(
            phone_symbol_table)
        first_word_disambig_id = find_first_disambig_symbol(symbol_table)
        HLG = compile_HLG(L=L,
                          G=G,
                          H=ctc_topo,
                          labels_disambig_id_start=first_phone_disambig_id,
                          aux_labels_disambig_id_start=first_word_disambig_id)
        torch.save(HLG.as_dict(), lang_dir / 'HLG.pt')
    else:
        logging.debug("Loading pre-compiled HLG")
        d = torch.load(lang_dir / 'HLG.pt')
        HLG = k2.Fsa.from_dict(d)

    logging.debug("convert HLG to device")
    HLG = HLG.to(device)
    HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0)
    HLG.requires_grad_(False)

    # load dataset
    feature_dir = Path('exp/data')
    test_sets = ['test-clean', 'test-other']
    for test_set in test_sets:
        logging.info(f'* DECODING: {test_set}')

        logging.debug("About to get test cuts")
        cuts_test = load_manifest(feature_dir / f'cuts_{test_set}.json.gz')
        logging.debug("About to create test dataset")
        from lhotse.dataset.input_strategies import OnTheFlyFeatures
        from lhotse import Fbank, FbankConfig
        test = K2SpeechRecognitionDataset(
            cuts_test,
            input_strategy=OnTheFlyFeatures(Fbank(
                FbankConfig(num_mel_bins=80))))
        sampler = SingleCutSampler(cuts_test, max_duration=max_duration)
        logging.debug("About to create test dataloader")
        test_dl = torch.utils.data.DataLoader(test,
                                              batch_size=None,
                                              sampler=sampler,
                                              num_workers=1)

        logging.debug("About to decode")
        results = decode(dataloader=test_dl,
                         model=model,
                         device=device,
                         HLG=HLG,
                         symbols=symbol_table)

        recog_path = exp_dir / f'recogs-{test_set}.txt'
        store_transcripts(path=recog_path, texts=results)
        logging.info(f'The transcripts are stored in {recog_path}')
        # compute WER
        dists = [edit_distance(r, h) for r, h in results]
        errors = {
            key: sum(dist[key] for dist in dists)
            for key in ['sub', 'ins', 'del', 'total']
        }
        total_words = sum(len(ref) for ref, _ in results)
        # Print Kaldi-like message:
        # %WER 8.20 [ 4459 / 54402, 695 ins, 427 del, 3337 sub ]
        logging.info(
            f'[{test_set}] %WER {errors["total"] / total_words:.2%} '
            f'[{errors["total"]} / {total_words}, {errors["ins"]} ins, {errors["del"]} del, {errors["sub"]} sub ]'
        )