def demo(args): model = RAFT(args) model = torch.nn.DataParallel(model) model.load_state_dict(torch.load(args.model)) model.to(DEVICE) model.eval() with torch.no_grad(): # sintel images image1 = load_image('images/sintel_0.png') image2 = load_image('images/sintel_1.png') flow_predictions = model(image1, image2, iters=args.iters, upsample=False) display(image1[0], image2[0], flow_predictions[-1][0]) # kitti images image1 = load_image('images/kitti_0.png') image2 = load_image('images/kitti_1.png') flow_predictions = model(image1, image2, iters=16) display(image1[0], image2[0], flow_predictions[-1][0]) # davis images image1 = load_image('images/davis_0.jpg') image2 = load_image('images/davis_1.jpg') flow_predictions = model(image1, image2, iters=16) display(image1[0], image2[0], flow_predictions[-1][0])
def train(gpu, ngpus_per_node, args): print("Using GPU %d for training" % gpu) args.gpu = gpu if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=ngpus_per_node, rank=args.gpu) model = RAFT(args) if args.distributed: torch.cuda.set_device(args.gpu) args.batch_size = int(args.batch_size / ngpus_per_node) model = nn.SyncBatchNorm.convert_sync_batchnorm(module=model) model = model.to(f'cuda:{args.gpu}') model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True, output_device=args.gpu) eppCbck = eppConstrainer_background(height=args.image_size[0], width=args.image_size[1], bz=args.batch_size) eppCbck.to(f'cuda:{args.gpu}') eppconcluer = eppConcluer() eppconcluer.to(f'cuda:{args.gpu}') else: model = torch.nn.DataParallel(model) model.cuda() if args.restore_ckpt is not None: print("=> loading checkpoint '{}'".format(args.restore_ckpt)) loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.restore_ckpt, map_location=loc) model.load_state_dict(checkpoint, strict=False) model.eval() if args.stage != 'chairs': model.module.freeze_bn() _, evaluation_entries = read_splits() eval_dataset = KITTI_eigen(split='evaluation', root=args.dataset_root, entries=evaluation_entries, semantics_root=args.semantics_root, depth_root=args.depth_root) eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_dataset) if args.distributed else None eval_loader = data.DataLoader(eval_dataset, batch_size=1, pin_memory=True, shuffle=(eval_sampler is None), num_workers=4, drop_last=True, sampler=eval_sampler) if args.distributed: group = dist.new_group([i for i in range(ngpus_per_node)]) print(validate_kitti(model.module, args, eval_loader, eppCbck, eppconcluer, group)) return
def RAFT(pretrained=False, model_name="chairs+things", device=None, **kwargs): """ RAFT model (https://arxiv.org/abs/2003.12039) model_name (str): One of 'chairs+things', 'sintel', 'kitti' and 'small' note that for 'small', the architecture is smaller """ model_list = ["chairs+things", "sintel", "kitti", "small"] if model_name not in model_list: raise ValueError("Model should be one of " + str(model_list)) model_args = argparse.Namespace(**kwargs) model_args.small = "small" in model_name model = RAFT_module(model_args) if device is None: device = torch.cuda.current_device() if torch.cuda.is_available( ) else "cpu" if device != "cpu": model = torch.nn.DataParallel(model, device_ids=[device]) else: model = torch.nn.DataParallel(model) model.device_ids = None if pretrained: torch_home = _get_torch_home() model_dir = os.path.join(torch_home, "checkpoints", "models_RAFT") model_path = os.path.join(model_dir, "models", model_name + ".pth") if not os.path.exists(model_dir): os.makedirs(model_dir, exist_ok=True) response = urllib.request.urlopen(models_url, timeout=10) z = zipfile.ZipFile(io.BytesIO(response.read())) z.extractall(model_dir) else: time.sleep( 10 ) # Give the time for the models to be downloaded and unzipped map_location = torch.device('cpu') if device == "cpu" else None model.load_state_dict(torch.load(model_path, map_location=map_location)) model = model.to(device) model.eval() return model
def train(gpu, ngpus_per_node, args): print("Using GPU %d for training" % gpu) args.gpu = gpu if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=ngpus_per_node, rank=args.gpu) model = RAFT(args) if args.distributed: torch.cuda.set_device(args.gpu) args.batch_size = int(args.batch_size / ngpus_per_node) model = nn.SyncBatchNorm.convert_sync_batchnorm(module=model) model = model.to(f'cuda:{args.gpu}') model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu], find_unused_parameters=True, output_device=args.gpu) else: model = torch.nn.DataParallel(model) model.cuda() logroot = os.path.join(args.logroot, args.name) print("Parameter Count: %d, saving location: %s" % (count_parameters(model), logroot)) if args.restore_ckpt is not None: print("=> loading checkpoint '{}'".format(args.restore_ckpt)) loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.restore_ckpt, map_location=loc) model.load_state_dict(checkpoint, strict=False) model.train() if args.stage != 'chairs': model.module.freeze_bn() train_entries, evaluation_entries = read_splits() aug_params = { 'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False } train_dataset = VirtualKITTI2(aug_params, split='training', root=args.dataset_root, entries=train_entries) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) if args.distributed else None train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=False, shuffle=(train_sampler is None), num_workers=args.num_workers, drop_last=True, sampler=train_sampler) eval_dataset = VirtualKITTI2(split='evaluation', root=args.dataset_root, entries=evaluation_entries) eval_sampler = torch.utils.data.distributed.DistributedSampler( eval_dataset) if args.distributed else None eval_loader = data.DataLoader(eval_dataset, batch_size=args.batch_size, pin_memory=False, shuffle=(eval_sampler is None), num_workers=args.num_workers, drop_last=True, sampler=eval_sampler) if args.distributed: group = dist.new_group([i for i in range(ngpus_per_node)]) optimizer, scheduler = fetch_optimizer(args, model) total_steps = 0 scaler = GradScaler(enabled=args.mixed_precision) if args.gpu == 0: logger = Logger(model, scheduler, logroot) logger_evaluation = Logger( model, scheduler, os.path.join(args.logroot, 'evaluation_VRKitti', args.name)) VAL_FREQ = 500 add_noise = True epoch = 0 should_keep_training = True while should_keep_training: for i_batch, data_blob in enumerate(train_loader): optimizer.zero_grad() image1, image2, flow, valid = data_blob image1 = Variable(image1, requires_grad=True) image1 = image1.cuda(gpu, non_blocking=True) image2 = Variable(image2, requires_grad=True) image2 = image2.cuda(gpu, non_blocking=True) flow = Variable(flow, requires_grad=True) flow = flow.cuda(gpu, non_blocking=True) valid = Variable(valid, requires_grad=True) valid = valid.cuda(gpu, non_blocking=True) if add_noise: stdv = np.random.uniform(0.0, 5.0) image1 = (image1 + stdv * torch.randn(*image1.shape).cuda( gpu, non_blocking=True)).clamp(0.0, 255.0) image2 = (image2 + stdv * torch.randn(*image2.shape).cuda( gpu, non_blocking=True)).clamp(0.0, 255.0) flow_predictions = model(image1, image2, iters=args.iters) loss, metrics = sequence_loss(flow_predictions, flow, valid, args.gamma) scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) scaler.step(optimizer) scheduler.step() scaler.update() if args.gpu == 0: logger.push(metrics, image1, image2, flow, flow_predictions, valid) if total_steps % VAL_FREQ == VAL_FREQ - 1: results = validate_VRKitti2(model.module, args, eval_loader, group) model.train() if args.stage != 'chairs': model.module.freeze_bn() if args.gpu == 0: logger_evaluation.write_dict(results, total_steps) PATH = os.path.join( logroot, '%s.pth' % (str(total_steps + 1).zfill(3))) torch.save(model.state_dict(), PATH) total_steps += 1 if total_steps > args.num_steps: should_keep_training = False break epoch = epoch + 1 if args.gpu == 0: logger.close() PATH = os.path.join(logroot, 'final.pth') torch.save(model.state_dict(), PATH) return PATH
def train(gpu, ngpus_per_node, args): print("Using GPU %d for training" % gpu) args.gpu = gpu if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=ngpus_per_node, rank=args.gpu) model = RAFT(args=args) if args.distributed: torch.cuda.set_device(args.gpu) args.batch_size = int(args.batch_size / ngpus_per_node) model = nn.SyncBatchNorm.convert_sync_batchnorm(module=model) model = model.to(f'cuda:{args.gpu}') model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True, output_device=args.gpu) else: model = torch.nn.DataParallel(model) model.cuda() logroot = os.path.join(args.logroot, args.name) print("Parameter Count: %d, saving location: %s" % (count_parameters(model), logroot)) if args.restore_ckpt is not None: print("=> loading checkpoint '{}'".format(args.restore_ckpt)) loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.restore_ckpt, map_location=loc) model.load_state_dict(checkpoint, strict=False) model.train() train_entries, evaluation_entries = read_splits() train_dataset = KITTI_eigen(root=args.dataset_root, inheight=args.inheight, inwidth=args.inwidth, entries=train_entries, maxinsnum=args.maxinsnum, depth_root=args.depth_root, depthvls_root=args.depthvlsgt_root, prediction_root=args.prediction_root, ins_root=args.ins_root, istrain=True, muteaug=False, banremovedup=False, isgarg=True) train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=True, num_workers=int(args.num_workers / ngpus_per_node), drop_last=True, sampler=train_sampler) eval_dataset = KITTI_eigen(root=args.dataset_root, inheight=args.evalheight, inwidth=args.evalwidth, entries=evaluation_entries, maxinsnum=args.maxinsnum, depth_root=args.depth_root, depthvls_root=args.depthvlsgt_root, prediction_root=args.prediction_root, ins_root=args.ins_root, istrain=False, isgarg=True) eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_dataset) if args.distributed else None eval_loader = data.DataLoader(eval_dataset, batch_size=1, pin_memory=True, num_workers=3, drop_last=True, sampler=eval_sampler) print("Training splits contain %d images while test splits contain %d images" % (train_dataset.__len__(), eval_dataset.__len__())) if args.distributed: group = dist.new_group([i for i in range(ngpus_per_node)]) optimizer, scheduler = fetch_optimizer(args, model, int(train_dataset.__len__() / 2)) total_steps = 0 if args.gpu == 0: logger = Logger(logroot) logger_evaluation = Logger(os.path.join(args.logroot, 'evaluation_eigen_background', args.name)) logger.create_summarywriter() logger_evaluation.create_summarywriter() VAL_FREQ = 5000 epoch = 0 minout = 1 st = time.time() should_keep_training = True while should_keep_training: train_sampler.set_epoch(epoch) for i_batch, data_blob in enumerate(train_loader): optimizer.zero_grad() image1 = data_blob['img1'].cuda(gpu) / 255.0 image2 = data_blob['img2'].cuda(gpu) / 255.0 flowmap = data_blob['flowmap'].cuda(gpu) outputs = model(image1, image2) selector = (flowmap[:, 0, :, :] != 0) flow_loss = sequence_loss(outputs, flowmap, selector, gamma=args.gamma, max_flow=MAX_FLOW) metrics = dict() metrics['flow_loss'] = flow_loss loss = flow_loss loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) optimizer.step() scheduler.step() if args.gpu == 0: logger.write_dict(metrics, step=total_steps) if total_steps % SUM_FREQ == 0: dr = time.time() - st resths = (args.num_steps - total_steps) * dr / (total_steps + 1) / 60 / 60 print("Step: %d, rest hour: %f, flowloss: %f" % (total_steps, resths, flow_loss.item())) logger.write_vls(data_blob, outputs, selector.unsqueeze(1), total_steps) if total_steps % VAL_FREQ == 1: if args.gpu == 0: results = validate_kitti(model.module, args, eval_loader, logger, group, total_steps) else: results = validate_kitti(model.module, args, eval_loader, None, group, None) if args.gpu == 0: logger_evaluation.write_dict(results, total_steps) if minout > results['out']: minout = results['out'] PATH = os.path.join(logroot, 'minout.pth') torch.save(model.state_dict(), PATH) print("model saved to %s" % PATH) model.train() total_steps += 1 if total_steps > args.num_steps: should_keep_training = False break epoch = epoch + 1 if args.gpu == 0: logger.close() PATH = os.path.join(logroot, 'final.pth') torch.save(model.state_dict(), PATH) return
out = ((epe > 3.0) & ((epe / mag) > 0.05)).float() epe_list.append(epe[val].mean().item()) out_list.append(out[val].cpu().numpy()) epe_list = np.array(epe_list) out_list = np.concatenate(out_list) print("Validation KITTI: %f, %f" % (np.mean(epe_list), 100 * np.mean(out_list))) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--model', help="restore checkpoint") parser.add_argument('--small', action='store_true', help='use small model') parser.add_argument('--sintel_iters', type=int, default=50) parser.add_argument('--kitti_iters', type=int, default=32) args = parser.parse_args() model = RAFT(args) model = torch.nn.DataParallel(model) model.load_state_dict(torch.load(args.model)) model.to('cuda') model.eval() validate_sintel(args, model, args.sintel_iters) validate_kitti(args, model, args.kitti_iters)
def train(gpu, ngpus_per_node, args): print("Using GPU %d for training" % gpu) args.gpu = gpu if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=ngpus_per_node, rank=args.gpu) model = RAFT(args) if args.distributed: torch.cuda.set_device(args.gpu) args.batch_size = int(args.batch_size / ngpus_per_node) model = nn.SyncBatchNorm.convert_sync_batchnorm(module=model) model = model.to(f'cuda:{args.gpu}') model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True, output_device=args.gpu) eppCbck = eppConstrainer_background(height=args.image_size[0], width=args.image_size[1], bz=args.batch_size) eppCbck.to(f'cuda:{args.gpu}') eppconcluer = eppConcluer() eppconcluer.to(f'cuda:{args.gpu}') else: model = torch.nn.DataParallel(model) model.cuda() logroot = os.path.join(args.logroot, args.name) print("Parameter Count: %d, saving location: %s" % (count_parameters(model), logroot)) if args.restore_ckpt is not None: print("=> loading checkpoint '{}'".format(args.restore_ckpt)) loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.restore_ckpt, map_location=loc) model.load_state_dict(checkpoint, strict=False) model.train() if args.stage != 'chairs': model.module.freeze_bn() train_entries, evaluation_entries = read_splits() aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} train_dataset = KITTI_eigen(aug_params, split='training', root=args.dataset_root, entries=train_entries, semantics_root=args.semantics_root, depth_root=args.depth_root) train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=True, shuffle=(train_sampler is None), num_workers=int(args.num_workers / ngpus_per_node), drop_last=True, sampler=train_sampler) eval_dataset = KITTI_eigen(split='evaluation', root=args.dataset_root, entries=evaluation_entries, semantics_root=args.semantics_root, depth_root=args.depth_root) eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_dataset) if args.distributed else None eval_loader = data.DataLoader(eval_dataset, batch_size=1, pin_memory=True, shuffle=(eval_sampler is None), num_workers=3, drop_last=True, sampler=eval_sampler) if args.distributed: group = dist.new_group([i for i in range(ngpus_per_node)]) optimizer, scheduler = fetch_optimizer(args, model) total_steps = 0 if args.gpu == 0: logger = Logger(model, scheduler, logroot, args.num_steps) logger_evaluation = Logger(model, scheduler, os.path.join(args.logroot, 'evaluation_eigen_background', args.name), args.num_steps) VAL_FREQ = 5000 add_noise = False epoch = 0 should_keep_training = True print(validate_kitti(model.module, args, eval_loader, eppCbck, eppconcluer, group)) while should_keep_training: train_sampler.set_epoch(epoch) for i_batch, data_blob in enumerate(train_loader): optimizer.zero_grad() image1 = data_blob['img1'] image1 = Variable(image1, requires_grad=True) image1 = image1.cuda(gpu, non_blocking=True) image2 = data_blob['img2'] image2 = Variable(image2, requires_grad=True) image2 = image2.cuda(gpu, non_blocking=True) flow = data_blob['flow'] flow = Variable(flow, requires_grad=True) flow = flow.cuda(gpu, non_blocking=True) valid = data_blob['valid'] valid = Variable(valid, requires_grad=True) valid = valid.cuda(gpu, non_blocking=True) E = data_blob['E'] E = Variable(E, requires_grad=True) E = E.cuda(gpu, non_blocking=True) semantic_selector = data_blob['semantic_selector'] semantic_selector = Variable(semantic_selector, requires_grad=True) semantic_selector = semantic_selector.cuda(gpu, non_blocking=True) if add_noise: stdv = np.random.uniform(0.0, 5.0) image1 = (image1 + stdv * torch.randn(*image1.shape).cuda(gpu, non_blocking=True)).clamp(0.0, 255.0) image2 = (image2 + stdv * torch.randn(*image2.shape).cuda(gpu, non_blocking=True)).clamp(0.0, 255.0) flow_predictions = model(image1, image2, iters=args.iters) metrics = dict() loss_flow, metrics_flow = sequence_flowloss(flow_predictions, flow, valid, args.gamma) loss_eppc, metrics_eppc = sequence_eppcloss(eppCbck, flow_predictions, semantic_selector, E, args.gamma) metrics.update(metrics_flow) metrics.update(metrics_eppc) loss = loss_flow + loss_eppc * args.eppcw loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) optimizer.step() scheduler.step() if args.gpu == 0: logger.push(metrics, image1, image2, flow, flow_predictions, valid, data_blob['depth']) if total_steps % VAL_FREQ == VAL_FREQ - 1: results = validate_kitti(model.module, args, eval_loader, eppCbck, eppconcluer, group) model.train() if args.stage != 'chairs': model.module.freeze_bn() if args.gpu == 0: logger_evaluation.write_dict(results, total_steps) PATH = os.path.join(logroot, '%s.pth' % (str(total_steps + 1).zfill(3))) torch.save(model.state_dict(), PATH) total_steps += 1 if total_steps > args.num_steps: should_keep_training = False break epoch = epoch + 1 if args.gpu == 0: logger.close() PATH = os.path.join(logroot, 'final.pth') torch.save(model.state_dict(), PATH) return PATH
def train(gpu, ngpus_per_node, args): print("Using GPU %d for training" % gpu) args.gpu = gpu if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=ngpus_per_node, rank=args.gpu) model = RAFT(args=args) if args.distributed: torch.cuda.set_device(args.gpu) args.batch_size = int(args.batch_size / ngpus_per_node) model = nn.SyncBatchNorm.convert_sync_batchnorm(module=model) model = model.to(f'cuda:{args.gpu}') model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu], find_unused_parameters=True, output_device=args.gpu) else: model = torch.nn.DataParallel(model) model.cuda() if args.restore_ckpt is not None: print("=> loading checkpoint '{}'".format(args.restore_ckpt)) loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.restore_ckpt, map_location=loc) model.load_state_dict(checkpoint, strict=False) evaluation_entries = read_splits_mapping() eval_dataset = KITTI_eigen_stereo15(root=args.dataset_stereo15_orgned_root, inheight=args.evalheight, inwidth=args.evalwidth, entries=evaluation_entries, mdPred_root=args.mdPred_root, maxinsnum=args.maxinsnum, istrain=True, isgarg=True, deepv2dpred_root=args.deepv2dpred_root, prediction_root=args.prediction_root, flowPred_root=args.flowPred_root) eval_sampler = torch.utils.data.distributed.DistributedSampler( eval_dataset) if args.distributed else None eval_loader = data.DataLoader(eval_dataset, batch_size=1, pin_memory=True, num_workers=3, drop_last=True, sampler=eval_sampler) print("Test splits contain %d images" % (eval_dataset.__len__())) if args.distributed: group = dist.new_group([i for i in range(ngpus_per_node)]) # validate_RAFT_flow(args, model, eval_loader, group, usestreodepth=False) validate_RAFT_flow_pose(args, model, eval_loader, group, usestreodepth=False, scale_info_src='mdPred') # validate_RAFT_flow_pose(args, model, eval_loader, group, usestreodepth=True, scale_info_src='mdPred') # validate_RAFT_flow_pose(args, model, eval_loader, group, usestreodepth=False, scale_info_src='deepv2d_depth') # validate_RAFT_flow_pose(args, model, eval_loader, group, usestreodepth=True, scale_info_src='stereo_depth') # validate_RAFT_flow_pose(args, model, eval_loader, group, usestreodepth=False, scale_info_src='stereo_depth') # validate_RAFT_flow_pose(args, model, eval_loader, group, usestreodepth=True, scale_info_src='deppv2d_pose') # validate_RAFT_flow_pose(args, model, eval_loader, group, usestreodepth=False, scale_info_src='deppv2d_pose') # validate_RAFT_flow_pose(args, model, eval_loader, group, usestreodepth=False) return
def train(gpu, ngpus_per_node, args): print("Using GPU %d for training" % gpu) args.gpu = gpu if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=ngpus_per_node, rank=args.gpu) model = RAFT(args=args) if args.distributed: torch.cuda.set_device(args.gpu) args.batch_size = int(args.batch_size / ngpus_per_node) model = nn.SyncBatchNorm.convert_sync_batchnorm(module=model) model = model.to(f'cuda:{args.gpu}') model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu], find_unused_parameters=True, output_device=args.gpu) else: model = torch.nn.DataParallel(model) model.cuda() logroot = os.path.join(args.logroot, args.name) print("Parameter Count: %d, saving location: %s" % (count_parameters(model), logroot)) if args.restore_ckpt is not None: print("=> loading checkpoint '{}'".format(args.restore_ckpt)) loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.restore_ckpt, map_location=loc) model.load_state_dict(checkpoint, strict=False) model.train() train_entries, evaluation_entries = read_splits() eval_dataset = KITTI_eigen(root=args.dataset_root, inheight=args.evalheight, inwidth=args.evalwidth, entries=evaluation_entries, maxinsnum=args.maxinsnum, depth_root=args.depth_root, depthvls_root=args.depthvlsgt_root, prediction_root=args.prediction_root, ins_root=args.ins_root, istrain=False, isgarg=True) eval_sampler = torch.utils.data.distributed.DistributedSampler( eval_dataset) if args.distributed else None eval_loader = data.DataLoader(eval_dataset, batch_size=1, pin_memory=True, num_workers=3, drop_last=True, sampler=eval_sampler) print("Test splits contain %d images" % (eval_dataset.__len__())) if args.distributed: group = dist.new_group([i for i in range(ngpus_per_node)]) validate_kitti(model.module, args, eval_loader, group) return
def train(gpu, ngpus_per_node, args): print("Using GPU %d for training" % gpu) args.gpu = gpu if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=ngpus_per_node, rank=args.gpu) model = RAFT(args) if args.distributed: torch.cuda.set_device(args.gpu) args.batch_size = int(args.batch_size / ngpus_per_node) model = nn.SyncBatchNorm.convert_sync_batchnorm(module=model) model = model.to(f'cuda:{args.gpu}') model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu], find_unused_parameters=True, output_device=args.gpu) else: model = torch.nn.DataParallel(model) model.cuda() logroot = os.path.join(args.logroot, args.name) print("Parameter Count: %d, saving location: %s" % (count_parameters(model), logroot)) if args.restore_ckpt is not None: print("=> loading checkpoint '{}'".format(args.restore_ckpt)) loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.restore_ckpt, map_location=loc) model.load_state_dict(checkpoint, strict=False) model.train() train_entries, evaluation_entries = read_splits() train_dataset = VirtualKITTI2(args=args, root=args.dataset_root, inheight=args.inheight, inwidth=args.inwidth, entries=train_entries, istrain=True) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) if args.distributed else None train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=False, shuffle=(train_sampler is None), num_workers=args.num_workers, drop_last=True, sampler=train_sampler) eval_dataset = VirtualKITTI2(args=args, root=args.dataset_root, inheight=args.evalheight, inwidth=args.evalwidth, entries=evaluation_entries, istrain=False) eval_sampler = torch.utils.data.distributed.DistributedSampler( eval_dataset) if args.distributed else None eval_loader = data.DataLoader(eval_dataset, batch_size=args.batch_size, pin_memory=False, shuffle=(eval_sampler is None), num_workers=2, drop_last=True, sampler=eval_sampler) print( "Training split contains %d images, validation split contained %d images" % (len(train_entries), len(evaluation_entries))) if args.distributed: group = dist.new_group([i for i in range(ngpus_per_node)]) optimizer, scheduler = fetch_optimizer(args, model) total_steps = 0 VAL_ITERINC = 4 if args.gpu == 0: logger = Logger(model, scheduler, logroot) logger.create_summarywriter() logger_evaluations = dict() for num_iters in range(args.iters, args.iters * 2 + 1, VAL_ITERINC): logger_evaluation = Logger( model, scheduler, os.path.join( args.logroot, 'evaluation_VRKitti', "{}_iternum{}".format(args.name, str(num_iters).zfill(2)))) logger_evaluation.create_summarywriter() logger_evaluations[num_iters] = logger_evaluation VAL_FREQ = 2000 maxout = 1 epoch = 0 st = time.time() should_keep_training = True while should_keep_training: for i_batch, data_blob in enumerate(train_loader): optimizer.zero_grad() image1 = data_blob['img1'].cuda(gpu, non_blocking=True) image2 = data_blob['img2'].cuda(gpu, non_blocking=True) flow = data_blob['flowmap'].cuda(gpu, non_blocking=True) # exlude invalid pixels and extremely large diplacements mag = torch.sum(flow**2, dim=1).sqrt() valid = ((flow[:, 0] != 0) * (flow[:, 1] != 0) * (mag < MAX_FLOW)).unsqueeze(1) flow_predictions = model(image1, image2, iters=args.iters) loss, metrics = sequence_loss(flow_predictions, flow, valid, args.gamma) metrics = dict() metrics['loss_flow'] = loss.float().item() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) optimizer.step() scheduler.step() if args.gpu == 0: logger.write_dict(metrics, total_steps) if total_steps % SUM_FREQ == 0: dr = time.time() - st resths = (args.num_steps - total_steps) * dr / (total_steps + 1) / 60 / 60 print("Step: %d, rest hour: %f, flow loss: %f" % (total_steps, resths, loss.item())) logger.write_vls(data_blob, flow_predictions, valid, total_steps) if total_steps % VAL_FREQ == 1: for num_iters in range(args.iters, args.iters * 2 + 1, VAL_ITERINC): if args.gpu == 0 and num_iters == 24: results = validate_VRKitti2(model.module, args, eval_loader, num_iters, group, logger, total_steps) else: results = validate_VRKitti2(model.module, args, eval_loader, num_iters, group, None, None) if args.gpu == 0: logger_evaluations[num_iters].write_dict( results, total_steps) if num_iters == 24: if results['out'] < maxout: maxout = results['out'] PATH = os.path.join(logroot, 'minout.pth') torch.save(model.state_dict(), PATH) print("model saved to %s" % PATH) model.train() total_steps += 1 if total_steps > args.num_steps: should_keep_training = False break epoch = epoch + 1 if args.gpu == 0: logger.close() PATH = os.path.join(logroot, 'final.pth') torch.save(model.state_dict(), PATH) return PATH