def evaluate(self): self.asr_model.eval() self.write_tr_logs() dev_info_ls = [ RunningAvgDict(decay_rate=1.) for _ in range(self.num_pretrain) ] for idx, dev_loader in enumerate(self.data_container.dev_loaders): tbar = get_bar( total=len(dev_loader), desc=f"Eval on {self.accents[idx]} @ step {self.global_step}") with torch.no_grad(): for cur_b, (x, ilens, ys, olens) in enumerate(dev_loader): if ilens.max() > self.dev_max_ilen: tbar.update(1) continue batch_size = len(ys) info = self._eval(idx, x, ilens, ys, olens) dev_info_ls[idx].add(info, batch_size) if cur_b % self.log_ival == 0: logger.log_info(dev_info_ls[idx], prefix='test') del x, ilens, ys, olens tbar.update(1) logger.flush() tbar.close() self.dashboard.log_info(f"dev_{self.accents[idx]}", dev_info_ls[idx]) self.write_dev_logs(f"dev_{self.accents[idx]}", dev_info_ls[idx]) dev_avg_info = RunningAvgDict(decay_rate=1.0) for dev_info in dev_info_ls: dev_avg_info.add({k: float(v) for k, v in dev_info.items()}) self.dashboard.log_info("dev", dev_avg_info) self.write_dev_logs("dev_avg", dev_avg_info) cur_cer = float(dev_avg_info['cer']) cur_wer = float(dev_avg_info['wer']) if cur_wer < self.best_wer: self.best_wer = cur_wer self.save_best_model() if cur_cer < self.best_cer: self.best_cer = cur_cer self.save_best_model('cer', only_stat=True) self.asr_model.train()
def evaluate(self): self.asr_model.eval() dev_info = RunningAvgDict(decay_rate=1.) tbar = get_bar(total=len(self.dev_set), desc=f"Eval @step{self.global_step}", leave=True) with torch.no_grad(): for cur_b, (x, ilens, ys, olens) in enumerate(self.dev_set): if ilens.max() > self.dev_max_ilen: tbar.update(1) continue batch_size = len(ys) info = self._eval(cur_b, x, ilens, ys, olens) dev_info.add(info, batch_size) if cur_b % self.log_ival == 0: logger.log_info(dev_info, prefix='test') del x, ilens, ys, olens tbar.update(1) logger.flush() tbar.close() self.dashboard.log_info('dev', dev_info) self.write_logs(dev_info) cur_cer = float(dev_info['cer']) cur_wer = float(dev_info['wer']) if cur_wer < self.best_wer: self.best_wer = cur_wer self.save_best_model() if cur_cer < self.best_cer: self.best_cer = cur_cer self.save_best_model('cer', only_stat=True) if self.lr_scheduler is not None: self.lr_scheduler.step(float(dev_info['loss'])) self.asr_model.train()
def check_evaluate(self): if self.global_step % self.eval_ival == 0: logger.flush() self.asr_opt.zero_grad() self.evaluate()