def create_kitti_submission(model, iters=24, output_path='kitti_submission', write_png=False): """ Create submission for the Sintel leaderboard """ model.eval() test_dataset = datasets.KITTI(split='testing', aug_params=None) if not os.path.exists(output_path): os.makedirs(output_path) if write_png: out_path_png = output_path + '_png' if not os.path.exists(out_path_png): os.makedirs(out_path_png) for test_id in range(len(test_dataset)): image1, image2, (frame_id, ) = test_dataset[test_id] padder = InputPadder(image1.shape, mode='kitti') image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) _, flow_pr = model(image1, image2, iters=iters, test_mode=True) flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() if write_png: output_filename_png = os.path.join(out_path_png, frame_id + '.png') cv2.imwrite(output_filename_png, flow_viz.flow_to_image(flow)) output_filename = os.path.join(output_path, frame_id) frame_utils.writeFlowKITTI(output_filename, flow)
def validate_kitti(args, model, iters=32): """ Evaluate trained model on KITTI (train) """ model.eval() val_dataset = datasets.KITTI(args, do_augument=False, is_val=True, do_pad=True) with torch.no_grad(): epe_list, out_list = [], [] for i in range(len(val_dataset)): image1, image2, flow_gt, valid_gt = val_dataset[i] image1 = image1[None].cuda() image2 = image2[None].cuda() flow_gt = flow_gt.cuda() valid_gt = valid_gt.cuda() flow_predictions = model.module(image1, image2, iters=iters) flow_pr = flow_predictions[-1][0] epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt() mag = torch.sum(flow_gt**2, dim=0).sqrt() epe = epe.view(-1) mag = mag.view(-1) val = valid_gt.view(-1) >= 0.5 out = ((epe > 3.0) & ((epe/mag) > 0.05)).float() epe_list.append(epe[val].mean().item()) out_list.append(out[val].cpu().numpy()) epe_list = np.array(epe_list) out_list = np.concatenate(out_list) print("Validation KITTI: %f, %f" % (np.mean(epe_list), 100*np.mean(out_list)))
def fetch_dataloader(args): """ Create the data loader for the corresponding trainign set """ if args.dataset == 'chairs': train_dataset = datasets.FlyingChairs(args, root=args.dataset_root, image_size=args.image_size) elif args.dataset == 'things': clean_dataset = datasets.SceneFlow(args, root=args.dataset_root, image_size=args.image_size, dstype='frames_cleanpass') final_dataset = datasets.SceneFlow(args, root=args.dataset_root, image_size=args.image_size, dstype='frames_finalpass') train_dataset = clean_dataset + final_dataset elif args.dataset == 'sintel': clean_dataset = datasets.MpiSintel_Train(args, root=args.dataset_root, image_size=args.image_size, dstype='clean') final_dataset = datasets.MpiSintel_Train(args, root=args.dataset_root, image_size=args.image_size, dstype='final') assert len(clean_dataset) == 908 and len(final_dataset) == 908 train_dataset = clean_dataset + final_dataset elif args.dataset == 'kitti': train_dataset = datasets.KITTI(args, root=args.dataset_root, image_size=args.image_size, is_val=False) else: raise NotImplementedError gpuargs = {'num_workers': args.num_of_workers, 'drop_last': True} train_loader = DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=True, shuffle=True, **gpuargs) print('Training with %d image pairs' % len(train_dataset)) return train_loader
def validate_kitti(model, iters=24): """ Peform validation using the KITTI-2015 (train) split """ model.eval() val_dataset = datasets.KITTI(split='training') out_list, epe_list = [], [] for val_id in range(len(val_dataset)): image1, image2, flow_gt, valid_gt = val_dataset[val_id] image1 = image1[None].cuda() image2 = image2[None].cuda() padder = InputPadder(image1.shape, mode='kitti') image1, image2 = padder.pad(image1, image2) flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True) flow = padder.unpad(flow_pr[0]).cpu() epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt() mag = torch.sum(flow_gt**2, dim=0).sqrt() epe = epe.view(-1) mag = mag.view(-1) val = valid_gt.view(-1) >= 0.5 out = ((epe > 3.0) & ((epe/mag) > 0.05)).float() epe_list.append(epe[val].mean().item()) out_list.append(out[val].cpu().numpy()) epe_list = np.array(epe_list) out_list = np.concatenate(out_list) epe = np.mean(epe_list) f1 = 100 * np.mean(out_list) print("Validation KITTI: %f, %f" % (epe, f1)) return {'kitti-epe': epe, 'kitti-f1': f1}
def main(): global args, best_EPE, save_path args = parser.parse_args() # Load config file if args.cfg is not None: cfg_from_file(args.cfg) assert cfg.TAG == os.path.splitext(os.path.basename( args.cfg))[0], 'TAG name should be file name' # Build save_path, which can be specified by out_dir and exp_dir save_path = '{},{}epochs{},b{},lr{}'.format( 'dicl_wrapper', args.epochs, ',epochSize' + str(args.epoch_size) if args.epoch_size > 0 else '', args.batch_size, args.lr) save_path = os.path.join(args.exp_dir, save_path) if args.out_dir is not None: outpath = os.path.join(args.out_dir, args.dataset) else: outpath = args.dataset save_path = os.path.join(outpath, save_path) if not os.path.exists(outpath): os.makedirs(outpath) if not os.path.exists(save_path): os.makedirs(save_path) # Create logger log_file = os.path.join(save_path, 'log.txt') logger = create_logger(log_file) logger.info('**********************Start logging**********************') logger.info('=> will save everything to {}'.format(save_path)) # Print settings for _, key in enumerate(args.__dict__): logger.info(args.__dict__[key]) save_config_to_file(cfg, logger=logger) logger.info(args.pretrained) # Set random seed torch.cuda.manual_seed(args.seed) np.random.seed(args.seed) train_writer = SummaryWriter(os.path.join(save_path, 'train')) eval_writer = SummaryWriter(os.path.join(save_path, 'eval')) logger.info("=> fetching img pairs in '{}'".format(args.data)) ########################## DATALOADER ########################## if args.dataset == 'flying_chairs': if cfg.SIMPLE_AUG: train_dataset = datasets.FlyingChairs_SimpleAug(args, root=args.data) test_dataset = datasets.FlyingChairs_SimpleAug(args, root=args.data, mode='val') else: train_dataset = datasets.FlyingChairs(args, image_size=cfg.CROP_SIZE, root=args.data) test_dataset = datasets.FlyingChairs(args, root=args.data, mode='val', do_augument=False) elif args.dataset == 'flying_things': train_dataset = datasets.SceneFlow(args, image_size=cfg.CROP_SIZE, root=args.data, dstype='frames_cleanpass', mode='train') test_dataset = datasets.SceneFlow(args, image_size=cfg.CROP_SIZE, root=args.data, dstype='frames_cleanpass', mode='val', do_augument=False) elif args.dataset == 'mpi_sintel_clean' or args.dataset == 'mpi_sintel_final': clean_dataset = datasets.MpiSintel(args, image_size=cfg.CROP_SIZE, root=args.data, dstype='clean') final_dataset = datasets.MpiSintel(args, image_size=cfg.CROP_SIZE, root=args.data, dstype='final') train_dataset = torch.utils.data.ConcatDataset([clean_dataset] + [final_dataset]) if args.dataset == 'mpi_sintel_final': test_dataset = datasets.MpiSintel(args, do_augument=False, image_size=None, root=args.data, dstype='final') else: test_dataset = datasets.MpiSintel(args, do_augument=False, image_size=None, root=args.data, dstype='clean') elif args.dataset == 'KITTI': train_dataset = datasets.KITTI(args, image_size=cfg.CROP_SIZE, root=args.data, is_val=False, logger=logger) if args.data_kitti12 is not None: train_dataset12 = datasets.KITTI12(args, image_size=cfg.CROP_SIZE, root=args.data_kitti12, is_val=False, logger=logger) train_dataset = torch.utils.data.ConcatDataset([train_dataset] + [train_dataset12]) test_dataset = datasets.KITTI(args, root=args.data, do_augument=False, is_val=True, do_pad=False) else: raise NotImplementedError logger.info('Training with %d image pairs' % len(train_dataset)) logger.info('Testing with %d image pairs' % len(test_dataset)) gpuargs = {'num_workers': args.workers, 'drop_last': cfg.DROP_LAST} train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=True, shuffle=True, **gpuargs) if 'KITTI' in args.dataset: # We set batch size to 1 since KITTI images have different sizes val_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, num_workers=args.workers, pin_memory=True, shuffle=False) else: val_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True, shuffle=False) # create model if args.pretrained: logger.info("=> using pre-trained model '{}'".format(args.pretrained)) pretrained_dict = torch.load(args.pretrained) if 'state_dict' in pretrained_dict.keys(): pretrained_dict['state_dict'] = { k: v for k, v in pretrained_dict['state_dict'].items() } model = models.__dict__['dicl_wrapper'](None) assert (args.solver in ['adam', 'sgd']) logger.info('=> setting {} solver'.format(args.solver)) if args.solver == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=cfg.WEIGHT_DECAY, betas=(cfg.MOMENTUM, cfg.BETA)) elif args.solver == 'sgd': optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=cfg.WEIGHT_DECAY, momentum=cfg.MOMENTUM) if args.pretrained: if 'state_dict' in pretrained_dict.keys(): model.load_state_dict(pretrained_dict['state_dict'], strict=False) else: model.load_state_dict(pretrained_dict, strict=False) if args.reuse_optim: try: optimizer.load_state_dict(pretrained_dict['optimizer_state']) except: logger.info('do not have optimizer state') del pretrained_dict torch.cuda.empty_cache() model = torch.nn.DataParallel(model) if torch.cuda.is_available(): model = model.cuda() # Evaluation if args.evaluate: with torch.no_grad(): best_EPE = validate(val_loader, model, 0, None, eval_writer, logger=logger) return # Learning rate schedule milestones = [] for num in range(len(args.milestones)): milestones.append(int(args.milestones[num])) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.5) ###################################### Training ###################################### for epoch in range(args.start_epoch, args.epochs): # train for one epoch train_loss = train(train_loader, model, optimizer, epoch, train_writer, logger=logger) scheduler.step() train_writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch) train_writer.add_scalar('avg_loss', train_loss, epoch) if epoch % args.eval_freq == 0 and not args.no_eval: with torch.no_grad(): EPE = validate(val_loader, model, epoch, output_writers, eval_writer, logger=logger) eval_writer.add_scalar('mean_EPE', EPE, epoch) if best_EPE < 0: best_EPE = EPE if EPE < best_EPE: best_EPE = EPE ckpt_best_file = 'checkpoint_best.pth.tar' save_checkpoint( { 'epoch': epoch + 1, 'arch': 'dicl_wrapper', 'state_dict': model.module.state_dict(), 'optimizer_state': optimizer.state_dict(), 'best_EPE': EPE }, False, filename=ckpt_best_file) logger.info('Epoch: [{0}] Best EPE: {1}'.format(epoch, best_EPE)) # Skip at least 5 epochs to save memory save_freq = max(args.eval_freq, 5) if epoch % save_freq == 0: ckpt_file = 'checkpoint_' + str(epoch) + '.pth.tar' save_checkpoint( { 'epoch': epoch + 1, 'arch': 'dicl_wrapper', 'state_dict': model.module.state_dict(), 'optimizer_state': optimizer.state_dict(), 'best_EPE': best_EPE }, False, filename=ckpt_file)