コード例 #1
0
    def set_model(self):
        ''' Setup ASR (and CLM if enabled)'''
        self.verbose('Init ASR model. Note: validation is done through greedy decoding w/ attention decoder.')

        # Build attention end-to-end ASR
        self.asr_model = Seq2Seq(self.sample_x,self.mapper.get_dim(),self.config['asr_model']).to(self.device)
        if 'VGG' in self.config['asr_model']['encoder']['enc_type']:
            self.verbose('VCC Extractor in Encoder is enabled, time subsample rate = 4.')
        self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0, reduction='none').to(self.device)  # , reduction='none')

        # Involve CTC
        self.ctc_loss = torch.nn.CTCLoss(blank=0, reduction='mean')
        self.ctc_weight = self.config['asr_model']['optimizer']['joint_ctc']

        # TODO: load pre-trained model
        if self.paras.load:
            raise NotImplementedError

        # Setup optimizer
        if self.apex and self.config['asr_model']['optimizer']['type'] == 'Adam':
            import apex
            self.asr_opt = apex.optimizers.FusedAdam(self.asr_model.parameters(), lr=self.config['asr_model']['optimizer']['learning_rate'])
        else:
            self.asr_opt = getattr(torch.optim,self.config['asr_model']['optimizer']['type'])
            self.asr_opt = self.asr_opt(self.asr_model.parameters(), lr=self.config['asr_model']['optimizer']['learning_rate'],eps=1e-8)

        # Apply CLM
        if self.apply_clm:
            self.clm = CLM_wrapper(self.mapper.get_dim(), self.config['clm']).to(self.device)
            clm_data_config = self.config['solver']
            clm_data_config['train_set'] = self.config['clm']['source']
            clm_data_config['use_gpu'] = self.paras.gpu
            self.clm.load_text(clm_data_config)
            self.verbose('CLM is enabled with text-only source: '+str(clm_data_config['train_set']))
            self.verbose('Extra text set total '+str(len(self.clm.train_set))+' batches.')
コード例 #2
0
ファイル: solver.py プロジェクト: Eman22S/Amharic-Seq2Seq
class Trainer(Solver):
    ''' Handler for complete training progress'''
    def __init__(self, config, paras):
        super(Trainer, self).__init__(config, paras)
        # Logger Settings
        self.logdir = os.path.join(paras.logdir, self.exp_name)
        self.log = SummaryWriter(self.logdir)
        self.valid_step = config['solver']['dev_step']
        self.best_val_ed = 2.0

        # Training details
        self.step = 0
        self.max_step = config['solver']['total_steps']
        self.tf_start = config['solver']['tf_start']
        self.tf_end = config['solver']['tf_end']
        self.apex = config['solver']['apex']

        # CLM option
        self.apply_clm = config['clm']['enable']

    def load_data(self):
        ''' Load date for training/validation'''
        self.verbose('Loading data from ' + self.config['solver']['data_path'])
        setattr(
            self, 'train_set',
            LoadDataset('train',
                        text_only=False,
                        use_gpu=self.paras.gpu,
                        **self.config['solver']))
        setattr(
            self, 'dev_set',
            LoadDataset('dev',
                        text_only=False,
                        use_gpu=self.paras.gpu,
                        **self.config['solver']))

        # Get 1 example for auto constructing model
        for self.sample_x, _ in getattr(self, 'train_set'):
            break
        if len(self.sample_x.shape) == 4: self.sample_x = self.sample_x[0]

    def set_model(self):
        ''' Setup ASR (and CLM if enabled)'''
        self.verbose(
            'Init ASR model. Note: validation is done through greedy decoding w/ attention decoder.'
        )

        # Build attention end-to-end ASR

        self.asr_model = Seq2Seq(self.sample_x, self.mapper.get_dim(),
                                 self.config['asr_model']).to(self.device)
        if 'VGG' in self.config['asr_model']['encoder']['enc_type']:
            self.verbose(
                'VCC Extractor in Encoder is enabled, time subsample rate = 4.'
            )
        self.seq_loss = torch.nn.CrossEntropyLoss(
            ignore_index=0,
            reduction='none').to(self.device)  #, reduction='none')

        # Involve CTC
        self.ctc_loss = torch.nn.CTCLoss(blank=0, reduction='mean')
        self.ctc_weight = self.config['asr_model']['optimizer']['joint_ctc']

        # TODO: load pre-trained model
        if self.paras.load:
            raise NotImplementedError

        # Setup optimizer
        if self.apex and self.config['asr_model']['optimizer'][
                'type'] == 'Adam':
            import apex
            self.asr_opt = apex.optimizers.FusedAdam(
                self.asr_model.parameters(),
                lr=self.config['asr_model']['optimizer']['learning_rate'])
        else:
            self.asr_opt = getattr(
                torch.optim, self.config['asr_model']['optimizer']['type'])
            self.asr_opt = self.asr_opt(
                self.asr_model.parameters(),
                lr=self.config['asr_model']['optimizer']['learning_rate'],
                eps=1e-8)

        # Apply CLM
        if self.apply_clm:
            self.clm = CLM_wrapper(self.mapper.get_dim(),
                                   self.config['clm']).to(self.device)
            clm_data_config = self.config['solver']
            clm_data_config['train_set'] = self.config['clm']['source']
            clm_data_config['use_gpu'] = self.paras.gpu
            self.clm.load_text(clm_data_config)
            self.verbose('CLM is enabled with text-only source: ' +
                         str(clm_data_config['train_set']))
            self.verbose('Extra text set total ' +
                         str(len(self.clm.train_set)) + ' batches.')

    def exec(self):
        ''' Training End-to-end ASR system'''
        self.verbose('Training set total ' + str(len(self.train_set)) +
                     ' batches.')

        while self.step < self.max_step:
            for x, y in self.train_set:
                self.progress('Training step - ' + str(self.step))

                # Perform teacher forcing rate decaying
                tf_rate = self.tf_start - self.step * (
                    self.tf_start - self.tf_end) / self.max_step

                # Hack bucket, record state length for each uttr, get longest label seq for decode step
                assert len(
                    x.shape
                ) == 4, 'Bucketing should cause acoustic feature to have shape 1xBxTxD'
                assert len(
                    y.shape
                ) == 3, 'Bucketing should cause label have to shape 1xBxT'
                x = x.squeeze(0).to(device=self.device, dtype=torch.float32)
                y = y.squeeze(0).to(device=self.device, dtype=torch.long)
                state_len = np.sum(np.sum(x.cpu().data.numpy(), axis=-1) != 0,
                                   axis=-1)
                state_len = [int(sl) for sl in state_len]
                ans_len = int(torch.max(torch.sum(y != 0, dim=-1)))

                # ASR forwarding
                self.asr_opt.zero_grad()
                ctc_pred, state_len, att_pred, _ = self.asr_model(
                    x,
                    ans_len,
                    tf_rate=tf_rate,
                    teacher=y,
                    state_len=state_len)

                # Calculate loss function
                loss_log = {}
                label = y[:, 1:ans_len + 1].contiguous()
                ctc_loss = 0
                att_loss = 0

                # CE loss on attention decoder
                if self.ctc_weight < 1:
                    b, t, c = att_pred.shape
                    att_loss = self.seq_loss(att_pred.view(b * t, c),
                                             label.view(-1))
                    att_loss = torch.sum(att_loss.view(b,t),dim=-1)/torch.sum(y!=0,dim=-1)\
                               .to(device = self.device,dtype=torch.float32) # Sum each uttr and devide by length
                    att_loss = torch.mean(att_loss)  # Mean by batch
                    loss_log['train_att'] = att_loss

                # CTC loss on CTC decoder
                if self.ctc_weight > 0:
                    target_len = torch.sum(y != 0, dim=-1)
                    ctc_loss = self.ctc_loss(
                        F.log_softmax(ctc_pred.transpose(0, 1), dim=-1), label,
                        torch.LongTensor(state_len), target_len)
                    loss_log['train_ctc'] = ctc_loss

                asr_loss = (1 - self.ctc_weight
                            ) * att_loss + self.ctc_weight * ctc_loss
                loss_log['train_full'] = asr_loss

                # Adversarial loss from CLM
                if self.apply_clm and att_pred.shape[1] >= CLM_MIN_SEQ_LEN:
                    if (self.step % self.clm.update_freq) == 0:
                        # update CLM once in a while
                        clm_log, gp = self.clm.train(att_pred.detach(),
                                                     CLM_MIN_SEQ_LEN)
                        self.write_log('clm_score', clm_log)
                        self.write_log('clm_gp', gp)
                    adv_feedback = self.clm.compute_loss(F.softmax(att_pred))
                    asr_loss -= adv_feedback

                # Backprop
                asr_loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.asr_model.parameters(), GRAD_CLIP)
                if math.isnan(grad_norm):
                    self.verbose('Error : grad norm is NaN @ step ' +
                                 str(self.step))
                else:
                    self.asr_opt.step()

                # Logger
                self.write_log('loss', loss_log)
                if self.ctc_weight < 1:
                    self.write_log('acc', {'train': cal_acc(att_pred, label)})
                if self.step % TRAIN_WER_STEP == 0:
                    self.write_log('error rate', {
                        'train':
                        cal_cer(att_pred, label, mapper=self.mapper)
                    })

                # Validation
                if self.step % self.valid_step == 0:
                    self.asr_opt.zero_grad()
                    self.valid()

                self.step += 1
                if self.step > self.max_step: break

    def write_log(self, val_name, val_dict):
        '''Write log to TensorBoard'''
        if 'att' in val_name:
            self.log.add_image(val_name, val_dict, self.step)
        elif 'txt' in val_name or 'hyp' in val_name:
            self.log.add_text(val_name, val_dict, self.step)
        else:
            self.log.add_scalars(val_name, val_dict, self.step)

    def valid(self):
        '''Perform validation step (!!!NOTE!!! greedy decoding with Attention decoder only)'''
        self.asr_model.eval()

        # Init stats
        val_loss, val_ctc, val_att, val_acc, val_cer = 0.0, 0.0, 0.0, 0.0, 0.0
        val_len = 0
        all_pred, all_true = [], []

        # Perform validation
        for cur_b, (x, y) in enumerate(self.dev_set):
            self.progress(' '.join([
                'Valid step -',
                str(self.step), '(',
                str(cur_b), '/',
                str(len(self.dev_set)), ')'
            ]))

            # Prepare data
            if len(x.shape) == 4: x = x.squeeze(0)
            if len(y.shape) == 3: y = y.squeeze(0)
            x = x.to(device=self.device, dtype=torch.float32)
            y = y.to(device=self.device, dtype=torch.long)
            state_len = torch.sum(torch.sum(x.cpu(), dim=-1) != 0, dim=-1)
            state_len = [int(sl) for sl in state_len]
            ans_len = int(torch.max(torch.sum(y != 0, dim=-1)))

            # Forward
            ctc_pred, state_len, att_pred, att_maps = self.asr_model(
                x, ans_len + VAL_STEP, state_len=state_len)

            # Compute attention loss & get decoding results
            label = y[:, 1:ans_len + 1].contiguous()
            if self.ctc_weight < 1:
                seq_loss = self.seq_loss(
                    att_pred[:, :ans_len, :].contiguous().view(
                        -1, att_pred.shape[-1]), label.view(-1))
                seq_loss = torch.sum(seq_loss.view(x.shape[0],-1),dim=-1)/torch.sum(y!=0,dim=-1)\
                           .to(device = self.device,dtype=torch.float32) # Sum each uttr and devide by length
                seq_loss = torch.mean(seq_loss)  # Mean by batch
                val_att += seq_loss.detach() * int(x.shape[0])
                t1, t2 = cal_cer(att_pred,
                                 label,
                                 mapper=self.mapper,
                                 get_sentence=True)
                all_pred += t1
                all_true += t2
                val_acc += cal_acc(att_pred, label) * int(x.shape[0])
                val_cer += cal_cer(att_pred, label, mapper=self.mapper) * int(
                    x.shape[0])

            # Compute CTC loss
            if self.ctc_weight > 0:
                target_len = torch.sum(y != 0, dim=-1)
                ctc_loss = self.ctc_loss(
                    F.log_softmax(ctc_pred.transpose(0, 1), dim=-1), label,
                    torch.LongTensor(state_len), target_len)
                val_ctc += ctc_loss.detach() * int(x.shape[0])

            val_len += int(x.shape[0])

        # Logger
        val_loss = (1 - self.ctc_weight) * val_att + self.ctc_weight * val_ctc
        loss_log = {}
        for k, v in zip(['dev_full', 'dev_ctc', 'dev_att'],
                        [val_loss, val_ctc, val_att]):
            if v > 0.0: loss_log[k] = v / val_len
        self.write_log('loss', loss_log)

        if self.ctc_weight < 1:
            # Plot attention map to log
            val_hyp, val_txt = cal_cer(att_pred,
                                       label,
                                       mapper=self.mapper,
                                       get_sentence=True)
            val_attmap = draw_att(att_maps, att_pred)

            # Record loss
            self.write_log('error rate', {'dev': val_cer / val_len})
            self.write_log('acc', {'dev': val_acc / val_len})
            for idx, attmap in enumerate(val_attmap):
                self.write_log('att_' + str(idx), attmap)
                self.write_log('hyp_' + str(idx), val_hyp[idx])
                self.write_log('txt_' + str(idx), val_txt[idx])

            # Save model by val er.
            if val_cer / val_len < self.best_val_ed:
                self.best_val_ed = val_cer / val_len
                self.verbose(
                    'Best val er       : {:.4f}       @ step {}'.format(
                        self.best_val_ed, self.step))
                torch.save(self.asr_model, os.path.join(self.ckpdir, 'asr'))
                if self.apply_clm:
                    torch.save(self.clm.clm, os.path.join(self.ckpdir, 'clm'))
                # Save hyps.
                with open(os.path.join(self.ckpdir, 'best_hyp.txt'), 'w') as f:
                    for t1, t2 in zip(all_pred, all_true):
                        f.write(t1 + ',' + t2 + '\n')

        self.asr_model.train()