def preprocess(x_2d, scene_rgb, scene_depth, scene_K, scene_Tcw, scene_ori_rgb, scene_center=None): N, L, C, H, W = scene_rgb.shape _, _, _, ori_H, ori_W = scene_ori_rgb.shape scene_rgb = scene_rgb.view(N * L, C, H, W) scene_depth = scene_depth.view(N * L, 1, H, W) scene_K = scene_K.view(N * L, 3, 3) scene_Tcw = scene_Tcw.view(N * L, 3, 4) # generate 3D world position of scene d = scene_depth.view(N * L, H * W, 1) # dim (N*L, H*W, 1) X_3d = batched_pi_inv(scene_K, x_2d, d) # dim (N*L, H*W, 3) Rwc, twc = batched_inv_pose(R=scene_Tcw[:, :3, :3], t=scene_Tcw[:, :3, 3].squeeze(-1)) # dim (N*L, 3, 3), (N, 3) X_world = batched_transpose(Rwc.cuda(), twc.cuda(), X_3d) # dim (N*L, H*W, 3) X_world = X_world.view(N, L * H * W, 3) # dim (N, L*H*W, 3) if scene_center is None: scene_center = torch.mean(X_world, dim=1) # dim (N, 3) X_world -= scene_center.view(N, 1, 3) X_world = X_world.view(N * L, H, W, 3).permute(0, 3, 1, 2).contiguous() # dim (N * L, 3, H, W) scene_input = torch.cat((scene_rgb, X_world), dim=1) return scene_input.cuda(), \ scene_ori_rgb.view((N*L, 3, ori_H, ori_W)).cuda(), \ X_world.cuda(), \ torch.gt(scene_depth, 1e-5).cuda().view(N * L, H, W), \ scene_center.cuda()
def preprocess_scene(x_2d, scene_rgb, scene_depth, scene_K, scene_Tcw, scene_ori_rgb): rand_R = torch.eye(3).unsqueeze(0) N, L, C, H, W = scene_rgb.shape # scene_rgb = scene_rgb.view(N, L, C, H, W) scene_depth = scene_depth.view(N * L, 1, H, W) scene_K = scene_K.view(N * L, 3, 3) scene_Tcw = scene_Tcw.view(N * L, 3, 4) # generate 3D world position of scene d = scene_depth.view(N * L, H * W, 1) # dim (N*L, H*W, 1) X_3d = batched_pi_inv(scene_K, x_2d, d) # dim (N*L, H*W, 3) Rwc, twc = batched_inv_pose( R=scene_Tcw[:, :3, :3], t=scene_Tcw[:, :3, 3].squeeze(-1)) # dim (N*L, 3, 3), (N, 3) X_world = batched_transpose(Rwc.cuda(), twc.cuda(), X_3d) # dim (N*L, H*W, 3) X_world = X_world.view(N, L * H * W, 3) # dim (N, L*H*W, 3) scene_center = torch.mean(X_world, dim=1) # dim (N, 3) X_world -= scene_center.view(N, 1, 3) X_world = batched_transpose( rand_R.cuda().expand(N, 3, 3), torch.zeros(1, 3, 1).cuda().expand(N, 3, 1), X_world) # dim (N, L*H*W, 3), data augmentation X_world = X_world.view(N, L, H, W, 3).permute(0, 1, 4, 2, 3).contiguous() # dim (N, L, 3, H, W) scene_input = torch.cat((scene_rgb, X_world), dim=2) return scene_input.cuda(), scene_ori_rgb.cuda(), X_world.cuda(), \ torch.gt(scene_depth, 1e-5).cuda().view(N, L, H, W), scene_center.cuda(), rand_R.expand(N, 3, 3).cuda()
def preprocess(sample_dict, pre_x2d, out_dim, rescale_dist=0.0): rand_angle = np.random.random_sample() * 2.0 * np.pi rand_R = quaternion_matrix(quaternion_about_axis(rand_angle, (0.0, 1.0, 0.0)))[:3, :3] rand_R = torch.FloatTensor(rand_R).unsqueeze(0) scene_rgb = sample_dict['frames_img'][:, :5, ...].cuda() scene_depth = sample_dict['frames_depth'][:, :5, ...].cuda() scene_K = sample_dict['frames_K'][:, :5, ...].cuda() scene_Tcw = sample_dict['frames_Tcw'][:, :5, ...] scene_ori_rgb = sample_dict['frames_ori_img'][:, :5, ...].cuda() scene_neg_tags = sample_dict['frames_neg_tags'][:, :5, ...].cuda() N, L, C, H, W = scene_rgb.shape # scene_rgb = scene_rgb.view(N, L, C, H, W) scene_depth = scene_depth.view(N * L, 1, H, W) scene_K = scene_K.view(N * L, 3, 3) scene_Tcw = scene_Tcw.view(N * L, 3, 4) # generate 3D world position of scene d = scene_depth.view(N * L, H * W, 1) # dim (N*L, H*W, 1) X_3d = batched_pi_inv(scene_K, pre_x2d, d) # dim (N*L, H*W, 3) Rwc, twc = batched_inv_pose(R=scene_Tcw[:, :3, :3], t=scene_Tcw[:, :3, 3].squeeze(-1)) # dim (N*L, 3, 3), (N, 3) X_world = batched_transpose(Rwc.cuda(), twc.cuda(), X_3d) # dim (N*L, H*W, 3) X_world = X_world.contiguous().view(N, L * H * W, 3) # dim (N, L*H*W, 3) scene_center = torch.mean(X_world, dim=1) # dim (N, 3) X_world -= scene_center.view(N, 1, 3) X_world = batched_transpose(rand_R.cuda().expand(N, 3, 3), torch.zeros(1, 3, 1).cuda().expand(N, 3, 1), X_world) # dim (N, L*H*W, 3), data augmentation X_world = X_world.view(N, L, H, W, 3).permute(0, 1, 4, 2, 3).contiguous() # dim (N, L, 3, H, W) # query image: query_img = sample_dict['img'] query_ori_img = sample_dict['ori_img'] # compute multiscale ground truth query_X_worlds & valid_masks query_X_worlds = [] valid_masks = [] out_H, out_W = out_dim query_depth = sample_dict['depth'].cuda() ori_query_depth = query_depth.clone() N, C, H, W = query_depth.shape for i in range(4): query_depth_patch = F.unfold( query_depth, kernel_size=(H // out_H, W // out_W), stride=(H // out_H, W // out_W) ).view(N, -1, out_H, out_W) mask = torch.gt(query_depth_patch, 1e-5) count = torch.sum(mask.float(), dim=1) query_depth_down = torch.sum(query_depth_patch * mask.float(), dim=1) / \ torch.where(torch.le(count, 1e-5), torch.full(count.shape, 1e6).to(count.device), count) # (N, 1, out_H, out_W) query_Tcw = sample_dict['Tcw'] query_K = sample_dict['K'].clone().cuda() query_K[:, 0, 0] *= out_W / W query_K[:, 0, 2] *= out_W / W query_K[:, 1, 1] *= out_H / H query_K[:, 1, 2] *= out_H / H query_d = query_depth_down.view(N, out_H * out_W, 1) # dim (N, H*W, 1) out_x_2d = x_2d_coords_torch(N, out_H, out_W).cuda().view(N, -1, 2) query_X_3d = batched_pi_inv(query_K, out_x_2d, query_d) # dim (N, H*W, 3) query_Rwc, query_twc = batched_inv_pose(R=query_Tcw[:, :3, :3], t=query_Tcw[:, :3, 3].squeeze(-1)) # dim (N, 3, 3), (N, 3) query_X_world = batched_transpose(query_Rwc.cuda(), query_twc.cuda(), query_X_3d) # dim (N, H*W, 3) query_X_world -= scene_center.view(N, 1, 3) query_X_world = batched_transpose(rand_R.cuda().expand(N, 3, 3), torch.zeros(1, 3, 1).cuda().expand(N, 3, 1), query_X_world) # dim (N, H*W, 3), data augmentation query_X_world = query_X_world.permute(0, 2, 1).view(N, 3, out_H, out_W).contiguous() # dim (N, 3, H, W) query_X_worlds.append(query_X_world.cuda()) valid_masks.append(torch.gt(query_depth_down, 1e-5).cuda().view(N, out_H, out_W)) if i == 3: query_X_worlds.append(query_X_world.cuda()) valid_masks.append(torch.gt(query_depth_down, 1e-5).cuda().view(N, out_H, out_W)) out_H //= 2 out_W //= 2 # compute norm_query_Tcw for normalized scene coordinate query_twc = query_twc.cuda() - scene_center.view(N, 3, 1) norm_query_Twc = torch.cat([query_Rwc.cuda(), query_twc], dim=-1) # dim (N, 3, 4) norm_query_Twc = torch.bmm(rand_R.cuda().expand(N, 3, 3), norm_query_Twc) # dim (N, 3, 4) query_Rcw, query_tcw = batched_inv_pose(R=norm_query_Twc[:, :3, :3], t=norm_query_Twc[:, :3, 3].squeeze(-1)) # dim (N, 3, 3), (N, 3) norm_query_Tcw = torch.cat([query_Rcw, query_tcw.view(N, 3, 1)], dim=-1) # dim (N, 3, 4) # compute down sampled query K out_H, out_W = out_dim query_K = sample_dict['K'].clone().cuda() query_K[:, 0, 0] *= out_W / W query_K[:, 0, 2] *= out_W / W query_K[:, 1, 1] *= out_H / H query_K[:, 1, 2] *= out_H / H if rescale_dist > 0: query_X_worlds, X_world, rescale_factor = rescale_scene_coords(query_X_worlds, X_world, scene_neg_tags, rescale_dist) else: rescale_factor = torch.ones(N) scene_input = torch.cat((scene_rgb, X_world), dim=2) return scene_input.cuda(), query_img.cuda(), query_X_worlds[::-1], valid_masks[::-1], \ scene_ori_rgb.cuda(), query_ori_img.cuda(), X_world.cuda(), \ torch.gt(scene_depth, 1e-5).cuda().view(N, L, H, W), norm_query_Tcw, query_K, scene_neg_tags, rescale_factor.cuda()
def preprocess_query(query_img, query_depth, query_ori_img, query_Tcw, ori_query_K, scene_center, rand_R, out_dim=(48, 64)): # compute multiscale ground truth query_X_worlds & valid_masks query_X_worlds = [] valid_masks = [] out_H, out_W = out_dim ori_query_depth = query_depth.clone() N, C, H, W = query_depth.shape for i in range(5): query_depth_patch = F.unfold(query_depth, kernel_size=(H // out_H, W // out_W), stride=(H // out_H, W // out_W)).view( N, -1, out_H, out_W) mask = torch.gt(query_depth_patch, 1e-5) count = torch.sum(mask.float(), dim=1) query_depth_down = torch.sum(query_depth_patch * mask.float(), dim=1) /\ torch.where(torch.le(count, 1e-5), torch.full(count.shape, 1e6).cuda(), count) # (N, 1, out_H, out_W) query_K = ori_query_K.clone().cuda() query_K[:, 0, 0] *= out_W / W query_K[:, 0, 2] *= out_W / W query_K[:, 1, 1] *= out_H / H query_K[:, 1, 2] *= out_H / H query_d = query_depth_down.view(N, out_H * out_W, 1) # dim (N, H*W, 1) out_x_2d = x_2d_coords_torch(N, out_H, out_W).cuda().view(N, -1, 2) query_X_3d = batched_pi_inv(query_K, out_x_2d, query_d) # dim (N, H*W, 3) query_Rwc, query_twc = batched_inv_pose( R=query_Tcw[:, :3, :3], t=query_Tcw[:, :3, 3].squeeze(-1)) # dim (N, 3, 3), (N, 3) query_X_world = batched_transpose(query_Rwc.cuda(), query_twc.cuda(), query_X_3d) # dim (N, H*W, 3) query_X_world -= scene_center.view(N, 1, 3) query_X_world = batched_transpose( rand_R.cuda().expand(N, 3, 3), torch.zeros(1, 3, 1).cuda().expand(N, 3, 1), query_X_world) # dim (N, H*W, 3), data augmentation query_X_world = query_X_world.permute(0, 2, 1).view( N, 3, out_H, out_W).contiguous() # dim (N, 3, H, W) query_X_worlds.append(query_X_world.cuda()) valid_masks.append( torch.gt(query_depth_down, 1e-5).cuda().view(N, out_H, out_W)) # if i == 3: # query_X_worlds.append(query_X_world.cuda()) # valid_masks.append(torch.gt(query_depth_down, 1e-5).cuda().view(N, out_H, out_W)) out_H //= 2 out_W //= 2 # compute norm_query_Tcw for normalized scene coordinate query_twc = query_twc.cuda() - scene_center.view(N, 3, 1) norm_query_Twc = torch.cat([query_Rwc.cuda(), query_twc], dim=-1) # dim (N, 3, 4) norm_query_Twc = torch.bmm(rand_R.cuda().expand(N, 3, 3), norm_query_Twc) # dim (N, 3, 4) query_Rcw, query_tcw = batched_inv_pose( R=norm_query_Twc[:, :3, :3], t=norm_query_Twc[:, :3, 3].squeeze(-1)) # dim (N, 3, 3), (N, 3) norm_query_Tcw = torch.cat([query_Rcw, query_tcw.view(N, 3, 1)], dim=-1) # dim (N, 3, 4) # compute down sampled query K out_H, out_W = out_dim query_K = ori_query_K.clone().cuda() query_K[:, 0, 0] *= out_W / W query_K[:, 0, 2] *= out_W / W query_K[:, 1, 1] *= out_H / H query_K[:, 1, 2] *= out_H / H return query_img.cuda(), query_X_worlds[::-1], valid_masks[::-1], query_ori_img.cuda(), \ scene_center.cuda(), query_Tcw.cuda(), query_K.cuda(), rand_R.expand(N, 3, 3).cuda()