def data_and_model_loader(config, pretrained_depth_path, pretrained_pose_path, seq=None, load_depth=True): if seq == None: seq = config['test_seq'] else: seq = [seq] device = torch.device( "cuda") if torch.cuda.is_available() else torch.device("cpu") test_dset = KittiLoaderPytorch( config, [seq, seq, seq], mode='test', transform_img=get_data_transforms(config)['test']) test_dset_loaders = torch.utils.data.DataLoader( test_dset, batch_size=config['minibatch'], shuffle=False, num_workers=6) eval_dsets = {'test': test_dset_loaders} if load_depth: depth_model = models.depth_model(config).to(device) pose_model = models.pose_model(config).to(device) if pretrained_depth_path is not None and load_depth == True: depth_model.load_state_dict(torch.load(pretrained_depth_path)) if pretrained_pose_path is not None: pose_model.load_state_dict(torch.load(pretrained_pose_path)) pose_model.train(False) pose_model.eval() if load_depth: depth_model.train(False) depth_model.eval() else: depth_model = None mmodels = [depth_model, pose_model] return test_dset_loaders, mmodels, device
args = parser.parse_args() config={ 'num_frames': None, 'skip':1, ### if not one, we skip every 'skip' samples that are generated ({1,2}, {2,3}, {3,4} becomes {1,2}, {3,4}) 'correction_rate': 1, ### if not one, only perform corrections every 'correction_rate' frames (samples become {1,3},{3,5},{5,7} when 2) 'minibatch':15, ##minibatch size 'load_pretrained_depth': True, 'freeze_depthnet': True, } for k in args.__dict__: config[k] = args.__dict__[k] print(config) print(args.train_seq, args.test_seq, args.val_seq) args.data_dir = '{}/{}_res'.format(args.data_dir, config['img_resolution']) config['data_dir'] = '{}/{}_res'.format(config['data_dir'], config['img_resolution']) dsets = {x: KittiLoaderPytorch(config, [args.train_seq, args.val_seq, args.test_seq], mode=x, transform_img=get_data_transforms(config)[x], \ augment=config['augment_motion'], skip=config['skip']) for x in ['train', 'val']} dset_loaders = {x: torch.utils.data.DataLoader(dsets[x], batch_size=config['minibatch'], shuffle=True, num_workers=8) for x in ['train', 'val']} val_dset = KittiLoaderPytorch(config, [args.train_seq, args.val_seq, args.test_seq], mode='val', transform_img=get_data_transforms(config)['val']) val_dset_loaders = torch.utils.data.DataLoader(val_dset, batch_size=config['minibatch'], shuffle=False, num_workers=8) test_dset = KittiLoaderPytorch(config, [args.train_seq, args.val_seq, args.test_seq], mode='test', transform_img=get_data_transforms(config)['test']) test_dset_loaders = torch.utils.data.DataLoader(test_dset, batch_size=config['minibatch'], shuffle=False, num_workers=8) eval_dsets = {'val': val_dset_loaders, 'test':test_dset_loaders} def main(): results = {} results['pose_output_type'] = config['pose_output_type'] results['estimator'] = config['estimator_type'] config['device'] = device start = time.time()
'{}/**depth**best-loss-val_seq-**-test_seq-**.pth'.format(dir))[0] pretrained_pose_path = glob.glob( '{}/**pose**best-loss-val_seq-**-test_seq-**.pth'.format(dir))[0] config['augment_motion'] = False config['augment_backwards'] = False config['test_seq'] = [seq] config['minibatch'] = 5 device = config['device'] config['data_dir'] = path_to_dset_downsized + config[ 'img_resolution'] + '_res/' ### dataset and model loading from data.kitti_loader_stereo import KittiLoaderPytorch test_dset = KittiLoaderPytorch( config, [[seq], [seq], [seq]], mode='test', transform_img=get_data_transforms(config)['test']) test_dset_loaders = torch.utils.data.DataLoader( test_dset, batch_size=config['minibatch'], shuffle=False, num_workers=6) import models.packetnet_depth_and_egomotion as models_packetnet import models.depth_and_egomotion as models depth_model = models.depth_model(config).to(device) pose_model = models_packetnet.pose_model(config).to(device) pretrained_depth_path = glob.glob( '{}/**depth**best-loss-val_seq-**-test_seq-{}**.pth'.format(dir, ''))[0]
date = 'best_stereo' pretrained_path = '{}/{}/2019-6-24-13-4-most_loop_closures-val_seq-00-test_seq-05.pth'.format( model_dirs, date) output_dir = '{}{}/'.format(model_dirs, date) args.data_dir = '{}/{}'.format(args.data_dir, args.mode) seq = [args.val_seq] #model.replace(output_dir,'').replace('/','').replace figures_output_dir = '{}figs'.format(output_dir) os.makedirs(figures_output_dir, exist_ok=True) os.makedirs(figures_output_dir + '/imgs', exist_ok=True) os.makedirs(figures_output_dir + '/depth', exist_ok=True) os.makedirs(figures_output_dir + '/exp_mask', exist_ok=True) test_dset = KittiLoaderPytorch( args.data_dir, config, [seq, seq, seq], mode='test', transform_img=get_data_transforms(config)['val']) test_dset_loaders = torch.utils.data.DataLoader(test_dset, batch_size=config['minibatch'], shuffle=False, num_workers=4) eval_dsets = {'test': test_dset_loaders} Reconstructor = stn.Reconstructor().to(device) model = mono_model_joint.joint_model( num_img_channels=(6 + 2 * config['use_flow']), output_exp=args.exploss, dropout_prob=config['dropout_prob']).to(device) model.load_state_dict(torch.load(pretrained_path)) _, _, _, _, _, corr_traj, corr_traj_rot, est_traj, gt_traj, _, _, _ = test_trajectory(