示例#1
0
    def __init__(self):
        # Overrides the current default graph for the lifetime of the context
        with tf.device('/gpu:0'):  # Use GPU 0

            # Count the number of eval data
            num_data = utils.count_text_lines(args.test_filenames_file)
            print('===> Test: There are totally %d Test files' % (num_data))

            steps_per_epoch = np.ceil(num_data / args.batch_size).astype(
                np.int32)
            self.num_total_steps = 3 * steps_per_epoch  # Test 3 epoches

            # Load data
            data_loader = Dataloader(test_dataloader_params,
                                     shuffle=True)  # No shuffle

            I1_batch = data_loader.I1_batch
            I2_batch = data_loader.I2_batch
            I1_aug_batch = data_loader.I1_aug_batch
            I2_aug_batch = data_loader.I2_aug_batch
            I_batch = data_loader.I_batch
            I_prime_batch = data_loader.I_prime_batch
            pts1_batch = data_loader.pts1_batch
            gt_batch = data_loader.gt_batch
            patch_indices_batch = data_loader.patch_indices_batch

            # Split on multiple GPU
            I1_splits = tf.split(I1_batch, args.num_gpus, 0)
            I2_splits = tf.split(I2_batch, args.num_gpus, 0)
            I1_aug_splits = tf.split(I1_aug_batch, args.num_gpus, 0)
            I2_aug_splits = tf.split(I2_aug_batch, args.num_gpus, 0)
            I_splits = tf.split(I_batch, args.num_gpus, 0)
            I_prime_splits = tf.split(I_prime_batch, args.num_gpus, 0)
            pts1_splits = tf.split(pts1_batch, args.num_gpus, 0)
            gt_splits = tf.split(gt_batch, args.num_gpus, 0)
            patch_indices_splits = tf.split(patch_indices_batch, args.num_gpus,
                                            0)

            # Train on multiple GPU:
            reuse_variables = None
            h_losses = []
            rec_losses = []
            ssim_losses = []
            l1_losses = []
            l1_smooth_losses = []
            num_fails = []
            model_params = homography_model_params(
                mode='test',
                batch_size=int(args.batch_size / args.num_gpus),
                patch_size=args.patch_size,
                img_h=args.img_h,
                img_w=args.img_w,
                loss_type=args.loss_type,
                use_batch_norm=args.use_batch_norm,
                augment_list=args.augment_list,
                leftright_consistent_weight=args.leftright_consistent_weight)
            # Deal with sharable variables
            with tf.variable_scope(tf.get_variable_scope()):
                for i in range(args.num_gpus):
                    with tf.device('/gpu:%d' % i):
                        model = HomographyModel(
                            model_params,
                            I1_splits[i],
                            I2_splits[i],
                            I1_aug_splits[i],
                            I2_aug_splits[i],
                            I_splits[i],
                            I_prime_splits[i],
                            pts1_splits[i],
                            gt_splits[i],
                            patch_indices_splits[i],
                            reuse_variables=reuse_variables,
                            model_index=i)
                        # Debug test splits
                        #test_synthetic_dataloader(data_loader, True, I1_splits[i], I2_splits[i], I_splits[i], I_prime_splits[i], pts1_splits[i], gt_splits[i], patch_indices_splits[i])

                        reuse_variables = True
                        # In testing, use bounded_h_loss (under successful condition)
                        h_loss = model.bounded_h_loss
                        rec_loss = model.rec_loss
                        ssim_loss = model.ssim_loss
                        l1_loss = model.l1_loss
                        l1_smooth_loss = model.l1_smooth_loss
                        num_fail = model.num_fail

                        self.pred_I2 = model.pred_I2
                        self.I2 = model.I2
                        self.H_mat = model.H_mat
                        self.I1 = model.I1
                        self.I1_aug = model.I1_aug
                        self.I2_aug = model.I2_aug
                        self.I = model.I
                        self.I_prime = model.I_prime
                        self.pts1 = model.pts_1
                        self.gt = model.gt
                        self.pred_h4p = model.pred_h4p

                        h_losses.append(h_loss)
                        rec_losses.append(rec_loss)
                        ssim_losses.append(ssim_loss)
                        l1_losses.append(l1_loss)
                        l1_smooth_losses.append(l1_smooth_loss)
                        num_fails.append(num_fail)
            self.total_h_loss = tf.reduce_mean(h_losses)
            self.total_num_fail = tf.reduce_sum(num_fails)
            self.total_rec_loss = tf.reduce_mean(rec_losses)
            self.total_ssim_loss = tf.reduce_mean(ssim_losses)
            self.total_l1_loss = tf.reduce_mean(l1_losses)
            self.total_l1_smooth_loss = tf.reduce_mean(l1_smooth_losses)
            with tf.name_scope('Losses'):
                tf.summary.scalar('Total_h_loss', self.total_h_loss)
                tf.summary.scalar('Total_rec_loss', self.total_rec_loss)
                tf.summary.scalar('Total_ssim_loss', self.total_ssim_loss)
                tf.summary.scalar('Total_l1_loss', self.total_l1_loss)
                tf.summary.scalar('Total_l1_smooth_loss',
                                  self.total_l1_smooth_loss)
            self.summary_opt = tf.summary.merge_all()
示例#2
0
def train():
    # Overrides the current default graph for the lifetime of the context
    with tf.Graph().as_default(), tf.device('/gpu:0'):  # Use GPU 0
        global_step = tf.Variable(0, trainable=False)

        # Training parameters
        # Count the number of training & eval data
        num_data = utils.count_text_lines(args.filenames_file)
        print('===> Train: There are totally %d training files' % (num_data))

        num_total_steps = 150000

        # Optimizer. Use exponential decay: decayed_lr = lr* decay_rate^ (global_steps/ decay_steps)
        decay_rate = 0.96

        decay_steps = (math.log(decay_rate) * num_total_steps) / math.log(
            args.min_lr * 1.0 / args.lr)
        print('args lr:', args.lr, args.min_lr)
        print('===> Decay steps:', decay_steps)
        learning_rate = tf.train.exponential_decay(args.lr,
                                                   global_step,
                                                   int(decay_steps),
                                                   decay_rate,
                                                   staircase=True)

        # Due to slim.batch_norm docs:
        # Note: when training, the moving_mean and moving_variance need to be updated.
        # By default the update ops are placed in `tf.GraphKeys.UPDATE_OPS`, so they
        # need to be added as a dependency to the `train_op`. For example:

        # ```python
        #   update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        #   with tf.control_dependencies(update_ops):
        #     train_op = optimizer.minimize(loss)
        # ```
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            opt_step = tf.train.AdamOptimizer(learning_rate)

        # Load data
        data_loader = Dataloader(train_dataloader_params,
                                 shuffle=True)  # shuffle

        I1_batch = data_loader.I1_batch
        I2_batch = data_loader.I2_batch
        I1_aug_batch = data_loader.I1_aug_batch
        I2_aug_batch = data_loader.I2_aug_batch
        I_batch = data_loader.I_batch
        I_prime_batch = data_loader.I_prime_batch
        pts1_batch = data_loader.pts1_batch
        gt_batch = data_loader.gt_batch
        patch_indices_batch = data_loader.patch_indices_batch

        # Split on multiple GPU
        I1_splits = tf.split(I1_batch, args.num_gpus, 0)
        I2_splits = tf.split(I2_batch, args.num_gpus, 0)
        I1_aug_splits = tf.split(I1_aug_batch, args.num_gpus, 0)
        I2_aug_splits = tf.split(I2_aug_batch, args.num_gpus, 0)
        I_splits = tf.split(I_batch, args.num_gpus, 0)
        I_prime_splits = tf.split(I_prime_batch, args.num_gpus, 0)
        pts1_splits = tf.split(pts1_batch, args.num_gpus, 0)
        gt_splits = tf.split(gt_batch, args.num_gpus, 0)
        patch_indices_splits = tf.split(patch_indices_batch, args.num_gpus, 0)

        # Train on multiple GPU:
        multi_grads = []
        reuse_variables = None
        h_losses = []
        rec_losses = []
        ssim_losses = []
        l1_losses = []
        l1_smooth_losses = []
        ncc_losses = []
        model_params = homography_model_params(
            mode=args.mode,
            batch_size=int(args.batch_size / args.num_gpus),
            patch_size=args.patch_size,
            img_h=args.img_h,
            img_w=args.img_w,
            loss_type=args.loss_type,
            use_batch_norm=args.use_batch_norm,
            augment_list=args.augment_list,
            leftright_consistent_weight=args.leftright_consistent_weight)
        # Deal with sharable variables
        with tf.variable_scope(tf.get_variable_scope()):
            for i in range(args.num_gpus):
                with tf.device('/gpu:%d' % i):
                    model = HomographyModel(model_params,
                                            I1_splits[i],
                                            I2_splits[i],
                                            I1_aug_splits[i],
                                            I2_aug_splits[i],
                                            I_splits[i],
                                            I_prime_splits[i],
                                            pts1_splits[i],
                                            gt_splits[i],
                                            patch_indices_splits[i],
                                            reuse_variables=reuse_variables,
                                            model_index=i)
                    h_loss = model.h_loss
                    rec_loss = model.rec_loss
                    ssim_loss = model.ssim_loss
                    l1_loss = model.l1_loss
                    l1_smooth_loss = model.l1_smooth_loss
                    ncc_loss = model.ncc_loss

                    pred_I2 = model.pred_I2
                    I2 = model.I2
                    H_mat = model.H_mat
                    I1 = model.I1
                    I = model.I

                    I1_aug = model.I1_aug
                    I2_aug = model.I2_aug

                    h_losses.append(h_loss)
                    rec_losses.append(rec_loss)
                    ssim_losses.append(ssim_loss)
                    l1_losses.append(l1_loss)
                    l1_smooth_losses.append(l1_smooth_loss)
                    ncc_losses.append(ncc_loss)

                    reuse_variables = True
                    if args.loss_type == 'h_loss':
                        grads = opt_step.compute_gradients(h_loss)
                    elif args.loss_type == 'rec_loss':
                        grads = opt_step.compute_gradients(rec_loss)
                    elif args.loss_type == 'ssim_loss':
                        grads = opt_step.compute_gradients(ssim_loss)
                    elif args.loss_type == 'l1_loss':
                        grads = opt_step.compute_gradients(l1_loss)
                    elif args.loss_type == 'l1_smooth_loss':
                        grads = opt_step.compute_gradients(l1_smooth_loss)
                    elif args.loss_type == 'ncc_loss':
                        grads = opt_step.compute_gradients(ncc_loss)
                    else:
                        print('===> Loss type does not exist!')
                        exit(0)
                    print('====> Use loss type: ', args.loss_type)
                    time.sleep(2)
                    multi_grads.append(grads)
        # Take average of the grads
        grads = utils.get_average_grads(multi_grads)
        apply_grad_opt = opt_step.apply_gradients(grads,
                                                  global_step=global_step)
        total_h_loss = tf.reduce_mean(h_losses)
        total_rec_loss = tf.reduce_mean(rec_losses)
        total_ssim_loss = tf.reduce_mean(ssim_losses)
        total_l1_loss = tf.reduce_mean(l1_losses)
        total_l1_smooth_loss = tf.reduce_mean(l1_smooth_losses)
        total_ncc_loss = tf.reduce_mean(ncc_losses)
        with tf.name_scope('Losses'):
            tf.summary.scalar('Learning_rate', learning_rate)
            tf.summary.scalar('Total_h_loss', total_h_loss)
            tf.summary.scalar('Total_rec_loss', total_rec_loss)
            tf.summary.scalar('Total_ssim_loss', total_ssim_loss)
            tf.summary.scalar('Total_l1_loss', total_l1_loss)
            tf.summary.scalar('Total_l1_smooth_loss', total_l1_smooth_loss)
            tf.summary.scalar('Total_ncc_loss', total_ncc_loss)
        summary_opt = tf.summary.merge_all()
        # Create a session
        gpu_options = tf.GPUOptions(
            allow_growth=True
        )  # Does not pre-allocate large, increase if needed
        config = tf.ConfigProto(
            allow_soft_placement=True, gpu_options=gpu_options
        )  # soft_placement allows to work on CPUs if GPUs are not available

        sess = tf.Session(config=config)

        # Saver
        log_name = args.loss_type
        summary_writer = tf.summary.FileWriter(args.log_dir, sess.graph)
        train_saver = tf.train.Saver(max_to_keep=5)  # Keep maximum 5 models

        # Initialize
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        # Threads coordinator
        coordinator = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coordinator)

        # Restore
        if args.resume:
            train_saver.restore(sess,
                                tf.train.latest_checkpoint(args.model_dir))
            if args.retrain:
                sess.run(global_step.assign(0))

        # Index of the image want to display
        index = 0
        h_total_loss_value = 0
        rec_total_loss_value = 0
        ssim_total_loss_value = 0
        l1_total_loss_value = 0
        l1_smooth_total_loss_value = 0
        ncc_total_loss_value = 0

        start_step = global_step.eval(session=sess)
        print('===> Start step:', start_step)

        # Start training
        for step in range(start_step, start_step + num_total_steps):
            if args.visual:
                _, h_loss_value, rec_loss_value, ssim_loss_value, l1_loss_value, l1_smooth_loss_value, ncc_loss_value, lr_value, H_mat_value, pred_I2_value, I2_value, I1_value, I1_aug_value, I2_aug_value, I_value = sess.run(
                    [
                        apply_grad_opt, total_h_loss, total_rec_loss,
                        total_ssim_loss, total_l1_loss, total_l1_smooth_loss,
                        total_ncc_loss, learning_rate, H_mat, pred_I2, I2, I1,
                        I1_aug, I2_aug, I
                    ])
            elif args.loss_type == "l1_loss" and not args.visual:
                _, h_loss_value, l1_loss_value, l1_smooth_loss_value, lr_value = sess.run(
                    [
                        apply_grad_opt, total_h_loss, total_l1_loss,
                        total_l1_smooth_loss, learning_rate
                    ])
                h_total_loss_value += h_loss_value
                l1_total_loss_value += l1_loss_value
                l1_smooth_total_loss_value += l1_smooth_loss_value
                if step % 100 == 0:
                    total_time = utils.progress_bar(
                        step, num_total_steps + start_step - 1,
                        'Train: 1, h_loss %4.3f, l1_loss %.6f, l1_smooth_loss %.6f, lr %.6f'
                        %
                        (h_total_loss_value /
                         (step - start_step + 1), l1_total_loss_value /
                         (step - start_step + 1), l1_smooth_total_loss_value /
                         (step - start_step + 1), lr_value))

            else:
                _, h_loss_value, rec_loss_value, ssim_loss_value, l1_loss_value, l1_smooth_loss_value, ncc_loss_value, lr_value = sess.run(
                    [
                        apply_grad_opt, total_h_loss, total_rec_loss,
                        total_ssim_loss, total_l1_loss, total_l1_smooth_loss,
                        total_ncc_loss, learning_rate
                    ])
                h_total_loss_value += h_loss_value
                rec_total_loss_value += rec_loss_value
                ssim_total_loss_value += ssim_loss_value
                l1_total_loss_value += l1_loss_value
                l1_smooth_total_loss_value += l1_smooth_loss_value
                ncc_total_loss_value += ncc_loss_value
                if step % 100 == 0:
                    total_time = utils.progress_bar(
                        step, num_total_steps + start_step - 1,
                        'Train: 1, h_loss %4.3f, rec_loss %4.3f, ssim_loss %.6f. l1_loss %.6f, l1_smooth_loss %.6f, ncc_loss %.6f, lr %.6f'
                        %
                        (h_total_loss_value /
                         (step - start_step + 1), rec_total_loss_value /
                         (step - start_step + 1), ssim_total_loss_value /
                         (step - start_step + 1), l1_total_loss_value /
                         (step - start_step + 1), l1_smooth_total_loss_value /
                         (step - start_step + 1), ncc_total_loss_value /
                         (step - start_step + 1), lr_value))

            # Tensorboard
            if step % 1000 == 0:
                summary_str = sess.run(summary_opt)
                summary_writer.add_summary(summary_str, global_step=step)
            if step and step % 1000 == 0:
                train_saver.save(sess,
                                 args.model_dir + args.model_name,
                                 global_step=step)

            if args.visual and step % 1 == 0:
                if 'normalize' in args.augment_list:
                    pred_I2_sample_value = utils.denorm_img(
                        pred_I2_value[index, :, :, 0]).astype(np.uint8)
                    I2_sample_value = utils.denorm_img(
                        I2_value[index, :, :, 0]).astype(np.uint8)
                    I1_sample_value = utils.denorm_img(
                        I1_value[index, :, :, 0]).astype(np.uint8)
                    I1_aug_sample_value = utils.denorm_img(
                        I1_aug_value[index, :, :, 0]).astype(np.uint8)
                    I2_aug_sample_value = utils.denorm_img(
                        I2_aug_value[index, :, :, 0]).astype(np.uint8)
                    I_sample_value = utils.denorm_img(
                        I_value[index, ...]).astype(np.uint8)
                else:
                    pred_I2_sample_value = pred_I2_value[index, :, :,
                                                         0].astype(np.uint8)
                    I2_sample_value = I2_value[index, :, :, 0].astype(np.uint8)
                    I1_sample_value = I1_value[index, :, :, 0].astype(np.uint8)
                    I1_aug_sample_value = I1_aug_value[index, :, :,
                                                       0].astype(np.uint8)
                    I2_aug_sample_value = I2_aug_value[index, :, :,
                                                       0].astype(np.uint8)
                    I_sample_value = I_value[index, ...].astype(np.uint8)
                plt.subplot(3, 1, 1)
                plt.imshow(np.concatenate(
                    [pred_I2_sample_value, I2_sample_value], 1),
                           cmap='gray')
                plt.title('Pred I2 vs I2')
                plt.subplot(3, 1, 2)
                plt.imshow(np.concatenate(
                    [I1_aug_sample_value, I2_aug_sample_value], 1),
                           cmap='gray')
                plt.title('I1_aug vs I2_aug')
                plt.subplot(3, 1, 3)
                plt.imshow(I_sample_value if I_sample_value.shape[2] ==
                           3 else I_sample_value[:, :, 0])
                plt.title('I')
                plt.show()
                plt.pause(0.05)
        # Save the final model
        train_saver.save(sess,
                         args.model_dir + args.model_name,
                         global_step=step)
示例#3
0
def train(args):
    # Load data
    TrainDataset = SyntheticDataset(data_path=args.data_path,
                                    mode=args.mode,
                                    img_h=args.img_h,
                                    img_w=args.img_w,
                                    patch_size=args.patch_size,
                                    do_augment=args.do_augment)
    train_loader = DataLoader(TrainDataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
    print('===> Train: There are totally {} training files'.format(len(TrainDataset)))

    net = HomographyModel(args.use_batch_norm)
    if args.resume:
        model_path = os.path.join(args.model_dir, args.model_name)
        ckpt = torch.load(model_path)
        net.load_state_dict(ckpt.state_dict())
    if torch.cuda.is_available():
        net = net.cuda()

    optimizer = optim.Adam(net.parameters(), lr=args.lr)  # default as 0.0001
    decay_rate = 0.96
    step_size = (math.log(decay_rate) * args.max_epochs) / math.log(args.min_lr * 1.0 / args.lr)
    print('args lr:', args.lr, args.min_lr)
    print('===> Decay steps:', step_size)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=int(step_size), gamma=0.96)

    print("start training")
    writer = SummaryWriter(logdir=args.log_dir, flush_secs=60)
    score_print_fre = 100
    summary_fre = 1000
    model_save_fre = 4000
    glob_iter = 0
    t0 = time.time()

    for epoch in range(args.max_epochs):
        net.train()
        epoch_start = time.time()
        train_l1_loss = 0.0
        train_l1_smooth_loss = 0.0
        train_h_loss = 0.0

        for i, batch_value in enumerate(train_loader):
            I1_batch = batch_value[0].float()
            I2_batch = batch_value[1].float()
            I1_aug_batch = batch_value[2].float()
            I2_aug_batch = batch_value[3].float()
            I_batch = batch_value[4].float()
            I_prime_batch = batch_value[5].float()
            pts1_batch = batch_value[6].float()
            gt_batch = batch_value[7].float()
            patch_indices_batch = batch_value[8].float()

            if torch.cuda.is_available():
                I1_aug_batch = I1_aug_batch.cuda()
                I2_aug_batch = I2_aug_batch.cuda()
                I_batch = I_batch.cuda()
                pts1_batch = pts1_batch.cuda()
                gt_batch = gt_batch.cuda()
                patch_indices_batch = patch_indices_batch.cuda()

            # forward, backward, update weights
            optimizer.zero_grad()
            batch_out = net(I1_aug_batch, I2_aug_batch, I_batch, pts1_batch, gt_batch, patch_indices_batch)
            h_loss = batch_out['h_loss']
            rec_loss = batch_out['rec_loss']
            ssim_loss = batch_out['ssim_loss']
            l1_loss = batch_out['l1_loss']
            l1_smooth_loss = batch_out['l1_smooth_loss']
            ncc_loss = batch_out['ncc_loss']
            pred_I2 = batch_out['pred_I2']

            loss = l1_loss
            loss.backward()
            optimizer.step()

            train_l1_loss += loss.item()
            train_l1_smooth_loss += l1_smooth_loss.item()
            train_h_loss += h_loss.item()
            if (i + 1) % score_print_fre == 0 or (i + 1) == len(train_loader):
                print(
                    "Training: Epoch[{:0>3}/{:0>3}] Iter[{:0>3}]/[{:0>3}] l1 loss: {:.4f} "
                    "l1 smooth loss: {:.4f} h loss: {:.4f} lr={:.8f}".format(
                        epoch + 1, args.max_epochs, i + 1, len(train_loader), train_l1_loss / score_print_fre,
                        train_l1_smooth_loss / score_print_fre, train_h_loss / score_print_fre, scheduler.get_lr()[0]))
                train_l1_loss = 0.0
                train_l1_smooth_loss = 0.0
                train_h_loss = 0.0

            if glob_iter % summary_fre == 0:
                writer.add_scalar('learning_rate', scheduler.get_lr()[0], glob_iter)
                writer.add_scalar('h_loss', h_loss, glob_iter)
                writer.add_scalar('rec_loss', rec_loss, glob_iter)
                writer.add_scalar('ssim_loss', ssim_loss, glob_iter)
                writer.add_scalar('l1_loss', l1_loss, glob_iter)
                writer.add_scalar('l1_smooth_loss', l1_smooth_loss, glob_iter)
                writer.add_scalar('ncc_loss', ncc_loss, glob_iter)

                writer.add_image('I', utils.denorm_img(I_batch[0, ...].cpu().numpy()).astype(np.uint8)[:, :, ::-1],
                                 glob_iter, dataformats='HWC')
                writer.add_image('I_prime',
                                 utils.denorm_img(I_prime_batch[0, ...].numpy()).astype(np.uint8)[:, :, ::-1],
                                 glob_iter, dataformats='HWC')

                writer.add_image('I1_aug', utils.denorm_img(I1_aug_batch[0, 0, ...].cpu().numpy()).astype(np.uint8),
                                 glob_iter, dataformats='HW')
                writer.add_image('I2_aug', utils.denorm_img(I2_aug_batch[0, 0, ...].cpu().numpy()).astype(np.uint8),
                                 glob_iter, dataformats='HW')
                writer.add_image('pred_I2',
                                 utils.denorm_img(pred_I2[0, 0, ...].cpu().detach().numpy()).astype(np.uint8),
                                 glob_iter, dataformats='HW')

                writer.add_image('I2', utils.denorm_img(I2_batch[0, 0, ...].numpy()).astype(np.uint8), glob_iter,
                                 dataformats='HW')
                writer.add_image('I1', utils.denorm_img(I1_batch[0, 0, ...].numpy()).astype(np.uint8), glob_iter,
                                 dataformats='HW')

            # save model
            if glob_iter % model_save_fre == 0 and glob_iter != 0:
                filename = 'model' + '_iter_' + str(glob_iter) + '.pth'
                model_save_path = os.path.join(args.model_dir, filename)
                torch.save(net, model_save_path)

            glob_iter += 1
        scheduler.step()
        print("Epoch: {} epoch time: {:.1f}s".format(epoch, time.time() - epoch_start))

    elapsed_time = time.time() - t0
    print("Finished Training in {:.0f}h {:.0f}m {:.0f}s.".format(
        elapsed_time // 3600, (elapsed_time % 3600) // 60, (elapsed_time % 3600) % 60))
    def __init__(self):
        # Overrides the current default graph for the lifetime of the context
        with tf.device('/gpu:0'):  # Use GPU 0

            # Count the number of eval data
            num_data = utils.count_text_lines(args.test_filenames_file)
            print('===> Test: There are totally %d Test files' % (num_data))

            steps_per_epoch = np.ceil(num_data / args.batch_size).astype(
                np.int32)
            self.num_total_steps = 2 * steps_per_epoch  # Test 2 epoches

            # Load data
            data_loader = Dataloader(test_dataloader_params,
                                     shuffle=False)  # No shuffle
            # Debug test train_dataloader
            # test_synthetic_dataloader(data_loader, True)

            I1_batch = data_loader.I1_batch
            I2_batch = data_loader.I2_batch
            I1_aug_batch = data_loader.I1_aug_batch
            I2_aug_batch = data_loader.I2_aug_batch
            I_batch = data_loader.I_batch
            I_prime_batch = data_loader.I_prime_batch
            full_I_batch = data_loader.full_I_batch
            full_I_prime_batch = data_loader.full_I_prime_batch
            pts1_batch = data_loader.pts1_batch
            gt_batch = data_loader.gt_batch
            patch_indices_batch = data_loader.patch_indices_batch

            # Split on multiple GPU
            I1_splits = tf.split(I1_batch, args.num_gpus, 0)
            I2_splits = tf.split(I2_batch, args.num_gpus, 0)
            I1_aug_splits = tf.split(I1_aug_batch, args.num_gpus, 0)
            I2_aug_splits = tf.split(I2_aug_batch, args.num_gpus, 0)
            I_splits = tf.split(I_batch, args.num_gpus, 0)
            I_prime_splits = tf.split(I_prime_batch, args.num_gpus, 0)
            pts1_splits = tf.split(pts1_batch, args.num_gpus, 0)
            gt_splits = tf.split(gt_batch, args.num_gpus, 0)
            patch_indices_splits = tf.split(patch_indices_batch, args.num_gpus,
                                            0)
            # Train on multiple GPU:
            reuse_variables = None

            rec_losses = []
            ssim_losses = []
            l1_losses = []
            l1_smooth_losses = []
            num_fails = []
            model_params = homography_model_params(
                mode='test',
                batch_size=int(args.batch_size / args.num_gpus),
                patch_size=args.patch_size,
                img_h=args.img_h,
                img_w=args.img_w,
                loss_type=args.loss_type,
                use_batch_norm=args.use_batch_norm,
                augment_list=args.augment_list,
                leftright_consistent_weight=args.leftright_consistent_weight)
            # Deal with sharable variables
            with tf.variable_scope(tf.get_variable_scope()):
                for i in range(args.num_gpus):
                    with tf.device('/gpu:%d' % i):
                        # Note that ground truth gt_ here is the correspondences between pairs of images
                        # and are different from the actual delta movement of fourpoints that we want to find
                        # This ground truth is used for evaluating the estimated homography on real image data.
                        model = HomographyModel(
                            model_params,
                            I1_splits[i],
                            I2_splits[i],
                            I1_aug_splits[i],
                            I2_aug_splits[i],
                            I_splits[i],
                            I_prime_splits[i],
                            pts1_splits[i],
                            gt_splits[i],
                            patch_indices_splits[i],
                            reuse_variables=reuse_variables,
                            model_index=i)
                        # Debug test splits
                        #test_synthetic_dataloader(data_loader, True, I1_splits[i], I2_splits[i], I_splits[i], I_prime_splits[i], pts1_splits[i], gt_splits[i], patch_indices_splits[i])

                        reuse_variables = True
                        rec_loss = model.rec_loss
                        ssim_loss = model.ssim_loss
                        l1_loss = model.l1_loss
                        l1_smooth_loss = model.l1_smooth_loss
                        if i == 0:
                            self.pred_I2 = model.pred_I2
                            self.I = model.I
                            self.I_prime = model.I_prime

                            self.I1_aug = model.I1_aug
                            self.I2_aug = model.I2_aug
                            self.pred_h4p = model.pred_h4p
                            self.gt_corr = model.gt
                            self.pts1 = model.pts_1
                        else:
                            self.pred_I2 = tf.concat(
                                [self.pred_I2, model.pred_I2], axis=0)
                            self.I = tf.concat([self.I, model.I], axis=0)
                            self.I_prime = tf.concat(
                                [self.I_prime, model.I_prime], axis=0)
                            self.I1_aug = tf.concat(
                                [self.I1_aug, model.I1_aug], axis=0)
                            self.I2_aug = tf.concat(
                                [self.I2_aug, model.I2_aug], axis=0)
                            self.pred_h4p = tf.concat(
                                [self.pred_h4p, model.pred_h4p], axis=0)
                            self.gt_corr = tf.concat([self.gt_corr, model.gt],
                                                     axis=0)
                            self.pts1 = tf.concat([self.pts1, model.pts_1],
                                                  axis=0)
                        rec_losses.append(rec_loss)
                        ssim_losses.append(ssim_loss)
                        l1_losses.append(l1_loss)
                        l1_smooth_losses.append(l1_smooth_loss)
            self.total_rec_loss = tf.reduce_mean(rec_losses)
            self.total_ssim_loss = tf.reduce_mean(ssim_losses)
            self.total_l1_loss = tf.reduce_mean(l1_losses)
            self.total_l1_smooth_loss = tf.reduce_mean(l1_smooth_losses)
            self.full_I = full_I_batch
            self.full_I_prime = full_I_prime_batch

            with tf.name_scope('Losses'):
                tf.summary.scalar('Total_rec_loss', self.total_rec_loss)
                tf.summary.scalar('Total_ssim_loss', self.total_ssim_loss)
                tf.summary.scalar('Total_l1_loss', self.total_l1_loss)
                tf.summary.scalar('Total_l1_smooth_loss',
                                  self.total_l1_smooth_loss)
            self.summary_opt = tf.summary.merge_all()
示例#5
0
def test(args):
    # Load data
    TestDataset = SyntheticDataset(data_path=args.data_path,
                                   mode=args.mode,
                                   img_h=args.img_h,
                                   img_w=args.img_w,
                                   patch_size=args.patch_size,
                                   do_augment=args.do_augment)
    test_loader = DataLoader(TestDataset, batch_size=1)
    print('===> Test: There are totally {} testing files'.format(len(TestDataset)))

    # Load model
    net = HomographyModel()
    model_path = os.path.join(args.model_dir, args.model_name)
    state = torch.load(model_path)
    net.load_state_dict(state.state_dict())
    if torch.cuda.is_available():
        net = net.cuda()

    print("start testing")

    with torch.no_grad():
        net.eval()
        test_l1_loss = 0.0
        test_h_loss = 0.0
        h_losses_array = []
        for i, batch_value in enumerate(test_loader):
            I1_aug_batch = batch_value[2].float()
            I2_aug_batch = batch_value[3].float()
            I_batch = batch_value[4].float()
            I_prime_batch = batch_value[5].float()
            pts1_batch = batch_value[6].float()
            gt_batch = batch_value[7].float()
            patch_indices_batch = batch_value[8].float()

            if torch.cuda.is_available():
                I1_aug_batch = I1_aug_batch.cuda()
                I2_aug_batch = I2_aug_batch.cuda()
                I_batch = I_batch.cuda()
                pts1_batch = pts1_batch.cuda()
                gt_batch = gt_batch.cuda()
                patch_indices_batch = patch_indices_batch.cuda()

            batch_out = net(I1_aug_batch, I2_aug_batch, I_batch, pts1_batch, gt_batch, patch_indices_batch)
            h_loss = batch_out['h_loss']
            rec_loss = batch_out['rec_loss']
            ssim_loss = batch_out['ssim_loss']
            l1_loss = batch_out['l1_loss']
            pred_h4p_value = batch_out['pred_h4p']

            test_h_loss += h_loss.item()
            test_l1_loss += l1_loss.item()
            h_losses_array.append(h_loss.item())

            if args.save_visual:
                I_sample = utils.denorm_img(I_batch[0].cpu().numpy()).astype(np.uint8)
                I_prime_sample = utils.denorm_img(I_prime_batch[0].numpy()).astype(np.uint8)
                pts1_sample = pts1_batch[0].cpu().numpy().reshape([4, 2]).astype(np.float32)
                gt_h4p_sample = gt_batch[0].cpu().numpy().reshape([4, 2]).astype(np.float32)

                pts2_sample = pts1_sample + gt_h4p_sample

                pred_h4p_sample = pred_h4p_value[0].cpu().numpy().reshape([4, 2]).astype(np.float32)
                pred_pts2_sample = pts1_sample + pred_h4p_sample

                # Save
                visual_file_name = ('%s' % i).zfill(4) + '.jpg'
                utils.save_correspondences_img(I_prime_sample, I_sample, pts1_sample, pts2_sample, pred_pts2_sample,
                                               args.results_dir, visual_file_name)

            print("Testing: h_loss: {:4.3f}, rec_loss: {:4.3f}, ssim_loss: {:4.3f}, l1_loss: {:4.3f}".format(
                h_loss.item(), rec_loss.item(), ssim_loss.item(), l1_loss.item()
            ))

    print('|Test size  |   h_loss   |    l1_loss   |')
    print(len(test_loader), test_h_loss / len(test_loader), test_l1_loss / len(test_loader))

    tops_list = utils.find_percentile(h_losses_array)
    print('===> Percentile Values: (20, 50, 80, 100):')
    print(tops_list)
    print('======> End! ====================================')