Ejemplo n.º 1
0
def evaluate(model, args):
    input_list = sorted(os.listdir(args.test_input_dir))
    gt_list = sorted(os.listdir(args.test_gt_dir))
    num = len(input_list)
    cumulative_psnr = 0
    cumulative_ssim = 0
    psnr_list = []
    ssim_list = []
    for i in range(num):
        prefix = input_list[i].split('_')[0]
        print('Processing image: %s' % (input_list[i]))
        img = cv2.imread(opj(args.test_input_dir, input_list[i]))
        gt = cv2.imread(opj(args.test_gt_dir, gt_list[i]))
        img = image_align(img)
        gt = image_align(gt)
        result = predict_single(model, img)
        result = np.array(result, dtype='uint8')
        cur_psnr = calc_psnr(result, gt)
        cur_ssim = calc_ssim(result, gt)
        print('PSNR is %.4f and SSIM is %.4f' % (cur_psnr, cur_ssim))
        cumulative_psnr += cur_psnr
        cumulative_ssim += cur_ssim
        psnr_list.append(cur_psnr)
        ssim_list.append(cur_ssim)
        out_name = prefix + "_" + "output.png"
        cv2.imwrite(opj(args.test_output_dir, out_name), result)
    print('In testing dataset, PSNR is %.4f and SSIM is %.4f' %
          (cumulative_psnr / num, cumulative_ssim / num))
    df = pd.DataFrame(np.array([psnr_list, ssim_list]).T,
                      columns=['psnr', 'ssim'])
    df.head()
    print(df.apply(status))
    return np.mean(ssim_list), np.mean(psnr_list)
Ejemplo n.º 2
0
def validate(test_list, arch, model, epoch, n_epochs):
    test_ref, test_lr, test_hr = test_list
    model.eval()

    psnr = 0
    with torch.no_grad():
        # Set mini-batch dataset
        ref = to_var(test_ref).detach()
        lr = to_var(test_lr).detach()
        hr = to_var(test_hr).detach()
        if arch == 'SSRNet':
            out, _, _, _, _, _ = model(lr, hr)
        elif arch == 'SSRSpat':
            _, out, _, _, _, _ = model(lr, hr)
        elif arch == 'SSRSpec':
            _, _, out, _, _, _ = model(lr, hr)
        else:
            out, _, _, _, _, _ = model(lr, hr)

        ref = ref.detach().cpu().numpy()
        out = out.detach().cpu().numpy()

        rmse = calc_rmse(ref, out)
        psnr = calc_psnr(ref, out)
        ergas = calc_ergas(ref, out)
        sam = calc_sam(ref, out)

        with open('ConSSFCNN.txt', 'a') as f:
            f.write(
                str(epoch) + ',' + str(rmse) + ',' + str(psnr) + ',' +
                str(ergas) + ',' + str(sam) + ',' + '\n')

    return psnr
Ejemplo n.º 3
0
def validate(model, inputs, labels):

    model.eval()

    raw_image_in = Variable(torch.FloatTensor(inputs['noisy_img'])).cuda()
    raw_image_var = Variable(torch.FloatTensor(inputs['variance'])).cuda()
    raw_image_gt = Variable(torch.FloatTensor(labels)).cuda()
    red_gain = Variable(torch.FloatTensor(inputs['red_gain'])).cuda()
    blue_gain = Variable(torch.FloatTensor(inputs['blue_gain'])).cuda()
    cam2rgb = Variable(torch.FloatTensor(inputs['cam2rgb'])).cuda()

    with torch.no_grad():
        raw_image_out = model(raw_image_in, raw_image_var)

    # Process RAW images to RGB
    rgb_image_in = process.process(raw_image_in, red_gain, blue_gain, cam2rgb)
    rgb_image_out = process.process(raw_image_out, red_gain, blue_gain,
                                    cam2rgb)
    rgb_image_gt = process.process(raw_image_gt, red_gain, blue_gain, cam2rgb)

    rgb_image_out = rgb_image_out[0, :, :, :].cpu().data.numpy().transpose(
        (1, 2, 0))
    rgb_image_out = np.array(rgb_image_out * 255.0, dtype='uint8')
    rgb_image_gt = rgb_image_gt[0, :, :, :].cpu().data.numpy().transpose(
        (1, 2, 0))
    rgb_image_gt = np.array(rgb_image_gt * 255.0, dtype='uint8')
    # print(np.shape(rgb_image_out), np.shape(rgb_image_gt))

    cur_psnr = calc_psnr(rgb_image_out, rgb_image_gt)
    cur_ssim = calc_ssim(rgb_image_out, rgb_image_gt)

    return cur_psnr, cur_ssim
Ejemplo n.º 4
0
def predict(args):
    model = Generator().cuda()
    model.load_state_dict(torch.load(opj(args.model_dir, args.g_weights)))

    if args.mode == 'demo':
        input_list = sorted(os.listdir(args.input_dir))
        num = len(input_list)
        for i in range(num):
            print('Processing image: %s' % (input_list[i]))
            img = cv2.imread(opj(args.input_dir, input_list[i]))
            img = image_align(img)
            result = predict_single(model, img)
            img_name = input_list[i].split('.')[0]
            cv2.imwrite(opj(args.output_dir, img_name + '.jpg'), result)

    elif args.mode == 'test':
        input_list = sorted(os.listdir(args.input_dir))
        gt_list = sorted(os.listdir(args.gt_dir))
        num = len(input_list)
        cumulative_psnr = 0
        cumulative_ssim = 0
        psnr_list = []
        ssim_list = []
        for i in range(num):
            print('Processing image: %s' % (input_list[i]))
            img = cv2.imread(opj(args.input_dir, input_list[i]))
            gt = cv2.imread(opj(args.gt_dir, gt_list[i]))
            img = image_align(img)
            gt = image_align(gt)
            result = predict_single(model, img)
            result = np.array(result, dtype='uint8')
            cur_psnr = calc_psnr(result, gt)
            cur_ssim = calc_ssim(result, gt)
            print('PSNR is %.4f and SSIM is %.4f' % (cur_psnr, cur_ssim))
            cumulative_psnr += cur_psnr
            cumulative_ssim += cur_ssim
            psnr_list.append(cur_psnr)
            ssim_list.append(cur_ssim)
        print('In testing dataset, PSNR is %.4f and SSIM is %.4f' %
              (cumulative_psnr / num, cumulative_ssim / num))
        with open('../try/psnr_list', 'wb') as fout:
            fout.write(pkl.dumps(psnr_list))
        with open('../try/ssim_list', 'wb') as fout:
            fout.write(pkl.dumps(ssim_list))
        df = pd.DataFrame(np.array([psnr_list, ssim_list]).T,
                          columns=['psnr', 'ssim'])
        df.head()
        print(df.apply(status))
    else:
        print('Mode Invalid!')
Ejemplo n.º 5
0
def split_result(args):
    # split the result with ssim (0,0.5);(0.5,0.82);(0.82,0.87);(0.87,1.0)
    split_point = [0.5, 0.82, 0.87]
    model = Generator().cuda()
    model.load_state_dict(torch.load(opj(args.model_dir, args.g_weights)))
    input_list = sorted(os.listdir(args.input_dir))
    gt_list = sorted(os.listdir(args.gt_dir))
    num = len(input_list)
    cumulative_psnr = 0
    cumulative_ssim = 0

    split_dir = args.split_dir
    interval1 = opj(split_dir, 'interval_1')
    interval2 = opj(split_dir, 'interval_2')
    interval3 = opj(split_dir, 'interval_3')
    interval4 = opj(split_dir, 'interval_4')
    interval_dirs = [interval1, interval2, interval3, interval4]
    for dir in interval_dirs:
        if not os.path.exists(dir):
            os.mkdir(dir)
    for i in range(num):
        print('Processing image: %s' % (input_list[i]))
        img = cv2.imread(opj(args.input_dir, input_list[i]))
        gt = cv2.imread(opj(args.gt_dir, gt_list[i]))
        img = image_align(img)
        gt = image_align(gt)
        result = predict_single(model, img)
        result = np.array(result, dtype='uint8')
        cur_psnr = calc_psnr(result, gt)
        cur_ssim = calc_ssim(result, gt)
        print('PSNR is %.4f and SSIM is %.4f' % (cur_psnr, cur_ssim))
        cumulative_psnr += cur_psnr
        cumulative_ssim += cur_ssim

        prefix = input_list[i].split('_')[0]
        if cur_ssim < split_point[0]:
            write_interval(interval_dirs[0], prefix, img, gt, result)
        elif cur_ssim < split_point[1]:
            write_interval(interval_dirs[1], prefix, img, gt, result)
        elif cur_ssim < split_point[2]:
            write_interval(interval_dirs[2], prefix, img, gt, result)
        else:
            write_interval(interval_dirs[3], prefix, img, gt, result)

    print('In testing dataset, PSNR is %.4f and SSIM is %.4f' %
          (cumulative_psnr / num, cumulative_ssim / num))
Ejemplo n.º 6
0
def get_analysis(img_path, gt_path):
    img = cv2.imread(img_path)
    # img = cv2.imread(args.input_dir + input_list[_i])
    gt = cv2.imread(gt_path)
    # gt = cv2.imread(args.gt_dir + gt_list[_i])
    dsize = (720, 480)
    img = cv2.resize(img, dsize)
    gt = cv2.resize(gt, dsize)

    img_tensor = prepare_img_to_tensor(img)
    with torch.no_grad():
        out = generator(img_tensor, times_in_attention, device)[-1]
        out = out.cpu().data
        out = out.numpy()
        out = out.transpose((0, 2, 3, 1))
        out = out[0, :, :, :] * 255.
        out = np.array(out, dtype='uint8')
        cur_psnr = calc_psnr(out, gt)
        cur_ssim = calc_ssim(out, gt)
    return cur_psnr, cur_ssim
            img = cv2.imread(args.input_dir + input_list[i])
            img = align_to_four(img)
            result = predict(img)
            img_name = input_list[i].split('.')[0]
            cv2.imwrite(args.output_dir + img_name + '.jpg', result)

    elif args.mode == 'test':
        input_list = sorted(os.listdir(args.input_dir))
        gt_list = sorted(os.listdir(args.gt_dir))
        num = len(input_list)
        cumulative_psnr = 0
        cumulative_ssim = 0
        for i in range(num):
            print('Processing image: %s' % (input_list[i]))
            img = cv2.imread(args.input_dir + input_list[i])
            gt = cv2.imread(args.gt_dir + gt_list[i])
            img = align_to_four(img)
            gt = align_to_four(gt)
            result = predict(img)
            result = np.array(result, dtype='uint8')
            cur_psnr = calc_psnr(result, gt)
            cur_ssim = calc_ssim(result, gt)
            print('PSNR is %.4f and SSIM is %.4f' % (cur_psnr, cur_ssim))
            cumulative_psnr += cur_psnr
            cumulative_ssim += cur_ssim
        print('In testing dataset, PSNR is %.4f and SSIM is %.4f' %
              (cumulative_psnr / num, cumulative_ssim / num))

    else:
        print('Mode Invalid!')
Ejemplo n.º 8
0
def main():
    if args.dataset == 'PaviaU':
        args.n_bands = 103
    elif args.dataset == 'Pavia':
        args.n_bands = 102
    elif args.dataset == 'Botswana':
        args.n_bands = 145
    elif args.dataset == 'KSC':
        args.n_bands = 176
    elif args.dataset == 'Urban':
        args.n_bands = 162
    elif args.dataset == 'IndianP':
        args.n_bands = 200
    elif args.dataset == 'Washington':
        args.n_bands = 191

    # Custom dataloader
    train_list, test_list = build_datasets(args.root, args.dataset,
                                           args.image_size,
                                           args.n_select_bands,
                                           args.scale_ratio)

    # Build the models
    if args.arch == 'SSFCNN':
        model = SSFCNN(args.scale_ratio, args.n_select_bands, args.n_bands)
    elif args.arch == 'ConSSFCNN':
        model = ConSSFCNN(args.scale_ratio, args.n_select_bands, args.n_bands)
    elif args.arch == 'TFNet':
        model = TFNet(args.scale_ratio, args.n_select_bands, args.n_bands)
    elif args.arch == 'ResTFNet':
        model = ResTFNet(args.scale_ratio, args.n_select_bands, args.n_bands)
    elif args.arch == 'MSDCNN':
        model = MSDCNN(args.scale_ratio, args.n_select_bands, args.n_bands)
    elif args.arch == 'SSRNET' or args.arch == 'SpatRNET' or args.arch == 'SpecRNET':
        model = SSRNET(args.arch, args.scale_ratio, args.n_select_bands,
                       args.n_bands)
    elif args.arch == 'SpatCNN':
        model = SpatCNN(args.scale_ratio, args.n_select_bands, args.n_bands)
    elif args.arch == 'SpecCNN':
        model = SpecCNN(args.scale_ratio, args.n_select_bands, args.n_bands)

    # Load the trained model parameters
    model_path = args.model_path.replace('dataset', args.dataset) \
                                .replace('arch', args.arch)
    if os.path.exists(model_path):
        model.load_state_dict(torch.load(model_path), strict=False)
        print('Load the chekpoint of {}'.format(model_path))

    test_ref, test_lr, test_hr = test_list
    model.eval()

    # Set mini-batch dataset
    ref = test_ref.float().detach()
    lr = test_lr.float().detach()
    hr = test_hr.float().detach()

    begin_time = time()
    if args.arch == 'SSRNET':
        out, _, _, _, _, _ = model(lr, hr)
    elif args.arch == 'SpatRNET':
        _, out, _, _, _, _ = model(lr, hr)
    elif args.arch == 'SpecRNET':
        _, _, out, _, _, _ = model(lr, hr)
    else:
        out, _, _, _, _, _ = model(lr, hr)
    end_time = time()
    run_time = (end_time - begin_time) * 1000

    print()
    print()
    print('Dataset:   {}'.format(args.dataset))
    print('Arch:   {}'.format(args.arch))
    print('ModelSize(M):   {}'.format(
        np.around(os.path.getsize(model_path) // 1024 / 1024.0, decimals=2)))
    print('Time(Ms):   {}'.format(np.around(run_time, decimals=2)))
    print()

    ref = ref.detach().cpu().numpy()
    out = out.detach().cpu().numpy()

    psnr = calc_psnr(ref, out)
    rmse = calc_rmse(ref, out)
    ergas = calc_ergas(ref, out)
    sam = calc_sam(ref, out)
    print('RMSE:   {:.4f};'.format(rmse))
    print('PSNR:   {:.4f};'.format(psnr))
    print('ERGAS:   {:.4f};'.format(ergas))
    print('SAM:   {:.4f}.'.format(sam))

    # bands order
    if args.dataset == 'Botswana':
        red = 47
        green = 14
        blue = 3
    elif args.dataset == 'PaviaU' or args.dataset == 'Pavia':
        red = 66
        green = 28
        blue = 0
    elif args.dataset == 'KSC':
        red = 28
        green = 14
        blue = 3
    elif args.dataset == 'Urban':
        red = 25
        green = 10
        blue = 0
    elif args.dataset == 'Washington':
        red = 54
        green = 34
        blue = 10
    elif args.dataset == 'IndianP':
        red = 28
        green = 14
        blue = 3

    lr = np.squeeze(test_lr.detach().cpu().numpy())
    lr_red = lr[red, :, :][:, :, np.newaxis]
    lr_green = lr[green, :, :][:, :, np.newaxis]
    lr_blue = lr[blue, :, :][:, :, np.newaxis]
    lr = np.concatenate((lr_blue, lr_green, lr_red), axis=2)
    lr = 255 * (lr - np.min(lr)) / (np.max(lr) - np.min(lr))
    lr = cv2.resize(lr, (out.shape[2], out.shape[3]),
                    interpolation=cv2.INTER_NEAREST)
    cv2.imwrite('./figs/{}_lr.jpg'.format(args.dataset), lr)

    out = np.squeeze(out)
    out_red = out[red, :, :][:, :, np.newaxis]
    out_green = out[green, :, :][:, :, np.newaxis]
    out_blue = out[blue, :, :][:, :, np.newaxis]
    out = np.concatenate((out_blue, out_green, out_red), axis=2)
    out = 255 * (out - np.min(out)) / (np.max(out) - np.min(out))
    cv2.imwrite('./figs/{}_{}_out.jpg'.format(args.dataset, args.arch), out)

    ref = np.squeeze(ref)
    ref_red = ref[red, :, :][:, :, np.newaxis]
    ref_green = ref[green, :, :][:, :, np.newaxis]
    ref_blue = ref[blue, :, :][:, :, np.newaxis]
    ref = np.concatenate((ref_blue, ref_green, ref_red), axis=2)
    ref = 255 * (ref - np.min(ref)) / (np.max(ref) - np.min(ref))
    cv2.imwrite('./figs/{}_ref.jpg'.format(args.dataset), ref)

    lr_dif = np.uint8(1.5 * np.abs((lr - ref)))
    lr_dif = cv2.cvtColor(lr_dif, cv2.COLOR_BGR2GRAY)
    lr_dif = cv2.applyColorMap(lr_dif, cv2.COLORMAP_JET)
    cv2.imwrite('./figs/{}_lr_dif.jpg'.format(args.dataset), lr_dif)

    out_dif = np.uint8(1.5 * np.abs((out - ref)))
    out_dif = cv2.cvtColor(out_dif, cv2.COLOR_BGR2GRAY)
    out_dif = cv2.applyColorMap(out_dif, cv2.COLORMAP_JET)
    cv2.imwrite('./figs/{}_{}_out_dif.jpg'.format(args.dataset, args.arch),
                out_dif)