def eval(self, model, val_loader, image_crit, index, aster, aster_info): for p in model.parameters(): p.requires_grad = False for p in aster.parameters(): p.requires_grad = False model.eval() aster.eval() n_correct = 0 sum_images = 0 metric_dict = {'psnr': [], 'ssim': [], 'accuracy': 0.0, 'psnr_avg': 0.0, 'ssim_avg': 0.0} for i, data in (enumerate(val_loader)): images_hr, images_lr, label_strs = data val_batch_size = images_lr.shape[0] images_lr = images_lr.to(self.device) images_hr = images_hr.to(self.device) images_sr = model(images_lr) metric_dict['psnr'].append(self.cal_psnr(images_sr, images_hr)) metric_dict['ssim'].append(self.cal_ssim(images_sr, images_hr)) aster_dict_sr = self.parse_aster_data(images_sr[:, :3, :, :]) aster_dict_lr = self.parse_aster_data(images_lr[:, :3, :, :]) aster_output_lr = aster(aster_dict_lr) aster_output_sr = aster(aster_dict_sr) pred_rec_lr = aster_output_lr['output']['pred_rec'] pred_rec_sr = aster_output_sr['output']['pred_rec'] pred_str_lr, _ = get_str_list(pred_rec_lr, aster_dict_lr['rec_targets'], dataset=aster_info) pred_str_sr, _ = get_str_list(pred_rec_sr, aster_dict_sr['rec_targets'], dataset=aster_info) for pred, target in zip(pred_str_sr, label_strs): if pred == str_filt(target, 'lower'): n_correct += 1 loss_im = image_crit(images_sr, images_hr).mean() loss_rec = aster_output_sr['losses']['loss_rec'].mean() sum_images += val_batch_size torch.cuda.empty_cache() psnr_avg = sum(metric_dict['psnr']) / len(metric_dict['psnr']) ssim_avg = sum(metric_dict['ssim']) / len(metric_dict['ssim']) print('[{}]\t' 'loss_rec {:.3f}| loss_im {:.3f}\t' 'PSNR {:.2f} | SSIM {:.4f}\t' .format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), float(loss_rec.data), float(loss_im.data), float(psnr_avg), float(ssim_avg), )) print('save display images') self.tripple_display(images_lr, images_sr, images_hr, pred_str_lr, pred_str_sr, label_strs, index) accuracy = round(n_correct / sum_images, 4) psnr_avg = round(psnr_avg.item(), 6) ssim_avg = round(ssim_avg.item(), 6) print('aster_accuray: %.2f%%' % (accuracy * 100)) metric_dict['accuracy'] = accuracy metric_dict['psnr_avg'] = psnr_avg metric_dict['ssim_avg'] = ssim_avg return metric_dict
def test(self): model_dict = self.generator_init() model, image_crit = model_dict['model'], model_dict['crit'] test_data, test_loader = self.get_test_data(self.test_data_dir) data_name = self.args.test_data_dir.split('/')[-1] print('evaling %s' % data_name) if self.args.rec == 'moran': moran = self.MORAN_init() moran.eval() elif self.args.rec == 'aster': aster, aster_info = self.Aster_init() aster.eval() elif self.args.rec == 'crnn': crnn = self.CRNN_init() crnn.eval() # print(sum(p.numel() for p in moran.parameters())) if self.args.arch != 'bicubic': for p in model.parameters(): p.requires_grad = False model.eval() n_correct = 0 sum_images = 0 metric_dict = {'psnr': [], 'ssim': [], 'accuracy': 0.0, 'psnr_avg': 0.0, 'ssim_avg': 0.0} current_acc_dict = {data_name: 0} time_begin = time.time() sr_time = 0 for i, data in (enumerate(test_loader)): images_hr, images_lr, label_strs = data val_batch_size = images_lr.shape[0] images_lr = images_lr.to(self.device) images_hr = images_hr.to(self.device) sr_beigin = time.time() images_sr = model(images_lr) # images_sr = images_lr sr_end = time.time() sr_time += sr_end - sr_beigin metric_dict['psnr'].append(self.cal_psnr(images_sr, images_hr)) metric_dict['ssim'].append(self.cal_ssim(images_sr, images_hr)) if self.args.rec == 'moran': moran_input = self.parse_moran_data(images_sr[:, :3, :, :]) moran_output = moran(moran_input[0], moran_input[1], moran_input[2], moran_input[3], test=True, debug=True) preds, preds_reverse = moran_output[0] _, preds = preds.max(1) sim_preds = self.converter_moran.decode(preds.data, moran_input[1].data) pred_str_sr = [pred.split('$')[0] for pred in sim_preds] elif self.args.rec == 'aster': aster_dict_sr = self.parse_aster_data(images_sr[:, :3, :, :]) aster_output_sr = aster(aster_dict_sr) pred_rec_sr = aster_output_sr['output']['pred_rec'] pred_str_sr, _ = get_str_list(pred_rec_sr, aster_dict_sr['rec_targets'], dataset=aster_info) aster_dict_lr = self.parse_aster_data(images_lr[:, :3, :, :]) aster_output_lr = aster(aster_dict_lr) pred_rec_lr = aster_output_lr['output']['pred_rec'] pred_str_lr, _ = get_str_list(pred_rec_lr, aster_dict_lr['rec_targets'], dataset=aster_info) elif self.args.rec == 'crnn': crnn_input = self.parse_crnn_data(images_sr[:, :3, :, :]) crnn_output = crnn(crnn_input) _, preds = crnn_output.max(2) preds = preds.transpose(1, 0).contiguous().view(-1) preds_size = torch.IntTensor([crnn_output.size(0)] * val_batch_size) pred_str_sr = self.converter_crnn.decode(preds.data, preds_size.data, raw=False) for pred, target in zip(pred_str_sr, label_strs): if str_filt(pred, 'lower') == str_filt(target, 'lower'): n_correct += 1 sum_images += val_batch_size torch.cuda.empty_cache() print('Evaluation: [{}][{}/{}]\t' .format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), i + 1, len(test_loader), )) # self.test_display(images_lr, images_sr, images_hr, pred_str_lr, pred_str_sr, label_strs, str_filt) time_end = time.time() psnr_avg = sum(metric_dict['psnr']) / len(metric_dict['psnr']) ssim_avg = sum(metric_dict['ssim']) / len(metric_dict['ssim']) acc = round(n_correct / sum_images, 4) fps = sum_images/(time_end - time_begin) psnr_avg = round(psnr_avg.item(), 6) ssim_avg = round(ssim_avg.item(), 6) current_acc_dict[data_name] = float(acc) # result = {'accuracy': current_acc_dict, 'fps': fps} result = {'accuracy': current_acc_dict, 'psnr_avg': psnr_avg, 'ssim_avg': ssim_avg, 'fps': fps} print(result)