コード例 #1
0
def test_unet(root,\
              psf_path,
              method,\
              scale, \
              model_path,\
              visual, \
              use_gpu,
              b_size=1):
    '''
    Model UNet
    '''

    model_name = method + '_poisson'
    save_images_path = './Results/' + model_name + '_peak_' + str(
        int(scale)) + '/'

    test_dataset = CellDataset(root, psf_path, 'poisson', scale, 0.0)
    test_loader = DataLoader(test_dataset,
                             batch_size=b_size,
                             shuffle=False,
                             num_workers=1)

    model = UNet(mode='batch')
    state_dict = torch.load(os.path.join(model_path, model_name))
    state_dict = state_dict['model_state_dict']
    new_state_dict = OrderedDict()

    for k, v in state_dict.items():
        new_state_dict[k] = v

    model.load_state_dict(new_state_dict)

    if use_gpu == 1:
        model.cuda()

    model.eval()

    psnr_values_test = []
    ssim_values_test = []

    distorted_psnr_test = []
    distorted_ssim_test = []

    with torch.no_grad():
        for i_batch, ((gt, image), psf, index, image_name, peak,
                      _) in enumerate(tqdm(test_loader)):
            image = image.reshape(
                (b_size, 1, image.shape[-2], image.shape[-1]))

            gt = gt.reshape((b_size, 1, gt.shape[-2], gt.shape[-1]))

            if use_gpu == 1:
                image = image.cuda()
                gt = gt.cuda()

            for l in range(gt.shape[0]):
                image[l] = image[l] / gt[l].max()
                gt[l] /= gt[l].max()

            output = model(image)

            distorted_psnr = calc_psnr(image.clamp(0, 1), gt)
            distorted_ssim = ssim(image.clamp(0, 1), gt)

            psnr_test = calc_psnr(output.clamp(0, 1), gt)
            s_sim_test = ssim(output.clamp(0, 1), gt)

            psnr_values_test.append(psnr_test.item())
            ssim_values_test.append(s_sim_test.item())

            distorted_psnr_test.append(distorted_psnr.item())
            distorted_ssim_test.append(distorted_ssim.item())

            #Save image
            if visual == 1:

                if not os.path.exists(save_images_path):
                    os.makedirs(save_images_path, exist_ok=True)

                io.imsave(os.path.join(save_images_path, 'output_' + str(image_name[0][:-4]) + '_' + \
                          str(model_name) + '_' + str(int(scale)) + '.png'), np.uint8(output[0][0].detach().cpu().numpy().clip(0,1) * 255.))


    print('Test on Poisson noise with peak %d: PSNR %.2f, SSIM %.4f, distorted PSNR %.2f, distorted SSIM %.4f' % (peak, np.array(psnr_values_test).mean(), \
                                                                                              np.array(ssim_values_test).mean(), \
                                                                                              np.array(distorted_psnr_test).mean(), \
                                                                                              np.array(distorted_ssim_test).mean()))

    return
コード例 #2
0
from pathlib import Path
from detection import TrainNet
from networks import UNet
from propagation import GuideCall

if __name__ == "__main__":
    torch.cuda.set_device(1)

    date = datetime.now().date()
    gpu = True
    key = 2

    weight_path = "./weight/best.pth"
    # image_path
    train_path = Path("./images/train")
    val_path = Path("./images/val")
    guided_input_path = sorted(train_path.joinpath("ori").glob("*.tif"))

    # guided output
    output_path = Path("output")

    # define model
    net = UNet(n_channels=1, n_classes=1)
    net.cuda()

    net.load_state_dict(
        torch.load(weight_path, map_location={"cuda:2": "cuda:0"}))

    bp = GuideCall(guided_input_path, output_path, net)
    bp.main()
コード例 #3
0
def test_unet(root,\
              psf_path,
              method,\
              std, \
              model_path,\
              visual, \
              use_gpu,\
              b_size=1):
    """
    Model UNet
    """

    model_name = method + '_gaussian'
    save_images_path = './Results/' + model_name + '_std_' + str(std).replace(
        '.', '') + '/'

    test_dataset = CellDataset(root, psf_path, 'gaussian', 1.0, std)
    test_loader = DataLoader(test_dataset,
                             batch_size=b_size,
                             shuffle=False,
                             num_workers=1)

    model = UNet(mode='batch')

    state_dict = torch.load(os.path.join(model_path, model_name))
    state_dict = state_dict['model_state_dict']
    new_state_dict = OrderedDict()

    for k, v in state_dict.items():
        new_state_dict[k] = v

    model.load_state_dict(new_state_dict)
    model.eval()

    if use_gpu == 1:
        model.cuda()

    psnr_values_test = []
    ssim_values_test = []

    distorted_psnr_test = []
    distorted_ssim_test = []

    for i_batch, ((gt, image), psf, index, image_name, _,
                  std) in enumerate(tqdm(test_loader)):

        image = image.reshape((b_size, 1, image.shape[-2], image.shape[-1]))

        gt = gt.reshape((b_size, 1, gt.shape[-2], gt.shape[-1]))

        if use_gpu == 1:
            image = image.cuda()
            gt = gt.cuda()

        distorted_psnr = calc_psnr(image.clamp(gt.min(), gt.max()), gt)
        distorted_ssim = ssim(image.clamp(gt.min(), gt.max()), gt)

        output = model(image)

        psnr_test = calc_psnr(output.clamp(gt.min(), gt.max()), gt)
        s_sim_test = ssim(output.clamp(gt.min(), gt.max()), gt)

        psnr_values_test.append(psnr_test.item())
        ssim_values_test.append(s_sim_test.item())

        distorted_psnr_test.append(distorted_psnr.item())
        distorted_ssim_test.append(distorted_ssim.item())

        #Save image
        if visual == 1:

            if not os.path.exists(save_images_path):
                os.makedirs(save_images_path, exist_ok=True)

            io.imsave(os.path.join(save_images_path, 'output_' + str(image_name[0][:-4]) + '_' + \
                      str(model_name) + '_' + str(std.item()).replace('.', '') + '.png'), np.uint8(output[0][0].detach().cpu().numpy().clip(0,1) * 255.))

    print('Test on Gaussian noise with %.3f std: PSNR %.2f, SSIM %.4f, distorted PSNR %.2f, distorted SSIM %.4f' % (std, np.array(psnr_values_test).mean(), \
                                                                                              np.array(ssim_values_test).mean(), \
                                                                                              np.array(distorted_psnr_test).mean(), \
                                                                                              np.array(distorted_ssim_test).mean()))

    return