def test_unet(root,\ psf_path, method,\ scale, \ model_path,\ visual, \ use_gpu, b_size=1): ''' Model UNet ''' model_name = method + '_poisson' save_images_path = './Results/' + model_name + '_peak_' + str( int(scale)) + '/' test_dataset = CellDataset(root, psf_path, 'poisson', scale, 0.0) test_loader = DataLoader(test_dataset, batch_size=b_size, shuffle=False, num_workers=1) model = UNet(mode='batch') state_dict = torch.load(os.path.join(model_path, model_name)) state_dict = state_dict['model_state_dict'] new_state_dict = OrderedDict() for k, v in state_dict.items(): new_state_dict[k] = v model.load_state_dict(new_state_dict) if use_gpu == 1: model.cuda() model.eval() psnr_values_test = [] ssim_values_test = [] distorted_psnr_test = [] distorted_ssim_test = [] with torch.no_grad(): for i_batch, ((gt, image), psf, index, image_name, peak, _) in enumerate(tqdm(test_loader)): image = image.reshape( (b_size, 1, image.shape[-2], image.shape[-1])) gt = gt.reshape((b_size, 1, gt.shape[-2], gt.shape[-1])) if use_gpu == 1: image = image.cuda() gt = gt.cuda() for l in range(gt.shape[0]): image[l] = image[l] / gt[l].max() gt[l] /= gt[l].max() output = model(image) distorted_psnr = calc_psnr(image.clamp(0, 1), gt) distorted_ssim = ssim(image.clamp(0, 1), gt) psnr_test = calc_psnr(output.clamp(0, 1), gt) s_sim_test = ssim(output.clamp(0, 1), gt) psnr_values_test.append(psnr_test.item()) ssim_values_test.append(s_sim_test.item()) distorted_psnr_test.append(distorted_psnr.item()) distorted_ssim_test.append(distorted_ssim.item()) #Save image if visual == 1: if not os.path.exists(save_images_path): os.makedirs(save_images_path, exist_ok=True) io.imsave(os.path.join(save_images_path, 'output_' + str(image_name[0][:-4]) + '_' + \ str(model_name) + '_' + str(int(scale)) + '.png'), np.uint8(output[0][0].detach().cpu().numpy().clip(0,1) * 255.)) print('Test on Poisson noise with peak %d: PSNR %.2f, SSIM %.4f, distorted PSNR %.2f, distorted SSIM %.4f' % (peak, np.array(psnr_values_test).mean(), \ np.array(ssim_values_test).mean(), \ np.array(distorted_psnr_test).mean(), \ np.array(distorted_ssim_test).mean())) return
from pathlib import Path from detection import TrainNet from networks import UNet from propagation import GuideCall if __name__ == "__main__": torch.cuda.set_device(1) date = datetime.now().date() gpu = True key = 2 weight_path = "./weight/best.pth" # image_path train_path = Path("./images/train") val_path = Path("./images/val") guided_input_path = sorted(train_path.joinpath("ori").glob("*.tif")) # guided output output_path = Path("output") # define model net = UNet(n_channels=1, n_classes=1) net.cuda() net.load_state_dict( torch.load(weight_path, map_location={"cuda:2": "cuda:0"})) bp = GuideCall(guided_input_path, output_path, net) bp.main()
def test_unet(root,\ psf_path, method,\ std, \ model_path,\ visual, \ use_gpu,\ b_size=1): """ Model UNet """ model_name = method + '_gaussian' save_images_path = './Results/' + model_name + '_std_' + str(std).replace( '.', '') + '/' test_dataset = CellDataset(root, psf_path, 'gaussian', 1.0, std) test_loader = DataLoader(test_dataset, batch_size=b_size, shuffle=False, num_workers=1) model = UNet(mode='batch') state_dict = torch.load(os.path.join(model_path, model_name)) state_dict = state_dict['model_state_dict'] new_state_dict = OrderedDict() for k, v in state_dict.items(): new_state_dict[k] = v model.load_state_dict(new_state_dict) model.eval() if use_gpu == 1: model.cuda() psnr_values_test = [] ssim_values_test = [] distorted_psnr_test = [] distorted_ssim_test = [] for i_batch, ((gt, image), psf, index, image_name, _, std) in enumerate(tqdm(test_loader)): image = image.reshape((b_size, 1, image.shape[-2], image.shape[-1])) gt = gt.reshape((b_size, 1, gt.shape[-2], gt.shape[-1])) if use_gpu == 1: image = image.cuda() gt = gt.cuda() distorted_psnr = calc_psnr(image.clamp(gt.min(), gt.max()), gt) distorted_ssim = ssim(image.clamp(gt.min(), gt.max()), gt) output = model(image) psnr_test = calc_psnr(output.clamp(gt.min(), gt.max()), gt) s_sim_test = ssim(output.clamp(gt.min(), gt.max()), gt) psnr_values_test.append(psnr_test.item()) ssim_values_test.append(s_sim_test.item()) distorted_psnr_test.append(distorted_psnr.item()) distorted_ssim_test.append(distorted_ssim.item()) #Save image if visual == 1: if not os.path.exists(save_images_path): os.makedirs(save_images_path, exist_ok=True) io.imsave(os.path.join(save_images_path, 'output_' + str(image_name[0][:-4]) + '_' + \ str(model_name) + '_' + str(std.item()).replace('.', '') + '.png'), np.uint8(output[0][0].detach().cpu().numpy().clip(0,1) * 255.)) print('Test on Gaussian noise with %.3f std: PSNR %.2f, SSIM %.4f, distorted PSNR %.2f, distorted SSIM %.4f' % (std, np.array(psnr_values_test).mean(), \ np.array(ssim_values_test).mean(), \ np.array(distorted_psnr_test).mean(), \ np.array(distorted_ssim_test).mean())) return