def validate(self):
        # Eval mode
        self.model.eval()
        if self.emb_decoder is not None:
            self.emb_decoder.eval()
        dev_wer = {'att': [], 'ctc': []}

        for i, data in enumerate(self.dv_set):
            self.progress('Valid step - {}/{}'.format(i + 1, len(self.dv_set)))
            # Fetch data
            feat, feat_len, txt, txt_len = self.fetch_data(data)

            # Forward model
            with torch.no_grad():
                ctc_output, encode_len, att_output, att_align, dec_state = \
                    self.model(feat, feat_len, int(max(txt_len)*self.DEV_STEP_RATIO),
                               emb_decoder=self.emb_decoder)

            dev_wer['att'].append(cal_er(self.tokenizer, att_output, txt))
            dev_wer['ctc'].append(
                cal_er(self.tokenizer, ctc_output, txt, ctc=True))

            # Show some example on tensorboard
            if i == len(self.dv_set) // 2:
                for i in range(min(len(txt), self.DEV_N_EXAMPLE)):
                    if self.step == 1:
                        self.write_log('true_text{}'.format(i),
                                       self.tokenizer.decode(txt[i].tolist()))
                    if att_output is not None:
                        self.write_log(
                            'att_align{}'.format(i),
                            feat_to_fig(att_align[i, 0, :, :].cpu().detach()))
                        self.write_log(
                            'att_text{}'.format(i),
                            self.tokenizer.decode(
                                att_output[i].argmax(dim=-1).tolist()))
                    if ctc_output is not None:
                        self.write_log(
                            'ctc_text{}'.format(i),
                            self.tokenizer.decode(
                                ctc_output[i].argmax(dim=-1).tolist(),
                                ignore_repeat=True))

        # Ckpt if performance improves
        for task in ['att', 'ctc']:
            dev_wer[task] = sum(dev_wer[task]) / len(dev_wer[task])
            if dev_wer[task] < self.best_wer[task]:
                self.best_wer[task] = dev_wer[task]
                self.save_checkpoint('best_{}.pth'.format(task), 'wer',
                                     dev_wer[task])
            self.write_log('wer', {'dv_' + task: dev_wer[task]})
        self.save_checkpoint('latest.pth',
                             'wer',
                             dev_wer['att'],
                             show_msg=False)

        # Resume training
        self.model.train()
        if self.emb_decoder is not None:
            self.emb_decoder.train()
示例#2
0
    def validate(self):
        # Eval mode
        self.model.eval()
        dev_wer = {'rnnt': []}
        dev_loss = {'rnnt': []}

        for i, data in enumerate(self.dv_set):
            self.progress('Valid step - {}/{}'.format(i + 1, len(self.dv_set)))
            # Fetch data
            feat, feat_len, txt, txt_len = self.fetch_data(data)
            '''
            # Forward model
            with torch.no_grad():
                logits, targets, enc_len, targets_len = self.model(feat, feat_len, txt, txt_len)
                rnnt_loss = self.rnntloss(logits, targets, enc_len, targets_len)
            
            dev_loss['rnnt'].append(rnnt_loss.item())
            '''

            with torch.no_grad():
                rnnt_output = self.model.decode(feat, feat_len)
            dev_wer['rnnt'].append(cal_er(self.tokenizer, rnnt_output, txt))

            # Show some example on tensorboard
            if i == len(self.dv_set) // 2:
                for i in range(min(len(txt), self.DEV_N_EXAMPLE)):
                    if self.step == 1:
                        self.write_log('true_text{}'.format(i),
                                       self.tokenizer.decode(txt[i].tolist()))
                    self.write_log('rnnt_text{}'.format(i),
                                   self.tokenizer.decode(rnnt_output[i]))
        '''
        # Ckpt if performance improves
        for task in ['rnnt']:
            dev_loss[task] = sum(dev_loss[task])/len(dev_loss[task])
            if dev_loss[task] < self.best_loss[task]:
                self.best_loss[task] = dev_loss[task]
                self.save_checkpoint('best_{}.pth'.format(task), 'loss', dev_loss[task])
            self.write_log('loss', {'dv_'+task: dev_loss[task]})
        self.save_checkpoint('latest.pth', 'loss', dev_loss['rnnt'], show_msg=False)
        '''

        for task in ['rnnt']:
            dev_wer[task] = sum(dev_wer[task]) / len(dev_wer[task])
            if dev_wer[task] < self.best_wer[task]:
                self.best_wer[task] = dev_wer[task]
                self.save_checkpoint('best_{}.pth'.format(task), 'wer',
                                     dev_wer[task])
            self.write_log('wer', {'dv_' + task: dev_wer[task]})
        self.save_checkpoint('latest.pth',
                             'wer',
                             dev_wer['rnnt'],
                             show_msg=False)
        # Resume training
        self.model.train()
    def exec(self):
        ''' Training End-to-end ASR system '''
        self.verbose('Total training steps {}.'.format(
            human_format(self.max_step)))
        ctc_loss, att_loss, emb_loss = None, None, None
        n_epochs = 0
        self.timer.set()

        while self.step < self.max_step:
            # Renew dataloader to enable random sampling
            if self.curriculum > 0 and n_epochs == self.curriculum:
                self.verbose(
                    'Curriculum learning ends after {} epochs, starting random sampling.'
                    .format(n_epochs))
                self.tr_set, _, _, _, _, _ = \
                    load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory,
                                 False, **self.config['data'])
            for data in self.tr_set:
                # Pre-step : update tf_rate/lr_rate and do zero_grad
                tf_rate = self.optimizer.pre_step(self.step)
                total_loss = 0

                # Fetch data
                feat, feat_len, txt, txt_len = self.fetch_data(data)
                self.timer.cnt('rd')

                # Forward model
                # Note: txt should NOT start w/ <sos>
                ctc_output, encode_len, att_output, att_align, dec_state = \
                    self.model(feat, feat_len, max(txt_len), tf_rate=tf_rate,
                               teacher=txt, get_dec_state=self.emb_reg)

                # Plugins
                if self.emb_reg:
                    emb_loss, fuse_output = self.emb_decoder(dec_state,
                                                             att_output,
                                                             label=txt)
                    total_loss += self.emb_decoder.weight * emb_loss

                # Compute all objectives
                if ctc_output is not None:
                    if self.paras.cudnn_ctc:
                        ctc_loss = self.ctc_loss(
                            ctc_output.transpose(0, 1),
                            txt.to_sparse().values().to(device='cpu',
                                                        dtype=torch.int32),
                            [ctc_output.shape[1]] * len(ctc_output),
                            txt_len.cpu().tolist())
                    else:
                        ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1),
                                                 txt, encode_len, txt_len)
                    total_loss += ctc_loss * self.model.ctc_weight

                if att_output is not None:
                    b, t, _ = att_output.shape
                    att_output = fuse_output if self.emb_fuse else att_output
                    att_loss = self.seq_loss(
                        att_output.contiguous().view(b * t, -1),
                        txt.contiguous().view(-1))
                    total_loss += att_loss * (1 - self.model.ctc_weight)

                self.timer.cnt('fw')

                # Backprop
                grad_norm = self.backward(total_loss)
                self.step += 1

                # Logger
                if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0):
                    self.progress(
                        'Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'.
                        format(total_loss.cpu().item(), grad_norm,
                               self.timer.show()))
                    self.write_log('loss', {
                        'tr_ctc': ctc_loss,
                        'tr_att': att_loss
                    })
                    self.write_log('emb_loss', {'tr': emb_loss})
                    self.write_log(
                        'wer', {
                            'tr_att':
                            cal_er(self.tokenizer, att_output, txt),
                            'tr_ctc':
                            cal_er(self.tokenizer, ctc_output, txt, ctc=True)
                        })
                    if self.emb_fuse:
                        if self.emb_decoder.fuse_learnable:
                            self.write_log(
                                'fuse_lambda',
                                {'emb': self.emb_decoder.get_weight()})
                        self.write_log('fuse_temp',
                                       {'temp': self.emb_decoder.get_temp()})

                # Validation
                if (self.step == 1) or (self.step % self.valid_step == 0):
                    self.validate()

                # End of step
                # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354
                torch.cuda.empty_cache()
                self.timer.set()
                if self.step > self.max_step:
                    break
            n_epochs += 1
        self.log.close()
示例#4
0
    def finetune(self, filename, fixed_text, max_step=5):
        # Load data for finetune
        self.verbose('Total training steps {}.'.format(
            human_format(max_step)))
        ctc_loss, att_loss, emb_loss = None, None, None
        n_epochs = 0
        accum_count = 0
        self.timer.set()
        step = 0
        for data in self.tr_set:
            # Pre-step : update tf_rate/lr_rate and do zero_grad
            if max_step == 0:
                break
            tf_rate = self.optimizer.pre_step(400000)
            total_loss = 0

            # Fetch data
            finetune_data = self.fetch_finetune_data(filename, fixed_text)
            main_batch = self.fetch_data(data)
            new_batch = self.merge_batch(main_batch, finetune_data)
            feat, feat_len, txt, txt_len = new_batch
            self.timer.cnt('rd')

            # Forward model
            # Note: txt should NOT start w/ <sos>
            ctc_output, encode_len, att_output, att_align, dec_state = \
                self.model(feat, feat_len, max(txt_len), tf_rate=tf_rate,
                            teacher=txt, get_dec_state=self.emb_reg)

            # Plugins
            if self.emb_reg:
                emb_loss, fuse_output = self.emb_decoder(
                    dec_state, att_output, label=txt)
                total_loss += self.emb_decoder.weight*emb_loss

            # Compute all objectives
            if ctc_output is not None:
                if self.paras.cudnn_ctc:
                    ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1),
                                                txt.to_sparse().values().to(device='cpu', dtype=torch.int32),
                                                [ctc_output.shape[1]] *
                                                len(ctc_output),
                                                txt_len.cpu().tolist())
                else:
                    ctc_loss = self.ctc_loss(ctc_output.transpose(
                        0, 1), txt, encode_len, txt_len)
                total_loss += ctc_loss*self.model.ctc_weight

            if att_output is not None:
                b, t, _ = att_output.shape
                att_output = fuse_output if self.emb_fuse else att_output
                att_loss = self.seq_loss(
                    att_output.contiguous().view(b*t, -1), txt.contiguous().view(-1))
                total_loss += att_loss*(1-self.model.ctc_weight)

            self.timer.cnt('fw')

            # Backprop
            grad_norm = self.backward(total_loss)
            step += 1

            # Logger
            self.progress('Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'
                        .format(total_loss.cpu().item(), grad_norm, self.timer.show()))
            self.write_log(
                'loss', {'tr_ctc': ctc_loss, 'tr_att': att_loss})
            self.write_log('emb_loss', {'tr': emb_loss})
            self.write_log('wer', {'tr_att': cal_er(self.tokenizer, att_output, txt),
                                'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, ctc=True)})
            if self.emb_fuse:
                if self.emb_decoder.fuse_learnable:
                    self.write_log('fuse_lambda', {
                                'emb': self.emb_decoder.get_weight()})
                self.write_log(
                    'fuse_temp', {'temp': self.emb_decoder.get_temp()})

            # End of step
            # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354
            torch.cuda.empty_cache()
            self.timer.set()
            if step > max_step:
                break
        ret = self.validate()
        self.log.close()
        return ret
示例#5
0
    def exec(self):
        ''' Training End-to-end ASR system '''
        self.verbose('Total training steps {}.'.format(
            human_format(self.max_step)))
        ctc_loss = None
        n_epochs = 0
        self.timer.set()

        while self.step < self.max_step:
            # Renew dataloader to enable random sampling
            if self.curriculum > 0 and n_epochs == self.curriculum:
                self.verbose(
                    'Curriculum learning ends after {} epochs, starting random sampling.'
                    .format(n_epochs))
                self.tr_set, _, _, _, _, _ = \
                    load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory,
                                 False, **self.config['data'])
            for data in self.tr_set:
                # Pre-step : update tf_rate/lr_rate and do zero_grad
                # zero grad here
                tf_rate = self.optimizer.pre_step(self.step)
                total_loss = 0

                # Fetch data
                feat, feat_len, txt, txt_len = self.fetch_data(data)
                self.timer.cnt('rd')

                # Forward model
                # Note: txt should NOT start w/ <sos>
                ctc_output, encode_len = self.model(feat, feat_len)

                # Compute all objectives
                if self.paras.cudnn_ctc:
                    ctc_loss = self.ctc_loss(
                        ctc_output.transpose(0, 1),
                        txt.to_sparse().values().to(device='cpu',
                                                    dtype=torch.int32),
                        [ctc_output.shape[1]] * len(ctc_output),
                        txt_len.cpu().tolist())
                else:
                    ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1), txt,
                                             encode_len, txt_len)

                total_loss = ctc_loss

                self.timer.cnt('fw')

                # Backprop
                grad_norm = self.backward(total_loss)

                self.step += 1
                # Logger

                if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0):
                    self.progress(
                        'Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'.
                        format(total_loss.cpu().item(), grad_norm,
                               self.timer.show()))
                    #self.write_log('wer', {'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, ctc=True)})
                    ctc_output = [
                        x[:length].argmax(dim=-1)
                        for x, length in zip(ctc_output, encode_len)
                    ]
                    self.write_log(
                        'per', {
                            'tr_ctc':
                            cal_er(self.tokenizer,
                                   ctc_output,
                                   txt,
                                   mode='per',
                                   ctc=True)
                        })
                    self.write_log(
                        'wer', {
                            'tr_ctc':
                            cal_er(self.tokenizer,
                                   ctc_output,
                                   txt,
                                   mode='wer',
                                   ctc=True)
                        })
                    self.write_log('loss', {'tr_ctc': ctc_loss.cpu().item()})

                # Validation
                if (self.step == 1) or (self.step % self.valid_step == 0):
                    self.validate()

                # End of step
                # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354
                torch.cuda.empty_cache()
                self.timer.set()
                if self.step > self.max_step:
                    break
            n_epochs += 1
示例#6
0
    def validate(self):
        # Eval mode
        self.model.eval()
        dev_per = {'ctc': []}
        dev_wer = {'ctc': []}

        for i, data in enumerate(self.dv_set):
            self.progress('Valid step - {}/{}'.format(i + 1, len(self.dv_set)))
            # Fetch data
            feat, feat_len, txt, txt_len = self.fetch_data(data)

            # Forward model
            with torch.no_grad():
                ctc_output, encode_len = self.model(feat, feat_len)

            ctc_output = [
                x[:length].argmax(dim=-1)
                for x, length in zip(ctc_output, encode_len)
            ]
            dev_per['ctc'].append(
                cal_er(self.tokenizer, ctc_output, txt, mode='per', ctc=True))
            dev_wer['ctc'].append(
                cal_er(self.tokenizer, ctc_output, txt, mode='wer', ctc=True))

            # Show some example on tensorboard
            if i == len(self.dv_set) // 2:
                for i in range(min(len(txt), self.DEV_N_EXAMPLE)):
                    #if self.step == 1:
                    self.write_log('true_text{}'.format(i),
                                   self.tokenizer.decode(txt[i].tolist()))
                    self.write_log(
                        'ctc_text{}'.format(i),
                        self.tokenizer.decode(ctc_output[i].tolist(),
                                              ignore_repeat=True))

        # Ckpt if performance improves
        for task in ['ctc']:
            dev_wer[task] = sum(dev_wer[task]) / len(dev_wer[task])
            dev_per[task] = sum(dev_per[task]) / len(dev_per[task])
            if dev_per[task] < self.best_per[task]:
                self.best_per[task] = dev_per[task]
                self.save_checkpoint('best_{}.pth'.format('per'), 'per',
                                     dev_per[task])
                self.log.log_other('dv_best_per', self.best_per['ctc'])
            if self.eval_target == 'char' and dev_wer[task] < self.best_wer[
                    task]:
                self.best_wer[task] = dev_wer[task]
                self.save_checkpoint('best_{}.pth'.format('wer'), 'wer',
                                     dev_wer[task])
                self.log.log_other('dv_best_wer', self.best_wer['ctc'])

            self.write_log('per', {'dv_' + task: dev_per[task]})
            if self.eval_target == 'char':
                self.write_log('wer', {'dv_' + task: dev_wer[task]})
        self.save_checkpoint('latest.pth',
                             'per',
                             dev_per['ctc'],
                             show_msg=False)
        if self.paras.save_every:
            self.save_checkpoint(f'{self.step}.path',
                                 'per',
                                 dev_per['ctc'],
                                 show_msg=False)

        # Resume training
        self.model.train()
    def validate(self, _dv_set, _name):
        # Eval mode
        self.model.eval()
        if self.emb_decoder is not None: self.emb_decoder.eval()
        dev_wer = {'att': [], 'ctc': []}
        dev_cer = {'att': [], 'ctc': []}
        dev_er = {'att': [], 'ctc': []}

        for i, data in enumerate(_dv_set):
            self.progress('Valid step - {}/{}'.format(i + 1, len(_dv_set)))
            # Fetch data
            feat, feat_len, txt, txt_len = self.fetch_data(data)

            # Forward model
            with torch.no_grad():
                ctc_output, encode_len, att_output, att_align, dec_state = \
                    self.model( feat, feat_len, int(max(txt_len)*self.DEV_STEP_RATIO),
                                    emb_decoder=self.emb_decoder)

            if att_output is not None:
                dev_wer['att'].append(
                    cal_er(self.tokenizer, att_output, txt, mode='wer'))
                dev_cer['att'].append(
                    cal_er(self.tokenizer, att_output, txt, mode='cer'))
                dev_er['att'].append(
                    cal_er(self.tokenizer, att_output, txt,
                           mode=self.val_mode))
            if ctc_output is not None:
                dev_wer['ctc'].append(
                    cal_er(self.tokenizer,
                           ctc_output,
                           txt,
                           mode='wer',
                           ctc=True))
                dev_cer['ctc'].append(
                    cal_er(self.tokenizer,
                           ctc_output,
                           txt,
                           mode='cer',
                           ctc=True))
                dev_er['ctc'].append(
                    cal_er(self.tokenizer,
                           ctc_output,
                           txt,
                           mode=self.val_mode,
                           ctc=True))

            # Show some example on tensorboard
            if i == len(_dv_set) // 2:
                for i in range(min(len(txt), self.DEV_N_EXAMPLE)):
                    if self.step == 1:
                        self.write_log('true_text_{}_{}'.format(_name, i),
                                       self.tokenizer.decode(txt[i].tolist()))
                    if att_output is not None:
                        self.write_log(
                            'att_align_{}_{}'.format(_name, i),
                            feat_to_fig(att_align[i, 0, :, :].cpu().detach()))
                        self.write_log(
                            'att_text_{}_{}'.format(_name, i),
                            self.tokenizer.decode(
                                att_output[i].argmax(dim=-1).tolist()))
                    if ctc_output is not None:
                        self.write_log(
                            'ctc_text_{}_{}'.format(_name, i),
                            self.tokenizer.decode(
                                ctc_output[i].argmax(dim=-1).tolist(),
                                ignore_repeat=True))

        # Ckpt if performance improves
        tasks = []
        if len(dev_er['att']) > 0:
            tasks.append('att')
        if len(dev_er['ctc']) > 0:
            tasks.append('ctc')

        for task in tasks:
            dev_er[task] = sum(dev_er[task]) / len(dev_er[task])
            dev_wer[task] = sum(dev_wer[task]) / len(dev_wer[task])
            dev_cer[task] = sum(dev_cer[task]) / len(dev_cer[task])
            if dev_er[task] < self.best_wer[task][_name]:
                self.best_wer[task][_name] = dev_er[task]
                self.save_checkpoint(
                    'best_{}_{}.pth'.format(
                        task, _name +
                        (self.save_name if self.transfer_learning else '')),
                    self.val_mode, dev_er[task], _name)
            if self.step >= self.max_step:
                self.save_checkpoint(
                    'last_{}_{}.pth'.format(
                        task, _name +
                        (self.save_name if self.transfer_learning else '')),
                    self.val_mode, dev_er[task], _name)
            self.write_log(self.WER,
                           {'dv_' + task + '_' + _name.lower(): dev_wer[task]})
            self.write_log('cer',
                           {'dv_' + task + '_' + _name.lower(): dev_cer[task]})
            # if self.transfer_learning:
            #     print('[{}] WER {:.4f} / CER {:.4f} on {}'.format(human_format(self.step), dev_wer[task], dev_cer[task], _name))

        # Resume training
        self.model.train()
        if self.transfer_learning:
            self.model.encoder.fix_layers(self.fix_enc)
            if self.fix_dec and self.model.enable_att:
                self.model.decoder.fix_layers()
            if self.fix_dec and self.model.enable_ctc:
                self.model.fix_ctc_layer()

        if self.emb_decoder is not None: self.emb_decoder.train()
    def exec(self):
        ''' Training End-to-end ASR system '''
        self.verbose('Total training steps {}.'.format(
            human_format(self.max_step)))
        if self.transfer_learning:
            self.model.encoder.fix_layers(self.fix_enc)
            if self.fix_dec and self.model.enable_att:
                self.model.decoder.fix_layers()
            if self.fix_dec and self.model.enable_ctc:
                self.model.fix_ctc_layer()

        self.n_epochs = 0
        self.timer.set()
        '''early stopping for ctc '''
        self.early_stoping = self.config['hparas']['early_stopping']
        stop_epoch = 10
        batch_size = self.config['data']['corpus']['batch_size']
        stop_step = len(self.tr_set) * stop_epoch // batch_size

        while self.step < self.max_step:
            ctc_loss, att_loss, emb_loss = None, None, None
            # Renew dataloader to enable random sampling

            if self.curriculum > 0 and n_epochs == self.curriculum:
                self.verbose(
                    'Curriculum learning ends after {} epochs, starting random sampling.'
                    .format(n_epochs))
                self.tr_set, _, _, _, _, _ = \
                         load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory,
                                      False, **self.config['data'])

            for data in self.tr_set:
                # Pre-step : update tf_rate/lr_rate and do zero_grad
                tf_rate = self.optimizer.pre_step(self.step)
                total_loss = 0

                # Fetch data
                feat, feat_len, txt, txt_len = self.fetch_data(data,
                                                               train=True)

                self.timer.cnt('rd')
                # Forward model
                # Note: txt should NOT start w/ <sos>
                ctc_output, encode_len, att_output, att_align, dec_state = \
                    self.model( feat, feat_len, max(txt_len), tf_rate=tf_rate,
                                    teacher=txt, get_dec_state=self.emb_reg)
                # Clear not used objects
                del att_align

                # Plugins
                if self.emb_reg:
                    emb_loss, fuse_output = self.emb_decoder(dec_state,
                                                             att_output,
                                                             label=txt)
                    total_loss += self.emb_decoder.weight * emb_loss
                else:
                    del dec_state
                ''' early stopping ctc'''
                if self.early_stoping:
                    if self.step > stop_step:
                        ctc_output = None
                        self.model.ctc_weight = 0
                #print(ctc_output.shape)
                # Compute all objectives
                if ctc_output is not None:
                    if self.paras.cudnn_ctc:
                        ctc_loss = self.ctc_loss(
                            ctc_output.transpose(0, 1),
                            txt.to_sparse().values().to(device='cpu',
                                                        dtype=torch.int32),
                            [ctc_output.shape[1]] * len(ctc_output),
                            #[int(encode_len.max()) for _ in encode_len],
                            txt_len.cpu().tolist())
                    else:
                        ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1),
                                                 txt, encode_len, txt_len)
                    total_loss += ctc_loss * self.model.ctc_weight
                    del encode_len

                if att_output is not None:
                    #print(att_output.shape)
                    b, t, _ = att_output.shape
                    att_output = fuse_output if self.emb_fuse else att_output
                    att_loss = self.seq_loss(att_output.view(b * t, -1),
                                             txt.view(-1))
                    # Sum each uttr and devide by length then mean over batch
                    # att_loss = torch.mean(torch.sum(att_loss.view(b,t),dim=-1)/torch.sum(txt!=0,dim=-1).float())
                    total_loss += att_loss * (1 - self.model.ctc_weight)

                self.timer.cnt('fw')

                # Backprop
                grad_norm = self.backward(total_loss)

                self.step += 1

                # Logger
                if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0):
                    self.progress('Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'\
                            .format(total_loss.cpu().item(),grad_norm,self.timer.show()))
                    self.write_log('emb_loss', {'tr': emb_loss})
                    if att_output is not None:
                        self.write_log('loss', {'tr_att': att_loss})
                        self.write_log(self.WER, {
                            'tr_att':
                            cal_er(self.tokenizer, att_output, txt)
                        })
                        self.write_log(
                            'cer', {
                                'tr_att':
                                cal_er(self.tokenizer,
                                       att_output,
                                       txt,
                                       mode='cer')
                            })
                    if ctc_output is not None:
                        self.write_log('loss', {'tr_ctc': ctc_loss})
                        self.write_log(
                            self.WER, {
                                'tr_ctc':
                                cal_er(
                                    self.tokenizer, ctc_output, txt, ctc=True)
                            })
                        self.write_log(
                            'cer', {
                                'tr_ctc':
                                cal_er(self.tokenizer,
                                       ctc_output,
                                       txt,
                                       mode='cer',
                                       ctc=True)
                            })
                        self.write_log(
                            'ctc_text_train',
                            self.tokenizer.decode(
                                ctc_output[0].argmax(dim=-1).tolist(),
                                ignore_repeat=True))
                    # if self.step==1 or self.step % (self.PROGRESS_STEP * 5) == 0:
                    #     self.write_log('spec_train',feat_to_fig(feat[0].transpose(0,1).cpu().detach(), spec=True))
                    #del total_loss

                    if self.emb_fuse:
                        if self.emb_decoder.fuse_learnable:
                            self.write_log(
                                'fuse_lambda',
                                {'emb': self.emb_decoder.get_weight()})
                        self.write_log('fuse_temp',
                                       {'temp': self.emb_decoder.get_temp()})

                # Validation
                if (self.step == 1) or (self.step % self.valid_step == 0):
                    if type(self.dv_set) is list:
                        for dv_id in range(len(self.dv_set)):
                            self.validate(self.dv_set[dv_id],
                                          self.dv_names[dv_id])
                    else:
                        self.validate(self.dv_set, self.dv_names)
                if self.step % (len(self.tr_set) //
                                batch_size) == 0:  # one epoch
                    print('Have finished epoch: ', self.n_epochs)
                    self.n_epochs += 1

                if self.lr_scheduler == None:
                    lr = self.optimizer.opt.param_groups[0]['lr']

                    if self.step == 1:
                        print(
                            '[INFO]    using lr schedular defined by Daniel, init lr = ',
                            lr)

                    if self.step > 99999 and self.step % 2000 == 0:
                        lr = lr * 0.85
                        for param_group in self.optimizer.opt.param_groups:
                            param_group['lr'] = lr
                        print('[INFO]     at step:', self.step)
                        print('[INFO]   lr reduce to', lr)

                    #self.lr_scheduler.step(total_loss)
                # End of step
                # if self.step % EMPTY_CACHE_STEP == 0:
                # Empty cuda cache after every fixed amount of steps
                torch.cuda.empty_cache(
                )  # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354
                self.timer.set()
                if self.step > self.max_step: break

            #update lr_scheduler

        self.log.close()
        print('[INFO] Finished training after', human_format(self.max_step),
              'steps.')