def __init__(self, config, paras, id2accent):
        super(FOMetaASRInterface, self).__init__(config, paras, id2accent)

        assert paras.meta_k is not None
        self.meta_k = paras.meta_k
        if paras.meta_batch_size is None:
            logger.log("Meta batch_size not set...", prefix='info')
            self.meta_batch_size = self.num_pretrain

        self.asr_model = None
        self.asr_opt = None

        self.max_step = paras.max_step if paras.max_step > 0 else config['solver']['total_steps']
        self.dashboard.set_status('pretraining')

        self._train = partial(self.run_batch, train=True)
        self._eval = partial(self.run_batch, train=False)

        self.meta_batch_size = paras.meta_batch_size
        self._updates = None
        self._counter = 0

        d_model = config['asr_model']['d_model']
        opt_k = config['asr_model']['meta']['optimizer_opt']['k']
        warmup_steps = config['asr_model']['meta']['optimizer_opt']['warmup_steps']
        self.inner_lr = d_model ** (-0.5) * opt_k * ((warmup_steps) ** (-0.5))

        logger.notice("Meta Interface Information")
        logger.log(f"Sampling strategy: {self.sample_strategy}", prefix='info')
        logger.log(f"Meta batch size  : {self.meta_batch_size}", prefix='info')
        logger.log(f"# inner-loop step: {self.meta_k}", prefix='info')
        # Max. of lr in noam
        logger.log( "Inner loop lr    : {}".format(round(self.inner_lr, 5)),prefix='info')
    def load_data(self):
        logger.notice(f"Loading data from {self.data_dir}")
        if self.model_name == 'blstm':
            self.id2ch = [BLANK_SYMBOL]
        elif self.model_name == 'transformer':
            self.id2ch = [SOS_SYMBOL]
        else:
            raise NotImplementedError

        with open(self.config['solver']['spm_mapping']) as fin:
            for line in fin.readlines():
                self.id2ch.append(line.rstrip().split(' ')[0])
            self.id2ch.append(EOS_SYMBOL)
        logger.log(f"Train units: {self.id2ch}")

        setattr(self, 'eval_set', get_loader(
            self.data_dir.joinpath('test'),
            # self.data_dir.joinpath('dev'),
            batch_size = self.batch_size,
            half_batch_ilen = 512 if self.batch_size > 1 else None,
            is_memmap = self.is_memmap,
            is_bucket = False,
            shuffle = False,
            num_workers = 1,
        ))
    def set_model(self):

        logger.notice(f"Load trained ASR model from {self.model_path}")
        if self.model_name == 'blstm':
            from src.model.blstm.mono_blstm import MonoBLSTM as ASRModel
        elif self.model_name == 'las':
            from src.model.seq2seq.mono_las import MonoLAS as ASRModel
        elif self.model_name == 'transformer':
            from src.model.transformer_pytorch.mono_transformer_torch import MyTransformer as ASRModel
        else:
            raise NotImplementedError
        self.asr_model = ASRModel(self.id2ch, self.config['asr_model'])

        if self.use_gpu:
            self.asr_model.load_state_dict(torch.load(self.model_path))
            self.asr_model = self.asr_model.cuda()
        else:
            self.asr_model.load_state_dict(torch.load(self.model_path, map_location=torch.device('cpu')))
        self.asr_model.eval()
        logger.log(f'ASR model device {self.asr_model.device}', prefix='debug')
        self.sos_id = self.asr_model.sos_id
        self.eos_id = self.asr_model.eos_id
        self.blank_id = None if self.model_name != 'blstm' else self.asr_model.blank_id

        #FIXME: will merge it later
        if self.decode_mode != 'greedy':
            if self.model_name == 'blstm':
                raise NotImplementedError
            elif self.model_name == 'transformer':
                raise NotImplementedError
            else:
                raise NotImplementedError
    def load_data(self):
        logger.notice(
            f"Loading data from {self.data_dir} with {self.paras.njobs} threads"
        )

        #TODO: combine the following with Metric
        self.id2ch = self.id2units

        setattr(
            self,
            'train_set',
            get_loader(
                self.data_dir.joinpath('train'),
                batch_size=self.config['solver']['batch_size'],
                min_ilen=self.config['solver']['min_ilen'],
                max_ilen=self.config['solver']['max_ilen'],
                half_batch_ilen=self.config['solver']['half_batch_ilen'],
                # bucket_reverse=True,
                bucket_reverse=False,
                is_memmap=self.is_memmap,
                is_bucket=self.is_bucket,
                num_workers=self.paras.njobs,
                # shuffle=False, #debug
            ))
        setattr(
            self, 'dev_set',
            get_loader(
                self.data_dir.joinpath('dev'),
                batch_size=self.config['solver']['dev_batch_size'],
                is_memmap=self.is_memmap,
                is_bucket=False,
                shuffle=False,
                num_workers=self.paras.njobs,
            ))
def get_loader(data_dir, batch_size, is_memmap, is_bucket, num_workers=0, 
               min_ilen=None, max_ilen=None, half_batch_ilen=None, 
               bucket_reverse=False, shuffle=True, read_file=False, 
               drop_last=False, pin_memory=True):

    assert not read_file, "Load from Kaldi ark haven't been implemented yet"
    dset = ESPDataset(data_dir, is_memmap)

    # if data is already loaded in memory
    if not is_memmap: 
        num_workers = 0

    logger.notice(f"Loading data from {data_dir} with {num_workers} threads")

    if is_bucket:
        my_sampler = BucketSampler(dset.ilens, 
                                   min_ilen = min_ilen, 
                                   max_ilen = max_ilen, 
                                   half_batch_ilen = half_batch_ilen, 
                                   batch_size=batch_size, 
                                   bucket_size=BUCKET_SIZE, 
                                   bucket_reverse=bucket_reverse, 
                                   drop_last = drop_last)

        loader = DataLoader(dset, batch_size=1, num_workers=num_workers,
                            collate_fn=collate_fn, batch_sampler=my_sampler,
                            drop_last=drop_last, pin_memory=pin_memory)
    else:
        loader = DataLoader(dset, batch_size=batch_size, num_workers=num_workers,
                            collate_fn=collate_fn, shuffle=shuffle,
                            drop_last=drop_last, pin_memory=pin_memory)

    return loader
def get_trainer(cls, config, paras, id2accent):
    logger.notice("LAS Trainer Inint...")

    class LASTrainer(cls):
        def __init__(self, config, paras, id2accent):
            super(LASTrainer, self).__init__(config, paras, id2accent)

        def set_model(self):
            self.asr_model = MonoLAS(self.id2ch,
                                     self.config['asr_model']).cuda()
            self.seq_loss = nn.CrossEntropyLoss(ignore_index=IGNORE_ID,
                                                reduction='none')

            self.sos_id = self.asr_model.sos_id
            self.eos_id = self.asr_model.eos_id

            assert self.config['asr_model']['optimizer'][
                'type'] == 'adadelta', "Use AdaDelta for pure seq2seq"
            self.asr_opt = getattr(torch.optim,\
                                   self.config['asr_model']['optimizer']['type'])
            self.asr_opt = self.asr_opt(self.asr_model.parameters(),\
                                        **self.config['asr_model']['optimizer_opt'])

            # self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            # self.asr_opt, mode='min',
            # factor=0.2, patience=3,
            # verbose=True)
            super().load_model()

        def exec(self):
            self.train()

        def run_batch(self, cur_b, x, ilens, ys, olens, train):
            sos = ys[0].new([self.sos_id])
            eos = ys[0].new([self.eos_id])
            ys_out = [torch.cat]

            if train:
                info = {'loss': loss.item(), 'acc': acc}
                # if self.global_step % 5 == 0:
                if self.global_step % 500 == 0:
                    self.probe_model(pred, ys)
                self.asr_opt.zero_grad()
                loss.backward()
            else:
                wer = self.metric_observer.batch_cal_wer(
                    pred.detach(), ys, ['att'])['att']
                info = {'wer': wer, 'loss': loss.item(), 'acc': acc}

            return info

        def probe_model(self, pred, ys_out):
            self.metric_observer.cal_att_wer(torch.argmax(pred[0], dim=-1),
                                             ys[0],
                                             show=True)

    return LASTrainer(config, paras, id2accent)
    def load_model(self):
        logger.log("MultiASR model for pretraining initialization")

        if self.paras.resume:
            logger.notice(f"Resume pretraining from {self.global_step}")
            self.asr_model.load_state_dict(torch.load(self.resume_model_path))
            self.dashboard.set_step(self.global_step)
        else:
            logger.notice(f"Start pretraining from {self.global_step}")
예제 #8
0
    def save_best_model(self, tpe='wer', only_stat=False):
        assert self.asr_model is not None

        if not only_stat:
            model_save_path = self.log_dir.joinpath(f'model.{tpe}.best')
            logger.notice('Current best {}: {:3f}, save model to {}'.format(
                tpe.upper(), getattr(self, f'best_{tpe}'), model_save_path))
            torch.save(self.asr_model.state_dict(), model_save_path)

        with open(self.log_dir.joinpath(f'best_{tpe}'), 'w') as fout:
            print('{} {}'.format(self.global_step, \
                                 getattr(self,f'best_{tpe}')), file=fout)
    def train(self):
        try:
            while self.global_step < self.max_step:
                tbar = get_bar(total=self.eval_ival, \
                               desc=f"Step {self.global_step}", leave=True)

                for _ in range(self.eval_ival):
                    #TODO: we can add sampling method to compare Meta and Multi fair
                    idx, (x, ilens, ys,
                          olens) = self.data_container.get_item()[0]

                    batch_size = len(ys)
                    info = self._train(idx,
                                       x,
                                       ilens,
                                       ys,
                                       olens,
                                       accent_idx=idx)
                    self.train_info.add(info, batch_size)

                    grad_norm = nn.utils.clip_grad_norm_(
                        self.asr_model.parameters(), GRAD_CLIP)

                    if math.isnan(grad_norm):
                        logger.warning(
                            f"grad norm NaN @ step {self.global_step}")
                    else:
                        self.asr_opt.step()

                    if isinstance(self.asr_opt, TransformerOptimizer):
                        self.log_msg(self.asr_opt.lr)
                    else:
                        self.log_msg()
                    self.check_evaluate()

                    self.global_step += 1
                    self.dashboard.step()

                    del x, ilens, ys, olens
                    tbar.update(1)

                    if self.global_step % self.save_ival == 0:
                        self.save_per_steps()
                    self.dashboard.check()
                tbar.close()

        except KeyboardInterrupt:
            logger.warning("Pretraining stopped")
            self.save_per_steps()
            self.dashboard.set_status('pretrained(SIGINT)')
        else:
            logger.notice("Pretraining completed")
            self.dashboard.set_status('pretrained')
예제 #10
0
    def train(self):
        self.evaluate()
        try:
            if self.save_verbose:
                self.save_init()
            while self.ep < self.max_epoch:
                tbar = get_bar(total=len(self.train_set), \
                               desc=f"Epoch {self.ep}", leave=True)

                for cur_b, (x, ilens, ys, olens) in enumerate(self.train_set):

                    batch_size = len(ys)
                    info = self._train(cur_b, x, ilens, ys, olens)
                    self.train_info.add(info, batch_size)

                    grad_norm = nn.utils.clip_grad_norm_(
                        self.asr_model.parameters(), GRAD_CLIP)

                    if math.isnan(grad_norm):
                        logger.warning(
                            f"grad norm NaN @ step {self.global_step}")
                    else:
                        self.asr_opt.step()

                    if isinstance(self.asr_opt, TransformerOptimizer):
                        self.log_msg(self.asr_opt.lr)
                    else:
                        self.log_msg()
                    self.check_evaluate()

                    self.global_step += 1
                    self.dashboard.step()

                    del x, ilens, ys, olens
                    tbar.update(1)

                self.ep += 1
                self.save_per_epoch()
                self.dashboard.check()
                tbar.close()

                if self.eval_every_epoch:
                    self.evaluate()

        except KeyboardInterrupt:
            logger.warning("Training stopped")
            self.evaluate()
            self.dashboard.set_status('trained(SIGINT)')
        else:
            logger.notice("Training completed")
            self.dashboard.set_status('trained')
    def train(self):

        try:
            task_ids = list(range(self.num_pretrain))
            while self.global_step < self.max_step:
                tbar = get_bar(total=self.eval_ival, \
                               desc=f"Step {self.global_step}", leave=True)
                for _ in range(self.eval_ival):
                    shuffle(task_ids)

                    #FIXME: Here split to inner-train and inner-test (should observe whether the performance drops)
                    for accent_id in task_ids[:self.meta_batch_size]:
                        # inner-loop learn
                        tr_batches = self.data_container.get_item(accent_id, self.meta_k)
                        self.run_task(tr_batches)

                        # inner-loop test
                        val_batch = self.data_container.get_item(accent_id)[0]
                        batch_size = len(val_batch[1][2])
                        info = self._train(val_batch[0],*val_batch[1], accent_idx = val_batch[0])
                        grad_norm = nn.utils.clip_grad_norm_(
                            self.asr_model.parameters(), GRAD_CLIP)

                        if math.isnan(grad_norm):
                            logger.warning(f"grad norm NaN @ step {self.global_step} on {self.accents[accent_id]}, ignore...")

                        self._partial_meta_update()
                        del val_batch
                        self.train_info.add(info, batch_size)

                    self._final_meta_update()

                    self.log_msg(self.meta_opt.lr)
                    self.check_evaluate()
                    self.global_step += 1
                    self.dashboard.step()
                    tbar.update(1)

                    if self.global_step % self.save_ival == 0:
                        self.save_per_steps()
                    self.dashboard.check()
                tbar.close()

        except KeyboardInterrupt:
            logger.warning("Pretraining stopped")
            self.save_per_steps()
            self.dashboard.set_status('pretrained(SIGINT)')
        else:
            logger.notice("Pretraining completed")
            self.dashboard.set_status('pretrained')
예제 #12
0
    def load_model(self):
        logger.log("ASR model initialization")

        if self.paras.resume:
            logger.notice(
                f"Resume training from epoch {self.ep} (best cer: {self.best_cer}, best wer: {self.best_wer})"
            )
            self.asr_model.load_state_dict(torch.load(self.resume_model_path))
            if isinstance(self.asr_opt, TransformerOptimizer):
                with open(self.optimizer_path, 'rb') as fin:
                    self.asr_opt = pickle.load(fin)
            else:
                self.asr_opt.load_state_dict(torch.load(self.optimizer_path))
            self.dashboard.set_step(self.global_step)
        elif self.paras.pretrain:
            model_dict = self.asr_model.state_dict()
            logger.notice(
                f"Load pretraining {','.join(self.pretrain_module)} from {self.pretrain_model_path}"
            )
            pretrain_dict = self.filter_model(
                torch.load(self.pretrain_model_path))
            model_dict.update(pretrain_dict)
            self.asr_model.load_state_dict(model_dict)
            if 'freeze_module' in self.config['solver']:
                logger.warning(
                    "Part of model will be frozen during fine-tuning")
                self.freeze_module(self.config['solver']['freeze_module'])
            logger.notice("Done!")
        else:  # simple monolingual training from step 0
            logger.notice("Training from scratch")
def get_trainer(cls, config, paras, id2accent):
    logger.notice("Transformer Trainer Init...")

    class TransformerTrainer(cls):
        def __init__(self, config, paras, id2accent):
            super(TransformerTrainer, self).__init__(config, paras, id2accent)

        def set_model(self):
            self.asr_model = Transformer(self.id2ch, self.config['asr_model']).cuda()
            self.asr_opt = optim.RAdam(self.asr_model.parameters(), betas=(0.9, 0.98), eps=1e-9)
            # self.asr_opt = TransformerOptimizer(
                # torch.optim.Adam(self.asr_model.parameters(), betas=(0.9, 0.98), eps=1e-09),
                # optim.RAdam(self.asr_model.parameters())
                # self.config['asr_model']['optimizer_opt']['k'],
                # self.config['asr_model']['encoder']['d_model'],
                # self.config['asr_model']['optimizer_opt']['warmup_steps']
            # )
            self.label_smoothing = self.config['solver']['label_smoothing']
            self.sos_id = self.asr_model.sos_id
            self.eos_id = self.asr_model.eos_id

            super().load_model()

        def exec(self):
            self.train()

        def run_batch(self, cur_b, x, ilens, ys, olens, train):

            pred, gold = self.asr_model(x, ilens, ys, olens)
            loss, acc = cal_performance(pred, gold, self.label_smoothing)


            if train:
                info = { 'loss': loss.item(), 'acc': acc}
                if self.global_step % 500 == 0:
                    self.probe_model(pred[0], gold[0])
                self.asr_opt.zero_grad()
                loss.backward()

            else:
                wer = self.metric_observer.batch_cal_wer(pred.detach(), gold)
                info = { 'wer': wer, 'loss':loss.item(), 'acc': acc}

            return info

        def probe_model(self, pred, ys_out):
            self.metric_observer.cal_wer(torch.argmax(pred, dim=-1), ys_out, show=True)

    return TransformerTrainer(config, paras, id2accent)
        def set_model(self):
            self.asr_model = MyTransformer(self.id2ch,
                                           self.config['asr_model']).cuda()
            self.label_smooth_rate = self.config['solver']['label_smoothing']
            if self.label_smooth_rate > 0.0:
                logger.log(
                    f"Use label smoothing rate {self.label_smooth_rate}",
                    prefix='info')
            # self.asr_opt = optim.RAdam(self.asr_model.parameters(), betas=(0.9, 0.98), eps=1e-9)
            # if self.config['asr_model']['optimizer_cls'] == 'noam':
            if 'inner_optimizer_cls' not in self.config[
                    'asr_model']:  # multi or mono
                if self.config['asr_model']['optimizer_cls'] == 'noam':
                    logger.notice(
                        "Use noam optimizer, it is recommended to be used in mono-lingual training"
                    )
                    self.asr_opt = TransformerOptimizer(
                        torch.optim.Adam(self.asr_model.parameters(),
                                         betas=(0.9, 0.98),
                                         eps=1e-09,
                                         weight_decay=1e-3),
                        self.config['asr_model']['optimizer_opt']['k'],
                        self.config['asr_model']['d_model'],
                        self.config['asr_model']['optimizer_opt']
                        ['warmup_steps'])
                elif self.config['asr_model']['optimizer_cls'] == 'RAdam':
                    logger.notice(
                        f"Use third-library {self.config['asr_model']['optimizer_cls']} optimizer"
                    )
                    self.asr_opt = getattr(extra_optim,\
                                           self.config['asr_model']['optimizer_cls'])
                    self.asr_opt = self.asr_opt(self.asr_model.parameters(), \
                                                **self.config['asr_model']['optimizer_opt'])
                else:
                    logger.notice(
                        f"Use {self.config['asr_model']['optimizer_cls']} optimizer, it is recommended to be used in fine-tuning"
                    )
                    self.asr_opt = getattr(torch.optim,\
                                           self.config['asr_model']['optimizer_cls'])
                    self.asr_opt = self.asr_opt(self.asr_model.parameters(), \
                                                **self.config['asr_model']['optimizer_opt'])
            else:
                logger.notice(
                    "During meta-training, model optimizer will reset after running each task"
                )

            self.sos_id = self.asr_model.sos_id
            self.eos_id = self.asr_model.eos_id

            super().load_model()
    def exec(self):
        if self.decode_mode != 'greedy':
            logger.notice(f"Start decoding with beam search (with beam size: {self.config['solver']['beam_decode']['beam_size']})")
            raise NotImplementedError(f"{self.decode_mode} haven't supported yet")
            self._decode = self.beam_decode
        else:
            logger.notice("Start greedy decoding")
            if self.batch_size > 1:
                dev = 'gpu' if self.use_gpu else 'cpu'
                logger.log(f"Number of utterance batches to decode: {len(self.eval_set)}, decoding with {self.batch_size} batch_size using {dev}")
                self._decode = self.batch_greedy_decode
                self.njobs = 1
            else:
                logger.log(f"Number of utterances to decode: {len(self.eval_set)}, decoding with {self.njobs} threads using cpu")
                self._decode = self.greedy_decode

        if self.njobs > 1:
            try:
                _ = Parallel(n_jobs=self.njobs)(delayed(self._decode)(i, x, ilen, y, olen) for i, (x, ilen, y, olen) in enumerate(self.eval_set))

            #NOTE: cannot log comet here, since it cannot serialize
            except KeyboardInterrupt:
                logger.warning("Decoding stopped")
            else:
                logger.notice("Decoding done")
                # self.comet_exp.log_other('status','decoded')
        else:

            tbar = get_bar(total=len(self.eval_set), leave=True)

            for cur_b, (xs, ilens, ys, olens) in enumerate(self.eval_set):
                self.batch_greedy_decode(xs, ilens, ys, olens)
                tbar.update(1)
    def load_model(self):
        logger.log("First Order GBML ASRmodel for pretraining initialization")

        if self.paras.resume:
            logger.notice(f"Resume pretraining from {self.global_step}")
            self.asr_model.load_state_dict(torch.load(self.resume_model_path))
            self.dashboard.set_step(self.global_step)
        else:
            logger.notice(f"Start pretraining from {self.global_step}")

        self._original = clone_state_dict(self.asr_model.state_dict(keep_vars=True))
        params = [p for p in self._original.values() if getattr(p, 'requires_grad', False)]

        if self.config['asr_model']['meta_opt_cls'] == 'noam':
            self.meta_opt = TransformerOptimizer(
                torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-09),
                self.config['asr_model']['meta']['optimizer_opt']['k'],
                self.config['asr_model']['d_model'],
                self.config['asr_model']['meta']['optimizer_opt']['warmup_steps']
            )
        else:
            raise NotImplementedError(f"Should use noam optimizer in outer loop transformer learning, but got {self.asr_model['meta_opt_cls']}")
예제 #17
0
def get_loader(data_dir, batch_size, is_memmap, is_bucket, num_workers=0, split_rate=1.0, split_seed=531,
               min_ilen=None, max_ilen=None, half_batch_ilen=None, 
               bucket_reverse=False, shuffle=True, read_file=False, 
               drop_last=False, pin_memory=True):

    assert not read_file, "Load from Kaldi ark haven't been implemented yet"
    dset = CommonVoiceDataset(data_dir, is_memmap)

    if split_rate < 1.0:
        logger.notice(f"Only use {split_rate * 100}% data for training (Split seed: {split_seed})")
        tot_sz = len(dset)
        num_tr = int(tot_sz * split_rate)
        num_drop = tot_sz - num_tr

        dset, _ = random_split(dset, [num_tr, num_drop], generator=torch.Generator().manual_seed(split_seed))


    # if data is already loaded in memory
    if not is_memmap: 
        num_workers = 0

    logger.notice(f"Loading data from {data_dir} with {num_workers} threads")

    if is_bucket and split_rate == 1.0:
        my_sampler = BucketSampler(dset.ilens, 
                                   min_ilen = min_ilen, 
                                   max_ilen = max_ilen, 
                                   half_batch_ilen = half_batch_ilen, 
                                   batch_size=batch_size, 
                                   bucket_size=BUCKET_SIZE, 
                                   bucket_reverse=bucket_reverse, 
                                   drop_last = drop_last)

        loader = DataLoader(dset, batch_size=1, num_workers=num_workers,
                            collate_fn=collate_fn, batch_sampler=my_sampler,
                            drop_last=drop_last, pin_memory=pin_memory)
    else:
        logger.notice("No bucket sampling, the efficincy will be a little bit worse")
        loader = DataLoader(dset, batch_size=batch_size, num_workers=num_workers,
                            collate_fn=collate_fn, shuffle=shuffle,
                            drop_last=drop_last, pin_memory=pin_memory)

    return loader
예제 #18
0
def get_trainer(cls, config, paras, id2accent):
    logger.notice("Transformer Trainer Init...")

    class TransformerTrainer(cls):
        def __init__(self, config, paras, id2accent):
            super(TransformerTrainer, self).__init__(config, paras, id2accent)

        def set_model(self):
            self.asr_model = MyTransformer(self.id2ch,
                                           self.config['asr_model']).cuda()
            self.label_smooth_rate = self.config['solver']['label_smoothing']
            if self.label_smooth_rate > 0.0:
                logger.log(
                    f"Use label smoothing rate {self.label_smooth_rate}",
                    prefix='info')
            # self.asr_opt = optim.RAdam(self.asr_model.parameters(), betas=(0.9, 0.98), eps=1e-9)
            # if self.config['asr_model']['optimizer_cls'] == 'noam':
            if 'inner_optimizer_cls' not in self.config[
                    'asr_model']:  # multi or mono
                if self.config['asr_model']['optimizer_cls'] == 'noam':
                    logger.notice(
                        "Use noam optimizer, it is recommended to be used in mono-lingual training"
                    )
                    self.asr_opt = TransformerOptimizer(
                        torch.optim.Adam(self.asr_model.parameters(),
                                         betas=(0.9, 0.98),
                                         eps=1e-09),
                        self.config['asr_model']['optimizer_opt']['k'],
                        self.config['asr_model']['d_model'],
                        self.config['asr_model']['optimizer_opt']
                        ['warmup_steps'])
                elif self.config['asr_model']['optimizer_cls'] == 'RAdam':
                    logger.notice(
                        f"Use third-library {self.config['asr_model']['optimizer_cls']} optimizer"
                    )
                    self.asr_opt = getattr(extra_optim,\
                                           self.config['asr_model']['optimizer_cls'])
                    self.asr_opt = self.asr_opt(self.asr_model.parameters(), \
                                                **self.config['asr_model']['optimizer_opt'])
                else:
                    logger.notice(
                        f"Use {self.config['asr_model']['optimizer_cls']} optimizer, it is recommended to be used in fine-tuning"
                    )
                    self.asr_opt = getattr(torch.optim,\
                                           self.config['asr_model']['optimizer_cls'])
                    self.asr_opt = self.asr_opt(self.asr_model.parameters(), \
                                                **self.config['asr_model']['optimizer_opt'])
            else:
                logger.notice(
                    "During meta-training, model optimizer will reset after running each task"
                )

            self.sos_id = self.asr_model.sos_id
            self.eos_id = self.asr_model.eos_id

            super().load_model()

        def exec(self):
            self.train()

        def run_batch(self,
                      cur_b,
                      x,
                      ilens,
                      ys,
                      olens,
                      train,
                      accent_idx=None):
            # if accent_idx is not None -> monolingual training

            batch_size = len(ys)
            pred, gold = self.asr_model(x, ilens, ys, olens)
            pred_cat = pred.view(-1, pred.size(2))
            gold_cat = gold.contiguous().view(-1)
            non_pad_mask = gold_cat.ne(IGNORE_ID)
            n_total = non_pad_mask.sum().item()

            if self.label_smooth_rate > 0.0:
                eps = self.label_smooth_rate
                n_class = pred_cat.size(1)

                gold_for_scatter = gold_cat.ne(IGNORE_ID).long() * gold_cat
                one_hot = torch.zeros_like(pred_cat).scatter(
                    1, gold_for_scatter.view(-1, 1), 1)
                one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / n_class
                log_prb = F.log_softmax(pred_cat, dim=-1)
                loss = -(one_hot * log_prb).sum(dim=1)
                loss = loss.masked_select(non_pad_mask).sum() / n_total

            else:
                loss = F.cross_entropy(pred_cat,
                                       gold_cat,
                                       ignore_index=IGNORE_ID)

            pred_cat = pred_cat.detach().max(1)[1]
            n_correct = pred_cat.eq(gold_cat).masked_select(
                non_pad_mask).sum().item()

            if train:
                info = {'loss': loss.item(), 'acc': float(n_correct) / n_total}
                # if self.global_step % 5 == 0:
                if self.global_step % 500 == 0:
                    self.probe_model(pred.detach(), gold, accent_idx)
                self.asr_opt.zero_grad()
                loss.backward()

            else:
                cer = self.metric_observer.batch_cal_er(
                    pred.detach(), gold, ['att'], ['cer'])['att_cer']
                wer = self.metric_observer.batch_cal_er(
                    pred.detach(), gold, ['att'], ['wer'])['att_wer']
                info = {
                    'cer': cer,
                    'wer': wer,
                    'loss': loss.item(),
                    'acc': float(n_correct) / n_total
                }

            return info

        def probe_model(self, pred, ys_out, accent_idx):
            if accent_idx is not None:
                logger.log(f"Probe on {self.accents[accent_idx]}",
                           prefix='debug')
            self.metric_observer.cal_att_cer(torch.argmax(pred[0], dim=-1),
                                             ys_out[0],
                                             show=True,
                                             show_decode=True)
            self.metric_observer.cal_att_wer(torch.argmax(pred[0], dim=-1),
                                             ys_out[0],
                                             show=True)

    return TransformerTrainer(config, paras, id2accent)
예제 #19
0
    ret = list(map(int, s.split(' ')))
    ret = [id2units[i] for i in ret]
    return ret


def remove_accent(s):
    return unidecode.unidecode(s)


def filter(s):
    # return re.sub(r'([,.!:-?"\'])\1+', r'\1', s)
    return s.translate({ord(c): None for c in IGNORE_CH_LIST})


### Cal CER ####################################################################
logger.notice("CER calculating...")
cer = 0.0
with open(Path(decode_dir, 'best-hyp'), 'r') as hyp_ref_in:
    cnt = 0
    for line in hyp_ref_in.readlines():
        cnt += 1
        ref, hyp = line.rstrip().split('\t')
        ref = filter(remove_accent(spm.DecodePieces(to_list(ref))))
        ref.upper()
        hyp = filter(remove_accent(spm.DecodePieces(to_list(hyp))))
        hyp.upper()
        cer += (editdistance.eval(ref, hyp) / len(ref) * 100)
    cer = cer / cnt
logger.log(f"CER: {cer}", prefix='test')
comet_exp.log_other(f"cer({paras.decode_mode})", round(cer, 2))
with open(Path(decode_dir, 'cer'), 'w') as fout:
예제 #20
0
    def __init__(self, config, paras, log_dir, train_type, resume=False):
        self.log_dir = log_dir
        self.expkey_f = Path(self.log_dir, 'exp_key')
        self.global_step = 1

        if resume:
            assert self.expkey_f.exists(
            ), f"Cannot find comet exp key in {self.log_dir}"
            with open(Path(self.log_dir, 'exp_key'), 'r') as f:
                exp_key = f.read().strip()
            self.exp = ExistingExperiment(
                previous_experiment=exp_key,
                project_name=COMET_PROJECT_NAME,
                workspace=COMET_WORKSPACE,
                auto_output_logging=None,
                auto_metric_logging=None,
                display_summary_level=0,
            )
        else:
            self.exp = Experiment(
                project_name=COMET_PROJECT_NAME,
                workspace=COMET_WORKSPACE,
                auto_output_logging=None,
                auto_metric_logging=None,
                display_summary_level=0,
            )
            #TODO: is there exists better way to do this?
            with open(self.expkey_f, 'w') as f:
                print(self.exp.get_key(), file=f)

            self.exp.log_other('seed', paras.seed)
            self.log_config(config)
            if train_type == 'evaluation':
                if paras.pretrain:
                    self.exp.set_name(
                        f"{paras.pretrain_suffix}-{paras.eval_suffix}")
                    self.exp.add_tags([
                        paras.pretrain_suffix, config['solver']['setting'],
                        paras.accent, paras.algo, paras.eval_suffix
                    ])
                    if paras.pretrain_model_path:
                        self.exp.log_other("pretrain-model-path",
                                           paras.pretrain_model_path)
                    else:
                        self.exp.log_other("pretrain-runs",
                                           paras.pretrain_runs)
                        self.exp.log_other("pretrain-setting",
                                           paras.pretrain_setting)
                        self.exp.log_other("pretrain-tgt-accent",
                                           paras.pretrain_tgt_accent)
                else:
                    self.exp.set_name(paras.eval_suffix)
                    self.exp.add_tags(
                        ["mono", config['solver']['setting'], paras.accent])
            else:
                self.exp.set_name(paras.pretrain_suffix)
                self.exp.log_others({
                    f"accent{i}": k
                    for i, k in enumerate(paras.pretrain_accents)
                })
                self.exp.log_other('accent', paras.tgt_accent)
                self.exp.add_tags([
                    paras.algo, config['solver']['setting'], paras.tgt_accent
                ])
            #TODO: Need to add pretrain setting

        ##slurm-related
        hostname = os.uname()[1]
        if len(hostname.split('.')) == 2 and hostname.split(
                '.')[1] == 'speech':
            logger.notice(f"Running on Battleship {hostname}")
            self.exp.log_other('jobid', int(os.getenv('SLURM_JOBID')))
        else:
            logger.notice(f"Running on {hostname}")
    def __init__(self, config, paras, id2accent):

        self.config = config
        self.paras = paras
        self.train_type = 'evaluation'
        self.is_memmap = paras.is_memmap
        self.model_name = paras.model_name

        self.njobs = paras.njobs

        if paras.algo == 'no' and paras.pretrain_suffix is None:
            paras.pretrain_suffix = paras.eval_suffix

        ### Set path
        cur_path = Path.cwd()
        self.data_dir = Path(config['solver']['data_root'], id2accent[paras.accent])
        self.log_dir = Path(cur_path, LOG_DIR, self.train_type, 
                            config['solver']['setting'], paras.algo, \
                            paras.pretrain_suffix, paras.eval_suffix, \
                            id2accent[paras.accent], str(paras.runs))
        self.model_path = Path(self.log_dir, paras.test_model)
        assert self.model_path.exists(), f"{self.model_path.as_posix()} not exists..."
        self.decode_dir = Path(self.log_dir, paras.decode_suffix)


        ### Decode
        self.decode_mode = paras.decode_mode
        self.beam_decode_param = config['solver']['beam_decode']
        self.batch_size = paras.decode_batch_size
        self.use_gpu = paras.cuda

        if paras.decode_mode == 'lm_beam':
            assert paras.lm_model_path is not None, "In LM Beam decode mode, lm_model_path should be specified"
            # assert self.model_name == 'blstm', "LM Beam decode is only supported in blstm model"
            self.lm_model_path  = paras.lm_model_path
        else :
            self.lm_model_path = None

        # if paras.decode_mode == 'greedy':
            # self._decode = self.greedy_decode
        # elif paras.decode_mode == 'beam' or paras.decode_mode == 'lm_beam':
            # self._decode = self.beam_decode
        # else :
            # raise NotImplementedError
        #####################################################################

        ### Resume Mechanism
        if not paras.resume:
            if self.decode_dir.exists():
                assert paras.overwrite, \
                    f"Path exists ({self.decode_dir}). Use --overwrite or change decode suffix"
                # time.sleep(10)
                logger.warning('Overwrite existing directory')
                rmtree(self.decode_dir)
            self.decode_dir.mkdir(parents=True)
            self.prev_decode_step = -1
        else:
            with open(Path(self.decode_dir,'best-hyp'),'r') as f:
                for i, l in enumerate(f):
                    pass
                self.prev_decode_step = i+1
            logger.notice(f"Decode resume from {self.prev_decode_step}")

        ### Comet
        with open(Path(self.log_dir,'exp_key'),'r') as f:
            exp_key = f.read().strip()
            comet_exp = ExistingExperiment(previous_experiment=exp_key,
                                           project_name=COMET_PROJECT_NAME,
                                           workspace=COMET_WORKSPACE,
                                           auto_output_logging=None,
                                           auto_metric_logging=None,
                                           display_summary_level=0,
                                           )
        comet_exp.log_other('status','decode')
예제 #22
0
def get_trainer(cls, config, paras, id2accent):
    logger.notice("BLSTM Trainer Init...")

    class BLSTMTrainer(cls):
        def __init__(self, config, paras, id2accent):
            super(BLSTMTrainer, self).__init__(config, paras, id2accent)

        def set_model(self):
            self.asr_model = MonoBLSTM(self.id2ch,
                                       self.config['asr_model']).cuda()
            self.ctc_loss = nn.CTCLoss(blank=0,
                                       reduction='mean',
                                       zero_infinity=True)

            self.sos_id = self.asr_model.sos_id
            self.eos_id = self.asr_model.eos_id

            self.asr_opt = getattr(torch.optim, \
                                   self.config['asr_model']['optimizer']['type'])
            self.asr_opt = self.asr_opt(self.asr_model.parameters(), \
                                        **self.config['asr_model']['optimizer_opt'])

            self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                self.asr_opt, mode='min', factor=0.2, patience=3, verbose=True)
            super().load_model()
            self.freeze_encoder(paras.freeze_layer)

        def freeze_encoder(self, module):
            if module is not None:
                if module == 'VGG':
                    for p in self.asr_model.encoder.vgg.parameters():
                        p.requires_grad = False
                elif module == 'VGG_BLSTM':
                    for p in self.asr_model.encoder.parameters():
                        p.requires_grad = False
                else:
                    raise ValueError(
                        f"Unknown freeze layer {module} (VGG, VGG_BLSTM)")

                logger.log(f"Freeze {' '.join(module.split('_'))} layer",
                           prefix='info')

        def exec(self):
            self.train()

        def run_batch(self, cur_b, x, ilens, ys, olens, train):
            sos = ys[0].new([self.sos_id])
            eos = ys[0].new([self.eos_id])
            ys_out = [torch.cat([sos, y, eos], dim=0) for y in ys]
            olens += 2  # pad <sos> and <eos>

            y_true = torch.cat(ys_out)

            pred, enc_lens = self.asr_model(x, ilens)
            olens = to_device(self.asr_model, olens)
            pred = F.log_softmax(pred, dim=-1)  # (T, o_dim)

            loss = self.ctc_loss(
                pred.transpose(0, 1).contiguous(),
                y_true.cuda().to(dtype=torch.long),
                enc_lens.cpu().to(dtype=torch.long),
                olens.cpu().to(dtype=torch.long))

            if train:
                info = {'loss': loss.item()}
                # if self.global_step % 5 == 0:
                if self.global_step % 500 == 0:
                    self.probe_model(pred, ys)
                self.asr_opt.zero_grad()
                loss.backward()

            else:
                cer = self.metric_observer.batch_cal_er(
                    pred.detach(), ys, ['ctc'], ['cer'])['ctc_cer']
                wer = self.metric_observer.batch_cal_er(
                    pred.detach(), ys, ['ctc'], ['wer'])['ctc_wer']
                info = {'cer': cer, 'wer': wer, 'loss': loss.item()}

            return info

        def probe_model(self, pred, ys):
            self.metric_observer.cal_ctc_cer(torch.argmax(pred[0], dim=-1),
                                             ys[0],
                                             show=True,
                                             show_decode=True)
            self.metric_observer.cal_ctc_wer(torch.argmax(pred[0], dim=-1),
                                             ys[0],
                                             show=True)

    return BLSTMTrainer(config, paras, id2accent)