Example #1
0
def test_specaugment_load_state_dict():
    # all values non-default
    config = dict(
        time_warp_factor=85,
        num_feature_masks=2,
        features_mask_size=12,
        num_frame_masks=2,
        frames_mask_size=71,
        max_frames_mask_fraction=0.25,
        p=0.6,
    )
    specaug = SpecAugment()
    specaug.load_state_dict(config)

    for key, value in config.items():
        assert getattr(specaug, key) == value
Example #2
0
def test_specaugment_state_dict():
    # all values default
    config = dict(
        time_warp_factor=80,
        num_feature_masks=1,
        features_mask_size=13,
        num_frame_masks=1,
        frames_mask_size=70,
        max_frames_mask_fraction=0.2,
        p=0.5,
    )
    specaug = SpecAugment(**config)
    state_dict = specaug.state_dict()

    for key, value in config.items():
        assert state_dict[key] == value
Example #3
0
def test_specaugment_2d_input_raises_error():
    cuts = CutSet.from_json("test/fixtures/ljspeech/cuts.json")
    feats = torch.from_numpy(cuts[0].load_features())
    tfnm = SpecAugment(p=1.0, time_warp_factor=10)
    with pytest.raises(AssertionError):
        augmented = tfnm(feats)
        assert (feats != augmented).any()
def test_specaugment_batch(num_feature_masks, num_frame_masks):
    cuts = CutSet.from_json('test/fixtures/ljspeech/cuts.json')
    feats, feat_lens = collate_features(cuts)
    tfnm = SpecAugment(p=1.0,
                       time_warp_factor=10,
                       features_mask_size=5,
                       frames_mask_size=20,
                       num_feature_masks=num_feature_masks,
                       num_frame_masks=num_frame_masks)
    augmented = tfnm(feats)
    assert (feats != augmented).any()
def main():
    args = get_parser().parse_args()

    model_type = args.model_type
    start_epoch = args.start_epoch
    num_epochs = args.num_epochs
    max_duration = args.max_duration
    accum_grad = args.accum_grad
    att_rate = args.att_rate

    fix_random_seed(42)

    exp_dir = Path('exp-' + model_type + '-noam-ctc-att-musan-sa')
    setup_logger('{}/log/log-train'.format(exp_dir))
    tb_writer = SummaryWriter(
        log_dir=f'{exp_dir}/tensorboard') if args.tensorboard else None

    # load L, G, symbol_table
    lang_dir = Path('data/lang_nosp')
    phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt')
    word_symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt')

    logging.info("Loading L.fst")
    if (lang_dir / 'Linv.pt').exists():
        L_inv = k2.Fsa.from_dict(torch.load(lang_dir / 'Linv.pt'))
    else:
        with open(lang_dir / 'L.fst.txt') as f:
            L = k2.Fsa.from_openfst(f.read(), acceptor=False)
            L_inv = k2.arc_sort(L.invert_())
            torch.save(L_inv.as_dict(), lang_dir / 'Linv.pt')

    graph_compiler = CtcTrainingGraphCompiler(L_inv=L_inv,
                                              phones=phone_symbol_table,
                                              words=word_symbol_table)
    phone_ids = get_phone_symbols(phone_symbol_table)

    # load dataset
    feature_dir = Path('exp/data')
    logging.info("About to get train cuts")
    cuts_train = load_manifest(feature_dir / 'cuts_train-clean-100.json.gz')
    if args.full_libri:
        cuts_train = (
            cuts_train +
            load_manifest(feature_dir / 'cuts_train-clean-360.json.gz') +
            load_manifest(feature_dir / 'cuts_train-other-500.json.gz'))
    logging.info("About to get dev cuts")
    cuts_dev = (load_manifest(feature_dir / 'cuts_dev-clean.json.gz') +
                load_manifest(feature_dir / 'cuts_dev-other.json.gz'))
    logging.info("About to get Musan cuts")
    cuts_musan = load_manifest(feature_dir / 'cuts_musan.json.gz')

    logging.info("About to create train dataset")
    transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
    if args.concatenate_cuts:
        logging.info(
            f'Using cut concatenation with duration factor {args.duration_factor} and gap {args.gap}.'
        )
        # Cut concatenation should be the first transform in the list,
        # so that if we e.g. mix noise in, it will fill the gaps between different utterances.
        transforms = [
            CutConcatenate(duration_factor=args.duration_factor, gap=args.gap)
        ] + transforms
    train = K2SpeechRecognitionDataset(cuts_train,
                                       cut_transforms=transforms,
                                       input_transforms=[
                                           SpecAugment(num_frame_masks=2,
                                                       features_mask_size=27,
                                                       num_feature_masks=2,
                                                       frames_mask_size=100)
                                       ])

    if args.on_the_fly_feats:
        # NOTE: the PerturbSpeed transform should be added only if we remove it from data prep stage.
        # # Add on-the-fly speed perturbation; since originally it would have increased epoch
        # # size by 3, we will apply prob 2/3 and use 3x more epochs.
        # # Speed perturbation probably should come first before concatenation,
        # # but in principle the transforms order doesn't have to be strict (e.g. could be randomized)
        # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2 / 3)] + transforms
        # Drop feats to be on the safe side.
        cuts_train = cuts_train.drop_features()
        from lhotse.features.fbank import FbankConfig
        train = K2SpeechRecognitionDataset(
            cuts=cuts_train,
            cut_transforms=transforms,
            input_strategy=OnTheFlyFeatures(Fbank(
                FbankConfig(num_mel_bins=80))),
            input_transforms=[
                SpecAugment(num_frame_masks=2,
                            features_mask_size=27,
                            num_feature_masks=2,
                            frames_mask_size=100)
            ])

    if args.bucketing_sampler:
        logging.info('Using BucketingSampler.')
        train_sampler = BucketingSampler(cuts_train,
                                         max_duration=max_duration,
                                         shuffle=True,
                                         num_buckets=args.num_buckets)
    else:
        logging.info('Using SingleCutSampler.')
        train_sampler = SingleCutSampler(
            cuts_train,
            max_duration=max_duration,
            shuffle=True,
        )
    logging.info("About to create train dataloader")
    train_dl = torch.utils.data.DataLoader(
        train,
        sampler=train_sampler,
        batch_size=None,
        num_workers=4,
    )

    logging.info("About to create dev dataset")
    if args.on_the_fly_feats:
        cuts_dev = cuts_dev.drop_features()
        validate = K2SpeechRecognitionDataset(
            cuts_dev.drop_features(),
            input_strategy=OnTheFlyFeatures(Fbank(
                FbankConfig(num_mel_bins=80))))
    else:
        validate = K2SpeechRecognitionDataset(cuts_dev)
    valid_sampler = SingleCutSampler(
        cuts_dev,
        max_duration=max_duration,
    )
    logging.info("About to create dev dataloader")
    valid_dl = torch.utils.data.DataLoader(validate,
                                           sampler=valid_sampler,
                                           batch_size=None,
                                           num_workers=1)

    if not torch.cuda.is_available():
        logging.error('No GPU detected!')
        sys.exit(-1)

    logging.info("About to create model")
    device_id = 0
    device = torch.device('cuda', device_id)

    if att_rate != 0.0:
        num_decoder_layers = 6
    else:
        num_decoder_layers = 0

    if model_type == "transformer":
        model = Transformer(
            num_features=80,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers)
    else:
        model = Conformer(
            num_features=80,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers)

    model.to(device)
    describe(model)

    optimizer = Noam(model.parameters(),
                     model_size=args.attention_dim,
                     factor=1.0,
                     warm_step=args.warm_step)

    best_objf = np.inf
    best_valid_objf = np.inf
    best_epoch = start_epoch
    best_model_path = os.path.join(exp_dir, 'best_model.pt')
    best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info')
    global_batch_idx_train = 0  # for logging only

    if start_epoch > 0:
        model_path = os.path.join(exp_dir,
                                  'epoch-{}.pt'.format(start_epoch - 1))
        ckpt = load_checkpoint(filename=model_path,
                               model=model,
                               optimizer=optimizer)
        best_objf = ckpt['objf']
        best_valid_objf = ckpt['valid_objf']
        global_batch_idx_train = ckpt['global_batch_idx_train']
        logging.info(
            f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}"
        )

    for epoch in range(start_epoch, num_epochs):
        train_sampler.set_epoch(epoch)
        curr_learning_rate = optimizer._rate
        if tb_writer is not None:
            tb_writer.add_scalar('train/learning_rate', curr_learning_rate,
                                 global_batch_idx_train)
            tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train)

        logging.info('epoch {}, learning rate {}'.format(
            epoch, curr_learning_rate))
        objf, valid_objf, global_batch_idx_train = train_one_epoch(
            dataloader=train_dl,
            valid_dataloader=valid_dl,
            model=model,
            device=device,
            graph_compiler=graph_compiler,
            optimizer=optimizer,
            accum_grad=accum_grad,
            att_rate=att_rate,
            current_epoch=epoch,
            tb_writer=tb_writer,
            num_epochs=num_epochs,
            global_batch_idx_train=global_batch_idx_train,
        )
        # the lower, the better
        if valid_objf < best_valid_objf:
            best_valid_objf = valid_objf
            best_objf = objf
            best_epoch = epoch
            save_checkpoint(filename=best_model_path,
                            optimizer=None,
                            scheduler=None,
                            model=model,
                            epoch=epoch,
                            learning_rate=curr_learning_rate,
                            objf=objf,
                            valid_objf=valid_objf,
                            global_batch_idx_train=global_batch_idx_train)
            save_training_info(filename=best_epoch_info_filename,
                               model_path=best_model_path,
                               current_epoch=epoch,
                               learning_rate=curr_learning_rate,
                               objf=objf,
                               best_objf=best_objf,
                               valid_objf=valid_objf,
                               best_valid_objf=best_valid_objf,
                               best_epoch=best_epoch)

        # we always save the model for every epoch
        model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch))
        save_checkpoint(filename=model_path,
                        optimizer=optimizer,
                        scheduler=None,
                        model=model,
                        epoch=epoch,
                        learning_rate=curr_learning_rate,
                        objf=objf,
                        valid_objf=valid_objf,
                        global_batch_idx_train=global_batch_idx_train)
        epoch_info_filename = os.path.join(exp_dir,
                                           'epoch-{}-info'.format(epoch))
        save_training_info(filename=epoch_info_filename,
                           model_path=model_path,
                           current_epoch=epoch,
                           learning_rate=curr_learning_rate,
                           objf=objf,
                           best_objf=best_objf,
                           valid_objf=valid_objf,
                           best_valid_objf=best_valid_objf,
                           best_epoch=best_epoch)

    logging.warning('Done')
Example #6
0
    def train_dataloaders(self) -> DataLoader:
        logging.info("About to get train cuts")
        cuts_train = self.train_cuts()

        logging.info("About to get Musan cuts")
        cuts_musan = load_manifest(self.args.feature_dir /
                                   'cuts_musan.json.gz')

        logging.info("About to create train dataset")
        transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
        if self.args.concatenate_cuts:
            logging.info(
                f'Using cut concatenation with duration factor '
                f'{self.args.duration_factor} and gap {self.args.gap}.')
            # Cut concatenation should be the first transform in the list,
            # so that if we e.g. mix noise in, it will fill the gaps between different utterances.
            transforms = [
                CutConcatenate(duration_factor=self.args.duration_factor,
                               gap=self.args.gap)
            ] + transforms

        input_transforms = [
            SpecAugment(num_frame_masks=2,
                        features_mask_size=27,
                        num_feature_masks=2,
                        frames_mask_size=100)
        ]

        train = K2SpeechRecognitionDataset(
            cut_transforms=transforms,
            input_transforms=input_transforms,
            return_cuts=True,
        )

        if self.args.on_the_fly_feats:
            # NOTE: the PerturbSpeed transform should be added only if we remove it from data prep stage.
            # # Add on-the-fly speed perturbation; since originally it would have increased epoch
            # # size by 3, we will apply prob 2/3 and use 3x more epochs.
            # # Speed perturbation probably should come first before concatenation,
            # # but in principle the transforms order doesn't have to be strict (e.g. could be randomized)
            # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2 / 3)] + transforms
            # Drop feats to be on the safe side.
            cuts_train = cuts_train.drop_features()
            train = K2SpeechRecognitionDataset(
                cut_transforms=transforms,
                input_strategy=OnTheFlyFeatures(
                    Fbank(FbankConfig(num_mel_bins=80))),
                input_transforms=input_transforms,
                return_cuts=True,
            )

        if self.args.bucketing_sampler:
            logging.info('Using BucketingSampler.')
            train_sampler = BucketingSampler(
                cuts_train,
                max_duration=self.args.max_duration,
                shuffle=self.args.shuffle,
                num_buckets=self.args.num_buckets)
        else:
            logging.info('Using SingleCutSampler.')
            train_sampler = SingleCutSampler(
                cuts_train,
                max_duration=self.args.max_duration,
                shuffle=self.args.shuffle,
            )
        logging.info("About to create train dataloader")
        train_dl = DataLoader(
            train,
            sampler=train_sampler,
            batch_size=None,
            num_workers=4,
            persistent_workers=True,
        )
        return train_dl
Example #7
0
def test_specaugment_single():
    cuts = CutSet.from_json('test/fixtures/ljspeech/cuts.json')
    feats = torch.from_numpy(cuts[0].load_features())
    tfnm = SpecAugment(p=1.0, time_warp_factor=10)
    augmented = tfnm(feats)
    assert (feats != augmented).any()