Example #1
0
    def save_per_epoch(self):

        if self.save_verbose:
            logger.log(f"Save model snapshot.ep.{self.ep}")
            torch.save(self.asr_model.state_dict(), \
                       self.log_dir.joinpath(f"snapshot.ep.{self.ep}"))

        logger.log("Save model snapshot.latest", prefix='info')
        torch.save(self.asr_model.state_dict(), \
                   self.log_dir.joinpath("snapshot.latest"))
        if isinstance(self.asr_opt, TransformerOptimizer):
            with open(self.log_dir.joinpath("optimizer.latest"), "wb") as fout:
                pickle.dump(self.asr_opt, fout)
        else:
            torch.save(self.asr_opt.state_dict(), \
                       self.log_dir.joinpath("optimizer.latest"))

        with open(self.log_dir.joinpath("info_dict.latest"), 'wb') as fout:
            pickle.dump(self.train_info, fout)

        with open(self.log_dir.joinpath("global_step"), 'w') as fout:
            print(self.global_step, file=fout)
        self.dashboard.log_step()

        with open(Path(self.log_dir, 'epoch'), 'w') as fout:
            print(self.ep, file=fout)
    def __init__(self, data_dir, is_memmap):
        """
        data_dir: str
        is_memmap: bool
        """
        if is_memmap:
            feat_path = data_dir.joinpath('feat').with_suffix('.dat')
            logger.log(f"Loading {feat_path} from memmap...",prefix='info')
            self.feat = np.load(feat_path, mmap_mode='r')
        else:
            feat_path = data_dir.joinpath('feat').with_suffix('.npy')
            logger.warning(f"Loading whole data ({feat_path}) into RAM")
            self.feat = np.load(feat_path)
        
        self.ilens = np.load(data_dir.joinpath('ilens.npy'))
        self.iptr = np.zeros(len(self.ilens)+1, dtype=int)
        self.ilens.cumsum(out=self.iptr[1:])

        self.label = np.load(data_dir.joinpath('label.npy'))
        self.olens = np.load(data_dir.joinpath('olens.npy'))
        self.optr = np.zeros(len(self.olens) + 1, dtype=int)
        self.olens.cumsum(out=self.optr[1:])

        assert len(self.ilens) == len(self.olens), \
        "Number of samples should be the same in features and labels"
    def __init__(self, config, paras, id2accent):
        super(FOMetaASRInterface, self).__init__(config, paras, id2accent)

        assert paras.meta_k is not None
        self.meta_k = paras.meta_k
        if paras.meta_batch_size is None:
            logger.log("Meta batch_size not set...", prefix='info')
            self.meta_batch_size = self.num_pretrain

        self.asr_model = None
        self.asr_opt = None

        self.max_step = paras.max_step if paras.max_step > 0 else config['solver']['total_steps']
        self.dashboard.set_status('pretraining')

        self._train = partial(self.run_batch, train=True)
        self._eval = partial(self.run_batch, train=False)

        self.meta_batch_size = paras.meta_batch_size
        self._updates = None
        self._counter = 0

        d_model = config['asr_model']['d_model']
        opt_k = config['asr_model']['meta']['optimizer_opt']['k']
        warmup_steps = config['asr_model']['meta']['optimizer_opt']['warmup_steps']
        self.inner_lr = d_model ** (-0.5) * opt_k * ((warmup_steps) ** (-0.5))

        logger.notice("Meta Interface Information")
        logger.log(f"Sampling strategy: {self.sample_strategy}", prefix='info')
        logger.log(f"Meta batch size  : {self.meta_batch_size}", prefix='info')
        logger.log(f"# inner-loop step: {self.meta_k}", prefix='info')
        # Max. of lr in noam
        logger.log( "Inner loop lr    : {}".format(round(self.inner_lr, 5)),prefix='info')
    def set_model(self):

        logger.notice(f"Load trained ASR model from {self.model_path}")
        if self.model_name == 'blstm':
            from src.model.blstm.mono_blstm import MonoBLSTM as ASRModel
        elif self.model_name == 'las':
            from src.model.seq2seq.mono_las import MonoLAS as ASRModel
        elif self.model_name == 'transformer':
            from src.model.transformer_pytorch.mono_transformer_torch import MyTransformer as ASRModel
        else:
            raise NotImplementedError
        self.asr_model = ASRModel(self.id2ch, self.config['asr_model'])

        if self.use_gpu:
            self.asr_model.load_state_dict(torch.load(self.model_path))
            self.asr_model = self.asr_model.cuda()
        else:
            self.asr_model.load_state_dict(torch.load(self.model_path, map_location=torch.device('cpu')))
        self.asr_model.eval()
        logger.log(f'ASR model device {self.asr_model.device}', prefix='debug')
        self.sos_id = self.asr_model.sos_id
        self.eos_id = self.asr_model.eos_id
        self.blank_id = None if self.model_name != 'blstm' else self.asr_model.blank_id

        #FIXME: will merge it later
        if self.decode_mode != 'greedy':
            if self.model_name == 'blstm':
                raise NotImplementedError
            elif self.model_name == 'transformer':
                raise NotImplementedError
            else:
                raise NotImplementedError
    def load_data(self):
        logger.notice(f"Loading data from {self.data_dir}")
        if self.model_name == 'blstm':
            self.id2ch = [BLANK_SYMBOL]
        elif self.model_name == 'transformer':
            self.id2ch = [SOS_SYMBOL]
        else:
            raise NotImplementedError

        with open(self.config['solver']['spm_mapping']) as fin:
            for line in fin.readlines():
                self.id2ch.append(line.rstrip().split(' ')[0])
            self.id2ch.append(EOS_SYMBOL)
        logger.log(f"Train units: {self.id2ch}")

        setattr(self, 'eval_set', get_loader(
            self.data_dir.joinpath('test'),
            # self.data_dir.joinpath('dev'),
            batch_size = self.batch_size,
            half_batch_ilen = 512 if self.batch_size > 1 else None,
            is_memmap = self.is_memmap,
            is_bucket = False,
            shuffle = False,
            num_workers = 1,
        ))
    def exec(self):
        if self.decode_mode != 'greedy':
            logger.notice(f"Start decoding with beam search (with beam size: {self.config['solver']['beam_decode']['beam_size']})")
            raise NotImplementedError(f"{self.decode_mode} haven't supported yet")
            self._decode = self.beam_decode
        else:
            logger.notice("Start greedy decoding")
            if self.batch_size > 1:
                dev = 'gpu' if self.use_gpu else 'cpu'
                logger.log(f"Number of utterance batches to decode: {len(self.eval_set)}, decoding with {self.batch_size} batch_size using {dev}")
                self._decode = self.batch_greedy_decode
                self.njobs = 1
            else:
                logger.log(f"Number of utterances to decode: {len(self.eval_set)}, decoding with {self.njobs} threads using cpu")
                self._decode = self.greedy_decode

        if self.njobs > 1:
            try:
                _ = Parallel(n_jobs=self.njobs)(delayed(self._decode)(i, x, ilen, y, olen) for i, (x, ilen, y, olen) in enumerate(self.eval_set))

            #NOTE: cannot log comet here, since it cannot serialize
            except KeyboardInterrupt:
                logger.warning("Decoding stopped")
            else:
                logger.notice("Decoding done")
                # self.comet_exp.log_other('status','decoded')
        else:

            tbar = get_bar(total=len(self.eval_set), leave=True)

            for cur_b, (xs, ilens, ys, olens) in enumerate(self.eval_set):
                self.batch_greedy_decode(xs, ilens, ys, olens)
                tbar.update(1)
Example #7
0
    def load_model(self):
        logger.log("ASR model initialization")

        if self.paras.resume:
            logger.notice(
                f"Resume training from epoch {self.ep} (best cer: {self.best_cer}, best wer: {self.best_wer})"
            )
            self.asr_model.load_state_dict(torch.load(self.resume_model_path))
            if isinstance(self.asr_opt, TransformerOptimizer):
                with open(self.optimizer_path, 'rb') as fin:
                    self.asr_opt = pickle.load(fin)
            else:
                self.asr_opt.load_state_dict(torch.load(self.optimizer_path))
            self.dashboard.set_step(self.global_step)
        elif self.paras.pretrain:
            model_dict = self.asr_model.state_dict()
            logger.notice(
                f"Load pretraining {','.join(self.pretrain_module)} from {self.pretrain_model_path}"
            )
            pretrain_dict = self.filter_model(
                torch.load(self.pretrain_model_path))
            model_dict.update(pretrain_dict)
            self.asr_model.load_state_dict(model_dict)
            if 'freeze_module' in self.config['solver']:
                logger.warning(
                    "Part of model will be frozen during fine-tuning")
                self.freeze_module(self.config['solver']['freeze_module'])
            logger.notice("Done!")
        else:  # simple monolingual training from step 0
            logger.notice("Training from scratch")
    def load_model(self):
        logger.log("MultiASR model for pretraining initialization")

        if self.paras.resume:
            logger.notice(f"Resume pretraining from {self.global_step}")
            self.asr_model.load_state_dict(torch.load(self.resume_model_path))
            self.dashboard.set_step(self.global_step)
        else:
            logger.notice(f"Start pretraining from {self.global_step}")
Example #9
0
 def probe_model(self, pred, ys_out, accent_idx):
     if accent_idx is not None:
         logger.log(f"Probe on {self.accents[accent_idx]}",
                    prefix='debug')
     self.metric_observer.cal_att_cer(torch.argmax(pred[0], dim=-1),
                                      ys_out[0],
                                      show=True,
                                      show_decode=True)
     self.metric_observer.cal_att_wer(torch.argmax(pred[0], dim=-1),
                                      ys_out[0],
                                      show=True)
 def beam_decode(self, cur_step ,x, ilen, y, olen):
     if cur_step > self.prev_decode_step:
         if cur_step % 500 == 0:
             logger.log(f"Current step {cur_step}")
         with torch.no_grad():
             model = copy.deepcopy(self.asr_model)
             hyp = model.beam_decode(x, ilen)
             del model
         self.write_hyp(y[0],hyp[0])
         del hyp
     del x, ilen, y, olen
     return True
Example #11
0
    def __init__(self, model_path, id2units, sos_id, eos_id, ignore_id=None):
        self.spm = spmlib.SentencePieceProcessor()
        self.spm.Load(model_path)
        self.id2units = id2units
        self.sos_id = sos_id
        self.eos_id = eos_id
        self.blank_id = None if BLANK_SYMBOL not in id2units else id2units.index(BLANK_SYMBOL)
        self.ignore_id = ignore_id
        # self.sos_id = self.id2units.index('<s>')
        # self.eos_id = self.id2units.index('</s>')

        logger.log(f"Train units: {self.id2units}")
        def set_model(self):
            self.asr_model = MyTransformer(self.id2ch,
                                           self.config['asr_model']).cuda()
            self.label_smooth_rate = self.config['solver']['label_smoothing']
            if self.label_smooth_rate > 0.0:
                logger.log(
                    f"Use label smoothing rate {self.label_smooth_rate}",
                    prefix='info')
            # self.asr_opt = optim.RAdam(self.asr_model.parameters(), betas=(0.9, 0.98), eps=1e-9)
            # if self.config['asr_model']['optimizer_cls'] == 'noam':
            if 'inner_optimizer_cls' not in self.config[
                    'asr_model']:  # multi or mono
                if self.config['asr_model']['optimizer_cls'] == 'noam':
                    logger.notice(
                        "Use noam optimizer, it is recommended to be used in mono-lingual training"
                    )
                    self.asr_opt = TransformerOptimizer(
                        torch.optim.Adam(self.asr_model.parameters(),
                                         betas=(0.9, 0.98),
                                         eps=1e-09,
                                         weight_decay=1e-3),
                        self.config['asr_model']['optimizer_opt']['k'],
                        self.config['asr_model']['d_model'],
                        self.config['asr_model']['optimizer_opt']
                        ['warmup_steps'])
                elif self.config['asr_model']['optimizer_cls'] == 'RAdam':
                    logger.notice(
                        f"Use third-library {self.config['asr_model']['optimizer_cls']} optimizer"
                    )
                    self.asr_opt = getattr(extra_optim,\
                                           self.config['asr_model']['optimizer_cls'])
                    self.asr_opt = self.asr_opt(self.asr_model.parameters(), \
                                                **self.config['asr_model']['optimizer_opt'])
                else:
                    logger.notice(
                        f"Use {self.config['asr_model']['optimizer_cls']} optimizer, it is recommended to be used in fine-tuning"
                    )
                    self.asr_opt = getattr(torch.optim,\
                                           self.config['asr_model']['optimizer_cls'])
                    self.asr_opt = self.asr_opt(self.asr_model.parameters(), \
                                                **self.config['asr_model']['optimizer_opt'])
            else:
                logger.notice(
                    "During meta-training, model optimizer will reset after running each task"
                )

            self.sos_id = self.asr_model.sos_id
            self.eos_id = self.asr_model.eos_id

            super().load_model()
Example #13
0
        def freeze_encoder(self, module):
            if module is not None:
                if module == 'VGG':
                    for p in self.asr_model.encoder.vgg.parameters():
                        p.requires_grad = False
                elif module == 'VGG_BLSTM':
                    for p in self.asr_model.encoder.parameters():
                        p.requires_grad = False
                else:
                    raise ValueError(
                        f"Unknown freeze layer {module} (VGG, VGG_BLSTM)")

                logger.log(f"Freeze {' '.join(module.split('_'))} layer",
                           prefix='info')
 def greedy_decode(self, cur_step, x, ilen, y, olen):
     if cur_step > self.prev_decode_step:
         if cur_step % 500 == 0:
             logger.log(f"Current step {cur_step}")
         with torch.no_grad():
             hyp , _ = self.asr_model.greedy_decode(x, ilen)
             hyp = self.trim(torch.argmax(hyp[0], dim=-1).tolist())
             hyp = [x[0] for x in groupby(hyp)]
             if self.blank_id is not None:
                 hyp = [x for x in hyp if x != self.blank_id]
             # del model
         self.write_hyp(y[0].tolist(),hyp)
         del hyp
     del x, ilen, y, olen
     return True
    def save_per_steps(self):
        assert self.asr_model is not None

        logger.log("Save model snapshot.latest", prefix='info')

        torch.save(self.asr_model.state_dict(), \
                   self.log_dir.joinpath("snapshot.latest"))
        with open(self.log_dir.joinpath("info_dict.latest"),'wb') as fin:
            pickle.dump(self.train_info, fin)

        with open(self.log_dir.joinpath("global_step"),'w') as fout:
            print(self.global_step, file=fout)

        # Used for transfer (as init weight for training)
        model_save_name = f"snapshot.step.{self.global_step}"
        logger.log(f"Save model {model_save_name}", prefix='info')

        torch.save(self.asr_model.state_dict(), self.log_dir.joinpath(model_save_name))
        self.dashboard.log_step()
    def load_model(self):
        logger.log("First Order GBML ASRmodel for pretraining initialization")

        if self.paras.resume:
            logger.notice(f"Resume pretraining from {self.global_step}")
            self.asr_model.load_state_dict(torch.load(self.resume_model_path))
            self.dashboard.set_step(self.global_step)
        else:
            logger.notice(f"Start pretraining from {self.global_step}")

        self._original = clone_state_dict(self.asr_model.state_dict(keep_vars=True))
        params = [p for p in self._original.values() if getattr(p, 'requires_grad', False)]

        if self.config['asr_model']['meta_opt_cls'] == 'noam':
            self.meta_opt = TransformerOptimizer(
                torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-09),
                self.config['asr_model']['meta']['optimizer_opt']['k'],
                self.config['asr_model']['d_model'],
                self.config['asr_model']['meta']['optimizer_opt']['warmup_steps']
            )
        else:
            raise NotImplementedError(f"Should use noam optimizer in outer loop transformer learning, but got {self.asr_model['meta_opt_cls']}")
Example #17
0
    def cal_att_cer(self, pred, y, show=False, show_decode=False):
        show_pred = self.discard_ch_after_eos(pred.tolist())
        show_pred = [x for x in show_pred if x!= self.sos_id]
        show_pred_text = self.spm.DecodePieces([self.id2units[x] for x in show_pred])

        show_y = [self.id2units[x] for x in y.tolist() if x!= self.eos_id and x!= IGNORE_ID]
        show_y_text = self.spm.DecodePieces(show_y)
        
        wer = float(editdistance.eval(show_pred_text, show_y_text)) / len(show_y_text) * 100
        
        if show_decode:
            logger.log(f"Hyp:\t {show_pred_text}", prefix='debug')
            logger.log(f"Ref:\t {show_y_text}", prefix='debug')
        if show:
            logger.log(f"CER: {wer}", prefix='debug')

        return wer
Example #18
0
    def cal_ctc_cer(self, pred, y, show=False, show_decode=False):
        assert self.blank_id is not None

        show_pred = pred.tolist()
        show_pred = [x[0] for x in groupby(show_pred)]
        show_pred = [x for x in show_pred if x != self.sos_id and x!= self.eos_id and x!= self.blank_id]

        show_pred_text = self.spm.DecodePieces([self.id2units[x] for x in show_pred])
        
        show_y = [self.id2units[x] for x in y.tolist()]
        show_y_text = self.spm.DecodePieces(show_y)

        cer = float(editdistance.eval(show_pred_text, show_y_text)) / len(show_y_text) * 100
        
        if show_decode:
            logger.log(f"Hyp:\t {show_pred_text}", prefix='debug')
            logger.log(f"Ref:\t {show_y_text}", prefix='debug')
        if show:
            logger.log(f"CER: {cer}", prefix='debug')

        return cer
Example #19
0
 def save_init(self):
     logger.log(f"Save model snapshot.init", prefix='info')
     torch.save(self.asr_model.state_dict(), \
                self.log_dir.joinpath(f"snapshot.init"))
Example #20
0
### Cal CER ####################################################################
logger.notice("CER calculating...")
cer = 0.0
with open(Path(decode_dir, 'best-hyp'), 'r') as hyp_ref_in:
    cnt = 0
    for line in hyp_ref_in.readlines():
        cnt += 1
        ref, hyp = line.rstrip().split('\t')
        ref = filter(remove_accent(spm.DecodePieces(to_list(ref))))
        ref.upper()
        hyp = filter(remove_accent(spm.DecodePieces(to_list(hyp))))
        hyp.upper()
        cer += (editdistance.eval(ref, hyp) / len(ref) * 100)
    cer = cer / cnt
logger.log(f"CER: {cer}", prefix='test')
comet_exp.log_other(f"cer({paras.decode_mode})", round(cer, 2))
with open(Path(decode_dir, 'cer'), 'w') as fout:
    print(str(cer), file=fout)
################################################################################

### Cal SER ####################################################################
logger.notice("Symbol error rate calculating...")
with open(Path(decode_dir,'best-hyp'),'r') as hyp_ref_in, \
     open(Path(decode_dir,'hyp.trn'),'w') as hyp_out, \
     open(Path(decode_dir,'ref.trn'),'w') as ref_out:
    for i, line in enumerate(hyp_ref_in.readlines()):
        foo = line.rstrip().split('\t')
        if len(foo) == 1:
            print(f"{' '.join(to_list(foo[0]))} ({i//1000}k_{i})",
                  file=ref_out)