def forward(self, outputs, ray_batch, scalars_to_log): ''' training criterion ''' pred_rgb = outputs['rgb'] pred_mask = outputs['mask'].float() gt_rgb = ray_batch['rgb'] loss = img2mse(pred_rgb, gt_rgb, pred_mask) return loss, scalars_to_log
def ddp_train_nerf(rank, args): ###### set up multi-processing setup(rank, args.world_size) ###### set up logger logger = logging.getLogger(__package__) setup_logger() ###### decide chunk size according to gpu memory logger.info('gpu_mem: {}'.format( torch.cuda.get_device_properties(rank).total_memory)) if torch.cuda.get_device_properties(rank).total_memory / 1e9 > 14: logger.info('setting batch size according to 24G gpu') args.N_rand = 1024 args.chunk_size = 8192 else: logger.info('setting batch size according to 12G gpu') args.N_rand = 512 args.chunk_size = 4096 ###### Create log dir and copy the config file if rank == 0: os.makedirs(os.path.join(args.basedir, args.expname), exist_ok=True) f = os.path.join(args.basedir, args.expname, 'args.txt') with open(f, 'w') as file: for arg in sorted(vars(args)): attr = getattr(args, arg) file.write('{} = {}\n'.format(arg, attr)) if args.config is not None: f = os.path.join(args.basedir, args.expname, 'config.txt') with open(f, 'w') as file: file.write(open(args.config, 'r').read()) torch.distributed.barrier() ray_samplers = load_data_split(args.datadir, args.scene, split='train', try_load_min_depth=args.load_min_depth) val_ray_samplers = load_data_split(args.datadir, args.scene, split='validation', try_load_min_depth=args.load_min_depth, skip=args.testskip) # write training image names for autoexposure if args.optim_autoexpo: f = os.path.join(args.basedir, args.expname, 'train_images.json') with open(f, 'w') as file: img_names = [ ray_samplers[i].img_path for i in range(len(ray_samplers)) ] json.dump(img_names, file, indent=2) ###### create network and wrap in ddp; each process should do this start, models = create_nerf(rank, args) ##### important!!! # make sure different processes sample different rays np.random.seed((rank + 1) * 777) # make sure different processes have different perturbations in depth samples torch.manual_seed((rank + 1) * 777) ##### only main process should do the logging if rank == 0: writer = SummaryWriter( os.path.join(args.basedir, 'summaries', args.expname)) # start training what_val_to_log = 0 # helper variable for parallel rendering of a image what_train_to_log = 0 for global_step in range(start + 1, start + 1 + args.N_iters): time0 = time.time() scalars_to_log = OrderedDict() ### Start of core optimization loop scalars_to_log['resolution'] = ray_samplers[0].resolution_level # randomly sample rays and move to device i = np.random.randint(low=0, high=len(ray_samplers)) ray_batch = ray_samplers[i].random_sample(args.N_rand, center_crop=False) for key in ray_batch: if torch.is_tensor(ray_batch[key]): ray_batch[key] = ray_batch[key].to(rank) # forward and backward dots_sh = list(ray_batch['ray_d'].shape[:-1]) # number of rays all_rets = [] # results on different cascade levels for m in range(models['cascade_level']): optim = models['optim_{}'.format(m)] net = models['net_{}'.format(m)] # sample depths N_samples = models['cascade_samples'][m] if m == 0: # foreground depth fg_far_depth = intersect_sphere(ray_batch['ray_o'], ray_batch['ray_d']) # [...,] fg_near_depth = ray_batch['min_depth'] # [..., ] step = (fg_far_depth - fg_near_depth) / (N_samples - 1) fg_depth = torch.stack( [fg_near_depth + i * step for i in range(N_samples)], dim=-1) # [..., N_samples] fg_depth = perturb_samples( fg_depth) # random perturbation during training # background depth bg_depth = torch.linspace(0., 1., N_samples).view([ 1, ] * len(dots_sh) + [ N_samples, ]).expand(dots_sh + [ N_samples, ]).to(rank) bg_depth = perturb_samples( bg_depth) # random perturbation during training else: # sample pdf and concat with earlier samples fg_weights = ret['fg_weights'].clone().detach() fg_depth_mid = .5 * (fg_depth[..., 1:] + fg_depth[..., :-1] ) # [..., N_samples-1] fg_weights = fg_weights[..., 1:-1] # [..., N_samples-2] fg_depth_samples = sample_pdf(bins=fg_depth_mid, weights=fg_weights, N_samples=N_samples, det=False) # [..., N_samples] fg_depth, _ = torch.sort( torch.cat((fg_depth, fg_depth_samples), dim=-1)) # sample pdf and concat with earlier samples bg_weights = ret['bg_weights'].clone().detach() bg_depth_mid = .5 * (bg_depth[..., 1:] + bg_depth[..., :-1]) bg_weights = bg_weights[..., 1:-1] # [..., N_samples-2] bg_depth_samples = sample_pdf(bins=bg_depth_mid, weights=bg_weights, N_samples=N_samples, det=False) # [..., N_samples] bg_depth, _ = torch.sort( torch.cat((bg_depth, bg_depth_samples), dim=-1)) optim.zero_grad() ret = net(ray_batch['ray_o'], ray_batch['ray_d'], fg_far_depth, fg_depth, bg_depth, img_name=ray_batch['img_name']) all_rets.append(ret) rgb_gt = ray_batch['rgb'].to(rank) if 'autoexpo' in ret: scale, shift = ret['autoexpo'] scalars_to_log['level_{}/autoexpo_scale'.format( m)] = scale.item() scalars_to_log['level_{}/autoexpo_shift'.format( m)] = shift.item() # rgb_gt = scale * rgb_gt + shift rgb_pred = (ret['rgb'] - shift) / scale rgb_loss = img2mse(rgb_pred, rgb_gt) loss = rgb_loss + args.lambda_autoexpo * ( torch.abs(scale - 1.) + torch.abs(shift)) else: rgb_loss = img2mse(ret['rgb'], rgb_gt) loss = rgb_loss scalars_to_log['level_{}/loss'.format(m)] = rgb_loss.item() scalars_to_log['level_{}/pnsr'.format(m)] = mse2psnr( rgb_loss.item()) loss.backward() optim.step() # # clean unused memory # torch.cuda.empty_cache() ### end of core optimization loop dt = time.time() - time0 scalars_to_log['iter_time'] = dt ### only main process should do the logging if rank == 0 and (global_step % args.i_print == 0 or global_step < 10): logstr = '{} step: {} '.format(args.expname, global_step) for k in scalars_to_log: logstr += ' {}: {:.6f}'.format(k, scalars_to_log[k]) writer.add_scalar(k, scalars_to_log[k], global_step) logger.info(logstr) ### each process should do this; but only main process merges the results if global_step % args.i_img == 0 or global_step == start + 1: #### critical: make sure each process is working on the same random image if len(val_ray_samplers) != 0: time0 = time.time() idx = what_val_to_log % len(val_ray_samplers) log_data = render_single_image(rank, args.world_size, models, val_ray_samplers[idx], args.chunk_size) what_val_to_log += 1 dt = time.time() - time0 if rank == 0: # only main process should do this logger.info( 'Logged a random validation view in {} seconds'.format( dt)) log_view_to_tb(writer, global_step, log_data, gt_img=val_ray_samplers[idx].get_img(), mask=None, prefix='val/') time0 = time.time() idx = what_train_to_log % len(ray_samplers) log_data = render_single_image(rank, args.world_size, models, ray_samplers[idx], args.chunk_size) what_train_to_log += 1 dt = time.time() - time0 if rank == 0: # only main process should do this logger.info( 'Logged a random training view in {} seconds'.format(dt)) log_view_to_tb(writer, global_step, log_data, gt_img=ray_samplers[idx].get_img(), mask=None, prefix='train/') del log_data torch.cuda.empty_cache() if rank == 0 and (global_step % args.i_weights == 0 and global_step > 0): # saving checkpoints and logging fpath = os.path.join(args.basedir, args.expname, 'model_{:06d}.pth'.format(global_step)) to_save = OrderedDict() for m in range(models['cascade_level']): name = 'net_{}'.format(m) to_save[name] = models[name].state_dict() name = 'optim_{}'.format(m) to_save[name] = models[name].state_dict() torch.save(to_save, fpath) # clean up for multi-processing cleanup()
def train(args): device = "cuda:{}".format(args.local_rank) out_folder = os.path.join(args.rootdir, 'out', args.expname) print('outputs will be saved to {}'.format(out_folder)) os.makedirs(out_folder, exist_ok=True) # save the args and config files f = os.path.join(out_folder, 'args.txt') with open(f, 'w') as file: for arg in sorted(vars(args)): attr = getattr(args, arg) file.write('{} = {}\n'.format(arg, attr)) if args.config is not None: f = os.path.join(out_folder, 'config.txt') if not os.path.isfile(f): shutil.copy(args.config, f) # create training dataset train_dataset, train_sampler = create_training_dataset(args) # currently only support batch_size=1 (i.e., one set of target and source views) for each GPU node # please use distributed parallel on multiple GPUs to train multiple target views per batch train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, worker_init_fn=lambda _: np.random.seed(), num_workers=args.workers, pin_memory=True, sampler=train_sampler, shuffle=True if train_sampler is None else False) # create validation dataset val_dataset = dataset_dict[args.eval_dataset](args, 'validation', scenes=args.eval_scenes) val_loader = DataLoader(val_dataset, batch_size=1) val_loader_iterator = iter(cycle(val_loader)) # Create IBRNet model model = IBRNetModel(args, load_opt=not args.no_load_opt, load_scheduler=not args.no_load_scheduler) # create projector projector = Projector(device=device) # Create criterion criterion = Criterion() tb_dir = os.path.join(args.rootdir, 'logs/', args.expname) if args.local_rank == 0: writer = SummaryWriter(tb_dir) print('saving tensorboard files to {}'.format(tb_dir)) scalars_to_log = {} global_step = model.start_step + 1 epoch = 0 while global_step < model.start_step + args.n_iters + 1: np.random.seed() for train_data in train_loader: time0 = time.time() if args.distributed: train_sampler.set_epoch(epoch) # Start of core optimization loop # load training rays ray_sampler = RaySamplerSingleImage(train_data, device) N_rand = int(1.0 * args.N_rand * args.num_source_views / train_data['src_rgbs'][0].shape[0]) ray_batch = ray_sampler.random_sample(N_rand, sample_mode=args.sample_mode, center_ratio=args.center_ratio, ) featmaps = model.feature_net(ray_batch['src_rgbs'].squeeze(0).permute(0, 3, 1, 2)) ret = render_rays(ray_batch=ray_batch, model=model, projector=projector, featmaps=featmaps, N_samples=args.N_samples, inv_uniform=args.inv_uniform, N_importance=args.N_importance, det=args.det, white_bkgd=args.white_bkgd) # compute loss model.optimizer.zero_grad() loss, scalars_to_log = criterion(ret['outputs_coarse'], ray_batch, scalars_to_log) if ret['outputs_fine'] is not None: fine_loss, scalars_to_log = criterion(ret['outputs_fine'], ray_batch, scalars_to_log) loss += fine_loss loss.backward() scalars_to_log['loss'] = loss.item() model.optimizer.step() model.scheduler.step() scalars_to_log['lr'] = model.scheduler.get_last_lr()[0] # end of core optimization loop dt = time.time() - time0 # Rest is logging if args.local_rank == 0: if global_step % args.i_print == 0 or global_step < 10: # write mse and psnr stats mse_error = img2mse(ret['outputs_coarse']['rgb'], ray_batch['rgb']).item() scalars_to_log['train/coarse-loss'] = mse_error scalars_to_log['train/coarse-psnr-training-batch'] = mse2psnr(mse_error) if ret['outputs_fine'] is not None: mse_error = img2mse(ret['outputs_fine']['rgb'], ray_batch['rgb']).item() scalars_to_log['train/fine-loss'] = mse_error scalars_to_log['train/fine-psnr-training-batch'] = mse2psnr(mse_error) logstr = '{} Epoch: {} step: {} '.format(args.expname, epoch, global_step) for k in scalars_to_log.keys(): logstr += ' {}: {:.6f}'.format(k, scalars_to_log[k]) writer.add_scalar(k, scalars_to_log[k], global_step) print(logstr) print('each iter time {:.05f} seconds'.format(dt)) if global_step % args.i_weights == 0: print('Saving checkpoints at {} to {}...'.format(global_step, out_folder)) fpath = os.path.join(out_folder, 'model_{:06d}.pth'.format(global_step)) model.save_model(fpath) if global_step % args.i_img == 0: print('Logging a random validation view...') val_data = next(val_loader_iterator) tmp_ray_sampler = RaySamplerSingleImage(val_data, device, render_stride=args.render_stride) H, W = tmp_ray_sampler.H, tmp_ray_sampler.W gt_img = tmp_ray_sampler.rgb.reshape(H, W, 3) log_view_to_tb(writer, global_step, args, model, tmp_ray_sampler, projector, gt_img, render_stride=args.render_stride, prefix='val/') torch.cuda.empty_cache() print('Logging current training view...') tmp_ray_train_sampler = RaySamplerSingleImage(train_data, device, render_stride=1) H, W = tmp_ray_train_sampler.H, tmp_ray_train_sampler.W gt_img = tmp_ray_train_sampler.rgb.reshape(H, W, 3) log_view_to_tb(writer, global_step, args, model, tmp_ray_train_sampler, projector, gt_img, render_stride=1, prefix='train/') global_step += 1 if global_step > model.start_step + args.n_iters + 1: break epoch += 1