def test(model, test_dataset, test_file_num, epoch, writer=None): model.eval() if test_file_num is None: test_file_num = len(test_dataset) rmse_sum, psnr_sum, ssim_sum = 0, 0, 0 with torch.no_grad(): for i in range(test_file_num): data = test_dataset[i] O, B, M = data['O'], data['B'], data['M'] O = torch.Tensor(O).unsqueeze(0).cuda() B = torch.Tensor(B).unsqueeze(0).cuda() M = torch.Tensor(M).unsqueeze(0).cuda() recon = model.forward(O) # recon = model.forward(recon) recon = torch.clamp(recon, 0., 1.) rmse = batch_RMSE_G(B, recon) psnr = batch_PSNR(B, recon) ssim = get_ssim(B, recon) # print(O.shape, B.shape, M.shape, mask.shape, recon.shape) print('epoch {}, test img {}, rmse {}, psnr {}, ssim {}'.format( epoch, i, rmse, psnr, ssim)) rmse_sum += rmse psnr_sum += psnr ssim_sum += ssim print('\nepoch {}, avg RMSE {}, avg PSNR {}, avg SSIM {}\n'.format( epoch, rmse_sum / test_file_num, psnr_sum / test_file_num, ssim_sum / test_file_num)) if writer is not None: writer.add_scalar('test_RMSE', rmse_sum / test_file_num, epoch) writer.add_scalar('test_PSNR', psnr_sum / test_file_num, epoch) writer.add_scalar('test_SSIM', ssim_sum / test_file_num, epoch)
def generate_result(model, outf, test_dataset, mat=False, ouput_img=True): model.eval() test_file_num = len(test_dataset) if ouput_img and not os.path.exists(os.path.join(outf, 'output_imgs')): os.makedirs(os.path.join(outf, 'output_imgs')) results = [] with torch.no_grad(): for i in range(test_file_num): data = test_dataset[i] O, B, M = data['O'], data['B'], data['M'] O = torch.Tensor(O).unsqueeze(0).cuda() B = torch.Tensor(B).unsqueeze(0).cuda() M = torch.Tensor(M).unsqueeze(0).cuda() recon = model.forward(O) # recon = model.forward(recon) recon = torch.clamp(recon, 0., 1.) rmse = batch_RMSE_G(B, recon) psnr = batch_PSNR(B, recon) ssim = get_ssim(B, recon) print('test img {}, rmse {}, psnr {}, ssim {}'.format(i, rmse, psnr, ssim)) results.append([rmse, psnr, ssim]) if ouput_img: generate_imgs(O, B, recon, outf, i, psnr) results = np.array(results) result_avg = np.sum(results, axis=0) / test_file_num print(results.shape, result_avg.shape) print('average results : {}'.format(result_avg)) results = np.append(results, [result_avg], axis=0) excel = pd.DataFrame(data=results, columns=['rmse', 'psnr', 'ssim']) excel.to_csv(os.path.join(outf, 'results.csv'))
def test(model, test_dataset, epoch, writer): test_image_num = len(test_dataset) model.eval() psnr_sum = 0 for i, data in enumerate(test_dataset): real_hyper, _, real_rgb = data real_hyper, real_rgb = real_hyper.cuda(), real_rgb.cuda() fake_hyper, _ = model.forward(real_rgb) psnr = batch_PSNR(real_hyper, fake_hyper) print('test img [{}/{}], psnr {}'.format(i, test_image_num, psnr.item())) psnr_sum += psnr.item() print('total {} test images, avg psnr {}'.format(test_image_num, psnr_sum / test_image_num)) writer.add_scalar('test_psnr', psnr_sum / test_image_num, epoch)
def train(model, criterion, optimizer, train_loader, epoch, writer): train_times_per_epoch = len(train_loader) model.train() for i, data in enumerate(train_loader): model.zero_grad() optimizer.zero_grad() real_hyper, _, real_rgb = data real_hyper, real_rgb = real_hyper.cuda(), real_rgb.cuda() fake_hyper, _ = model.forward(real_rgb) loss = criterion(fake_hyper, real_hyper) loss.backward() optimizer.step() if i % 10 == 0: psnr = batch_PSNR(real_hyper, fake_hyper.detach()) print("[epoch {}][{}/{}] psnr: {}".format(epoch, i, len(train_loader), psnr.item())) writer.add_scalar('train_psnr', psnr.item(), train_times_per_epoch * epoch + i)
def train_epoch(model, optimizer, train_loader, l1_criterion, mask_criterion, ssim_criterion, epoch, writer=None, radio=1): model.train() num = len(train_loader) for i, data in enumerate(train_loader): model.zero_grad() optimizer.zero_grad() O, B, M = data['O'].cuda(), data['B'].cuda(), data['M'].cuda() recon = model.forward(O) # recon = model.forward(recon) recon = torch.clamp(recon, 0., 1.) l1_loss = l1_criterion(recon, B) # mask_loss = mask_criterion(mask[:, 0, :, :], M) ssim_loss = ssim_criterion(recon, B) loss = l1_loss - radio * ssim_loss # + mask_loss loss.backward() optimizer.step() if i % 10 == 0: with torch.no_grad(): rmse = batch_RMSE_G(B, recon) psnr = batch_PSNR(B, recon) print( 'epoch {}, [{}/{}], loss {}, PSNR {}, SSIM {}, RMSE {}, '. format(epoch, i, num, loss, psnr, ssim_loss.item(), rmse)) if writer is not None: step = epoch * num + i writer.add_scalar('loss', loss.item(), step) writer.add_scalar('l1_loss', l1_loss.item(), step) writer.add_scalar('ssim', ssim_loss.item(), step) # writer.add_scalar('mask_loss', mask_loss.item(), step) writer.add_scalar('RMSE', rmse.item(), step) writer.add_scalar('PSNR', psnr.item(), step)
def main(model_path, need_weisout=False): if not os.path.exists(os.path.join(model_path, 'result')): os.mkdir(os.path.join(model_path, 'result')) if not os.path.exists(os.path.join(model_path, 'mask')): os.mkdir(os.path.join(model_path, 'mask')) print('load model path : {}'.format(model_path)) model_name = 'model.pth' name = model_path.split('/')[-1].split('_') bNum = int(name[2]) nblocks = int(name[3]) model = FMNet(bNum=bNum, nblocks=nblocks, input_features=31, num_features=64, out_features=31) print(bNum, nblocks) model.load_state_dict( torch.load(os.path.join(model_path, model_name), map_location='cpu')) print('model MNet has load !') model.eval() model.cuda() testDataset = HyperDataset(mode='test') num = len(testDataset) print('test img num : {}'.format(num)) test_rgb_filename_list = get_testfile_list() print('test_rgb_filename_list len : {}'.format( len(test_rgb_filename_list))) psnr_sum = 0 all_time = 0 for i in range(num): file_name = test_rgb_filename_list[i].split('/')[-1] key = int(file_name.split('.')[0].split('_')[-1]) print(file_name, key) real_hyper, _, real_rgb = testDataset.get_data_by_key(str(key)) real_hyper, real_rgb = torch.unsqueeze(real_hyper, 0), torch.unsqueeze( real_rgb, 0) real_hyper, real_rgb = real_hyper.cuda(), real_rgb.cuda() # print('test img [{}/{}], input rgb shape : {}, hyper shape : {}'.format(i+1, num, real_rgb.shape, real_hyper.shape)) # forward with torch.no_grad(): start = time.time() fake_hyper = model.forward(real_rgb) all_time += (time.time() - start) if isinstance(fake_hyper, tuple): fake_hyper, weis_out = fake_hyper if need_weisout: weis_out = weis_out[0, :, 0, :, :].cpu().numpy() weis_out = np.squeeze(weis_out) else: weis_out = False else: weis_out = None psnr = batch_PSNR(real_hyper, fake_hyper).item() print('test img [{}/{}], fake hyper shape : {}, psnr : {}'.format( i + 1, num, fake_hyper.shape, psnr)) psnr_sum += psnr fake_hyper_mat = fake_hyper[0, :, :, :].cpu().numpy() if weis_out is None: scio.savemat( os.path.join( model_path, 'result', test_rgb_filename_list[i].split('/')[-1].split('.')[0] + '.mat'), {'rad': fake_hyper_mat}) print('sucessfully save fake hyper !!!') else: scio.savemat( os.path.join( model_path, 'result', test_rgb_filename_list[i].split('/')[-1].split('.')[0] + '.mat'), { 'rad': fake_hyper_mat, 'weis_out': weis_out }) print('sucessfully save fake hyper and weis_out !!!') print() print('average psnr : {}'.format(psnr_sum / num)) print('average test time : {}'.format(all_time / num))