Beispiel #1
0
def load_openvino_model():
    logdir = os.path.join('logs', FLAGS.name)

    tokenizer = HuggingFaceTokenizer(cache_dir=logdir,
                                     vocab_size=FLAGS.bpe_size)

    _, transform, input_size = build_transform(feature_type=FLAGS.feature,
                                               feature_size=FLAGS.feature_size,
                                               n_fft=FLAGS.n_fft,
                                               win_length=FLAGS.win_length,
                                               hop_length=FLAGS.hop_length,
                                               delta=FLAGS.delta,
                                               cmvn=FLAGS.cmvn,
                                               downsample=FLAGS.downsample,
                                               pad_to_divisible=False,
                                               T_mask=FLAGS.T_mask,
                                               T_num_mask=FLAGS.T_num_mask,
                                               F_mask=FLAGS.F_mask,
                                               F_num_mask=FLAGS.F_num_mask)

    ie = IECore()
    encoder_net = ie.read_network(model=os.path.join(logdir, 'encoder.xml'),
                                  weights=os.path.join(logdir, 'encoder.bin'))
    encoder = ie.load_network(network=encoder_net, device_name='CPU')

    decoder_net = ie.read_network(model=os.path.join(logdir, 'decoder.xml'),
                                  weights=os.path.join(logdir, 'decoder.bin'))
    decoder = ie.load_network(network=decoder_net, device_name='CPU')

    joint_net = ie.read_network(model=os.path.join(logdir, 'joint.xml'),
                                weights=os.path.join(logdir, 'joint.bin'))
    joint = ie.load_network(network=joint_net, device_name='CPU')

    return encoder, decoder, joint, tokenizer, transform
Beispiel #2
0
    def __init__(self, FLAGS):
        self.FLAGS = FLAGS
        logdir = os.path.join('logs', FLAGS.name)

        self.tokenizer = HuggingFaceTokenizer(
            cache_dir=logdir, vocab_size=FLAGS.bpe_size)

        _, self.transform, input_size = build_transform(
            feature_type=FLAGS.feature, feature_size=FLAGS.feature_size,
            n_fft=FLAGS.n_fft, win_length=FLAGS.win_length,
            hop_length=FLAGS.hop_length, delta=FLAGS.delta, cmvn=FLAGS.cmvn,
            downsample=FLAGS.downsample, pad_to_divisible=False,
            T_mask=FLAGS.T_mask, T_num_mask=FLAGS.T_num_mask,
            F_mask=FLAGS.F_mask, F_num_mask=FLAGS.F_num_mask)

        ie = IECore()
        encoder_net = ie.read_network(
            model=os.path.join(logdir, 'encoder.xml'),
            weights=os.path.join(logdir, 'encoder.bin'))
        self.encoder = ie.load_network(network=encoder_net, device_name='CPU')

        decoder_net = ie.read_network(
            model=os.path.join(logdir, 'decoder.xml'),
            weights=os.path.join(logdir, 'decoder.bin'))
        self.decoder = ie.load_network(network=decoder_net, device_name='CPU')

        joint_net = ie.read_network(
            model=os.path.join(logdir, 'joint.xml'),
            weights=os.path.join(logdir, 'joint.bin'))
        self.joint = ie.load_network(network=joint_net, device_name='CPU')

        self.reset_profile()
        self.reset()
Beispiel #3
0
    def __init__(self):
        super(ParallelTraining, self).__init__()
        _, _, input_size = build_transform(feature_type=FLAGS.feature,
                                           feature_size=FLAGS.feature_size,
                                           n_fft=FLAGS.n_fft,
                                           win_length=FLAGS.win_length,
                                           hop_length=FLAGS.hop_length,
                                           delta=FLAGS.delta,
                                           cmvn=FLAGS.cmvn,
                                           downsample=FLAGS.downsample,
                                           T_mask=FLAGS.T_mask,
                                           T_num_mask=FLAGS.T_num_mask,
                                           F_mask=FLAGS.F_mask,
                                           F_num_mask=FLAGS.F_num_mask)
        self.log_path = None
        self.loss_fn = RNNTLoss(blank=NUL)

        if FLAGS.tokenizer == 'char':
            self.tokenizer = CharTokenizer(cache_dir=self.logdir)
        else:
            self.tokenizer = HuggingFaceTokenizer(cache_dir='BPE-2048',
                                                  vocab_size=FLAGS.bpe_size)
        self.vocab_size = self.tokenizer.vocab_size
        print(FLAGS.enc_type)

        self.model = Transducer(
            vocab_embed_size=FLAGS.vocab_embed_size,
            vocab_size=self.vocab_size,
            input_size=input_size,
            enc_hidden_size=FLAGS.enc_hidden_size,
            enc_layers=FLAGS.enc_layers,
            enc_dropout=FLAGS.enc_dropout,
            enc_proj_size=FLAGS.enc_proj_size,
            dec_hidden_size=FLAGS.dec_hidden_size,
            dec_layers=FLAGS.dec_layers,
            dec_dropout=FLAGS.dec_dropout,
            dec_proj_size=FLAGS.dec_proj_size,
            joint_size=FLAGS.joint_size,
            module_type=FLAGS.enc_type,
            output_loss=False,
        )
        self.latest_alignment = None
        self.steps = 0
        self.epoch = 0
        self.best_wer = 1000
Beispiel #4
0
    def __init__(self, FLAGS):
        self.FLAGS = FLAGS
        logdir = os.path.join('logs', FLAGS.name)

        self.tokenizer = HuggingFaceTokenizer(
            cache_dir='BPE-'+str(FLAGS.bpe_size), vocab_size=FLAGS.bpe_size)
        
        assert self.tokenizer.tokenizer != None

        _, self.transform, input_size = build_transform(
            feature_type=FLAGS.feature, feature_size=FLAGS.feature_size,
            n_fft=FLAGS.n_fft, win_length=FLAGS.win_length,
            hop_length=FLAGS.hop_length, delta=FLAGS.delta, cmvn=FLAGS.cmvn,
            downsample=FLAGS.downsample, pad_to_divisible=False,
            T_mask=FLAGS.T_mask, T_num_mask=FLAGS.T_num_mask,
            F_mask=FLAGS.F_mask, F_num_mask=FLAGS.F_num_mask)

        model_path = os.path.join(logdir, 'models', FLAGS.model_name)
        if os.path.exists(model_path):
            checkpoint = torch.load(model_path, lambda storage, loc: storage)
        else:
            model_path = os.path.join(logdir, FLAGS.model_name)
            checkpoint = torch.load(model_path, lambda storage, loc: storage)

        transducer = Transducer(
            vocab_embed_size=FLAGS.vocab_embed_size,
            vocab_size=self.tokenizer.vocab_size,
            input_size=input_size,
            enc_hidden_size=FLAGS.enc_hidden_size,
            enc_layers=FLAGS.enc_layers,
            enc_dropout=FLAGS.enc_dropout,
            enc_proj_size=FLAGS.enc_proj_size,
            dec_hidden_size=FLAGS.dec_hidden_size,
            dec_layers=FLAGS.dec_layers,
            dec_dropout=FLAGS.dec_dropout,
            dec_proj_size=FLAGS.dec_proj_size,
            joint_size=FLAGS.joint_size,
            output_loss=False,
        )

        transducer.load_state_dict(convert_lightning2normal(checkpoint)['model'])
        transducer.eval()
        self.encoder = transducer.encoder
        self.decoder = transducer.decoder
        self.joint = transducer.joint

        self.reset_profile()
        self.reset()
Beispiel #5
0
def main(argv):
    assert FLAGS.step_n_frame % 2 == 0, ("step_n_frame must be divisible by "
                                         "reduction_factor of TimeReduction")

    logdir = os.path.join('logs', FLAGS.name)

    tokenizer = HuggingFaceTokenizer(cache_dir=logdir,
                                     vocab_size=FLAGS.bpe_size)

    transform_train, transform_test, input_size = build_transform(
        feature_type=FLAGS.feature,
        feature_size=FLAGS.feature_size,
        n_fft=FLAGS.n_fft,
        win_length=FLAGS.win_length,
        hop_length=FLAGS.hop_length,
        delta=FLAGS.delta,
        cmvn=FLAGS.cmvn,
        downsample=FLAGS.downsample,
        T_mask=FLAGS.T_mask,
        T_num_mask=FLAGS.T_num_mask,
        F_mask=FLAGS.F_mask,
        F_num_mask=FLAGS.F_num_mask)

    model_path = os.path.join(logdir, 'models', FLAGS.model_name)
    checkpoint = torch.load(model_path, lambda storage, loc: storage)
    transducer = Transducer(
        vocab_embed_size=FLAGS.vocab_embed_size,
        vocab_size=tokenizer.vocab_size,
        input_size=input_size,
        enc_hidden_size=FLAGS.enc_hidden_size,
        enc_layers=FLAGS.enc_layers,
        enc_dropout=FLAGS.enc_dropout,
        enc_proj_size=FLAGS.enc_proj_size,
        dec_hidden_size=FLAGS.dec_hidden_size,
        dec_layers=FLAGS.dec_layers,
        dec_dropout=FLAGS.dec_dropout,
        dec_proj_size=FLAGS.dec_proj_size,
        joint_size=FLAGS.joint_size,
    )
    transducer.load_state_dict(checkpoint['model'])
    transducer.eval()

    export_encoder(transducer, input_size, tokenizer.vocab_size, logdir)
    export_decoder(transducer, input_size, tokenizer.vocab_size, logdir)
    export_join(transducer, input_size, tokenizer.vocab_size, logdir)
Beispiel #6
0
def load_pytorch_model():
    logdir = os.path.join('logs', FLAGS.name)

    tokenizer = HuggingFaceTokenizer(cache_dir=logdir,
                                     vocab_size=FLAGS.bpe_size)

    _, transform, input_size = build_transform(feature_type=FLAGS.feature,
                                               feature_size=FLAGS.feature_size,
                                               n_fft=FLAGS.n_fft,
                                               win_length=FLAGS.win_length,
                                               hop_length=FLAGS.hop_length,
                                               delta=FLAGS.delta,
                                               cmvn=FLAGS.cmvn,
                                               downsample=FLAGS.downsample,
                                               pad_to_divisible=False,
                                               T_mask=FLAGS.T_mask,
                                               T_num_mask=FLAGS.T_num_mask,
                                               F_mask=FLAGS.F_mask,
                                               F_num_mask=FLAGS.F_num_mask)

    model_path = os.path.join(logdir, 'models', '%d.pt' % FLAGS.step)
    checkpoint = torch.load(model_path, lambda storage, loc: storage)
    transducer = Transducer(
        vocab_embed_size=FLAGS.vocab_embed_size,
        vocab_size=tokenizer.vocab_size,
        input_size=input_size,
        enc_hidden_size=FLAGS.enc_hidden_size,
        enc_layers=FLAGS.enc_layers,
        enc_dropout=FLAGS.enc_dropout,
        enc_proj_size=FLAGS.enc_proj_size,
        dec_hidden_size=FLAGS.dec_hidden_size,
        dec_layers=FLAGS.dec_layers,
        dec_dropout=FLAGS.dec_dropout,
        dec_proj_size=FLAGS.dec_proj_size,
        joint_size=FLAGS.joint_size,
    )
    transducer.load_state_dict(checkpoint['model'])
    transducer.eval()
    encoder = transducer.encoder
    decoder = transducer.decoder
    joint = transducer.joint
    return encoder, decoder, joint, tokenizer, transform
Beispiel #7
0
class ParallelTraining(pl.LightningModule):
    def __init__(self):
        super(ParallelTraining, self).__init__()
        _, _, input_size = build_transform(feature_type=FLAGS.feature,
                                           feature_size=FLAGS.feature_size,
                                           n_fft=FLAGS.n_fft,
                                           win_length=FLAGS.win_length,
                                           hop_length=FLAGS.hop_length,
                                           delta=FLAGS.delta,
                                           cmvn=FLAGS.cmvn,
                                           downsample=FLAGS.downsample,
                                           T_mask=FLAGS.T_mask,
                                           T_num_mask=FLAGS.T_num_mask,
                                           F_mask=FLAGS.F_mask,
                                           F_num_mask=FLAGS.F_num_mask)
        self.log_path = None
        self.loss_fn = RNNTLoss(blank=NUL)

        if FLAGS.tokenizer == 'char':
            self.tokenizer = CharTokenizer(cache_dir=self.logdir)
        else:
            self.tokenizer = HuggingFaceTokenizer(cache_dir='BPE-2048',
                                                  vocab_size=FLAGS.bpe_size)
        self.vocab_size = self.tokenizer.vocab_size
        print(FLAGS.enc_type)

        self.model = Transducer(
            vocab_embed_size=FLAGS.vocab_embed_size,
            vocab_size=self.vocab_size,
            input_size=input_size,
            enc_hidden_size=FLAGS.enc_hidden_size,
            enc_layers=FLAGS.enc_layers,
            enc_dropout=FLAGS.enc_dropout,
            enc_proj_size=FLAGS.enc_proj_size,
            dec_hidden_size=FLAGS.dec_hidden_size,
            dec_layers=FLAGS.dec_layers,
            dec_dropout=FLAGS.dec_dropout,
            dec_proj_size=FLAGS.dec_proj_size,
            joint_size=FLAGS.joint_size,
            module_type=FLAGS.enc_type,
            output_loss=False,
        )
        self.latest_alignment = None
        self.steps = 0
        self.epoch = 0
        self.best_wer = 1000

    def warmup_optimizer_step(self, steps):
        if steps < FLAGS.warmup_step:
            lr_scale = min(1., float(steps + 1) / FLAGS.warmup_step * 1.0)
            for pg in self.optimizer.param_groups:
                pg['lr'] = lr_scale * FLAGS.lr

    def forward(self, batch):
        xs, ys, xlen, ylen = batch
        # xs, ys, xlen = xs.cuda(), ys, xlen.cuda()
        alignment = self.model(xs, ys, xlen, ylen)
        return alignment

    def training_step(self, batch, batch_nb):
        xs, ys, xlen, ylen = batch
        # xs, ys, xlen = xs.cuda(), ys, xlen.cuda()
        if xs.shape[1] != xlen.max():
            xs = xs[:, :xlen.max()]
            ys = ys[:, :ylen.max()]
        alignment = self.model(xs, ys, xlen, ylen)
        xlen = self.model.scale_length(alignment, xlen)
        loss = self.loss_fn(alignment, ys.int(), xlen, ylen)

        if batch_nb % 100 == 0:
            lr_val = 0
            for param_group in self.optimizer.param_groups:
                lr_val = param_group['lr']
            self.logger.experiment.add_scalar('lr', lr_val, self.steps)

        self.steps += 1

        if self.steps < FLAGS.warmup_step:
            self.warmup_optimizer_step(self.steps)

        return {'loss': loss, 'log': {'loss': loss.item()}}

    def validation_step(self, batch, batch_nb):
        xs, ys, xlen, ylen = batch
        y, nll = self.model.greedy_decode(xs, xlen)

        hypothesis = self.tokenizer.decode_plus(y)
        ground_truth = self.tokenizer.decode_plus(ys.cpu().numpy())
        measures = jiwer.compute_measures(ground_truth, hypothesis)

        return {
            'val_loss': nll.mean().item(),
            'wer': measures['wer'],
            'ground_truth': ground_truth[0],
            'hypothesis': hypothesis[0]
        }

    def validation_end(self, outputs):
        # OPTIONAL
        self.logger.experiment.add_text('test', 'This is test', 0)

        avg_wer = np.mean([x['wer'] for x in outputs])
        ppl = np.mean([x['val_loss'] for x in outputs])
        self.logger.experiment.add_scalar('val/WER', avg_wer, self.steps)
        self.logger.experiment.add_scalar('val/perplexity', ppl, self.steps)

        hypothesis, ground_truth = '', ''
        for idx in range(min(5, len(outputs))):
            hypothesis += outputs[idx]['hypothesis'] + '\n\n'
            ground_truth += outputs[idx]['ground_truth'] + '\n\n'

        self.logger.experiment.add_text('generated', hypothesis, self.steps)
        self.logger.experiment.add_text('grouth_truth', ground_truth,
                                        self.steps)
        if self.latest_alignment != None:
            alignment = self.latest_alignment
            idx = random.randint(0, alignment.size(0) - 1)
            alignment = torch.softmax(alignment[idx], dim=-1)
            alignment[:, :, 0] = 0  # ignore blank token
            alignment = alignment.mean(dim=-1)

            self.logger.experiment.add_image("alignment",
                                             plot_alignment_to_numpy(
                                                 alignment.data.numpy().T),
                                             self.steps,
                                             dataformats='HWC')
        self.logger.experiment.flush()

        if self.best_wer > avg_wer and self.epoch > 0:
            print('best checkpoint found!')
            # checkpoint = {
            #     'model': self.model.state_dict(),
            #     'optimizer': self.optimizer.state_dict(),
            #     'epoch': self.epoch
            # }
            # if FLAGS.apex:
            #     checkpoint['amp'] = amp.state_dict()
            # torch.save(checkpoint, os.path.join(self.log_path, str(self.epoch)+'amp_checkpoint.pt'))
            self.trainer.save_checkpoint(
                os.path.join(self.log_path,
                             str(self.epoch) + 'amp_checkpoint.pt'))

            self.best_wer = avg_wer

        self.plateau_scheduler.step(avg_wer)
        self.epoch += 1

        return {
            'val/WER': torch.tensor(avg_wer),
            'wer': torch.tensor(avg_wer),
            'val/perplexity': torch.tensor(ppl)
        }

    def validation_epoch_end(self, outputs):
        avg_wer = np.mean([x['wer'] for x in outputs])
        ppl = np.mean([x['val_loss'] for x in outputs])

        hypothesis, ground_truth = '', ''
        for idx in range(5):
            hypothesis += outputs[idx]['hypothesis'] + '\n\n'
            ground_truth += outputs[idx]['ground_truth'] + '\n\n'

        writer.add_text('generated', hypothesis, self.steps)
        writer.add_text('grouth_truth', ground_truth, self.steps)

        if self.latest_alignment != None:
            alignment = self.latest_alignment
            idx = random.randint(0, alignment.size(0) - 1)
            alignment = torch.softmax(alignment[idx], dim=-1)
            alignment[:, :, 0] = 0  # ignore blank token
            alignment = alignment.mean(dim=-1)

            writer.add_image("alignment",
                             plot_alignment_to_numpy(alignment.data.numpy().T),
                             self.steps,
                             dataformats='HWC')

        self.logger.experiment.add_scalar('val/WER', avg_wer, self.steps)
        self.logger.experiment.add_scalar('val/perplexity', ppl, self.steps)
        self.logger.experiment.flush()

        self.plateau_scheduler.step(avg_wer)

        self.epoch += 1
        return {
            'val/WER': torch.tensor(avg_wer),
            'val/perplexity': torch.tensor(ppl)
        }

    def configure_optimizers(self):
        if FLAGS.optim == 'adam':
            self.optimizer = AdamW(self.model.parameters(),
                                   lr=FLAGS.lr,
                                   weight_decay=1e-5)
        elif FLAGS.optim == 'sm3':
            self.optimizer = SM3(self.model.parameters(),
                                 lr=FLAGS.lr,
                                 momentum=0.0)
        else:
            self.optimizer = Novograd(self.model.parameters(),
                                      lr=FLAGS.lr,
                                      weight_decay=1e-3)
        scheduler = []
        if FLAGS.sched:
            self.plateau_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer,
                patience=FLAGS.sched_patience,
                factor=FLAGS.sched_factor,
                min_lr=FLAGS.sched_min_lr,
                verbose=1)
            scheduler = [self.plateau_scheduler]

        self.warmup_optimizer_step(0)
        return [self.optimizer]

    @pl.data_loader
    def train_dataloader(self):
        transform_train, _, _ = build_transform(
            feature_type=FLAGS.feature,
            feature_size=FLAGS.feature_size,
            n_fft=FLAGS.n_fft,
            win_length=FLAGS.win_length,
            hop_length=FLAGS.hop_length,
            delta=FLAGS.delta,
            cmvn=FLAGS.cmvn,
            downsample=FLAGS.downsample,
            T_mask=FLAGS.T_mask,
            T_num_mask=FLAGS.T_num_mask,
            F_mask=FLAGS.F_mask,
            F_num_mask=FLAGS.F_num_mask)

        dataloader = DataLoader(
            dataset=MergedDataset([
                Librispeech(root=FLAGS.LibriSpeech_train_500,
                            tokenizer=self.tokenizer,
                            transform=transform_train,
                            audio_max_length=FLAGS.audio_max_length),
                Librispeech(root=FLAGS.LibriSpeech_train_360,
                            tokenizer=self.tokenizer,
                            transform=transform_train,
                            audio_max_length=FLAGS.audio_max_length),
                # Librispeech(
                #     root=FLAGS.LibriSpeech_train_100,
                #     tokenizer=self.tokenizer,
                #     transform=transform_train,
                #     audio_max_length=FLAGS.audio_max_length),
                TEDLIUM(root=FLAGS.TEDLIUM_train,
                        tokenizer=self.tokenizer,
                        transform=transform_train,
                        audio_max_length=FLAGS.audio_max_length),
                CommonVoice(root=FLAGS.CommonVoice,
                            labels='train.tsv',
                            tokenizer=self.tokenizer,
                            transform=transform_train,
                            audio_max_length=FLAGS.audio_max_length,
                            audio_min_length=1),
                YoutubeCaption(root='../speech_data/youtube-speech-text/',
                               labels='bloomberg2_meta.csv',
                               tokenizer=self.tokenizer,
                               transform=transform_train,
                               audio_max_length=FLAGS.audio_max_length,
                               audio_min_length=1),
                YoutubeCaption(root='../speech_data/youtube-speech-text/',
                               labels='life_meta.csv',
                               tokenizer=self.tokenizer,
                               transform=transform_train,
                               audio_max_length=FLAGS.audio_max_length,
                               audio_min_length=1),
                YoutubeCaption(root='../speech_data/youtube-speech-text/',
                               labels='news_meta.csv',
                               tokenizer=self.tokenizer,
                               transform=transform_train,
                               audio_max_length=FLAGS.audio_max_length,
                               audio_min_length=1),
                YoutubeCaption(root='../speech_data/youtube-speech-text/',
                               labels='english2_meta.csv',
                               tokenizer=self.tokenizer,
                               transform=transform_train,
                               audio_max_length=FLAGS.audio_max_length,
                               audio_min_length=1),
            ]),
            batch_size=FLAGS.sub_batch_size,
            shuffle=True,
            num_workers=FLAGS.num_workers,
            collate_fn=seq_collate,
            drop_last=True)
        return dataloader

    @pl.data_loader
    def val_dataloader(self):
        _, transform_test, _ = build_transform(feature_type=FLAGS.feature,
                                               feature_size=FLAGS.feature_size,
                                               n_fft=FLAGS.n_fft,
                                               win_length=FLAGS.win_length,
                                               hop_length=FLAGS.hop_length,
                                               delta=FLAGS.delta,
                                               cmvn=FLAGS.cmvn,
                                               downsample=FLAGS.downsample,
                                               T_mask=FLAGS.T_mask,
                                               T_num_mask=FLAGS.T_num_mask,
                                               F_mask=FLAGS.F_mask,
                                               F_num_mask=FLAGS.F_num_mask)

        val_dataloader = DataLoader(dataset=MergedDataset([
            Librispeech(root=FLAGS.LibriSpeech_test,
                        tokenizer=self.tokenizer,
                        transform=transform_test,
                        reverse_sorted_by_length=True)
        ]),
                                    batch_size=FLAGS.eval_batch_size,
                                    shuffle=False,
                                    num_workers=FLAGS.num_workers,
                                    collate_fn=seq_collate)
        return val_dataloader
    def __init__(self):
        self.name = FLAGS.name
        self.logdir = os.path.join('logs', FLAGS.name)
        self.model_dir = os.path.join(self.logdir, 'models')

        # Transform
        transform_train, transform_test, input_size = build_transform(
            feature_type=FLAGS.feature,
            feature_size=FLAGS.feature_size,
            n_fft=FLAGS.n_fft,
            win_length=FLAGS.win_length,
            hop_length=FLAGS.hop_length,
            delta=FLAGS.delta,
            cmvn=FLAGS.cmvn,
            downsample=FLAGS.downsample,
            T_mask=FLAGS.T_mask,
            T_num_mask=FLAGS.T_num_mask,
            F_mask=FLAGS.F_mask,
            F_num_mask=FLAGS.F_num_mask)

        # Tokenizer
        if FLAGS.tokenizer == 'char':
            self.tokenizer = CharTokenizer(cache_dir=self.logdir)
        else:
            self.tokenizer = HuggingFaceTokenizer(cache_dir=self.logdir,
                                                  vocab_size=FLAGS.bpe_size)

        # Dataloader
        self.dataloader_train = DataLoader(
            dataset=MergedDataset([
                Librispeech(root=FLAGS.LibriSpeech_train_500,
                            tokenizer=self.tokenizer,
                            transform=transform_train,
                            audio_max_length=FLAGS.audio_max_length),
                Librispeech(root=FLAGS.LibriSpeech_train_360,
                            tokenizer=self.tokenizer,
                            transform=transform_train,
                            audio_max_length=FLAGS.audio_max_length),
                Librispeech(root=FLAGS.LibriSpeech_train_100,
                            tokenizer=self.tokenizer,
                            transform=transform_train,
                            audio_max_length=FLAGS.audio_max_length),
                # TEDLIUM(
                #     root=FLAGS.TEDLIUM_train,
                #     tokenizer=self.tokenizer,
                #     transform=transform_train,
                #     audio_max_length=FLAGS.audio_max_length),
                # CommonVoice(
                #     root=FLAGS.CommonVoice, labels='train.tsv',
                #     tokenizer=self.tokenizer,
                #     transform=transform_train,
                #     audio_max_length=FLAGS.audio_max_length)
            ]),
            batch_size=FLAGS.batch_size,
            shuffle=True,
            num_workers=FLAGS.num_workers,
            collate_fn=seq_collate,
            drop_last=True)

        self.dataloader_val = DataLoader(dataset=MergedDataset([
            Librispeech(root=FLAGS.LibriSpeech_test,
                        tokenizer=self.tokenizer,
                        transform=transform_test,
                        reverse_sorted_by_length=True)
        ]),
                                         batch_size=FLAGS.eval_batch_size,
                                         shuffle=False,
                                         num_workers=FLAGS.num_workers,
                                         collate_fn=seq_collate)

        self.tokenizer.build(self.dataloader_train.dataset.texts())
        self.vocab_size = self.dataloader_train.dataset.tokenizer.vocab_size

        # Model
        self.model = Transducer(
            vocab_embed_size=FLAGS.vocab_embed_size,
            vocab_size=self.vocab_size,
            input_size=input_size,
            enc_hidden_size=FLAGS.enc_hidden_size,
            enc_layers=FLAGS.enc_layers,
            enc_dropout=FLAGS.enc_dropout,
            enc_proj_size=FLAGS.enc_proj_size,
            dec_hidden_size=FLAGS.dec_hidden_size,
            dec_layers=FLAGS.dec_layers,
            dec_dropout=FLAGS.dec_dropout,
            dec_proj_size=FLAGS.dec_proj_size,
            joint_size=FLAGS.joint_size,
        ).to(device)

        # Optimizer
        if FLAGS.optim == 'adam':
            self.optim = optim.Adam(self.model.parameters(), lr=FLAGS.lr)
        else:
            self.optim = optim.SGD(self.model.parameters(),
                                   lr=FLAGS.lr,
                                   momentum=0.9)
        # Scheduler
        if FLAGS.sched:
            self.sched = optim.lr_scheduler.ReduceLROnPlateau(
                self.optim,
                patience=FLAGS.sched_patience,
                factor=FLAGS.sched_factor,
                min_lr=FLAGS.sched_min_lr,
                verbose=1)
        # Apex
        if FLAGS.apex:
            self.model, self.optim = amp.initialize(self.model,
                                                    self.optim,
                                                    opt_level=FLAGS.opt_level)
        # Multi GPU
        if FLAGS.multi_gpu:
            self.model = torch.nn.DataParallel(self.model)
class Trainer:
    def __init__(self):
        self.name = FLAGS.name
        self.logdir = os.path.join('logs', FLAGS.name)
        self.model_dir = os.path.join(self.logdir, 'models')

        # Transform
        transform_train, transform_test, input_size = build_transform(
            feature_type=FLAGS.feature,
            feature_size=FLAGS.feature_size,
            n_fft=FLAGS.n_fft,
            win_length=FLAGS.win_length,
            hop_length=FLAGS.hop_length,
            delta=FLAGS.delta,
            cmvn=FLAGS.cmvn,
            downsample=FLAGS.downsample,
            T_mask=FLAGS.T_mask,
            T_num_mask=FLAGS.T_num_mask,
            F_mask=FLAGS.F_mask,
            F_num_mask=FLAGS.F_num_mask)

        # Tokenizer
        if FLAGS.tokenizer == 'char':
            self.tokenizer = CharTokenizer(cache_dir=self.logdir)
        else:
            self.tokenizer = HuggingFaceTokenizer(cache_dir=self.logdir,
                                                  vocab_size=FLAGS.bpe_size)

        # Dataloader
        self.dataloader_train = DataLoader(
            dataset=MergedDataset([
                Librispeech(root=FLAGS.LibriSpeech_train_500,
                            tokenizer=self.tokenizer,
                            transform=transform_train,
                            audio_max_length=FLAGS.audio_max_length),
                Librispeech(root=FLAGS.LibriSpeech_train_360,
                            tokenizer=self.tokenizer,
                            transform=transform_train,
                            audio_max_length=FLAGS.audio_max_length),
                Librispeech(root=FLAGS.LibriSpeech_train_100,
                            tokenizer=self.tokenizer,
                            transform=transform_train,
                            audio_max_length=FLAGS.audio_max_length),
                # TEDLIUM(
                #     root=FLAGS.TEDLIUM_train,
                #     tokenizer=self.tokenizer,
                #     transform=transform_train,
                #     audio_max_length=FLAGS.audio_max_length),
                # CommonVoice(
                #     root=FLAGS.CommonVoice, labels='train.tsv',
                #     tokenizer=self.tokenizer,
                #     transform=transform_train,
                #     audio_max_length=FLAGS.audio_max_length)
            ]),
            batch_size=FLAGS.batch_size,
            shuffle=True,
            num_workers=FLAGS.num_workers,
            collate_fn=seq_collate,
            drop_last=True)

        self.dataloader_val = DataLoader(dataset=MergedDataset([
            Librispeech(root=FLAGS.LibriSpeech_test,
                        tokenizer=self.tokenizer,
                        transform=transform_test,
                        reverse_sorted_by_length=True)
        ]),
                                         batch_size=FLAGS.eval_batch_size,
                                         shuffle=False,
                                         num_workers=FLAGS.num_workers,
                                         collate_fn=seq_collate)

        self.tokenizer.build(self.dataloader_train.dataset.texts())
        self.vocab_size = self.dataloader_train.dataset.tokenizer.vocab_size

        # Model
        self.model = Transducer(
            vocab_embed_size=FLAGS.vocab_embed_size,
            vocab_size=self.vocab_size,
            input_size=input_size,
            enc_hidden_size=FLAGS.enc_hidden_size,
            enc_layers=FLAGS.enc_layers,
            enc_dropout=FLAGS.enc_dropout,
            enc_proj_size=FLAGS.enc_proj_size,
            dec_hidden_size=FLAGS.dec_hidden_size,
            dec_layers=FLAGS.dec_layers,
            dec_dropout=FLAGS.dec_dropout,
            dec_proj_size=FLAGS.dec_proj_size,
            joint_size=FLAGS.joint_size,
        ).to(device)

        # Optimizer
        if FLAGS.optim == 'adam':
            self.optim = optim.Adam(self.model.parameters(), lr=FLAGS.lr)
        else:
            self.optim = optim.SGD(self.model.parameters(),
                                   lr=FLAGS.lr,
                                   momentum=0.9)
        # Scheduler
        if FLAGS.sched:
            self.sched = optim.lr_scheduler.ReduceLROnPlateau(
                self.optim,
                patience=FLAGS.sched_patience,
                factor=FLAGS.sched_factor,
                min_lr=FLAGS.sched_min_lr,
                verbose=1)
        # Apex
        if FLAGS.apex:
            self.model, self.optim = amp.initialize(self.model,
                                                    self.optim,
                                                    opt_level=FLAGS.opt_level)
        # Multi GPU
        if FLAGS.multi_gpu:
            self.model = torch.nn.DataParallel(self.model)

    def scale_length(self, prob, xlen):
        scale = (xlen.max().float() / prob.shape[1]).ceil()
        xlen = (xlen / scale).ceil().int()
        return xlen

    def train(self, start_step=1):
        if FLAGS.mode == "resume":
            exist_ok = True
        else:
            exist_ok = False
        os.makedirs(self.model_dir, exist_ok=exist_ok)
        writer = SummaryWriter(self.logdir)
        writer.add_text('flagfile',
                        FLAGS.flags_into_string().replace('\n', '\n\n'))
        FLAGS.append_flags_into_file(os.path.join(self.logdir, 'flagfile.txt'))

        looper = infloop(self.dataloader_train)
        losses = []
        steps = len(self.dataloader_train) * FLAGS.epochs
        with trange(start_step, steps + 1, dynamic_ncols=True) as pbar:
            for step in pbar:
                if step <= FLAGS.warmup_step:
                    scale = step / FLAGS.warmup_step
                    self.optim.param_groups[0]['lr'] = FLAGS.lr * scale
                batch, epoch = next(looper)
                loss = self.train_step(batch)
                losses.append(loss)
                lr = self.optim.param_groups[0]['lr']
                pbar.set_description('Epoch %d, loss: %.4f, lr: %.3E' %
                                     (epoch, loss, lr))

                if step % FLAGS.loss_step == 0:
                    train_loss = torch.stack(losses).mean()
                    losses = []
                    writer.add_scalar('train_loss', train_loss, step)

                if step % FLAGS.save_step == 0:
                    self.save(step)

                if step % FLAGS.eval_step == 0:
                    pbar.set_description('Evaluating ...')
                    val_loss, wer, pred_seqs, true_seqs = self.evaluate()
                    if FLAGS.sched:
                        self.sched.step(val_loss)
                    writer.add_scalar('WER', wer, step)
                    writer.add_scalar('val_loss', val_loss, step)
                    for i in range(FLAGS.sample_size):
                        log = "`%s`\n\n`%s`" % (true_seqs[i], pred_seqs[i])
                        writer.add_text('val/%d' % i, log, step)
                    pbar.write('Epoch %d, step %d, loss: %.4f, WER: %.4f' %
                               (epoch, step, val_loss, wer))

    def train_step(self, batch):
        sub_losses = []
        start_idxs = range(0, FLAGS.batch_size, FLAGS.sub_batch_size)
        self.optim.zero_grad()
        for sub_batch_idx, start_idx in enumerate(start_idxs):
            sub_slice = slice(start_idx, start_idx + FLAGS.sub_batch_size)
            xs, ys, xlen, ylen = [x[sub_slice].to(device) for x in batch]
            xs = xs[:, :xlen.max()].contiguous()
            ys = ys[:, :ylen.max()].contiguous()
            loss = self.model(xs, ys, xlen, ylen)
            if FLAGS.multi_gpu:
                loss = loss.mean() / len(start_idxs)
            else:
                loss = loss / len(start_idxs)
            if FLAGS.apex:
                delay_unscale = sub_batch_idx < len(start_idxs) - 1
                with amp.scale_loss(
                        loss, self.optim,
                        delay_unscale=delay_unscale) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            sub_losses.append(loss.detach())

        if FLAGS.gradclip is not None:
            if FLAGS.apex:
                parameters = amp.master_params(self.optim)
            else:
                parameters = self.model.parameters()
            torch.nn.utils.clip_grad_norm_(parameters, FLAGS.gradclip)
        self.optim.step()

        loss = torch.stack(sub_losses).sum()
        return loss

    def evaluate(self):
        self.model.eval()
        wers = []
        losses = []
        pred_seqs = []
        true_seqs = []
        with torch.no_grad():
            with tqdm(self.dataloader_val, dynamic_ncols=True) as pbar:
                for batch in pbar:
                    loss, wer, pred_seq, true_seq = self.evaluate_step(batch)
                    wers.append(wer)
                    losses.append(loss)
                    sample_nums = FLAGS.sample_size - len(pred_seqs)
                    pred_seqs.extend(pred_seq[:sample_nums])
                    true_seqs.extend(true_seq[:sample_nums])
                    pbar.set_description('wer: %.4f, loss: %.4f' % (wer, loss))
        loss = np.mean(losses)
        wer = np.mean(wers)
        self.model.train()
        return loss, wer, pred_seqs, true_seqs

    def evaluate_step(self, batch):
        xs, ys, xlen, ylen = [x.to(device) for x in batch]
        xs = xs[:, :xlen.max()]
        ys = ys[:, :ylen.max()].contiguous()
        loss = self.model(xs, ys, xlen, ylen)
        if FLAGS.multi_gpu:
            loss = loss.mean()
        if FLAGS.multi_gpu:
            ys_hat, nll = self.model.module.greedy_decode(xs, xlen)
        else:
            ys_hat, nll = self.model.greedy_decode(xs, xlen)
        pred_seq = self.tokenizer.decode_plus(ys_hat)
        true_seq = self.tokenizer.decode_plus(ys.cpu().numpy())
        wer = jiwer.wer(true_seq, pred_seq)
        return loss.item(), wer, pred_seq, true_seq

    def save(self, step):
        checkpoint = {'optim': self.optim.state_dict()}

        if FLAGS.multi_gpu:
            checkpoint.update({'model': self.model.module.state_dict()})
        else:
            checkpoint.update({'model': self.model.state_dict()})

        if self.sched is not None:
            checkpoint.update({'sched': self.sched.state_dict()})

        if FLAGS.apex:
            checkpoint.update({'amp': amp.state_dict()})

        path = os.path.join(self.model_dir, '%d.pt' % step)
        torch.save(checkpoint, path)

    def load(self, path):
        checkpoint = torch.load(path)
        # self.optim.load_state_dict(checkpoint['optim'])

        if FLAGS.multi_gpu:
            self.model.module.load_state_dict(checkpoint['model'])
        else:
            self.model.load_state_dict(checkpoint['model'])

        if self.sched is not None:
            self.sched.load_state_dict(checkpoint['sched'])

        if FLAGS.apex:
            amp.load_state_dict(checkpoint['amp'])

    def sanity_check(self):
        self.model.eval()
        batch = next(iter(self.dataloader_val))
        self.evaluate_step(batch)
        self.model.train()
Beispiel #10
0
    def __init__(self):
        self.name = FLAGS.name
        self.logdir = os.path.join('logs', FLAGS.name)
        self.model_dir = os.path.join(self.logdir, 'models')

        # Transform
        transform = torch.nn.Sequential(
            TrimAudio(sampling_rate=16000, max_audio_length=FLAGS.audio_max_length)
        )
        transform_train, transform_test = transform, transform

        # Tokenizer
        if FLAGS.tokenizer == 'char':
            self.tokenizer = CharTokenizer(cache_dir=self.logdir)
        else:
            self.tokenizer = HuggingFaceTokenizer(
                cache_dir='BPE-2048', vocab_size=FLAGS.bpe_size)

        # Dataloader
        self.dataloader_train = DataLoader(
            dataset=MergedDataset([
                Librispeech(
                    root=FLAGS.LibriSpeech_train_100,
                    tokenizer=self.tokenizer,
                    transform=transform_train,
                    audio_max_length=FLAGS.audio_max_length),
                Librispeech(
                    root=FLAGS.LibriSpeech_dev,
                    tokenizer=self.tokenizer,
                    transform=transform_train,
                    audio_max_length=FLAGS.audio_max_length),
            ]),
            batch_size=FLAGS.batch_size, shuffle=True,
            num_workers=FLAGS.num_workers, collate_fn=seq_collate,
            drop_last=True
        )

        self.dataloader_val = DataLoader(
            dataset=MergedDataset([
                Librispeech(
                    root=FLAGS.LibriSpeech_test,
                    tokenizer=self.tokenizer,
                    transform=transform_test,
                    reverse_sorted_by_length=True)]),
            batch_size=FLAGS.eval_batch_size, shuffle=False,
            num_workers=FLAGS.num_workers, collate_fn=seq_collate)

        self.tokenizer.build(self.dataloader_train.dataset.texts())
        self.vocab_size = self.dataloader_train.dataset.tokenizer.vocab_size

        # Model
        self.frontend = FrontEnd(
            frontend_params = [(10, 5, 32)]+[(3, 2, 128)]*4 + [(2,2,128)] *3,
            bias=True,
        )

        self.model = Transducer(
            vocab_embed_size=FLAGS.vocab_embed_size,
            vocab_size=self.vocab_size,
            input_size=128,
            enc_hidden_size=FLAGS.enc_hidden_size,
            enc_layers=FLAGS.enc_layers,
            enc_dropout=FLAGS.enc_dropout,
            enc_proj_size=FLAGS.enc_proj_size,
            enc_time_reductions=[],
            dec_hidden_size=FLAGS.dec_hidden_size,
            dec_layers=FLAGS.dec_layers,
            dec_dropout=FLAGS.dec_dropout,
            dec_proj_size=FLAGS.dec_proj_size,
            joint_size=FLAGS.joint_size,
        )
        if FLAGS.use_pretrained:
            self.frontend, self.model = load_pretrained_model(self.frontend, self.model)
            print('load pretrained model')

        self.frontend = self.frontend.to(device)
        self.model = self.model.to(device)

        # Optimizer
        if FLAGS.optim == 'adam':
            self.optim = optim.Adam(
                list(self.model.parameters())+list(self.frontend.parameters()), lr=FLAGS.lr)
        else:
            self.optim = optim.SGD(
                list(self.model.parameters())+list(self.frontend.parameters()), lr=FLAGS.lr, momentum=0.9)
        # Scheduler
        if FLAGS.sched:
            self.sched = optim.lr_scheduler.ReduceLROnPlateau(
                self.optim, patience=FLAGS.sched_patience,
                factor=FLAGS.sched_factor, min_lr=FLAGS.sched_min_lr,
                verbose=1)
        # Apex
        if FLAGS.apex:
            self.model, self.optim = amp.initialize(
                self.model, self.optim, opt_level=FLAGS.opt_level)
        # Multi GPU
        if FLAGS.multi_gpu:
            self.model = torch.nn.DataParallel(self.model)
Beispiel #11
0
                                                                  reduce=False)
            for key, value in logging_output.items():
                if key not in logging_outputs:
                    logging_outputs[key] = []
                if FLAGS.multi_gpu and isinstance(value, torch.Tensor):
                    value = value.mean()

                logging_outputs[key].append(value)

    model.train()
    return {key: np.mean(scores) for key, scores in logging_outputs.items()}


if __name__ == '__main__':
    # tokenizer is not needed in this stage
    tokenizer = HuggingFaceTokenizer(cache_dir='BPE-2048', vocab_size=2048)

    transform = torch.nn.Sequential(
        TrimAudio(sampling_rate=16000, max_audio_length=15))

    dataloader = DataLoader(dataset=MergedDataset([
        YoutubeCaption(
            '../yt_speech/',
            labels='news_dummy.csv',
            tokenizer=tokenizer,
            transform=transform,
            audio_max_length=14,
        ),
        YoutubeCaption(
            '../yt_speech/',
            labels='life_dummy.csv',
def main(argv):
    assert FLAGS.step_n_frame % 2 == 0, ("step_n_frame must be divisible by "
                                         "reduction_factor of TimeReduction")

    tokenizer = HuggingFaceTokenizer(cache_dir=os.path.join(
        'logs', FLAGS.name),
                                     vocab_size=FLAGS.bpe_size)

    dataloader = DataLoader(dataset=MergedDataset([
        Librispeech(root=FLAGS.LibriSpeech_test,
                    tokenizer=tokenizer,
                    transform=None,
                    reverse_sorted_by_length=True)
    ]),
                            batch_size=1,
                            shuffle=False,
                            num_workers=0)

    pytorch_decoder = PytorchStreamDecoder(FLAGS)
    # pytorch_decoder.reset_profile()
    # wers = []
    # total_time = 0
    # total_frame = 0
    # with tqdm(dataloader, dynamic_ncols=True) as pbar:
    #     pbar.set_description("Pytorch full sequence decode")
    #     for waveform, tokens in pbar:
    #         true_seq = tokenizer.decode(tokens[0].numpy())
    #         # pytorch: Encode waveform at a time
    #         start = time.time()
    #         pred_seq, frames = fullseq_decode(pytorch_decoder, waveform)
    #         # pbar.write(true_seq)
    #         # pbar.write(pred_seq)
    #         elapsed = time.time() - start
    #         total_time += elapsed
    #         total_frame += frames
    #         wer = jiwer.wer(true_seq, pred_seq)
    #         wers.append(wer)
    #         pbar.set_postfix(wer='%.3f' % wer, elapsed='%.3f' % elapsed)
    # wer = np.mean(wers)
    # print('Mean wer: %.3f, Frame: %d, Time: %.3f, FPS: %.3f, speed: %.3f' % (
    #     wer, total_frame, total_time, total_frame / total_time,
    #     total_frame / total_time / 16000))

    pytorch_decoder.reset_profile()
    wers = []
    total_time = 0
    total_frame = 0
    with tqdm(dataloader, dynamic_ncols=True) as pbar:
        pbar.set_description("Pytorch frame wise decode")
        for waveform, tokens in pbar:
            true_seq = tokenizer.decode(tokens[0].numpy())
            # pytorch: Encode waveform at a time
            start = time.time()
            pred_seq, frames = stream_decode(pytorch_decoder, waveform)
            elapsed = time.time() - start
            total_time += elapsed
            total_frame += frames
            wer = jiwer.wer(true_seq, pred_seq)
            wers.append(wer)
            pbar.set_postfix(wer='%.3f' % wer, elapsed='%.3f' % elapsed)
    wer = np.mean(wers)
    print('Mean wer: %.3f, Frame: %d, Time: %.3f, FPS: %.3f, speed: %.3f' %
          (wer, total_frame, total_time, total_frame / total_time,
           total_frame / total_time / 16000))
    print("Mean encoding time: %.3f ms" %
          (1000 * np.mean(pytorch_decoder.encoder_elapsed)))
    print("Mean decoding time: %.3f ms" %
          (1000 * np.mean(pytorch_decoder.decoder_elapsed)))
    print("Mean joint time: %.3f ms" %
          (1000 * np.mean(pytorch_decoder.joint_elapsed)))

    openvino_decoder = OpenVINOStreamDecoder(FLAGS)
    openvino_decoder.reset_profile()
    wers = []
    total_time = 0
    total_frame = 0
    with tqdm(dataloader, dynamic_ncols=True) as pbar:
        pbar.set_description("OpenVINO frame wise decode")
        for waveform, tokens in pbar:
            true_seq = tokenizer.decode(tokens[0].numpy())
            # pytorch: Encode waveform at a time
            start = time.time()
            pred_seq, frames = stream_decode(openvino_decoder, waveform)
            # pbar.write(true_seq)
            # pbar.write(pred_seq)
            elapsed = time.time() - start
            total_time += elapsed
            total_frame += frames
            wer = jiwer.wer(true_seq, pred_seq)
            wers.append(wer)
            pbar.set_postfix(wer='%.3f' % wer, elapsed='%.3f' % elapsed)
    wer = np.mean(wers)
    print('Mean wer: %.3f, Frame: %d, Time: %.3f, FPS: %.3f, speed: %.3f' %
          (wer, total_frame, total_time, total_frame / total_time,
           total_frame / total_time / 16000))
    print("Mean encoding time: %.3f ms" %
          (1000 * np.mean(openvino_decoder.encoder_elapsed)))
    print("Mean decoding time: %.3f ms" %
          (1000 * np.mean(openvino_decoder.decoder_elapsed)))
    print("Mean joint time: %.3f ms" %
          (1000 * np.mean(openvino_decoder.joint_elapsed)))