コード例 #1
0
    def valid(self):
        '''Perform validation step (!!!NOTE!!! greedy decoding on Attention decoder only)'''
        val_cer = 0.0
        val_len = 0
        all_pred,all_true = [],[]
        ctc_results = []
        with torch.no_grad():
            for cur_b,(x,y,state_len) in enumerate(self.dev_set):
                self.progress(' '.join(['Valid step - (',str(cur_b),'/',str(len(self.dev_set)),')']))

                # Prepare data
                if len(x.shape) == 4: x = x.squeeze(0)
                if len(y.shape) == 3: y = y.squeeze(0)
                x = x.to(device = self.device,dtype=torch.float32)
                y = y.to(device = self.device,dtype=torch.long)
                # state_len = torch.sum(torch.sum(x.cpu(),dim=-1) != 0, dim=-1)
                # state_len = [int(sl) for sl in state_len]
                ans_len = int(torch.max(torch.sum(y != 0, dim=-1)))

                # Forward
                ctc_pred, state_len, att_pred, att_maps = self.asr_model(x, ans_len+VAL_STEP,state_len=state_len)
                ctc_pred = torch.argmax(ctc_pred,dim=-1).cpu() if ctc_pred is not None else None
                ctc_results.append(ctc_pred)

                # Result
                label = y[:,1:ans_len+1].contiguous()
                t1,t2 = cal_cer(att_pred,label,mapper=self.mapper,get_sentence=True)
                all_pred += t1
                all_true += t2
                val_cer += cal_cer(att_pred,label,mapper=self.mapper)*int(x.shape[0])
                val_len += int(x.shape[0])

        # Dump model score to ensure model is corrected
        self.verbose('Validation Error Rate of Current model : {:.4f}      '.format(val_cer/val_len))
        self.verbose('See {} for validation results.'.format(os.path.join(self.ckpdir,'dev_att_decode.txt')))
        with open(os.path.join(self.ckpdir,'dev_att_decode.txt'),'w') as f:
            for hyp,gt in zip(all_pred,all_true):
                f.write(gt.lstrip()+'\t'+hyp.lstrip()+'\n')

        # Also dump CTC result if available
        if ctc_results[0] is not None:
            ctc_results = [i for ins in ctc_results for i in ins]
            ctc_text = []
            for pred in ctc_results:
                p = [i for i in pred.tolist() if i != 0]
                p = [k for k, g in itertools.groupby(p)]
                ctc_text.append(self.mapper.translate(p,return_string=True))
            self.verbose('Also, see {} for CTC validation results.'.format(os.path.join(self.ckpdir,'dev_ctc_decode.txt')))
            with open(os.path.join(self.ckpdir,'dev_ctc_decode.txt'),'w') as f:
                for hyp,gt in zip(ctc_text,all_true):
                    f.write(gt.lstrip()+'\t'+hyp.lstrip()+'\n')
コード例 #2
0
ファイル: solver.py プロジェクト: Eman22S/Amharic-Seq2Seq
    def valid(self):
        '''Perform validation step (!!!NOTE!!! greedy decoding with Attention decoder only)'''
        self.asr_model.eval()

        # Init stats
        val_loss, val_ctc, val_att, val_acc, val_cer = 0.0, 0.0, 0.0, 0.0, 0.0
        val_len = 0
        all_pred, all_true = [], []

        # Perform validation
        for cur_b, (x, y) in enumerate(self.dev_set):
            self.progress(' '.join([
                'Valid step -',
                str(self.step), '(',
                str(cur_b), '/',
                str(len(self.dev_set)), ')'
            ]))

            # Prepare data
            if len(x.shape) == 4: x = x.squeeze(0)
            if len(y.shape) == 3: y = y.squeeze(0)
            x = x.to(device=self.device, dtype=torch.float32)
            y = y.to(device=self.device, dtype=torch.long)
            state_len = torch.sum(torch.sum(x.cpu(), dim=-1) != 0, dim=-1)
            state_len = [int(sl) for sl in state_len]
            ans_len = int(torch.max(torch.sum(y != 0, dim=-1)))

            # Forward
            ctc_pred, state_len, att_pred, att_maps = self.asr_model(
                x, ans_len + VAL_STEP, state_len=state_len)

            # Compute attention loss & get decoding results
            label = y[:, 1:ans_len + 1].contiguous()
            if self.ctc_weight < 1:
                seq_loss = self.seq_loss(
                    att_pred[:, :ans_len, :].contiguous().view(
                        -1, att_pred.shape[-1]), label.view(-1))
                seq_loss = torch.sum(seq_loss.view(x.shape[0],-1),dim=-1)/torch.sum(y!=0,dim=-1)\
                           .to(device = self.device,dtype=torch.float32) # Sum each uttr and devide by length
                seq_loss = torch.mean(seq_loss)  # Mean by batch
                val_att += seq_loss.detach() * int(x.shape[0])
                t1, t2 = cal_cer(att_pred,
                                 label,
                                 mapper=self.mapper,
                                 get_sentence=True)
                all_pred += t1
                all_true += t2
                val_acc += cal_acc(att_pred, label) * int(x.shape[0])
                val_cer += cal_cer(att_pred, label, mapper=self.mapper) * int(
                    x.shape[0])

            # Compute CTC loss
            if self.ctc_weight > 0:
                target_len = torch.sum(y != 0, dim=-1)
                ctc_loss = self.ctc_loss(
                    F.log_softmax(ctc_pred.transpose(0, 1), dim=-1), label,
                    torch.LongTensor(state_len), target_len)
                val_ctc += ctc_loss.detach() * int(x.shape[0])

            val_len += int(x.shape[0])

        # Logger
        val_loss = (1 - self.ctc_weight) * val_att + self.ctc_weight * val_ctc
        loss_log = {}
        for k, v in zip(['dev_full', 'dev_ctc', 'dev_att'],
                        [val_loss, val_ctc, val_att]):
            if v > 0.0: loss_log[k] = v / val_len
        self.write_log('loss', loss_log)

        if self.ctc_weight < 1:
            # Plot attention map to log
            val_hyp, val_txt = cal_cer(att_pred,
                                       label,
                                       mapper=self.mapper,
                                       get_sentence=True)
            val_attmap = draw_att(att_maps, att_pred)

            # Record loss
            self.write_log('error rate', {'dev': val_cer / val_len})
            self.write_log('acc', {'dev': val_acc / val_len})
            for idx, attmap in enumerate(val_attmap):
                self.write_log('att_' + str(idx), attmap)
                self.write_log('hyp_' + str(idx), val_hyp[idx])
                self.write_log('txt_' + str(idx), val_txt[idx])

            # Save model by val er.
            if val_cer / val_len < self.best_val_ed:
                self.best_val_ed = val_cer / val_len
                self.verbose(
                    'Best val er       : {:.4f}       @ step {}'.format(
                        self.best_val_ed, self.step))
                torch.save(self.asr_model, os.path.join(self.ckpdir, 'asr'))
                if self.apply_clm:
                    torch.save(self.clm.clm, os.path.join(self.ckpdir, 'clm'))
                # Save hyps.
                with open(os.path.join(self.ckpdir, 'best_hyp.txt'), 'w') as f:
                    for t1, t2 in zip(all_pred, all_true):
                        f.write(t1 + ',' + t2 + '\n')

        self.asr_model.train()
コード例 #3
0
ファイル: solver.py プロジェクト: Eman22S/Amharic-Seq2Seq
    def exec(self):
        ''' Training End-to-end ASR system'''
        self.verbose('Training set total ' + str(len(self.train_set)) +
                     ' batches.')

        while self.step < self.max_step:
            for x, y in self.train_set:
                self.progress('Training step - ' + str(self.step))

                # Perform teacher forcing rate decaying
                tf_rate = self.tf_start - self.step * (
                    self.tf_start - self.tf_end) / self.max_step

                # Hack bucket, record state length for each uttr, get longest label seq for decode step
                assert len(
                    x.shape
                ) == 4, 'Bucketing should cause acoustic feature to have shape 1xBxTxD'
                assert len(
                    y.shape
                ) == 3, 'Bucketing should cause label have to shape 1xBxT'
                x = x.squeeze(0).to(device=self.device, dtype=torch.float32)
                y = y.squeeze(0).to(device=self.device, dtype=torch.long)
                state_len = np.sum(np.sum(x.cpu().data.numpy(), axis=-1) != 0,
                                   axis=-1)
                state_len = [int(sl) for sl in state_len]
                ans_len = int(torch.max(torch.sum(y != 0, dim=-1)))

                # ASR forwarding
                self.asr_opt.zero_grad()
                ctc_pred, state_len, att_pred, _ = self.asr_model(
                    x,
                    ans_len,
                    tf_rate=tf_rate,
                    teacher=y,
                    state_len=state_len)

                # Calculate loss function
                loss_log = {}
                label = y[:, 1:ans_len + 1].contiguous()
                ctc_loss = 0
                att_loss = 0

                # CE loss on attention decoder
                if self.ctc_weight < 1:
                    b, t, c = att_pred.shape
                    att_loss = self.seq_loss(att_pred.view(b * t, c),
                                             label.view(-1))
                    att_loss = torch.sum(att_loss.view(b,t),dim=-1)/torch.sum(y!=0,dim=-1)\
                               .to(device = self.device,dtype=torch.float32) # Sum each uttr and devide by length
                    att_loss = torch.mean(att_loss)  # Mean by batch
                    loss_log['train_att'] = att_loss

                # CTC loss on CTC decoder
                if self.ctc_weight > 0:
                    target_len = torch.sum(y != 0, dim=-1)
                    ctc_loss = self.ctc_loss(
                        F.log_softmax(ctc_pred.transpose(0, 1), dim=-1), label,
                        torch.LongTensor(state_len), target_len)
                    loss_log['train_ctc'] = ctc_loss

                asr_loss = (1 - self.ctc_weight
                            ) * att_loss + self.ctc_weight * ctc_loss
                loss_log['train_full'] = asr_loss

                # Adversarial loss from CLM
                if self.apply_clm and att_pred.shape[1] >= CLM_MIN_SEQ_LEN:
                    if (self.step % self.clm.update_freq) == 0:
                        # update CLM once in a while
                        clm_log, gp = self.clm.train(att_pred.detach(),
                                                     CLM_MIN_SEQ_LEN)
                        self.write_log('clm_score', clm_log)
                        self.write_log('clm_gp', gp)
                    adv_feedback = self.clm.compute_loss(F.softmax(att_pred))
                    asr_loss -= adv_feedback

                # Backprop
                asr_loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.asr_model.parameters(), GRAD_CLIP)
                if math.isnan(grad_norm):
                    self.verbose('Error : grad norm is NaN @ step ' +
                                 str(self.step))
                else:
                    self.asr_opt.step()

                # Logger
                self.write_log('loss', loss_log)
                if self.ctc_weight < 1:
                    self.write_log('acc', {'train': cal_acc(att_pred, label)})
                if self.step % TRAIN_WER_STEP == 0:
                    self.write_log('error rate', {
                        'train':
                        cal_cer(att_pred, label, mapper=self.mapper)
                    })

                # Validation
                if self.step % self.valid_step == 0:
                    self.asr_opt.zero_grad()
                    self.valid()

                self.step += 1
                if self.step > self.max_step: break
コード例 #4
0
    def greedy_decode(self, split):
        '''Perform validation step (!!!NOTE!!! greedy decoding on Attention decoder only)'''
        # TODO : Add rnnlm & ctc decode to greedy.
        val_cer = 0.0
        val_len = 0
        all_pred, all_true = [], []
        ctc_results = []

        ds = self.dev_set if split == 'dev' else self.test_set

        # for MWER only
        tf_data = False
        if self.mwer:
            print("TF={}".format(tf_data))
            idx = 0
            mwer_dir = str(
                self.pred_path) + '_dev' if split == 'dev' else self.pred_path
            if not os.path.exists(mwer_dir):
                os.makedirs(mwer_dir)
                os.makedirs(mwer_dir + '/data')
            f = open(os.path.join(mwer_dir, 'data.csv'), 'a')
        # Origin start
        with torch.no_grad():
            for cur_b, (x, y) in enumerate(tqdm(ds)):
                #self.progress(' '.join(['Decode step - (',str(cur_b),'/',str(len(ds)),')']))

                # Prepare data
                if len(x.shape) == 4: x = x.squeeze(0)
                if len(y.shape) == 3: y = y.squeeze(0)
                x = x.to(device=self.device, dtype=torch.float32)
                y = y.to(device=self.device, dtype=torch.long)
                state_len = torch.sum(torch.sum(x.cpu(), dim=-1) != 0, dim=-1)
                state_len = [int(sl) for sl in state_len]
                ans_len = int(torch.max(torch.sum(y != 0, dim=-1)))

                # Forward
                if self.mwer and tf_data:
                    ctc_pred, state_len, att_pred, att_maps = self.asr_model(
                        x, ans_len, tf_rate=1, teacher=y, state_len=state_len)
                else:
                    decode_len = ans_len+VAL_STEP if split =='dev' else \
                              int(np.ceil(max(state_len)*self.decode_step_ratio))
                    ctc_pred, state_len, att_pred, att_maps = self.asr_model(
                        x, decode_len, state_len=state_len)
                ctc_pred = torch.argmax(
                    ctc_pred, dim=-1).cpu() if ctc_pred is not None else None
                ctc_results.append(ctc_pred)

                ### MWER
                if self.mwer:
                    att_out = F.softmax(att_pred, dim=-1).cpu().numpy()

                    for output, ans in zip(att_out, y):
                        answer = ans.tolist()
                        answer = answer[1:answer.index(1)] + [1]
                        eos_pos = np.where(np.argmax(output, axis=-1) == 1)[0]
                        if len(eos_pos) > 0:
                            output = output[:eos_pos[0] + 1]
                        f_name = str(idx) + '.npy'
                        f.write("{},{}\n".format(
                            f_name, '_'.join([str(c) for c in answer])))
                        np.save(str(os.path.join(mwer_dir, 'data', f_name)),
                                output)
                        idx += 1

                # Result
                label = y[:, 1:ans_len + 1].contiguous()
                t1, t2 = cal_cer(att_pred,
                                 label,
                                 mapper=self.mapper,
                                 get_sentence=True)
                all_pred += t1
                all_true += t2
                val_cer += cal_cer(att_pred, label, mapper=self.mapper) * int(
                    x.shape[0])
                val_len += int(x.shape[0])

        if split == 'dev':
            # Dump model score to ensure model is corrected
            decode_path = os.path.join(self.ckpdir, 'dev_att_decode.txt')
            er_msg = 'Validation Error Rate of Current model : {:.4f}      '.format(
                val_cer / val_len)
            save_msg = 'See {} for validation results.'.format(decode_path)
        else:
            decode_path = os.path.join(self.ckpdir, self.decode_file + '.txt')
            er_msg = 'Test Error Rate: {:.4f}      '.format(val_cer / val_len)
            save_msg = 'See {} for decoding results.'.format(decode_path)

        self.verbose(er_msg)
        self.verbose(save_msg)

        ## MWER
        if self.mwer:
            f.close()
            return 0

        with open(decode_path, 'w') as f:
            for hyp, gt in zip(all_pred, all_true):
                f.write(gt.lstrip() + '\t' + hyp.lstrip() + '\n')

        # Also dump CTC result if available
        if ctc_results[0] is not None and split == 'dev':
            ctc_results = [i for ins in ctc_results for i in ins]
            ctc_text = []
            for pred in ctc_results:
                p = [i for i in pred.tolist() if i != 0]
                p = [k for k, g in itertools.groupby(p)]
                ctc_text.append(self.mapper.translate(p, return_string=True))
            self.verbose('Also, see {} for CTC validation results.'.format(
                os.path.join(self.ckpdir, 'dev_ctc_decode.txt')))
            with open(os.path.join(self.ckpdir, 'dev_ctc_decode.txt'),
                      'w') as f:
                for hyp, gt in zip(ctc_text, all_true):
                    f.write(gt.lstrip() + '\t' + hyp.lstrip() + '\n')
コード例 #5
0
    def valid(self):
        '''Perform validation step (!!!NOTE!!! greedy decoding with Attention decoder only)'''
        self.asr_model.eval()

        # Init stats
        val_loss, val_ctc, val_att, val_acc, val_cer = 0.0, 0.0, 0.0, 0.0, 0.0
        val_len = 0
        all_pred, all_true = [],[]

        progress_bar = tqdm(self.dev_set, leave=False)
        # Perform validation
        for cur_b,(x,y,state_len) in enumerate(progress_bar):
            progress_bar.set_description("[valid {}/{}]".format(self.step, cur_b, len(self.dev_set)))

            # Prepare data
            if len(x.shape) == 4: x = x.squeeze(0)
            if len(y.shape) == 3: y = y.squeeze(0)
            if state_len.dim() == 2: state_len = state_len.squeeze(0)
            x = x.to(device = self.device,dtype=torch.float32)
            y = y.to(device = self.device,dtype=torch.long)
            state_len = state_len.to(device = self.device,dtype=torch.long)
            # state_len = torch.sum(torch.sum(x.cpu(),dim=-1) != 0,dim=-1)
            # state_len = [int(sl) for sl in state_len]
            ans_len = int(torch.max(torch.sum(y != 0, dim=-1)))

            # Forward
            ctc_pred, state_len, att_pred, att_maps = self.asr_model(x, ans_len+VAL_STEP,state_len=state_len)

            # Compute attention loss & get decoding results
            label = y[:,1:ans_len+1].contiguous()
            if self.ctc_weight < 1:
                seq_loss = self.seq_loss(att_pred[:,:ans_len,:].contiguous().view(-1,att_pred.shape[-1]),label.view(-1))
                seq_loss = torch.sum(seq_loss.view(x.shape[0],-1),dim=-1)/torch.sum(y != 0, dim=-1)\
                                .to(device = self.device,dtype=torch.float32)  # Sum each uttr and devide by length
                seq_loss = torch.mean(seq_loss)  # Mean by batch
                val_att += seq_loss.detach()*int(x.shape[0])
                t1,t2 = cal_cer(att_pred,label,mapper=self.mapper,get_sentence=True)
                all_pred += t1
                all_true += t2
                val_acc += cal_acc(att_pred,label)*int(x.shape[0])
                val_cer += cal_cer(att_pred,label,mapper=self.mapper)*int(x.shape[0])
            else:
                # only ctc
                t1,t2 = cal_cer(ctc_pred,label,mapper=self.mapper,get_sentence=True)
                all_pred += t1
                all_true += t2
                val_acc += cal_acc(ctc_pred,label)*int(x.shape[0])
                val_cer += cal_cer(ctc_pred,label,mapper=self.mapper)*int(x.shape[0])

            # Compute CTC loss
            if self.ctc_weight > 0:
                target_len = torch.sum(y != 0, dim=-1)
                ctc_loss = self.ctc_loss(F.log_softmax( ctc_pred.transpose(0,1),dim=-1), label,
                                         torch.LongTensor(state_len), target_len)
                val_ctc += ctc_loss.detach()*int(x.shape[0])

            val_len += int(x.shape[0])

        # Logger
        val_loss = (1-self.ctc_weight)*val_att + self.ctc_weight*val_ctc
        loss_log = {}
        for k,v in zip(['dev_full','dev_ctc','dev_att'],[val_loss, val_ctc, val_att]):
            if v > 0.0: loss_log[k] = v/val_len
        self.write_log('loss',loss_log)

        # attention decoder
        if self.ctc_weight < 1:
            # Plot attention map to log
            val_hyp,val_txt = cal_cer(att_pred,label,mapper=self.mapper,get_sentence=True)
            val_attmap = draw_att(att_maps,att_pred)

            # Record loss
            self.write_log('error rate',{'dev':val_cer/val_len})
            self.write_log('acc',{'dev':val_acc/val_len})
            for idx,attmap in enumerate(val_attmap):
                self.write_log('att_'+str(idx),attmap)
                self.write_log('hyp_'+str(idx),val_hyp[idx])
                self.write_log('txt_'+str(idx),val_txt[idx])

        else:
            # only use ctc
            val_hyp, val_txt = cal_cer(ctc_pred, label,mapper=self.mapper,get_sentence=True)

            # Record loss
            self.write_log('ctc error rate',{'dev':val_cer/val_len})
            self.write_log('ctc acc',{'dev':val_acc/val_len})
            for idx in range(len(val_hyp)):
                self.write_log('hyp_'+str(idx),val_hyp[idx])
                self.write_log('txt_'+str(idx),val_txt[idx])

        # Save model by val er.
        self.maybe_dump_checkpoint(val_cer / val_len, all_pred, all_true)

        self.asr_model.train()
コード例 #6
0
    def exec(self):
        ''' Training End-to-end ASR system'''
        self.verbose('Training set total '+str(len(self.train_set))+' batches.')
        progress_bar = tqdm(range(self.max_step))
        while self.step < self.max_step:
            for x,y,state_len in self.train_set:
                progress_bar.set_description("[training {}]".format(self.step))
                # Perform teacher forcing rate decaying
                tf_rate = self.tf_start - self.step*(self.tf_start-self.tf_end)/self.max_step

                # Hack bucket, record state length for each uttr, get longest label seq for decode step
                assert len(x.shape) == 4,'Bucketing should cause acoustic feature to have shape 1xBxTxD'
                assert len(y.shape) == 3,'Bucketing should cause label have to shape 1xBxT'
                x = x.squeeze(0).to(device = self.device,dtype=torch.float32)
                y = y.squeeze(0).to(device = self.device,dtype=torch.long)
                state_len = state_len.squeeze(0).to(device = self.device,dtype=torch.long)
                # state_len = np.sum(np.sum(x.cpu().data.numpy(),axis=-1) != 0,axis=-1)
                # state_len = [int(sl) for sl in state_len]
                ans_len = int(torch.max(torch.sum(y != 0, dim=-1)))

                # ASR forwarding
                self.asr_opt.zero_grad()
                ctc_pred, state_len, att_pred, _ = self.asr_model(
                    x, ans_len,
                    tf_rate=tf_rate,
                    teacher=y,
                    state_len=state_len
                )

                # print()
                # print("inp", x.shape[1] // self.asr_model.encoder.downsample_rate)
                # print("ctc", ctc_pred.shape[1] if ctc_pred is not None else 0)
                # print("att", att_pred.shape[1] if att_pred is not None else 0)
                assert ctc_pred is not None, "for saber, must use ctc_pred!"
                assert x.shape[1] // self.asr_model.encoder.downsample_rate == ctc_pred.shape[1]

                # Calculate loss function
                loss_log = {}
                label = y[:,1:ans_len+1].contiguous()
                ctc_loss = 0
                att_loss = 0

                # CE loss on attention decoder
                if self.ctc_weight < 1:
                    b,t,c = att_pred.shape
                    att_loss = self.seq_loss(att_pred.view(b*t,c),label.view(-1))
                    att_loss = torch.sum(att_loss.view(b,t),dim=-1)/torch.sum(y != 0, dim=-1)\
                                    .to(device = self.device,dtype=torch.float32)  # Sum each uttr and devide by length
                    att_loss = torch.mean(att_loss)  # Mean by batch
                    loss_log['train_att'] = att_loss

                # CTC loss on CTC decoder
                if self.ctc_weight > 0:
                    target_len = torch.sum(y != 0, dim=-1)
                    ctc_loss = self.ctc_loss(
                        F.log_softmax(ctc_pred.transpose(0,1), dim=-1), label,
                        torch.LongTensor(state_len),                    target_len
                    )
                    loss_log['train_ctc'] = ctc_loss

                asr_loss = (1-self.ctc_weight)*att_loss+self.ctc_weight*ctc_loss
                loss_log['train_full'] = asr_loss

                # Adversarial loss from CLM
                if (
                    self.apply_clm and
                    (att_pred is not None) and
                    (att_pred.shape[1] >= CLM_MIN_SEQ_LEN)
                ):
                    if (self.step % self.clm.update_freq) == 0:
                        # update CLM once in a while
                        clm_log,gp = self.clm.train(att_pred.detach(),CLM_MIN_SEQ_LEN)
                        self.write_log('clm_score',clm_log)
                        self.write_log('clm_gp',gp)
                    adv_feedback = self.clm.compute_loss(F.softmax(att_pred))
                    asr_loss -= adv_feedback

                # Backprop
                asr_loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(self.asr_model.parameters(), GRAD_CLIP)
                if math.isnan(grad_norm):
                    self.verbose('Error : grad norm is NaN @ step '+str(self.step))
                else:
                    self.asr_opt.step()

                # Logger
                self.write_log('loss',loss_log)
                if self.ctc_weight < 1:
                    self.write_log('acc', {'train':cal_acc(att_pred,label)})
                    if self.step % TRAIN_WER_STEP == 0:
                        self.write_log('error rate', {'train':cal_cer(att_pred,label,mapper=self.mapper)})
                else:
                    # only ctc
                    self.write_log('ctc acc', {'train':cal_acc(ctc_pred,label)})
                    if self.step % TRAIN_WER_STEP == 0:
                        self.write_log('ctc error rate', {'train':cal_cer(ctc_pred,label,mapper=self.mapper)})

                # visualize inputs
                if self.step % 1000 == 0:
                    example_inputs = x[0].cpu().detach().numpy().transpose(1, 0)
                    img = visualizer.plot(visualizer.plot_item(example_inputs, "inputs-feature"))
                    self.log.add_image("0-inputs", img, global_step=self.step, dataformats="HWC")

                # Validation
                if self.step % self.valid_step == 0:
                    self.asr_opt.zero_grad()
                    self.valid()

                self.step += 1
                progress_bar.update()
                if self.step > self.max_step:break