def draw_corners_on_image(self, rgb, corners_cam, centers_cam, scores, tids, pix_T_cam): # first we need to get rid of invalid gt boxes # gt_boxes = trim_gt_boxes(gt_boxes) B, C, H, W = list(rgb.shape) assert(C==3) B2, N, D, E = list(corners_cam.shape) assert(B2==B) assert(D==8) # 8 corners assert(E==3) # 3D rgb = back2color(rgb) corners_cam_ = torch.reshape(corners_cam, [B, N*8, 3]) centers_cam_ = torch.reshape(centers_cam, [B, N*1, 3]) corners_pix_ = utils_geom.apply_pix_T_cam(pix_T_cam, corners_cam_) centers_pix_ = utils_geom.apply_pix_T_cam(pix_T_cam, centers_cam_) corners_pix = torch.reshape(corners_pix_, [B, N, 8, 2]) centers_pix = torch.reshape(centers_pix_, [B, N, 1, 2]) out = self.draw_boxes_on_image_py(rgb[0].detach().cpu().numpy(), corners_pix[0].detach().cpu().numpy(), centers_pix[0].detach().cpu().numpy(), scores[0].detach().cpu().numpy(), tids[0].detach().cpu().numpy()) out = torch.from_numpy(out).type(torch.ByteTensor).permute(2, 0, 1) out = torch.unsqueeze(out, dim=0) out = preprocess_color(out) out = torch.reshape(out, [1, C, H, W]) return out
def compute_box_3d(obj, P): ''' Takes an object and a projection matrix (P) and projects the 3d bounding box into the image plane. Returns: corners_2d: (8,2) array in left image coord. corners_3d: (8,3) array in in rect camera coord. ''' # st() pix_T_cam = np.zeros((4, 4)) pix_T_cam[:3, :] = P pix_T_cam[3] = [0, 0, 0, 1] corners_3d = np.array(obj.box3d).reshape((8, 3)) # corners_3d = utils_geom.transform_boxes3D_to_corners_py(box) corners_3d_in_2d = utils_geom.apply_pix_T_cam( torch.from_numpy(pix_T_cam).unsqueeze(0), torch.from_numpy(corners_3d).unsqueeze(0)) return corners_3d_in_2d.squeeze(0).numpy(), corners_3d
def forward(self, feat_cam0, feat_cam1, mask_mem0, pix_T_cam0, pix_T_cam1, cam1_T_cam0, vox_util, summ_writer=None): total_loss = torch.tensor(0.0).cuda() B, C, Z, Y, X = list(mask_mem0.shape) assert (C == 1) B2, C, H, W = list(feat_cam0.shape) assert (B == B2) go_slow = True go_slow = False if go_slow: xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X) mask_mem0 = mask_mem0.reshape(B, Z * Y * X) vec0_list = [] vec1_list = [] for b in list(range(B)): xyz_mem0_b = xyz_mem0[b] mask_mem0_b = mask_mem0[b] xyz_mem0_b = xyz_mem0_b[torch.where(mask_mem0_b > 0)] # this is N x 3 N, D = list(xyz_mem0_b.shape) if N > self.num_samples: # to not waste time, i will subsample right here perm = np.random.permutation(N) xyz_mem0_b = xyz_mem0_b[perm[:self.num_samples]] # this is num_samples x 3 (smaller than before) xyz_cam0_b = vox_util.Mem2Ref(xyz_mem0_b.unsqueeze(0), Z, Y, X) xyz_cam1_b = utils_geom.apply_4x4(cam1_T_cam0[b:b + 1], xyz_cam0_b) # these are N x 3 # now, i need to project both of these, and sample from the feats xy_cam0_b = utils_geom.apply_pix_T_cam(pix_T_cam0[b:b + 1], xyz_cam0_b).squeeze(0) xy_cam1_b = utils_geom.apply_pix_T_cam(pix_T_cam1[b:b + 1], xyz_cam1_b).squeeze(0) # these are N x 2 vec0 = utils_samp.bilinear_sample_single( feat_cam0[b], xy_cam0_b[:, 0], xy_cam0_b[:, 1]) vec1 = utils_samp.bilinear_sample_single( feat_cam1[b], xy_cam1_b[:, 0], xy_cam1_b[:, 1]) # these are C x N x_pix0 = xy_cam0_b[:, 0] y_pix0 = xy_cam0_b[:, 1] x_pix1 = xy_cam1_b[:, 0] y_pix1 = xy_cam1_b[:, 1] y_pix0, x_pix0 = utils_basic.normalize_grid2D( y_pix0, x_pix0, H, W) y_pix1, x_pix1 = utils_basic.normalize_grid2D( y_pix1, x_pix1, H, W) xy_pix0 = torch.stack([x_pix0, y_pix0], axis=1).unsqueeze(0) xy_pix1 = torch.stack([x_pix1, y_pix1], axis=1).unsqueeze(0) # these are 1 x N x 2 print('xy_pix0', xy_pix0.shape) vec0 = F.grid_sample(feat_cam0[b:b + 1], xy_pix0) vec1 = F.grid_sample(feat_cam1[b:b + 1], xy_pix1) print('vec0', vec0.shape) vec0_list.append(vec0) vec1_list.append(vec1) vec0 = torch.cat(vec0_list, dim=1).permute(1, 0) vec1 = torch.cat(vec1_list, dim=1).permute(1, 0) else: xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X) mask_mem0 = mask_mem0.reshape(B, Z * Y * X) valid_batches = 0 sampling_coords_mem0 = torch.zeros(B, self.num_samples, 3).float().cuda() valid_feat_cam0 = torch.zeros_like(feat_cam0) valid_feat_cam1 = torch.zeros_like(feat_cam1) valid_pix_T_cam0 = torch.zeros_like(pix_T_cam0) valid_pix_T_cam1 = torch.zeros_like(pix_T_cam1) valid_cam1_T_cam0 = torch.zeros_like(cam1_T_cam0) # sampling_coords_mem1 = torch.zeros(B, self.num_samples, 3).float().cuda() for b in list(range(B)): xyz_mem0_b = xyz_mem0[b] mask_mem0_b = mask_mem0[b] xyz_mem0_b = xyz_mem0_b[torch.where(mask_mem0_b > 0)] # this is N x 3 N, D = list(xyz_mem0_b.shape) if N >= self.num_samples: perm = np.random.permutation(N) xyz_mem0_b = xyz_mem0_b[perm[:self.num_samples]] # this is num_samples x 3 (smaller than before) valid_batches += 1 # sampling_coords_mem0[valid_batches] = xyz_mem0_b sampling_coords_mem0[b] = xyz_mem0_b valid_feat_cam0[b] = feat_cam0[b] valid_feat_cam1[b] = feat_cam1[b] valid_pix_T_cam0[b] = pix_T_cam0[b] valid_pix_T_cam1[b] = pix_T_cam1[b] valid_cam1_T_cam0[b] = cam1_T_cam0[b] print('valid_batches:', valid_batches) if valid_batches == 0: # return early return total_loss # trim down sampling_coords_mem0 = sampling_coords_mem0[:valid_batches] feat_cam0 = valid_feat_cam0[:valid_batches] feat_cam1 = valid_feat_cam1[:valid_batches] pix_T_cam0 = valid_pix_T_cam0[:valid_batches] pix_T_cam1 = valid_pix_T_cam1[:valid_batches] cam1_T_cam0 = valid_cam1_T_cam0[:valid_batches] xyz_cam0 = vox_util.Mem2Ref(sampling_coords_mem0, Z, Y, X) xyz_cam1 = utils_geom.apply_4x4(cam1_T_cam0, xyz_cam0) # these are B x N x 3 # now, i need to project both of these, and sample from the feats xy_cam0 = utils_geom.apply_pix_T_cam(pix_T_cam0, xyz_cam0) xy_cam1 = utils_geom.apply_pix_T_cam(pix_T_cam1, xyz_cam1) # these are B x N x 2 vec0 = utils_samp.bilinear_sample2D(feat_cam0, xy_cam0[:, :, 0], xy_cam0[:, :, 1]) vec1 = utils_samp.bilinear_sample2D(feat_cam1, xy_cam1[:, :, 0], xy_cam1[:, :, 1]) # these are B x C x N vec0 = vec0.permute(0, 2, 1).view(valid_batches * self.num_samples, C) vec1 = vec1.permute(0, 2, 1).view(valid_batches * self.num_samples, C) print('vec0', vec0.shape) print('vec1', vec1.shape) # these are N x C # # where g is valid, we use it as reference and pull up e # margin_loss = self.compute_margin_loss(B, C, D, H, W, emb_e_vec, emb_g_vec.detach(), vis_g_vec, 'g', True, summ_writer) # l2_loss = reduce_masked_mean(sql2_on_axis(emb_e-emb_g.detach(), 1, keepdim=True), vis_g) # total_loss = utils_misc.add_loss('emb3D/emb_3D_ml_loss', total_loss, margin_loss, hyp.emb_3D_ml_coeff, summ_writer) # total_loss = utils_misc.add_loss('emb3D/emb_3D_l2_loss', total_loss, l2_loss, hyp.emb_3D_l2_coeff, summ_writer) ce_loss = self.compute_ce_loss(vec0, vec1.detach()) total_loss = utils_misc.add_loss('tri2D/emb_ce_loss', total_loss, ce_loss, hyp.tri_2D_ce_coeff, summ_writer) # l2_loss_im = torch.mean(sql2_on_axis(emb_e-emb_g, 1, keepdim=True), dim=3) # if summ_writer is not None: # summ_writer.summ_oned('emb3D/emb_3D_l2_loss', l2_loss_im) # summ_writer.summ_feats('emb3D/embs_3D', [emb_e, emb_g], pca=True) return total_loss
def draw_boxes_on_rgb(rgb_camX, pix_T_cams, bboxes, visualize=False): xmin, ymin, zmin = bboxes[0, 0:3] xmax, ymax, zmax = bboxes[0, 3:6] rgb = np.copy(rgb_camX) bbox_xyz = np.array([[xmin, ymin, zmin], [xmin, ymin, zmax], [xmin, ymax, zmin], [xmin, ymax, zmax], [xmax, ymin, zmin], [xmax, ymin, zmax], [xmax, ymax, zmin], [xmax, ymax, zmax]] ) # bbox_xyz_pytorch = torch.from_numpy(bbox_xyz).unsqueeze(0).unsqueeze(0) # bbox_xyz_pytorch = torch.tensor(bbox_xyz_pytorch, dtype=torch.float32) # scores_pytorch = torch.ones((1,1), dtype=torch.uint8) # # st() # tids_pytorch = torch.ones_like(scores_pytorch) # rgb_pytorch = torch.tensor(torch.from_numpy(rgb_camX).permute(2, 0, 1).unsqueeze(0), dtype=torch.float32) # pix_T_cams_pytorch = torch.tensor(torch.from_numpy(pix_T_cams).unsqueeze(0), dtype=torch.float32) # summwriter = utils_improc.Summ_writer(None, 10, None, 8, 8) # # st() # bbox_rgb = summwriter.summ_box_by_corners("name_dummy", rgb_pytorch, bbox_xyz_pytorch, scores_pytorch, tids_pytorch, pix_T_cams_pytorch, only_return=True) # bbox_rgb = utils_improc.back2color(bbox_rgb).permute(0, 2, 3, 1)[0].numpy() # # st() # plt.imshow(bbox_rgb) # plt.show(block=True) bbox_img_xy = utils_geom.apply_pix_T_cam(torch.from_numpy(pix_T_cams).unsqueeze(0), torch.from_numpy(bbox_xyz).unsqueeze(0)).squeeze(0) # torch.Size([8, 2]) bbox_img_xy = bbox_img_xy.numpy() bbox_img_xy = bbox_img_xy.astype(int) A, E, D, H, B, F, C, G = bbox_img_xy A = (A[0], A[1]) B = (B[0], B[1]) C = (C[0], C[1]) D = (D[0], D[1]) E = (E[0], E[1]) F = (F[0], F[1]) G = (G[0], G[1]) H = (H[0], H[1]) lineThickness = 2 # img = cv2.rectangle(img,(384,0),(510,128),(0,255,0),3) rgb_camX = cv2.line(rgb_camX, A, E,(255,0,0),lineThickness) rgb_camX = cv2.line(rgb_camX, E, H,(255,0,0),lineThickness) rgb_camX = cv2.line(rgb_camX, D, H,(255,0,0),lineThickness) rgb_camX = cv2.line(rgb_camX, D, A,(255,0,0),lineThickness) rgb_camX = cv2.line(rgb_camX, B, F,(255,0,0),lineThickness) rgb_camX = cv2.line(rgb_camX, G, F,(255,0,0),lineThickness) rgb_camX = cv2.line(rgb_camX, G, C,(255,0,0),lineThickness) rgb_camX = cv2.line(rgb_camX, C, B,(255,0,0),lineThickness) rgb_camX = cv2.line(rgb_camX, A, B,(255,0,0),lineThickness) rgb_camX = cv2.line(rgb_camX, E, F,(255,0,0),lineThickness) rgb_camX = cv2.line(rgb_camX, C, D,(255,0,0),lineThickness) rgb_camX = cv2.line(rgb_camX, G, H,(255,0,0),lineThickness) # cv2.line(rgb_camX, bbox_img_xy[0], bbox_img_xy[2],(255,0,0),5) # st() # cv2.imshow('image',rgb_camX) if visualize: rgb = np.concatenate([rgb, rgb_camX], axis=1) plt.imshow(rgb) plt.show(block=True) return rgb_camX
def get_object_info(self, f, rgb_camX, pix_T_camX, origin_T_camX): # st() objs_info = f['objects_info'] object_dict = {} if self.visualize: plt.imshow(rgb_camX[..., :3]) plt.show(block=True) for obj_info in objs_info: classname = obj_info['category_name'] if classname in self.ignore_classes: continue category = obj_info['category_id'] instance_id = obj_info['instance_id'] bbox_center = obj_info['bbox_center'] bbox_size = obj_info['bbox_size'] xmin, xmax = bbox_center[0] - bbox_size[0] / 2., bbox_center[ 0] + bbox_size[0] / 2. ymin, ymax = bbox_center[1] - bbox_size[1] / 2., bbox_center[ 1] + bbox_size[1] / 2. zmin, zmax = bbox_center[2] - bbox_size[2] / 2., bbox_center[ 2] + bbox_size[2] / 2. bbox_volume = (xmax - xmin) * (ymax - ymin) * (zmax - zmin) bbox_origin_ends = np.array([xmin, ymin, zmin, xmax, ymax, zmax]) bbox_origin_ends = torch.tensor(bbox_origin_ends).reshape( 1, 1, 2, 3).float() bbox_origin_theta = nlu.get_alignedboxes2thetaformat( bbox_origin_ends) bbox_origin_corners = utils_geom.transform_boxes_to_corners( bbox_origin_theta).float() camX_T_origin = utils_geom.safe_inverse( torch.tensor(origin_T_camX).unsqueeze(0)).float() bbox_corners_camX = utils_geom.apply_4x4( camX_T_origin.float(), bbox_origin_corners.squeeze(0).float()) bbox_corners_pixX = utils_geom.apply_pix_T_cam( torch.tensor(pix_T_camX).unsqueeze(0).float(), bbox_corners_camX) bbox_ends_pixX = nlu.get_ends_of_corner( bbox_corners_pixX.permute(0, 2, 1)).permute(0, 2, 1) bbox_ends_pixX_np = torch.clamp( bbox_ends_pixX.squeeze(0), 0, rgb_camX.shape[1]).numpy().astype(int) bbox_area = (bbox_ends_pixX_np[1, 1] - bbox_ends_pixX_np[0, 1]) * ( bbox_ends_pixX_np[1, 0] - bbox_ends_pixX_np[0, 0]) print("Volume and area occupied by class {} is {} and {}".format( classname, bbox_volume, bbox_area)) semantic = f['semantic_camX'] instance_id_pixel_cnt = np.where(semantic == instance_id)[0].shape object_to_bbox_ratio = instance_id_pixel_cnt / bbox_area print( "Num pixels in semantic map {}. Ratio of pixels to bbox area{}. Ratio of pixels to bbox volume {}. " .format(instance_id_pixel_cnt, object_to_bbox_ratio, instance_id_pixel_cnt / bbox_volume)) if self.visualize: # print("bbox ends are: ", bbox_ends_pixX_np) cropped_rgb = rgb_camX[ bbox_ends_pixX_np[0, 1]:bbox_ends_pixX_np[1, 1], bbox_ends_pixX_np[0, 0]:bbox_ends_pixX_np[1, 0], :3] plt.imshow(cropped_rgb) plt.show(block=True) if bbox_area < self.bbox_area_thresh: continue if object_to_bbox_ratio < self.occlusion_thresh: continue object_dict[instance_id] = (classname, category, instance_id, bbox_origin_ends) return object_dict
def forward(self, feed): results = dict() if 'log_freq' not in feed.keys(): feed['log_freq'] = None start_time = time.time() summ_writer = utils_improc.Summ_writer(writer=feed['writer'], global_step=feed['global_step'], set_name=feed['set_name'], log_freq=feed['log_freq'], fps=8) writer = feed['writer'] global_step = feed['global_step'] total_loss = torch.tensor(0.0).cuda() __p = lambda x: utils_basic.pack_seqdim(x, B) __u = lambda x: utils_basic.unpack_seqdim(x, B) __pb = lambda x: utils_basic.pack_boxdim(x, hyp.N) __ub = lambda x: utils_basic.unpack_boxdim(x, hyp.N) if hyp.aug_object_ent_dis: __pb_a = lambda x: utils_basic.pack_boxdim( x, hyp.max_obj_aug + hyp.max_obj_aug_dis) __ub_a = lambda x: utils_basic.unpack_boxdim( x, hyp.max_obj_aug + hyp.max_obj_aug_dis) else: __pb_a = lambda x: utils_basic.pack_boxdim(x, hyp.max_obj_aug) __ub_a = lambda x: utils_basic.unpack_boxdim(x, hyp.max_obj_aug) B, H, W, V, S, N = hyp.B, hyp.H, hyp.W, hyp.V, hyp.S, hyp.N PH, PW = hyp.PH, hyp.PW K = hyp.K BOX_SIZE = hyp.BOX_SIZE Z, Y, X = hyp.Z, hyp.Y, hyp.X Z2, Y2, X2 = int(Z / 2), int(Y / 2), int(X / 2) Z4, Y4, X4 = int(Z / 4), int(Y / 4), int(X / 4) D = 9 tids = torch.from_numpy(np.reshape(np.arange(B * N), [B, N])) rgb_camXs = feed["rgb_camXs_raw"] pix_T_cams = feed["pix_T_cams_raw"] camRs_T_origin = feed["camR_T_origin_raw"] origin_T_camRs = __u(utils_geom.safe_inverse(__p(camRs_T_origin))) origin_T_camXs = feed["origin_T_camXs_raw"] camX0_T_camXs = utils_geom.get_camM_T_camXs(origin_T_camXs, ind=0) camRs_T_camXs = __u( torch.matmul(utils_geom.safe_inverse(__p(origin_T_camRs)), __p(origin_T_camXs))) camXs_T_camRs = __u(utils_geom.safe_inverse(__p(camRs_T_camXs))) camX0_T_camRs = camXs_T_camRs[:, 0] camX1_T_camRs = camXs_T_camRs[:, 1] camR_T_camX0 = utils_geom.safe_inverse(camX0_T_camRs) xyz_camXs = feed["xyz_camXs_raw"] depth_camXs_, valid_camXs_ = utils_geom.create_depth_image( __p(pix_T_cams), __p(xyz_camXs), H, W) dense_xyz_camXs_ = utils_geom.depth2pointcloud(depth_camXs_, __p(pix_T_cams)) xyz_camRs = __u( utils_geom.apply_4x4(__p(camRs_T_camXs), __p(xyz_camXs))) xyz_camX0s = __u( utils_geom.apply_4x4(__p(camX0_T_camXs), __p(xyz_camXs))) occXs = __u(utils_vox.voxelize_xyz(__p(xyz_camXs), Z, Y, X)) occXs_to_Rs = utils_vox.apply_4x4s_to_voxs(camRs_T_camXs, occXs) occXs_to_Rs_45 = cross_corr.rotate_tensor_along_y_axis(occXs_to_Rs, 45) occXs_half = __u(utils_vox.voxelize_xyz(__p(xyz_camXs), Z2, Y2, X2)) occRs_half = __u(utils_vox.voxelize_xyz(__p(xyz_camRs), Z2, Y2, X2)) occX0s_half = __u(utils_vox.voxelize_xyz(__p(xyz_camX0s), Z2, Y2, X2)) unpXs = __u( utils_vox.unproject_rgb_to_mem(__p(rgb_camXs), Z, Y, X, __p(pix_T_cams))) unpXs_half = __u( utils_vox.unproject_rgb_to_mem(__p(rgb_camXs), Z2, Y2, X2, __p(pix_T_cams))) unpX0s_half = __u( utils_vox.unproject_rgb_to_mem( __p(rgb_camXs), Z2, Y2, X2, utils_basic.matmul2( __p(pix_T_cams), utils_geom.safe_inverse(__p(camX0_T_camXs))))) unpRs = __u( utils_vox.unproject_rgb_to_mem( __p(rgb_camXs), Z, Y, X, utils_basic.matmul2( __p(pix_T_cams), utils_geom.safe_inverse(__p(camRs_T_camXs))))) unpRs_half = __u( utils_vox.unproject_rgb_to_mem( __p(rgb_camXs), Z2, Y2, X2, utils_basic.matmul2( __p(pix_T_cams), utils_geom.safe_inverse(__p(camRs_T_camXs))))) dense_xyz_camRs_ = utils_geom.apply_4x4(__p(camRs_T_camXs), dense_xyz_camXs_) inbound_camXs_ = utils_vox.get_inbounds(dense_xyz_camRs_, Z, Y, X).float() inbound_camXs_ = torch.reshape(inbound_camXs_, [B * S, 1, H, W]) depth_camXs = __u(depth_camXs_) valid_camXs = __u(valid_camXs_) * __u(inbound_camXs_) summ_writer.summ_oneds('2D_inputs/depth_camXs', torch.unbind(depth_camXs, dim=1), maxdepth=21.0) summ_writer.summ_oneds('2D_inputs/valid_camXs', torch.unbind(valid_camXs, dim=1)) summ_writer.summ_rgbs('2D_inputs/rgb_camXs', torch.unbind(rgb_camXs, dim=1)) summ_writer.summ_occs('3D_inputs/occXs', torch.unbind(occXs, dim=1)) summ_writer.summ_unps('3D_inputs/unpXs', torch.unbind(unpXs, dim=1), torch.unbind(occXs, dim=1)) occRs = __u(utils_vox.voxelize_xyz(__p(xyz_camRs), Z, Y, X)) if hyp.do_eval_boxes: if hyp.dataset_name == "clevr_vqa": gt_boxes_origin_corners = feed['gt_box'] gt_scores_origin = feed['gt_scores'].detach().cpu().numpy() classes = feed['classes'] scores = gt_scores_origin tree_seq_filename = feed['tree_seq_filename'] gt_boxes_origin = nlu.get_ends_of_corner( gt_boxes_origin_corners) gt_boxes_origin_end = torch.reshape(gt_boxes_origin, [hyp.B, hyp.N, 2, 3]) gt_boxes_origin_theta = nlu.get_alignedboxes2thetaformat( gt_boxes_origin_end) gt_boxes_origin_corners = utils_geom.transform_boxes_to_corners( gt_boxes_origin_theta) gt_boxesR_corners = __ub( utils_geom.apply_4x4(camRs_T_origin[:, 0], __pb(gt_boxes_origin_corners))) gt_boxesR_theta = utils_geom.transform_corners_to_boxes( gt_boxesR_corners) gt_boxesR_end = nlu.get_ends_of_corner(gt_boxesR_corners) else: tree_seq_filename = feed['tree_seq_filename'] tree_filenames = [ join(hyp.root_dataset, i) for i in tree_seq_filename if i != "invalid_tree" ] invalid_tree_filenames = [ join(hyp.root_dataset, i) for i in tree_seq_filename if i == "invalid_tree" ] num_empty = len(invalid_tree_filenames) trees = [pickle.load(open(i, "rb")) for i in tree_filenames] len_valid = len(trees) if len_valid > 0: gt_boxesR, scores, classes = nlu.trees_rearrange(trees) if num_empty > 0: gt_boxesR = np.concatenate([ gt_boxesR, empty_gt_boxesR ]) if len_valid > 0 else empty_gt_boxesR scores = np.concatenate([ scores, empty_scores ]) if len_valid > 0 else empty_scores classes = np.concatenate([ classes, empty_classes ]) if len_valid > 0 else empty_classes gt_boxesR = torch.from_numpy( gt_boxesR).cuda().float() # torch.Size([2, 3, 6]) gt_boxesR_end = torch.reshape(gt_boxesR, [hyp.B, hyp.N, 2, 3]) gt_boxesR_theta = nlu.get_alignedboxes2thetaformat( gt_boxesR_end) #torch.Size([2, 3, 9]) gt_boxesR_corners = utils_geom.transform_boxes_to_corners( gt_boxesR_theta) class_names_ex_1 = "_".join(classes[0]) summ_writer.summ_text('eval_boxes/class_names', class_names_ex_1) gt_boxesRMem_corners = __ub( utils_vox.Ref2Mem(__pb(gt_boxesR_corners), Z2, Y2, X2)) gt_boxesRMem_end = nlu.get_ends_of_corner(gt_boxesRMem_corners) gt_boxesRMem_theta = utils_geom.transform_corners_to_boxes( gt_boxesRMem_corners) gt_boxesRUnp_corners = __ub( utils_vox.Ref2Mem(__pb(gt_boxesR_corners), Z, Y, X)) gt_boxesRUnp_end = nlu.get_ends_of_corner(gt_boxesRUnp_corners) gt_boxesX0_corners = __ub( utils_geom.apply_4x4(camX0_T_camRs, __pb(gt_boxesR_corners))) gt_boxesX0Mem_corners = __ub( utils_vox.Ref2Mem(__pb(gt_boxesX0_corners), Z2, Y2, X2)) gt_boxesX0Mem_theta = utils_geom.transform_corners_to_boxes( gt_boxesX0Mem_corners) gt_boxesX0Mem_end = nlu.get_ends_of_corner(gt_boxesX0Mem_corners) gt_boxesX0_end = nlu.get_ends_of_corner(gt_boxesX0_corners) gt_cornersX0_pix = __ub( utils_geom.apply_pix_T_cam(pix_T_cams[:, 0], __pb(gt_boxesX0_corners))) rgb_camX0 = rgb_camXs[:, 0] rgb_camX1 = rgb_camXs[:, 1] summ_writer.summ_box_by_corners('eval_boxes/gt_boxescamX0', rgb_camX0, gt_boxesX0_corners, torch.from_numpy(scores), tids, pix_T_cams[:, 0]) unps_vis = utils_improc.get_unps_vis(unpX0s_half, occX0s_half) unp_vis = torch.mean(unps_vis, dim=1) unps_visRs = utils_improc.get_unps_vis(unpRs_half, occRs_half) unp_visRs = torch.mean(unps_visRs, dim=1) unps_visRs_full = utils_improc.get_unps_vis(unpRs, occRs) unp_visRs_full = torch.mean(unps_visRs_full, dim=1) summ_writer.summ_box_mem_on_unp('eval_boxes/gt_boxesR_mem', unp_visRs, gt_boxesRMem_end, scores, tids) unpX0s_half = torch.mean(unpX0s_half, dim=1) unpX0s_half = nlu.zero_out(unpX0s_half, gt_boxesX0Mem_end, scores) occX0s_half = torch.mean(occX0s_half, dim=1) occX0s_half = nlu.zero_out(occX0s_half, gt_boxesX0Mem_end, scores) summ_writer.summ_unp('3D_inputs/unpX0s', unpX0s_half, occX0s_half) if hyp.do_feat: featXs_input = torch.cat([occXs, occXs * unpXs], dim=2) featXs_input_ = __p(featXs_input) freeXs_ = utils_vox.get_freespace(__p(xyz_camXs), __p(occXs_half)) freeXs = __u(freeXs_) visXs = torch.clamp(occXs_half + freeXs, 0.0, 1.0) mask_ = None if (type(mask_) != type(None)): assert (list(mask_.shape)[2:5] == list( featXs_input_.shape)[2:5]) featXs_, feat_loss = self.featnet(featXs_input_, summ_writer, mask=__p(occXs)) #mask_) total_loss += feat_loss validXs = torch.ones_like(visXs) _validX00 = validXs[:, 0:1] _validX01 = utils_vox.apply_4x4s_to_voxs(camX0_T_camXs[:, 1:], validXs[:, 1:]) validX0s = torch.cat([_validX00, _validX01], dim=1) validRs = utils_vox.apply_4x4s_to_voxs(camRs_T_camXs, validXs) visRs = utils_vox.apply_4x4s_to_voxs(camRs_T_camXs, visXs) featXs = __u(featXs_) _featX00 = featXs[:, 0:1] _featX01 = utils_vox.apply_4x4s_to_voxs(camX0_T_camXs[:, 1:], featXs[:, 1:]) featX0s = torch.cat([_featX00, _featX01], dim=1) emb3D_e = torch.mean(featX0s[:, 1:], dim=1) vis3D_e_R = torch.max(visRs[:, 1:], dim=1)[0] emb3D_g = featX0s[:, 0] vis3D_g_R = visRs[:, 0] validR_combo = torch.min(validRs, dim=1).values summ_writer.summ_feats('3D_feats/featXs_input', torch.unbind(featXs_input, dim=1), pca=True) summ_writer.summ_feats('3D_feats/featXs_output', torch.unbind(featXs, dim=1), valids=torch.unbind(validXs, dim=1), pca=True) summ_writer.summ_feats('3D_feats/featX0s_output', torch.unbind(featX0s, dim=1), valids=torch.unbind( torch.ones_like(validRs), dim=1), pca=True) summ_writer.summ_feats('3D_feats/validRs', torch.unbind(validRs, dim=1), pca=False) summ_writer.summ_feat('3D_feats/vis3D_e_R', vis3D_e_R, pca=False) summ_writer.summ_feat('3D_feats/vis3D_g_R', vis3D_g_R, pca=False) if hyp.do_munit: object_classes, filenames = nlu.create_object_classes( classes, [tree_seq_filename, tree_seq_filename], scores) if hyp.do_munit_fewshot: emb3D_e_R = utils_vox.apply_4x4_to_vox(camR_T_camX0, emb3D_e) emb3D_g_R = utils_vox.apply_4x4_to_vox(camR_T_camX0, emb3D_g) emb3D_R = emb3D_e_R emb3D_e_R_object, emb3D_g_R_object, validR_combo_object = nlu.create_object_tensors( [emb3D_e_R, emb3D_g_R], [validR_combo], gt_boxesRMem_end, scores, [BOX_SIZE, BOX_SIZE, BOX_SIZE]) emb3D_R_object = (emb3D_e_R_object + emb3D_g_R_object) / 2 content, style = self.munitnet.net.gen_a.encode(emb3D_R_object) objects_taken, _ = self.munitnet.net.gen_a.decode( content, style) styles = style contents = content elif hyp.do_3d_style_munit: emb3D_e_R = utils_vox.apply_4x4_to_vox(camR_T_camX0, emb3D_e) emb3D_g_R = utils_vox.apply_4x4_to_vox(camR_T_camX0, emb3D_g) emb3D_R = emb3D_e_R # st() emb3D_e_R_object, emb3D_g_R_object, validR_combo_object = nlu.create_object_tensors( [emb3D_e_R, emb3D_g_R], [validR_combo], gt_boxesRMem_end, scores, [BOX_SIZE, BOX_SIZE, BOX_SIZE]) emb3D_R_object = (emb3D_e_R_object + emb3D_g_R_object) / 2 camX1_T_R = camXs_T_camRs[:, 1] camX0_T_R = camXs_T_camRs[:, 0] assert hyp.B == 2 assert emb3D_e_R_object.shape[0] == 2 munit_loss, sudo_input_0, sudo_input_1, recon_input_0, recon_input_1, sudo_input_0_cycle, sudo_input_1_cycle, styles, contents, adin = self.munitnet( emb3D_R_object[0:1], emb3D_R_object[1:2]) if hyp.store_content_style_range: if self.max_content == None: self.max_content = torch.zeros_like( contents[0][0]).cuda() - 100000000 if self.min_content == None: self.min_content = torch.zeros_like( contents[0][0]).cuda() + 100000000 if self.max_style == None: self.max_style = torch.zeros_like( styles[0][0]).cuda() - 100000000 if self.min_style == None: self.min_style = torch.zeros_like( styles[0][0]).cuda() + 100000000 self.max_content = torch.max( torch.max(self.max_content, contents[0][0]), contents[1][0]) self.min_content = torch.min( torch.min(self.min_content, contents[0][0]), contents[1][0]) self.max_style = torch.max( torch.max(self.max_style, styles[0][0]), styles[1][0]) self.min_style = torch.min( torch.min(self.min_style, styles[0][0]), styles[1][0]) data_to_save = { 'max_content': self.max_content.cpu().numpy(), 'min_content': self.min_content.cpu().numpy(), 'max_style': self.max_style.cpu().numpy(), 'min_style': self.min_style.cpu().numpy() } with open('content_style_range.p', 'wb') as f: pickle.dump(data_to_save, f) elif hyp.is_contrastive_examples: if hyp.normalize_contrast: content0 = (contents[0] - self.min_content) / ( self.max_content - self.min_content + 1e-5) content1 = (contents[1] - self.min_content) / ( self.max_content - self.min_content + 1e-5) style0 = (styles[0] - self.min_style) / ( self.max_style - self.min_style + 1e-5) style1 = (styles[1] - self.min_style) / ( self.max_style - self.min_style + 1e-5) else: content0 = contents[0] content1 = contents[1] style0 = styles[0] style1 = styles[1] # euclid_dist_content = torch.sum(torch.sqrt((content0 - content1)**2))/torch.prod(torch.tensor(content0.shape)) # euclid_dist_style = torch.sum(torch.sqrt((style0-style1)**2))/torch.prod(torch.tensor(style0.shape)) euclid_dist_content = (content0 - content1).norm(2) / ( content0.numel()) euclid_dist_style = (style0 - style1).norm(2) / (style0.numel()) content_0_pooled = torch.mean( content0.reshape(list(content0.shape[:2]) + [-1]), dim=-1) content_1_pooled = torch.mean( content1.reshape(list(content1.shape[:2]) + [-1]), dim=-1) euclid_dist_content_pooled = (content_0_pooled - content_1_pooled).norm(2) / ( content_0_pooled.numel()) content_0_normalized = content0 / content0.norm() content_1_normalized = content1 / content1.norm() style_0_normalized = style0 / style0.norm() style_1_normalized = style1 / style1.norm() content_0_pooled_normalized = content_0_pooled / content_0_pooled.norm( ) content_1_pooled_normalized = content_1_pooled / content_1_pooled.norm( ) cosine_dist_content = torch.sum(content_0_normalized * content_1_normalized) cosine_dist_style = torch.sum(style_0_normalized * style_1_normalized) cosine_dist_content_pooled = torch.sum( content_0_pooled_normalized * content_1_pooled_normalized) print("euclid dist [content, pooled-content, style]: ", euclid_dist_content, euclid_dist_content_pooled, euclid_dist_style) print("cosine sim [content, pooled-content, style]: ", cosine_dist_content, cosine_dist_content_pooled, cosine_dist_style) if hyp.run_few_shot_on_munit: if (global_step % 300) == 1 or (global_step % 300) == 0: wrong = False try: precision_style = float(self.tp_style) / self.all_style precision_content = float( self.tp_content) / self.all_content except ZeroDivisionError: wrong = True if not wrong: summ_writer.summ_scalar( 'precision/unsupervised_precision_style', precision_style) summ_writer.summ_scalar( 'precision/unsupervised_precision_content', precision_content) # st() self.embed_list_style = defaultdict(lambda: []) self.embed_list_content = defaultdict(lambda: []) self.tp_style = 0 self.all_style = 0 self.tp_content = 0 self.all_content = 0 self.check = False elif not self.check and not nlu.check_fill_dict( self.embed_list_content, self.embed_list_style): print("Filling \n") for index, class_val in enumerate(object_classes): if hyp.dataset_name == "clevr_vqa": class_val_content, class_val_style = class_val.split( "/") else: class_val_content, class_val_style = [ class_val.split("/")[0], class_val.split("/")[0] ] print(len(self.embed_list_style.keys()), "style class", len(self.embed_list_content), "content class", self.embed_list_content.keys()) if len(self.embed_list_style[class_val_style] ) < hyp.few_shot_nums: self.embed_list_style[class_val_style].append( styles[index].squeeze()) if len(self.embed_list_content[class_val_content] ) < hyp.few_shot_nums: if hyp.avg_3d: content_val = contents[index] content_val = torch.mean(content_val.reshape( [content_val.shape[1], -1]), dim=-1) # st() self.embed_list_content[ class_val_content].append(content_val) else: self.embed_list_content[ class_val_content].append( contents[index].reshape([-1])) else: self.check = True try: print(float(self.tp_content) / self.all_content) print(float(self.tp_style) / self.all_style) except Exception as e: pass average = True if average: for key, val in self.embed_list_style.items(): if isinstance(val, type([])): self.embed_list_style[key] = torch.mean( torch.stack(val, dim=0), dim=0) for key, val in self.embed_list_content.items(): if isinstance(val, type([])): self.embed_list_content[key] = torch.mean( torch.stack(val, dim=0), dim=0) else: for key, val in self.embed_list_style.items(): if isinstance(val, type([])): self.embed_list_style[key] = torch.stack(val, dim=0) for key, val in self.embed_list_content.items(): if isinstance(val, type([])): self.embed_list_content[key] = torch.stack( val, dim=0) for index, class_val in enumerate(object_classes): class_val = class_val if hyp.dataset_name == "clevr_vqa": class_val_content, class_val_style = class_val.split( "/") else: class_val_content, class_val_style = [ class_val.split("/")[0], class_val.split("/")[0] ] style_val = styles[index].squeeze().unsqueeze(0) if not average: embed_list_val_style = torch.cat(list( self.embed_list_style.values()), dim=0) embed_list_key_style = list( np.repeat( np.expand_dims( list(self.embed_list_style.keys()), 1), hyp.few_shot_nums, 1).reshape([-1])) else: embed_list_val_style = torch.stack(list( self.embed_list_style.values()), dim=0) embed_list_key_style = list( self.embed_list_style.keys()) embed_list_val_style = utils_basic.l2_normalize( embed_list_val_style, dim=1).permute(1, 0) style_val = utils_basic.l2_normalize(style_val, dim=1) scores_styles = torch.matmul(style_val, embed_list_val_style) index_key = torch.argmax(scores_styles, dim=1).squeeze() selected_class_style = embed_list_key_style[index_key] self.styles_prediction[class_val_style].append( selected_class_style) if class_val_style == selected_class_style: self.tp_style += 1 self.all_style += 1 if hyp.avg_3d: content_val = contents[index] content_val = torch.mean(content_val.reshape( [content_val.shape[1], -1]), dim=-1).unsqueeze(0) else: content_val = contents[index].reshape( [-1]).unsqueeze(0) if not average: embed_list_val_content = torch.cat(list( self.embed_list_content.values()), dim=0) embed_list_key_content = list( np.repeat( np.expand_dims( list(self.embed_list_content.keys()), 1), hyp.few_shot_nums, 1).reshape([-1])) else: embed_list_val_content = torch.stack(list( self.embed_list_content.values()), dim=0) embed_list_key_content = list( self.embed_list_content.keys()) embed_list_val_content = utils_basic.l2_normalize( embed_list_val_content, dim=1).permute(1, 0) content_val = utils_basic.l2_normalize(content_val, dim=1) scores_content = torch.matmul(content_val, embed_list_val_content) index_key = torch.argmax(scores_content, dim=1).squeeze() selected_class_content = embed_list_key_content[ index_key] self.content_prediction[class_val_content].append( selected_class_content) if class_val_content == selected_class_content: self.tp_content += 1 self.all_content += 1 # st() munit_loss = hyp.munit_loss_weight * munit_loss recon_input_obj = torch.cat([recon_input_0, recon_input_1], dim=0) recon_emb3D_R = nlu.update_scene_with_objects( emb3D_R, recon_input_obj, gt_boxesRMem_end, scores) sudo_input_obj = torch.cat([sudo_input_0, sudo_input_1], dim=0) styled_emb3D_R = nlu.update_scene_with_objects( emb3D_R, sudo_input_obj, gt_boxesRMem_end, scores) styled_emb3D_e_X1 = utils_vox.apply_4x4_to_vox( camX1_T_R, styled_emb3D_R) styled_emb3D_e_X0 = utils_vox.apply_4x4_to_vox( camX0_T_R, styled_emb3D_R) emb3D_e_X1 = utils_vox.apply_4x4_to_vox(camX1_T_R, recon_emb3D_R) emb3D_e_X0 = utils_vox.apply_4x4_to_vox(camX0_T_R, recon_emb3D_R) emb3D_e_X1_og = utils_vox.apply_4x4_to_vox(camX1_T_R, emb3D_R) emb3D_e_X0_og = utils_vox.apply_4x4_to_vox(camX0_T_R, emb3D_R) emb3D_R_aug_diff = torch.abs(emb3D_R - recon_emb3D_R) summ_writer.summ_feat(f'aug_feat/og', emb3D_R) summ_writer.summ_feat(f'aug_feat/og_gen', recon_emb3D_R) summ_writer.summ_feat(f'aug_feat/og_aug_diff', emb3D_R_aug_diff) if hyp.cycle_style_view_loss: sudo_input_obj_cycle = torch.cat( [sudo_input_0_cycle, sudo_input_1_cycle], dim=0) styled_emb3D_R_cycle = nlu.update_scene_with_objects( emb3D_R, sudo_input_obj_cycle, gt_boxesRMem_end, scores) styled_emb3D_e_X0_cycle = utils_vox.apply_4x4_to_vox( camX0_T_R, styled_emb3D_R_cycle) styled_emb3D_e_X1_cycle = utils_vox.apply_4x4_to_vox( camX1_T_R, styled_emb3D_R_cycle) summ_writer.summ_scalar('munit_loss', munit_loss.cpu().item()) total_loss += munit_loss if hyp.do_occ and hyp.occ_do_cheap: occX0_sup, freeX0_sup, _, freeXs = utils_vox.prep_occs_supervision( camX0_T_camXs, xyz_camXs, Z2, Y2, X2, agg=True) summ_writer.summ_occ('occ_sup/occ_sup', occX0_sup) summ_writer.summ_occ('occ_sup/free_sup', freeX0_sup) summ_writer.summ_occs('occ_sup/freeXs_sup', torch.unbind(freeXs, dim=1)) summ_writer.summ_occs('occ_sup/occXs_sup', torch.unbind(occXs_half, dim=1)) occ_loss, occX0s_pred_ = self.occnet( torch.mean(featX0s[:, 1:], dim=1), occX0_sup, freeX0_sup, torch.max(validX0s[:, 1:], dim=1)[0], summ_writer) occX0s_pred = __u(occX0s_pred_) total_loss += occ_loss if hyp.do_view: assert (hyp.do_feat) PH, PW = hyp.PH, hyp.PW sy = float(PH) / float(hyp.H) sx = float(PW) / float(hyp.W) assert (sx == 0.5) # else we need a fancier downsampler assert (sy == 0.5) projpix_T_cams = __u( utils_geom.scale_intrinsics(__p(pix_T_cams), sx, sy)) # st() if hyp.do_munit: feat_projX00 = utils_vox.apply_pixX_T_memR_to_voxR( projpix_T_cams[:, 0], camX0_T_camXs[:, 1], emb3D_e_X1, # use feat1 to predict rgb0 hyp.view_depth, PH, PW) feat_projX00_og = utils_vox.apply_pixX_T_memR_to_voxR( projpix_T_cams[:, 0], camX0_T_camXs[:, 1], emb3D_e_X1_og, # use feat1 to predict rgb0 hyp.view_depth, PH, PW) # only for checking the style styled_feat_projX00 = utils_vox.apply_pixX_T_memR_to_voxR( projpix_T_cams[:, 0], camX0_T_camXs[:, 1], styled_emb3D_e_X1, # use feat1 to predict rgb0 hyp.view_depth, PH, PW) if hyp.cycle_style_view_loss: styled_feat_projX00_cycle = utils_vox.apply_pixX_T_memR_to_voxR( projpix_T_cams[:, 0], camX0_T_camXs[:, 1], styled_emb3D_e_X1_cycle, # use feat1 to predict rgb0 hyp.view_depth, PH, PW) else: feat_projX00 = utils_vox.apply_pixX_T_memR_to_voxR( projpix_T_cams[:, 0], camX0_T_camXs[:, 1], featXs[:, 1], # use feat1 to predict rgb0 hyp.view_depth, PH, PW) rgb_X00 = utils_basic.downsample(rgb_camXs[:, 0], 2) rgb_X01 = utils_basic.downsample(rgb_camXs[:, 1], 2) valid_X00 = utils_basic.downsample(valid_camXs[:, 0], 2) view_loss, rgb_e, emb2D_e = self.viewnet(feat_projX00, rgb_X00, valid_X00, summ_writer, "rgb") if hyp.do_munit: _, rgb_e, emb2D_e = self.viewnet(feat_projX00_og, rgb_X00, valid_X00, summ_writer, "rgb_og") if hyp.do_munit: styled_view_loss, styled_rgb_e, styled_emb2D_e = self.viewnet( styled_feat_projX00, rgb_X00, valid_X00, summ_writer, "recon_style") if hyp.cycle_style_view_loss: styled_view_loss_cycle, styled_rgb_e_cycle, styled_emb2D_e_cycle = self.viewnet( styled_feat_projX00_cycle, rgb_X00, valid_X00, summ_writer, "recon_style_cycle") rgb_input_1 = torch.cat( [rgb_X01[1], rgb_X01[0], styled_rgb_e[0]], dim=2) rgb_input_2 = torch.cat( [rgb_X01[0], rgb_X01[1], styled_rgb_e[1]], dim=2) complete_vis = torch.cat([rgb_input_1, rgb_input_2], dim=1) summ_writer.summ_rgb('munit/munit_recons_vis', complete_vis.unsqueeze(0)) if not hyp.do_munit: total_loss += view_loss else: if hyp.basic_view_loss: total_loss += view_loss if hyp.style_view_loss: total_loss += styled_view_loss if hyp.cycle_style_view_loss: total_loss += styled_view_loss_cycle summ_writer.summ_scalar('loss', total_loss.cpu().item()) if hyp.save_embed_tsne: for index, class_val in enumerate(object_classes): class_val_content, class_val_style = class_val.split("/") style_val = styles[index].squeeze().unsqueeze(0) self.cluster_pool.update(style_val, [class_val_style]) print(self.cluster_pool.num) if self.cluster_pool.is_full(): embeds, classes = self.cluster_pool.fetch() with open("offline_cluster" + '/%st.txt' % 'classes', 'w') as f: for index, embed in enumerate(classes): class_val = classes[index] f.write("%s\n" % class_val) f.close() with open("offline_cluster" + '/%st.txt' % 'embeddings', 'w') as f: for index, embed in enumerate(embeds): # embed = utils_basic.l2_normalize(embed,dim=0) print("writing {} embed".format(index)) embed_l_s = [str(i) for i in embed.tolist()] embed_str = '\t'.join(embed_l_s) f.write("%s\n" % embed_str) f.close() st() return total_loss, results