コード例 #1
0
def image_to_patches(IMG1, IMG2, PATCH_HEIGHT, PATCH_WIDTH, IMG2_Path, IMG1_Path, filename, k, pair_thres, adj_thres):
    prev_patch = None
    for i in range(0,IMG1.size[1],PATCH_HEIGHT):
        for j in range(0,IMG1.size[0],PATCH_WIDTH):
            if(j + PATCH_WIDTH <= IMG1.size[0] and i + PATCH_HEIGHT <= IMG1.size[1]):
                box = (j, i, j+PATCH_WIDTH, i+PATCH_HEIGHT)
                IMG2_patch = IMG2.crop(box)
                IMG1_patch = IMG1.crop(box)

                IMG1_cv2 = IMG1_patch.convert('RGB')
                IMG1_cv2 = np.array(IMG1_cv2)
                IMG1_cv2 = cv2.cvtColor(IMG1_cv2, cv2.COLOR_BGR2RGB)
                #pair_eval = compare_ssim(np.array(IMG1_patch), np.array(IMG2_patch), multichannel=True)
                pair_eval = MultiScaleSSIM(np.expand_dims(IMG1_patch, axis=0), np.expand_dims(IMG2_patch, axis=0), max_val=255)

                if(pair_eval >= pair_thres and (prev_patch is None or (prev_patch is not None and compare_ssim(IMG1_cv2, prev_patch, multichannel=True) <= adj_thres))):
                    IMG2_patch.save(IMG2_Path + '(' + str(k) + ").jpg")
                    IMG1_patch.save(IMG1_Path + '(' + str(k) + ").jpg")
                    k = k + 1
                    prev_patch = IMG1_cv2
                        
    return k
コード例 #2
0
def main(args):
    # loading training and test data
    logger.info("Loading test data...")
    test_data, test_answ = load_test_data(args.dataset, args.dataset_dir, args.test_size, args.patch_size)
    logger.info("Test data was loaded\n")

    logger.info("Loading training data...")
    train_data, train_answ = load_batch(args.dataset, args.dataset_dir, args.train_size, args.patch_size)
    logger.info("Training data was loaded\n")

    TEST_SIZE = test_data.shape[0]
    num_test_batches = int(test_data.shape[0] / args.batch_size)

    # defining system architecture
    with tf.Graph().as_default(), tf.Session() as sess:

        # placeholders for training data
        phone_ = tf.placeholder(tf.float32, [None, args.patch_size])
        phone_image = tf.reshape(phone_, [-1, args.patch_height, args.patch_width, 3])

        dslr_ = tf.placeholder(tf.float32, [None, args.patch_size])
        dslr_image = tf.reshape(dslr_, [-1, args.patch_height, args.patch_width, 3])

        adv_ = tf.placeholder(tf.float32, [None, 1])
        enhanced = unet(phone_image)
        [w, h, d] = enhanced.get_shape().as_list()[1:]

        # # learning rate exponential_decay
        # global_step = tf.Variable(0)
        # learning_rate = tf.train.exponential_decay(args.learning_rate, global_step, decay_steps=args.train_size / args.batch_size, decay_rate=0.98, staircase=True)

        ## loss introduce
        '''
        content loss three ways : 
        1. vgg_loss: mat model load;
        2. vgg_loss: npy model load;
        3. iqa model(meon_loss): feature and scores
        '''
        # vgg = vgg19_loss.Vgg19(vgg_path=args.pretrain_weights) #  # load vgg models
        # vgg_content = 2000*tf.reduce_mean(tf.sqrt(tf.reduce_sum(
        #     tf.square((vgg.extract_feature(enhanced) - vgg.extract_feature(dslr_image))))) / (w * h * d))
        # # loss_content = multi_content_loss(args.pretrain_weights, enhanced, dslr_image, args.batch_size) # change another way

        # meon loss
        # with tf.variable_scope('meon_loss') as scope: # load ckpt is not conveient.
        MEON_evaluate_model, loss_content = meon_loss(dslr_image, enhanced)

        loss_texture, discim_accuracy = texture_loss(enhanced, dslr_image, args.patch_width, args.patch_height, adv_)
        loss_discrim = -loss_texture

        loss_color = color_loss(enhanced, dslr_image, args.batch_size)
        loss_tv = variation_loss(enhanced, args.patch_width, args.patch_height, args.batch_size)

        loss_psnr = PSNR(enhanced, dslr_image)
        loss_ssim = MultiScaleSSIM(enhanced, dslr_image)

        loss_generator = args.w_content * loss_content + args.w_texture * loss_texture + args.w_tv * loss_tv + 1000 * (
                    1 - loss_ssim) + args.w_color * loss_color

        # optimize parameters of image enhancement (generator) and discriminator networks
        generator_vars = [v for v in tf.global_variables() if v.name.startswith("generator")]
        discriminator_vars = [v for v in tf.global_variables() if v.name.startswith("discriminator")]
        meon_vars = [v for v in tf.global_variables() if v.name.startswith("conv") or v.name.startswith("subtask")]

        # train_step_gen = tf.train.AdamOptimizer(args.learning_rate).minimize(loss_generator, var_list=generator_vars)
        # train_step_disc = tf.train.AdamOptimizer(args.learning_rate).minimize(loss_discrim, var_list=discriminator_vars)

        train_step_gen = tf.train.AdamOptimizer(5e-5).minimize(loss_generator, var_list=generator_vars)
        train_step_disc = tf.train.AdamOptimizer(5e-5).minimize(loss_discrim, var_list=discriminator_vars)

        saver = tf.train.Saver(var_list=generator_vars, max_to_keep=100)
        meon_saver = tf.train.Saver(var_list=meon_vars)

        logger.info('Initializing variables')
        sess.run(tf.global_variables_initializer())
        logger.info('Training network')
        train_loss_gen = 0.0
        train_acc_discrim = 0.0
        all_zeros = np.reshape(np.zeros((args.batch_size, 1)), [args.batch_size, 1])
        test_crops = test_data[np.random.randint(0, TEST_SIZE, 5), :]  # choose five images to visual

        # summary ,add the scalar you want to see
        tf.summary.scalar('loss_generator', loss_generator),
        tf.summary.scalar('loss_content', loss_content),
        tf.summary.scalar('loss_color', loss_color),
        tf.summary.scalar('loss_texture', loss_texture),
        tf.summary.scalar('loss_tv', loss_tv),
        tf.summary.scalar('discim_accuracy', discim_accuracy),
        tf.summary.scalar('psnr', loss_psnr),
        tf.summary.scalar('ssim', loss_ssim),
        tf.summary.scalar('learning_rate', args.learning_rate),
        merge_summary = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(os.path.join(args.tesorboard_logs_dir, 'train', args.exp_name), sess.graph,
                                             filename_suffix=args.exp_name)
        test_writer = tf.summary.FileWriter(os.path.join(args.tesorboard_logs_dir, 'test', args.exp_name), sess.graph,
                                            filename_suffix=args.exp_name)
        tf.global_variables_initializer().run()

        '''load ckpt models'''
        ckpt = tf.train.get_checkpoint_state(args.checkpoint_dir)
        start_i = 0
        if ckpt and ckpt.model_checkpoint_path:
            logger.info('loading checkpoint:' + ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
            import re
            start_i = int(re.findall("_(\d+).ckpt", ckpt.model_checkpoint_path)[0])
        MEON_evaluate_model.initialize(sess, meon_saver,
                                       args.meod_ckpt_path)  # initialize with anohter model pretrained weights

        '''start training...'''
        for i in range(start_i, args.iter_max):

            iter_start = time.time()
            # train generator
            idx_train = np.random.randint(0, args.train_size, args.batch_size)
            phone_images = train_data[idx_train]
            dslr_images = train_answ[idx_train]

            [loss_temp, temp] = sess.run([loss_generator, train_step_gen],
                                         feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: all_zeros})
            train_loss_gen += loss_temp / args.eval_step

            # train discriminator
            idx_train = np.random.randint(0, args.train_size, args.batch_size)

            # generate image swaps (dslr or enhanced) for discriminator
            swaps = np.reshape(np.random.randint(0, 2, args.batch_size), [args.batch_size, 1])

            phone_images = train_data[idx_train]
            dslr_images = train_answ[idx_train]
            # sess.run(train_step_disc)=train_step_disc.compute_gradients(loss,var)+train_step_disc.apply_gradients(var) @20190105
            [accuracy_temp, temp] = sess.run([discim_accuracy, train_step_disc],
                                             feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: swaps})
            train_acc_discrim += accuracy_temp / args.eval_step

            if i % args.summary_step == 0:
                # summary intervals
                # enhance_f1_, enhance_f2_, enhance_s_, vgg_content_ = sess.run([enhance_f1, enhance_f2, enhance_s,vgg_content],
                #                          feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: swaps})
                # loss_content1_, loss_content2_, loss_content3_ = sess.run([loss_content1,loss_content2,loss_content3],
                #                          feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: swaps})
                # print("-----------------------------------------------")
                # print(enhance_f1_, enhance_f2_, enhance_s_,vgg_content_,loss_content1_, loss_content2_, loss_content3_)
                # print("-----------------------------------------------")
                train_summary = sess.run(merge_summary,
                                         feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: swaps})
                train_writer.add_summary(train_summary, i)

            if i % args.eval_step == 0:
                # test generator and discriminator CNNs
                test_losses_gen = np.zeros((1, 7))
                test_accuracy_disc = 0.0

                for j in range(num_test_batches):
                    be = j * args.batch_size
                    en = (j + 1) * args.batch_size

                    swaps = np.reshape(np.random.randint(0, 2, args.batch_size), [args.batch_size, 1])
                    phone_images = test_data[be:en]
                    dslr_images = test_answ[be:en]

                    [enhanced_crops, accuracy_disc, losses] = sess.run([enhanced, discim_accuracy, \
                                                                        [loss_generator, loss_content, loss_color,
                                                                         loss_texture, loss_tv, loss_psnr, loss_ssim]], \
                                                                       feed_dict={phone_: phone_images,
                                                                                  dslr_: dslr_images, adv_: swaps})

                    test_losses_gen += np.asarray(losses) / num_test_batches
                    test_accuracy_disc += accuracy_disc / num_test_batches

                logs_disc = "step %d/%d, %s | discriminator accuracy | train: %.4g, test: %.4g" % \
                            (i, args.iter_max, args.dataset, train_acc_discrim, test_accuracy_disc)
                logs_gen = "generator losses | train: %.4g, test: %.4g | content: %.4g, color: %.4g, texture: %.4g, tv: %.4g | psnr: %.4g, ssim: %.4g\n" % \
                           (train_loss_gen, test_losses_gen[0][0], test_losses_gen[0][1], test_losses_gen[0][2],
                            test_losses_gen[0][3], test_losses_gen[0][4], test_losses_gen[0][5], test_losses_gen[0][6])

                logger.info(logs_disc)
                logger.info(logs_gen)

                test_summary = sess.run(merge_summary,
                                        feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: swaps})
                test_writer.add_summary(test_summary, i)

                # save visual results for several test image crops
                if args.save_visual_result:
                    enhanced_crops = sess.run(enhanced,
                                              feed_dict={phone_: test_crops, dslr_: dslr_images, adv_: all_zeros})
                    idx = 0
                    for crop in enhanced_crops:
                        before_after = np.hstack(
                            (np.reshape(test_crops[idx], [args.patch_height, args.patch_width, 3]), crop))
                        misc.imsave(
                            os.path.join(args.checkpoint_dir, str(args.dataset) + str(idx) + '_iteration_' + str(i) +
                                         '.jpg'), before_after)
                        idx += 1

                # save the model that corresponds to the current iteration
                if args.save_ckpt_file:
                    saver.save(sess,
                               os.path.join(args.checkpoint_dir, str(args.dataset) + '_iteration_' + str(i) + '.ckpt'),
                               write_meta_graph=False)

                train_loss_gen = 0.0
                train_acc_discrim = 0.0
                # reload a different batch of training data
                del train_data
                del train_answ
                del test_data
                del test_answ
                test_data, test_answ = load_test_data(args.dataset, args.dataset_dir, args.test_size, args.patch_size)
                train_data, train_answ = load_batch(args.dataset, args.dataset_dir, args.train_size, args.patch_size)
コード例 #3
0
def Mssim_loss(target, prediction):
    loss_Mssim = 1 - MultiScaleSSIM(target, prediction)
    return loss_Mssim * 1000
コード例 #4
0
ファイル: train.py プロジェクト: ligua/cnn_image_enhance
def main(args, data_params):
    procname = os.path.basename(args.checkpoint_dir)

    log.info('Preparing summary and checkpoint directory {}'.format(
        args.checkpoint_dir))
    if not os.path.exists(args.checkpoint_dir):
        os.makedirs(args.checkpoint_dir)

    tf.set_random_seed(1234)  # Make experiments repeatable

    # Select an architecture

    # Add model parameters to the graph (so they are saved to disk at checkpoint)

    # --- Train/Test datasets ---------------------------------------------------
    data_pipe = getattr(dp, args.data_pipeline)
    with tf.variable_scope('train_data'):
        train_data_pipeline = data_pipe(
            args.data_dir,
            shuffle=True,
            batch_size=args.batch_size,
            nthreads=args.data_threads,
            fliplr=args.fliplr,
            flipud=args.flipud,
            rotate=args.rotate,
            random_crop=args.random_crop,
            params=data_params,
            output_resolution=args.output_resolution,
            scale=args.scale)
        train_samples = train_data_pipeline.samples

    if args.eval_data_dir is not None:
        with tf.variable_scope('eval_data'):
            eval_data_pipeline = data_pipe(
                args.eval_data_dir,
                shuffle=True,
                batch_size=args.batch_size,
                nthreads=args.data_threads,
                fliplr=False,
                flipud=False,
                rotate=False,
                random_crop=False,
                params=data_params,
                output_resolution=args.output_resolution,
                scale=args.scale)
            eval_samples = eval_data_pipeline.samples
    # ---------------------------------------------------------------------------
    swaps = np.reshape(np.random.randint(0, 2, args.batch_size),
                       [args.batch_size, 1])
    swaps = tf.convert_to_tensor(swaps)
    swaps = tf.cast(swaps, tf.float32)
    # Training graph
    with tf.variable_scope('inference'):
        prediction = unet(train_samples['image_input'])
        loss,loss_content,loss_texture,loss_color,loss_Mssim,loss_tv,discim_accuracy =\
          compute_loss.total_loss(train_samples['image_output'], prediction, swaps, args.batch_size)
        psnr = PSNR(train_samples['image_output'], prediction)
        loss_ssim = MultiScaleSSIM(train_samples['image_output'], prediction)

    # Evaluation graph
    if args.eval_data_dir is not None:
        with tf.name_scope('eval'):
            with tf.variable_scope('inference', reuse=True):
                eval_prediction = unet(eval_samples['image_input'])
            eval_psnr = PSNR(eval_samples['image_output'], eval_prediction)
            eval_ssim = MultiScaleSSIM(eval_samples['image_output'],
                                       eval_prediction)

    # Optimizer
    model_vars1 = [
        v for v in tf.global_variables()
        if v.name.startswith("inference/generator")
    ]
    discriminator_vars1 = [
        v for v in tf.global_variables()
        if v.name.startswith("inference/l2_loss/discriminator")
    ]

    global_step = tf.contrib.framework.get_or_create_global_step()
    with tf.name_scope('optimizer'):
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        updates = tf.group(*update_ops, name='update_ops')
        log.info("Adding {} update ops".format(len(update_ops)))

        reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
        if reg_losses and args.weight_decay is not None and args.weight_decay > 0:
            print("Regularization losses:")
            for rl in reg_losses:
                print(" ", rl.name)
            opt_loss = loss + args.weight_decay * sum(reg_losses)
        else:
            print("No regularization.")
            opt_loss = loss

        with tf.control_dependencies([updates]):
            opt = tf.train.AdamOptimizer(args.learning_rate)
            minimize = opt.minimize(opt_loss,
                                    name='optimizer',
                                    global_step=global_step,
                                    var_list=model_vars1)
            minimize_discrim = opt.minimize(-loss_texture,
                                            name='discriminator',
                                            global_step=global_step,
                                            var_list=discriminator_vars1)

    # Average loss and psnr for display
    with tf.name_scope("moving_averages"):
        ema = tf.train.ExponentialMovingAverage(decay=0.99)
        update_ma = ema.apply([
            loss, loss_content, loss_texture, loss_color, loss_Mssim, loss_tv,
            discim_accuracy, psnr, loss_ssim
        ])
        loss = ema.average(loss)
        loss_content = ema.average(loss_content)
        loss_texture = ema.average(loss_texture)
        loss_color = ema.average(loss_color)
        loss_Mssim = ema.average(loss_Mssim)
        loss_tv = ema.average(loss_tv)
        discim_accuracy = ema.average(discim_accuracy)
        psnr = ema.average(psnr)
        loss_ssim = ema.average(loss_ssim)

    # Training stepper operation
    train_op = tf.group(minimize, update_ma)
    train_discrim_op = tf.group(minimize_discrim, update_ma)

    # Save a few graphs to
    summaries = [
        tf.summary.scalar('loss', loss),
        tf.summary.scalar('loss_content', loss_content),
        tf.summary.scalar('loss_color', loss_color),
        tf.summary.scalar('loss_texture', loss_texture),
        tf.summary.scalar('loss_ssim', loss_Mssim),
        tf.summary.scalar('loss_tv', loss_tv),
        tf.summary.scalar('discim_accuracy', discim_accuracy),
        tf.summary.scalar('psnr', psnr),
        tf.summary.scalar('ssim', loss_ssim),
        tf.summary.scalar('learning_rate', args.learning_rate),
        tf.summary.scalar('batch_size', args.batch_size),
    ]

    log_fetches = {
        "loss_content": loss_content,
        "loss_texture": loss_texture,
        "loss_color": loss_color,
        "loss_Mssim": loss_Mssim,
        "loss_tv": loss_tv,
        "discim_accuracy": discim_accuracy,
        "step": global_step,
        "loss": loss,
        "psnr": psnr,
        "loss_ssim": loss_ssim
    }

    model_vars = [
        v for v in tf.global_variables()
        if not v.name.startswith("inference/l2_loss/discriminator")
    ]
    discriminator_vars = [
        v for v in tf.global_variables()
        if v.name.startswith("inference/l2_loss/discriminator")
    ]

    # Train config
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True  # Do not canibalize the entire GPU

    sv = tf.train.Supervisor(
        saver=tf.train.Saver(var_list=model_vars, max_to_keep=100),
        local_init_op=tf.initialize_variables(discriminator_vars),
        logdir=args.checkpoint_dir,
        save_summaries_secs=args.summary_interval,
        save_model_secs=args.checkpoint_interval)
    # Train loop
    with sv.managed_session(config=config) as sess:
        sv.loop(args.log_interval, log_hook, (sess, log_fetches))
        last_eval = time.time()
        while True:
            if sv.should_stop():
                log.info("stopping supervisor")
                break
            try:
                step, _ = sess.run([global_step, train_op])
                _ = sess.run(train_discrim_op)
                since_eval = time.time() - last_eval

                if args.eval_data_dir is not None and since_eval > args.eval_interval:
                    log.info("Evaluating on {} images at step {}".format(
                        3, step))

                    p_ = 0
                    s_ = 0
                    for it in range(3):
                        p_ += sess.run(eval_psnr)
                        s_ += sess.run(eval_ssim)
                    p_ /= 3
                    s_ /= 3

                    sv.summary_writer.add_summary(tf.Summary(value=[
                        tf.Summary.Value(tag="psnr/eval", simple_value=p_)
                    ]),
                                                  global_step=step)

                    sv.summary_writer.add_summary(tf.Summary(value=[
                        tf.Summary.Value(tag="ssim/eval", simple_value=s_)
                    ]),
                                                  global_step=step)

                    log.info("  Evaluation PSNR = {:.2f} dB".format(p_))
                    log.info("  Evaluation SSIM = {:.4f} ".format(s_))

                    last_eval = time.time()

            except tf.errors.AbortedError:
                log.error("Aborted")
                break
            except KeyboardInterrupt:
                break
        chkpt_path = os.path.join(args.checkpoint_dir, 'on_stop.ckpt')
        log.info("Training complete, saving chkpt {}".format(chkpt_path))
        sv.saver.save(sess, chkpt_path)
        sv.request_stop()
コード例 #5
0
ファイル: train.py プロジェクト: gengmr/TrackA
    enhanced = EDSR(phone_image)
    print enhanced.shape

    #loss introduce
    # loss_texture, discim_accuracy = texture_loss(enhanced,dslr_image,PATCH_WIDTH,PATCH_HEIGHT,adv_)
    # loss_discrim = -loss_texture
    # loss_content = content_loss(vgg_dir,enhanced,dslr_image,batch_size)
    # loss_color = color_loss(enhanced, dslr_image, batch_size)
    # loss_tv = variation_loss(enhanced,PATCH_WIDTH,PATCH_HEIGHT,batch_size)

    # loss_generator = w_content * loss_content + w_texture * loss_texture + w_color * loss_color + w_tv * loss_tv
    loss_generator = tf.losses.absolute_difference(labels=dslr_image,
                                                   predictions=enhanced)
    loss_psnr = PSNR(enhanced, dslr_, PATCH_SIZE, batch_size)
    loss_ssim = MultiScaleSSIM(enhanced, dslr_image)

    # optimize parameters of image enhancement (generator) and discriminator networks
    generator_vars = [
        v for v in tf.global_variables() if v.name.startswith("generator")
    ]
    # discriminator_vars = [v for v in tf.global_variables() if v.name.startswith("discriminator")]

    train_step_gen = tf.train.AdamOptimizer(learning_rate).minimize(
        loss_generator, var_list=generator_vars)
    # train_step_disc = tf.train.AdamOptimizer(learning_rate).minimize(loss_discrim, var_list=discriminator_vars)

    saver = tf.train.Saver(var_list=generator_vars, max_to_keep=100)

    print('Initializing variables')
    sess.run(tf.global_variables_initializer())