def data_and_model_loader(config, pretrained_depth_path, pretrained_pose_path, seq=None, load_depth=True): if seq == None: seq = config['test_seq'] else: seq = [seq] device = torch.device( "cuda") if torch.cuda.is_available() else torch.device("cpu") test_dset = KittiLoaderPytorch( config, [seq, seq, seq], mode='test', transform_img=get_data_transforms(config)['test']) test_dset_loaders = torch.utils.data.DataLoader( test_dset, batch_size=config['minibatch'], shuffle=False, num_workers=6) eval_dsets = {'test': test_dset_loaders} if load_depth: depth_model = models.depth_model(config).to(device) pose_model = models.pose_model(config).to(device) if pretrained_depth_path is not None and load_depth == True: depth_model.load_state_dict(torch.load(pretrained_depth_path)) if pretrained_pose_path is not None: pose_model.load_state_dict(torch.load(pretrained_pose_path)) pose_model.train(False) pose_model.eval() if load_depth: depth_model.train(False) depth_model.eval() else: depth_model = None mmodels = [depth_model, pose_model] return test_dset_loaders, mmodels, device
### dataset and model loading from data.kitti_loader_stereo import KittiLoaderPytorch test_dset = KittiLoaderPytorch( config, [[seq], [seq], [seq]], mode='test', transform_img=get_data_transforms(config)['test']) test_dset_loaders = torch.utils.data.DataLoader( test_dset, batch_size=config['minibatch'], shuffle=False, num_workers=6) import models.packetnet_depth_and_egomotion as models_packetnet import models.depth_and_egomotion as models depth_model = models.depth_model(config).to(device) pose_model = models_packetnet.pose_model(config).to(device) pretrained_depth_path = glob.glob( '{}/**depth**best-loss-val_seq-**-test_seq-{}**.pth'.format(dir, ''))[0] pretrained_pose_path = glob.glob( '{}/**pose**best-loss-val_seq-**-test_seq-{}**.pth'.format(dir, ''))[0] depth_model.load_state_dict(torch.load(pretrained_depth_path)) pose_model.load_state_dict(torch.load(pretrained_pose_path)) pose_model.train(False).eval() depth_model.train(False).eval() ### Plane Model from models.plane_net import PlaneModel, scale_recovery from losses import Plane_Height_loss plane_loss = Plane_Height_loss(config)
def main(): results = {} config['device'] = torch.device( "cuda") if torch.cuda.is_available() else torch.device("cpu") start = time.time() now = datetime.datetime.now() ts = '{}-{}-{}-{}-{}'.format(now.year, now.month, now.day, now.hour, now.minute) ''' Load Pretrained Models''' pretrained_depth_path, pretrained_pose_path = None, None if config['load_pretrained_depth']: pretrained_depth_path = glob.glob( '{}/**depth**best-loss-val_seq-**-test_seq-**.pth'.format( config['pretrained_dir']))[0] if config['load_pretrained_pose']: pretrained_pose_path = glob.glob( '{}/**pose**best-loss-val_seq-**-test_seq-**.pth'.format( config['pretrained_dir']))[0] if config['load_pretrained_pose'] == True and config[ 'load_pretrained_depth'] == True: ## skip epoch 0 if pose model is pretrained (epoch 0 initializes with VO) epochs = range(1, config['num_epochs']) else: epochs = range(0, config['num_epochs']) ### load the models depth_model = models_monodepth.depth_model(config).to(config['device']) pose_model = models_packetnet.pose_model(config).to(config['device']) if pretrained_depth_path is not None and config[ 'load_pretrained_depth'] == True: depth_model.load_state_dict(torch.load(pretrained_depth_path)) if pretrained_pose_path is not None and config[ 'load_pretrained_pose'] == True: pose_model.load_state_dict(torch.load(pretrained_pose_path)) models = [depth_model, pose_model] ## Load the pretrained plane estimator if using the plane loss if config['l_scale_recovery']: from models.plane_net import PlaneModel plane_model = PlaneModel(config).to(config['device']) pretrained_plane_path = glob.glob('{}/**plane**.pth'.format( config['pretrained_plane_dir']))[0] plane_model.load_state_dict(torch.load(pretrained_plane_path)) for param in plane_model.parameters(): param.requires_grad = False else: plane_model = None if config['freeze_depthnet']: print('Freezing depth network weights.') if config['freeze_posenet']: print('Freezing pose network weights.') for param in depth_model.parameters(): param.requires_grad = not config['freeze_depthnet'] for param in pose_model.parameters(): param.requires_grad = not config['freeze_posenet'] params = [{ 'params': depth_model.parameters() }, { 'params': pose_model.parameters(), 'lr': 2 * config['lr'] }] loss = losses.Compute_Loss(config, plane_model=plane_model) optimizer = torch.optim.Adam(params, lr=config['lr'], weight_decay=config['wd']) #, amsgrad=True) trainer = Trainer(config, models, loss, optimizer) cudnn.benchmark = True best_val_loss = {} best_loss_epoch = {} for key, dset in eval_dsets.items(): best_val_loss[key] = 1e5 for epoch in epochs: optimizer = exp_lr_scheduler( optimizer, epoch, lr_decay_epoch=config['lr_decay_epoch'] ) ## reduce learning rate as training progresses print("Epoch {}".format(epoch)) train_losses = trainer.forward(dset_loaders['train'], epoch, 'train') with torch.no_grad(): val_losses = trainer.forward(dset_loaders['val'], epoch, 'val') # if epoch == 0 or (epoch == 1 and (config['load_pretrained_pose'] == True)): val_writer = SummaryWriter( comment="tw-val-{}-test_seq-{}_val".format( args.val_seq[0], args.test_seq[0])) train_writer = SummaryWriter( comment="tw-val-{}-test_seq-{}_train".format( args.val_seq[0], args.test_seq[0])) if train_losses is not None and val_losses is not None: for key, value in train_losses.items(): train_writer.add_scalar('{}'.format(key), value, epoch + 1) val_writer.add_scalar('{}'.format(key), val_losses[key], epoch + 1) for key, dset in eval_dsets.items(): print("{} Set, Epoch {}".format(key, epoch)) if epoch > 0: ###plot images, depth map, explainability mask img_array, disparity, exp_mask, d = test_depth_and_reconstruction( device, models, dset, config, epoch=epoch) source_disp, reconstructed_disp, d_masks = d[0], d[1], d[2] img_array = plot_img_array(img_array) train_writer.add_image(key + '/imgs', img_array, epoch + 1) for i, d in enumerate(disparity): train_writer.add_image( key + '/depth-{}/target-depth'.format(i), plot_disp(d), epoch + 1) ### For depth consistency train_writer.add_image( key + '/depth-{}/source-depth'.format(i), plot_disp(source_disp[i]), epoch + 1) train_writer.add_image( key + '/depth-{}/reconstructed-depth'.format(i), plot_disp(reconstructed_disp[i]), epoch + 1) d_masks = plot_img_array(d_masks) train_writer.add_image(key + '/depth/masks', d_masks, epoch + 1) exp_mask = plot_img_array(exp_mask) train_writer.add_image(key + '/exp_mask', exp_mask, epoch + 1) ###evaluate trajectories if args.data_format == 'odometry': est_lie_alg, gt_lie_alg, est_traj, gt_traj, errors = test_trajectory( config, device, models, dset, epoch) errors = plot_6_by_1(est_lie_alg - gt_lie_alg, title='6x1 Errors') pose_vec_est = plot_6_by_1(est_lie_alg, title='Est') pose_vec_gt = plot_6_by_1(gt_lie_alg, title='GT') est_traj_img = plot_multi_traj(est_traj, 'Est.', gt_traj, 'GT', key + ' Set') train_writer.add_image(key + '/est_traj', est_traj_img, epoch + 1) train_writer.add_image(key + '/errors', errors, epoch + 1) train_writer.add_image(key + '/gt_lie_alg', pose_vec_gt, epoch + 1) train_writer.add_image(key + '/est_lie_alg', pose_vec_est, epoch + 1) results[key] = { 'val_seq': args.val_seq, 'test_seq': args.test_seq, 'epochs': epoch + 1, 'est_pose_vecs': est_lie_alg, 'gt_pose_vecs': gt_lie_alg, 'est_traj': est_traj, 'gt_traj': gt_traj, } if args.save_results: ##Save the best models os.makedirs('results/{}'.format(config['date']), exist_ok=True) if (val_losses['l_reconstruct_forward'] + val_losses['l_reconstruct_inverse'] ) < best_val_loss[ key] and epoch > 0: # and epoch > 2*(config['iterations']-1): best_val_loss[key] = ( val_losses['l_reconstruct_forward'] + val_losses['l_reconstruct_inverse']) best_loss_epoch[key] = epoch depth_dict_loss = depth_model.state_dict() pose_dict_loss = pose_model.state_dict() if key == 'val': print("Lowest validation loss (saving model)") torch.save( depth_dict_loss, 'results/{}/{}-depth-best-loss-val_seq-{}-test_seq-{}.pth' .format(config['date'], ts, args.val_seq[0], args.test_seq[0])) torch.save( pose_dict_loss, 'results/{}/{}-pose-best-loss-val_seq-{}-test_seq-{}.pth' .format(config['date'], ts, args.val_seq[0], args.test_seq[0])) if args.data_format == 'odometry': results[key]['best_loss_epoch'] = best_loss_epoch[ key] save_obj( results, 'results/{}/{}-results-val_seq-{}-test_seq-{}'. format(config['date'], ts, args.val_seq[0], args.test_seq[0])) save_obj(config, 'results/{}/config'.format(config['date'])) f = open( "results/{}/config.txt".format(config['date']), "w") f.write(str(config)) f.close() save_obj(loss.scale_factor_list, 'results/{}/scale_factor'.format(config['date'])) duration = timeSince(start) print("Training complete (duration: {})".format(duration))