Beispiel #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model',
                        choices=['net-lin', 'net'],
                        default='net-lin',
                        help='net-lin or net')
    parser.add_argument('--net',
                        choices=['squeeze', 'alex', 'vgg'],
                        default='alex',
                        help='squeeze, alex, or vgg')
    parser.add_argument('--version', type=str, default='0.1')
    args = parser.parse_args()

    ex_ref = load_image('./PerceptualSimilarity/imgs/ex_ref.png')
    ex_p0 = load_image('./PerceptualSimilarity/imgs/ex_p0.png')
    ex_p1 = load_image('./PerceptualSimilarity/imgs/ex_p1.png')

    session = tf.Session()

    image0_ph = tf.placeholder(tf.float32)
    image1_ph = tf.placeholder(tf.float32)
    lpips_fn = session.make_callable(
        lpips_tf.lpips(image0_ph,
                       image1_ph,
                       model=args.model,
                       net=args.net,
                       version=args.version), [image0_ph, image1_ph])

    ex_d0 = lpips_fn(ex_ref, ex_p0)
    ex_d1 = lpips_fn(ex_ref, ex_p1)

    print('Distances: (%.3f, %.3f)' % (ex_d0, ex_d1))
Beispiel #2
0
        def image_lpips(y_true, y_pred):
            y_true = join_reim_mag_output(y_true)
            y_pred = join_reim_mag_output(y_pred)

            y_true = K.tile(y_true, [1, 1, 1, 3])
            y_pred = K.tile(y_pred, [1, 1, 1, 3])

            return lpips_tf.lpips(y_true, y_pred, model='net-lin', net='alex')
Beispiel #3
0
def lpips(input0, input1):
    if input0.shape[-1].value == 1:
        input0 = tf.tile(input0, [1] * (input0.shape.ndims - 1) + [3])
    if input1.shape[-1].value == 1:
        input1 = tf.tile(input1, [1] * (input1.shape.ndims - 1) + [3])

    distance = lpips_tf.lpips(input0, input1)
    return -distance
Beispiel #4
0
        def image_lpips(y_true, y_pred):
            y_true_image = utils.convert_tensor_to_image_domain(y_true)
            y_pred_image = utils.convert_tensor_to_image_domain(y_pred)
            y_true = join_reim_mag_output(y_true_image)
            y_pred = join_reim_mag_output(y_pred_image)

            y_true = K.tile(y_true, [1, 1, 1, 3])
            y_pred = K.tile(y_pred, [1, 1, 1, 3])
            return lpips_tf.lpips(y_true, y_pred, model='net-lin', net='alex')
Beispiel #5
0
    def loss(self, y_true, y_pred):
        """Compute loss."""
        y_true_rgb = tf.concat(3 * [y_true], axis=-1)
        y_pred_rgb = tf.concat(3 * [y_pred], axis=-1)

        return lpips_tf.lpips(y_true_rgb,
                              y_pred_rgb,
                              model=self.model,
                              net=self.net)
Beispiel #6
0
    def get_alexnet(inp, otp):
        # https://stackoverflow.com/questions/38376478/changing-the-scale-of-a-tensor-in-tensorflow
        with tf.name_scope('perceptual_loss'):
            axis = None
            inp_normalized = tf.div(
                tf.subtract(inp, tf.reduce_min(inp, axis=axis)),
                tf.subtract(tf.reduce_max(inp, axis=axis),
                            tf.reduce_min(inp, axis=axis)))

            otp_normalized = tf.div(
                tf.subtract(otp, tf.reduce_min(otp, axis=axis)),
                tf.subtract(tf.reduce_max(otp, axis=axis),
                            tf.reduce_min(otp, axis=axis)))

            permutation = [0, 2, 3, 1]
            inp_normalized = tf.transpose(inp_normalized, permutation)
            otp_normalized = tf.transpose(otp_normalized, permutation)
            # pad_val = (224-160)//2
            # paddings = tf.constant([[0,0],[pad_val,pad_val],[pad_val,pad_val],[0,0]])
            # ksizes = [1, 64, 64, 1]
            # strides = [1, 32, 32, 1]
            # rates = [1,1,1,1]
            # padding = 'SAME'
            # patches = tf.image.extract_image_patches(inp_normalized, ksizes, strides, rates, padding)
            # patches_shape = tf.shape(patches)
            # h = tf.shape(inp_normalized)[1]
            # w = tf.shape(inp_normalized)[2]
            # c = tf.shape(inp_normalized)[3]
            # patches = tf.reshape(patches, [tf.reduce_prod(patches_shape[0:3]), h, w, int(c)])
            # inp_normalized = tf.pad(inp_normalized, paddings, 'CONSTANT')
            # otp_normalized = tf.pad(otp_normalized, paddings, 'CONSTANT')
            translator = model_translator(inp_normalized[:, :64, :64, :],
                                          otp_normalized[:, :64, :64, :],
                                          network='alexnet')

            translator.get_weights()
            # model_translator.get_alexnet_diff(inp_normalized[:,:64,:64,:], otp_normalized[:,:64,:64,:])
            # img_content_normalized = (inp - np.min(inp)) / (
            #         np.max(inp) - np.min(inp))
            # img_content_normalized = np.transpose(np.expand_dims(img_content_normalized, axis=0),
            #                                       [0, 2, 3, 1])
            # out_img_content_normalized = (otp - np.min(otp)) / (
            #         np.max(otp) - np.min(otp))
            # print(out_img_content_normalized.shape)
            # out_img_content_normalized = np.transpose(out_img_content_normalized,
            #                                           [0, 2, 3, 1])

            distance_t = lpips_tf.lpips(inp_normalized,
                                        otp_normalized,
                                        model='net-lin',
                                        net='vgg',
                                        data_format='NCHW')
            return distance_t
Beispiel #7
0
def perceptual_loss_img(y_true, y_pred, model="net-lin", net="vgg"):
    """Compute the perceptual loss (PL) between two images.

    Parameters
    ----------
    y_true : np.array
        Image 1. Either (h, w) of (N, h, w). If (N, h, w), the decorator `multiple_images_decorator` takes care of
        the sample dimension.

    y_pred : np.array
        Image 2. Either (h, w) of (N, h, w). If (N, h, w), the decorator `multiple_images_decorator` takes care of
        the sample dimension.

    model: str, {'net', 'net-lin'}
        Type of model (cf lpips_tf package).

    net: str, {'vgg', 'alex'}
        Type of network (cf lpips_tf package).

    Return
    ------
    pl: float
        The Perceptual Loss (PL) metric. Loss metric, the lower the more similar the images are.

    Notes
    -----
    We use the decorator just to make sure we do not run out of memory during a forward pass. Also,
    its fully convolutional but if the images are too small then might run into issues.
    """
    # gray2rgb
    y_true = np.stack((y_true, ) * 3, axis=-1)
    y_pred = np.stack((y_pred, ) * 3, axis=-1)

    image0_ph = tf.placeholder(tf.float32)
    image1_ph = tf.placeholder(tf.float32)

    distance_t = lpips_tf.lpips(image0_ph, image1_ph, model=model, net=net)

    with tf.Session() as session:
        pl = session.run(distance_t,
                         feed_dict={
                             image0_ph: y_true,
                             image1_ph: y_pred
                         })

    tf.reset_default_graph()

    return pl.item()
    def recontrruction_error(self, batch_size=32):
        image_shape = (batch_size, 64, 64, 3)
        image0 = np.random.random(image_shape)
        image1 = np.random.random(image_shape)
        image0_ph = tf.placeholder(tf.float32)
        image1_ph = tf.placeholder(tf.float32)

        distance_t = lpips_tf.lpips(image0_ph,
                                    image1_ph,
                                    model='net-lin',
                                    net='alex')
        to_batch = lambda x: np.concatenate(
            [np.tile(imgs, (1, 1, 1, 3)) for imgs in x], axis=0)
        total_distance = None

        for i in range(self.n_images):
            generated_imgs = to_batch([
                self.converter.generator.predict(
                    [self.content_codes[[j]], self.class_adain_params[[i]]])[0]
                for j in range(self.n_images) if i != j
            ])
            same_content_imgs = to_batch(
                [self.curr_imgs[j] for j in range(self.n_images) if i != j])
            with tf.Session() as session:
                distance = session.run(
                    distance_t,
                    feed_dict={
                        image0_ph:
                        generated_imgs[idx * batch_size:(idx + 1) *
                                       batch_size],
                        image1_ph:
                        same_content_imgs[idx * batch_size:(idx + 1) *
                                          batch_size]
                    })
                total_distance = np.mean(distance)
#                 for idx in range(int(np.ceil((self.n_images-1) / batch_size))):
#                     distance = session.run(distance_t, feed_dict={image0_ph: generated_imgs[idx * batch_size:(idx + 1) * batch_size],
#                                                                   image1_ph: same_content_imgs[idx * batch_size:(idx + 1) * batch_size]})
#                     if total_distance is None:
#                         total_distance = np.mean(distance)
#                     else:
#                         total_distance = ((total_distance*idx * batch_size) + np.sum(distance)) / ((idx * batch_size) + len(distance))
        return total_distance
Beispiel #9
0
def lpips_distance(model,
                   x,
                   attack_name=None,
                   adv_x=None,
                   norm=2,
                   distance_only=False,
                   run_parallel=0):
    import tensorflow as tf
    if adv_x is None:
        adv_x = craft_attack(model, x, attack_name, norm)

    batch_size = 32
    image_shape = (batch_size, x.shape[1], x.shape[2], x.shape[3])
    image0_ph = tf.placeholder(tf.float32)
    image1_ph = tf.placeholder(tf.float32)

    distance_t = lpips_tf.lpips(image0_ph,
                                image1_ph,
                                model='net-lin',
                                net='alex',
                                model_dir="./metrics")

    def f(i):
        with tf.Session() as session:
            distance = session.run(distance_t,
                                   feed_dict={
                                       image0_ph: x[i],
                                       image1_ph: adv_x[i]
                                   })

            if distance_only:
                return distance
            else:
                return x[i], adv_x[i], distance

    if run_parallel > 0:
        with Pool(run_parallel) as p:
            return p.map(f, range(len(x)))
    else:
        for i in range(len(x)):
            yield f(i)
Beispiel #10
0
def main():
    args, save_dir, load_dir = check_args(parse_arguments())
    global BATCH_SIZE
    BATCH_SIZE = args.batch_size
    config_path = os.path.join(load_dir, 'params.pkl')
    if os.path.exists(config_path):
        config = pickle.load(open(config_path, 'rb'))
        output_width = config['output_width']
        output_height = config['output_height']
        resolution = output_height
        z_dim = config['z_dim']
    else:
        output_width = output_height = 64
        resolution = 64
        z_dim = 100

    ### open session
    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True
    with tf.Session(config=run_config) as sess:
        dcgan = DCGAN(sess,
                      output_width=output_width,
                      output_height=output_height,
                      batch_size=BATCH_SIZE,
                      sample_num=BATCH_SIZE,
                      z_dim=z_dim)

        load_success, load_counter = dcgan.load(load_dir)
        if not load_success:
            raise Exception("Checkpoint not found in " + load_dir)

        ### initialization
        init_val_ph = None
        init_val = {'pos': None, 'neg': None}
        if args.initialize_type == 'zero':
            z = tf.Variable(tf.zeros([BATCH_SIZE, z_dim], tf.float32),
                            name='latent_z')

        elif args.initialize_type == 'random':
            np.random.seed(RANDOM_SEED)
            init_val_np = np.random.normal(size=(z_dim, ))
            init = np.tile(init_val_np, (BATCH_SIZE, 1)).astype(np.float32)
            z = tf.Variable(init, name='latent_z')

        elif args.initialize_type == 'nn':
            idx = 0
            init_val['pos'] = np.load(os.path.join(args.nn_dir,
                                                   'pos_z.npy'))[:, idx, :]
            init_val['neg'] = np.load(os.path.join(args.nn_dir,
                                                   'neg_z.npy'))[:, idx, :]
            init_val_ph = tf.placeholder(dtype=tf.float32,
                                         name='init_ph',
                                         shape=(BATCH_SIZE, z_dim))
            z = tf.Variable(init_val_ph, name='latent_z')

        else:
            raise NotImplementedError

        ### define variables
        x = tf.placeholder(tf.float32,
                           shape=(BATCH_SIZE, resolution, resolution, 3))
        x_hat = dcgan.generator(z, is_training=False)

        ### loss
        if args.distance == 'l2':
            print('use distance: l2')
            loss_l2 = tf.reduce_mean(tf.square(x_hat - x), axis=[1, 2, 3])
            vec_loss = loss_l2
            vec_losses = {'l2': loss_l2}

        elif args.distance == 'l2-lpips':
            print('use distance: lpips + l2')
            loss_l2 = tf.reduce_mean(tf.square(x_hat - x), axis=[1, 2, 3])
            loss_lpips = lpips_tf.lpips(x_hat,
                                        x,
                                        normalize=False,
                                        model='net-lin',
                                        net='vgg',
                                        version='0.1')
            vec_losses = {'l2': loss_l2, 'lpips': loss_lpips}
            vec_loss = loss_l2 + LAMBDA2 * loss_lpips
        else:
            raise NotImplementedError

        ## regularizer
        norm = tf.reduce_sum(tf.square(z), axis=1)
        norm_penalty = (norm - z_dim)**2

        if args.if_norm_reg:
            loss = tf.reduce_mean(
                vec_loss) + LAMBDA3 * tf.reduce_mean(norm_penalty)
            vec_losses['norm'] = norm_penalty
        else:
            loss = tf.reduce_mean(vec_loss)

        ### set up optimizer
        opt = tf.contrib.opt.ScipyOptimizerInterface(
            loss,
            var_list=[z],
            method='L-BFGS-B',
            options={'maxfun': args.maxfunc})

        ### load query images
        pos_data_paths = get_filepaths_from_dir(args.pos_data_dir,
                                                ext='png')[:args.data_num]
        pos_query_imgs = np.array(
            [read_image(f, resolution) for f in pos_data_paths])

        neg_data_paths = get_filepaths_from_dir(args.neg_data_dir,
                                                ext='png')[:args.data_num]
        neg_query_imgs = np.array(
            [read_image(f, resolution) for f in neg_data_paths])

        ### run the optimization on query images
        query_loss, query_z, query_xhat = optimize_z(
            sess, z, x, x_hat, init_val_ph, init_val['pos'], pos_query_imgs,
            check_folder(os.path.join(save_dir, 'pos_results')), opt, vec_loss,
            vec_losses)
        save_files(save_dir, ['pos_loss'], [query_loss])

        query_loss, query_z, query_xhat = optimize_z(
            sess, z, x, x_hat, init_val_ph, init_val['neg'], neg_query_imgs,
            check_folder(os.path.join(save_dir, 'neg_results')), opt, vec_loss,
            vec_losses)
        save_files(save_dir, ['neg_loss'], [query_loss])
Beispiel #11
0
import numpy as np
import tensorflow as tf
import lpips_tf, glob, os, scipy


def read_color_img(filename):
    return scipy.misc.imread(filename).astype('float32') / 255


batch_size = 1
image_shape = (batch_size, None, None, 3)
image0_ph = tf.placeholder(tf.float32)
image1_ph = tf.placeholder(tf.float32)

distance_t = lpips_tf.lpips(image0_ph, image1_ph, model='net-lin', net='alex')

folder = 'E://DepthOfField//test_vis//sr_x4//rgb'
gt_lists = glob.glob(os.path.join(folder, '*HR.png'))
distance = []
with tf.Session() as sess:
    for i_file in range(len(gt_lists)):
        gt_name = gt_lists[i_file]
        img_name = gt_name.replace('HR.png', 'LR4_rcan_real.png')
        label = read_color_img(gt_name)
        output = read_color_img(img_name)
        d = sess.run(distance_t,
                     feed_dict={
                         image0_ph: label,
                         image1_ph: output
                     })
        distance.append(d)
def compute_metric(dir1, dir2, mode, mask=None, thre=0):

    if mode == 'perceptual':
        from models import dist_model as dm
        import torch
        from util import util
        global model
        if model is None:
            cwd = os.getcwd()
            os.chdir('../../PerceptualSimilarity')
            model = dm.DistModel()
            #model.initialize(model='net-lin',net='alex',use_gpu=True, spatial=True)
            model.initialize(model='net-lin', net='alex', use_gpu=True)
            print('Model [%s] initialized' % model.name())
            os.chdir(cwd)

    if mode.startswith('perceptual_tf'):
        sys.path += ['../../lpips-tensorflow']
        import lpips_tf
        import tensorflow as tf
        image0_ph = tf.placeholder(tf.float32, [1, None, None, 3])
        image1_ph = tf.placeholder(tf.float32, [1, None, None, 3])
        if mode == 'perceptual_tf':
            distance_t = lpips_tf.lpips(image0_ph,
                                        image1_ph,
                                        model='net-lin',
                                        net='alex')
        elif mode == 'perceptual_tf_vgg':
            distance_t = lpips_tf.lpips(image0_ph,
                                        image1_ph,
                                        model='net-lin',
                                        net='vgg')
        else:
            raise
        sess = tf.Session()

    if mode == 'l2_with_gradient':
        import demo
        import tensorflow as tf
        output = tf.placeholder(tf.float32, shape=[1, None, None, None])
        gradient = demo.image_gradients(output)
        sess = tf.Session()

    files1 = os.listdir(dir1)
    files2 = os.listdir(dir2)
    img_files1 = sorted([
        file for file in files1
        if file.endswith('.png') or file.endswith('.jpg')
    ])
    img_files2 = sorted([
        file for file in files2
        if file.endswith('.png') or file.endswith('.jpg')
    ])

    if '--prefix' in sys.argv:
        prefix_idx = sys.argv.index('--prefix')
        prefix = sys.argv[prefix_idx + 1]
        img_files2 = [file for file in img_files2 if file.startswith(prefix)]

    if mask is not None:
        mask_files = []
        for dir in sorted(mask):
            files3 = os.listdir(dir)
            add_mask_files = sorted([
                os.path.join(dir, file) for file in files3
                if file.startswith('mask')
            ])
            if len(add_mask_files) == 0:
                files_to_dilate = sorted([
                    file for file in files3
                    if file.startswith('g_intermediates')
                ])
                for file in files_to_dilate:
                    mask_arr = numpy.load(os.path.join(dir, file))
                    mask_arr = numpy.squeeze(mask_arr)
                    mask_arr = mask_arr >= thre
                    dilated_mask = dilation(mask_arr, disk(10))
                    dilated_filename = 'mask_' + file
                    numpy.save(os.path.join(dir, dilated_filename),
                               dilated_mask.astype('f'))
                    add_mask_files.append(os.path.join(dir, dilated_filename))
            mask_files += add_mask_files
        print(mask_files)
        assert len(mask_files) == len(img_files1)

    skip_first_n = 0
    if '--skip_first_n' in sys.argv:
        try:
            skip_first_n = int(sys.argv[sys.argv.index('--skip_first_n') + 1])
        except:
            skip_first_n = 0

    skip_last_n = 0
    if '--skip_last_n' in sys.argv:
        try:
            skip_last_n = int(sys.argv[sys.argv.index('--skip_last_n') + 1])
        except:
            skip_last_n = 0

    img_files2 = img_files2[skip_first_n:]
    if skip_last_n > 0:
        img_files2 = img_files2[:-skip_last_n]
    assert len(img_files1) == len(img_files2)

    # locate GT gradient directory
    if mode == 'l2_with_gradient':
        head, tail = os.path.split(dir2)
        gradient_gt_dir = os.path.join(head, tail[:-3] + 'grad')
        if not os.path.exists(gradient_gt_dir):
            printf("dir not found,", gradient_gt_dir)
            raise
        gradient_gt_files = os.listdir(gradient_gt_dir)
        gradient_gt_files = sorted(
            [file for file in gradient_gt_files if file.endswith('.npy')])
        assert len(img_files1) == len(gradient_gt_files)

    vals = numpy.empty(len(img_files1))
    #if mode == 'perceptual':
    #    global model

    for ind in range(len(img_files1)):
        if mode == 'ssim' or mode == 'l2' or mode == 'l2_with_gradient':
            img1 = skimage.img_as_float(
                skimage.io.imread(os.path.join(dir1, img_files1[ind])))
            img2 = skimage.img_as_float(
                skimage.io.imread(os.path.join(dir2, img_files2[ind])))
            if mode == 'ssim':
                #vals[ind] = skimage.measure.compare_ssim(img1, img2, datarange=img2.max()-img2.min(), multichannel=True)
                metric_val = skimage.measure.compare_ssim(
                    img1,
                    img2,
                    datarange=img2.max() - img2.min(),
                    multichannel=True)
            else:
                #vals[ind] = numpy.mean((img1 - img2) ** 2) * 255.0 * 255.0
                metric_val = ((img1 - img2)**2) * 255.0 * 255.0
            if mode == 'l2_with_gradient':
                metric_val = numpy.mean(metric_val, axis=2)
                gradient_gt = numpy.load(
                    os.path.join(gradient_gt_dir, gradient_gt_files[ind]))
                dx, dy = sess.run(gradient,
                                  feed_dict={
                                      output:
                                      numpy.expand_dims(img1[..., ::-1],
                                                        axis=0)
                                  })
                #is_edge = skimage.feature.canny(skimage.color.rgb2gray(img1))
                dx_ground = gradient_gt[:, :, :, 1:4]
                dy_ground = gradient_gt[:, :, :, 4:]
                edge_ground = gradient_gt[:, :, :, 0]
                gradient_loss_term = numpy.mean(
                    (dx - dx_ground)**2.0 + (dy - dy_ground)**2.0, axis=3)
                metric_val += numpy.squeeze(
                    0.2 * 255.0 * 255.0 * gradient_loss_term * edge_ground *
                    edge_ground.size / numpy.sum(edge_ground))

            #if mode == 'l2' and mask is not None:
            #    img_diff = (img1 - img2) ** 2.0
            #    mask_img = numpy.load(mask_files[ind])
            #    img_diff *= numpy.expand_dims(mask_img, axis=2)
            #    vals[ind] = (numpy.sum(img_diff) / numpy.sum(mask_img * 3)) * 255.0 * 255.0
        elif mode == 'perceptual':
            img1 = util.im2tensor(
                util.load_image(os.path.join(dir1, img_files1[ind])))
            img2 = util.im2tensor(
                util.load_image(os.path.join(dir2, img_files2[ind])))
            #vals[ind] = numpy.mean(model.forward(img1, img2)[0])
            metric_val = numpy.expand_dims(model.forward(img1, img2), axis=2)
        elif mode.startswith('perceptual_tf'):
            img1 = np.expand_dims(skimage.img_as_float(
                skimage.io.imread(os.path.join(dir1, img_files1[ind]))),
                                  axis=0)
            img2 = np.expand_dims(skimage.img_as_float(
                skimage.io.imread(os.path.join(dir2, img_files2[ind]))),
                                  axis=0)
            metric_val = sess.run(distance_t,
                                  feed_dict={
                                      image0_ph: img1,
                                      image1_ph: img2
                                  })
        else:
            raise

        if mask is not None:
            assert mode in ['l2', 'perceptual']
            mask_img = numpy.load(mask_files[ind])
            metric_val *= numpy.expand_dims(mask_img, axis=2)
            vals[ind] = numpy.sum(metric_val) / (numpy.sum(mask_img) *
                                                 metric_val.shape[2])
        else:
            vals[ind] = numpy.mean(metric_val)

    mode = mode + ('_mask' if mask is not None else '')
    filename_all = mode + '_all.txt'
    filename_breakdown = mode + '_breakdown.txt'
    filename_single = mode + '.txt'
    numpy.savetxt(os.path.join(dir1, filename_all), vals, fmt="%f, ")
    target = open(os.path.join(dir1, filename_single), 'w')
    target.write("%f" % numpy.mean(vals))
    target.close()
    if len(img_files1) == 30:
        target = open(os.path.join(dir1, filename_breakdown), 'w')
        target.write("%f, %f, %f" % (numpy.mean(
            vals[:5]), numpy.mean(vals[5:10]), numpy.mean(vals[10:])))
        target.close()
    if mode in ['l2_with_gradient', 'perceptual_tf']:
        sess.close()
    return vals
Beispiel #13
0
def validate(val_dirs: ValidationDirs, images_iterator: ImagesIterator,
             flags: OutputFlags):
    """
    Saves in val_dirs.log_dir/val/dataset_name/measures.csv:
        - `img_name,bpp,psnr,ms-ssim forall img_name`
    """
    print(_VALIDATION_INFO_STR)

    validated_checkpoints = val_dirs.get_validated_checkpoints(
    )  # :: [10000, 18000, ..., 256000], ie, [int]
    all_ckpts = Saver.all_ckpts_with_iterations(val_dirs.ckpt_dir)
    if len(all_ckpts) == 0:
        print('No checkpoints found in {}'.format(val_dirs.ckpt_dir))
        return
    # if ckpt_step is -1, then all_ckpt[:-1:flags.ckpt_step] === [] because of how strides work
    ckpt_to_check = all_ckpts[:-1:flags.ckpt_step] + [
        all_ckpts[-1]
    ]  # every ckpt_step-th checkpoint plus the last one
    if flags.ckpt_step == -1:
        assert len(ckpt_to_check) == 1
    print('Validating {}/{} checkpoints (--ckpt_step {})...'.format(
        len(ckpt_to_check), len(all_ckpts), flags.ckpt_step))

    missing_checkpoints = [(ckpt_itr, ckpt_path)
                           for ckpt_itr, ckpt_path in ckpt_to_check
                           if ckpt_itr not in validated_checkpoints]
    if len(missing_checkpoints) == 0:
        print('All checkpoints validated, stopping...')
        return

    # ---

    # create networks
    autoencoder_config_path, probclass_config_path = logdir_helpers.config_paths_from_log_dir(
        val_dirs.log_dir,
        base_dirs=[constants.CONFIG_BASE_AE, constants.CONFIG_BASE_PC])
    ae_config, ae_config_rel_path = config_parser.parse(
        autoencoder_config_path)
    pc_config, pc_config_rel_path = config_parser.parse(probclass_config_path)

    ae_cls = autoencoder.get_network_cls(ae_config)
    pc_cls = probclass.get_network_cls(pc_config)

    # Instantiate autoencoder and probability classifier
    ae = ae_cls(ae_config)
    pc = pc_cls(pc_config, num_centers=ae_config.num_centers)

    x_val_ph = tf.placeholder(tf.uint8, (3, None, None), name='x_val_ph')
    x_val_uint8 = tf.expand_dims(x_val_ph, 0, name='batch')
    x_val = tf.to_float(x_val_uint8, name='x_val')
    x_val_normalized = tf.div(
        tf.subtract(x_val, tf.reduce_min(x_val)),
        tf.subtract(tf.reduce_max(x_val), tf.reduce_min(x_val)))

    enc_out_val = ae.encode(x_val, is_training=False)
    x_out_val = ae.decode(enc_out_val.qhard, is_training=False)
    x_out_val_normalized = tf.div(
        tf.subtract(x_out_val, tf.reduce_min(x_out_val)),
        tf.subtract(tf.reduce_max(x_out_val), tf.reduce_min(x_out_val)))
    bc_val = pc.bitcost(enc_out_val.qbar,
                        enc_out_val.symbols,
                        is_training=False,
                        pad_value=pc.auto_pad_value(ae))
    bpp_val = bits.bitcost_to_bpp(bc_val, x_val)

    x_out_val_uint8 = tf.cast(x_out_val, tf.uint8, name='x_out_val_uint8')
    # Using numpy implementation due to dynamic shapes
    msssim_val = ms_ssim_np.tf_msssim_np(x_val_uint8,
                                         x_out_val_uint8,
                                         data_format='NCHW')
    psnr_val = psnr_np(x_val_uint8, x_out_val_uint8)

    restorer = Saver(val_dirs.ckpt_dir,
                     var_list=Saver.get_var_list_of_ckpt_dir(
                         val_dirs.ckpt_dir))

    # create fetch_dict
    fetch_dict = {
        'bpp': bpp_val,
        'ms-ssim': msssim_val,
        'psnr': psnr_val,
    }

    if flags.real_bpp:
        fetch_dict['sym'] = enc_out_val.symbols  # NCHW

    if flags.save_ours:
        fetch_dict['img_out'] = x_out_val_uint8

    # ---
    fw = tf.summary.FileWriter(val_dirs.out_dir, graph=tf.get_default_graph())

    def full_summary_tag(summary_name):
        return '/'.join(['val', images_iterator.dataset_name, summary_name])

    # Distance
    try:
        codec_distance_ms_ssim = CodecDistance(images_iterator.dataset_name,
                                               codec='bpg',
                                               metric='ms-ssim')
        codec_distance_psnr = CodecDistance(images_iterator.dataset_name,
                                            codec='bpg',
                                            metric='psnr')
    except CodecDistanceReadException as e:  # no codec distance values stored for the current setup
        print('*** Distance to BPG not available for {}:\n{}'.format(
            images_iterator.dataset_name, str(e)))
        codec_distance_ms_ssim = None
        codec_distance_psnr = None

    # Note that for each checkpoint, the structure of the network will be the same. Thus the pad depending image
    # loading can be cached.
    lpips_ph1 = tf.placeholder(tf.float32)
    lpips_ph2 = tf.placeholder(tf.float32)

    distance_t = lpips_tf.lpips(lpips_ph1,
                                lpips_ph2,
                                model='net-lin',
                                net='alex')
    distance_t = tf.Print(distance_t, [distance_t])
    # create session
    with tf_helpers.create_session() as sess:
        if flags.real_bpp:
            pred = probclass.PredictionNetwork(pc, pc_config,
                                               ae.get_centers_variable(), sess)
            checker = probclass.ProbclassNetworkTesting(pc, ae, sess)
            bpp_fetcher = bpp_helpers.BppFetcher(pred, checker)

        fetcher = sess.make_callable(fetch_dict, feed_list=[x_val_ph])

        last_ckpt_itr = missing_checkpoints[-1][0]
        for ckpt_itr, ckpt_path in missing_checkpoints:
            if not ckpt_still_exists(ckpt_path):
                # May happen if job is still training
                print('Checkpoint disappeared: {}'.format(ckpt_path))
                continue

            print(_CKPT_ITR_INFO_STR.format(ckpt_itr))

            restorer.restore_ckpt(sess, ckpt_path)

            values_aggregator = ValuesAggregator('bpp', 'ms-ssim', 'psnr')

            # truncates the previous measures.csv file! This way, only the last valid checkpoint is saved.
            measures_writer = MeasuresWriter(val_dirs.out_dir)

            # ----------------------------------------
            # iterate over images
            # images are padded to work with current auto encoder

            for img_i, (img_name, img_content) in enumerate(
                    images_iterator.iter_imgs(
                        pad=ae.get_subsampling_factor())):

                otp = fetcher(img_content)
                measures_writer.append(img_name, otp)
                img_content_normalized = (
                    img_content - np.min(img_content)) / (np.max(img_content) -
                                                          np.min(img_content))
                img_content_normalized = np.transpose(
                    np.expand_dims(img_content_normalized, axis=0),
                    [0, 2, 3, 1])
                out_img_content_normalized = (otp['img_out'] - np.min(
                    otp['img_out'])) / (np.max(otp['img_out']) -
                                        np.min(otp['img_out']))
                #print(out_img_content_normalized.shape)
                out_img_content_normalized = np.transpose(
                    out_img_content_normalized, [0, 2, 3, 1])

                sess.run(distance_t,
                         feed_dict={
                             lpips_ph1: img_content_normalized,
                             lpips_ph2: out_img_content_normalized
                         })

                if flags.real_bpp:
                    # Calculate
                    bpp_real, bpp_theory = bpp_fetcher.get_bpp(
                        otp['sym'],
                        bpp_helpers.num_pixels_in_image(img_content))

                    # Logging
                    bpp_loss = otp['bpp']
                    diff_percent_tr = (bpp_theory / bpp_real) * 100
                    diff_percent_lt = (bpp_loss / bpp_theory) * 100
                    print('BPP: Real         {:.5f}\n'
                          '     Theoretical: {:.5f} [{:5.1f}% of real]\n'
                          '     Loss:        {:.5f} [{:5.1f}% of real]'.format(
                              bpp_real, bpp_theory, diff_percent_tr, bpp_loss,
                              diff_percent_lt))
                    assert abs(
                        bpp_theory - bpp_loss
                    ) < 1e-3, 'Expected bpp_theory to match loss! Got {} and {}'.format(
                        bpp_theory, bpp_loss)

                if flags.save_ours and ckpt_itr == last_ckpt_itr:
                    save_img(img_name, otp['img_out'], val_dirs)

                values_aggregator.update(otp)

                print('{: 10d} {img_name} | Mean: {avgs}'.format(
                    img_i,
                    img_name=img_name,
                    avgs=values_aggregator.averages_str()),
                      end=('\r' if not flags.real_bpp else '\n'),
                      flush=True)

            measures_writer.close()

            print()  # add newline
            avgs = values_aggregator.averages()
            avg_bpp, avg_ms_ssim, avg_psnr = avgs['bpp'], avgs[
                'ms-ssim'], avgs['psnr']

            tf_helpers.log_values(
                fw, [(full_summary_tag('avg_bpp'), avg_bpp),
                     (full_summary_tag('avg_ms_ssim'), avg_ms_ssim),
                     (full_summary_tag('avg_psnr'), avg_psnr)],
                iteration=ckpt_itr)

            if codec_distance_ms_ssim and codec_distance_psnr:
                try:
                    d_ms_ssim = codec_distance_ms_ssim.distance(
                        avg_bpp, avg_ms_ssim)
                    d_pnsr = codec_distance_psnr.distance(avg_bpp, avg_psnr)
                    print('Distance to BPG: {:.3f} ms-ssim // {:.3f} psnr'.
                          format(d_ms_ssim, d_pnsr))
                    tf_helpers.log_values(
                        fw,
                        [(full_summary_tag('distance_BPG_MS-SSIM'), d_ms_ssim),
                         (full_summary_tag('distance_BPG_PSNR'), d_pnsr)],
                        iteration=ckpt_itr)
                except ValueError as e:  # out of range errors from distance calls
                    print(e)

            val_dirs.add_validated_checkpoint(ckpt_itr)

    print('Validation completed {}'.format(val_dirs))
Beispiel #14
0
def main():
    args, save_dir, load_dir = check_args(parse_arguments())

    ### open session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        sess.as_default()

        ### load pre-trained model
        network_path = sorted(glob(os.path.join(load_dir,
                                                'network-*.pkl')))[-1]
        with open(network_path, 'rb') as file:
            print('Loading networks from "%s"...' % network_path)
            G, D, Gs = pickle.load(file)
        Gs.print_layers()
        D.print_layers()

        ### define variables
        global BATCH_SIZE
        BATCH_SIZE = args.batch_size
        z_dim = G.input_shape[-1]
        resolution = G.output_shape[-1]
        x = tf.placeholder(tf.float32,
                           shape=(BATCH_SIZE, resolution, resolution, 3))
        labels = tf.zeros([BATCH_SIZE, 0], tf.float32)

        ### initialization
        init_val_ph = None
        init_val = {'pos': None, 'neg': None}
        if args.initialize_type == 'zero':
            z = tf.Variable(tf.zeros([BATCH_SIZE, z_dim], tf.float32),
                            name='latent_z')

        elif args.initialize_type == 'random':
            np.random.seed(RANDOM_SEED)
            init_val_np = np.random.normal(size=(z_dim, ))
            init_val_np = init_val_np / np.sqrt(
                np.mean(np.square(init_val_np)) + 1e-8)
            init = np.tile(init_val_np, (BATCH_SIZE, 1)).astype(np.float32)
            z = tf.Variable(init, name='latent_z')

        elif args.initialize_type == 'nn':
            init_val = {}
            init_val['pos'] = np.load(os.path.join(args.nn_dir,
                                                   'pos_z.npy'))[:, 0, :]
            init_val['neg'] = np.load(os.path.join(args.nn_dir,
                                                   'neg_z.npy'))[:, 0, :]
            init_val_ph = tf.placeholder(dtype=tf.float32,
                                         name='init_ph',
                                         shape=(BATCH_SIZE, z_dim))
            z = tf.Variable(init_val_ph, name='latent_z')

        else:
            raise NotImplementedError

        ### get the reconstruction (x_hat)
        with tf.variable_scope(Gs.scope, reuse=True):
            assert tf.get_variable_scope().name == Gs.scope
            with tf.control_dependencies(
                    None):  # ignore surrounding control_dependencies
                inputs = [z, labels]
                x_hat = Gs._build_func(*inputs,
                                       is_template_graph=True,
                                       **Gs.static_kwargs)
                x_hat = tf.transpose(x_hat, perm=[0, 2, 3, 1])
                x_hat = tf.clip_by_value(x_hat, -1., 1.)

        ### loss
        if args.distance == 'l2':
            print('Use distance: l2')
            loss_l2 = tf.reduce_mean(tf.square(x_hat - x), axis=[1, 2, 3])
            vec_loss = loss_l2
            vec_losses = {'l2': loss_l2}

        elif args.distance == 'l2-lpips':
            print('Use distance: lpips + l2')
            loss_l2 = tf.reduce_mean(tf.square(x_hat - x), axis=[1, 2, 3])
            loss_lpips = lpips_tf.lpips(x_hat,
                                        x,
                                        normalize=False,
                                        model='net-lin',
                                        net='vgg',
                                        version='0.1')
            vec_losses = {'l2': loss_l2, 'lpips': loss_lpips}
            vec_loss = loss_l2 + LAMBDA2 * loss_lpips

        else:
            raise NotImplementedError

        ## regularizer
        norm = tf.reduce_sum(tf.square(z), axis=1)
        norm_penalty = (norm - z_dim)**2

        if args.if_norm_reg:
            loss = tf.reduce_mean(
                vec_loss) + LAMBDA3 * tf.reduce_mean(norm_penalty)
            vec_losses['norm'] = norm_penalty
        else:
            loss = tf.reduce_mean(vec_loss)

        ### set up optimizer
        opt = tf.contrib.opt.ScipyOptimizerInterface(
            loss,
            var_list=[z],
            method='L-BFGS-B',
            options={'maxfun': args.maxfunc})

        ### load query images
        pos_data_paths = get_filepaths_from_dir(args.pos_data_dir,
                                                ext='png')[:args.data_num]
        pos_query_imgs = np.array(
            [read_image(f, resolution) for f in pos_data_paths])

        neg_data_paths = get_filepaths_from_dir(args.neg_data_dir,
                                                ext='png')[:args.data_num]
        neg_query_imgs = np.array(
            [read_image(f, resolution) for f in neg_data_paths])

        ### run the optimization on query images
        query_loss, query_z, query_xhat = optimize_z(
            sess, z, x, x_hat, init_val_ph, init_val['pos'], pos_query_imgs,
            check_folder(os.path.join(save_dir, 'pos_results')), opt, vec_loss,
            vec_losses)
        save_files(save_dir, ['pos_loss'], [query_loss])

        query_loss, query_z, query_xhat = optimize_z(
            sess, z, x, x_hat, init_val_ph, init_val['neg'], neg_query_imgs,
            check_folder(os.path.join(save_dir, 'neg_results')), opt, vec_loss,
            vec_losses)
        save_files(save_dir, ['neg_loss'], [query_loss])
Beispiel #15
0
image2_ph = tf.compat.v1.placeholder(tf.float32)
image3_ph = tf.compat.v1.placeholder(tf.float32)
image4_ph = tf.compat.v1.placeholder(tf.float32)
image5_ph = tf.compat.v1.placeholder(tf.float32)
image6_ph = tf.compat.v1.placeholder(tf.float32)
image7_ph = tf.compat.v1.placeholder(tf.float32)
image8_ph = tf.compat.v1.placeholder(tf.float32)
image9_ph = tf.compat.v1.placeholder(tf.float32)
image10_ph = tf.compat.v1.placeholder(tf.float32)
image11_ph = tf.compat.v1.placeholder(tf.float32)
image12_ph = tf.compat.v1.placeholder(tf.float32)
image13_ph = tf.compat.v1.placeholder(tf.float32)
image14_ph = tf.compat.v1.placeholder(tf.float32)
image15_ph = tf.compat.v1.placeholder(tf.float32)

distance_t1 = lpips_tf.lpips(image0_ph, image1_ph, model='net-lin', net='alex')
distance_t2 = lpips_tf.lpips(image0_ph, image2_ph, model='net-lin', net='alex')
distance_t3 = lpips_tf.lpips(image0_ph, image3_ph, model='net-lin', net='alex')
distance_t4 = lpips_tf.lpips(image0_ph, image4_ph, model='net-lin', net='alex')
distance_t5 = lpips_tf.lpips(image0_ph, image5_ph, model='net-lin', net='alex')
distance_t6 = lpips_tf.lpips(image0_ph, image6_ph, model='net-lin', net='alex')
distance_t7 = lpips_tf.lpips(image0_ph, image7_ph, model='net-lin', net='alex')
distance_t8 = lpips_tf.lpips(image0_ph, image8_ph, model='net-lin', net='alex')
distance_t9 = lpips_tf.lpips(image0_ph, image9_ph, model='net-lin', net='alex')
distance_t10 = lpips_tf.lpips(image0_ph,
                              image10_ph,
                              model='net-lin',
                              net='alex')
distance_t11 = lpips_tf.lpips(image0_ph,
                              image11_ph,
                              model='net-lin',
Beispiel #16
0
def main():
    args, save_dir, load_dir = check_args(parse_arguments())
    config_path = os.path.join(load_dir, 'params.pkl')
    if os.path.exists(config_path):
        config = pickle.load(open(config_path, 'rb'))
        OUTPUT_SIZE = config['OUTPUT_SIZE']
        GAN_TYPE = config['Architecture']
        Z_DIM = config['Z_DIM']
    else:
        OUTPUT_SIZE = 64
        GAN_TYPE = 'good'
        Z_DIM = 128

    ### set up the generator and the discriminator
    Generator, Discriminator = GeneratorAndDiscriminator(GAN_TYPE)

    ### open session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:

        ### define variables
        global BATCH_SIZE
        BATCH_SIZE = args.batch_size
        x = tf.placeholder(tf.float32,
                           shape=(None, OUTPUT_SIZE, OUTPUT_SIZE, 3),
                           name='x')

        ### initialization
        init_val_ph = None
        init_val = {'pos': None, 'neg': None}
        if args.initialize_type == 'zero':
            z = tf.Variable(tf.zeros([BATCH_SIZE, Z_DIM], tf.float32),
                            name='latent_z')

        elif args.initialize_type == 'random':
            np.random.seed(RANDOM_SEED)
            init_val_np = np.random.normal(size=(Z_DIM, ))
            init = np.tile(init_val_np, (BATCH_SIZE, 1)).astype(np.float32)
            z = tf.Variable(init, name='latent_z')

        elif args.initialize_type == 'nn':
            init_val['pos'] = np.load(os.path.join(args.nn_dir,
                                                   'pos_z.npy'))[:, 0, :]
            init_val['neg'] = np.load(os.path.join(args.nn_dir,
                                                   'neg_z.npy'))[:, 0, :]
            init_val_ph = tf.placeholder(dtype=tf.float32,
                                         name='init_ph',
                                         shape=(BATCH_SIZE, Z_DIM))
            z = tf.Variable(init_val_ph, name='latent_z')

        else:
            raise NotImplementedError

        ### get the reconstruction (x_hat)
        x_hat = Generator(BATCH_SIZE, noise=z, is_training=False, z_dim=Z_DIM)
        x_hat = tf.reshape(x_hat, [-1, 3, OUTPUT_SIZE, OUTPUT_SIZE])
        x_hat = tf.transpose(x_hat, perm=[0, 2, 3, 1])

        ### load model
        vars = [v for v in tf.global_variables() if 'latent_z' not in v.name]
        saver = tf.train.Saver(vars)
        sess.run(tf.initialize_variables(vars))
        if_load, counter = load_model_from_checkpoint(load_dir, saver, sess)
        assert if_load is True

        ### loss
        if args.distance == 'l2':
            print('use distance: l2')
            loss_l2 = tf.reduce_mean(tf.square(x_hat - x), axis=[1, 2, 3])
            vec_loss = loss_l2
            vec_losses = {'l2': loss_l2}

        elif args.distance == 'l2-lpips':
            print('use distance: lpips + l2')
            loss_l2 = tf.reduce_mean(tf.square(x_hat - x), axis=[1, 2, 3])
            loss_lpips = lpips_tf.lpips(x_hat,
                                        x,
                                        normalize=False,
                                        model='net-lin',
                                        net='vgg',
                                        version='0.1')
            vec_losses = {'l2': loss_l2, 'lpips': loss_lpips}
            vec_loss = loss_l2 + LAMBDA2 * loss_lpips
        else:
            raise NotImplementedError

        ## regularizer
        norm = tf.reduce_sum(tf.square(z), axis=1)
        norm_penalty = (norm - Z_DIM)**2

        if args.if_norm_reg:
            loss = tf.reduce_mean(
                vec_loss) + LAMBDA3 * tf.reduce_mean(norm_penalty)
            vec_losses['norm'] = norm_penalty
        else:
            loss = tf.reduce_mean(vec_loss)

        ### set up optimizer
        opt = tf.contrib.opt.ScipyOptimizerInterface(
            loss,
            var_list=[z],
            method='Powell',
            options={'maxiter': args.maxiter})

        ### load query images
        pos_data_paths = get_filepaths_from_dir(args.pos_data_dir,
                                                ext='png')[:args.data_num]
        pos_query_imgs = np.array(
            [read_image(f, OUTPUT_SIZE) for f in pos_data_paths])

        neg_data_paths = get_filepaths_from_dir(args.neg_data_dir,
                                                ext='png')[:args.data_num]
        neg_query_imgs = np.array(
            [read_image(f, OUTPUT_SIZE) for f in neg_data_paths])

        ### run the optimization on query images
        query_loss, query_z, query_xhat = optimize_z(
            sess, z, x, x_hat, init_val_ph, init_val['pos'], pos_query_imgs,
            check_folder(os.path.join(save_dir, 'pos_results')), opt, vec_loss,
            vec_losses)
        save_files(save_dir, ['pos_loss'], [query_loss])

        query_loss, query_z, query_xhat = optimize_z(
            sess, z, x, x_hat, init_val_ph, init_val['neg'], neg_query_imgs,
            check_folder(os.path.join(save_dir, 'neg_results')), opt, vec_loss,
            vec_losses)
        save_files(save_dir, ['neg_loss'], [query_loss])
Beispiel #17
0
# get test IDs
test_fns = glob.glob(gt_dir + '/%d*.ARW' % d_id)
test_ids = [int(os.path.basename(test_fn)[0:5]) for test_fn in test_fns]

misaligned = [10034, 10045, 10172]

sess = tf.Session()
in_image = tf.placeholder(tf.float32, [None, None, None, None, 4])
in_image_low = tf.placeholder(tf.float32, [None, None, None, None, 4])
gt_image = tf.placeholder(tf.float32, [None, None, None, 3])

coarse_outs = burst_nets.coarse_net(in_image_low)
out_image = burst_nets.fine_net(in_image, coarse_outs)

distance_t = lpips_tf.lpips(gt_image, out_image, model='net-lin', net='alex')

t_vars = tf.trainable_variables()
saver = tf.train.Saver(t_vars)
sess.run(tf.global_variables_initializer())
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)

if ckpt:
    print('loaded ' + ckpt.model_checkpoint_path)
    saver.restore(sess, ckpt.model_checkpoint_path)

ssim_list = []
psnr_list = []
lpips_list = []
time_list = []
ratio_list = []