def evaluate(self):
        self.asr_model.eval()
        self.write_tr_logs()

        dev_info_ls = [
            RunningAvgDict(decay_rate=1.) for _ in range(self.num_pretrain)
        ]
        for idx, dev_loader in enumerate(self.data_container.dev_loaders):
            tbar = get_bar(
                total=len(dev_loader),
                desc=f"Eval on {self.accents[idx]} @ step {self.global_step}")
            with torch.no_grad():
                for cur_b, (x, ilens, ys, olens) in enumerate(dev_loader):

                    if ilens.max() > self.dev_max_ilen:
                        tbar.update(1)
                        continue

                    batch_size = len(ys)
                    info = self._eval(idx, x, ilens, ys, olens)
                    dev_info_ls[idx].add(info, batch_size)

                    if cur_b % self.log_ival == 0:
                        logger.log_info(dev_info_ls[idx], prefix='test')

                    del x, ilens, ys, olens
                    tbar.update(1)

                logger.flush()
                tbar.close()

                self.dashboard.log_info(f"dev_{self.accents[idx]}",
                                        dev_info_ls[idx])
                self.write_dev_logs(f"dev_{self.accents[idx]}",
                                    dev_info_ls[idx])

        dev_avg_info = RunningAvgDict(decay_rate=1.0)
        for dev_info in dev_info_ls:
            dev_avg_info.add({k: float(v) for k, v in dev_info.items()})

        self.dashboard.log_info("dev", dev_avg_info)
        self.write_dev_logs("dev_avg", dev_avg_info)
        cur_cer = float(dev_avg_info['cer'])
        cur_wer = float(dev_avg_info['wer'])
        if cur_wer < self.best_wer:
            self.best_wer = cur_wer
            self.save_best_model()
        if cur_cer < self.best_cer:
            self.best_cer = cur_cer
            self.save_best_model('cer', only_stat=True)

        self.asr_model.train()
Ejemplo n.º 2
0
    def evaluate(self):
        self.asr_model.eval()

        dev_info = RunningAvgDict(decay_rate=1.)
        tbar = get_bar(total=len(self.dev_set),
                       desc=f"Eval @step{self.global_step}",
                       leave=True)

        with torch.no_grad():
            for cur_b, (x, ilens, ys, olens) in enumerate(self.dev_set):

                if ilens.max() > self.dev_max_ilen:
                    tbar.update(1)
                    continue

                batch_size = len(ys)
                info = self._eval(cur_b, x, ilens, ys, olens)
                dev_info.add(info, batch_size)

                if cur_b % self.log_ival == 0:
                    logger.log_info(dev_info, prefix='test')

                del x, ilens, ys, olens
                tbar.update(1)

            logger.flush()
            tbar.close()

            self.dashboard.log_info('dev', dev_info)
            self.write_logs(dev_info)

            cur_cer = float(dev_info['cer'])
            cur_wer = float(dev_info['wer'])
            if cur_wer < self.best_wer:
                self.best_wer = cur_wer
                self.save_best_model()

            if cur_cer < self.best_cer:
                self.best_cer = cur_cer
                self.save_best_model('cer', only_stat=True)

            if self.lr_scheduler is not None:
                self.lr_scheduler.step(float(dev_info['loss']))

        self.asr_model.train()
 def check_evaluate(self):
     if self.global_step % self.eval_ival == 0:
         logger.flush()
         self.asr_opt.zero_grad()
         self.evaluate()