예제 #1
0
def train():
    # Overrides the current default graph for the lifetime of the context
    with tf.Graph().as_default(), tf.device('/gpu:0'):  # Use GPU 0
        global_step = tf.Variable(0, trainable=False)

        # Training parameters
        # Count the number of training & eval data
        num_data = utils.count_text_lines(args.filenames_file)
        print('===> Train: There are totally %d training files' % (num_data))

        num_total_steps = 150000

        # Optimizer. Use exponential decay: decayed_lr = lr* decay_rate^ (global_steps/ decay_steps)
        decay_rate = 0.96

        decay_steps = (math.log(decay_rate) * num_total_steps) / math.log(
            args.min_lr * 1.0 / args.lr)
        print('args lr:', args.lr, args.min_lr)
        print('===> Decay steps:', decay_steps)
        learning_rate = tf.train.exponential_decay(args.lr,
                                                   global_step,
                                                   int(decay_steps),
                                                   decay_rate,
                                                   staircase=True)

        # Due to slim.batch_norm docs:
        # Note: when training, the moving_mean and moving_variance need to be updated.
        # By default the update ops are placed in `tf.GraphKeys.UPDATE_OPS`, so they
        # need to be added as a dependency to the `train_op`. For example:

        # ```python
        #   update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        #   with tf.control_dependencies(update_ops):
        #     train_op = optimizer.minimize(loss)
        # ```
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            opt_step = tf.train.AdamOptimizer(learning_rate)

        # Load data
        data_loader = Dataloader(train_dataloader_params,
                                 shuffle=True)  # shuffle

        I1_batch = data_loader.I1_batch
        I2_batch = data_loader.I2_batch
        I1_aug_batch = data_loader.I1_aug_batch
        I2_aug_batch = data_loader.I2_aug_batch
        I_batch = data_loader.I_batch
        I_prime_batch = data_loader.I_prime_batch
        pts1_batch = data_loader.pts1_batch
        gt_batch = data_loader.gt_batch
        patch_indices_batch = data_loader.patch_indices_batch

        # Split on multiple GPU
        I1_splits = tf.split(I1_batch, args.num_gpus, 0)
        I2_splits = tf.split(I2_batch, args.num_gpus, 0)
        I1_aug_splits = tf.split(I1_aug_batch, args.num_gpus, 0)
        I2_aug_splits = tf.split(I2_aug_batch, args.num_gpus, 0)
        I_splits = tf.split(I_batch, args.num_gpus, 0)
        I_prime_splits = tf.split(I_prime_batch, args.num_gpus, 0)
        pts1_splits = tf.split(pts1_batch, args.num_gpus, 0)
        gt_splits = tf.split(gt_batch, args.num_gpus, 0)
        patch_indices_splits = tf.split(patch_indices_batch, args.num_gpus, 0)

        # Train on multiple GPU:
        multi_grads = []
        reuse_variables = None
        h_losses = []
        rec_losses = []
        ssim_losses = []
        l1_losses = []
        l1_smooth_losses = []
        ncc_losses = []
        model_params = homography_model_params(
            mode=args.mode,
            batch_size=int(args.batch_size / args.num_gpus),
            patch_size=args.patch_size,
            img_h=args.img_h,
            img_w=args.img_w,
            loss_type=args.loss_type,
            use_batch_norm=args.use_batch_norm,
            augment_list=args.augment_list,
            leftright_consistent_weight=args.leftright_consistent_weight)
        # Deal with sharable variables
        with tf.variable_scope(tf.get_variable_scope()):
            for i in range(args.num_gpus):
                with tf.device('/gpu:%d' % i):
                    model = HomographyModel(model_params,
                                            I1_splits[i],
                                            I2_splits[i],
                                            I1_aug_splits[i],
                                            I2_aug_splits[i],
                                            I_splits[i],
                                            I_prime_splits[i],
                                            pts1_splits[i],
                                            gt_splits[i],
                                            patch_indices_splits[i],
                                            reuse_variables=reuse_variables,
                                            model_index=i)
                    h_loss = model.h_loss
                    rec_loss = model.rec_loss
                    ssim_loss = model.ssim_loss
                    l1_loss = model.l1_loss
                    l1_smooth_loss = model.l1_smooth_loss
                    ncc_loss = model.ncc_loss

                    pred_I2 = model.pred_I2
                    I2 = model.I2
                    H_mat = model.H_mat
                    I1 = model.I1
                    I = model.I

                    I1_aug = model.I1_aug
                    I2_aug = model.I2_aug

                    h_losses.append(h_loss)
                    rec_losses.append(rec_loss)
                    ssim_losses.append(ssim_loss)
                    l1_losses.append(l1_loss)
                    l1_smooth_losses.append(l1_smooth_loss)
                    ncc_losses.append(ncc_loss)

                    reuse_variables = True
                    if args.loss_type == 'h_loss':
                        grads = opt_step.compute_gradients(h_loss)
                    elif args.loss_type == 'rec_loss':
                        grads = opt_step.compute_gradients(rec_loss)
                    elif args.loss_type == 'ssim_loss':
                        grads = opt_step.compute_gradients(ssim_loss)
                    elif args.loss_type == 'l1_loss':
                        grads = opt_step.compute_gradients(l1_loss)
                    elif args.loss_type == 'l1_smooth_loss':
                        grads = opt_step.compute_gradients(l1_smooth_loss)
                    elif args.loss_type == 'ncc_loss':
                        grads = opt_step.compute_gradients(ncc_loss)
                    else:
                        print('===> Loss type does not exist!')
                        exit(0)
                    print('====> Use loss type: ', args.loss_type)
                    time.sleep(2)
                    multi_grads.append(grads)
        # Take average of the grads
        grads = utils.get_average_grads(multi_grads)
        apply_grad_opt = opt_step.apply_gradients(grads,
                                                  global_step=global_step)
        total_h_loss = tf.reduce_mean(h_losses)
        total_rec_loss = tf.reduce_mean(rec_losses)
        total_ssim_loss = tf.reduce_mean(ssim_losses)
        total_l1_loss = tf.reduce_mean(l1_losses)
        total_l1_smooth_loss = tf.reduce_mean(l1_smooth_losses)
        total_ncc_loss = tf.reduce_mean(ncc_losses)
        with tf.name_scope('Losses'):
            tf.summary.scalar('Learning_rate', learning_rate)
            tf.summary.scalar('Total_h_loss', total_h_loss)
            tf.summary.scalar('Total_rec_loss', total_rec_loss)
            tf.summary.scalar('Total_ssim_loss', total_ssim_loss)
            tf.summary.scalar('Total_l1_loss', total_l1_loss)
            tf.summary.scalar('Total_l1_smooth_loss', total_l1_smooth_loss)
            tf.summary.scalar('Total_ncc_loss', total_ncc_loss)
        summary_opt = tf.summary.merge_all()
        # Create a session
        gpu_options = tf.GPUOptions(
            allow_growth=True
        )  # Does not pre-allocate large, increase if needed
        config = tf.ConfigProto(
            allow_soft_placement=True, gpu_options=gpu_options
        )  # soft_placement allows to work on CPUs if GPUs are not available

        sess = tf.Session(config=config)

        # Saver
        log_name = args.loss_type
        summary_writer = tf.summary.FileWriter(args.log_dir, sess.graph)
        train_saver = tf.train.Saver(max_to_keep=5)  # Keep maximum 5 models

        # Initialize
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        # Threads coordinator
        coordinator = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coordinator)

        # Restore
        if args.resume:
            train_saver.restore(sess,
                                tf.train.latest_checkpoint(args.model_dir))
            if args.retrain:
                sess.run(global_step.assign(0))

        # Index of the image want to display
        index = 0
        h_total_loss_value = 0
        rec_total_loss_value = 0
        ssim_total_loss_value = 0
        l1_total_loss_value = 0
        l1_smooth_total_loss_value = 0
        ncc_total_loss_value = 0

        start_step = global_step.eval(session=sess)
        print('===> Start step:', start_step)

        # Start training
        for step in range(start_step, start_step + num_total_steps):
            if args.visual:
                _, h_loss_value, rec_loss_value, ssim_loss_value, l1_loss_value, l1_smooth_loss_value, ncc_loss_value, lr_value, H_mat_value, pred_I2_value, I2_value, I1_value, I1_aug_value, I2_aug_value, I_value = sess.run(
                    [
                        apply_grad_opt, total_h_loss, total_rec_loss,
                        total_ssim_loss, total_l1_loss, total_l1_smooth_loss,
                        total_ncc_loss, learning_rate, H_mat, pred_I2, I2, I1,
                        I1_aug, I2_aug, I
                    ])
            elif args.loss_type == "l1_loss" and not args.visual:
                _, h_loss_value, l1_loss_value, l1_smooth_loss_value, lr_value = sess.run(
                    [
                        apply_grad_opt, total_h_loss, total_l1_loss,
                        total_l1_smooth_loss, learning_rate
                    ])
                h_total_loss_value += h_loss_value
                l1_total_loss_value += l1_loss_value
                l1_smooth_total_loss_value += l1_smooth_loss_value
                if step % 100 == 0:
                    total_time = utils.progress_bar(
                        step, num_total_steps + start_step - 1,
                        'Train: 1, h_loss %4.3f, l1_loss %.6f, l1_smooth_loss %.6f, lr %.6f'
                        %
                        (h_total_loss_value /
                         (step - start_step + 1), l1_total_loss_value /
                         (step - start_step + 1), l1_smooth_total_loss_value /
                         (step - start_step + 1), lr_value))

            else:
                _, h_loss_value, rec_loss_value, ssim_loss_value, l1_loss_value, l1_smooth_loss_value, ncc_loss_value, lr_value = sess.run(
                    [
                        apply_grad_opt, total_h_loss, total_rec_loss,
                        total_ssim_loss, total_l1_loss, total_l1_smooth_loss,
                        total_ncc_loss, learning_rate
                    ])
                h_total_loss_value += h_loss_value
                rec_total_loss_value += rec_loss_value
                ssim_total_loss_value += ssim_loss_value
                l1_total_loss_value += l1_loss_value
                l1_smooth_total_loss_value += l1_smooth_loss_value
                ncc_total_loss_value += ncc_loss_value
                if step % 100 == 0:
                    total_time = utils.progress_bar(
                        step, num_total_steps + start_step - 1,
                        'Train: 1, h_loss %4.3f, rec_loss %4.3f, ssim_loss %.6f. l1_loss %.6f, l1_smooth_loss %.6f, ncc_loss %.6f, lr %.6f'
                        %
                        (h_total_loss_value /
                         (step - start_step + 1), rec_total_loss_value /
                         (step - start_step + 1), ssim_total_loss_value /
                         (step - start_step + 1), l1_total_loss_value /
                         (step - start_step + 1), l1_smooth_total_loss_value /
                         (step - start_step + 1), ncc_total_loss_value /
                         (step - start_step + 1), lr_value))

            # Tensorboard
            if step % 1000 == 0:
                summary_str = sess.run(summary_opt)
                summary_writer.add_summary(summary_str, global_step=step)
            if step and step % 1000 == 0:
                train_saver.save(sess,
                                 args.model_dir + args.model_name,
                                 global_step=step)

            if args.visual and step % 1 == 0:
                if 'normalize' in args.augment_list:
                    pred_I2_sample_value = utils.denorm_img(
                        pred_I2_value[index, :, :, 0]).astype(np.uint8)
                    I2_sample_value = utils.denorm_img(
                        I2_value[index, :, :, 0]).astype(np.uint8)
                    I1_sample_value = utils.denorm_img(
                        I1_value[index, :, :, 0]).astype(np.uint8)
                    I1_aug_sample_value = utils.denorm_img(
                        I1_aug_value[index, :, :, 0]).astype(np.uint8)
                    I2_aug_sample_value = utils.denorm_img(
                        I2_aug_value[index, :, :, 0]).astype(np.uint8)
                    I_sample_value = utils.denorm_img(
                        I_value[index, ...]).astype(np.uint8)
                else:
                    pred_I2_sample_value = pred_I2_value[index, :, :,
                                                         0].astype(np.uint8)
                    I2_sample_value = I2_value[index, :, :, 0].astype(np.uint8)
                    I1_sample_value = I1_value[index, :, :, 0].astype(np.uint8)
                    I1_aug_sample_value = I1_aug_value[index, :, :,
                                                       0].astype(np.uint8)
                    I2_aug_sample_value = I2_aug_value[index, :, :,
                                                       0].astype(np.uint8)
                    I_sample_value = I_value[index, ...].astype(np.uint8)
                plt.subplot(3, 1, 1)
                plt.imshow(np.concatenate(
                    [pred_I2_sample_value, I2_sample_value], 1),
                           cmap='gray')
                plt.title('Pred I2 vs I2')
                plt.subplot(3, 1, 2)
                plt.imshow(np.concatenate(
                    [I1_aug_sample_value, I2_aug_sample_value], 1),
                           cmap='gray')
                plt.title('I1_aug vs I2_aug')
                plt.subplot(3, 1, 3)
                plt.imshow(I_sample_value if I_sample_value.shape[2] ==
                           3 else I_sample_value[:, :, 0])
                plt.title('I')
                plt.show()
                plt.pause(0.05)
        # Save the final model
        train_saver.save(sess,
                         args.model_dir + args.model_name,
                         global_step=step)
예제 #2
0
    def __init__(self):
        # Overrides the current default graph for the lifetime of the context
        with tf.device('/gpu:0'):  # Use GPU 0

            # Count the number of eval data
            num_data = utils.count_text_lines(args.test_filenames_file)
            print('===> Test: There are totally %d Test files' % (num_data))

            steps_per_epoch = np.ceil(num_data / args.batch_size).astype(
                np.int32)
            self.num_total_steps = 3 * steps_per_epoch  # Test 3 epoches

            # Load data
            data_loader = Dataloader(test_dataloader_params,
                                     shuffle=True)  # No shuffle

            I1_batch = data_loader.I1_batch
            I2_batch = data_loader.I2_batch
            I1_aug_batch = data_loader.I1_aug_batch
            I2_aug_batch = data_loader.I2_aug_batch
            I_batch = data_loader.I_batch
            I_prime_batch = data_loader.I_prime_batch
            pts1_batch = data_loader.pts1_batch
            gt_batch = data_loader.gt_batch
            patch_indices_batch = data_loader.patch_indices_batch

            # Split on multiple GPU
            I1_splits = tf.split(I1_batch, args.num_gpus, 0)
            I2_splits = tf.split(I2_batch, args.num_gpus, 0)
            I1_aug_splits = tf.split(I1_aug_batch, args.num_gpus, 0)
            I2_aug_splits = tf.split(I2_aug_batch, args.num_gpus, 0)
            I_splits = tf.split(I_batch, args.num_gpus, 0)
            I_prime_splits = tf.split(I_prime_batch, args.num_gpus, 0)
            pts1_splits = tf.split(pts1_batch, args.num_gpus, 0)
            gt_splits = tf.split(gt_batch, args.num_gpus, 0)
            patch_indices_splits = tf.split(patch_indices_batch, args.num_gpus,
                                            0)

            # Train on multiple GPU:
            reuse_variables = None
            h_losses = []
            rec_losses = []
            ssim_losses = []
            l1_losses = []
            l1_smooth_losses = []
            num_fails = []
            model_params = homography_model_params(
                mode='test',
                batch_size=int(args.batch_size / args.num_gpus),
                patch_size=args.patch_size,
                img_h=args.img_h,
                img_w=args.img_w,
                loss_type=args.loss_type,
                use_batch_norm=args.use_batch_norm,
                augment_list=args.augment_list,
                leftright_consistent_weight=args.leftright_consistent_weight)
            # Deal with sharable variables
            with tf.variable_scope(tf.get_variable_scope()):
                for i in range(args.num_gpus):
                    with tf.device('/gpu:%d' % i):
                        model = HomographyModel(
                            model_params,
                            I1_splits[i],
                            I2_splits[i],
                            I1_aug_splits[i],
                            I2_aug_splits[i],
                            I_splits[i],
                            I_prime_splits[i],
                            pts1_splits[i],
                            gt_splits[i],
                            patch_indices_splits[i],
                            reuse_variables=reuse_variables,
                            model_index=i)
                        # Debug test splits
                        #test_synthetic_dataloader(data_loader, True, I1_splits[i], I2_splits[i], I_splits[i], I_prime_splits[i], pts1_splits[i], gt_splits[i], patch_indices_splits[i])

                        reuse_variables = True
                        # In testing, use bounded_h_loss (under successful condition)
                        h_loss = model.bounded_h_loss
                        rec_loss = model.rec_loss
                        ssim_loss = model.ssim_loss
                        l1_loss = model.l1_loss
                        l1_smooth_loss = model.l1_smooth_loss
                        num_fail = model.num_fail

                        self.pred_I2 = model.pred_I2
                        self.I2 = model.I2
                        self.H_mat = model.H_mat
                        self.I1 = model.I1
                        self.I1_aug = model.I1_aug
                        self.I2_aug = model.I2_aug
                        self.I = model.I
                        self.I_prime = model.I_prime
                        self.pts1 = model.pts_1
                        self.gt = model.gt
                        self.pred_h4p = model.pred_h4p

                        h_losses.append(h_loss)
                        rec_losses.append(rec_loss)
                        ssim_losses.append(ssim_loss)
                        l1_losses.append(l1_loss)
                        l1_smooth_losses.append(l1_smooth_loss)
                        num_fails.append(num_fail)
            self.total_h_loss = tf.reduce_mean(h_losses)
            self.total_num_fail = tf.reduce_sum(num_fails)
            self.total_rec_loss = tf.reduce_mean(rec_losses)
            self.total_ssim_loss = tf.reduce_mean(ssim_losses)
            self.total_l1_loss = tf.reduce_mean(l1_losses)
            self.total_l1_smooth_loss = tf.reduce_mean(l1_smooth_losses)
            with tf.name_scope('Losses'):
                tf.summary.scalar('Total_h_loss', self.total_h_loss)
                tf.summary.scalar('Total_rec_loss', self.total_rec_loss)
                tf.summary.scalar('Total_ssim_loss', self.total_ssim_loss)
                tf.summary.scalar('Total_l1_loss', self.total_l1_loss)
                tf.summary.scalar('Total_l1_smooth_loss',
                                  self.total_l1_smooth_loss)
            self.summary_opt = tf.summary.merge_all()
    def __init__(self):
        # Overrides the current default graph for the lifetime of the context
        with tf.device('/gpu:0'):  # Use GPU 0

            # Count the number of eval data
            num_data = utils.count_text_lines(args.test_filenames_file)
            print('===> Test: There are totally %d Test files' % (num_data))

            steps_per_epoch = np.ceil(num_data / args.batch_size).astype(
                np.int32)
            self.num_total_steps = 2 * steps_per_epoch  # Test 2 epoches

            # Load data
            data_loader = Dataloader(test_dataloader_params,
                                     shuffle=False)  # No shuffle
            # Debug test train_dataloader
            # test_synthetic_dataloader(data_loader, True)

            I1_batch = data_loader.I1_batch
            I2_batch = data_loader.I2_batch
            I1_aug_batch = data_loader.I1_aug_batch
            I2_aug_batch = data_loader.I2_aug_batch
            I_batch = data_loader.I_batch
            I_prime_batch = data_loader.I_prime_batch
            full_I_batch = data_loader.full_I_batch
            full_I_prime_batch = data_loader.full_I_prime_batch
            pts1_batch = data_loader.pts1_batch
            gt_batch = data_loader.gt_batch
            patch_indices_batch = data_loader.patch_indices_batch

            # Split on multiple GPU
            I1_splits = tf.split(I1_batch, args.num_gpus, 0)
            I2_splits = tf.split(I2_batch, args.num_gpus, 0)
            I1_aug_splits = tf.split(I1_aug_batch, args.num_gpus, 0)
            I2_aug_splits = tf.split(I2_aug_batch, args.num_gpus, 0)
            I_splits = tf.split(I_batch, args.num_gpus, 0)
            I_prime_splits = tf.split(I_prime_batch, args.num_gpus, 0)
            pts1_splits = tf.split(pts1_batch, args.num_gpus, 0)
            gt_splits = tf.split(gt_batch, args.num_gpus, 0)
            patch_indices_splits = tf.split(patch_indices_batch, args.num_gpus,
                                            0)
            # Train on multiple GPU:
            reuse_variables = None

            rec_losses = []
            ssim_losses = []
            l1_losses = []
            l1_smooth_losses = []
            num_fails = []
            model_params = homography_model_params(
                mode='test',
                batch_size=int(args.batch_size / args.num_gpus),
                patch_size=args.patch_size,
                img_h=args.img_h,
                img_w=args.img_w,
                loss_type=args.loss_type,
                use_batch_norm=args.use_batch_norm,
                augment_list=args.augment_list,
                leftright_consistent_weight=args.leftright_consistent_weight)
            # Deal with sharable variables
            with tf.variable_scope(tf.get_variable_scope()):
                for i in range(args.num_gpus):
                    with tf.device('/gpu:%d' % i):
                        # Note that ground truth gt_ here is the correspondences between pairs of images
                        # and are different from the actual delta movement of fourpoints that we want to find
                        # This ground truth is used for evaluating the estimated homography on real image data.
                        model = HomographyModel(
                            model_params,
                            I1_splits[i],
                            I2_splits[i],
                            I1_aug_splits[i],
                            I2_aug_splits[i],
                            I_splits[i],
                            I_prime_splits[i],
                            pts1_splits[i],
                            gt_splits[i],
                            patch_indices_splits[i],
                            reuse_variables=reuse_variables,
                            model_index=i)
                        # Debug test splits
                        #test_synthetic_dataloader(data_loader, True, I1_splits[i], I2_splits[i], I_splits[i], I_prime_splits[i], pts1_splits[i], gt_splits[i], patch_indices_splits[i])

                        reuse_variables = True
                        rec_loss = model.rec_loss
                        ssim_loss = model.ssim_loss
                        l1_loss = model.l1_loss
                        l1_smooth_loss = model.l1_smooth_loss
                        if i == 0:
                            self.pred_I2 = model.pred_I2
                            self.I = model.I
                            self.I_prime = model.I_prime

                            self.I1_aug = model.I1_aug
                            self.I2_aug = model.I2_aug
                            self.pred_h4p = model.pred_h4p
                            self.gt_corr = model.gt
                            self.pts1 = model.pts_1
                        else:
                            self.pred_I2 = tf.concat(
                                [self.pred_I2, model.pred_I2], axis=0)
                            self.I = tf.concat([self.I, model.I], axis=0)
                            self.I_prime = tf.concat(
                                [self.I_prime, model.I_prime], axis=0)
                            self.I1_aug = tf.concat(
                                [self.I1_aug, model.I1_aug], axis=0)
                            self.I2_aug = tf.concat(
                                [self.I2_aug, model.I2_aug], axis=0)
                            self.pred_h4p = tf.concat(
                                [self.pred_h4p, model.pred_h4p], axis=0)
                            self.gt_corr = tf.concat([self.gt_corr, model.gt],
                                                     axis=0)
                            self.pts1 = tf.concat([self.pts1, model.pts_1],
                                                  axis=0)
                        rec_losses.append(rec_loss)
                        ssim_losses.append(ssim_loss)
                        l1_losses.append(l1_loss)
                        l1_smooth_losses.append(l1_smooth_loss)
            self.total_rec_loss = tf.reduce_mean(rec_losses)
            self.total_ssim_loss = tf.reduce_mean(ssim_losses)
            self.total_l1_loss = tf.reduce_mean(l1_losses)
            self.total_l1_smooth_loss = tf.reduce_mean(l1_smooth_losses)
            self.full_I = full_I_batch
            self.full_I_prime = full_I_prime_batch

            with tf.name_scope('Losses'):
                tf.summary.scalar('Total_rec_loss', self.total_rec_loss)
                tf.summary.scalar('Total_ssim_loss', self.total_ssim_loss)
                tf.summary.scalar('Total_l1_loss', self.total_l1_loss)
                tf.summary.scalar('Total_l1_smooth_loss',
                                  self.total_l1_smooth_loss)
            self.summary_opt = tf.summary.merge_all()