def __init__(self, data_dir, is_memmap):
        """
        data_dir: str
        is_memmap: bool
        """
        if is_memmap:
            feat_path = data_dir.joinpath('feat').with_suffix('.dat')
            logger.log(f"Loading {feat_path} from memmap...",prefix='info')
            self.feat = np.load(feat_path, mmap_mode='r')
        else:
            feat_path = data_dir.joinpath('feat').with_suffix('.npy')
            logger.warning(f"Loading whole data ({feat_path}) into RAM")
            self.feat = np.load(feat_path)
        
        self.ilens = np.load(data_dir.joinpath('ilens.npy'))
        self.iptr = np.zeros(len(self.ilens)+1, dtype=int)
        self.ilens.cumsum(out=self.iptr[1:])

        self.label = np.load(data_dir.joinpath('label.npy'))
        self.olens = np.load(data_dir.joinpath('olens.npy'))
        self.optr = np.zeros(len(self.olens) + 1, dtype=int)
        self.olens.cumsum(out=self.optr[1:])

        assert len(self.ilens) == len(self.olens), \
        "Number of samples should be the same in features and labels"
Beispiel #2
0
    def load_model(self):
        logger.log("ASR model initialization")

        if self.paras.resume:
            logger.notice(
                f"Resume training from epoch {self.ep} (best cer: {self.best_cer}, best wer: {self.best_wer})"
            )
            self.asr_model.load_state_dict(torch.load(self.resume_model_path))
            if isinstance(self.asr_opt, TransformerOptimizer):
                with open(self.optimizer_path, 'rb') as fin:
                    self.asr_opt = pickle.load(fin)
            else:
                self.asr_opt.load_state_dict(torch.load(self.optimizer_path))
            self.dashboard.set_step(self.global_step)
        elif self.paras.pretrain:
            model_dict = self.asr_model.state_dict()
            logger.notice(
                f"Load pretraining {','.join(self.pretrain_module)} from {self.pretrain_model_path}"
            )
            pretrain_dict = self.filter_model(
                torch.load(self.pretrain_model_path))
            model_dict.update(pretrain_dict)
            self.asr_model.load_state_dict(model_dict)
            if 'freeze_module' in self.config['solver']:
                logger.warning(
                    "Part of model will be frozen during fine-tuning")
                self.freeze_module(self.config['solver']['freeze_module'])
            logger.notice("Done!")
        else:  # simple monolingual training from step 0
            logger.notice("Training from scratch")
    def run_task(self, batches):
        self._counter += 1

        self.asr_model.load_state_dict(self._original)
        self.asr_model.train()
        self.asr_opt = getattr(torch.optim, \
                               self.config['asr_model']['inner_optimizer_cls'])
        #TODO: how to set lr in inner-loop?? following noam's lr?? 
        # Should warmup here???????? How about fine-tune??????

        self.asr_opt = self.asr_opt(self.asr_model.parameters(),
                                    lr = self.inner_lr,
                                    momentum = self.config['asr_model']['inner_optimizer_opt']['momentum'],
                                    nesterov = self.config['asr_model']['inner_optimizer_opt']['nesterov'])

        for cnt, (idx, (x, ilens, ys, olens)) in enumerate(batches):
            # info = self._train(idx, x, ilens, ys, olens)
            self._train(idx, x, ilens, ys, olens)

            grad_norm = nn.utils.clip_grad_norm_(
                self.asr_model.parameters(), GRAD_CLIP)

            if math.isnan(grad_norm):
                logger.warning(f"grad norm NaN @ step {self.global_step}")
            else:
                self.asr_opt.step()

            del x, ilens, ys, olens
Beispiel #4
0
    def freeze_module(self, modules):
        # logger.warning(f"Freeze modules: {','.join(modules)}")

        for module in modules:
            logger.warning(f"Freeze {module}")
            for p in getattr(self.asr_model, module).parameters():
                p.requires_grad = False
    def exec(self):
        if self.decode_mode != 'greedy':
            logger.notice(f"Start decoding with beam search (with beam size: {self.config['solver']['beam_decode']['beam_size']})")
            raise NotImplementedError(f"{self.decode_mode} haven't supported yet")
            self._decode = self.beam_decode
        else:
            logger.notice("Start greedy decoding")
            if self.batch_size > 1:
                dev = 'gpu' if self.use_gpu else 'cpu'
                logger.log(f"Number of utterance batches to decode: {len(self.eval_set)}, decoding with {self.batch_size} batch_size using {dev}")
                self._decode = self.batch_greedy_decode
                self.njobs = 1
            else:
                logger.log(f"Number of utterances to decode: {len(self.eval_set)}, decoding with {self.njobs} threads using cpu")
                self._decode = self.greedy_decode

        if self.njobs > 1:
            try:
                _ = Parallel(n_jobs=self.njobs)(delayed(self._decode)(i, x, ilen, y, olen) for i, (x, ilen, y, olen) in enumerate(self.eval_set))

            #NOTE: cannot log comet here, since it cannot serialize
            except KeyboardInterrupt:
                logger.warning("Decoding stopped")
            else:
                logger.notice("Decoding done")
                # self.comet_exp.log_other('status','decoded')
        else:

            tbar = get_bar(total=len(self.eval_set), leave=True)

            for cur_b, (xs, ilens, ys, olens) in enumerate(self.eval_set):
                self.batch_greedy_decode(xs, ilens, ys, olens)
                tbar.update(1)
    def train(self):
        try:
            while self.global_step < self.max_step:
                tbar = get_bar(total=self.eval_ival, \
                               desc=f"Step {self.global_step}", leave=True)

                for _ in range(self.eval_ival):
                    #TODO: we can add sampling method to compare Meta and Multi fair
                    idx, (x, ilens, ys,
                          olens) = self.data_container.get_item()[0]

                    batch_size = len(ys)
                    info = self._train(idx,
                                       x,
                                       ilens,
                                       ys,
                                       olens,
                                       accent_idx=idx)
                    self.train_info.add(info, batch_size)

                    grad_norm = nn.utils.clip_grad_norm_(
                        self.asr_model.parameters(), GRAD_CLIP)

                    if math.isnan(grad_norm):
                        logger.warning(
                            f"grad norm NaN @ step {self.global_step}")
                    else:
                        self.asr_opt.step()

                    if isinstance(self.asr_opt, TransformerOptimizer):
                        self.log_msg(self.asr_opt.lr)
                    else:
                        self.log_msg()
                    self.check_evaluate()

                    self.global_step += 1
                    self.dashboard.step()

                    del x, ilens, ys, olens
                    tbar.update(1)

                    if self.global_step % self.save_ival == 0:
                        self.save_per_steps()
                    self.dashboard.check()
                tbar.close()

        except KeyboardInterrupt:
            logger.warning("Pretraining stopped")
            self.save_per_steps()
            self.dashboard.set_status('pretrained(SIGINT)')
        else:
            logger.notice("Pretraining completed")
            self.dashboard.set_status('pretrained')
Beispiel #7
0
    def train(self):
        self.evaluate()
        try:
            if self.save_verbose:
                self.save_init()
            while self.ep < self.max_epoch:
                tbar = get_bar(total=len(self.train_set), \
                               desc=f"Epoch {self.ep}", leave=True)

                for cur_b, (x, ilens, ys, olens) in enumerate(self.train_set):

                    batch_size = len(ys)
                    info = self._train(cur_b, x, ilens, ys, olens)
                    self.train_info.add(info, batch_size)

                    grad_norm = nn.utils.clip_grad_norm_(
                        self.asr_model.parameters(), GRAD_CLIP)

                    if math.isnan(grad_norm):
                        logger.warning(
                            f"grad norm NaN @ step {self.global_step}")
                    else:
                        self.asr_opt.step()

                    if isinstance(self.asr_opt, TransformerOptimizer):
                        self.log_msg(self.asr_opt.lr)
                    else:
                        self.log_msg()
                    self.check_evaluate()

                    self.global_step += 1
                    self.dashboard.step()

                    del x, ilens, ys, olens
                    tbar.update(1)

                self.ep += 1
                self.save_per_epoch()
                self.dashboard.check()
                tbar.close()

                if self.eval_every_epoch:
                    self.evaluate()

        except KeyboardInterrupt:
            logger.warning("Training stopped")
            self.evaluate()
            self.dashboard.set_status('trained(SIGINT)')
        else:
            logger.notice("Training completed")
            self.dashboard.set_status('trained')
    def train(self):

        try:
            task_ids = list(range(self.num_pretrain))
            while self.global_step < self.max_step:
                tbar = get_bar(total=self.eval_ival, \
                               desc=f"Step {self.global_step}", leave=True)
                for _ in range(self.eval_ival):
                    shuffle(task_ids)

                    #FIXME: Here split to inner-train and inner-test (should observe whether the performance drops)
                    for accent_id in task_ids[:self.meta_batch_size]:
                        # inner-loop learn
                        tr_batches = self.data_container.get_item(accent_id, self.meta_k)
                        self.run_task(tr_batches)

                        # inner-loop test
                        val_batch = self.data_container.get_item(accent_id)[0]
                        batch_size = len(val_batch[1][2])
                        info = self._train(val_batch[0],*val_batch[1], accent_idx = val_batch[0])
                        grad_norm = nn.utils.clip_grad_norm_(
                            self.asr_model.parameters(), GRAD_CLIP)

                        if math.isnan(grad_norm):
                            logger.warning(f"grad norm NaN @ step {self.global_step} on {self.accents[accent_id]}, ignore...")

                        self._partial_meta_update()
                        del val_batch
                        self.train_info.add(info, batch_size)

                    self._final_meta_update()

                    self.log_msg(self.meta_opt.lr)
                    self.check_evaluate()
                    self.global_step += 1
                    self.dashboard.step()
                    tbar.update(1)

                    if self.global_step % self.save_ival == 0:
                        self.save_per_steps()
                    self.dashboard.check()
                tbar.close()

        except KeyboardInterrupt:
            logger.warning("Pretraining stopped")
            self.save_per_steps()
            self.dashboard.set_status('pretrained(SIGINT)')
        else:
            logger.notice("Pretraining completed")
            self.dashboard.set_status('pretrained')
random.seed(paras.seed)
np.random.seed(paras.seed)
torch.manual_seed(paras.seed)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(paras.seed)

with open(Path('data', 'accent-code.json'), 'r') as fin:
    id2accent = json.load(fin)

if paras.test:
    from src.tester import Tester

    if paras.decode_mode != 'greedy':
        assert paras.decode_batch_size == 1, f"decode_batch_size can only be 1 if decode_mode is {paras.decode_mode}"
        if paras.cuda and torch.cuda.device_count() == 0:
            logger.warning(
                f"cuda is set to True, but no gpu detected, use cpu for decoding"
            )
            paras.cuda = False
    solver = Tester(config, paras, id2accent)
else:
    if paras.model_name == 'blstm':
        from src.blstm_trainer import get_trainer
    elif paras.model_name == 'las':
        from src.las_trainer import get_trainer
    elif paras.model_name == 'transformer':
        from src.transformer_torch_trainer import get_trainer
    else:
        raise NotImplementedError
    solver = get_trainer(MonoASRInterface, config, paras, id2accent)

solver.load_data()
    def __init__(self, id2char, model_para):
        super(MyTransformer, self).__init__()

        self.idim = model_para['idim']

        #FIXME: need to remove these hardcoded thing later
        self.odim = len(id2char)
        #TODO: check whether we need to make sos_id and eos_id different or the same
        self.sos_id = 0
        self.eos_id = len(id2char)-1
        self.vgg_ch_dim = 128
        
        self.feat_extractor = nn.Sequential(
                nn.Conv2d(1, 64, 3, stride=1, padding=1),
                nn.ReLU(),
                nn.Conv2d(64, 64, 3, stride=1, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(2, stride=2),
                nn.Conv2d(64, 128, 3, stride=1, padding=1),
                nn.ReLU(),
                nn.Conv2d(self.vgg_ch_dim,self.vgg_ch_dim, 3, stride=1, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(2, stride=2),
        )
        self.vgg_o_dim = self.vgg_ch_dim * floor(self.idim/4)
        self.vgg2enc = nn.Linear(self.vgg_o_dim, model_para['d_model'])

        self.pos_encoder = PositionalEncoding(model_para['d_model'], model_para['pos_dropout'])

        self.char_trans = nn.Linear(model_para['d_model'], self.odim)
        self.pre_embed = nn.Embedding(self.odim, model_para['d_model'])
        if model_para['tgt_share_weight'] != 0:
            logger.warning("Tie weight of char_trans and embedding")
            self.char_trans.weight = self.pre_embed.weight

        self.d_model = model_para['d_model']
        self.nhead = model_para['nheads']
        encoder_layer = nn.TransformerEncoderLayer(
            d_model = self.d_model,
            nhead = self.nhead,
            dim_feedforward = model_para['d_inner'],
            dropout=model_para['dropout']
        )
        encoder_norm = nn.LayerNorm(model_para['d_model'])
        self.encoder = nn.TransformerEncoder(
            encoder_layer = encoder_layer,
            num_layers = model_para['encoder']['nlayers'],
            norm = encoder_norm
        )

        decoder_layer = nn.TransformerDecoderLayer(
            d_model = self.d_model,
            nhead = self.nhead,
            dim_feedforward = model_para['d_inner'],
            dropout=model_para['dropout']
        )
        decoder_norm = nn.LayerNorm(model_para['d_model'])
        self.decoder = nn.TransformerDecoder(
            decoder_layer = decoder_layer,
            num_layers = model_para['decoder']['nlayers'],
            norm = decoder_norm
        )

        self.init_parameters()
    def __init__(self, config, paras, id2accent):

        self.config = config
        self.paras = paras
        self.train_type = 'evaluation'
        self.is_memmap = paras.is_memmap
        self.model_name = paras.model_name

        self.njobs = paras.njobs

        if paras.algo == 'no' and paras.pretrain_suffix is None:
            paras.pretrain_suffix = paras.eval_suffix

        ### Set path
        cur_path = Path.cwd()
        self.data_dir = Path(config['solver']['data_root'], id2accent[paras.accent])
        self.log_dir = Path(cur_path, LOG_DIR, self.train_type, 
                            config['solver']['setting'], paras.algo, \
                            paras.pretrain_suffix, paras.eval_suffix, \
                            id2accent[paras.accent], str(paras.runs))
        self.model_path = Path(self.log_dir, paras.test_model)
        assert self.model_path.exists(), f"{self.model_path.as_posix()} not exists..."
        self.decode_dir = Path(self.log_dir, paras.decode_suffix)


        ### Decode
        self.decode_mode = paras.decode_mode
        self.beam_decode_param = config['solver']['beam_decode']
        self.batch_size = paras.decode_batch_size
        self.use_gpu = paras.cuda

        if paras.decode_mode == 'lm_beam':
            assert paras.lm_model_path is not None, "In LM Beam decode mode, lm_model_path should be specified"
            # assert self.model_name == 'blstm', "LM Beam decode is only supported in blstm model"
            self.lm_model_path  = paras.lm_model_path
        else :
            self.lm_model_path = None

        # if paras.decode_mode == 'greedy':
            # self._decode = self.greedy_decode
        # elif paras.decode_mode == 'beam' or paras.decode_mode == 'lm_beam':
            # self._decode = self.beam_decode
        # else :
            # raise NotImplementedError
        #####################################################################

        ### Resume Mechanism
        if not paras.resume:
            if self.decode_dir.exists():
                assert paras.overwrite, \
                    f"Path exists ({self.decode_dir}). Use --overwrite or change decode suffix"
                # time.sleep(10)
                logger.warning('Overwrite existing directory')
                rmtree(self.decode_dir)
            self.decode_dir.mkdir(parents=True)
            self.prev_decode_step = -1
        else:
            with open(Path(self.decode_dir,'best-hyp'),'r') as f:
                for i, l in enumerate(f):
                    pass
                self.prev_decode_step = i+1
            logger.notice(f"Decode resume from {self.prev_decode_step}")

        ### Comet
        with open(Path(self.log_dir,'exp_key'),'r') as f:
            exp_key = f.read().strip()
            comet_exp = ExistingExperiment(previous_experiment=exp_key,
                                           project_name=COMET_PROJECT_NAME,
                                           workspace=COMET_WORKSPACE,
                                           auto_output_logging=None,
                                           auto_metric_logging=None,
                                           display_summary_level=0,
                                           )
        comet_exp.log_other('status','decode')
Beispiel #12
0
 def check(self):
     if not self.exp.alive:
         logger.warning("Comet logging stopped")
Beispiel #13
0
    def __init__(self, config, paras, id2accent):

        ### config setting
        self.config = config
        self.paras = paras
        self.train_type = 'pretrain'
        self.is_memmap = paras.is_memmap
        self.is_bucket = paras.is_bucket
        self.model_name = paras.model_name
        self.eval_ival = config['solver']['eval_ival']
        self.log_ival = config['solver']['log_ival']
        self.save_ival = config['solver']['save_ival']
        self.half_batch_ilen = config['solver']['half_batch_ilen']
        self.dev_max_ilen = config['solver']['dev_max_ilen']

        self.sample_strategy = paras.sample_strategy

        self.best_cer = INIT_BEST_ER
        self.best_wer = INIT_BEST_ER

        if self.paras.model_name == 'transformer':
            self.id2units = [SOS_SYMBOL]
            with open(config['solver']['spm_mapping']) as fin:
                for line in fin.readlines():
                    self.id2units.append(line.rstrip().split(' ')[0])
            self.id2units.append(EOS_SYMBOL)
            self.metric_observer = Metric(config['solver']['spm_model'],
                                          self.id2units, 0,
                                          len(self.id2units) - 1)
        elif self.paras.model_name == 'blstm':
            self.id2units = [BLANK_SYMBOL]
            with open(config['solver']['spm_mapping']) as fin:
                for line in fin.readlines():
                    self.id2units.append(line.rstrip().split(' ')[0])
            self.id2units.append(EOS_SYMBOL)
            self.metric_observer = Metric(config['solver']['spm_model'],
                                          self.id2units,
                                          len(self.id2units) - 1,
                                          len(self.id2units) - 1)
        else:
            raise ValueError(f"Unknown model name {self.paras.model_name}")

        self.accents = [
            id2accent[accent_id] for accent_id in paras.pretrain_accents
        ]
        self.num_pretrain = paras.num_pretrain
        self.tgt_accent = id2accent[paras.tgt_accent]

        self.max_step = paras.max_step if paras.max_step > 0 else config[
            'solver']['total_steps']
        #######################################################################

        ### Set path
        assert self.num_pretrain == len(self.accents),\
            f"num_pretrain is {self.num_pretrain}, but got {len(self.accents)} in pretrain_accents"

        cur_path = Path.cwd()
        self.data_dirs = [
            Path(config['solver']['data_root']).joinpath(accent)
            for accent in self.accents
        ]

        self.log_dir = Path(cur_path, LOG_DIR, self.train_type,
                            config['solver']['setting'], paras.algo,
                            paras.pretrain_suffix, self.tgt_accent,
                            str(paras.runs))

        if not paras.resume:
            if self.log_dir.exists():
                assert paras.overwrite, \
                    f"Path exists ({self.log_dir}). Use --overwrite or change suffix"
                # time.sleep(10)
                logger.warning('Overwriting existing directory')
                rmtree(self.log_dir)

            self.log_dir.mkdir(parents=True)
            self.train_info = RunningAvgDict(decay_rate=0.99)
            self.global_step = 1
        else:
            self.resume_model_path = self.log_dir.joinpath('snapshot.latest')
            info_dict_path = self.log_dir.joinpath('info_dict.latest')
            self.optimizer_path = self.log_dir.joinpath('optimizer.latest')


            assert self.optimizer_path.exists(), \
                f"Optimizer state {self.optimizer_path} not exists..."

            with open(Path(self.log_dir, 'global_step'), 'r') as f:
                self.global_step = int(f.read().strip())

            assert self.resume_model_path.exists(), \
                f"{self.resume_model_path} not exists..."
            assert info_dict_path.exists(), \
                f"PreTraining info {info_dict_path} not exists..."

            with open(info_dict_path, 'rb') as fin:
                self.train_info = pickle.load(fin)
        if paras.use_tensorboard:
            from src.monitor.tb_dashboard import Dashboard
            logger.warning("Use tensorboard instead of comet")
        else:
            from src.monitor.dashboard import Dashboard

        self.dashboard = Dashboard(config, paras, self.log_dir, \
                                   self.train_type, paras.resume)
    def __init__(self, config, paras, id2accent):

        ### config setting
        self.config = config
        self.paras = paras
        self.train_type = 'evaluation'
        self.is_memmap = paras.is_memmap
        self.is_bucket = paras.is_bucket
        self.model_name = paras.model_name
        self.eval_ival = config['solver']['eval_ival']
        self.log_ival = config['solver']['log_ival']
        self.half_batch_ilen = config['solver']['half_batch_ilen'],
        self.dev_max_ilen = config['solver']['dev_max_ilen']

        self.best_wer = INIT_BEST_ER
        self.best_cer = INIT_BEST_ER
        if self.paras.model_name == 'transformer':
            self.id2units = [SOS_SYMBOL]
            with open(config['solver']['spm_mapping']) as fin:
                for line in fin.readlines():
                    self.id2units.append(line.rstrip().split(' ')[0])
            self.id2units.append(EOS_SYMBOL)
            self.metric_observer = Metric(config['solver']['spm_model'],
                                          self.id2units, 0,
                                          len(self.id2units) - 1)
        elif self.paras.model_name == 'blstm':
            self.id2units = [BLANK_SYMBOL]
            with open(config['solver']['spm_mapping']) as fin:
                for line in fin.readlines():
                    self.id2units.append(line.rstrip().split(' ')[0])
            self.id2units.append(EOS_SYMBOL)
            self.metric_observer = Metric(config['solver']['spm_model'],
                                          self.id2units,
                                          len(self.id2units) - 1,
                                          len(self.id2units) - 1)
        else:
            raise ValueError(f"Unknown model name {self.paras.model_name}")

        self.save_verbose = paras.save_verbose
        #######################################################################

        ### Set path
        cur_path = Path.cwd()

        if paras.pretrain:
            assert paras.pretrain_suffix or paras.pretrain_model_path, \
            "You should specify pretrain model and the corresponding prefix"

            if paras.pretrain_model_path:
                self.pretrain_model_path = Path(paras.pretrain_model_path)
            else:
                assert paras.pretrain_suffix and paras.pretrain_setting and paras.pretrain_step > 0, "Should specify pretrain_setting"
                self.pretrain_model_path = Path(cur_path, LOG_DIR, 'pretrain', \
                                                paras.pretrain_setting, paras.algo, \
                                                paras.pretrain_suffix, \
                                                id2accent[paras.pretrain_tgt_accent],\
                                                str(paras.pretrain_runs), f"snapshot.step.{paras.pretrain_step}")

            assert self.pretrain_model_path.exists(), \
                f"Pretrain model path {self.pretrain_model_path} not exists"
            self.pretrain_module = config['solver']['pretrain_module']
        else:
            assert paras.pretrain_suffix is None and paras.algo == 'no', \
            f"Training from scratch shouldn't have meta-learner {paras.algo} and pretrain_suffix"
            paras.pretrain_suffix = paras.eval_suffix

        self.accent = id2accent[paras.accent]
        self.data_dir = Path(config['solver']['data_root'], self.accent)
        self.log_dir = Path(cur_path, LOG_DIR,self.train_type, \
                            config['solver']['setting'], paras.algo, \
                            paras.pretrain_suffix, paras.eval_suffix, \
                            self.accent, str(paras.runs))
        ########################################################################

        ### Resume mechanism
        if not paras.resume:
            if self.log_dir.exists():
                assert paras.overwrite, \
                    f"Path exists ({self.log_dir}). Use --overwrite or change suffix"
                # time.sleep(10)
                logger.warning('Overwrite existing directory')
                rmtree(self.log_dir)

            self.log_dir.mkdir(parents=True)
            self.train_info = RunningAvgDict(decay_rate=0.99)
            self.global_step = 1
            self.ep = 0

        else:
            self.resume_model_path = self.log_dir.joinpath('snapshot.latest')
            info_dict_path = self.log_dir.joinpath('info_dict.latest')
            self.optimizer_path = self.log_dir.joinpath('optimizer.latest')

            assert self.optimizer_path.exists(), \
                f"Optimizer state {self.optimizer_path} not exists..."

            with open(Path(self.log_dir, 'epoch'), 'r') as f:
                self.ep = int(f.read().strip())
            with open(Path(self.log_dir, 'global_step'), 'r') as f:
                self.global_step = int(f.read().strip())
            with open(Path(self.log_dir, 'best_wer'), 'r') as f:
                self.best_wer = float(f.read().strip().split(' ')[1])
            with open(Path(self.log_dir, 'best_cer'), 'r') as f:
                self.best_cer = float(f.read().strip().split(' ')[1])

            assert self.resume_model_path.exists(),\
                f"{self.resume_model_path} not exists..."
            assert info_dict_path.exists(),\
                f"Training info {info_dict_path} not exists..."

            with open(info_dict_path, 'rb') as fin:
                self.train_info = pickle.load(fin)

        if paras.use_tensorboard:
            from src.monitor.tb_dashboard import Dashboard
            logger.warning("Use tensorboard instead of comet")
        else:
            from src.monitor.dashboard import Dashboard

        self.dashboard = Dashboard(config, paras, self.log_dir, \
                                   self.train_type, paras.resume)