Exemplo n.º 1
0
class Solver(BaseSolver):
    ''' Solver for training'''
    def __init__(self, config, paras):
        super().__init__(config, paras)
        # Logger settings
        self.val_loss = 1000
        self.cur_epoch = 0

    def fetch_data(self, data):
        ''' Move data to device '''
        file_id, audio_feat, audio_len = data
        if self.gpu:
            audio_feat = audio_feat.cuda()
        return file_id, audio_feat, audio_len

    def load_data(self):
        ''' Load data for training/validation '''
        self.tr_set, self.dv_set, _, self.audio_dim, msg = \
            prepare_data(self.paras.njobs, self.paras.dev_njobs, self.paras.gpu,
                         self.paras.pin_memory, **self.config['data'])
        self.verbose(msg)

    def set_model(self):
        ''' Setup model and optimizer '''
        # Model
        self.method = self.config['model']['method']
        if self.method in ['apc','vqapc']:
            self.n_future = self.config['model']['n_future']
            from model.apc import APC as Net
        elif self.method == 'npc':
            from model.npc import NPC as Net
        else:
            raise NotImplementedError
        self.model = Net(input_size=self.audio_dim, **self.config['model']['paras'])
        if self.gpu:
            self.model = self.model.cuda()
        self.verbose(self.model.create_msg())
        model_paras = [{'params': self.model.parameters()}]

        # Loss
        if 'npc' in self.method:
            # Avoid reduction for NPC for zero-padding
            self.loss = torch.nn.L1Loss(reduction='none')
        else:
            # APC family have zero-padding with torch API
            self.loss = torch.nn.L1Loss()
        if self.gpu:
            self.loss = self.loss.cuda()

        # Optimizer
        self.optimizer = Optimizer(model_paras, **self.config['hparas'])
        self.verbose(self.optimizer.create_msg())

        # Automatically load pre-trained model if self.paras.load is given
        self.load_ckpt()

        # ToDo:  Data Parallel?
        # self.model = torch.nn.DataParallel(self.model)
        self.model.train()

    def exec(self):
        ''' Training End-to-end ASR system '''
        self.verbose('Total training epoch {}.'.format(
            human_format(self.epoch)))
        self.timer.set()
        aug_loss = None
        ep_len = len(self.tr_set)

        for ep in range(self.epoch):
            # Pre-step, decay
            if ep>0:
                self.optimizer.decay()

            for data in self.tr_set:
                # Pre-step : update tf_rate/lr_rate and do zero_grad
                self.optimizer.pre_step(self.step)
                
                # Fetch data
                _, audio_feat, audio_len = self.fetch_data(data)
                self.timer.cnt('rd')

                # Forward real data
                if 'npc' in self.method:
                    # NPC: input = target
                    pred, _ = self.model(audio_feat)
                    loss = self.loss(pred, audio_feat)
                    # Compute loss on valid part only
                    effective_loss = 0
                    for i,a_len in enumerate(audio_len):
                        effective_loss += loss[i,:a_len,:].mean(dim=-1).sum()
                    loss = effective_loss/sum(audio_len)
                else:
                    # APC: input = shifted target
                    audio_len = [l-self.n_future for l in audio_len]
                    pred, _ = self.model(audio_feat[:,:-self.n_future,:], audio_len, testing=False)
                    loss = self.loss(pred, audio_feat[:,self.n_future:,:])
                self.timer.cnt('fw')
                # Backprop
                grad_norm = self.backward(loss)
                self.step += 1

                # Logger
                if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0):
                    self.progress(' {:2.1f} % | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'
                                  .format(100*float(self.step%ep_len)/ep_len,
                                          loss.cpu().item(),
                                          grad_norm,
                                          self.timer.show()))
                    self.write_log('loss', {'tr': loss})
                    
                if (self.step == 1) or (self.step % self.PLOT_STEP == 0):
                    # Perplexity of P(token)
                    g1_ppx, g2_ppx = self.model.report_ppx()     
                    self.write_log('ppx', {'group 1':g1_ppx,
                                           'group 2':g2_ppx})
                    g1_usg, g2_usg = self.model.report_usg() # Empty cache
                    # Plots
                    if self.paras.draw:
                        g1_hist = draw(g1_usg, hist=True)
                        g2_hist = draw(g2_usg, hist=True)
                        self.write_log('VQ Group 1 Hist.',g1_hist)
                        self.write_log('VQ Group 2 Hist.',g2_hist)
                        # Some spectrograms
                        plt_idx = 0
                        self.write_log('Spectrogram (raw)', draw(audio_feat[plt_idx]))
                        self.write_log('Spectrogram (pred)', draw(pred[plt_idx]))

                # End of step
                self.timer.set()
            # End of epoch
            self.cur_epoch += 1
            self.validate()
        self.log.close()

    def validate(self):
        # Eval mode
        self.model.eval()
        dev_loss = []
        for i, data in enumerate(self.dv_set):
            self.progress('Valid step - {}/{}'.format(i+1, len(self.dv_set)))
            # Fetch data
            _, audio_feat, audio_len = self.fetch_data(data)

            # Forward model
            with torch.no_grad():
                if 'npc' in self.method:
                    pred, _ = self.model(audio_feat, testing=True)
                    loss = self.loss(pred, audio_feat)
                    # Compute loss on valid part only
                    effective_loss = 0
                    for i,a_len in enumerate(audio_len):
                        effective_loss += loss[i,:a_len,:].mean(dim=-1).sum()
                    loss = effective_loss/sum(audio_len)
                else:
                    audio_len = [l-self.n_future for l in audio_len]
                    pred, _ = self.model(audio_feat[:,:-self.n_future,:], audio_len, testing=True)
                    loss = self.loss(pred, audio_feat[:,self.n_future:,:])
                dev_loss.append(loss.cpu().item())

        # Record metric
        dev_loss = sum(dev_loss)/len(dev_loss)
        self.write_log('loss', {'dev':dev_loss})
        if dev_loss < self.val_loss:
            self.val_loss = dev_loss
            self.save_checkpoint('best_loss.pth', 'loss', dev_loss)
        # Resume training
        self.model.train()
Exemplo n.º 2
0
class Solver(BaseSolver):
    ''' Solver for training'''
    def __init__(self, config, paras):
        super().__init__(config, paras)
        # Logger settings
        self.best_dev_er = 1.0
        self.cur_epoch = 0
        # Configs following self-supervised learning
        self.task = self.paras.task
        assert self.task in ['phn-clf', 'spk-clf'], 'unsupported task'
        self.ssl_config = yaml.load(open(
            self.config['model']['feat']['config'], 'r'),
                                    Loader=yaml.FullLoader)
        self.feature = self.ssl_config['model']['method']
        if self.feature == 'npc' and 'spec' in self.config['model']['feat']:
            # NPC has additional option to use unmasked feature
            self.feat_spec = self.config['model']['feat']['spec']
        else:
            self.feat_spec = None
        self.config['data']['audio'] = self.ssl_config['data']['audio']

    def fetch_data(self, data, train=True):
        ''' Move data to device '''
        file_id, audio_feat, audio_len, label = data
        if self.gpu:
            audio_feat = audio_feat.cuda()
            label = label.cuda()
        # Extract feature
        with torch.no_grad():
            if self.feat_spec is not None:
                # Get unmasked feature from particular NPC layer
                n_layer_feat = int(self.feat_spec.split('-')[-1])
                audio_feat = self.feat_extractor.get_unmasked_feat(
                    audio_feat, n_layer_feat)
            elif self.feature == 'npc':
                # Get masked feature from NPC
                _, audio_feat = self.feat_extractor(audio_feat, testing=True)
            else:
                # Get feature from APC based model
                _, audio_feat = self.feat_extractor(audio_feat,
                                                    audio_len,
                                                    testing=True)
            # Mean pool feature for spkr classification
            if self.task == 'spk-clf':
                single_feat = []
                for a_feat, a_len in zip(audio_feat, audio_len):
                    single_feat.append(a_feat[:a_len].mean(dim=0))
                audio_feat = torch.stack(single_feat, dim=0)
        return file_id, audio_feat, audio_len, label

    def load_data(self):
        ''' Load data for training/validation '''
        self.tr_set, self.dv_set, self.tt_set, self.audio_dim, msg = \
            prepare_data(self.paras.njobs,self.paras.dev_njobs,self.paras.gpu,
                         self.paras.pin_memory, **self.config['data'])
        self.verbose(msg)

    def set_model(self):
        ''' Setup model and optimizer '''
        # Load SSL models for feature extraction
        self.verbose([' Load feat. extractor ckpt from '\
                        +self.config['model']['feat']['ckpt']])
        if self.feature in ['apc', 'vqapc']:
            from model.apc import APC as Net
        elif self.feature == 'npc':
            from model.npc import NPC as Net
            if self.feat_spec is not None:
                self.verbose([' Using specific feature: ' + self.feat_spec])
        else:
            raise NotImplementedError
        self.feat_extractor = Net(input_size=self.audio_dim,
                                  **self.ssl_config['model']['paras'])
        ckpt = torch.load(
            self.config['model']['feat']['ckpt'],
            map_location=self.device if self.mode == 'train' else 'cpu')
        ckpt['model'] = {k.replace('module.','',1):v \
                            for k,v in ckpt['model'].items()}
        self.feat_extractor.load_state_dict(ckpt['model'])

        # Classifier model
        self.model = CLF(feat_dim=self.feat_extractor.code_dim,
                         **self.config['model']['clf'])
        if self.gpu:
            self.feat_extractor = self.feat_extractor.cuda()
            self.feat_extractor.eval()
            self.model = self.model.cuda()
        model_paras = [{'params': self.model.parameters()}]

        # Losses
        ignore_idx = 0 if self.task == 'phn-clf' else -1
        self.loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_idx)
        if self.gpu:
            self.loss = self.loss.cuda()

        # Optimizer
        self.optimizer = Optimizer(model_paras, **self.config['hparas'])
        self.verbose(self.optimizer.create_msg())

        self.load_ckpt()
        self.model.train()

    def exec(self):
        ''' Training End-to-end ASR system '''
        if self.paras.mode == 'train':
            self.verbose('Total training epoch {}.'.format(
                human_format(self.epoch)))
            self.timer.set()
            ep_len = len(self.tr_set)
            for ep in range(self.epoch):
                if ep > 0:
                    # Lr decay if needed
                    self.optimizer.decay()
                for data in self.tr_set:
                    # Pre-step :  do zero_grad
                    self.optimizer.pre_step(self.step)

                    # Fetch data
                    self.timer.cnt('rd')
                    _, audio_feat, audio_len, label = self.fetch_data(data)

                    # Forward
                    pred = self.model(audio_feat)
                    if self.task == 'phn-clf':
                        pred = pred.permute(0, 2, 1)  # BxCxT for phn clf
                    loss = self.loss(pred, label)
                    self.timer.cnt('fw')

                    # Backprop
                    grad_norm = self.backward(loss)
                    self.step += 1

                    # Logger
                    if (self.step == 1) or (self.step % self.PROGRESS_STEP
                                            == 0):
                        self.progress(
                            ' {:2.1f} % | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'
                            .format(100 * float(self.step % ep_len) / ep_len,
                                    loss.cpu().item(), grad_norm,
                                    self.timer.show()))
                        self.write_log(self.task + '_loss', {'tr': loss})
                        if self.task == 'phn-clf':
                            tr_er = cal_per(pred, label, audio_len)[0]
                        else:
                            tr_er = (pred.argmax(dim=-1) != label)
                            tr_er = tr_er.sum().detach().cpu().float() / len(
                                label)
                        self.write_log(self.task + '_er', {'tr': tr_er})
                    # End of step
                    self.timer.set()
                # End of epoch
                self.cur_epoch += 1
                self.validate()

        # Test at the end
        self.validate(test=True)
        self.log.close()

    def validate(self, test=False):
        # Eval mode
        self.model.eval()
        val_loss = []
        split = 'dev'
        val_hit, val_total = 0.0, 0.0
        ds = self.tt_set if test else self.dv_set

        # In training mode, best model is stored in RAM for test
        # ToDo: load ckpt
        if test:
            split = 'test'
            if self.paras.mode == 'train':
                self.model = self.best_model
                if self.gpu:
                    self.model = self.model.cuda()

        for i, data in enumerate(ds):
            self.progress('Valid step - {}/{}'.format(i + 1, len(ds)))
            # Fetch data
            _, audio_feat, audio_len, label = self.fetch_data(data)

            # Forward model
            with torch.no_grad():
                # Prediction
                pred = self.model(audio_feat)
                if self.task == 'phn-clf':
                    pred = pred.permute(0, 2, 1)  # BxCxT
                # Accumulate batch result
                val_loss.append(self.loss(pred, label))
                if self.task == 'phn-clf':
                    _, hit, total = cal_per(pred, label, audio_len)
                    val_hit += hit
                    val_total += total
                else:
                    hit = (pred.argmax(dim=-1) == label).sum()
                    val_hit += hit.detach().cpu().float()
                    val_total += len(label)
                # Write testing prediction if needed
                if test and self.paras.write_test:
                    if self.task == 'phn-clf':
                        pred = pred.argmax(dim=1).detach().cpu()
                    label = label.cpu()
                    with open(os.path.join(self.ckpdir, self.task + '.csv'),
                              'a') as f:
                        for p, l, a_len in zip(pred, label, audio_len):
                            for x, y in zip(p[:a_len].tolist(),
                                            l[:a_len].tolist()):
                                f.write('{}\t{}\n'.format(x, y))

        # Record metric, store ckpt by dev error rate
        val_loss = sum(val_loss) / len(val_loss)
        val_er = 1.0 - val_hit / val_total
        self.write_log(self.task + '_loss', {split: val_loss})
        self.write_log(self.task + '_er', {split: val_er})
        if split == 'dev' and self.best_dev_er > val_er:
            self.best_dev_er = val_er
            self.save_checkpoint('best.pth', self.task + '_er', val_er)
            self.best_model = copy.deepcopy(self.model.cpu())  # Clone for test

        # Resume training
        if self.gpu:
            self.model = self.model.cuda()
        self.model.train()