type=int, default=1, metavar='S', help='random seed (default: 1)') args = parser.parse_args() args.cuda = not args.no_cuda and torch.cuda.is_available() torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp = kl2012.dataloader( args.datapath) TrainImgLoader = torch.utils.data.DataLoader(kl.myImageFloder( all_left_img, all_right_img, all_left_disp, True), batch_size=3, shuffle=True, num_workers=8, drop_last=False) TestImgLoader = torch.utils.data.DataLoader(kl.myImageFloder( test_left_img, test_right_img, test_left_disp, False), batch_size=3, shuffle=False, num_workers=4, drop_last=False) if args.model == 'stackhourglass': model = stackhourglass(args.maxdisp) elif args.model == 'basic':
from dataloader import KITTIloader_VirtualKT2 as ls train_file_list = args.vkt2_train_list val_file_list = args.vkt2_val_list virtual_kitti2 = True else: raise Exception("No suitable KITTI found ...") print('[??] args.datapath = ', args.datapath) all_left_img, all_right_img, all_left_disp, test_left_img, \ test_right_img, test_left_disp = ls.dataloader( args.datapath, train_file_list, val_file_list) TrainImgLoader = torch.utils.data.DataLoader(DA.myImageFloder( all_left_img, all_right_img, all_left_disp, training=True, virtual_kitti2=virtual_kitti2), batch_size=args.batch_size, shuffle=True, num_workers=8, drop_last=False) TestImgLoader = torch.utils.data.DataLoader(DA.myImageFloder( test_left_img, test_right_img, test_left_disp, training=False, virtual_kitti2=virtual_kitti2), batch_size=args.batch_size, shuffle=False,
def main(): global args log = logger.setup_logger(args.save_path + '/training.log') train_left_img, train_right_img, train_left_disp, test_left_img, test_right_img, test_left_disp = ls.dataloader( args.datapath, log) TrainImgLoader = torch.utils.data.DataLoader(DA.myImageFloder( train_left_img, train_right_img, train_left_disp, True), batch_size=args.train_bsize, shuffle=True, num_workers=4, drop_last=False) TestImgLoader = torch.utils.data.DataLoader(DA.myImageFloder( test_left_img, test_right_img, test_left_disp, False), batch_size=args.test_bsize, shuffle=False, num_workers=4, drop_last=False) if not os.path.isdir(args.save_path): os.makedirs(args.save_path) for key, value in sorted(vars(args).items()): log.info(str(key) + ': ' + str(value)) model = models.anynet.AnyNet(args) model = nn.DataParallel(model).cuda() optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999)) log.info('Number of model parameters: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) if args.pretrained: if os.path.isfile(args.pretrained): checkpoint = torch.load(args.pretrained) model.load_state_dict(checkpoint['state_dict']) log.info("=> loaded pretrained model '{}'".format(args.pretrained)) else: log.info("=> no pretrained model found at '{}'".format( args.pretrained)) log.info("=> Will start from scratch.") args.start_epoch = 0 if args.resume: if os.path.isfile(args.resume): log.info("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) log.info("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: log.info("=> no checkpoint found at '{}'".format(args.resume)) log.info("=> Will start from scratch.") else: log.info('Not Resume') cudnn.benchmark = True start_full_time = time.time() for epoch in range(args.start_epoch, args.epochs): log.info('This is {}-th epoch'.format(epoch)) adjust_learning_rate(optimizer, epoch) train(TrainImgLoader, model, optimizer, log, epoch) savefilename = args.save_path + '/checkpoint.tar' torch.save( { 'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, savefilename) if epoch % 1 == 0: test(TestImgLoader, model, log) test(TestImgLoader, model, log) log.info('full training time = {:.2f} Hours'.format( (time.time() - start_full_time) / 3600))
def main(): epoch_start = 0 max_acc = 0 max_epo = 0 # 数据集部分 all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp = ls.dataloader( args.datapath) TrainImgLoader = torch.utils.data.DataLoader(DA.myImageFloder( all_left_img, all_right_img, all_left_disp, True), batch_size=args.batch_size, shuffle=True, num_workers=12, drop_last=True) TestImgLoader = torch.utils.data.DataLoader(DA.myImageFloder( test_left_img, test_right_img, test_left_disp, False), batch_size=args.batch_size, shuffle=False, num_workers=4, drop_last=False) start_full_time = time.time() for epoch in range(epoch_start, args.epochs + 1): print("epoch:", epoch) # training total_train_loss = 0 for batch_idx, (imgL_crop, imgR_crop, disp_crop_L) in enumerate(TrainImgLoader): loss = train(imgL_crop, imgR_crop, disp_crop_L) print(loss) exit() total_train_loss += loss print('epoch:{}, step:{}, loss:{}'.format(epoch, batch_idx, loss)) print('epoch %d average training loss = %.3f' % (epoch, total_train_loss / len(TrainImgLoader))) # test total_test_three_pixel_error_rate = 0 for batch_idx, (imgL, imgR, disp_L) in enumerate(TestImgLoader): test_three_pixel_error_rate = test(imgL, imgR, disp_L) total_test_three_pixel_error_rate += test_three_pixel_error_rate print('epoch %d total 3-px error in val = %.3f' % (epoch, total_test_three_pixel_error_rate / len(TestImgLoader) * 100)) acc = (1 - total_test_three_pixel_error_rate / len(TestImgLoader)) * 100 if acc > max_acc: max_acc = acc max_epo = epoch savefilename = './kitti15.tar' # # savefilename = root_path + '/checkpoints/checkpoint_finetune_kitti15.tar' torch.save( { 'state_dict': model.state_dict(), 'total_train_loss': total_train_loss, 'epoch': epoch + 1, 'optimizer_state_dict': optimizer.state_dict(), 'max_acc': max_acc, 'max_epoch': max_epo }, savefilename) print("-- max acc checkpoint saved --") print('MAX epoch %d test 3 pixel correct rate = %.3f' % (max_epo, max_acc)) print('full finetune time = %.2f HR' % ((time.time() - start_full_time) / 3600)) print(max_epo) print(max_acc)
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') args = parser.parse_args() args.cuda = not args.no_cuda and torch.cuda.is_available() os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed_all(args.seed) test_left_img1, test_right_img1, test_left_disp1, test_left_img, test_right_img, test_left_disp = ls.dataloader( args.datapath) TestImgLoader = torch.utils.data.DataLoader(DA.myImageFloder( test_left_img, test_right_img, test_left_disp, False), batch_size=1, shuffle=False, num_workers=4, drop_last=False) if args.model == 'ShuffleStereo8': model = MABNet_origin(args.maxdisp) elif args.model == 'ShuffleStereo16': model = ShuffleStereo16(args.maxdisp) else: print('no model') if args.cuda: model = nn.DataParallel(model) model.cuda()
def main(): global args log = logger.setup_logger(args.save_path + '/training.log') # train_left_img, train_right_img, train_left_disp, test_left_img, test_right_img, test_left_disp, test_fn = ls.dataloader( # args.datapath,log, args.split_file) # # TrainImgLoader = torch.utils.data.DataLoader( # DA.myImageFloder(train_left_img, train_right_img, train_left_disp, True), # batch_size=args.train_bsize, shuffle=True, num_workers=4, drop_last=False) # # TestImgLoader = torch.utils.data.DataLoader( # DA.myImageFloder(test_left_img, test_right_img, test_left_disp, False, test_fn), # batch_size=args.test_bsize, shuffle=False, num_workers=4, drop_last=False) train_left_img, train_right_img, train_left_disp, test_left_img, test_right_img, left_val_disp, val_fn, left_train_semantic, left_val_semantic = ls.dataloader( args.datapath, log, args.split_file) TrainImgLoader = torch.utils.data.DataLoader(DA.myImageFloder( train_left_img, train_right_img, train_left_disp, left_train_semantic, True), batch_size=args.train_bsize, shuffle=True, num_workers=4, drop_last=False) TestImgLoader = torch.utils.data.DataLoader(DA.myImageFloder( test_left_img, test_right_img, left_val_disp, left_val_semantic, False, val_fn), batch_size=args.test_bsize, shuffle=False, num_workers=4, drop_last=False) if not os.path.isdir(args.save_path): os.makedirs(args.save_path) for key, value in sorted(vars(args).items()): log.info(str(key) + ': ' + str(value)) model = models.anynet.AnyNet(args) model = nn.DataParallel(model).cuda() optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999)) log.info('Number of model parameters: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) if args.pretrained: if os.path.isfile(args.pretrained): checkpoint = torch.load(args.pretrained) model.load_state_dict(checkpoint['state_dict'], strict=False) log.info("=> loaded pretrained model '{}'".format(args.pretrained)) else: log.info("=> no pretrained model found at '{}'".format( args.pretrained)) log.info("=> Will start from scratch.") args.start_epoch = 0 if args.resume: if os.path.isfile(args.resume): log.info("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) log.info("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: log.info("=> no checkpoint found at '{}'".format(args.resume)) log.info("=> Will start from scratch.") else: log.info('Not Resume') cudnn.benchmark = True test(TestImgLoader, model, log) return
def main(log): global args torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) np.random.seed(args.seed) ## init dist ## args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.world_size = 1 if args.distributed: torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." if args.model_types == "PSMNet": model = PSMNet(args) args.loss_weights = [0.5, 0.7, 1.] elif args.model_types == "PSMNet_DSM": model = PSMNet_DSM(args) args.loss_weights = [0.5, 0.7, 1.] elif args.model_types == "Hybrid_Net_DSM" or "Hybrid_Net": model = Hybrid_Net(args) args.loss_weights = [0.5, 0.7, 1., 1., 1.] else: AssertionError("model error") if args.datatype == '2015': all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp = ls.dataloader2015( args.datapath2015, split=args.split_for_val) elif args.datatype == '2012': all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp = ls.dataloader2012( args.datapath2012, split=False) elif args.datatype == 'mix': all_left_img_2015, all_right_img_2015, all_left_disp_2015, test_left_img_2015, test_right_img_2015, test_left_disp_2015 = ls.dataloader2015( args.datapath2015, split=False) all_left_img_2012, all_right_img_2012, all_left_disp_2012, test_left_img_2012, test_right_img_2012, test_left_disp_2012 = ls.dataloader2012( args.datapath2012, split=False) all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp = \ all_left_img_2015 + all_left_img_2012, all_right_img_2015 + all_right_img_2012, \ all_left_disp_2015 + all_left_disp_2012, test_left_img_2015 + test_left_img_2012, \ test_right_img_2015 + test_right_img_2012, test_left_disp_2015 + test_left_disp_2012 else: AssertionError("please define the finetune dataset") train_set = DA.myImageFloder(all_left_img, all_right_img, all_left_disp, True) val_set = DA.myImageFloder(test_left_img, test_right_img, test_left_disp, False) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( train_set) val_sampler = torch.utils.data.distributed.DistributedSampler(val_set) else: train_sampler = None val_sampler = None TrainImgLoader = torch.utils.data.DataLoader(train_set, batch_size=args.train_bsize, shuffle=False, num_workers=4, pin_memory=True, sampler=train_sampler, drop_last=False) TestImgLoader = torch.utils.data.DataLoader(val_set, batch_size=args.test_bsize, shuffle=False, num_workers=4, pin_memory=True, sampler=None, drop_last=False) num_train = len(TrainImgLoader) num_test = len(TestImgLoader) if args.local_rank == 0: for key, value in sorted(vars(args).items()): log.info(str(key) + ': ' + str(value)) stages = len(args.loss_weights) # note if args.sync_bn: import apex model = apex.parallel.convert_syncbn_model(model) if args.local_rank == 0: log.info( "using apex synced BN-----------------------------------------------------" ) model = model.cuda() optimizer = optim.Adam(model.parameters(), lr=0.1, betas=(0.9, 0.999)) model, optimizer = amp.initialize( model, optimizer, opt_level=args.opt_level, keep_batchnorm_fp32=args.keep_batchnorm_fp32, loss_scale=args.loss_scale) if args.distributed: model = DDP(model, delay_allreduce=True) if args.local_rank == 0: log.info( "using distributed-----------------------------------------------------" ) if args.pretrained: if os.path.isfile(args.pretrained): checkpoint = torch.load(args.pretrained, map_location='cpu') model.load_state_dict(checkpoint['state_dict'], strict=True) if args.local_rank == 0: log.info("=> loaded pretrained model '{}'".format( args.pretrained)) else: if args.local_rank == 0: log.info("=> no pretrained model found at '{}'".format( args.pretrained)) log.info("=> Will start from scratch.") args.start_epoch = 0 if args.resume: if os.path.isfile(args.resume): if args.local_rank == 0: log.info("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume, map_location='cpu') model.load_state_dict(checkpoint['state_dict']) args.start_epoch = checkpoint['epoch'] + 1 if args.local_rank == 0: log.info("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: if args.local_rank == 0: log.info("=> no checkpoint found at '{}'".format(args.resume)) log.info("=> Will start from scratch.") else: if args.local_rank == 0: log.info('Not Resume') if args.local_rank == 0: log.info('Number of model parameters: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) start_full_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) total_train_loss = 0 total_d1 = 0 total_epe = 0 adjust_learning_rate(optimizer, epoch) losses = [AverageMeter() for _ in range(stages)] ## training ## for batch_idx, (imgL_crop, imgR_crop, disp_crop_L) in enumerate(TrainImgLoader): loss = train(imgL_crop, imgR_crop, disp_crop_L, model, optimizer) for idx in range(stages): losses[idx].update(loss[idx].item() / args.loss_weights[idx]) # # record loss info_str = [ 'Stage {} = {:.2f}({:.2f})'.format(x, losses[x].val, losses[x].avg) for x in range(stages) ] info_str = '\t'.join(info_str) if args.local_rank == 0: log.info('losses Epoch{} [{}/{}] {}'.format( epoch, batch_idx, len(TrainImgLoader), info_str)) total_train_loss += loss[-1] if args.local_rank == 0: log.info('epoch %d total training loss = %.3f' % (epoch, total_train_loss / num_train)) if epoch % 50 == 0: ## Test ## inference_time = 0 for batch_idx, (imgL, imgR, disp_L) in enumerate(TestImgLoader): epe, d1, single_inference_time = test(imgL, imgR, disp_L, model) inference_time += single_inference_time total_d1 += d1 total_epe += epe if args.distributed: total_epe = reduce_tensor(total_epe.data) total_d1 = reduce_tensor(total_d1.data) else: total_epe = total_epe total_d1 = total_d1 if args.local_rank == 0: log.info('epoch %d avg_3-px error in val = %.3f' % (epoch, total_d1 / num_test * 100)) log.info('epoch %d avg_epe in val = %.3f' % (epoch, total_epe / num_test)) log.info(('=> Mean inference time for %d images: %.3fs' % (num_test, inference_time / num_test))) if args.local_rank == 0: if epoch % 100 == 0: savefilename = args.save_path + '/finetune_' + str( epoch) + '.tar' torch.save( { 'epoch': epoch, 'state_dict': model.state_dict(), 'train_loss': total_train_loss / len(TrainImgLoader), 'test_loss': total_d1 / len(TestImgLoader) * 100, }, savefilename) if args.local_rank == 0: log.info('full finetune time = %.2f HR' % ((time.time() - start_full_time) / 3600))
def main(): global args log = logger.setup_logger(args.save_path + '/training.log') train_left_img, train_right_img, train_left_disp, test_left_img, test_right_img, test_left_disp = ls.dataloader( args.datapath, log, args.split_file) n_train = int(len(train_left_img)) TrainImgLoader = torch.utils.data.DataLoader(DA.myImageFloder( train_left_img, train_right_img, train_left_disp, True), batch_size=args.train_bsize, shuffle=True, num_workers=4, drop_last=False) TestImgLoader = torch.utils.data.DataLoader(DA.myImageFloder( test_left_img, test_right_img, test_left_disp, False), batch_size=args.test_bsize, shuffle=False, num_workers=4, drop_last=False) if not os.path.isdir(args.save_path): os.makedirs(args.save_path) for key, value in sorted(vars(args).items()): log.info(str(key) + ': ' + str(value)) model = models.anynet.AnyNet(args) torch.save(model, './model_para.pth') model = nn.DataParallel(model).cuda() torch.manual_seed(2.0) left = torch.randn(1, 3, 256, 512) right = torch.randn(1, 3, 256, 512) with SummaryWriter(comment='AnyNet_model_stracture') as w: w.add_graph(model, ( left, right, )) optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999)) log.info('Number of model parameters: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) args.start_epoch = 0 if args.resume: #训练中断后,继续加载 if os.path.isfile(args.resume): log.info("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) #模型加载, optimizer.load_state_dict(checkpoint['optimizer']) #优化器加载 log.info("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: log.info("=> no checkpoint found at '{}'".format(args.resume)) log.info("=> Will start from scratch.") else: log.info('Not Resume') cudnn.benchmark = True start_full_time = time.time() train(TrainImgLoader, model, optimizer, log, n_train, TestImgLoader) #开始进行模型训练 test(TestImgLoader, model, log) log.info('full training time = {:.2f} Hours'.format( (time.time() - start_full_time) / 3600))
def main(): global args os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu train_left_img, train_right_img, train_left_disp, test_left_img, test_right_img, test_left_disp = ls.dataloader( args.datapath) TrainImgLoader = torch.utils.data.DataLoader(DA.myImageFloder( train_left_img, train_right_img, train_left_disp, True), batch_size=args.train_bsize, shuffle=True, num_workers=1, drop_last=False) TestImgLoader = torch.utils.data.DataLoader(DA.myImageFloder( test_left_img, test_right_img, test_left_disp, False), batch_size=args.test_bsize, shuffle=False, num_workers=4, drop_last=False) if not os.path.isdir(args.save_path): os.makedirs(args.save_path) log = logger.setup_logger(args.save_path + 'training.log') for key, value in sorted(vars(args).items()): log.info(str(key) + ':' + str(value)) model = StereoNet(maxdisp=args.maxdisp) model = nn.DataParallel(model).cuda() model.apply(weights_init) optimizer = optim.Adam(model.parameters(), lr=args.lr) scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[200], gamma=args.gamma) log.info('Number of model parameters: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) args.start_epoch = 0 if args.resume: if os.path.isfile(args.resume): log.info("=> loading checkpoint '{}'".format((args.resume))) checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['state_dict']) log.info("=> loaded checkpoint '{}' (epeoch {})".format( args.resume, checkpoint['epoch'])) else: log.info("=> no checkpoint found at '{}'".format(args.resume)) log.info("=> will start from scratch.") else: log.info("Not Resume") min_erro = 100000 max_epo = 0 start_full_time = time.time() for epoch in range(args.start_epoch, args.epoch): log.info('This is {}-th epoch'.format(epoch)) train(TrainImgLoader, model, optimizer, log, epoch) scheduler.step() erro = test(TestImgLoader, model, log) if erro < min_erro: max_epo = epoch min_erro = erro savefilename = args.save_path + 'finetune_checkpoint_{}.pth'.format( max_epo) torch.save({ 'epoch': epoch, 'state_dict': model.state_dict() }, savefilename) log.info('MIN epoch %d total test erro = %.3f' % (max_epo, min_erro)) log.info('full training time = {: 2f} Hours'.format( (time.time() - start_full_time) / 3600))
help='random seed (default: 1)') args = parser.parse_args() args.cuda = not args.no_cuda and torch.cuda.is_available() torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) if args.datatype == '2015': from dataloader import KITTIloader2015 as ls elif args.datatype == '2012': from dataloader import KITTIloader2012 as ls all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp = ls.dataloader(args.datapath) TrainImgLoader = torch.utils.data.DataLoader( DA.myImageFloder(all_left_img,all_right_img,all_left_disp, True), batch_size= 2, shuffle= True, num_workers= 4, drop_last=False) TestImgLoader = torch.utils.data.DataLoader( DA.myImageFloder(test_left_img,test_right_img,test_left_disp, False), batch_size= 2, shuffle= False, num_workers= 4, drop_last=False) if args.model == 'stackhourglass': model = stackhourglass(args.maxdisp) elif args.model == 'basic': model = basic(args.maxdisp) else: print('no model') if args.cuda: model = nn.DataParallel(model)
if args.datatype == '2015': from dataloader import KITTIloader2015 as ls elif args.datatype == '2012': from dataloader import KITTIloader2012 as ls #读取数据路径(默认读取2012的数据,返回图片的路径) all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp = ls.dataloader( args.datapath) #读数据,并且讲数据以12个batch进行封装,并且图片全部shuffle #TrainImgLoader是一个list包括了imgL_crop, imgR_crop, disp_crop_L,list[i]表示一个batch_size #kitti2012_data = DA.myImageFloder(all_left_img,all_right_img,all_left_disp, True) TrainImgLoader = DA.ImgLoader(DA.myImageFloder(all_left_img, all_right_img, all_left_disp, True), BATCH_SIZE=1) #TestImgLoader = DA.ImgLoader( # DA.myImageFloder(test_left_img,test_right_img,test_left_disp, False), # BATCH_SIZE= 8) """ #读取已有的模型(后续在写) if args.loadmodel is not None: state_dict = torch.load(args.loadmodel) model.load_state_dict(state_dict['state_dict']) print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) """ """ def test(imgL,imgR,disp_true):
from dataloader import KITTILoader as DA elif args.datatype == 'sceneflow': from dataloader import listflowfile as ls from dataloader import SecenFlowLoader as DA elif args.datatype == 'kitti_object': from dataloader.KITTIObjectLoader import KITTIObjectLoader else: print('unknown datatype: ', args.datatype) sys.exit() if args.datatype == 'kitti_object': dataloader = KITTIObjectLoader(args.datapath, 'trainval') else: all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp = ls.dataloader( args.datapath) dataloader = DA.myImageFloder(test_left_img, test_right_img, test_left_disp, False) if args.model == 'stackhourglass': model = stackhourglass(args.maxdisp) elif args.model == 'basic': model = basic(args.maxdisp) else: print('no model') refine_model = unet_refine.resnet34(pretrained=True, rgbd=True) #refine_model = unet_refine.resnet18(pretrained=True) model = nn.DataParallel(model, device_ids=[0]) model.cuda() refine_model = nn.DataParallel(refine_model) refine_model.cuda()
help='load image as RGB or gray') args = parser.parse_args() args.cuda = not args.no_cuda and torch.cuda.is_available() torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) if args.datatype == '2015': from dataloader import KITTIloader2015 as ls elif args.datatype == '2012': from dataloader import KITTIloader2012 as ls all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp = ls.dataloader(args.datapath) TrainImgLoader = torch.utils.data.DataLoader( DA.myImageFloder(all_left_img,all_right_img,all_left_disp, True, colormode=args.colormode), batch_size= 12, shuffle= True, num_workers= 8, drop_last=False) TestImgLoader = torch.utils.data.DataLoader( DA.myImageFloder(test_left_img,test_right_img,test_left_disp, False, colormode=args.colormode), batch_size= 3, shuffle= False, num_workers= 4, drop_last=False) if args.model == 'stackhourglass': model = stackhourglass(args.maxdisp, colormode=args.colormode) elif args.model == 'basic': model = basic(args.maxdisp, colormode=args.colormode) else: print('no model') if args.cuda: model = nn.DataParallel(model)
args = parser.parse_args() if args.KITTI == '2015': from dataloader import KITTI_submission_loader as DA else: from dataloader import KITTI_submission_loader2012 as DA test_left_img, test_right_img = DA.dataloader(args.datapath) from dataloader import KITTIloader2015 as lt from dataloader import KITTILoader as DA_tmp all_left_img, all_right_img, all_left_disp, test_left_img_tmp, test_right_img_tmp, test_left_disp = lt.dataloader(args.datapath) TrainImgLoader = torch.utils.data.DataLoader( DA_tmp.myImageFloder(all_left_img,all_right_img,all_left_disp, True), batch_size= 1, shuffle= True, num_workers= 8, drop_last=False) args.cuda = not args.no_cuda and torch.cuda.is_available() torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) if args.model == 'stackhourglass': model = psm_net.PSMNet(args.maxdisp) elif args.model == 'basic': model = basic_net.PSMNet(args.maxdisp) elif args.model == 'concatNet': model = concatNet.PSMNet(args.maxdisp)