コード例 #1
0
def evaluate_single_epoch(config, student_model, teacher_model, dataloader,
                          criterion, epoch, writer, visualizer, postfix_dict,
                          eval_type):
    teacher_model.eval()
    student_model.eval()
    with torch.no_grad():
        batch_size = config.eval.batch_size
        total_size = len(dataloader.dataset)
        total_step = math.ceil(total_size / batch_size)

        tbar = tqdm.tqdm(enumerate(dataloader), total=total_step)

        total_psnr = 0
        total_loss = 0
        for i, (LR_img, HR_img, filepath) in tbar:
            HR_img = HR_img[:, :1].to(device)
            LR_img = LR_img[:, :1].to(device)

            student_pred_dict = student_model.forward(LR=LR_img)
            pred_hr = student_pred_dict['hr']
            total_loss += criterion['val'](pred_hr, HR_img).item()

            pred_hr = quantize(pred_hr, config.data.rgb_range)
            total_psnr += get_psnr(pred_hr,
                                   HR_img,
                                   config.data.scale,
                                   config.data.rgb_range,
                                   benchmark=eval_type == 'test')

            f_epoch = epoch + i / total_step
            desc = '{:5s}'.format(eval_type)
            desc += ', {:06d}/{:06d}, {:.2f} epoch'.format(
                i, total_step, f_epoch)
            tbar.set_description(desc)
            tbar.set_postfix(**postfix_dict)

            # for test
            teacher_pred_dict = teacher_model.forward(LR=LR_img, HR=HR_img)

            if writer is not None and eval_type == 'test':
                fig = visualizer(LR_img, HR_img, student_pred_dict,
                                 teacher_pred_dict)
                writer.add_figure('{}/{:04d}'.format(eval_type, i),
                                  fig,
                                  global_step=epoch)


#         print(total_pseudo_psnr / (i+1))
        log_dict = {}
        avg_loss = total_loss / (i + 1)
        avg_psnr = total_psnr / (i + 1)
        log_dict['loss'] = avg_loss
        log_dict['psnr'] = avg_psnr

        for key, value in log_dict.items():
            if writer is not None:
                writer.add_scalar('{}/{}'.format(eval_type, key), value, epoch)
            postfix_dict['{}/{}'.format(eval_type, key)] = value

        return avg_psnr
コード例 #2
0
ファイル: evaluate.py プロジェクト: zhwzhong/PISR
def evaluate_single_epoch(config, student_model, dataloader_dict, eval_type):
    student_model.eval()
    log_dict = {}
    with torch.no_grad():
        for name, dataloader in dataloader_dict.items():
            print('evaluate %s'%(name))
            batch_size = config.eval.batch_size
            total_size = len(dataloader.dataset)
            total_step = math.ceil(total_size / batch_size)

            tbar = tqdm.tqdm(enumerate(dataloader), total=total_step)

            total_psnr = 0
            total_iter = 0
            for i, (LR_img, HR_img, filepath) in tbar:
                HR_img = HR_img.to(device)
                LR_img = LR_img.to(device)

                student_pred_dict = student_model.forward(LR=LR_img)
                pred_hr = student_pred_dict['hr']
                pred_hr = quantize(pred_hr, config.data.rgb_range)
                total_psnr += get_psnr(pred_hr, HR_img, config.data.scale,
                                      config.data.rgb_range,
                                      benchmark=eval_type=='test')

                f_epoch = i / total_step
                desc = '{:5s}'.format(eval_type)
                desc += ', {:06d}/{:06d}, {:.2f} epoch'.format(i, total_step, f_epoch)
                tbar.set_description(desc)
                total_iter = i

            avg_psnr = total_psnr / (total_iter+1)
            log_dict[name] = avg_psnr
            print('%s : %.3f'%(name, avg_psnr))
            
    return log_dict