def __getitem__(self, i): ref_idx = i % self.num_view #31# src_idxs = self.pair[ref_idx][:self.num_src] #[22, 48, 0]# light_idx = i // self.num_view % self.num_light if self.fix_light is None else self.fix_light #3# scan_idx = i // self.num_view // self.num_light #int(os.environ['SCAN']) #2# ref = self.data_list[scan_idx][light_idx][ref_idx] scan = int(ref[0].split('/')[1][4:]) srcs = [ self.data_list[scan_idx][light_idx][source_idx] for source_idx in src_idxs ] masks = [ f'occlusion2/scan{scan}/{ref_idx}_{src_idx}.png' for src_idx in src_idxs ] skip = 0 filenames = { 'ref': ref[0], 'ref_cam': ref[1], 'srcs': [srcs[i][0] for i in range(self.num_src)], 'srcs_cam': [srcs[i][1] for i in range(self.num_src)], 'gt': ref[2], 'masks': masks } recursive_apply(filenames, lambda fn: os.path.join(self.root, fn)) filenames['skip'] = skip sample = self.read(filenames) for transform in self.transforms: sample = transform(sample) return sample
def __getitem__(self, i): filenames = copy.deepcopy(self.train_data[i]) recursive_apply(filenames, lambda fn: os.path.join(self.root, fn)) sample = self.read(filenames) for transform in self.transforms: sample = transform(sample) return sample
n_views = len(pair['id_list']) views = {} for i, id in tqdm(enumerate(pair['id_list']), 'load data', n_views): image = cv2.imread(f'{args.data}/{id.zfill(8)}.jpg').transpose(2,0,1)[::-1] cam = load_cam(f'{args.data}/cam_{id.zfill(8)}_flow3.txt', 256, 1) depth = np.expand_dims(load_pfm(f'{args.data}/{id.zfill(8)}_flow3.pfm'), axis=0) probs = np.stack([load_pfm(f'{args.data}/{id.zfill(8)}_flow{k+1}_prob.pfm') for k in range(3)], axis=0) views[id] = { 'image': image, # 13hw (after next step) 'cam': cam, # 1244 'depth': depth, # 11hw 'prob': probs, # 13hw } recursive_apply(views[id], lambda arr: torch.from_numpy(np.ascontiguousarray(arr)).float().unsqueeze(0)) for i, id in tqdm(enumerate(pair['id_list']), 'prob filter', n_views): views[id]['mask'] = prob_filter(views[id]['prob'].cuda(), pthresh).cpu() # 11hw bool views[id]['depth'] *= views[id]['mask'] update = {} for i, id in tqdm(enumerate(pair['id_list']), 'vis filter and med fusion', n_views): srcs_id = pair[id]['pair'][:args.view] ref_depth_g, ref_cam_g = views[id]['depth'].cuda(), views[id]['cam'].cuda() srcs_depth_g, srcs_cam_g = [torch.stack([views[loop_id][attr] for loop_id in srcs_id], dim=1).cuda() for attr in ['depth', 'cam']] reproj_xyd_g, in_range_g = get_reproj(ref_depth_g, srcs_depth_g, ref_cam_g, srcs_cam_g) vis_masks_g, vis_mask_g = vis_filter(ref_depth_g, reproj_xyd_g, in_range_g, 1, 0.01, args.vthresh) ref_depth_med_g = med_fusion(ref_depth_g, reproj_xyd_g, vis_masks_g, vis_mask_g)
model.cuda() # model = amp.initialize(model, opt_level='O0') model = nn.DataParallel(model) print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters() if p.requires_grad]))) load_model(model, args.load_path, args.load_step) print(f'load {os.path.join(args.load_path, str(args.load_step))}') model.eval() pbar = tqdm.tqdm(enumerate(loader), dynamic_ncols=True, total=len(loader)) # pbar = itertools.product(range(num_scan), range(num_ref), range(num_view)) for i, sample in pbar: if sample.get('skip') is not None and np.any(sample['skip']): raise ValueError() ref, ref_cam, srcs, srcs_cam, gt, masks = [sample[attr] for attr in ['ref', 'ref_cam', 'srcs', 'srcs_cam', 'gt', 'masks']] recursive_apply(sample, lambda x: torch.from_numpy(x).float().cuda()) ref_t, ref_cam_t, srcs_t, srcs_cam_t, gt_t, masks_t = [sample[attr] for attr in ['ref', 'ref_cam', 'srcs', 'srcs_cam', 'gt', 'masks']] with torch.no_grad(): # est_depth, prob_map, pair_results = model([ref_t, ref_cam_t, srcs_t, srcs_cam_t], args.max_d, upsample=True, mem=True, mode=args.mode) #MVS outputs, refined_depth, prob_maps = model(sample, cas_depth_num, cas_interv_scale, mode=args.mode) [[est_depth_1, pair_results_1], [est_depth_2, pair_results_2], [est_depth_3, pair_results]] = outputs # est_depth = model([ref_t, ref_cam_t, srcs_t, srcs_cam_t, gt_t], args.max_d) # est_depth, prob_map = [arr.clone().cpu().data.numpy() for arr in [refined_depth, prob_map]] est_depth, *prob_maps = [arr.clone().cpu().data.numpy() for arr in [refined_depth] + prob_maps] recursive_apply(pair_results, lambda x: x.clone().cpu().data.numpy()) #MVS pbar.set_description(f'{est_depth.shape}') if (i % 49 == 0 or True) and (args.show_result or args.write_result): if args.show_result:
for param_group in optimizer.param_groups: param_group['lr'] = curr_lr return curr_lr model.train() pbar = tqdm.tqdm(loader, dynamic_ncols=True) if global_step != 0: pbar.update(global_step) for sample in pbar: if global_step >= total_steps: break if sample.get('skip') is not None and np.any(sample['skip']): continue curr_lr = piecewise_constant() recursive_apply(sample, lambda x: torch.from_numpy(x).float().cuda()) ref, ref_cam, srcs, srcs_cam, gt, masks = [ sample[attr] for attr in ['ref', 'ref_cam', 'srcs', 'srcs_cam', 'gt', 'masks'] ] loss, uncert_loss, less1, less3, l1, losses, outputs, refined_depth = None, None, None, None, None, None, None, None try: outputs, refined_depth, _ = model(sample, cas_depth_num, cas_interv_scale, mode=args.mode) losses = compute_loss([outputs, refined_depth], gt, masks, ref_cam,