def flow_loader(self, flow_path, resize=1.0, n_pre_cache=100, range=(0, 0)): self.flow_loader_pt = 0 flow_fn_list = sorted(os.listdir(flow_path)) if range != (0, 0): flow_fn_list = flow_fn_list[range[0]:range[1]] print(f'{len(flow_fn_list)} flows loaded') flow_example = load_flow(os.path.join(flow_path, flow_fn_list[0])) self.N_FRAMES = len(flow_fn_list) + 1 self.h = int(flow_example.shape[0] * resize) self.w = int(flow_example.shape[1] * resize) for fn in flow_fn_list: while len(self.flows) - self.fid_cur > n_pre_cache: time.sleep(0.01) flow = load_flow(os.path.join(flow_path, fn)) if flow.shape[0] != self.h or flow.shape[1] != self.w: flow_rescale = (self.w / flow.shape[1], self.h / flow.shape[0]) flow = cv2.resize(flow, (self.w, self.h)) flow[..., 0] *= flow_rescale[0] flow[..., 1] *= flow_rescale[1] self.flows.append(flow) self.flow_loader_pt += 1
def disp_loader(self, disp_path, n_pre_cache=100, range=(0, 0)): self.disp_loader_pt = 0 disp_fn_list = sorted(os.listdir(disp_path)) if range != (0, 0): disp_fn_list = disp_fn_list[range[0]:range[1]] print(f'{len(disp_fn_list)} disparities loaded') for fn in disp_fn_list: while len( self.disps ) - self.fid_cur > n_pre_cache or self.flow_loader_pt <= 0: time.sleep(0.01) if fn.endswith('.flo'): disp = -load_flow(os.path.join(disp_path, fn))[..., 0] disp = np.ascontiguousarray(disp) elif fn.endswith('.png'): disp = cv2.imread(os.path.join(disp_path, fn), cv2.IMREAD_UNCHANGED) disp = disp.astype(np.float32) / 256.0 else: raise f'Unsupported disparity format {fn}' if disp.shape[0] != self.h or disp.shape[1] != self.w: disp_rescale = self.w / disp.shape[1] disp = cv2.resize(disp, (self.w, self.h)) * disp_rescale self.disps.append(disp) self.disp_loader_pt += 1
def __getitem__(self, idx): img1_path, img2_path, flow_path = self.samples[idx] img1, img2 = map(imageio.imread, (img1_path, img2_path)) flow = load_flow(flow_path) if self.color == 'gray': img1 = cv2.cvtColor(img1, cv2.COLOR_RGB2GRAY)[:,:,np.newaxis] img2 = cv2.cvtColor(img2, cv2.COLOR_RGB2GRAY)[:,:,np.newaxis] images = [img1, img2] if self.crop_shape is not None: cropper = StaticRandomCrop(img1.shape[:2], self.crop_shape) if self.cropper == 'random' else StaticCenterCrop(img1.shape[:2], self.crop_shape) # print(cropper) images = list(map(cropper, images)) flow = cropper(flow) if self.resize_shape is not None: resizer = partial(cv2.resize, dsize = (0,0), dst = self.resize_shape) images = list(map(resizer, images)) flow = resizer(flow) elif self.resize_scale is not None: resizer = partial(cv2.resize, dsize = (0,0), fx = self.resize_scale, fy = self.resize_scale) images = list(map(resizer, images)) flow = resizer(flow) images = np.array(images).transpose(3,0,1,2) flow = flow.transpose(2,0,1) images = torch.from_numpy(images.astype(np.float32)) flow = torch.from_numpy(flow.astype(np.float32)) return [images], [flow]