예제 #1
0
class Solver(BaseSolver):
    ''' Solver for training'''

    def __init__(self, config, paras, mode):
        super().__init__(config, paras, mode)

        # ToDo : support tr/eval on different corpus
        assert self.config['data']['corpus']['name'] == self.src_config['data']['corpus']['name']
        self.config['data']['corpus']['path'] = self.src_config['data']['corpus']['path']
        self.config['data']['corpus']['bucketing'] = False

        # The follow attribute should be identical to training config
        self.config['data']['audio'] = self.src_config['data']['audio']
        self.config['data']['corpus']['train_split'] = self.src_config['data']['corpus']['train_split']
        self.config['data']['text'] = self.src_config['data']['text']
        self.tokenizer = load_text_encoder(**self.config['data']['text'])
        self.config['model'] = self.src_config['model']
        self.finetune_first = 5
        self.best_wer = {'att': 3.0, 'ctc': 3.0}

        # Output file
        self.output_file = str(self.ckpdir)+'_{}_{}.csv'

        # Override batch size for beam decoding
        self.greedy = self.config['decode']['beam_size'] == 1
        self.dealer = Datadealer(self.config['data']['audio'])
        self.ctc = self.config['decode']['ctc_weight'] == 1.0
        if not self.greedy:
            self.config['data']['corpus']['batch_size'] = 1
        else:
            # ToDo : implement greedy
            raise NotImplementedError

        # Logger settings
        self.logdir = os.path.join(paras.logdir, self.exp_name)
        self.log = SummaryWriter(
            self.logdir, flush_secs=self.TB_FLUSH_FREQ)
        self.timer = Timer()

    def fetch_data(self, data):
        ''' Move data to device and compute text seq. length'''
        _, feat, feat_len, txt = data
        feat = feat.to(self.device)
        feat_len = feat_len.to(self.device)
        txt = txt.to(self.device)
        txt_len = torch.sum(txt != 0, dim=-1)

        return feat, feat_len, txt, txt_len

    def load_data(self, batch_size=7):
        ''' Load data for training/validation, store tokenizer and input/output shape'''
        prev_batch_size = self.config['data']['corpus']['batch_size']
        self.config['data']['corpus']['batch_size'] = batch_size
        self.tr_set, self.dv_set, self.feat_dim, self.vocab_size, self.tokenizer, msg = \
            load_dataset(self.paras.njobs, self.paras.gpu,
                         self.paras.pin_memory, False, **self.config['data'])
        self.config['data']['corpus']['batch_size'] = prev_batch_size
        self.verbose(msg)

    def set_model(self):
        ''' Setup ASR model '''
        # Model
        self.feat_dim = 120
        self.vocab_size = 46 
        init_adadelta = True
        ''' Setup ASR model and optimizer '''
        # Model
        # init_adadelta = self.config['hparas']['optimizer'] == 'Adadelta'
        self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta, **
                         self.src_config['model']).to(self.device)
        self.verbose(self.model.create_msg())

        if self.finetune_first>0:
            names = ["encoder.layers.%d"%i for i in range(self.finetune_first)]
            model_paras = [{"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in names)]}]
        else:
            model_paras = [{'params': self.model.parameters()}]

        # Losses
        self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0)
        # Note: zero_infinity=False is unstable?
        self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=False)

        # Plug-ins
        self.emb_fuse = False
        self.emb_reg = ('emb' in self.config) and (
            self.config['emb']['enable'])
        if self.emb_reg:
            from src.plugin import EmbeddingRegularizer
            self.emb_decoder = EmbeddingRegularizer(
                self.tokenizer, self.model.dec_dim, **self.config['emb']).to(self.device)
            model_paras.append({'params': self.emb_decoder.parameters()})
            self.emb_fuse = self.emb_decoder.apply_fuse
            if self.emb_fuse:
                self.seq_loss = torch.nn.NLLLoss(ignore_index=0)
            self.verbose(self.emb_decoder.create_msg())

        # Optimizer
        self.optimizer = Optimizer(model_paras, **self.src_config['hparas'])
        self.verbose(self.optimizer.create_msg())

        # Enable AMP if needed
        self.enable_apex()

        # Automatically load pre-trained model if self.paras.load is given
        self.load_ckpt()
        # Beam decoder
        self.decoder = BeamDecoder(
            self.model, self.emb_decoder, **self.config['decode'])
        self.verbose(self.decoder.create_msg())
        # del self.model
        # del self.emb_decoder
        self.decoder.to(self.device)

    def exec(self):
        ''' Testing End-to-end ASR system '''
        while True:
            try:
                filename = input("Input wav file name: ")
                if filename == "exit":
                    return
                feat, feat_len = self.dealer(filename)
                feat = feat.to(self.device)
                feat_len = feat_len.to(self.device)
                # Decode
                with torch.no_grad():
                    hyps = self.decoder(feat, feat_len)

                hyp_seqs = [hyp.outIndex for hyp in hyps]
                hyp_txts = [self.tokenizer.decode(hyp, ignore_repeat=self.ctc) for hyp in hyp_seqs]
                for txt in hyp_txts:
                    print(txt)
            except:
                print("Invalid file")
                pass

    def recognize(self, filename):
        try:
            feat, feat_len = self.dealer(filename)
            feat = feat.to(self.device)
            feat_len = feat_len.to(self.device)
            # Decode
            with torch.no_grad():
                hyps = self.decoder(feat, feat_len)
            
            hyp_seqs = [hyp.outIndex for hyp in hyps]
            hyp_txts = [self.tokenizer.decode(hyp, ignore_repeat=self.ctc) for hyp in hyp_seqs]
            return hyp_txts[0]
        except Exception as e:
            print(e)
            app.logger.debug(e)
            return "Invalid file"

    def fetch_finetune_data(self, filename, fixed_text):
        feat, feat_len = self.dealer(filename)
        feat = feat.to(self.device)
        feat_len = feat_len.to(self.device)
        text = self.tokenizer.encode(fixed_text)
        text = torch.tensor(text).to(self.device)
        text_len = len(text)
        return [feat, feat_len, text, text_len]

    def merge_batch(self, main_batch, attach_batch):
        max_feat_len = max(main_batch[1])
        max_text_len = max(main_batch[3])
        if attach_batch[0].shape[1] > max_feat_len:
            # reduce extra long example
            attach_batch[0] = attach_batch[0][:,:max_feat_len]
            attach_batch[1][0] = max_feat_len
        else:
            # pad to max_feat_len
            padding = torch.zeros(1, max_feat_len - attach_batch[0].shape[1], attach_batch[0].shape[2], dtype=attach_batch[0].dtype).to(self.device)
            attach_batch[0] = torch.cat([attach_batch[0], padding], dim=1)
        if attach_batch[2].shape[0] > max_text_len:
            attach_batch[2] = attach_batch[2][:max_text_len]
            main_batch[3][0] = max_text_len
        else:
            padding = torch.zeros(max_text_len - attach_batch[2].shape[0], dtype=attach_batch[2].dtype).to(self.device)
            try:
                attach_batch[2] = torch.cat([attach_batch[2], padding], dim=0).unsqueeze(0)
            except:
                pdb.set_trace()
        new_batch = (
            torch.cat([main_batch[0], attach_batch[0]], dim=0),
            torch.cat([main_batch[1], attach_batch[1]], dim=0),
            torch.cat([main_batch[2], attach_batch[2]], dim=0),
            torch.cat([main_batch[3], torch.tensor([attach_batch[3]]).to(self.device)], dim=0)
        )
        return new_batch
            


    def finetune(self, filename, fixed_text, max_step=5):
        # Load data for finetune
        self.verbose('Total training steps {}.'.format(
            human_format(max_step)))
        ctc_loss, att_loss, emb_loss = None, None, None
        n_epochs = 0
        accum_count = 0
        self.timer.set()
        step = 0
        for data in self.tr_set:
            # Pre-step : update tf_rate/lr_rate and do zero_grad
            if max_step == 0:
                break
            tf_rate = self.optimizer.pre_step(400000)
            total_loss = 0

            # Fetch data
            finetune_data = self.fetch_finetune_data(filename, fixed_text)
            main_batch = self.fetch_data(data)
            new_batch = self.merge_batch(main_batch, finetune_data)
            feat, feat_len, txt, txt_len = new_batch
            self.timer.cnt('rd')

            # Forward model
            # Note: txt should NOT start w/ <sos>
            ctc_output, encode_len, att_output, att_align, dec_state = \
                self.model(feat, feat_len, max(txt_len), tf_rate=tf_rate,
                            teacher=txt, get_dec_state=self.emb_reg)

            # Plugins
            if self.emb_reg:
                emb_loss, fuse_output = self.emb_decoder(
                    dec_state, att_output, label=txt)
                total_loss += self.emb_decoder.weight*emb_loss

            # Compute all objectives
            if ctc_output is not None:
                if self.paras.cudnn_ctc:
                    ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1),
                                                txt.to_sparse().values().to(device='cpu', dtype=torch.int32),
                                                [ctc_output.shape[1]] *
                                                len(ctc_output),
                                                txt_len.cpu().tolist())
                else:
                    ctc_loss = self.ctc_loss(ctc_output.transpose(
                        0, 1), txt, encode_len, txt_len)
                total_loss += ctc_loss*self.model.ctc_weight

            if att_output is not None:
                b, t, _ = att_output.shape
                att_output = fuse_output if self.emb_fuse else att_output
                att_loss = self.seq_loss(
                    att_output.contiguous().view(b*t, -1), txt.contiguous().view(-1))
                total_loss += att_loss*(1-self.model.ctc_weight)

            self.timer.cnt('fw')

            # Backprop
            grad_norm = self.backward(total_loss)
            step += 1

            # Logger
            self.progress('Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'
                        .format(total_loss.cpu().item(), grad_norm, self.timer.show()))
            self.write_log(
                'loss', {'tr_ctc': ctc_loss, 'tr_att': att_loss})
            self.write_log('emb_loss', {'tr': emb_loss})
            self.write_log('wer', {'tr_att': cal_er(self.tokenizer, att_output, txt),
                                'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, ctc=True)})
            if self.emb_fuse:
                if self.emb_decoder.fuse_learnable:
                    self.write_log('fuse_lambda', {
                                'emb': self.emb_decoder.get_weight()})
                self.write_log(
                    'fuse_temp', {'temp': self.emb_decoder.get_temp()})

            # End of step
            # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354
            torch.cuda.empty_cache()
            self.timer.set()
            if step > max_step:
                break
        ret = self.validate()
        self.log.close()
        return ret


    def validate(self):
        # Eval mode
        self.model.eval()
        if self.emb_decoder is not None:
            self.emb_decoder.eval()
        dev_wer = {'att': [], 'ctc': []}

        for i, data in enumerate(self.dv_set):
            self.progress('Valid step - {}/{}'.format(i+1, len(self.dv_set)))
            # Fetch data
            feat, feat_len, txt, txt_len = self.fetch_data(data)

            # Forward model
            with torch.no_grad():
                ctc_output, encode_len, att_output, att_align, dec_state = \
                    self.model(feat, feat_len, int(max(txt_len)*self.DEV_STEP_RATIO),
                               emb_decoder=self.emb_decoder)

            dev_wer['att'].append(cal_er(self.tokenizer, att_output, txt))
            dev_wer['ctc'].append(cal_er(self.tokenizer, ctc_output, txt, ctc=True))

            # Show some example on tensorboard
            if i == len(self.dv_set)//2:
                for i in range(min(len(txt), self.DEV_N_EXAMPLE)):
                    if True:
                        self.write_log('true_text{}'.format(
                            i), self.tokenizer.decode(txt[i].tolist()))
                    if att_output is not None:
                        self.write_log('att_align{}'.format(i), feat_to_fig(
                            att_align[i, 0, :, :].cpu().detach()))
                        self.write_log('att_text{}'.format(i), self.tokenizer.decode(
                            att_output[i].argmax(dim=-1).tolist()))
                    if ctc_output is not None:
                        self.write_log('ctc_text{}'.format(i), self.tokenizer.decode(ctc_output[i].argmax(dim=-1).tolist(),
                                                                                     ignore_repeat=True))

        # Skip save model here
        # Ckpt if performance improves
        to_prints = []
        for task in ['att', 'ctc']:
            dev_wer[task] = sum(dev_wer[task]) / len(dev_wer[task])
            if dev_wer[task] < self.best_wer[task]:
                to_print = f"WER of {task}: {dev_wer[task]} < prev best ({self.best_wer[task]})"
                self.best_wer[task] = dev_wer[task]
            else:
                to_print = f"WER of {task}: {dev_wer[task]} >= prev best ({self.best_wer[task]})"
            print(to_print, flush=True)
            to_prints.append(to_print)
        #         self.save_checkpoint('best_{}.pth'.format(task), 'wer', dev_wer[task])
            self.write_log('wer', {'dv_'+task: dev_wer[task]})
        # self.save_checkpoint('latest.pth', 'wer', dev_wer['att'], show_msg=False)

        # Resume training
        self.model.train()
        if self.emb_decoder is not None:
            self.emb_decoder.train()
        return '\n'.join(to_prints)
예제 #2
0
class BaseSolver():
    ''' 
    Prototype Solver for all kinds of tasks
    Arguments
        config - yaml-styled config
        paras  - argparse outcome
    '''
    def __init__(self, config, paras, mode):
        # General Settings
        self.config = config
        self.paras = paras
        self.mode = mode
        for k, v in default_hparas.items():
            setattr(self, k, v)
        self.device = torch.device(
            'cuda') if self.paras.gpu and torch.cuda.is_available(
            ) else torch.device('cpu')
        self.amp = paras.amp

        # Name experiment
        self.exp_name = paras.name
        if self.exp_name is None:
            # By default, exp is named after config file
            self.exp_name = paras.config.split('/')[-1].replace('.yaml', '')
            if mode == 'train':
                self.exp_name += '_sd{}'.format(paras.seed)

        # Plugin list
        self.emb_decoder = None

        if mode == 'train':
            # Filepath setup
            os.makedirs(paras.ckpdir, exist_ok=True)
            self.ckpdir = os.path.join(paras.ckpdir, self.exp_name)
            os.makedirs(self.ckpdir, exist_ok=True)

            # Logger settings
            self.logdir = os.path.join(paras.logdir, self.exp_name)
            self.log = SummaryWriter(self.logdir,
                                     flush_secs=self.TB_FLUSH_FREQ)
            self.timer = Timer()

            # Hyperparameters
            self.step = 0
            self.valid_step = config['hparas']['valid_step']
            self.max_step = config['hparas']['max_step']

            self.verbose('Exp. name : {}'.format(self.exp_name))
            self.verbose('Loading data... large corpus may took a while.')

        elif mode == 'test':
            # Output path
            os.makedirs(paras.outdir, exist_ok=True)
            self.ckpdir = os.path.join(paras.outdir, self.exp_name)

            # Load training config to get acoustic feat, text encoder and build model
            self.src_config = yaml.load(open(config['src']['config'], 'r'),
                                        Loader=yaml.FullLoader)
            self.paras.load = config['src']['ckpt']

            self.verbose('Evaluating result of tr. config @ {}'.format(
                config['src']['config']))

    def backward(self, loss):
        '''
        Standard backward step with self.timer and debugger
        Arguments
            loss - the loss to perform loss.backward()
        '''
        self.timer.set()
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   self.GRAD_CLIP)
        if math.isnan(grad_norm):
            self.verbose('Error : grad norm is NaN @ step ' + str(self.step))
        else:
            self.optimizer.step()
        self.timer.cnt('bw')
        return grad_norm

    def load_ckpt(self):
        ''' Load ckpt if --load option is specified '''
        if self.paras.load:
            # Load weights
            ckpt = torch.load(
                self.paras.load,
                map_location=self.device if self.mode == 'train' else 'cpu')
            self.model.load_state_dict(ckpt['model'])
            if self.emb_decoder is not None:
                self.emb_decoder.load_state_dict(ckpt['emb_decoder'])
            # if self.amp:
            #    amp.load_state_dict(ckpt['amp'])
            # Load task-dependent items
            if self.mode == 'train':
                self.step = ckpt['global_step']
                self.optimizer.load_opt_state_dict(ckpt['optimizer'])
                self.verbose('Load ckpt from {}, restarting at step {}'.format(
                    self.paras.load, self.step))
            else:
                for k, v in ckpt.items():
                    if type(v) is float:
                        metric, score = k, v
                self.model.eval()
                if self.emb_decoder is not None:
                    self.emb_decoder.eval()
                self.verbose(
                    'Evaluation target = {} (recorded {} = {:.2f} %)'.format(
                        self.paras.load, metric, score))

    def verbose(self, msg):
        ''' Verbose function for print information to stdout'''
        if self.paras.verbose:
            if type(msg) == list:
                for m in msg:
                    print('[INFO]', m.ljust(100))
            else:
                print('[INFO]', msg.ljust(100))

    def progress(self, msg):
        ''' Verbose function for updating progress on stdout (do not include newline) '''
        if self.paras.verbose:
            sys.stdout.write("\033[K")  # Clear line
            print('[{}] {}'.format(human_format(self.step), msg), end='\r')

    def write_log(self, log_name, log_dict):
        '''
        Write log to TensorBoard
            log_name  - <str> Name of tensorboard variable 
            log_value - <dict>/<array> Value of variable (e.g. dict of losses), passed if value = None
        '''
        if type(log_dict) is dict:
            log_dict = {
                key: val
                for key, val in log_dict.items()
                if (val is not None and not math.isnan(val))
            }
        if log_dict is None:
            pass
        elif len(log_dict) > 0:
            if 'align' in log_name or 'spec' in log_name:
                img, form = log_dict
                self.log.add_image(log_name,
                                   img,
                                   global_step=self.step,
                                   dataformats=form)
            elif 'text' in log_name or 'hyp' in log_name:
                self.log.add_text(log_name, log_dict, self.step)
            else:
                self.log.add_scalars(log_name, log_dict, self.step)

    def save_checkpoint(self, f_name, metric, score, show_msg=True):
        '''' 
        Ckpt saver
            f_name - <str> the name phnof ckpt file (w/o prefix) to store, overwrite if existed
            score  - <float> The value of metric used to evaluate model
        '''
        ckpt_path = os.path.join(self.ckpdir, f_name)
        full_dict = {
            "model": self.model.state_dict(),
            "optimizer": self.optimizer.get_opt_state_dict(),
            "global_step": self.step,
            metric: score
        }
        # Additional modules to save
        # if self.amp:
        #    full_dict['amp'] = self.amp_lib.state_dict()
        if self.emb_decoder is not None:
            full_dict['emb_decoder'] = self.emb_decoder.state_dict()

        torch.save(full_dict, ckpt_path)
        if show_msg:
            self.verbose(
                "Saved checkpoint (step = {}, {} = {:.2f}) and status @ {}".
                format(human_format(self.step), metric, score, ckpt_path))

    def enable_apex(self):
        if self.amp:
            # Enable mixed precision computation (ToDo: Save/Load amp)
            from apex import amp
            self.amp_lib = amp
            self.verbose(
                "AMP enabled (check https://github.com/NVIDIA/apex for more details)."
            )
            self.model, self.optimizer.opt = self.amp_lib.initialize(
                self.model, self.optimizer.opt, opt_level='O1')

    # ----------------------------------- Abtract Methods ------------------------------------------ #
    @abc.abstractmethod
    def load_data(self):
        '''
        Called by main to load all data
        After this call, data related attributes should be setup (e.g. self.tr_set, self.dev_set)
        No return value
        '''
        raise NotImplementedError

    @abc.abstractmethod
    def set_model(self):
        '''
        Called by main to set models
        After this call, model related attributes should be setup (e.g. self.l2_loss)
        The followings MUST be setup
            - self.model (torch.nn.Module)
            - self.optimizer (src.Optimizer),
                init. w/ self.optimizer = src.Optimizer(self.model.parameters(),**self.config['hparas'])
        Loading pre-trained model should also be performed here 
        No return value
        '''
        raise NotImplementedError

    @abc.abstractmethod
    def exec(self):
        '''
        Called by main to execute training/inference
        '''
        raise NotImplementedError
예제 #3
0
class BaseSolver():
    ''' 
    Prototype Solver for all kinds of tasks
    Arguments
        config - yaml-styled config
        paras  - argparse outcome
        mode - string that specifies training/testing
    '''
    def __init__(self, config, paras):
        # General Settings
        self.config = config
        self.paras = paras
        for k, v in default_hparas.items():
            setattr(self, k, v)
        if self.paras.gpu and torch.cuda.is_available():
            self.gpu = True
            self.device = torch.device('cuda')
        else:
            self.gpu = False
            self.device = torch.device('cpu')

        # Settings for training/testing
        self.mode = self.paras.mode  # legacy, should be removed

        # Name experiment
        self.exp_name = paras.name
        if self.exp_name is None:
            # By default, exp is named after config file
            self.exp_name = paras.config.split('/')[-1].split('.y')[0]
            self.exp_name += '_sd{}'.format(paras.seed)

        # Filepath setup
        os.makedirs(paras.ckpdir, exist_ok=True)
        self.ckpdir = os.path.join(paras.ckpdir, self.exp_name)
        os.makedirs(self.ckpdir, exist_ok=True)

        # Logger settings
        self.logdir = os.path.join(paras.logdir, self.exp_name)
        self.log = SummaryWriter(self.logdir, flush_secs=self.TB_FLUSH_FREQ)
        self.timer = Timer()

        # Hyperparameters
        self.step = 0
        self.epoch = config['hparas']['epoch']

        self.verbose('Exp. name : {}'.format(self.exp_name))
        self.verbose('Loading data...')

    def backward(self, loss):
        '''
        Standard backward step with timer and debugger
        Arguments
            loss - the loss to perform loss.backward()
        '''
        self.timer.set()
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   self.GRAD_CLIP)
        if math.isnan(grad_norm):
            self.verbose('Error : grad norm is NaN @ step ' + str(self.step))
        else:
            self.optimizer.step()
        self.timer.cnt('bw')
        return grad_norm

    def load_ckpt(self):
        ''' Load ckpt if --load option is specified '''
        if self.paras.load:
            # Load weights
            ckpt = torch.load(self.paras.load,
                              map_location=self.device
                              if self.paras.mode == 'train' else 'cpu')
            ckpt['model'] = {k.replace('module.','',1):v \
                                for k,v in ckpt['model'].items()}
            self.model.load_state_dict(ckpt['model'])

            # Load task-dependent items
            metric = "None"
            score = 0.0
            for k, v in ckpt.items():
                if type(v) is float:
                    metric, score = k, v
            if self.paras.mode == 'train':
                self.cur_epoch = ckpt['epoch']
                self.step = ckpt['global_step']
                self.optimizer.load_opt_state_dict(ckpt['optimizer'])
                msg = \
                    'Load ckpt from {}, restarting at step {} \
                    (recorded {} = {:.2f} %)'\
                    .format(self.paras.load, self.step, metric, score)
                self.verbose(msg)
            else:
                # Inference
                msg = 'Evaluation target = {} (recorded {} = {:.2f} %)'\
                      .format(self.paras.load, metric, score)
                self.verbose(msg)

    def verbose(self, msg, display_step=False):
        ''' Verbose function for print information to stdout'''
        header = '[' + human_format(
            self.step) + ']' if display_step else '[INFO]'
        if self.paras.verbose:
            if type(msg) == list:
                for m in msg:
                    print(header, m.ljust(100))
            else:
                print(header, msg.ljust(100))

    def progress(self, msg):
        ''' Verbose function for updating progress on stdout 
            Do not include newline in msg '''
        if self.paras.verbose:
            sys.stdout.write("\033[K")  # Clear line
            print('[Ep {}] {}'.format(human_format(self.cur_epoch), msg),
                  end='\r')

    def write_log(self, log_name, log_dict, bins=None):
        ''' Write log to TensorBoard
            log_name  - <str> Name of tensorboard variable 
            log_dict - <dict>/<array> Value of variable (e.g. dict of losses)
        '''
        if log_dict is not None:
            if type(log_dict) is dict:
                log_dict = {
                    key: val
                    for key, val in log_dict.items()
                    if (val is not None and not math.isnan(val))
                }
                self.log.add_scalars(log_name, log_dict, self.step)
            elif 'Hist.' in log_name or 'Spec' in log_name:
                img, form = log_dict
                self.log.add_image(log_name,
                                   img,
                                   global_step=self.step,
                                   dataformats=form)
            else:
                raise NotImplementedError

    def save_checkpoint(self, f_name, metric, score, show_msg=True):
        '''' pt saver
            f_name - <str> the name of ckpt (w/o prefix), overwrite if existed
            score  - <float> The value of metric used to evaluate model
        '''
        ckpt_path = os.path.join(self.ckpdir, f_name)
        full_dict = {
            "model": self.model.state_dict(),
            "optimizer": self.optimizer.get_opt_state_dict(),
            "global_step": self.step,
            "epoch": self.cur_epoch,
            metric: score
        }
        torch.save(full_dict, ckpt_path)
        if show_msg:
            msg = "Saved checkpoint (epoch = {}, {} = {:.2f}) and status @ {}"
            self.verbose(
                msg.format(human_format(self.cur_epoch), metric, score,
                           ckpt_path))
        return ckpt_path

    # ----------------------------------- Abtract Methods ------------------- #
    @abc.abstractmethod
    def load_data(self):
        '''
        Called by main to load all data
        After this call, data related attributes should be setup 
        (e.g. self.tr_set, self.dev_set)
        No return value
        '''
        raise NotImplementedError

    @abc.abstractmethod
    def set_model(self):
        '''
        Called by main to set models
        After this call, model related attributes should be setup 
        The followings MUST be setup
            - self.model (torch.nn.Module)
            - self.optimizer (src.Optimizer),
        Loading pre-trained model should also be performed here 
        No return value
        '''
        raise NotImplementedError

    @abc.abstractmethod
    def exec(self):
        '''
        Called by main to execute training/inference
        '''
        raise NotImplementedError