def cumulative_ssim_plot(data_dict, path):
    """ cum. ssim scores between adversarials and corresponding originals """
    originals = data_dict.pop('original').tensors[0]
    for name, data in data_dict.items():
        ssim_scores = []
        image_pairs = zip(originals, data['images'])
        for original, adversarial in image_pairs:
            ssim_scores.append(metrics.SSIM(original, adversarial))
        counts, bins = numpy.histogram(ssim_scores, bins=1000, range=(0, 1))
        cdf = numpy.cumsum(counts) / numpy.sum(counts)
        plt.plot(
            numpy.vstack((bins, numpy.roll(bins, -1))).T.flatten()[:-2],
            numpy.vstack((cdf, cdf)).T.flatten() * 100)
    plt.title('Cumulative SSIM Scores')
    plt.legend(list(data_dict.keys()))
    plt.xlabel('SSIM', size=16)
    plt.ylabel('% below', size=16)
    filename = '{}cumulative_ssim.png'.format(path)
    plt.savefig(filename, transparent=False)
    plt.close()
Ejemplo n.º 2
0
Archivo: eval.py Proyecto: Timozen/srcc
def main():
    # paths to the models
    model_paths = [
        os.path.join("..", "models", "SRDense-Type-3_ep80.h5"),
        os.path.join("..", "models", "srdense-norm.h5"),
        os.path.join("..", "models", "srresnet85.h5"),
        os.path.join("..", "models", "gen_model90.h5"),
        os.path.join("..", "models", "srgan60.h5"),
        os.path.join("..", "models", "srgan-mse-20.h5"), "Nearest"
    ]

    # corresponding names of the models
    model_names = [
        "SRDense", "SRDense-norm", "SRResNet", "SRGAN-from-scratch",
        "SRGAN-percept.-loss", "SRGAN-mse", "NearestNeighbor"
    ]

    # corresponding tile shapes
    tile_shapes = [((168, 168), (42, 42)), ((168, 168), (42, 42)),
                   ((504, 504), (126, 126)), ((336, 336), (84, 84)),
                   ((504, 504), (126, 126)), ((504, 504), (126, 126)),
                   ((336, 336), (84, 84))]

    # used to load the models with custom loss functions
    loss = VGG_LOSS((504, 504, 3))
    custom_objects = [{}, {
        "tf": tf
    }, {
        "tf": tf
    }, {
        "tf": tf
    }, {
        "tf": tf
    }, {
        "tf": tf
    }, {}]

    # creating a list of test images
    # [(lr, hr)]
    DOWN_SCALING_FACTOR = 4
    INTERPOLATION = cv2.INTER_CUBIC

    test_images = []
    root = os.path.join("..", "DSIDS", "test")
    # iterating over all files in the test folder
    for img in os.listdir(root):
        # chekcing if the file is an image
        if not ".jpg" in img:
            continue
        hr = Utils.crop_into_lr_shape(cv2.cvtColor(
            cv2.imread(os.path.join(root, img), cv2.IMREAD_COLOR),
            cv2.COLOR_BGR2RGB),
                                      shape=(3024, 4032))
        lr = cv2.resize(hr, (0, 0),
                        fx=1 / DOWN_SCALING_FACTOR,
                        fy=1 / DOWN_SCALING_FACTOR,
                        interpolation=INTERPOLATION)
        test_images.append((lr, hr))

    if TILES:
        '''
        First calculating performance metrics on single image tiles
        '''

        tile_performance = {}
        for i, mp in tqdm(enumerate(model_paths)):
            keras.backend.clear_session()
            # first step: load the model
            if i < 6:
                model = load_model(mp, custom_objects=custom_objects[i])

            mse = []
            psnr = []
            ssim = []
            mssim = []
            # second step: iterate over the test images
            for test_pair in tqdm(test_images):
                # third step: tile the test image
                lr_tiles = Utils.tile_image(test_pair[0],
                                            shape=tile_shapes[i][1])
                hr_tiles = Utils.tile_image(test_pair[1],
                                            shape=tile_shapes[i][0])

                m = []
                p = []
                s = []
                ms = []

                # fourth step: iterate over the tiles
                for lr, hr in zip(lr_tiles, hr_tiles):
                    # fifth step: calculate the sr tile
                    if i < 2:
                        if i == 1:
                            lr = lr.astype(np.float64)
                            lr = lr / 255
                        tmp = np.squeeze(
                            model.predict(np.expand_dims(lr, axis=0)))
                        if i == 1:
                            tmp = tmp * 255
                        tmp[tmp < 0] = 0
                        tmp[tmp > 255] = 255
                        sr = tmp.astype(np.uint8)
                    elif i < 6:
                        sr = Utils.denormalize(
                            np.squeeze(model.predict(
                                np.expand_dims(rescale_imgs_to_neg1_1(lr),
                                               axis=0)),
                                       axis=0))
                    else:
                        sr = cv2.resize(lr, (0, 0),
                                        fx=4,
                                        fy=4,
                                        interpolation=cv2.INTER_NEAREST)

                    # sixth step: append the calculated metric
                    m.append(metrics.MSE(hr, sr))
                    p.append(metrics.PSNR(hr, sr))
                    s.append(metrics.SSIM(hr, sr))
                    ms.append(metrics.MSSIM(hr, sr))

                # seventh step: append the mean metric for this image
                mse.append(np.mean(m))
                psnr.append(np.mean(p))
                ssim.append(np.mean(s))
                mssim.append(np.mean(ms))

            # eight step: append the mean metric for this model
            tile_performance[model_names[i]] = (np.mean(mse), np.mean(psnr),
                                                np.mean(ssim), np.mean(mssim))

        # final output
        print("Performance on single tiles:")
        f = open("tile_performance.txt", "w")
        for key in tile_performance:
            print(
                key + ":   MSE = " + str(tile_performance[key][0]) +
                ", PSNR = " + str(tile_performance[key][1]) + ", SSIM = " +
                str(tile_performance[key][2]),
                ", MSSIM = " + str(tile_performance[key][3]))
            f.write(key + " " + str(tile_performance[key][0]) + " " +
                    str(tile_performance[key][1]) + " " +
                    str(tile_performance[key][2]) + " " +
                    str(tile_performance[key][3]) + "\n")
        f.close()

    if WHOLE_LR:
        '''
        Second calculating performance metrics on a single upscaled image
        '''

        img_performance = {}
        for i, mp in tqdm(enumerate(model_paths)):
            keras.backend.clear_session()
            # first step: load the model
            if i < 6:
                model = load_model(mp, custom_objects=custom_objects[i])

                # second step: changing the input layer
                _in = Input(shape=test_images[0][0].shape)
                _out = model(_in)
                _model = Model(_in, _out)

            mse = []
            psnr = []
            ssim = []
            mssim = []
            # third step: iterate over the test images
            for test_pair in tqdm(test_images):
                # fourth step: calculate the sr image
                try:
                    if i < 2:
                        if i == 1:
                            lr = test_pair[0].astype(np.float64)
                            lr = lr / 255
                        else:
                            lr = test_pair[0]
                        tmp = np.squeeze(
                            _model.predict(np.expand_dims(lr, axis=0)))
                        if i == 1:
                            tmp = tmp * 255
                        tmp[tmp < 0] = 0
                        tmp[tmp > 255] = 255
                        sr = tmp.astype(np.uint8)
                    elif i < 6:
                        sr = Utils.denormalize(
                            np.squeeze(_model.predict(
                                np.expand_dims(rescale_imgs_to_neg1_1(
                                    test_pair[0]),
                                               axis=0)),
                                       axis=0))
                    else:
                        sr = cv2.resize(test_pair[0], (0, 0),
                                        fx=4,
                                        fy=4,
                                        interpolation=cv2.INTER_NEAREST)

                    # fifth step: append the metric for this image
                    mse.append(metrics.MSE(test_pair[1], sr))
                    psnr.append(metrics.PSNR(test_pair[1], sr))
                    ssim.append(metrics.SSIM(test_pair[1], sr))
                    mssim.append(metrics.MSSIM(test_pair[1], sr))
                except:
                    mse.append("err")
                    psnr.append("err")
                    ssim.append("err")
                    mssim.append("err")

            # sixth step: append the mean metric for this model
            try:
                img_performance[model_names[i]] = (np.mean(mse), np.mean(psnr),
                                                   np.mean(ssim),
                                                   np.mean(mssim))
            except:
                img_performance[model_names[i]] = ("err", "err", "err", "err")

        # final output
        print("Performance on whole lr:")
        f = open("whole_lr_performance.txt", "w")
        for key in img_performance:
            print(
                key + ":   MSE = " + str(img_performance[key][0]) +
                ", PSNR = " + str(img_performance[key][1]) + ", SSIM = " +
                str(img_performance[key][2]),
                ", MSSIM = " + str(img_performance[key][3]))
            f.write(key + " " + str(img_performance[key][0]) + " " +
                    str(img_performance[key][1]) + " " +
                    str(img_performance[key][2]) + " " +
                    str(img_performance[key][3]) + "\n")
        f.close()

    if STITCHED:
        '''
        Second calculating performance metrics on a stitched image
        '''

        stitch_performance = {}
        for i, mp in tqdm(enumerate(model_paths)):
            keras.backend.clear_session()
            # first step: load the model
            if i < 6:
                model = load_model(mp, custom_objects=custom_objects[i])

            mse = []
            psnr = []
            ssim = []
            mssim = []

            o_mse = []
            o_psnr = []
            o_ssim = []
            o_mssim = []
            # second step: iterate over the test images
            for test_pair in tqdm(test_images):
                # third step: tile the test image
                lr_tiles = Utils.tile_image(test_pair[0],
                                            shape=tile_shapes[i][1])
                lr_tiles_overlap = Utils.tile_image(test_pair[0],
                                                    shape=tile_shapes[i][1],
                                                    overlap=True)

                sr_tiles = []
                sr_tiles_overlap = []
                # fourth step: iterate over the tiles
                for lr in lr_tiles:
                    # fifth step: calculate the sr tiles
                    if i < 2:
                        if i == 1:
                            lr = lr.astype(np.float64)
                            lr = lr / 255
                        tmp = np.squeeze(
                            model.predict(np.expand_dims(lr, axis=0)))
                        if i == 1:
                            tmp = tmp * 255
                        tmp[tmp < 0] = 0
                        tmp[tmp > 255] = 255
                        sr = tmp.astype(np.uint8)
                        sr_tiles.append(sr)
                    elif i < 6:
                        sr_tiles.append(
                            Utils.denormalize(
                                np.squeeze(model.predict(
                                    np.expand_dims(rescale_imgs_to_neg1_1(lr),
                                                   axis=0)),
                                           axis=0)))
                    else:
                        sr_tiles.append(
                            cv2.resize(lr, (0, 0),
                                       fx=4,
                                       fy=4,
                                       interpolation=cv2.INTER_NEAREST))

                for lr in lr_tiles_overlap:
                    # fifth step: calculate the sr tiles
                    if i < 2:
                        if i == 1:
                            lr = lr.astype(np.float64)
                            lr = lr / 255
                        tmp = np.squeeze(
                            model.predict(np.expand_dims(lr, axis=0)))
                        if i == 1:
                            tmp = tmp * 255
                        tmp[tmp < 0] = 0
                        tmp[tmp > 255] = 255
                        sr = tmp.astype(np.uint8)
                        sr_tiles_overlap.append(sr)
                    elif i < 6:
                        sr_tiles_overlap.append(
                            Utils.denormalize(
                                np.squeeze(model.predict(
                                    np.expand_dims(rescale_imgs_to_neg1_1(lr),
                                                   axis=0)),
                                           axis=0)))
                    else:
                        sr_tiles_overlap.append(
                            cv2.resize(lr, (0, 0),
                                       fx=4,
                                       fy=4,
                                       interpolation=cv2.INTER_NEAREST))

                # sixth step: stitch the image
                sr_simple = ImageStitching.stitch_images(
                    sr_tiles, test_pair[1].shape[1], test_pair[1].shape[0],
                    sr_tiles[0].shape[1], sr_tiles[0].shape[0],
                    test_pair[1].shape[1] // sr_tiles[0].shape[1],
                    test_pair[1].shape[0] // sr_tiles[0].shape[0])
                sr_advanced = ImageStitching.stitching(
                    sr_tiles_overlap,
                    LR=None,
                    image_size=(test_pair[1].shape[0], test_pair[1].shape[1]),
                    adjustRGB=False,
                    overlap=True)

                # seventh step: append the mean metric for this image
                mse.append(metrics.MSE(test_pair[1], sr_simple))
                psnr.append(metrics.PSNR(test_pair[1], sr_simple))
                ssim.append(metrics.SSIM(test_pair[1], sr_simple))
                mssim.append(metrics.MSSIM(test_pair[1], sr_simple))

                o_mse.append(metrics.MSE(test_pair[1], sr_advanced))
                o_psnr.append(metrics.PSNR(test_pair[1], sr_advanced))
                o_ssim.append(metrics.SSIM(test_pair[1], sr_advanced))
                o_mssim.append(metrics.MSSIM(test_pair[1], sr_advanced))

            # ninth step: append the mean metric for this model
            stitch_performance[model_names[i]] = [
                (np.mean(mse), np.mean(psnr), np.mean(ssim), np.mean(mssim)),
                (np.mean(o_mse), np.mean(o_psnr), np.mean(o_ssim),
                 np.mean(o_mssim))
            ]

        # final output
        print("Performance on stitched images:")
        f = open("stitch_performance.txt", "w")
        for key in stitch_performance:
            print(
                "simple stitch:  " + key + ":   MSE = " +
                str(stitch_performance[key][0][0]) + ", PSNR = " +
                str(stitch_performance[key][0][1]) + ", SSIM = " +
                str(stitch_performance[key][0][2]),
                ", MSSIM = " + str(stitch_performance[key][0][3]))
            print(
                "advanced stitch:  " + key + ":   MSE = " +
                str(stitch_performance[key][1][0]) + ", PSNR = " +
                str(stitch_performance[key][1][1]) + ", SSIM = " +
                str(stitch_performance[key][1][2]),
                ", MSSIM = " + str(stitch_performance[key][1][3]))
            f.write(key + " " + str(stitch_performance[key][0][0]) + " " +
                    str(stitch_performance[key][0][1]) + " " +
                    str(stitch_performance[key][0][2]) + " " +
                    str(stitch_performance[key][0][3]) + "\n")
            f.write(key + " " + str(stitch_performance[key][1][0]) + " " +
                    str(stitch_performance[key][1][1]) + " " +
                    str(stitch_performance[key][1][2]) + " " +
                    str(stitch_performance[key][1][3]) + "\n")
        f.close()
Ejemplo n.º 3
0
    if not os.path.exists(save_dirA) or not os.path.exists(save_dirB):
        os.makedirs(save_dirA)
        os.makedirs(save_dirB)
    ### Build model
    netG_A2C = eval(checkA[0].replace('@G2LAB', ''))(1, 1, int(checkA[2][1])).to(opt.device)
    netG_C2B = eval(checkB[0].replace('@G2LAB', ''))(1, 2).to(opt.device)
    # load check point
    netG_A2C.load_state_dict(torch.load(os.path.join(Check_DIR, os.path.basename(args.netGA))))
    netG_C2B.load_state_dict(torch.load(os.path.join(Check_DIR, os.path.basename(args.netGB))))
    netG_A2C.eval()
    netG_C2B.eval()
    print("Starting test Loop...")
    # setup data loader
    data_loader = DataLoader(testset, opt.batch_size, num_workers=opt.num_works,
                             shuffle=False, pin_memory=True, )
    evaluators = [metrics.MSE(), metrics.PSNR(), metrics.AE(), metrics.SSIM()]
    performs = [[] for i in range(len(evaluators))]
    for idx, sample in enumerate(data_loader):
        realA = sample['src'].to(opt.device)
#         realA -= 0.5
        realB = sample['tar'].to(opt.device)
#         realB -= 0.5
        # Y = 0.2125 R + 0.7154 G + 0.0721 B [RGB2Gray, 3=>1 ch]
        realBC = realB[:,:1,:,:]
        sf = int(checkA[2][1])
        realBA = nn.functional.interpolate(realBC, scale_factor=1. / sf, mode='bilinear')
        realBA = nn.functional.interpolate(realBA, scale_factor=sf, mode='bilinear')
#         realAA = nn.functional.interpolate(realA, scale_factor=1. / sf)
        realAA = realA
        fake_AC = netG_A2C(realAA)
        fake_AB = netG_C2B(fake_AC)
Ejemplo n.º 4
0
 def __init__(self):
     super(DSSIMLoss, self).__init__()
     self.criterion = metrics.SSIM()