Exemple #1
0
    def run_from_file(self):
        if self.args.KITTI == '2015':
            from dataloader import KITTI_submission_loader as DA
        else:
            from dataloader import KITTI_submission_loader2012 as DA
        test_left_img, test_right_img = DA.dataloader(self.args.datapath)

        if not os.path.isdir(self.args.save_path):
            os.makedirs(self.args.save_path)

        for inx in range(len(test_left_img)):
            imgL_o = (skimage.io.imread(test_left_img[inx]).astype('float32'))
            imgR_o = (skimage.io.imread(test_right_img[inx]).astype('float32'))

            img = self.disp_pred_net.run(imgL_o, imgR_o)

            # file output
            print(test_left_img[inx].split('/')[-1])
            if self.args.save_figure:
                skimage.io.imsave(
                    self.args.save_path + '/' +
                    test_left_img[inx].split('/')[-1],
                    (img * 256).astype('uint16'))
            else:
                np.save(
                    self.args.save_path + '/' +
                    test_left_img[inx].split('/')[-1][:-4], img)
    def run_from_file(self):
        if self.args_disp.KITTI == '2015':
            from dataloader import KITTI_submission_loader as DA
        else:
            from dataloader import KITTI_submission_loader2012 as DA
        test_left_img, test_right_img = DA.dataloader(self.args_disp.datapath)

        if not os.path.isdir(self.args_disp.save_path):
            os.makedirs(self.args_disp.save_path)

        for inx in range(len(test_left_img)):
            imgL_o = (skimage.io.imread(test_left_img[inx]).astype('float32'))
            imgR_o = (skimage.io.imread(test_right_img[inx]).astype('float32'))

            img = self.disp_pred_net.run(imgL_o, imgR_o)

            # # file output
            # print(test_left_img[inx].split('/')[-1])
            # if self.args.save_figure:
            #     skimage.io.imsave(self.args.save_path+'/'+test_left_img[inx].split('/')[-1],(img*256).astype('uint16'))
            # else:
            #     np.save(self.args.save_path+'/'+test_left_img[inx].split('/')[-1][:-4], img)

            predix = test_left_img[inx].split('/')[-1][:-4]
            calib_file = '{}/{}.txt'.format(self.args_gen_lidar.calib_dir,
                                            predix)
            calib = kitti_util.Calibration(calib_file)

            img = (img * 256).astype(np.uint16) / 256.
            lidar = self.pcl_generator.run(calib, img)

            # pad 1 in the indensity dimension
            lidar = np.concatenate([lidar, np.ones((lidar.shape[0], 1))], 1)
            lidar = lidar.astype(np.float32)
            lidar.tofile('{}/{}.bin'.format(self.args_gen_lidar.save_dir,
                                            predix))
            print('Finish Depth {}'.format(predix))
Exemple #3
0
def main():
    global best_RMSE

    lw = utils_func.LossWise(args.api_key, args.losswise_tag, args.epochs - 1)
    # set logger
    log = logger.setup_logger(os.path.join(args.save_path, 'training.log'))
    for key, value in sorted(vars(args).items()):
        log.info(str(key) + ': ' + str(value))

    # set tensorboard
    writer = SummaryWriter(args.save_path + '/tensorboardx')

    # Data Loader
    if args.generate_depth_map:
        TrainImgLoader = None
        import dataloader.KITTI_submission_loader as KITTI_submission_loader
        TestImgLoader = torch.utils.data.DataLoader(
            KITTI_submission_loader.SubmiteDataset(args.datapath,
                                                   args.data_list,
                                                   args.dynamic_bs),
            batch_size=args.bval,
            shuffle=False,
            num_workers=args.workers,
            drop_last=False)
    elif args.dataset == 'kitti':
        train_data, val_data = KITTILoader3D.dataloader(
            args.datapath,
            args.split_train,
            args.split_val,
            kitti2015=args.kitti2015)
        TrainImgLoader = torch.utils.data.DataLoader(
            KITTILoader_dataset3d.myImageFloder(train_data,
                                                True,
                                                kitti2015=args.kitti2015,
                                                dynamic_bs=args.dynamic_bs),
            batch_size=args.btrain,
            shuffle=True,
            num_workers=8,
            drop_last=False,
            pin_memory=True)
        TestImgLoader = torch.utils.data.DataLoader(
            KITTILoader_dataset3d.myImageFloder(val_data,
                                                False,
                                                kitti2015=args.kitti2015,
                                                dynamic_bs=args.dynamic_bs),
            batch_size=args.bval,
            shuffle=False,
            num_workers=8,
            drop_last=False,
            pin_memory=True)
    else:
        train_data, val_data = listflowfile.dataloader(args.datapath)
        TrainImgLoader = torch.utils.data.DataLoader(
            SceneFlowLoader.myImageFloder(train_data,
                                          True,
                                          calib=args.calib_value),
            batch_size=args.btrain,
            shuffle=True,
            num_workers=8,
            drop_last=False)
        TestImgLoader = torch.utils.data.DataLoader(
            SceneFlowLoader.myImageFloder(val_data,
                                          False,
                                          calib=args.calib_value),
            batch_size=args.bval,
            shuffle=False,
            num_workers=8,
            drop_last=False)

    # Load Model
    if args.data_type == 'disparity':
        model = disp_models.__dict__[args.arch](maxdisp=args.maxdisp)
    elif args.data_type == 'depth':
        model = models.__dict__[args.arch](maxdepth=args.maxdepth,
                                           maxdisp=args.maxdisp,
                                           down=args.down,
                                           scale=args.scale)
    else:
        log.info('Model is not implemented')
        assert False

    # Number of parameters
    log.info('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))
    model = nn.DataParallel(model).cuda()
    torch.backends.cudnn.benchmark = True

    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999))
    scheduler = MultiStepLR(optimizer,
                            milestones=args.lr_stepsize,
                            gamma=args.lr_gamma)

    if args.pretrain:
        if os.path.isfile(args.pretrain):
            log.info("=> loading pretrain '{}'".format(args.pretrain))
            checkpoint = torch.load(args.pretrain)
            model.load_state_dict(checkpoint['state_dict'], strict=False)
        else:
            log.info('[Attention]: Do not find checkpoint {}'.format(
                args.pretrain))

    if args.resume:
        if os.path.isfile(args.resume):
            log.info("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['state_dict'])
            args.start_epoch = checkpoint['epoch']
            optimizer.load_state_dict(checkpoint['optimizer'])
            best_RMSE = checkpoint['best_RMSE']
            scheduler.load_state_dict(checkpoint['scheduler'])
            log.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            log.info('[Attention]: Do not find checkpoint {}'.format(
                args.resume))

    if args.generate_depth_map:
        os.makedirs(args.save_path + '/depth_maps/' + args.data_tag,
                    exist_ok=True)

        tqdm_eval_loader = tqdm(TestImgLoader, total=len(TestImgLoader))
        for batch_idx, (imgL_crop, imgR_crop, calib, H, W,
                        filename) in enumerate(tqdm_eval_loader):
            pred_disp = inference(imgL_crop, imgR_crop, calib, model)
            for idx, name in enumerate(filename):
                np.save(
                    args.save_path + '/depth_maps/' + args.data_tag + '/' +
                    name, pred_disp[idx][-H[idx]:, :W[idx]])
        import sys
        sys.exit()

    # evaluation
    if args.evaluate:
        evaluate_metric = utils_func.Metric()
        ## training ##
        for batch_idx, (imgL_crop, imgR_crop, disp_crop_L,
                        calib) in enumerate(TestImgLoader):
            start_time = time.time()
            test(imgL_crop, imgR_crop, disp_crop_L, calib, evaluate_metric,
                 optimizer, model)

            log.info(
                evaluate_metric.print(batch_idx, 'EVALUATE') +
                ' Time:{:.3f}'.format(time.time() - start_time))
        import sys
        sys.exit()

    for epoch in range(args.start_epoch, args.epochs):
        scheduler.step()

        ## training ##
        train_metric = utils_func.Metric()
        tqdm_train_loader = tqdm(TrainImgLoader, total=len(TrainImgLoader))
        for batch_idx, (imgL_crop, imgR_crop, disp_crop_L,
                        calib) in enumerate(tqdm_train_loader):
            # start_time = time.time()
            train(imgL_crop, imgR_crop, disp_crop_L, calib, train_metric,
                  optimizer, model, epoch)
            # log.info(train_metric.print(batch_idx, 'TRAIN') + ' Time:{:.3f}'.format(time.time() - start_time))
        log.info(train_metric.print(0, 'TRAIN Epoch' + str(epoch)))
        train_metric.tensorboard(writer, epoch, token='TRAIN')
        lw.update(train_metric.get_info(), epoch, 'Train')

        ## testing ##
        is_best = False
        if epoch == 0 or ((epoch + 1) % args.eval_interval) == 0:
            test_metric = utils_func.Metric()
            tqdm_test_loader = tqdm(TestImgLoader, total=len(TestImgLoader))
            for batch_idx, (imgL_crop, imgR_crop, disp_crop_L,
                            calib) in enumerate(tqdm_test_loader):
                # start_time = time.time()
                test(imgL_crop, imgR_crop, disp_crop_L, calib, test_metric,
                     optimizer, model)
                # log.info(test_metric.print(batch_idx, 'TEST') + ' Time:{:.3f}'.format(time.time() - start_time))
            log.info(test_metric.print(0, 'TEST Epoch' + str(epoch)))
            test_metric.tensorboard(writer, epoch, token='TEST')
            lw.update(test_metric.get_info(), epoch, 'Test')

            # SAVE
            is_best = test_metric.RMSELIs.avg < best_RMSE
            best_RMSE = min(test_metric.RMSELIs.avg, best_RMSE)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_RMSE': best_RMSE,
                'scheduler': scheduler.state_dict(),
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            epoch,
            folder=args.save_path)
    lw.done()
Exemple #4
0
                    default=1,
                    metavar='S',
                    help='random seed (default: 1)')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

if args.KITTI == '2015':
    from dataloader import KITTI_submission_loader as DA
else:
    from dataloader import KITTI_submission_loader2012 as DA

test_left_img, test_right_img = DA.dataloader(args.datapath)

if args.model == 'stackhourglass':
    model = stackhourglass(args.maxdisp)
elif args.model == 'basic':
    model = basic(args.maxdisp)
else:
    print('no model')

model = nn.DataParallel(model, device_ids=[0])
model.cuda()

if args.loadmodel is not None:
    state_dict = torch.load(args.loadmodel)
    model.load_state_dict(state_dict['state_dict'])
Exemple #5
0
parser.add_argument('--datapath2015',
                    default='/data6/wsgan/KITTI/KITTI2015/testing/',
                    help='datapath')
parser.add_argument('--datapath2012',
                    default='/data6/wsgan/KITTI/KITTI2012/testing/',
                    help='datapath')
parser.add_argument('--datatype',
                    default='2015',
                    help='finetune dataset: 2012, 2015')

args = parser.parse_args()

if args.datatype == '2015':
    from dataloader import KITTI_submission_loader as DA

    test_left_img, test_right_img = DA.dataloader2015(args.datapath2015)

elif args.datatype == '2012':

    from dataloader import KITTI_submission_loader as DA
    test_left_img, test_right_img = DA.dataloader2012(args.datapath2012)

else:

    AssertionError("None found datatype")

log = logger.setup_logger(args.save_path + '/training.log')
for key, value in sorted(vars(args).items()):
    log.info(str(key) + ': ' + str(value))

if args.pretrained:
Exemple #6
0
                    metavar='S',
                    help='random seed (default: 1)')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

if args.KITTI == '2015':
    from dataloader import KITTI_submission_loader as DA
else:
    from dataloader import KITTI_submission_loader2012 as DA

# test_left_img, test_right_img = DA.dataloader(args.datapath)
test_left_img, test_right_img, test_left_disp = DA.dataloader_val(
    args.datapath)

if args.model == 'stackhourglass':
    model = stackhourglass(int(args.maxdisp))
elif args.model == 'basic':
    model = basic(args.maxdisp)
else:
    print('no model')

model = nn.DataParallel(model, device_ids=[0])
model.cuda()

# if args.loadmodel is not None:
#     state_dict = torch.load(args.loadmodel)
#     model.load_state_dict(state_dict['state_dict'])
#     print("Loaded model")
Exemple #7
0
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

kitti2015 = False
kitti2012 = False
kitti_vkt2 = False

if args.KITTI == '2015':
    print("processing KT15!")
    data_type_str = "kt15"
    from dataloader import KITTI_submission_loader as DA
    test_left_img, test_right_img = DA.dataloader(args.datapath,
                                                  args.file_txt_path)
    test_left_disp = None
    kitti2015 = True
elif args.KITTI == '2012':
    print("processing KT12!")
    data_type_str = "kt12"
    from dataloader import KITTI_submission_loader2012 as DA
    test_left_img, test_right_img = DA.dataloader(args.datapath,
                                                  args.file_txt_path)
    test_left_disp = None
    kitti2012 = True

# added by CCJ on 2020/05/22:
elif args.KITTI == 'virtual_kt_2':
    print("processing Virtual KT 2!")
    data_type_str = "virtual_kt2"
Exemple #8
0
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

if args.KITTI == '2015':
    from dataloader import KITTI_submission_loader as DA
else:
    from dataloader import KITTI_submission_loader2012 as DA

import sintel_loader as DA

test_left_img, test_right_img = DA.dataloader(args.datapath, args.left_dir,
                                              args.right_dir)

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
if args.model == 'stackhourglass':
    model = stackhourglass(args.maxdisp,
                           device=device,
                           dfd_net=args.dfd,
                           dfd_at_end=args.dfd_at_end,
                           right_head=args.right_head)
elif args.model == 'basic':
    model = basic(args.maxdisp)
else:
    print('no model')

model = nn.DataParallel(model, device_ids=[0])
model.cuda()