def train(args): model = RAFT(args) model = nn.DataParallel(model) print("Parameter Count: %d" % count_parameters(model)) if args.restore_ckpt is not None: model.load_state_dict(torch.load(args.restore_ckpt)) model.cuda() model.train() if 'chairs' not in args.dataset: model.module.freeze_bn() train_loader = fetch_dataloader(args) optimizer, scheduler = fetch_optimizer(args, model) total_steps = 0 logger = Logger(model, scheduler) should_keep_training = True while should_keep_training: for i_batch, data_blob in enumerate(train_loader): image1, image2, flow, valid = [x.cuda() for x in data_blob] optimizer.zero_grad() flow_predictions = model(image1, image2, iters=args.iters) loss, metrics = sequence_loss(flow_predictions, flow, valid) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) optimizer.step() scheduler.step() total_steps += 1 logger.push(metrics) if total_steps % VAL_FREQ == VAL_FREQ - 1: PATH = 'checkpoints/%d_%s.pth' % (total_steps + 1, args.name) torch.save(model.state_dict(), PATH) if total_steps == args.num_steps: should_keep_training = False break PATH = 'checkpoints/%s.pth' % args.name torch.save(model.state_dict(), PATH) return PATH
def train(gpu, ngpus_per_node, args): print("Using GPU %d for training" % gpu) args.gpu = gpu if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=ngpus_per_node, rank=args.gpu) model = RAFT(args) if args.distributed: torch.cuda.set_device(args.gpu) args.batch_size = int(args.batch_size / ngpus_per_node) model = nn.SyncBatchNorm.convert_sync_batchnorm(module=model) model = model.to(f'cuda:{args.gpu}') model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True, output_device=args.gpu) eppCbck = eppConstrainer_background(height=args.image_size[0], width=args.image_size[1], bz=args.batch_size) eppCbck.to(f'cuda:{args.gpu}') eppconcluer = eppConcluer() eppconcluer.to(f'cuda:{args.gpu}') else: model = torch.nn.DataParallel(model) model.cuda() if args.restore_ckpt is not None: print("=> loading checkpoint '{}'".format(args.restore_ckpt)) loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.restore_ckpt, map_location=loc) model.load_state_dict(checkpoint, strict=False) model.eval() if args.stage != 'chairs': model.module.freeze_bn() _, evaluation_entries = read_splits() eval_dataset = KITTI_eigen(split='evaluation', root=args.dataset_root, entries=evaluation_entries, semantics_root=args.semantics_root, depth_root=args.depth_root) eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_dataset) if args.distributed else None eval_loader = data.DataLoader(eval_dataset, batch_size=1, pin_memory=True, shuffle=(eval_sampler is None), num_workers=4, drop_last=True, sampler=eval_sampler) if args.distributed: group = dist.new_group([i for i in range(ngpus_per_node)]) print(validate_kitti(model.module, args, eval_loader, eppCbck, eppconcluer, group)) return
def train(gpu, ngpus_per_node, args): print("Using GPU %d for training" % gpu) args.gpu = gpu if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=ngpus_per_node, rank=args.gpu) model = RAFT(args) if args.distributed: torch.cuda.set_device(args.gpu) args.batch_size = int(args.batch_size / ngpus_per_node) model = nn.SyncBatchNorm.convert_sync_batchnorm(module=model) model = model.to(f'cuda:{args.gpu}') model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu], find_unused_parameters=True, output_device=args.gpu) else: model = torch.nn.DataParallel(model) model.cuda() logroot = os.path.join(args.logroot, args.name) print("Parameter Count: %d, saving location: %s" % (count_parameters(model), logroot)) if args.restore_ckpt is not None: print("=> loading checkpoint '{}'".format(args.restore_ckpt)) loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.restore_ckpt, map_location=loc) model.load_state_dict(checkpoint, strict=False) model.train() if args.stage != 'chairs': model.module.freeze_bn() train_entries, evaluation_entries = read_splits() aug_params = { 'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False } train_dataset = VirtualKITTI2(aug_params, split='training', root=args.dataset_root, entries=train_entries) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) if args.distributed else None train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=False, shuffle=(train_sampler is None), num_workers=args.num_workers, drop_last=True, sampler=train_sampler) eval_dataset = VirtualKITTI2(split='evaluation', root=args.dataset_root, entries=evaluation_entries) eval_sampler = torch.utils.data.distributed.DistributedSampler( eval_dataset) if args.distributed else None eval_loader = data.DataLoader(eval_dataset, batch_size=args.batch_size, pin_memory=False, shuffle=(eval_sampler is None), num_workers=args.num_workers, drop_last=True, sampler=eval_sampler) if args.distributed: group = dist.new_group([i for i in range(ngpus_per_node)]) optimizer, scheduler = fetch_optimizer(args, model) total_steps = 0 scaler = GradScaler(enabled=args.mixed_precision) if args.gpu == 0: logger = Logger(model, scheduler, logroot) logger_evaluation = Logger( model, scheduler, os.path.join(args.logroot, 'evaluation_VRKitti', args.name)) VAL_FREQ = 500 add_noise = True epoch = 0 should_keep_training = True while should_keep_training: for i_batch, data_blob in enumerate(train_loader): optimizer.zero_grad() image1, image2, flow, valid = data_blob image1 = Variable(image1, requires_grad=True) image1 = image1.cuda(gpu, non_blocking=True) image2 = Variable(image2, requires_grad=True) image2 = image2.cuda(gpu, non_blocking=True) flow = Variable(flow, requires_grad=True) flow = flow.cuda(gpu, non_blocking=True) valid = Variable(valid, requires_grad=True) valid = valid.cuda(gpu, non_blocking=True) if add_noise: stdv = np.random.uniform(0.0, 5.0) image1 = (image1 + stdv * torch.randn(*image1.shape).cuda( gpu, non_blocking=True)).clamp(0.0, 255.0) image2 = (image2 + stdv * torch.randn(*image2.shape).cuda( gpu, non_blocking=True)).clamp(0.0, 255.0) flow_predictions = model(image1, image2, iters=args.iters) loss, metrics = sequence_loss(flow_predictions, flow, valid, args.gamma) scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) scaler.step(optimizer) scheduler.step() scaler.update() if args.gpu == 0: logger.push(metrics, image1, image2, flow, flow_predictions, valid) if total_steps % VAL_FREQ == VAL_FREQ - 1: results = validate_VRKitti2(model.module, args, eval_loader, group) model.train() if args.stage != 'chairs': model.module.freeze_bn() if args.gpu == 0: logger_evaluation.write_dict(results, total_steps) PATH = os.path.join( logroot, '%s.pth' % (str(total_steps + 1).zfill(3))) torch.save(model.state_dict(), PATH) total_steps += 1 if total_steps > args.num_steps: should_keep_training = False break epoch = epoch + 1 if args.gpu == 0: logger.close() PATH = os.path.join(logroot, 'final.pth') torch.save(model.state_dict(), PATH) return PATH
def train(gpu, ngpus_per_node, args): print("Using GPU %d for training" % gpu) args.gpu = gpu if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=ngpus_per_node, rank=args.gpu) model = RAFT(args=args) if args.distributed: torch.cuda.set_device(args.gpu) args.batch_size = int(args.batch_size / ngpus_per_node) model = nn.SyncBatchNorm.convert_sync_batchnorm(module=model) model = model.to(f'cuda:{args.gpu}') model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True, output_device=args.gpu) else: model = torch.nn.DataParallel(model) model.cuda() logroot = os.path.join(args.logroot, args.name) print("Parameter Count: %d, saving location: %s" % (count_parameters(model), logroot)) if args.restore_ckpt is not None: print("=> loading checkpoint '{}'".format(args.restore_ckpt)) loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.restore_ckpt, map_location=loc) model.load_state_dict(checkpoint, strict=False) model.train() train_entries, evaluation_entries = read_splits() train_dataset = KITTI_eigen(root=args.dataset_root, inheight=args.inheight, inwidth=args.inwidth, entries=train_entries, maxinsnum=args.maxinsnum, depth_root=args.depth_root, depthvls_root=args.depthvlsgt_root, prediction_root=args.prediction_root, ins_root=args.ins_root, istrain=True, muteaug=False, banremovedup=False, isgarg=True) train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=True, num_workers=int(args.num_workers / ngpus_per_node), drop_last=True, sampler=train_sampler) eval_dataset = KITTI_eigen(root=args.dataset_root, inheight=args.evalheight, inwidth=args.evalwidth, entries=evaluation_entries, maxinsnum=args.maxinsnum, depth_root=args.depth_root, depthvls_root=args.depthvlsgt_root, prediction_root=args.prediction_root, ins_root=args.ins_root, istrain=False, isgarg=True) eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_dataset) if args.distributed else None eval_loader = data.DataLoader(eval_dataset, batch_size=1, pin_memory=True, num_workers=3, drop_last=True, sampler=eval_sampler) print("Training splits contain %d images while test splits contain %d images" % (train_dataset.__len__(), eval_dataset.__len__())) if args.distributed: group = dist.new_group([i for i in range(ngpus_per_node)]) optimizer, scheduler = fetch_optimizer(args, model, int(train_dataset.__len__() / 2)) total_steps = 0 if args.gpu == 0: logger = Logger(logroot) logger_evaluation = Logger(os.path.join(args.logroot, 'evaluation_eigen_background', args.name)) logger.create_summarywriter() logger_evaluation.create_summarywriter() VAL_FREQ = 5000 epoch = 0 minout = 1 st = time.time() should_keep_training = True while should_keep_training: train_sampler.set_epoch(epoch) for i_batch, data_blob in enumerate(train_loader): optimizer.zero_grad() image1 = data_blob['img1'].cuda(gpu) / 255.0 image2 = data_blob['img2'].cuda(gpu) / 255.0 flowmap = data_blob['flowmap'].cuda(gpu) outputs = model(image1, image2) selector = (flowmap[:, 0, :, :] != 0) flow_loss = sequence_loss(outputs, flowmap, selector, gamma=args.gamma, max_flow=MAX_FLOW) metrics = dict() metrics['flow_loss'] = flow_loss loss = flow_loss loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) optimizer.step() scheduler.step() if args.gpu == 0: logger.write_dict(metrics, step=total_steps) if total_steps % SUM_FREQ == 0: dr = time.time() - st resths = (args.num_steps - total_steps) * dr / (total_steps + 1) / 60 / 60 print("Step: %d, rest hour: %f, flowloss: %f" % (total_steps, resths, flow_loss.item())) logger.write_vls(data_blob, outputs, selector.unsqueeze(1), total_steps) if total_steps % VAL_FREQ == 1: if args.gpu == 0: results = validate_kitti(model.module, args, eval_loader, logger, group, total_steps) else: results = validate_kitti(model.module, args, eval_loader, None, group, None) if args.gpu == 0: logger_evaluation.write_dict(results, total_steps) if minout > results['out']: minout = results['out'] PATH = os.path.join(logroot, 'minout.pth') torch.save(model.state_dict(), PATH) print("model saved to %s" % PATH) model.train() total_steps += 1 if total_steps > args.num_steps: should_keep_training = False break epoch = epoch + 1 if args.gpu == 0: logger.close() PATH = os.path.join(logroot, 'final.pth') torch.save(model.state_dict(), PATH) return
def train(gpu, ngpus_per_node, args): print("Using GPU %d for training" % gpu) args.gpu = gpu if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=ngpus_per_node, rank=args.gpu) model = RAFT(args) if args.distributed: torch.cuda.set_device(args.gpu) args.batch_size = int(args.batch_size / ngpus_per_node) model = nn.SyncBatchNorm.convert_sync_batchnorm(module=model) model = model.to(f'cuda:{args.gpu}') model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True, output_device=args.gpu) eppCbck = eppConstrainer_background(height=args.image_size[0], width=args.image_size[1], bz=args.batch_size) eppCbck.to(f'cuda:{args.gpu}') eppconcluer = eppConcluer() eppconcluer.to(f'cuda:{args.gpu}') else: model = torch.nn.DataParallel(model) model.cuda() logroot = os.path.join(args.logroot, args.name) print("Parameter Count: %d, saving location: %s" % (count_parameters(model), logroot)) if args.restore_ckpt is not None: print("=> loading checkpoint '{}'".format(args.restore_ckpt)) loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.restore_ckpt, map_location=loc) model.load_state_dict(checkpoint, strict=False) model.train() if args.stage != 'chairs': model.module.freeze_bn() train_entries, evaluation_entries = read_splits() aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} train_dataset = KITTI_eigen(aug_params, split='training', root=args.dataset_root, entries=train_entries, semantics_root=args.semantics_root, depth_root=args.depth_root) train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=True, shuffle=(train_sampler is None), num_workers=int(args.num_workers / ngpus_per_node), drop_last=True, sampler=train_sampler) eval_dataset = KITTI_eigen(split='evaluation', root=args.dataset_root, entries=evaluation_entries, semantics_root=args.semantics_root, depth_root=args.depth_root) eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_dataset) if args.distributed else None eval_loader = data.DataLoader(eval_dataset, batch_size=1, pin_memory=True, shuffle=(eval_sampler is None), num_workers=3, drop_last=True, sampler=eval_sampler) if args.distributed: group = dist.new_group([i for i in range(ngpus_per_node)]) optimizer, scheduler = fetch_optimizer(args, model) total_steps = 0 if args.gpu == 0: logger = Logger(model, scheduler, logroot, args.num_steps) logger_evaluation = Logger(model, scheduler, os.path.join(args.logroot, 'evaluation_eigen_background', args.name), args.num_steps) VAL_FREQ = 5000 add_noise = False epoch = 0 should_keep_training = True print(validate_kitti(model.module, args, eval_loader, eppCbck, eppconcluer, group)) while should_keep_training: train_sampler.set_epoch(epoch) for i_batch, data_blob in enumerate(train_loader): optimizer.zero_grad() image1 = data_blob['img1'] image1 = Variable(image1, requires_grad=True) image1 = image1.cuda(gpu, non_blocking=True) image2 = data_blob['img2'] image2 = Variable(image2, requires_grad=True) image2 = image2.cuda(gpu, non_blocking=True) flow = data_blob['flow'] flow = Variable(flow, requires_grad=True) flow = flow.cuda(gpu, non_blocking=True) valid = data_blob['valid'] valid = Variable(valid, requires_grad=True) valid = valid.cuda(gpu, non_blocking=True) E = data_blob['E'] E = Variable(E, requires_grad=True) E = E.cuda(gpu, non_blocking=True) semantic_selector = data_blob['semantic_selector'] semantic_selector = Variable(semantic_selector, requires_grad=True) semantic_selector = semantic_selector.cuda(gpu, non_blocking=True) if add_noise: stdv = np.random.uniform(0.0, 5.0) image1 = (image1 + stdv * torch.randn(*image1.shape).cuda(gpu, non_blocking=True)).clamp(0.0, 255.0) image2 = (image2 + stdv * torch.randn(*image2.shape).cuda(gpu, non_blocking=True)).clamp(0.0, 255.0) flow_predictions = model(image1, image2, iters=args.iters) metrics = dict() loss_flow, metrics_flow = sequence_flowloss(flow_predictions, flow, valid, args.gamma) loss_eppc, metrics_eppc = sequence_eppcloss(eppCbck, flow_predictions, semantic_selector, E, args.gamma) metrics.update(metrics_flow) metrics.update(metrics_eppc) loss = loss_flow + loss_eppc * args.eppcw loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) optimizer.step() scheduler.step() if args.gpu == 0: logger.push(metrics, image1, image2, flow, flow_predictions, valid, data_blob['depth']) if total_steps % VAL_FREQ == VAL_FREQ - 1: results = validate_kitti(model.module, args, eval_loader, eppCbck, eppconcluer, group) model.train() if args.stage != 'chairs': model.module.freeze_bn() if args.gpu == 0: logger_evaluation.write_dict(results, total_steps) PATH = os.path.join(logroot, '%s.pth' % (str(total_steps + 1).zfill(3))) torch.save(model.state_dict(), PATH) total_steps += 1 if total_steps > args.num_steps: should_keep_training = False break epoch = epoch + 1 if args.gpu == 0: logger.close() PATH = os.path.join(logroot, 'final.pth') torch.save(model.state_dict(), PATH) return PATH
def train(gpu, ngpus_per_node, args): print("Using GPU %d for training" % gpu) args.gpu = gpu if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=ngpus_per_node, rank=args.gpu) model = RAFT(args=args) if args.distributed: torch.cuda.set_device(args.gpu) args.batch_size = int(args.batch_size / ngpus_per_node) model = nn.SyncBatchNorm.convert_sync_batchnorm(module=model) model = model.to(f'cuda:{args.gpu}') model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu], find_unused_parameters=True, output_device=args.gpu) else: model = torch.nn.DataParallel(model) model.cuda() if args.restore_ckpt is not None: print("=> loading checkpoint '{}'".format(args.restore_ckpt)) loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.restore_ckpt, map_location=loc) model.load_state_dict(checkpoint, strict=False) evaluation_entries = read_splits_mapping() eval_dataset = KITTI_eigen_stereo15(root=args.dataset_stereo15_orgned_root, inheight=args.evalheight, inwidth=args.evalwidth, entries=evaluation_entries, mdPred_root=args.mdPred_root, maxinsnum=args.maxinsnum, istrain=True, isgarg=True, deepv2dpred_root=args.deepv2dpred_root, prediction_root=args.prediction_root, flowPred_root=args.flowPred_root) eval_sampler = torch.utils.data.distributed.DistributedSampler( eval_dataset) if args.distributed else None eval_loader = data.DataLoader(eval_dataset, batch_size=1, pin_memory=True, num_workers=3, drop_last=True, sampler=eval_sampler) print("Test splits contain %d images" % (eval_dataset.__len__())) if args.distributed: group = dist.new_group([i for i in range(ngpus_per_node)]) # validate_RAFT_flow(args, model, eval_loader, group, usestreodepth=False) validate_RAFT_flow_pose(args, model, eval_loader, group, usestreodepth=False, scale_info_src='mdPred') # validate_RAFT_flow_pose(args, model, eval_loader, group, usestreodepth=True, scale_info_src='mdPred') # validate_RAFT_flow_pose(args, model, eval_loader, group, usestreodepth=False, scale_info_src='deepv2d_depth') # validate_RAFT_flow_pose(args, model, eval_loader, group, usestreodepth=True, scale_info_src='stereo_depth') # validate_RAFT_flow_pose(args, model, eval_loader, group, usestreodepth=False, scale_info_src='stereo_depth') # validate_RAFT_flow_pose(args, model, eval_loader, group, usestreodepth=True, scale_info_src='deppv2d_pose') # validate_RAFT_flow_pose(args, model, eval_loader, group, usestreodepth=False, scale_info_src='deppv2d_pose') # validate_RAFT_flow_pose(args, model, eval_loader, group, usestreodepth=False) return
def train(gpu, ngpus_per_node, args): print("Using GPU %d for training" % gpu) args.gpu = gpu if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=ngpus_per_node, rank=args.gpu) model = RAFT(args=args) if args.distributed: torch.cuda.set_device(args.gpu) args.batch_size = int(args.batch_size / ngpus_per_node) model = nn.SyncBatchNorm.convert_sync_batchnorm(module=model) model = model.to(f'cuda:{args.gpu}') model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu], find_unused_parameters=True, output_device=args.gpu) else: model = torch.nn.DataParallel(model) model.cuda() logroot = os.path.join(args.logroot, args.name) print("Parameter Count: %d, saving location: %s" % (count_parameters(model), logroot)) if args.restore_ckpt is not None: print("=> loading checkpoint '{}'".format(args.restore_ckpt)) loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.restore_ckpt, map_location=loc) model.load_state_dict(checkpoint, strict=False) model.train() train_entries, evaluation_entries = read_splits() eval_dataset = KITTI_eigen(root=args.dataset_root, inheight=args.evalheight, inwidth=args.evalwidth, entries=evaluation_entries, maxinsnum=args.maxinsnum, depth_root=args.depth_root, depthvls_root=args.depthvlsgt_root, prediction_root=args.prediction_root, ins_root=args.ins_root, istrain=False, isgarg=True) eval_sampler = torch.utils.data.distributed.DistributedSampler( eval_dataset) if args.distributed else None eval_loader = data.DataLoader(eval_dataset, batch_size=1, pin_memory=True, num_workers=3, drop_last=True, sampler=eval_sampler) print("Test splits contain %d images" % (eval_dataset.__len__())) if args.distributed: group = dist.new_group([i for i in range(ngpus_per_node)]) validate_kitti(model.module, args, eval_loader, group) return
def train(gpu, ngpus_per_node, args): print("Using GPU %d for training" % gpu) args.gpu = gpu if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=ngpus_per_node, rank=args.gpu) model = RAFT(args) if args.distributed: torch.cuda.set_device(args.gpu) args.batch_size = int(args.batch_size / ngpus_per_node) model = nn.SyncBatchNorm.convert_sync_batchnorm(module=model) model = model.to(f'cuda:{args.gpu}') model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu], find_unused_parameters=True, output_device=args.gpu) else: model = torch.nn.DataParallel(model) model.cuda() logroot = os.path.join(args.logroot, args.name) print("Parameter Count: %d, saving location: %s" % (count_parameters(model), logroot)) if args.restore_ckpt is not None: print("=> loading checkpoint '{}'".format(args.restore_ckpt)) loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.restore_ckpt, map_location=loc) model.load_state_dict(checkpoint, strict=False) model.train() train_entries, evaluation_entries = read_splits() train_dataset = VirtualKITTI2(args=args, root=args.dataset_root, inheight=args.inheight, inwidth=args.inwidth, entries=train_entries, istrain=True) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) if args.distributed else None train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=False, shuffle=(train_sampler is None), num_workers=args.num_workers, drop_last=True, sampler=train_sampler) eval_dataset = VirtualKITTI2(args=args, root=args.dataset_root, inheight=args.evalheight, inwidth=args.evalwidth, entries=evaluation_entries, istrain=False) eval_sampler = torch.utils.data.distributed.DistributedSampler( eval_dataset) if args.distributed else None eval_loader = data.DataLoader(eval_dataset, batch_size=args.batch_size, pin_memory=False, shuffle=(eval_sampler is None), num_workers=2, drop_last=True, sampler=eval_sampler) print( "Training split contains %d images, validation split contained %d images" % (len(train_entries), len(evaluation_entries))) if args.distributed: group = dist.new_group([i for i in range(ngpus_per_node)]) optimizer, scheduler = fetch_optimizer(args, model) total_steps = 0 VAL_ITERINC = 4 if args.gpu == 0: logger = Logger(model, scheduler, logroot) logger.create_summarywriter() logger_evaluations = dict() for num_iters in range(args.iters, args.iters * 2 + 1, VAL_ITERINC): logger_evaluation = Logger( model, scheduler, os.path.join( args.logroot, 'evaluation_VRKitti', "{}_iternum{}".format(args.name, str(num_iters).zfill(2)))) logger_evaluation.create_summarywriter() logger_evaluations[num_iters] = logger_evaluation VAL_FREQ = 2000 maxout = 1 epoch = 0 st = time.time() should_keep_training = True while should_keep_training: for i_batch, data_blob in enumerate(train_loader): optimizer.zero_grad() image1 = data_blob['img1'].cuda(gpu, non_blocking=True) image2 = data_blob['img2'].cuda(gpu, non_blocking=True) flow = data_blob['flowmap'].cuda(gpu, non_blocking=True) # exlude invalid pixels and extremely large diplacements mag = torch.sum(flow**2, dim=1).sqrt() valid = ((flow[:, 0] != 0) * (flow[:, 1] != 0) * (mag < MAX_FLOW)).unsqueeze(1) flow_predictions = model(image1, image2, iters=args.iters) loss, metrics = sequence_loss(flow_predictions, flow, valid, args.gamma) metrics = dict() metrics['loss_flow'] = loss.float().item() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) optimizer.step() scheduler.step() if args.gpu == 0: logger.write_dict(metrics, total_steps) if total_steps % SUM_FREQ == 0: dr = time.time() - st resths = (args.num_steps - total_steps) * dr / (total_steps + 1) / 60 / 60 print("Step: %d, rest hour: %f, flow loss: %f" % (total_steps, resths, loss.item())) logger.write_vls(data_blob, flow_predictions, valid, total_steps) if total_steps % VAL_FREQ == 1: for num_iters in range(args.iters, args.iters * 2 + 1, VAL_ITERINC): if args.gpu == 0 and num_iters == 24: results = validate_VRKitti2(model.module, args, eval_loader, num_iters, group, logger, total_steps) else: results = validate_VRKitti2(model.module, args, eval_loader, num_iters, group, None, None) if args.gpu == 0: logger_evaluations[num_iters].write_dict( results, total_steps) if num_iters == 24: if results['out'] < maxout: maxout = results['out'] PATH = os.path.join(logroot, 'minout.pth') torch.save(model.state_dict(), PATH) print("model saved to %s" % PATH) model.train() total_steps += 1 if total_steps > args.num_steps: should_keep_training = False break epoch = epoch + 1 if args.gpu == 0: logger.close() PATH = os.path.join(logroot, 'final.pth') torch.save(model.state_dict(), PATH) return PATH