Esempio n. 1
0
    def valid(self):
        '''Perform validation step (!!!NOTE!!! greedy decoding with Attention decoder only)'''
        self.asr_model.eval()

        # Init stats
        val_loss, val_ctc, val_att, val_acc, val_cer = 0.0, 0.0, 0.0, 0.0, 0.0
        val_len = 0
        all_pred, all_true = [], []

        # Perform validation
        for cur_b, (x, y) in enumerate(self.dev_set):
            self.progress(' '.join([
                'Valid step -',
                str(self.step), '(',
                str(cur_b), '/',
                str(len(self.dev_set)), ')'
            ]))

            # Prepare data
            if len(x.shape) == 4: x = x.squeeze(0)
            if len(y.shape) == 3: y = y.squeeze(0)
            x = x.to(device=self.device, dtype=torch.float32)
            y = y.to(device=self.device, dtype=torch.long)
            state_len = torch.sum(torch.sum(x.cpu(), dim=-1) != 0, dim=-1)
            state_len = [int(sl) for sl in state_len]
            ans_len = int(torch.max(torch.sum(y != 0, dim=-1)))

            # Forward
            ctc_pred, state_len, att_pred, att_maps = self.asr_model(
                x, ans_len + VAL_STEP, state_len=state_len)

            # Compute attention loss & get decoding results
            label = y[:, 1:ans_len + 1].contiguous()
            if self.ctc_weight < 1:
                seq_loss = self.seq_loss(
                    att_pred[:, :ans_len, :].contiguous().view(
                        -1, att_pred.shape[-1]), label.view(-1))
                seq_loss = torch.sum(seq_loss.view(x.shape[0],-1),dim=-1)/torch.sum(y!=0,dim=-1)\
                           .to(device = self.device,dtype=torch.float32) # Sum each uttr and devide by length
                seq_loss = torch.mean(seq_loss)  # Mean by batch
                val_att += seq_loss.detach() * int(x.shape[0])
                t1, t2 = cal_cer(att_pred,
                                 label,
                                 mapper=self.mapper,
                                 get_sentence=True)
                all_pred += t1
                all_true += t2
                val_acc += cal_acc(att_pred, label) * int(x.shape[0])
                val_cer += cal_cer(att_pred, label, mapper=self.mapper) * int(
                    x.shape[0])

            # Compute CTC loss
            if self.ctc_weight > 0:
                target_len = torch.sum(y != 0, dim=-1)
                ctc_loss = self.ctc_loss(
                    F.log_softmax(ctc_pred.transpose(0, 1), dim=-1), label,
                    torch.LongTensor(state_len), target_len)
                val_ctc += ctc_loss.detach() * int(x.shape[0])

            val_len += int(x.shape[0])

        # Logger
        val_loss = (1 - self.ctc_weight) * val_att + self.ctc_weight * val_ctc
        loss_log = {}
        for k, v in zip(['dev_full', 'dev_ctc', 'dev_att'],
                        [val_loss, val_ctc, val_att]):
            if v > 0.0: loss_log[k] = v / val_len
        self.write_log('loss', loss_log)

        if self.ctc_weight < 1:
            # Plot attention map to log
            val_hyp, val_txt = cal_cer(att_pred,
                                       label,
                                       mapper=self.mapper,
                                       get_sentence=True)
            val_attmap = draw_att(att_maps, att_pred)

            # Record loss
            self.write_log('error rate', {'dev': val_cer / val_len})
            self.write_log('acc', {'dev': val_acc / val_len})
            for idx, attmap in enumerate(val_attmap):
                self.write_log('att_' + str(idx), attmap)
                self.write_log('hyp_' + str(idx), val_hyp[idx])
                self.write_log('txt_' + str(idx), val_txt[idx])

            # Save model by val er.
            if val_cer / val_len < self.best_val_ed:
                self.best_val_ed = val_cer / val_len
                self.verbose(
                    'Best val er       : {:.4f}       @ step {}'.format(
                        self.best_val_ed, self.step))
                torch.save(self.asr_model, os.path.join(self.ckpdir, 'asr'))
                if self.apply_clm:
                    torch.save(self.clm.clm, os.path.join(self.ckpdir, 'clm'))
                # Save hyps.
                with open(os.path.join(self.ckpdir, 'best_hyp.txt'), 'w') as f:
                    for t1, t2 in zip(all_pred, all_true):
                        f.write(t1 + ',' + t2 + '\n')

        self.asr_model.train()
Esempio n. 2
0
    def valid(self):
        '''Perform validation step (!!!NOTE!!! greedy decoding with Attention decoder only)'''
        self.asr_model.eval()

        # Init stats
        val_loss, val_ctc, val_att, val_acc, val_cer = 0.0, 0.0, 0.0, 0.0, 0.0
        val_len = 0
        all_pred, all_true = [],[]

        progress_bar = tqdm(self.dev_set, leave=False)
        # Perform validation
        for cur_b,(x,y,state_len) in enumerate(progress_bar):
            progress_bar.set_description("[valid {}/{}]".format(self.step, cur_b, len(self.dev_set)))

            # Prepare data
            if len(x.shape) == 4: x = x.squeeze(0)
            if len(y.shape) == 3: y = y.squeeze(0)
            if state_len.dim() == 2: state_len = state_len.squeeze(0)
            x = x.to(device = self.device,dtype=torch.float32)
            y = y.to(device = self.device,dtype=torch.long)
            state_len = state_len.to(device = self.device,dtype=torch.long)
            # state_len = torch.sum(torch.sum(x.cpu(),dim=-1) != 0,dim=-1)
            # state_len = [int(sl) for sl in state_len]
            ans_len = int(torch.max(torch.sum(y != 0, dim=-1)))

            # Forward
            ctc_pred, state_len, att_pred, att_maps = self.asr_model(x, ans_len+VAL_STEP,state_len=state_len)

            # Compute attention loss & get decoding results
            label = y[:,1:ans_len+1].contiguous()
            if self.ctc_weight < 1:
                seq_loss = self.seq_loss(att_pred[:,:ans_len,:].contiguous().view(-1,att_pred.shape[-1]),label.view(-1))
                seq_loss = torch.sum(seq_loss.view(x.shape[0],-1),dim=-1)/torch.sum(y != 0, dim=-1)\
                                .to(device = self.device,dtype=torch.float32)  # Sum each uttr and devide by length
                seq_loss = torch.mean(seq_loss)  # Mean by batch
                val_att += seq_loss.detach()*int(x.shape[0])
                t1,t2 = cal_cer(att_pred,label,mapper=self.mapper,get_sentence=True)
                all_pred += t1
                all_true += t2
                val_acc += cal_acc(att_pred,label)*int(x.shape[0])
                val_cer += cal_cer(att_pred,label,mapper=self.mapper)*int(x.shape[0])
            else:
                # only ctc
                t1,t2 = cal_cer(ctc_pred,label,mapper=self.mapper,get_sentence=True)
                all_pred += t1
                all_true += t2
                val_acc += cal_acc(ctc_pred,label)*int(x.shape[0])
                val_cer += cal_cer(ctc_pred,label,mapper=self.mapper)*int(x.shape[0])

            # Compute CTC loss
            if self.ctc_weight > 0:
                target_len = torch.sum(y != 0, dim=-1)
                ctc_loss = self.ctc_loss(F.log_softmax( ctc_pred.transpose(0,1),dim=-1), label,
                                         torch.LongTensor(state_len), target_len)
                val_ctc += ctc_loss.detach()*int(x.shape[0])

            val_len += int(x.shape[0])

        # Logger
        val_loss = (1-self.ctc_weight)*val_att + self.ctc_weight*val_ctc
        loss_log = {}
        for k,v in zip(['dev_full','dev_ctc','dev_att'],[val_loss, val_ctc, val_att]):
            if v > 0.0: loss_log[k] = v/val_len
        self.write_log('loss',loss_log)

        # attention decoder
        if self.ctc_weight < 1:
            # Plot attention map to log
            val_hyp,val_txt = cal_cer(att_pred,label,mapper=self.mapper,get_sentence=True)
            val_attmap = draw_att(att_maps,att_pred)

            # Record loss
            self.write_log('error rate',{'dev':val_cer/val_len})
            self.write_log('acc',{'dev':val_acc/val_len})
            for idx,attmap in enumerate(val_attmap):
                self.write_log('att_'+str(idx),attmap)
                self.write_log('hyp_'+str(idx),val_hyp[idx])
                self.write_log('txt_'+str(idx),val_txt[idx])

        else:
            # only use ctc
            val_hyp, val_txt = cal_cer(ctc_pred, label,mapper=self.mapper,get_sentence=True)

            # Record loss
            self.write_log('ctc error rate',{'dev':val_cer/val_len})
            self.write_log('ctc acc',{'dev':val_acc/val_len})
            for idx in range(len(val_hyp)):
                self.write_log('hyp_'+str(idx),val_hyp[idx])
                self.write_log('txt_'+str(idx),val_txt[idx])

        # Save model by val er.
        self.maybe_dump_checkpoint(val_cer / val_len, all_pred, all_true)

        self.asr_model.train()
Esempio n. 3
0
    def exec(self):
        ''' Training End-to-end ASR system'''
        self.verbose('Training set total ' + str(len(self.train_set)) +
                     ' batches.')

        while self.step < self.max_step:
            for x, y in self.train_set:
                self.progress('Training step - ' + str(self.step))

                # Perform teacher forcing rate decaying
                tf_rate = self.tf_start - self.step * (
                    self.tf_start - self.tf_end) / self.max_step

                # Hack bucket, record state length for each uttr, get longest label seq for decode step
                assert len(
                    x.shape
                ) == 4, 'Bucketing should cause acoustic feature to have shape 1xBxTxD'
                assert len(
                    y.shape
                ) == 3, 'Bucketing should cause label have to shape 1xBxT'
                x = x.squeeze(0).to(device=self.device, dtype=torch.float32)
                y = y.squeeze(0).to(device=self.device, dtype=torch.long)
                state_len = np.sum(np.sum(x.cpu().data.numpy(), axis=-1) != 0,
                                   axis=-1)
                state_len = [int(sl) for sl in state_len]
                ans_len = int(torch.max(torch.sum(y != 0, dim=-1)))

                # ASR forwarding
                self.asr_opt.zero_grad()
                ctc_pred, state_len, att_pred, _ = self.asr_model(
                    x,
                    ans_len,
                    tf_rate=tf_rate,
                    teacher=y,
                    state_len=state_len)

                # Calculate loss function
                loss_log = {}
                label = y[:, 1:ans_len + 1].contiguous()
                ctc_loss = 0
                att_loss = 0

                # CE loss on attention decoder
                if self.ctc_weight < 1:
                    b, t, c = att_pred.shape
                    att_loss = self.seq_loss(att_pred.view(b * t, c),
                                             label.view(-1))
                    att_loss = torch.sum(att_loss.view(b,t),dim=-1)/torch.sum(y!=0,dim=-1)\
                               .to(device = self.device,dtype=torch.float32) # Sum each uttr and devide by length
                    att_loss = torch.mean(att_loss)  # Mean by batch
                    loss_log['train_att'] = att_loss

                # CTC loss on CTC decoder
                if self.ctc_weight > 0:
                    target_len = torch.sum(y != 0, dim=-1)
                    ctc_loss = self.ctc_loss(
                        F.log_softmax(ctc_pred.transpose(0, 1), dim=-1), label,
                        torch.LongTensor(state_len), target_len)
                    loss_log['train_ctc'] = ctc_loss

                asr_loss = (1 - self.ctc_weight
                            ) * att_loss + self.ctc_weight * ctc_loss
                loss_log['train_full'] = asr_loss

                # Adversarial loss from CLM
                if self.apply_clm and att_pred.shape[1] >= CLM_MIN_SEQ_LEN:
                    if (self.step % self.clm.update_freq) == 0:
                        # update CLM once in a while
                        clm_log, gp = self.clm.train(att_pred.detach(),
                                                     CLM_MIN_SEQ_LEN)
                        self.write_log('clm_score', clm_log)
                        self.write_log('clm_gp', gp)
                    adv_feedback = self.clm.compute_loss(F.softmax(att_pred))
                    asr_loss -= adv_feedback

                # Backprop
                asr_loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.asr_model.parameters(), GRAD_CLIP)
                if math.isnan(grad_norm):
                    self.verbose('Error : grad norm is NaN @ step ' +
                                 str(self.step))
                else:
                    self.asr_opt.step()

                # Logger
                self.write_log('loss', loss_log)
                if self.ctc_weight < 1:
                    self.write_log('acc', {'train': cal_acc(att_pred, label)})
                if self.step % TRAIN_WER_STEP == 0:
                    self.write_log('error rate', {
                        'train':
                        cal_cer(att_pred, label, mapper=self.mapper)
                    })

                # Validation
                if self.step % self.valid_step == 0:
                    self.asr_opt.zero_grad()
                    self.valid()

                self.step += 1
                if self.step > self.max_step: break
Esempio n. 4
0
    def exec(self):
        ''' Training End-to-end ASR system'''
        self.verbose('Training set total '+str(len(self.train_set))+' batches.')
        progress_bar = tqdm(range(self.max_step))
        while self.step < self.max_step:
            for x,y,state_len in self.train_set:
                progress_bar.set_description("[training {}]".format(self.step))
                # Perform teacher forcing rate decaying
                tf_rate = self.tf_start - self.step*(self.tf_start-self.tf_end)/self.max_step

                # Hack bucket, record state length for each uttr, get longest label seq for decode step
                assert len(x.shape) == 4,'Bucketing should cause acoustic feature to have shape 1xBxTxD'
                assert len(y.shape) == 3,'Bucketing should cause label have to shape 1xBxT'
                x = x.squeeze(0).to(device = self.device,dtype=torch.float32)
                y = y.squeeze(0).to(device = self.device,dtype=torch.long)
                state_len = state_len.squeeze(0).to(device = self.device,dtype=torch.long)
                # state_len = np.sum(np.sum(x.cpu().data.numpy(),axis=-1) != 0,axis=-1)
                # state_len = [int(sl) for sl in state_len]
                ans_len = int(torch.max(torch.sum(y != 0, dim=-1)))

                # ASR forwarding
                self.asr_opt.zero_grad()
                ctc_pred, state_len, att_pred, _ = self.asr_model(
                    x, ans_len,
                    tf_rate=tf_rate,
                    teacher=y,
                    state_len=state_len
                )

                # print()
                # print("inp", x.shape[1] // self.asr_model.encoder.downsample_rate)
                # print("ctc", ctc_pred.shape[1] if ctc_pred is not None else 0)
                # print("att", att_pred.shape[1] if att_pred is not None else 0)
                assert ctc_pred is not None, "for saber, must use ctc_pred!"
                assert x.shape[1] // self.asr_model.encoder.downsample_rate == ctc_pred.shape[1]

                # Calculate loss function
                loss_log = {}
                label = y[:,1:ans_len+1].contiguous()
                ctc_loss = 0
                att_loss = 0

                # CE loss on attention decoder
                if self.ctc_weight < 1:
                    b,t,c = att_pred.shape
                    att_loss = self.seq_loss(att_pred.view(b*t,c),label.view(-1))
                    att_loss = torch.sum(att_loss.view(b,t),dim=-1)/torch.sum(y != 0, dim=-1)\
                                    .to(device = self.device,dtype=torch.float32)  # Sum each uttr and devide by length
                    att_loss = torch.mean(att_loss)  # Mean by batch
                    loss_log['train_att'] = att_loss

                # CTC loss on CTC decoder
                if self.ctc_weight > 0:
                    target_len = torch.sum(y != 0, dim=-1)
                    ctc_loss = self.ctc_loss(
                        F.log_softmax(ctc_pred.transpose(0,1), dim=-1), label,
                        torch.LongTensor(state_len),                    target_len
                    )
                    loss_log['train_ctc'] = ctc_loss

                asr_loss = (1-self.ctc_weight)*att_loss+self.ctc_weight*ctc_loss
                loss_log['train_full'] = asr_loss

                # Adversarial loss from CLM
                if (
                    self.apply_clm and
                    (att_pred is not None) and
                    (att_pred.shape[1] >= CLM_MIN_SEQ_LEN)
                ):
                    if (self.step % self.clm.update_freq) == 0:
                        # update CLM once in a while
                        clm_log,gp = self.clm.train(att_pred.detach(),CLM_MIN_SEQ_LEN)
                        self.write_log('clm_score',clm_log)
                        self.write_log('clm_gp',gp)
                    adv_feedback = self.clm.compute_loss(F.softmax(att_pred))
                    asr_loss -= adv_feedback

                # Backprop
                asr_loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(self.asr_model.parameters(), GRAD_CLIP)
                if math.isnan(grad_norm):
                    self.verbose('Error : grad norm is NaN @ step '+str(self.step))
                else:
                    self.asr_opt.step()

                # Logger
                self.write_log('loss',loss_log)
                if self.ctc_weight < 1:
                    self.write_log('acc', {'train':cal_acc(att_pred,label)})
                    if self.step % TRAIN_WER_STEP == 0:
                        self.write_log('error rate', {'train':cal_cer(att_pred,label,mapper=self.mapper)})
                else:
                    # only ctc
                    self.write_log('ctc acc', {'train':cal_acc(ctc_pred,label)})
                    if self.step % TRAIN_WER_STEP == 0:
                        self.write_log('ctc error rate', {'train':cal_cer(ctc_pred,label,mapper=self.mapper)})

                # visualize inputs
                if self.step % 1000 == 0:
                    example_inputs = x[0].cpu().detach().numpy().transpose(1, 0)
                    img = visualizer.plot(visualizer.plot_item(example_inputs, "inputs-feature"))
                    self.log.add_image("0-inputs", img, global_step=self.step, dataformats="HWC")

                # Validation
                if self.step % self.valid_step == 0:
                    self.asr_opt.zero_grad()
                    self.valid()

                self.step += 1
                progress_bar.update()
                if self.step > self.max_step:break