def transform_boxes_to_corners(boxes): # returns corners, shaped B x N x 8 x 3 B, N, D = list(boxes.shape) assert (D == 9) __p = lambda x: utils_basic.pack_seqdim(x, B) __u = lambda x: utils_basic.unpack_seqdim(x, B) boxes_ = __p(boxes) corners_ = transform_boxes_to_corners_single(boxes_) corners = __u(corners_) return corners
def transform_corners_to_boxes(corners): # corners is B x N x 8 x 3 B, N, C, D = corners.shape assert (C == 8) assert (D == 3) # do them all at once __p = lambda x: utils_basic.pack_seqdim(x, B) __u = lambda x: utils_basic.unpack_seqdim(x, B) corners_ = __p(corners) boxes_ = transform_corners_to_boxes_single(corners_) boxes_ = boxes_.cuda() boxes = __u(boxes_) return boxes
def rotate_tensor_along_y_axis(tensor, gamma): B = tensor.shape[0] tensor = tensor.to("cpu") assert tensor.ndim == 6, "Tensors should have 6 dimensions." tensor = tensor.float() # B,S,C,D,H,W __p = lambda x: utils_basic.pack_seqdim(x, B) __u = lambda x: utils_basic.unpack_seqdim(x, B) tensor_ = __p(tensor) tensor_ = tensor_.permute( 0, 1, 3, 2, 4) # Make it BS, C, H, D, W (i.e. BS, C, y, z, x) BS, C, H, D, W = tensor_.shape # merge y dimension with channel dimension and rotate with gamma_ tensor_y_reduced = tensor_.reshape(BS, C * H, D, W) # # gammas will be rotation angles along y axis. # gammas = torch.arange(10, 360, 10) # define the rotation center center = torch.ones(1, 2) center[..., 0] = tensor_y_reduced.shape[3] / 2 # x center[..., 1] = tensor_y_reduced.shape[2] / 2 # z # define the scale factor scale = torch.ones(1) gamma_ = torch.ones(1) * gamma # compute the transformation matrix M = kornia.get_rotation_matrix2d(center, gamma_, scale) M = M.repeat(BS, 1, 1) # apply the transformation to original image # st() tensor_y_reduced_warped = kornia.warp_affine(tensor_y_reduced, M, dsize=(D, W)) tensor_y_reduced_warped = tensor_y_reduced_warped.reshape(BS, C, H, D, W) tensor_y_reduced_warped = tensor_y_reduced_warped.permute(0, 1, 3, 2, 4) tensor_y_reduced_warped = __u(tensor_y_reduced_warped) return tensor_y_reduced_warped.cuda()
def __getitem__(self, index): if hyp.dataset_name == 'kitti' or hyp.dataset_name == 'clevr' or hyp.dataset_name == 'real' or hyp.dataset_name == "bigbird" or hyp.dataset_name == "carla" or hyp.dataset_name == "carla_mix" or hyp.dataset_name == "replica" or hyp.dataset_name == "clevr_vqa" or hyp.dataset_name == "carla_det": # print(index) # st() filename = self.records[index] d = pickle.load(open(filename, "rb")) d = dict(d) d_empty = pickle.load(open(self.empty_scene, "rb")) d_empty = dict(d_empty) # st() # elif hyp.dataset_name=="carla": # filename = self.records[index] # d = np.load(filename) # d = dict(d) # d['rgb_camXs_raw'] = d['rgb_camXs'] # d['pix_T_cams_raw'] = d['pix_T_cams'] # d['tree_seq_filename'] = "dummy_tree_filename" # d['origin_T_camXs_raw'] = d['origin_T_camXs'] # d['camR_T_origin_raw'] = utils_geom.safe_inverse(torch.from_numpy(d['origin_T_camRs'])).numpy() # d['xyz_camXs_raw'] = d['xyz_camXs'] else: assert (False) # reader not ready yet if hyp.do_empty: item_names = [ 'pix_T_cams_raw', 'origin_T_camXs_raw', 'camR_T_origin_raw', 'rgb_camXs_raw', 'xyz_camXs_raw', 'empty_rgb_camXs_raw', 'empty_xyz_camXs_raw', ] else: item_names = [ 'pix_T_cams_raw', 'origin_T_camXs_raw', 'camR_T_origin_raw', 'rgb_camXs_raw', 'xyz_camXs_raw', ] if hyp.use_gt_occs: __p = lambda x: utils_basic.pack_seqdim(x, 1) __u = lambda x: utils_basic.unpack_seqdim(x, 1) B, H, W, V, S, N = hyp.B, hyp.H, hyp.W, hyp.V, hyp.S, hyp.N PH, PW = hyp.PH, hyp.PW K = hyp.K BOX_SIZE = hyp.BOX_SIZE Z, Y, X = hyp.Z, hyp.Y, hyp.X Z2, Y2, X2 = int(Z / 2), int(Y / 2), int(X / 2) Z4, Y4, X4 = int(Z / 4), int(Y / 4), int(X / 4) D = 9 pix_T_cams = torch.from_numpy( d["pix_T_cams_raw"]).unsqueeze(0).cuda().to(torch.float) camRs_T_origin = torch.from_numpy( d["camR_T_origin_raw"]).unsqueeze(0).cuda().to(torch.float) origin_T_camRs = __u(utils_geom.safe_inverse(__p(camRs_T_origin))) origin_T_camXs = torch.from_numpy( d["origin_T_camXs_raw"]).unsqueeze(0).cuda().to(torch.float) camX0_T_camXs = utils_geom.get_camM_T_camXs(origin_T_camXs, ind=0) camRs_T_camXs = __u( torch.matmul(utils_geom.safe_inverse(__p(origin_T_camRs)), __p(origin_T_camXs))) camXs_T_camRs = __u(utils_geom.safe_inverse(__p(camRs_T_camXs))) camX0_T_camRs = camXs_T_camRs[:, 0] camX1_T_camRs = camXs_T_camRs[:, 1] camR_T_camX0 = utils_geom.safe_inverse(camX0_T_camRs) xyz_camXs = torch.from_numpy( d["xyz_camXs_raw"]).unsqueeze(0).cuda().to(torch.float) xyz_camRs = __u( utils_geom.apply_4x4(__p(camRs_T_camXs), __p(xyz_camXs))) depth_camXs_, valid_camXs_ = utils_geom.create_depth_image( __p(pix_T_cams), __p(xyz_camXs), H, W) dense_xyz_camXs_ = utils_geom.depth2pointcloud( depth_camXs_, __p(pix_T_cams)) occXs = __u(utils_vox.voxelize_xyz(__p(xyz_camXs), Z, Y, X)) occRs_half = __u(utils_vox.voxelize_xyz(__p(xyz_camRs), Z2, Y2, X2)) occRs_half = torch.max(occRs_half, dim=1).values.squeeze(0) occ_complete = occRs_half.cpu().numpy() # if hyp.do_time_flip: # d = random_time_flip_single(d,item_names) # if the sequence length > 2, select S frames # filename = d['raw_seq_filename'] original_filename = filename original_filename_empty = self.empty_scene # st() if hyp.dataset_name == "clevr_vqa": d['tree_seq_filename'] = "temp" pix_T_cams = d['pix_T_cams_raw'] num_cams = pix_T_cams.shape[0] # padding_1 = torch.zeros([num_cams,1,3]) # padding_2 = torch.zeros([num_cams,4,1]) # padding_2[:,3] = 1.0 # st() # pix_T_cams = torch.cat([pix_T_cams,padding_1],dim=1) # pix_T_cams = torch.cat([pix_T_cams,padding_2],dim=2) # st() shape_name = d['shape_list'] color_name = d['color_list'] material_name = d['material_list'] all_name = [] all_style = [] for index in range(len(shape_name)): name = shape_name[index] + "/" + color_name[ index] + "_" + material_name[index] style_name = color_name[index] + "_" + material_name[index] all_name.append(name) all_style.append(style_name) # st() if hyp.do_shape: class_name = shape_name elif hyp.do_color: class_name = color_name elif hyp.do_material: class_name = material_name elif hyp.do_style: class_name = all_style else: class_name = all_name object_category = class_name bbox_origin = d['bbox_origin'] # bbox_origin = torch.cat([bbox_origin],dim=0) # object_category = object_category bbox_origin_empty = np.zeros_like(bbox_origin) object_category_empty = ['0'] # st() if not hyp.dataset_name == "clevr_vqa": filename = d['tree_seq_filename'] filename_empty = d_empty['tree_seq_filename'] if hyp.fixed_view: d, indexes = non_random_select_single(d, item_names, num_samples=hyp.S) d_empty, indexes_empty = specific_select_single_empty( d_empty, item_names, d['origin_T_camXs_raw'], num_samples=hyp.S) filename_g = "/".join([original_filename, str(indexes[0])]) filename_e = "/".join([original_filename, str(indexes[1])]) filename_g_empty = "/".join([original_filename_empty, str(indexes[0])]) filename_e_empty = "/".join([original_filename_empty, str(indexes[1])]) rgb_camXs = d['rgb_camXs_raw'] rgb_camXs_empty = d_empty['rgb_camXs_raw'] # move channel dim inward, like pytorch wants # rgb_camRs = np.transpose(rgb_camRs, axes=[0, 3, 1, 2]) rgb_camXs = np.transpose(rgb_camXs, axes=[0, 3, 1, 2]) rgb_camXs = rgb_camXs[:, :3] rgb_camXs = utils_improc.preprocess_color(rgb_camXs) rgb_camXs_empty = np.transpose(rgb_camXs_empty, axes=[0, 3, 1, 2]) rgb_camXs_empty = rgb_camXs_empty[:, :3] rgb_camXs_empty = utils_improc.preprocess_color(rgb_camXs_empty) if hyp.dataset_name == "clevr_vqa": num_boxes = bbox_origin.shape[0] bbox_origin = np.array(bbox_origin) score = np.pad(np.ones([num_boxes]), [0, hyp.N - num_boxes]) bbox_origin = np.pad(bbox_origin, [[0, hyp.N - num_boxes], [0, 0], [0, 0]]) object_category = np.pad(object_category, [[0, hyp.N - num_boxes]], lambda x, y, z, m: "0") object_category_empty = np.pad(object_category_empty, [[0, hyp.N - 1]], lambda x, y, z, m: "0") # st() score_empty = np.zeros_like(score) bbox_origin_empty = np.zeros_like(bbox_origin) d['gt_box'] = np.stack( [bbox_origin.astype(np.float32), bbox_origin_empty]) d['gt_scores'] = np.stack([score.astype(np.float32), score_empty]) try: d['classes'] = np.stack( [object_category, object_category_empty]).tolist() except Exception as e: st() d['rgb_camXs_raw'] = np.stack([rgb_camXs, rgb_camXs_empty]) d['pix_T_cams_raw'] = np.stack( [d["pix_T_cams_raw"], d_empty["pix_T_cams_raw"]]) d['origin_T_camXs_raw'] = np.stack( [d["origin_T_camXs_raw"], d_empty["origin_T_camXs_raw"]]) d['camR_T_origin_raw'] = np.stack( [d["camR_T_origin_raw"], d_empty["camR_T_origin_raw"]]) d['xyz_camXs_raw'] = np.stack( [d["xyz_camXs_raw"], d_empty["xyz_camXs_raw"]]) # d['rgb_camXs_raw'] = rgb_camXs # d['tree_seq_filename'] = filename if not hyp.dataset_name == "clevr_vqa": d['tree_seq_filename'] = [filename, "invalid_tree"] else: d['tree_seq_filename'] = ["temp"] # st() d['filename_e'] = ["temp"] d['filename_g'] = ["temp"] if hyp.use_gt_occs: d['occR_complete'] = np.expand_dims(occ_complete, axis=0) return d
def __getitem__(self, index): if hyp.dataset_name == 'kitti' or hyp.dataset_name == 'clevr' or hyp.dataset_name == 'real' or hyp.dataset_name == "bigbird" or hyp.dataset_name == "carla" or hyp.dataset_name == "carla_mix" or hyp.dataset_name == "carla_det" or hyp.dataset_name == "replica" or hyp.dataset_name == "clevr_vqa": # print(index) filename = self.records[index] d = pickle.load(open(filename, "rb")) d = dict(d) # elif hyp.dataset_name=="carla": # filename = self.records[index] # d = np.load(filename) # d = dict(d) # d['rgb_camXs_raw'] = d['rgb_camXs'] # d['pix_T_cams_raw'] = d['pix_T_cams'] # d['tree_seq_filename'] = "dummy_tree_filename" # d['origin_T_camXs_raw'] = d['origin_T_camXs'] # d['camR_T_origin_raw'] = utils_geom.safe_inverse(torch.from_numpy(d['origin_T_camRs'])).numpy() # d['xyz_camXs_raw'] = d['xyz_camXs'] else: assert (False) # reader not ready yet # st() # if hyp.save_gt_occs: # pickle.dump(d,open(filename, "wb")) # st() # st() if hyp.use_gt_occs: __p = lambda x: utils_basic.pack_seqdim(x, 1) __u = lambda x: utils_basic.unpack_seqdim(x, 1) B, H, W, V, S, N = hyp.B, hyp.H, hyp.W, hyp.V, hyp.S, hyp.N PH, PW = hyp.PH, hyp.PW K = hyp.K BOX_SIZE = hyp.BOX_SIZE Z, Y, X = hyp.Z, hyp.Y, hyp.X Z2, Y2, X2 = int(Z / 2), int(Y / 2), int(X / 2) Z4, Y4, X4 = int(Z / 4), int(Y / 4), int(X / 4) D = 9 pix_T_cams = torch.from_numpy( d["pix_T_cams_raw"]).unsqueeze(0).cuda().to(torch.float) camRs_T_origin = torch.from_numpy( d["camR_T_origin_raw"]).unsqueeze(0).cuda().to(torch.float) origin_T_camRs = __u(utils_geom.safe_inverse(__p(camRs_T_origin))) origin_T_camXs = torch.from_numpy( d["origin_T_camXs_raw"]).unsqueeze(0).cuda().to(torch.float) camX0_T_camXs = utils_geom.get_camM_T_camXs(origin_T_camXs, ind=0) camRs_T_camXs = __u( torch.matmul(utils_geom.safe_inverse(__p(origin_T_camRs)), __p(origin_T_camXs))) camXs_T_camRs = __u(utils_geom.safe_inverse(__p(camRs_T_camXs))) camX0_T_camRs = camXs_T_camRs[:, 0] camX1_T_camRs = camXs_T_camRs[:, 1] camR_T_camX0 = utils_geom.safe_inverse(camX0_T_camRs) xyz_camXs = torch.from_numpy( d["xyz_camXs_raw"]).unsqueeze(0).cuda().to(torch.float) xyz_camRs = __u( utils_geom.apply_4x4(__p(camRs_T_camXs), __p(xyz_camXs))) depth_camXs_, valid_camXs_ = utils_geom.create_depth_image( __p(pix_T_cams), __p(xyz_camXs), H, W) dense_xyz_camXs_ = utils_geom.depth2pointcloud( depth_camXs_, __p(pix_T_cams)) occXs = __u(utils_vox.voxelize_xyz(__p(xyz_camXs), Z, Y, X)) occRs_half = __u(utils_vox.voxelize_xyz(__p(xyz_camRs), Z2, Y2, X2)) occRs_half = torch.max(occRs_half, dim=1).values.squeeze(0) occ_complete = occRs_half.cpu().numpy() # st() if hyp.do_empty: item_names = [ 'pix_T_cams_raw', 'origin_T_camXs_raw', 'camR_T_origin_raw', 'rgb_camXs_raw', 'xyz_camXs_raw', 'empty_rgb_camXs_raw', 'empty_xyz_camXs_raw', ] else: item_names = [ 'pix_T_cams_raw', 'origin_T_camXs_raw', 'camR_T_origin_raw', 'rgb_camXs_raw', 'xyz_camXs_raw', ] # if hyp.do_time_flip: # d = random_time_flip_single(d,item_names) # if the sequence length > 2, select S frames # filename = d['raw_seq_filename'] original_filename = filename if hyp.dataset_name == "carla_mix" or hyp.dataset_name == "carla_det": bbox_origin_gt = d['bbox_origin'] if 'bbox_origin_predicted' in d: bbox_origin_predicted = d['bbox_origin_predicted'] else: bbox_origin_predicted = [] classes = d['obj_name'] if isinstance(classes, str): classes = [classes] # st() d['tree_seq_filename'] = "temp" if hyp.dataset_name == "replica": d['tree_seq_filename'] = "temp" object_category = d['object_category_names'] bbox_origin = d['bbox_origin'] if hyp.dataset_name == "clevr_vqa": d['tree_seq_filename'] = "temp" pix_T_cams = d['pix_T_cams_raw'] num_cams = pix_T_cams.shape[0] # padding_1 = torch.zeros([num_cams,1,3]) # padding_2 = torch.zeros([num_cams,4,1]) # padding_2[:,3] = 1.0 # st() # pix_T_cams = torch.cat([pix_T_cams,padding_1],dim=1) # pix_T_cams = torch.cat([pix_T_cams,padding_2],dim=2) # st() shape_name = d['shape_list'] color_name = d['color_list'] material_name = d['material_list'] all_name = [] all_style = [] for index in range(len(shape_name)): name = shape_name[index] + "/" + color_name[ index] + "_" + material_name[index] style_name = color_name[index] + "_" + material_name[index] all_name.append(name) all_style.append(style_name) # st() if hyp.do_shape: class_name = shape_name elif hyp.do_color: class_name = color_name elif hyp.do_material: class_name = material_name elif hyp.do_style: class_name = all_style else: class_name = all_name object_category = class_name bbox_origin = d['bbox_origin'] # st() if hyp.dataset_name == "carla": camR_index = d['camR_index'] rgb_camtop = d['rgb_camXs_raw'][camR_index:camR_index + 1] origin_T_camXs_top = d['origin_T_camXs_raw'][ camR_index:camR_index + 1] # predicted_box = d['bbox_origin_predicted'] predicted_box = [] filename = d['tree_seq_filename'] if hyp.do_2d_style_munit: d, indexes = non_random_select_single(d, item_names, num_samples=hyp.S) # st() if hyp.fixed_view: d, indexes = non_random_select_single(d, item_names, num_samples=hyp.S) elif self.shuffle or hyp.randomly_select_views: d, indexes = random_select_single(d, item_names, num_samples=hyp.S) else: d, indexes = non_random_select_single(d, item_names, num_samples=hyp.S) filename_g = "/".join([original_filename, str(indexes[0])]) filename_e = "/".join([original_filename, str(indexes[1])]) rgb_camXs = d['rgb_camXs_raw'] # move channel dim inward, like pytorch wants # rgb_camRs = np.transpose(rgb_camRs, axes=[0, 3, 1, 2]) rgb_camXs = np.transpose(rgb_camXs, axes=[0, 3, 1, 2]) rgb_camXs = rgb_camXs[:, :3] rgb_camXs = utils_improc.preprocess_color(rgb_camXs) if hyp.dataset_name == "carla": rgb_camtop = np.transpose(rgb_camtop, axes=[0, 3, 1, 2]) rgb_camtop = rgb_camtop[:, :3] rgb_camtop = utils_improc.preprocess_color(rgb_camtop) d['rgb_camtop'] = rgb_camtop d['origin_T_camXs_top'] = origin_T_camXs_top if len(predicted_box) == 0: predicted_box = np.zeros([hyp.N, 6]) score = np.zeros([hyp.N]).astype(np.float32) else: num_boxes = predicted_box.shape[0] score = np.pad(np.ones([num_boxes]), [0, hyp.N - num_boxes]) predicted_box = np.pad(predicted_box, [[0, hyp.N - num_boxes], [0, 0]]) d['predicted_box'] = predicted_box.astype(np.float32) d['predicted_scores'] = score.astype(np.float32) if hyp.dataset_name == "clevr_vqa": num_boxes = bbox_origin.shape[0] bbox_origin = np.array(bbox_origin) score = np.pad(np.ones([num_boxes]), [0, hyp.N - num_boxes]) bbox_origin = np.pad(bbox_origin, [[0, hyp.N - num_boxes], [0, 0], [0, 0]]) object_category = np.pad(object_category, [[0, hyp.N - num_boxes]], lambda x, y, z, m: "0") d['gt_box'] = bbox_origin.astype(np.float32) d['gt_scores'] = score.astype(np.float32) d['classes'] = list(object_category) if hyp.dataset_name == "replica": if len(bbox_origin) == 0: score = np.zeros([hyp.N]) bbox_origin = np.zeros([hyp.N, 6]) object_category = ["0"] * hyp.N object_category = np.array(object_category) else: num_boxes = len(bbox_origin) bbox_origin = torch.stack(bbox_origin).numpy().squeeze( 1).squeeze(1).reshape([num_boxes, 6]) bbox_origin = np.array(bbox_origin) score = np.pad(np.ones([num_boxes]), [0, hyp.N - num_boxes]) bbox_origin = np.pad(bbox_origin, [[0, hyp.N - num_boxes], [0, 0]]) object_category = np.pad(object_category, [[0, hyp.N - num_boxes]], lambda x, y, z, m: "0") d['gt_box'] = bbox_origin.astype(np.float32) d['gt_scores'] = score.astype(np.float32) d['classes'] = list(object_category) # st() if hyp.dataset_name == "carla_mix" or hyp.dataset_name == "carla_det": bbox_origin_predicted = bbox_origin_predicted[:3] if len(bbox_origin_gt.shape) == 1: bbox_origin_gt = np.expand_dims(bbox_origin_gt, 0) num_boxes = bbox_origin_gt.shape[0] # st() score_gt = np.pad(np.ones([num_boxes]), [0, hyp.N - num_boxes]) bbox_origin_gt = np.pad(bbox_origin_gt, [[0, hyp.N - num_boxes], [0, 0]]) # st() classes = np.pad(classes, [[0, hyp.N - num_boxes]], lambda x, y, z, m: "0") if len(bbox_origin_predicted) == 0: bbox_origin_predicted = np.zeros([hyp.N, 6]) score_pred = np.zeros([hyp.N]).astype(np.float32) else: num_boxes = bbox_origin_predicted.shape[0] score_pred = np.pad(np.ones([num_boxes]), [0, hyp.N - num_boxes]) bbox_origin_predicted = np.pad( bbox_origin_predicted, [[0, hyp.N - num_boxes], [0, 0]]) d['predicted_box'] = bbox_origin_predicted.astype(np.float32) d['predicted_scores'] = score_pred.astype(np.float32) d['gt_box'] = bbox_origin_gt.astype(np.float32) d['gt_scores'] = score_gt.astype(np.float32) d['classes'] = list(classes) d['rgb_camXs_raw'] = rgb_camXs if hyp.dataset_name != "carla" and hyp.do_empty: empty_rgb_camXs = d['empty_rgb_camXs_raw'] # move channel dim inward, like pytorch wants empty_rgb_camXs = np.transpose(empty_rgb_camXs, axes=[0, 3, 1, 2]) empty_rgb_camXs = empty_rgb_camXs[:, :3] empty_rgb_camXs = utils_improc.preprocess_color(empty_rgb_camXs) d['empty_rgb_camXs_raw'] = empty_rgb_camXs # st() if hyp.use_gt_occs: d['occR_complete'] = occ_complete d['tree_seq_filename'] = filename d['filename_e'] = filename_e d['filename_g'] = filename_g return d
def forward(self, clist_cam, energy_map, occ_mems, summ_writer): total_loss = torch.tensor(0.0).cuda() B, S, C, Z, Y, X = list(occ_mems.shape) B2, S, D = list(clist_cam.shape) assert (B == B2) traj_past = clist_cam[:, :self.T_past] traj_futu = clist_cam[:, self.T_past:] # just xz traj_past = torch.stack([traj_past[:, :, 0], traj_past[:, :, 2]], dim=2) # xz traj_futu = torch.stack([traj_futu[:, :, 0], traj_futu[:, :, 2]], dim=2) # xz feat = occ_mems[:, 0].permute(0, 1, 3, 2, 4).reshape(B, C * Y, Z, X) mask = 1.0 - (feat == 0).all(dim=1, keepdim=True).float().cuda() halfgrid = utils_basic.meshgrid2D(B, int(Z / 2), int(X / 2), stack=True, norm=True).permute(0, 3, 1, 2) feat_map, _ = self.compressor(feat, mask, halfgrid) pred_map = self.conv2d(feat_map) # these are B x C x Z x X K = 12 # number of samples traj_past = traj_past.unsqueeze(0).repeat(K, 1, 1, 1) feat_map = feat_map.unsqueeze(0).repeat(K, 1, 1, 1, 1) pred_map = pred_map.unsqueeze(0).repeat(K, 1, 1, 1, 1) # to sample the K trajectories in parallel, we'll pack K onto the batch dim __p = lambda x: utils_basic.pack_seqdim(x, K) __u = lambda x: utils_basic.unpack_seqdim(x, K) traj_past_ = __p(traj_past) feat_map_ = __p(feat_map) pred_map_ = __p(pred_map) base_sample_ = torch.randn(K * B, self.T_futu, 2).cuda() traj_futu_e_ = self.compute_forward_mapping(feat_map_, pred_map_, base_sample_, traj_past_) traj_futu_e = __u(traj_futu_e_) # this is K x B x T x 2 # print('traj_futu_e', traj_futu_e.shape, traj_futu_e[0,0]) if summ_writer.save_this: o = [] for k in list(range(K)): o.append( utils_improc.preprocess_color( summ_writer.summ_traj_on_occ( '', utils_vox.Ref2Mem(self.add_fake_y(traj_futu_e[k]), Z, Y, X), occ_mems[:, 0], already_mem=True, only_return=True))) summ_writer.summ_traj_on_occ( 'rponet/traj_futu_sample_%d' % k, utils_vox.Ref2Mem(self.add_fake_y(traj_futu_e[k]), Z, Y, X), occ_mems[:, 0], already_mem=True) mean_vis = torch.max(torch.stack(o, dim=0), dim=0)[0] summ_writer.summ_rgb('rponet/traj_futu_e_mean', mean_vis) summ_writer.summ_traj_on_occ('rponet/traj_futu_g', utils_vox.Ref2Mem( self.add_fake_y(traj_futu), Z, Y, X), occ_mems[:, 0], already_mem=True) # forward loss: neg logprob of GT samples under the model # reverse loss: neg logprob of estim samples under the (approx) GT (i.e., spatial prior) forward_loss, reverse_loss = self.compute_loss(feat_map[0], pred_map[0], traj_past[0], traj_futu, traj_futu_e, energy_map) total_loss = utils_misc.add_loss('rpo/forward_loss', total_loss, forward_loss, hyp.rpo2D_forward_coeff, summ_writer) total_loss = utils_misc.add_loss('rpo/reverse_loss', total_loss, reverse_loss, hyp.rpo2D_reverse_coeff, summ_writer) return total_loss
def forward(self, clist_cam, occs, summ_writer, vox_util, suffix=''): total_loss = torch.tensor(0.0).cuda() B, S, C, Z, Y, X = list(occs.shape) B2, S2, D = list(clist_cam.shape) assert (B == B2, S == S2) assert (D == 3) if summ_writer.save_this: summ_writer.summ_traj_on_occ('motioncost/actual_traj', clist_cam, occs[:, self.T_past], vox_util, sigma=2) __p = lambda x: utils_basic.pack_seqdim(x, B) __u = lambda x: utils_basic.unpack_seqdim(x, B) # occs_ = occs.reshape(B*S, C, Z, Y, X) occs_ = __p(occs) feats_ = occs_.permute(0, 1, 3, 2, 4).reshape(B * S, C * Y, Z, X) masks_ = 1.0 - (feats_ == 0).all(dim=1, keepdim=True).float().cuda() halfgrids_ = utils_basic.meshgrid2D(B * S, int(Z / 2), int(X / 2), stack=True, norm=True).permute(0, 3, 1, 2) # feats_ = torch.cat([feats_, grids_], dim=1) feats = __u(feats_) masks = __u(masks_) halfgrids = __u(halfgrids_) input_feats = feats[:, :self.T_past] input_masks = masks[:, :self.T_past] input_halfgrids = halfgrids[:, :self.T_past] dense_feats_, _ = self.densifier(__p(input_feats), __p(input_masks), __p(input_halfgrids)) dense_feats = __u(dense_feats_) super_feat = dense_feats.reshape(B, self.T_past * self.dense_dim, int(Z / 2), int(X / 2)) cost_maps = self.motioncoster(super_feat) cost_maps = F.interpolate(cost_maps, scale_factor=4, mode='bilinear') # this is B x T_futu x Z x X cost_maps = cost_maps.clamp(-1000, 1000) # raquel says this adds stability summ_writer.summ_histogram('motioncost/cost_maps_hist', cost_maps) summ_writer.summ_oneds('motioncost/cost_maps', torch.unbind(cost_maps.unsqueeze(2), dim=1)) # next i need to sample some trajectories N = hyp.motioncost_num_negs sampled_trajs_cam = self.sample_trajs(N, clist_cam) # this is B x N x S x 3 if summ_writer.save_this: # for n in list(range(np.min([N, 3]))): # # this is 1 x S x 3 # summ_writer.summ_traj_on_occ('motioncost/sample%d_clist' % n, # sampled_trajs_cam[0, n].unsqueeze(0), # occs[:,self.T_past], # # torch.max(occs, dim=1)[0], # # torch.zeros([1, 1, Z, Y, X]).float().cuda(), # already_mem=False) o = [] for n in list(range(N)): o.append( utils_improc.preprocess_color( summ_writer.summ_traj_on_occ( '', sampled_trajs_cam[0, n].unsqueeze(0), occs[0:1, self.T_past], vox_util, only_return=True, sigma=0.5))) summ_vis = torch.max(torch.stack(o, dim=0), dim=0)[0] summ_writer.summ_rgb('motioncost/all_sampled_trajs', summ_vis) # smooth loss cost_maps_ = cost_maps.reshape(B * self.T_futu, 1, Z, X) dz, dx = gradient2D(cost_maps_, absolute=True) dt = torch.abs(cost_maps[:, 1:] - cost_maps[:, 0:-1]) smooth_spatial = torch.mean(dx + dz, dim=1, keepdims=True) smooth_time = torch.mean(dt, dim=1, keepdims=True) summ_writer.summ_oned('motioncost/smooth_loss_spatial', smooth_spatial) summ_writer.summ_oned('motioncost/smooth_loss_time', smooth_time) smooth_loss = torch.mean(smooth_spatial) + torch.mean(smooth_time) total_loss = utils_misc.add_loss('motioncost/smooth_loss', total_loss, smooth_loss, hyp.motioncost_smooth_coeff, summ_writer) # def clamp_xyz(xyz, X, Y, Z): # x, y, z = torch.unbind(xyz, dim=-1) # x = x.clamp(0, X) # y = x.clamp(0, Y) # z = x.clamp(0, Z) # # if zero_y: # # y = torch.zeros_like(y) # xyz = torch.stack([x,y,z], dim=-1) # return xyz def clamp_xz(xz, X, Z): x, z = torch.unbind(xz, dim=-1) x = x.clamp(0, X) z = x.clamp(0, Z) xz = torch.stack([x, z], dim=-1) return xz clist_mem = utils_vox.Ref2Mem(clist_cam, Z, Y, X) # this is B x S x 3 # sampled_trajs_cam is B x N x S x 3 sampled_trajs_cam_ = sampled_trajs_cam.reshape(B, N * S, 3) sampled_trajs_mem_ = utils_vox.Ref2Mem(sampled_trajs_cam_, Z, Y, X) sampled_trajs_mem = sampled_trajs_mem_.reshape(B, N, S, 3) # this is B x N x S x 3 xyz_pos_ = clist_mem[:, self.T_past:].reshape(B * self.T_futu, 1, 3) xyz_neg_ = sampled_trajs_mem[:, :, self.T_past:].permute(0, 2, 1, 3).reshape( B * self.T_futu, N, 3) # get rid of y xz_pos_ = torch.stack([xyz_pos_[:, :, 0], xyz_pos_[:, :, 2]], dim=2) xz_neg_ = torch.stack([xyz_neg_[:, :, 0], xyz_neg_[:, :, 2]], dim=2) xz_ = torch.cat([xz_pos_, xz_neg_], dim=1) xz_ = clamp_xz(xz_, X, Z) cost_maps_ = cost_maps.reshape(B * self.T_futu, 1, Z, X) cost_ = utils_samp.bilinear_sample2D(cost_maps_, xz_[:, :, 0], xz_[:, :, 1]).squeeze(1) # cost is B*T_futu x 1+N cost_pos = cost_[:, 0:1] # B*T_futu x 1 cost_neg = cost_[:, 1:] # B*T_futu x N cost_pos = cost_pos.unsqueeze(2) # B*T_futu x 1 x 1 cost_neg = cost_neg.unsqueeze(1) # B*T_futu x 1 x N utils_misc.add_loss('motioncost/mean_cost_pos', 0, torch.mean(cost_pos), 0, summ_writer) utils_misc.add_loss('motioncost/mean_cost_neg', 0, torch.mean(cost_neg), 0, summ_writer) utils_misc.add_loss('motioncost/mean_margin', 0, torch.mean(cost_neg - cost_pos), 0, summ_writer) xz_pos = xz_pos_.unsqueeze(2) # B*T_futu x 1 x 1 x 3 xz_neg = xz_neg_.unsqueeze(1) # B*T_futu x 1 x N x 3 dist = torch.norm(xz_pos - xz_neg, dim=3) # B*T_futu x 1 x N dist = dist / float( Z) * 5.0 # normalize for resolution, but upweight it a bit margin = F.relu(cost_pos - cost_neg + dist) margin = margin.reshape(B, self.T_futu, N) # mean over time (in the paper this is a sum) margin = torch.mean(margin, dim=1) # max over the negatives maxmargin = torch.max(margin, dim=1)[0] # B maxmargin_loss = torch.mean(maxmargin) total_loss = utils_misc.add_loss('motioncost/maxmargin_loss', total_loss, maxmargin_loss, hyp.motioncost_maxmargin_coeff, summ_writer) # now let's see some top k # we'll do this for the first el of the batch cost_neg = cost_neg.reshape(B, self.T_futu, N)[0].detach().cpu().numpy() futu_mem = sampled_trajs_mem[:, :, self.T_past:].reshape( B, N, self.T_futu, 3)[0:1] cost_neg = np.reshape(cost_neg, [self.T_futu, N]) cost_neg = np.sum(cost_neg, axis=0) inds = np.argsort(cost_neg, axis=0) for n in list(range(2)): xyzlist_e_mem = futu_mem[0:1, inds[n]] xyzlist_e_cam = utils_vox.Mem2Ref(xyzlist_e_mem, Z, Y, X) # this is B x S x 3 if summ_writer.save_this and n == 0: print('xyzlist_e_cam', xyzlist_e_cam[0:1]) print('xyzlist_g_cam', clist_cam[0:1, self.T_past:]) dist = torch.norm(clist_cam[0:1, self.T_past:] - xyzlist_e_cam[0:1], dim=2) # this is B x T_futu meandist = torch.mean(dist) utils_misc.add_loss('motioncost/xyz_dist_%d' % n, 0, meandist, 0, summ_writer) if summ_writer.save_this: # plot the best and worst trajs # print('sorted costs:', cost_neg[inds]) for n in list(range(2)): ind = inds[n] print('plotting good traj with cost %.2f' % (cost_neg[ind])) xyzlist_e_mem = sampled_trajs_mem[:, ind] # this is 1 x S x 3 summ_writer.summ_traj_on_occ('motioncost/best_sampled_traj%d' % n, xyzlist_e_mem[0:1], occs[0:1, self.T_past], vox_util, already_mem=True, sigma=2) for n in list(range(2)): ind = inds[-(n + 1)] print('plotting bad traj with cost %.2f' % (cost_neg[ind])) xyzlist_e_mem = sampled_trajs_mem[:, ind] # this is 1 x S x 3 summ_writer.summ_traj_on_occ( 'motioncost/worst_sampled_traj%d' % n, xyzlist_e_mem[0:1], occs[0:1, self.T_past], vox_util, already_mem=True, sigma=2) # xyzlist_e_mem = utils_vox.Ref2Mem(xyzlist_e, Z, Y, X) # xyzlist_g_mem = utils_vox.Ref2Mem(xyzlist_g, Z, Y, X) # summ_writer.summ_traj_on_occ('motioncost/traj_e', # xyzlist_e_mem, # torch.max(occs, dim=1)[0], # already_mem=True, # sigma=2) # summ_writer.summ_traj_on_occ('motioncost/traj_g', # xyzlist_g_mem, # torch.max(occs, dim=1)[0], # already_mem=True, # sigma=2) # scorelist_here = scorelist[:,self.num_given:,0] # sql2 = torch.sum((vel_g-vel_e)**2, dim=2) # ## yes weightmask # weightmask = torch.arange(0, self.num_need, dtype=torch.float32, device=torch.device('cuda')) # weightmask = torch.exp(-weightmask**(1./4)) # # 1.0000, 0.3679, 0.3045, 0.2682, 0.2431, 0.2242, 0.2091, 0.1966, 0.1860, # # 0.1769, 0.1689, 0.1618, 0.1555, 0.1497, 0.1445, 0.1397, 0.1353 # weightmask = weightmask.reshape(1, self.num_need) # l2_loss = utils_basic.reduce_masked_mean(sql2, scorelist_here * weightmask) # utils_misc.add_loss('motioncost/l2_loss', 0, l2_loss, 0, summ_writer) # # # no weightmask: # # l2_loss = utils_basic.reduce_masked_mean(sql2, scorelist_here) # # total_loss = utils_misc.add_loss('motioncost/l2_loss', total_loss, l2_loss, hyp.motioncost_l2_coeff, summ_writer) # dist = torch.norm(xyzlist_e - xyzlist_g, dim=2) # meandist = utils_basic.reduce_masked_mean(dist, scorelist[:,:,0]) # utils_misc.add_loss('motioncost/xyz_dist_0', 0, meandist, 0, summ_writer) # l2_loss_noexp = utils_basic.reduce_masked_mean(sql2, scorelist_here) # # utils_misc.add_loss('motioncost/vel_dist_noexp', 0, l2_loss, 0, summ_writer) # total_loss = utils_misc.add_loss('motioncost/l2_loss_noexp', total_loss, l2_loss_noexp, hyp.motioncost_l2_coeff, summ_writer) return total_loss
def run_test(self, feed): results = dict() global_step = feed['global_step'] total_loss = torch.tensor(0.0).cuda() __p = lambda x: utils_basic.pack_seqdim(x, self.B) __u = lambda x: utils_basic.unpack_seqdim(x, self.B) self.obj_clist_camX0 = utils_geom.get_clist_from_lrtlist( self.lrt_camX0s) self.original_centroid = self.scene_centroid.clone() obj_lengths, cams_T_obj0 = utils_geom.split_lrtlist(self.lrt_camX0s) obj_length = obj_lengths[:, 0] for b in list(range(self.B)): if self.score_s[b, 0] < 1.0: # we need the template to exist print('returning early, since score_s[%d,0] = %.1f' % (b, self.score_s[b, 0].cpu().numpy())) return total_loss, results, True # if torch.sum(self.score_s[b]) < (self.S/2): if not (torch.sum(self.score_s[b]) == self.S): # the full traj should be valid print( 'returning early, since sum(score_s) = %d, while S = %d' % (torch.sum(self.score_s).cpu().numpy(), self.S)) return total_loss, results, True if hyp.do_feat3D: feat_memX0_input = torch.cat([ self.occ_memX0s[:, 0], self.unp_memX0s[:, 0] * self.occ_memX0s[:, 0], ], dim=1) _, feat_memX0, valid_memX0 = self.featnet3D(feat_memX0_input) B, C, Z, Y, X = list(feat_memX0.shape) S = self.S obj_mask_memX0s = self.vox_util.assemble_padded_obj_masklist( self.lrt_camX0s, self.score_s, Z, Y, X).squeeze(1) # only take the occupied voxels occ_memX0 = self.vox_util.voxelize_xyz(self.xyz_camX0s[:, 0], Z, Y, X) # obj_mask_memX0 = obj_mask_memX0s[:,0] * occ_memX0 obj_mask_memX0 = obj_mask_memX0s[:, 0] # discard the known freespace _, free_memX0_, _, _ = self.vox_util.prep_occs_supervision( self.camX0s_T_camXs[:, 0:1], self.xyz_camXs[:, 0:1], Z, Y, X, agg=True) free_memX0 = free_memX0_.squeeze(1) obj_mask_memX0 = obj_mask_memX0 * (1.0 - free_memX0) for b in list(range(self.B)): if torch.sum(obj_mask_memX0[b] * occ_memX0[b]) <= 8: print( 'returning early, since there are not enough valid object points' ) return total_loss, results, True # for b in list(range(self.B)): # sum_b = torch.sum(obj_mask_memX0[b]) # print('sum_b', sum_b.detach().cpu().numpy()) # if sum_b > 1000: # obj_mask_memX0[b] *= occ_memX0[b] # sum_b = torch.sum(obj_mask_memX0[b]) # print('reducing this to', sum_b.detach().cpu().numpy()) feat0_vec = feat_memX0.view(B, hyp.feat3D_dim, -1) # this is B x C x huge feat0_vec = feat0_vec.permute(0, 2, 1) # this is B x huge x C obj_mask0_vec = obj_mask_memX0.reshape(B, -1).round() occ_mask0_vec = occ_memX0.reshape(B, -1).round() free_mask0_vec = free_memX0.reshape(B, -1).round() # these are B x huge orig_xyz = utils_basic.gridcloud3D(B, Z, Y, X) # this is B x huge x 3 obj_lengths, cams_T_obj0 = utils_geom.split_lrtlist( self.lrt_camX0s) obj_length = obj_lengths[:, 0] cam0_T_obj = cams_T_obj0[:, 0] # this is B x S x 4 x 4 mem_T_cam = self.vox_util.get_mem_T_ref(B, Z, Y, X) cam_T_mem = self.vox_util.get_ref_T_mem(B, Z, Y, X) lrt_camIs_g = self.lrt_camX0s.clone() lrt_camIs_e = torch.zeros_like(self.lrt_camX0s) # we will fill this up ious = torch.zeros([B, S]).float().cuda() point_counts = np.zeros([B, S]) inb_counts = np.zeros([B, S]) feat_vis = [] occ_vis = [] for s in range(self.S): if not (s == 0): # remake the vox util and all the mem data self.scene_centroid = utils_geom.get_clist_from_lrtlist( lrt_camIs_e[:, s - 1:s])[:, 0] delta = self.scene_centroid - self.original_centroid self.vox_util = vox_util.Vox_util( self.Z, self.Y, self.X, self.set_name, scene_centroid=self.scene_centroid, assert_cube=True) self.occ_memXs = __u( self.vox_util.voxelize_xyz(__p(self.xyz_camXs), self.Z, self.Y, self.X)) self.occ_memX0s = __u( self.vox_util.voxelize_xyz(__p(self.xyz_camX0s), self.Z, self.Y, self.X)) self.unp_memXs = __u( self.vox_util.unproject_rgb_to_mem( __p(self.rgb_camXs), self.Z, self.Y, self.X, __p(self.pix_T_cams))) self.unp_memX0s = self.vox_util.apply_4x4s_to_voxs( self.camX0s_T_camXs, self.unp_memXs) self.summ_writer.summ_occ('track/reloc_occ_%d' % s, self.occ_memX0s[:, s]) else: self.summ_writer.summ_occ('track/init_occ_%d' % s, self.occ_memX0s[:, s]) delta = torch.zeros([B, 3]).float().cuda() # print('scene centroid:', self.scene_centroid.detach().cpu().numpy()) occ_vis.append( self.summ_writer.summ_occ('', self.occ_memX0s[:, s], only_return=True)) # inb = __u(self.vox_util.get_inbounds(__p(self.xyz_camX0s), self.Z4, self.Y4, self.X, already_mem=False)) inb = self.vox_util.get_inbounds(self.xyz_camX0s[:, s], self.Z4, self.Y4, self.X, already_mem=False) num_inb = torch.sum(inb.float(), axis=1) # print('num_inb', num_inb, num_inb.shape) inb_counts[:, s] = num_inb.cpu().numpy() feat_memI_input = torch.cat([ self.occ_memX0s[:, s], self.unp_memX0s[:, s] * self.occ_memX0s[:, s], ], dim=1) _, feat_memI, valid_memI = self.featnet3D(feat_memI_input) self.summ_writer.summ_feat('3D_feats/feat_%d_input' % s, feat_memI_input, pca=True) self.summ_writer.summ_feat('3D_feats/feat_%d' % s, feat_memI, pca=True) feat_vis.append( self.summ_writer.summ_feat('', feat_memI, pca=True, only_return=True)) # collect freespace here, to discard bad matches _, free_memI_, _, _ = self.vox_util.prep_occs_supervision( self.camX0s_T_camXs[:, s:s + 1], self.xyz_camXs[:, s:s + 1], Z, Y, X, agg=True) free_memI = free_memI_.squeeze(1) feat_vec = feat_memI.view(B, hyp.feat3D_dim, -1) # this is B x C x huge feat_vec = feat_vec.permute(0, 2, 1) # this is B x huge x C memI_T_mem0 = utils_geom.eye_4x4(B) # we will fill this up # # put these on cpu, to save mem # feat0_vec = feat0_vec.detach().cpu() # feat_vec = feat_vec.detach().cpu() # to simplify the impl, we will iterate over the batch dim for b in list(range(B)): feat_vec_b = feat_vec[b] feat0_vec_b = feat0_vec[b] obj_mask0_vec_b = obj_mask0_vec[b] occ_mask0_vec_b = occ_mask0_vec[b] free_mask0_vec_b = free_mask0_vec[b] orig_xyz_b = orig_xyz[b] # these are huge x C careful = False if careful: # start with occ points, since these are definitely observed obj_inds_b = torch.where( (occ_mask0_vec_b * obj_mask0_vec_b) > 0) obj_vec_b = feat0_vec_b[obj_inds_b] xyz0 = orig_xyz_b[obj_inds_b] # these are med x C # also take random non-free non-occ points in the mask ok_mask = obj_mask0_vec_b * (1.0 - occ_mask0_vec_b) * ( 1.0 - free_mask0_vec_b) alt_inds_b = torch.where(ok_mask > 0) alt_vec_b = feat0_vec_b[alt_inds_b] alt_xyz0 = orig_xyz_b[alt_inds_b] # these are med x C # issues arise when "med" is too large num = len(alt_xyz0) max_pts = 2000 if num > max_pts: # print('have %d pts; taking a random set of %d pts inside' % (num, max_pts)) perm = np.random.permutation(num) alt_vec_b = alt_vec_b[perm[:max_pts]] alt_xyz0 = alt_xyz0[perm[:max_pts]] obj_vec_b = torch.cat([obj_vec_b, alt_vec_b], dim=0) xyz0 = torch.cat([xyz0, alt_xyz0], dim=0) if s == 0: print('have %d pts in total' % (len(xyz0))) else: # take any points within the mask obj_inds_b = torch.where(obj_mask0_vec_b > 0) obj_vec_b = feat0_vec_b[obj_inds_b] xyz0 = orig_xyz_b[obj_inds_b] # these are med x C # issues arise when "med" is too large # trim down to max_pts num = len(xyz0) max_pts = 2000 if num > max_pts: print( 'have %d pts; taking a random set of %d pts inside' % (num, max_pts)) perm = np.random.permutation(num) obj_vec_b = obj_vec_b[perm[:max_pts]] xyz0 = xyz0[perm[:max_pts]] obj_vec_b = obj_vec_b.permute(1, 0) # this is is C x med corr_b = torch.matmul(feat_vec_b, obj_vec_b) # this is huge x med heat_b = corr_b.permute(1, 0).reshape(-1, 1, Z, Y, X) # this is med x 1 x Z4 x Y4 x X4 # # for numerical stability, we sub the max, and mult by the resolution # heat_b_ = heat_b.reshape(-1, Z*Y*X) # heat_b_max = (torch.max(heat_b_, dim=1).values).reshape(-1, 1, 1, 1, 1) # heat_b = heat_b - heat_b_max # heat_b = heat_b * float(len(heat_b[0].reshape(-1))) # # for numerical stability, we sub the max, and mult by the resolution # heat_b_ = heat_b.reshape(-1, Z*Y*X) # heat_b_max = (torch.max(heat_b_, dim=1).values).reshape(-1, 1, 1, 1, 1) # heat_b = heat_b - heat_b_max # heat_b = heat_b * float(len(heat_b[0].reshape(-1))) # heat_b_ = heat_b.reshape(-1, Z*Y*X) # # heat_b_min = (torch.min(heat_b_, dim=1).values).reshape(-1, 1, 1, 1, 1) # heat_b_min = (torch.min(heat_b_).values) # free_b = free_memI[b:b+1] # print('free_b', free_b.shape) # print('heat_b', heat_b.shape) # heat_b[free_b > 0.0] = heat_b_min # make the min zero heat_b_ = heat_b.reshape(-1, Z * Y * X) heat_b_min = (torch.min(heat_b_, dim=1).values).reshape( -1, 1, 1, 1, 1) heat_b = heat_b - heat_b_min # zero out the freespace heat_b = heat_b * (1.0 - free_memI[b:b + 1]) # make the max zero heat_b_ = heat_b.reshape(-1, Z * Y * X) heat_b_max = (torch.max(heat_b_, dim=1).values).reshape( -1, 1, 1, 1, 1) heat_b = heat_b - heat_b_max # scale up, for numerical stability heat_b = heat_b * float(len(heat_b[0].reshape(-1))) xyzI = utils_basic.argmax3D(heat_b, hard=False, stack=True) # xyzI = utils_basic.argmax3D(heat_b*float(Z*10), hard=False, stack=True) # this is med x 3 xyzI_cam = self.vox_util.Mem2Ref(xyzI.unsqueeze(1), Z, Y, X) xyzI_cam += delta xyzI = self.vox_util.Ref2Mem(xyzI_cam, Z, Y, X).squeeze(1) memI_T_mem0[b] = utils_track.rigid_transform_3D(xyz0, xyzI) # record #points, since ransac depends on this point_counts[b, s] = len(xyz0) # done stepping through batch mem0_T_memI = utils_geom.safe_inverse(memI_T_mem0) cam0_T_camI = utils_basic.matmul3(cam_T_mem, mem0_T_memI, mem_T_cam) # eval camI_T_obj = utils_basic.matmul4(cam_T_mem, memI_T_mem0, mem_T_cam, cam0_T_obj) # this is B x 4 x 4 lrt_camIs_e[:, s] = utils_geom.merge_lrt(obj_length, camI_T_obj) ious[:, s] = utils_geom.get_iou_from_corresponded_lrtlists( lrt_camIs_e[:, s:s + 1], lrt_camIs_g[:, s:s + 1]).squeeze(1) results['ious'] = ious # if ious[0,-1] > 0.5: # print('returning early, since acc is too high') # return total_loss, results, True self.summ_writer.summ_rgbs('track/feats', feat_vis) self.summ_writer.summ_oneds('track/occs', occ_vis, norm=False) for s in range(self.S): self.summ_writer.summ_scalar( 'track/mean_iou_%02d' % s, torch.mean(ious[:, s]).cpu().item()) self.summ_writer.summ_scalar('track/mean_iou', torch.mean(ious).cpu().item()) self.summ_writer.summ_scalar('track/point_counts', np.mean(point_counts)) # self.summ_writer.summ_scalar('track/inb_counts', torch.mean(inb_counts).cpu().item()) self.summ_writer.summ_scalar('track/inb_counts', np.mean(inb_counts)) lrt_camX0s_e = lrt_camIs_e.clone() lrt_camXs_e = utils_geom.apply_4x4s_to_lrts( self.camXs_T_camX0s, lrt_camX0s_e) if self.include_vis: visX_e = [] for s in list(range(self.S)): visX_e.append( self.summ_writer.summ_lrtlist('track/box_camX%d_e' % s, self.rgb_camXs[:, s], lrt_camXs_e[:, s:s + 1], self.score_s[:, s:s + 1], self.tid_s[:, s:s + 1], self.pix_T_cams[:, 0], only_return=True)) self.summ_writer.summ_rgbs('track/box_camXs_e', visX_e) visX_g = [] for s in list(range(self.S)): visX_g.append( self.summ_writer.summ_lrtlist('track/box_camX%d_g' % s, self.rgb_camXs[:, s], self.lrt_camXs[:, s:s + 1], self.score_s[:, s:s + 1], self.tid_s[:, s:s + 1], self.pix_T_cams[:, 0], only_return=True)) self.summ_writer.summ_rgbs('track/box_camXs_g', visX_g) obj_clist_camX0_e = utils_geom.get_clist_from_lrtlist(lrt_camX0s_e) dists = torch.norm(obj_clist_camX0_e - self.obj_clist_camX0, dim=2) # this is B x S mean_dist = utils_basic.reduce_masked_mean(dists, self.score_s) median_dist = utils_basic.reduce_masked_median(dists, self.score_s) # this is [] self.summ_writer.summ_scalar('track/centroid_dist_mean', mean_dist.cpu().item()) self.summ_writer.summ_scalar('track/centroid_dist_median', median_dist.cpu().item()) # if self.include_vis: if (True): self.summ_writer.summ_traj_on_occ('track/traj_e', obj_clist_camX0_e, self.occ_memX0s[:, 0], self.vox_util, already_mem=False, sigma=2) self.summ_writer.summ_traj_on_occ('track/traj_g', self.obj_clist_camX0, self.occ_memX0s[:, 0], self.vox_util, already_mem=False, sigma=2) total_loss += mean_dist # we won't backprop, but it's nice to plot and print this anyway else: ious = torch.zeros([self.B, self.S]).float().cuda() for s in list(range(self.S)): ious[:, s] = utils_geom.get_iou_from_corresponded_lrtlists( self.lrt_camX0s[:, 0:1], self.lrt_camX0s[:, s:s + 1]).squeeze(1) results['ious'] = ious for s in range(self.S): self.summ_writer.summ_scalar( 'track/mean_iou_%02d' % s, torch.mean(ious[:, s]).cpu().item()) self.summ_writer.summ_scalar('track/mean_iou', torch.mean(ious).cpu().item()) lrt_camX0s_e = self.lrt_camX0s[:, 0:1].repeat(1, self.S, 1) obj_clist_camX0_e = utils_geom.get_clist_from_lrtlist(lrt_camX0s_e) self.summ_writer.summ_traj_on_occ('track/traj_e', obj_clist_camX0_e, self.occ_memX0s[:, 0], self.vox_util, already_mem=False, sigma=2) self.summ_writer.summ_traj_on_occ('track/traj_g', self.obj_clist_camX0, self.occ_memX0s[:, 0], self.vox_util, already_mem=False, sigma=2) self.summ_writer.summ_scalar('loss', total_loss.cpu().item()) return total_loss, results, False
def run_train(self, feed): results = dict() global_step = feed['global_step'] total_loss = torch.tensor(0.0).cuda() __p = lambda x: utils_basic.pack_seqdim(x, self.B) __u = lambda x: utils_basic.unpack_seqdim(x, self.B) if hyp.do_feat3D: feat_memX0s_input = torch.cat([ self.occ_memX0s, self.unp_memX0s * self.occ_memX0s, ], dim=2) feat3D_loss, feat_memX0s_, valid_memX0s_ = self.featnet3D( __p(feat_memX0s_input[:, 1:]), self.summ_writer, ) feat_memX0s = __u(feat_memX0s_) valid_memX0s = __u(valid_memX0s_) total_loss += feat3D_loss feat_memX0 = utils_basic.reduce_masked_mean( feat_memX0s, valid_memX0s.repeat(1, 1, hyp.feat3D_dim, 1, 1, 1), dim=1) valid_memX0 = torch.sum(valid_memX0s, dim=1).clamp(0, 1) self.summ_writer.summ_feat('3D_feats/feat_memX0', feat_memX0, valid=valid_memX0, pca=True) self.summ_writer.summ_feat('3D_feats/valid_memX0', valid_memX0, pca=False) if hyp.do_emb3D: _, altfeat_memX0, altvalid_memX0 = self.featnet3D_slow( feat_memX0s_input[:, 0]) self.summ_writer.summ_feat('3D_feats/altfeat_memX0', altfeat_memX0, valid=altvalid_memX0, pca=True) self.summ_writer.summ_feat('3D_feats/altvalid_memX0', altvalid_memX0, pca=False) if hyp.do_emb3D: if hyp.do_feat3D: _, _, Z_, Y_, X_ = list(feat_memX0.shape) else: Z_, Y_, X_ = self.Z2, self.Y2, self.X2 # Z_, Y_, X_ = self.Z, self.Y, self.X occ_memX0s, free_memX0s, _, _ = self.vox_util.prep_occs_supervision( self.camX0s_T_camXs, self.xyz_camXs, Z_, Y_, X_, agg=False) not_ok = torch.zeros_like(occ_memX0s[:, 0]) # it's not ok for a voxel to be marked occ only once not_ok += (torch.sum(occ_memX0s, dim=1) == 1.0).float() # it's not ok for a voxel to be marked occ AND free occ_agg = torch.sum(occ_memX0s, dim=1).clamp(0, 1) free_agg = torch.sum(free_memX0s, dim=1).clamp(0, 1) have_either = (occ_agg + free_agg).clamp(0, 1) have_both = occ_agg * free_agg not_ok += have_either * have_both # it's not ok for a voxel to be totally unobserved not_ok += (have_either == 0.0).float() not_ok = not_ok.clamp(0, 1) self.summ_writer.summ_occ('rely/not_ok', not_ok) self.summ_writer.summ_occ( 'rely/not_ok_occ', not_ok * torch.max(self.occ_memX0s_half, dim=1)[0]) self.summ_writer.summ_occ( 'rely/ok_occ', (1.0 - not_ok) * torch.max(self.occ_memX0s_half, dim=1)[0]) self.summ_writer.summ_occ( 'rely/aggressive_occ', torch.max(self.occ_memX0s_half, dim=1)[0]) be_safe = False if hyp.do_feat3D and be_safe: # update the valid masks valid_memX0 = valid_memX0 * (1.0 - not_ok) altvalid_memX0 = altvalid_memX0 * (1.0 - not_ok) if hyp.do_occ: _, _, Z_, Y_, X_ = list(feat_memX0.shape) occ_memX0_sup, free_memX0_sup, _, free_memX0s = self.vox_util.prep_occs_supervision( self.camX0s_T_camXs, self.xyz_camXs, Z_, Y_, X_, agg=True) self.summ_writer.summ_occ('occ_sup/occ_sup', occ_memX0_sup) self.summ_writer.summ_occ('occ_sup/free_sup', free_memX0_sup) self.summ_writer.summ_occs('occ_sup/freeX0s_sup', torch.unbind(free_memX0s, dim=1)) self.summ_writer.summ_occs( 'occ_sup/occX0s_sup', torch.unbind(self.occ_memX0s_half, dim=1)) occ_loss, occ_memX0_pred = self.occnet(altfeat_memX0, occ_memX0_sup, free_memX0_sup, altvalid_memX0, self.summ_writer) total_loss += occ_loss if hyp.do_emb3D: # compute 3D ML emb_loss_3D = self.embnet3D(feat_memX0, altfeat_memX0, valid_memX0.round(), altvalid_memX0.round(), self.summ_writer) total_loss += emb_loss_3D self.summ_writer.summ_scalar('loss', total_loss.cpu().item()) return total_loss, results, False
def prepare_common_tensors(self, feed, prep_summ=True): results = dict() if prep_summ: self.summ_writer = utils_improc.Summ_writer( writer=feed['writer'], global_step=feed['global_step'], log_freq=feed['set_log_freq'], fps=8, just_gif=feed['just_gif'], ) else: self.summ_writer = None self.include_vis = hyp.do_include_vis self.B = feed["set_batch_size"] self.S = feed["set_seqlen"] __p = lambda x: utils_basic.pack_seqdim(x, self.B) __u = lambda x: utils_basic.unpack_seqdim(x, self.B) self.H, self.W, self.V, self.N = hyp.H, hyp.W, hyp.V, hyp.N self.PH, self.PW = hyp.PH, hyp.PW self.K = hyp.K self.set_name = feed['set_name'] # print('set_name', self.set_name) if self.set_name == 'test': self.Z, self.Y, self.X = hyp.Z_test, hyp.Y_test, hyp.X_test else: self.Z, self.Y, self.X = hyp.Z, hyp.Y, hyp.X # print('Z, Y, X = %d, %d, %d' % (self.Z, self.Y, self.X)) self.Z2, self.Y2, self.X2 = int(self.Z / 2), int(self.Y / 2), int( self.X / 2) self.Z4, self.Y4, self.X4 = int(self.Z / 4), int(self.Y / 4), int( self.X / 4) self.rgb_camXs = feed["rgb_camXs"] self.pix_T_cams = feed["pix_T_cams"] self.origin_T_camXs = feed["origin_T_camXs"] self.cams_T_velos = feed["cams_T_velos"] self.camX0s_T_camXs = utils_geom.get_camM_T_camXs(self.origin_T_camXs, ind=0) self.camXs_T_camX0s = __u( utils_geom.safe_inverse(__p(self.camX0s_T_camXs))) self.xyz_veloXs = feed["xyz_veloXs"] self.xyz_camXs = __u( utils_geom.apply_4x4(__p(self.cams_T_velos), __p(self.xyz_veloXs))) self.xyz_camX0s = __u( utils_geom.apply_4x4(__p(self.camX0s_T_camXs), __p(self.xyz_camXs))) if self.set_name == 'test': self.boxlist_camXs = feed["boxlists"] self.scorelist_s = feed["scorelists"] self.tidlist_s = feed["tidlists"] boxlist_camXs_ = __p(self.boxlist_camXs) scorelist_s_ = __p(self.scorelist_s) tidlist_s_ = __p(self.tidlist_s) boxlist_camXs_, tidlist_s_, scorelist_s_ = utils_misc.shuffle_valid_and_sink_invalid_boxes( boxlist_camXs_, tidlist_s_, scorelist_s_) self.boxlist_camXs = __u(boxlist_camXs_) self.scorelist_s = __u(scorelist_s_) self.tidlist_s = __u(tidlist_s_) # self.boxlist_camXs[:,0], self.scorelist_s[:,0], self.tidlist_s[:,0] = utils_misc.shuffle_valid_and_sink_invalid_boxes( # self.boxlist_camXs[:,0], self.tidlist_s[:,0], self.scorelist_s[:,0]) # self.score_s = feed["scorelists"] # self.tid_s = torch.ones_like(self.score_s).long() # self.lrt_camRs = utils_geom.convert_boxlist_to_lrtlist(self.box_camRs) # self.lrt_camXs = utils_geom.apply_4x4s_to_lrts(self.camXs_T_camRs, self.lrt_camRs) # self.lrt_camX0s = utils_geom.apply_4x4s_to_lrts(self.camX0s_T_camXs, self.lrt_camXs) # self.lrt_camR0s = utils_geom.apply_4x4s_to_lrts(self.camR0s_T_camRs, self.lrt_camRs) # boxlist_camXs_ = __p(self.boxlist_camXs) # boxlist_camXs_ = __p(self.boxlist_camXs) # lrtlist_camXs = __u(utils_geom.convert_boxlist_to_lrtlist(__p(self.boxlist_camXs))).reshape( # self.B, self.S, self.N, 19) self.lrtlist_camXs = __u( utils_geom.convert_boxlist_to_lrtlist(__p(self.boxlist_camXs))) # print('lrtlist_camXs', lrtlist_camXs.shape) # # self.B, self.S, self.N, 19) # # lrtlist_camXs = __u(utils_geom.apply_4x4_to_lrtlist(__p(camXs_T_camRs), __p(lrtlist_camRs))) # self.summ_writer.summ_lrtlist('2D_inputs/lrtlist_camX0', self.rgb_camXs[:,0], lrtlist_camXs[:,0], # self.scorelist_s[:,0], self.tidlist_s[:,0], self.pix_T_cams[:,0]) # self.summ_writer.summ_lrtlist('2D_inputs/lrtlist_camX1', self.rgb_camXs[:,1], lrtlist_camXs[:,1], # self.scorelist_s[:,1], self.tidlist_s[:,1], self.pix_T_cams[:,1]) ( self.lrt_camXs, self.box_camXs, self.score_s, ) = utils_misc.collect_object_info(self.lrtlist_camXs, self.boxlist_camXs, self.tidlist_s, self.scorelist_s, 1, mod='X', do_vis=False, summ_writer=None) self.lrt_camXs = self.lrt_camXs.squeeze(0) self.score_s = self.score_s.squeeze(0) self.tid_s = torch.ones_like(self.score_s).long() self.lrt_camX0s = utils_geom.apply_4x4s_to_lrts( self.camX0s_T_camXs, self.lrt_camXs) if prep_summ and self.include_vis: visX_g = [] for s in list(range(self.S)): visX_g.append( self.summ_writer.summ_lrtlist('', self.rgb_camXs[:, s], self.lrtlist_camXs[:, s], self.scorelist_s[:, s], self.tidlist_s[:, s], self.pix_T_cams[:, 0], only_return=True)) self.summ_writer.summ_rgbs('2D_inputs/box_camXs', visX_g) # visX_g = [] # for s in list(range(self.S)): # visX_g.append(self.summ_writer.summ_lrtlist( # 'track/box_camX%d_g' % s, self.rgb_camXs[:,s], self.lrt_camXs[:,s:s+1], # self.score_s[:,s:s+1], self.tid_s[:,s:s+1], self.pix_T_cams[:,0], only_return=True)) # self.summ_writer.summ_rgbs('track/box_camXs_g', visX_g) if self.set_name == 'test': # center on an object, so that it does not fall out of bounds self.scene_centroid = utils_geom.get_clist_from_lrtlist( self.lrt_camXs)[:, 0] self.vox_util = vox_util.Vox_util( self.Z, self.Y, self.X, self.set_name, scene_centroid=self.scene_centroid, assert_cube=True) else: # center randomly scene_centroid_x = np.random.uniform(-8.0, 8.0) scene_centroid_y = np.random.uniform(-1.5, 3.0) scene_centroid_z = np.random.uniform(10.0, 26.0) scene_centroid = np.array( [scene_centroid_x, scene_centroid_y, scene_centroid_z]).reshape([1, 3]) self.scene_centroid = torch.from_numpy( scene_centroid).float().cuda() # center on a random non-outlier point all_ok = False num_tries = 0 while not all_ok: scene_centroid_x = np.random.uniform(-8.0, 8.0) scene_centroid_y = np.random.uniform(-1.5, 3.0) scene_centroid_z = np.random.uniform(10.0, 26.0) scene_centroid = np.array( [scene_centroid_x, scene_centroid_y, scene_centroid_z]).reshape([1, 3]) self.scene_centroid = torch.from_numpy( scene_centroid).float().cuda() num_tries += 1 # try to vox self.vox_util = vox_util.Vox_util( self.Z, self.Y, self.X, self.set_name, scene_centroid=self.scene_centroid, assert_cube=True) all_ok = True # we want to ensure this gives us a few points inbound for each batch el inb = __u( self.vox_util.get_inbounds(__p(self.xyz_camX0s), self.Z4, self.Y4, self.X, already_mem=False)) num_inb = torch.sum(inb.float(), axis=2) if torch.min(num_inb) < 100: all_ok = False if num_tries > 100: return False self.summ_writer.summ_scalar('zoom_sampling/num_tries', num_tries) self.summ_writer.summ_scalar('zoom_sampling/num_inb', torch.mean(num_inb).cpu().item()) self.occ_memXs = __u( self.vox_util.voxelize_xyz(__p(self.xyz_camXs), self.Z, self.Y, self.X)) self.occ_memX0s = __u( self.vox_util.voxelize_xyz(__p(self.xyz_camX0s), self.Z, self.Y, self.X)) self.occ_memX0s_half = __u( self.vox_util.voxelize_xyz(__p(self.xyz_camX0s), self.Z2, self.Y2, self.X2)) self.unp_memXs = __u( self.vox_util.unproject_rgb_to_mem(__p(self.rgb_camXs), self.Z, self.Y, self.X, __p(self.pix_T_cams))) self.unp_memX0s = self.vox_util.apply_4x4s_to_voxs( self.camX0s_T_camXs, self.unp_memXs) if prep_summ and self.include_vis: self.summ_writer.summ_rgbs('2D_inputs/rgb_camXs', torch.unbind(self.rgb_camXs, dim=1)) self.summ_writer.summ_occs('3D_inputs/occ_memXs', torch.unbind(self.occ_memXs, dim=1)) self.summ_writer.summ_occs('3D_inputs/occ_memX0s', torch.unbind(self.occ_memX0s, dim=1)) self.summ_writer.summ_rgb('2D_inputs/rgb_camX0', self.rgb_camXs[:, 0]) # self.summ_writer.summ_oned('2D_inputs/depth_camX0', self.depth_camXs[:,0], maxval=20.0) # self.summ_writer.summ_oned('2D_inputs/valid_camX0', self.valid_camXs[:,0], norm=False) return True
def forward(self, feat, obj_lrtlist_cams, obj_scorelist_s, summ_writer, suffix=''): total_loss = torch.tensor(0.0).cuda() B, C, Z, Y, X = list(feat.shape) N, B2, S, D = list(obj_lrtlist_cams.shape) assert(B==B2) # obj_scorelist_s is N x B x S obj_lrtlist_cams_ = obj_lrtlist_cams.reshape(N*B, S, 19) obj_clist_cam_ = utils_geom.get_clist_from_lrtlist(obj_lrtlist_cams_) obj_clist_cam = obj_clist_cam_.reshape(N, B, S, 1, 3) # obj_clist_cam is N x B x S x 1 x 3 obj_clist_cam = obj_clist_cam.squeeze(3) # # obj_clist_cam is N x B x S x 3 # clist_cam = obj_clist_cam.reshape(N*B, S, 3) # clist_mem = utils_vox.Ref2Mem(clist_cam, Z, Y, X) # # this is N*B x S x 3 # clist_mem = clist_mem.reshape(N, B, S, 3) # as with prinet, let's do this for a single object first traj_past = obj_clist_cam[0,:,:,:2] traj_futu = obj_clist_cam[0,:,:,2:] T_past = 2 T_futu = S-2 # traj_past is B x T_past x 3 # traj_futu is B x T_futu x 3 print('traj_past', traj_past.shape) print('traj_futu', traj_futu.shape) feat_map = self.compressor(feat) pred_map = self.conv3d(feat) # these are B x C x Z x Y x X # each component of the noise is IID Normal ## get K samples K = 5 # number of samples traj_past = traj_past.unsqueeze(0).repeat(K, 1, 1, 1) feat_map = feat_map.unsqueeze(0).repeat(K, 1, 1, 1, 1, 1) pred_map = pred_map.unsqueeze(0).repeat(K, 1, 1, 1, 1, 1) # to sample the K trajectories in parallel, we'll pack K onto the batch dim __p = lambda x: utils_basic.pack_seqdim(x, K) __u = lambda x: utils_basic.unpack_seqdim(x, K) traj_past_ = __p(traj_past) feat_map_ = __p(feat_map) pred_map_ = __p(pred_map) base_sample_ = torch.randn(K*B, T_futu, 3) traj_futu_e_ = self.sample_forward(feat_map_, pred_map_, base_sample_, traj_past_) # traj_futu_e = __u(traj_futu_e_) # # this is K x B x T x 3 # print('traj_futu_e', traj_futu_e) # print(traj_futu_e.shape) # energy_vol = self.conv3d(feat) # # energy_vol is B x 1 x Z x Y x X # summ_writer.summ_oned('rpo/energy_vol', torch.mean(energy_vol, dim=3)) # summ_writer.summ_histogram('rpo/energy_vol_hist', energy_vol) # summ_writer.summ_traj_on_occ('traj/obj%d_clist' % k, # obj_clist, occ_memXs[:,0], already_mem=False) return total_loss
def forward(self, feed): results = dict() if 'log_freq' not in feed.keys(): feed['log_freq'] = None start_time = time.time() summ_writer = utils_improc.Summ_writer(writer=feed['writer'], global_step=feed['global_step'], set_name=feed['set_name'], log_freq=feed['log_freq'], fps=8) writer = feed['writer'] global_step = feed['global_step'] total_loss = torch.tensor(0.0).cuda() __p = lambda x: utils_basic.pack_seqdim(x, B) __u = lambda x: utils_basic.unpack_seqdim(x, B) __pb = lambda x: utils_basic.pack_boxdim(x, hyp.N) __ub = lambda x: utils_basic.unpack_boxdim(x, hyp.N) if hyp.aug_object_ent_dis: __pb_a = lambda x: utils_basic.pack_boxdim( x, hyp.max_obj_aug + hyp.max_obj_aug_dis) __ub_a = lambda x: utils_basic.unpack_boxdim( x, hyp.max_obj_aug + hyp.max_obj_aug_dis) else: __pb_a = lambda x: utils_basic.pack_boxdim(x, hyp.max_obj_aug) __ub_a = lambda x: utils_basic.unpack_boxdim(x, hyp.max_obj_aug) B, H, W, V, S, N = hyp.B, hyp.H, hyp.W, hyp.V, hyp.S, hyp.N PH, PW = hyp.PH, hyp.PW K = hyp.K BOX_SIZE = hyp.BOX_SIZE Z, Y, X = hyp.Z, hyp.Y, hyp.X Z2, Y2, X2 = int(Z / 2), int(Y / 2), int(X / 2) Z4, Y4, X4 = int(Z / 4), int(Y / 4), int(X / 4) D = 9 tids = torch.from_numpy(np.reshape(np.arange(B * N), [B, N])) rgb_camXs = feed["rgb_camXs_raw"] pix_T_cams = feed["pix_T_cams_raw"] camRs_T_origin = feed["camR_T_origin_raw"] origin_T_camRs = __u(utils_geom.safe_inverse(__p(camRs_T_origin))) origin_T_camXs = feed["origin_T_camXs_raw"] camX0_T_camXs = utils_geom.get_camM_T_camXs(origin_T_camXs, ind=0) camRs_T_camXs = __u( torch.matmul(utils_geom.safe_inverse(__p(origin_T_camRs)), __p(origin_T_camXs))) camXs_T_camRs = __u(utils_geom.safe_inverse(__p(camRs_T_camXs))) camX0_T_camRs = camXs_T_camRs[:, 0] camX1_T_camRs = camXs_T_camRs[:, 1] camR_T_camX0 = utils_geom.safe_inverse(camX0_T_camRs) xyz_camXs = feed["xyz_camXs_raw"] depth_camXs_, valid_camXs_ = utils_geom.create_depth_image( __p(pix_T_cams), __p(xyz_camXs), H, W) dense_xyz_camXs_ = utils_geom.depth2pointcloud(depth_camXs_, __p(pix_T_cams)) xyz_camRs = __u( utils_geom.apply_4x4(__p(camRs_T_camXs), __p(xyz_camXs))) xyz_camX0s = __u( utils_geom.apply_4x4(__p(camX0_T_camXs), __p(xyz_camXs))) occXs = __u(utils_vox.voxelize_xyz(__p(xyz_camXs), Z, Y, X)) occXs_to_Rs = utils_vox.apply_4x4s_to_voxs(camRs_T_camXs, occXs) occXs_to_Rs_45 = cross_corr.rotate_tensor_along_y_axis(occXs_to_Rs, 45) occXs_half = __u(utils_vox.voxelize_xyz(__p(xyz_camXs), Z2, Y2, X2)) occRs_half = __u(utils_vox.voxelize_xyz(__p(xyz_camRs), Z2, Y2, X2)) occX0s_half = __u(utils_vox.voxelize_xyz(__p(xyz_camX0s), Z2, Y2, X2)) unpXs = __u( utils_vox.unproject_rgb_to_mem(__p(rgb_camXs), Z, Y, X, __p(pix_T_cams))) unpXs_half = __u( utils_vox.unproject_rgb_to_mem(__p(rgb_camXs), Z2, Y2, X2, __p(pix_T_cams))) unpX0s_half = __u( utils_vox.unproject_rgb_to_mem( __p(rgb_camXs), Z2, Y2, X2, utils_basic.matmul2( __p(pix_T_cams), utils_geom.safe_inverse(__p(camX0_T_camXs))))) unpRs = __u( utils_vox.unproject_rgb_to_mem( __p(rgb_camXs), Z, Y, X, utils_basic.matmul2( __p(pix_T_cams), utils_geom.safe_inverse(__p(camRs_T_camXs))))) unpRs_half = __u( utils_vox.unproject_rgb_to_mem( __p(rgb_camXs), Z2, Y2, X2, utils_basic.matmul2( __p(pix_T_cams), utils_geom.safe_inverse(__p(camRs_T_camXs))))) dense_xyz_camRs_ = utils_geom.apply_4x4(__p(camRs_T_camXs), dense_xyz_camXs_) inbound_camXs_ = utils_vox.get_inbounds(dense_xyz_camRs_, Z, Y, X).float() inbound_camXs_ = torch.reshape(inbound_camXs_, [B * S, 1, H, W]) depth_camXs = __u(depth_camXs_) valid_camXs = __u(valid_camXs_) * __u(inbound_camXs_) summ_writer.summ_oneds('2D_inputs/depth_camXs', torch.unbind(depth_camXs, dim=1), maxdepth=21.0) summ_writer.summ_oneds('2D_inputs/valid_camXs', torch.unbind(valid_camXs, dim=1)) summ_writer.summ_rgbs('2D_inputs/rgb_camXs', torch.unbind(rgb_camXs, dim=1)) summ_writer.summ_occs('3D_inputs/occXs', torch.unbind(occXs, dim=1)) summ_writer.summ_unps('3D_inputs/unpXs', torch.unbind(unpXs, dim=1), torch.unbind(occXs, dim=1)) occRs = __u(utils_vox.voxelize_xyz(__p(xyz_camRs), Z, Y, X)) if hyp.do_eval_boxes: if hyp.dataset_name == "clevr_vqa": gt_boxes_origin_corners = feed['gt_box'] gt_scores_origin = feed['gt_scores'].detach().cpu().numpy() classes = feed['classes'] scores = gt_scores_origin tree_seq_filename = feed['tree_seq_filename'] gt_boxes_origin = nlu.get_ends_of_corner( gt_boxes_origin_corners) gt_boxes_origin_end = torch.reshape(gt_boxes_origin, [hyp.B, hyp.N, 2, 3]) gt_boxes_origin_theta = nlu.get_alignedboxes2thetaformat( gt_boxes_origin_end) gt_boxes_origin_corners = utils_geom.transform_boxes_to_corners( gt_boxes_origin_theta) gt_boxesR_corners = __ub( utils_geom.apply_4x4(camRs_T_origin[:, 0], __pb(gt_boxes_origin_corners))) gt_boxesR_theta = utils_geom.transform_corners_to_boxes( gt_boxesR_corners) gt_boxesR_end = nlu.get_ends_of_corner(gt_boxesR_corners) else: tree_seq_filename = feed['tree_seq_filename'] tree_filenames = [ join(hyp.root_dataset, i) for i in tree_seq_filename if i != "invalid_tree" ] invalid_tree_filenames = [ join(hyp.root_dataset, i) for i in tree_seq_filename if i == "invalid_tree" ] num_empty = len(invalid_tree_filenames) trees = [pickle.load(open(i, "rb")) for i in tree_filenames] len_valid = len(trees) if len_valid > 0: gt_boxesR, scores, classes = nlu.trees_rearrange(trees) if num_empty > 0: gt_boxesR = np.concatenate([ gt_boxesR, empty_gt_boxesR ]) if len_valid > 0 else empty_gt_boxesR scores = np.concatenate([ scores, empty_scores ]) if len_valid > 0 else empty_scores classes = np.concatenate([ classes, empty_classes ]) if len_valid > 0 else empty_classes gt_boxesR = torch.from_numpy( gt_boxesR).cuda().float() # torch.Size([2, 3, 6]) gt_boxesR_end = torch.reshape(gt_boxesR, [hyp.B, hyp.N, 2, 3]) gt_boxesR_theta = nlu.get_alignedboxes2thetaformat( gt_boxesR_end) #torch.Size([2, 3, 9]) gt_boxesR_corners = utils_geom.transform_boxes_to_corners( gt_boxesR_theta) class_names_ex_1 = "_".join(classes[0]) summ_writer.summ_text('eval_boxes/class_names', class_names_ex_1) gt_boxesRMem_corners = __ub( utils_vox.Ref2Mem(__pb(gt_boxesR_corners), Z2, Y2, X2)) gt_boxesRMem_end = nlu.get_ends_of_corner(gt_boxesRMem_corners) gt_boxesRMem_theta = utils_geom.transform_corners_to_boxes( gt_boxesRMem_corners) gt_boxesRUnp_corners = __ub( utils_vox.Ref2Mem(__pb(gt_boxesR_corners), Z, Y, X)) gt_boxesRUnp_end = nlu.get_ends_of_corner(gt_boxesRUnp_corners) gt_boxesX0_corners = __ub( utils_geom.apply_4x4(camX0_T_camRs, __pb(gt_boxesR_corners))) gt_boxesX0Mem_corners = __ub( utils_vox.Ref2Mem(__pb(gt_boxesX0_corners), Z2, Y2, X2)) gt_boxesX0Mem_theta = utils_geom.transform_corners_to_boxes( gt_boxesX0Mem_corners) gt_boxesX0Mem_end = nlu.get_ends_of_corner(gt_boxesX0Mem_corners) gt_boxesX0_end = nlu.get_ends_of_corner(gt_boxesX0_corners) gt_cornersX0_pix = __ub( utils_geom.apply_pix_T_cam(pix_T_cams[:, 0], __pb(gt_boxesX0_corners))) rgb_camX0 = rgb_camXs[:, 0] rgb_camX1 = rgb_camXs[:, 1] summ_writer.summ_box_by_corners('eval_boxes/gt_boxescamX0', rgb_camX0, gt_boxesX0_corners, torch.from_numpy(scores), tids, pix_T_cams[:, 0]) unps_vis = utils_improc.get_unps_vis(unpX0s_half, occX0s_half) unp_vis = torch.mean(unps_vis, dim=1) unps_visRs = utils_improc.get_unps_vis(unpRs_half, occRs_half) unp_visRs = torch.mean(unps_visRs, dim=1) unps_visRs_full = utils_improc.get_unps_vis(unpRs, occRs) unp_visRs_full = torch.mean(unps_visRs_full, dim=1) summ_writer.summ_box_mem_on_unp('eval_boxes/gt_boxesR_mem', unp_visRs, gt_boxesRMem_end, scores, tids) unpX0s_half = torch.mean(unpX0s_half, dim=1) unpX0s_half = nlu.zero_out(unpX0s_half, gt_boxesX0Mem_end, scores) occX0s_half = torch.mean(occX0s_half, dim=1) occX0s_half = nlu.zero_out(occX0s_half, gt_boxesX0Mem_end, scores) summ_writer.summ_unp('3D_inputs/unpX0s', unpX0s_half, occX0s_half) if hyp.do_feat: featXs_input = torch.cat([occXs, occXs * unpXs], dim=2) featXs_input_ = __p(featXs_input) freeXs_ = utils_vox.get_freespace(__p(xyz_camXs), __p(occXs_half)) freeXs = __u(freeXs_) visXs = torch.clamp(occXs_half + freeXs, 0.0, 1.0) mask_ = None if (type(mask_) != type(None)): assert (list(mask_.shape)[2:5] == list( featXs_input_.shape)[2:5]) featXs_, feat_loss = self.featnet(featXs_input_, summ_writer, mask=__p(occXs)) #mask_) total_loss += feat_loss validXs = torch.ones_like(visXs) _validX00 = validXs[:, 0:1] _validX01 = utils_vox.apply_4x4s_to_voxs(camX0_T_camXs[:, 1:], validXs[:, 1:]) validX0s = torch.cat([_validX00, _validX01], dim=1) validRs = utils_vox.apply_4x4s_to_voxs(camRs_T_camXs, validXs) visRs = utils_vox.apply_4x4s_to_voxs(camRs_T_camXs, visXs) featXs = __u(featXs_) _featX00 = featXs[:, 0:1] _featX01 = utils_vox.apply_4x4s_to_voxs(camX0_T_camXs[:, 1:], featXs[:, 1:]) featX0s = torch.cat([_featX00, _featX01], dim=1) emb3D_e = torch.mean(featX0s[:, 1:], dim=1) vis3D_e_R = torch.max(visRs[:, 1:], dim=1)[0] emb3D_g = featX0s[:, 0] vis3D_g_R = visRs[:, 0] validR_combo = torch.min(validRs, dim=1).values summ_writer.summ_feats('3D_feats/featXs_input', torch.unbind(featXs_input, dim=1), pca=True) summ_writer.summ_feats('3D_feats/featXs_output', torch.unbind(featXs, dim=1), valids=torch.unbind(validXs, dim=1), pca=True) summ_writer.summ_feats('3D_feats/featX0s_output', torch.unbind(featX0s, dim=1), valids=torch.unbind( torch.ones_like(validRs), dim=1), pca=True) summ_writer.summ_feats('3D_feats/validRs', torch.unbind(validRs, dim=1), pca=False) summ_writer.summ_feat('3D_feats/vis3D_e_R', vis3D_e_R, pca=False) summ_writer.summ_feat('3D_feats/vis3D_g_R', vis3D_g_R, pca=False) if hyp.do_munit: object_classes, filenames = nlu.create_object_classes( classes, [tree_seq_filename, tree_seq_filename], scores) if hyp.do_munit_fewshot: emb3D_e_R = utils_vox.apply_4x4_to_vox(camR_T_camX0, emb3D_e) emb3D_g_R = utils_vox.apply_4x4_to_vox(camR_T_camX0, emb3D_g) emb3D_R = emb3D_e_R emb3D_e_R_object, emb3D_g_R_object, validR_combo_object = nlu.create_object_tensors( [emb3D_e_R, emb3D_g_R], [validR_combo], gt_boxesRMem_end, scores, [BOX_SIZE, BOX_SIZE, BOX_SIZE]) emb3D_R_object = (emb3D_e_R_object + emb3D_g_R_object) / 2 content, style = self.munitnet.net.gen_a.encode(emb3D_R_object) objects_taken, _ = self.munitnet.net.gen_a.decode( content, style) styles = style contents = content elif hyp.do_3d_style_munit: emb3D_e_R = utils_vox.apply_4x4_to_vox(camR_T_camX0, emb3D_e) emb3D_g_R = utils_vox.apply_4x4_to_vox(camR_T_camX0, emb3D_g) emb3D_R = emb3D_e_R # st() emb3D_e_R_object, emb3D_g_R_object, validR_combo_object = nlu.create_object_tensors( [emb3D_e_R, emb3D_g_R], [validR_combo], gt_boxesRMem_end, scores, [BOX_SIZE, BOX_SIZE, BOX_SIZE]) emb3D_R_object = (emb3D_e_R_object + emb3D_g_R_object) / 2 camX1_T_R = camXs_T_camRs[:, 1] camX0_T_R = camXs_T_camRs[:, 0] assert hyp.B == 2 assert emb3D_e_R_object.shape[0] == 2 munit_loss, sudo_input_0, sudo_input_1, recon_input_0, recon_input_1, sudo_input_0_cycle, sudo_input_1_cycle, styles, contents, adin = self.munitnet( emb3D_R_object[0:1], emb3D_R_object[1:2]) if hyp.store_content_style_range: if self.max_content == None: self.max_content = torch.zeros_like( contents[0][0]).cuda() - 100000000 if self.min_content == None: self.min_content = torch.zeros_like( contents[0][0]).cuda() + 100000000 if self.max_style == None: self.max_style = torch.zeros_like( styles[0][0]).cuda() - 100000000 if self.min_style == None: self.min_style = torch.zeros_like( styles[0][0]).cuda() + 100000000 self.max_content = torch.max( torch.max(self.max_content, contents[0][0]), contents[1][0]) self.min_content = torch.min( torch.min(self.min_content, contents[0][0]), contents[1][0]) self.max_style = torch.max( torch.max(self.max_style, styles[0][0]), styles[1][0]) self.min_style = torch.min( torch.min(self.min_style, styles[0][0]), styles[1][0]) data_to_save = { 'max_content': self.max_content.cpu().numpy(), 'min_content': self.min_content.cpu().numpy(), 'max_style': self.max_style.cpu().numpy(), 'min_style': self.min_style.cpu().numpy() } with open('content_style_range.p', 'wb') as f: pickle.dump(data_to_save, f) elif hyp.is_contrastive_examples: if hyp.normalize_contrast: content0 = (contents[0] - self.min_content) / ( self.max_content - self.min_content + 1e-5) content1 = (contents[1] - self.min_content) / ( self.max_content - self.min_content + 1e-5) style0 = (styles[0] - self.min_style) / ( self.max_style - self.min_style + 1e-5) style1 = (styles[1] - self.min_style) / ( self.max_style - self.min_style + 1e-5) else: content0 = contents[0] content1 = contents[1] style0 = styles[0] style1 = styles[1] # euclid_dist_content = torch.sum(torch.sqrt((content0 - content1)**2))/torch.prod(torch.tensor(content0.shape)) # euclid_dist_style = torch.sum(torch.sqrt((style0-style1)**2))/torch.prod(torch.tensor(style0.shape)) euclid_dist_content = (content0 - content1).norm(2) / ( content0.numel()) euclid_dist_style = (style0 - style1).norm(2) / (style0.numel()) content_0_pooled = torch.mean( content0.reshape(list(content0.shape[:2]) + [-1]), dim=-1) content_1_pooled = torch.mean( content1.reshape(list(content1.shape[:2]) + [-1]), dim=-1) euclid_dist_content_pooled = (content_0_pooled - content_1_pooled).norm(2) / ( content_0_pooled.numel()) content_0_normalized = content0 / content0.norm() content_1_normalized = content1 / content1.norm() style_0_normalized = style0 / style0.norm() style_1_normalized = style1 / style1.norm() content_0_pooled_normalized = content_0_pooled / content_0_pooled.norm( ) content_1_pooled_normalized = content_1_pooled / content_1_pooled.norm( ) cosine_dist_content = torch.sum(content_0_normalized * content_1_normalized) cosine_dist_style = torch.sum(style_0_normalized * style_1_normalized) cosine_dist_content_pooled = torch.sum( content_0_pooled_normalized * content_1_pooled_normalized) print("euclid dist [content, pooled-content, style]: ", euclid_dist_content, euclid_dist_content_pooled, euclid_dist_style) print("cosine sim [content, pooled-content, style]: ", cosine_dist_content, cosine_dist_content_pooled, cosine_dist_style) if hyp.run_few_shot_on_munit: if (global_step % 300) == 1 or (global_step % 300) == 0: wrong = False try: precision_style = float(self.tp_style) / self.all_style precision_content = float( self.tp_content) / self.all_content except ZeroDivisionError: wrong = True if not wrong: summ_writer.summ_scalar( 'precision/unsupervised_precision_style', precision_style) summ_writer.summ_scalar( 'precision/unsupervised_precision_content', precision_content) # st() self.embed_list_style = defaultdict(lambda: []) self.embed_list_content = defaultdict(lambda: []) self.tp_style = 0 self.all_style = 0 self.tp_content = 0 self.all_content = 0 self.check = False elif not self.check and not nlu.check_fill_dict( self.embed_list_content, self.embed_list_style): print("Filling \n") for index, class_val in enumerate(object_classes): if hyp.dataset_name == "clevr_vqa": class_val_content, class_val_style = class_val.split( "/") else: class_val_content, class_val_style = [ class_val.split("/")[0], class_val.split("/")[0] ] print(len(self.embed_list_style.keys()), "style class", len(self.embed_list_content), "content class", self.embed_list_content.keys()) if len(self.embed_list_style[class_val_style] ) < hyp.few_shot_nums: self.embed_list_style[class_val_style].append( styles[index].squeeze()) if len(self.embed_list_content[class_val_content] ) < hyp.few_shot_nums: if hyp.avg_3d: content_val = contents[index] content_val = torch.mean(content_val.reshape( [content_val.shape[1], -1]), dim=-1) # st() self.embed_list_content[ class_val_content].append(content_val) else: self.embed_list_content[ class_val_content].append( contents[index].reshape([-1])) else: self.check = True try: print(float(self.tp_content) / self.all_content) print(float(self.tp_style) / self.all_style) except Exception as e: pass average = True if average: for key, val in self.embed_list_style.items(): if isinstance(val, type([])): self.embed_list_style[key] = torch.mean( torch.stack(val, dim=0), dim=0) for key, val in self.embed_list_content.items(): if isinstance(val, type([])): self.embed_list_content[key] = torch.mean( torch.stack(val, dim=0), dim=0) else: for key, val in self.embed_list_style.items(): if isinstance(val, type([])): self.embed_list_style[key] = torch.stack(val, dim=0) for key, val in self.embed_list_content.items(): if isinstance(val, type([])): self.embed_list_content[key] = torch.stack( val, dim=0) for index, class_val in enumerate(object_classes): class_val = class_val if hyp.dataset_name == "clevr_vqa": class_val_content, class_val_style = class_val.split( "/") else: class_val_content, class_val_style = [ class_val.split("/")[0], class_val.split("/")[0] ] style_val = styles[index].squeeze().unsqueeze(0) if not average: embed_list_val_style = torch.cat(list( self.embed_list_style.values()), dim=0) embed_list_key_style = list( np.repeat( np.expand_dims( list(self.embed_list_style.keys()), 1), hyp.few_shot_nums, 1).reshape([-1])) else: embed_list_val_style = torch.stack(list( self.embed_list_style.values()), dim=0) embed_list_key_style = list( self.embed_list_style.keys()) embed_list_val_style = utils_basic.l2_normalize( embed_list_val_style, dim=1).permute(1, 0) style_val = utils_basic.l2_normalize(style_val, dim=1) scores_styles = torch.matmul(style_val, embed_list_val_style) index_key = torch.argmax(scores_styles, dim=1).squeeze() selected_class_style = embed_list_key_style[index_key] self.styles_prediction[class_val_style].append( selected_class_style) if class_val_style == selected_class_style: self.tp_style += 1 self.all_style += 1 if hyp.avg_3d: content_val = contents[index] content_val = torch.mean(content_val.reshape( [content_val.shape[1], -1]), dim=-1).unsqueeze(0) else: content_val = contents[index].reshape( [-1]).unsqueeze(0) if not average: embed_list_val_content = torch.cat(list( self.embed_list_content.values()), dim=0) embed_list_key_content = list( np.repeat( np.expand_dims( list(self.embed_list_content.keys()), 1), hyp.few_shot_nums, 1).reshape([-1])) else: embed_list_val_content = torch.stack(list( self.embed_list_content.values()), dim=0) embed_list_key_content = list( self.embed_list_content.keys()) embed_list_val_content = utils_basic.l2_normalize( embed_list_val_content, dim=1).permute(1, 0) content_val = utils_basic.l2_normalize(content_val, dim=1) scores_content = torch.matmul(content_val, embed_list_val_content) index_key = torch.argmax(scores_content, dim=1).squeeze() selected_class_content = embed_list_key_content[ index_key] self.content_prediction[class_val_content].append( selected_class_content) if class_val_content == selected_class_content: self.tp_content += 1 self.all_content += 1 # st() munit_loss = hyp.munit_loss_weight * munit_loss recon_input_obj = torch.cat([recon_input_0, recon_input_1], dim=0) recon_emb3D_R = nlu.update_scene_with_objects( emb3D_R, recon_input_obj, gt_boxesRMem_end, scores) sudo_input_obj = torch.cat([sudo_input_0, sudo_input_1], dim=0) styled_emb3D_R = nlu.update_scene_with_objects( emb3D_R, sudo_input_obj, gt_boxesRMem_end, scores) styled_emb3D_e_X1 = utils_vox.apply_4x4_to_vox( camX1_T_R, styled_emb3D_R) styled_emb3D_e_X0 = utils_vox.apply_4x4_to_vox( camX0_T_R, styled_emb3D_R) emb3D_e_X1 = utils_vox.apply_4x4_to_vox(camX1_T_R, recon_emb3D_R) emb3D_e_X0 = utils_vox.apply_4x4_to_vox(camX0_T_R, recon_emb3D_R) emb3D_e_X1_og = utils_vox.apply_4x4_to_vox(camX1_T_R, emb3D_R) emb3D_e_X0_og = utils_vox.apply_4x4_to_vox(camX0_T_R, emb3D_R) emb3D_R_aug_diff = torch.abs(emb3D_R - recon_emb3D_R) summ_writer.summ_feat(f'aug_feat/og', emb3D_R) summ_writer.summ_feat(f'aug_feat/og_gen', recon_emb3D_R) summ_writer.summ_feat(f'aug_feat/og_aug_diff', emb3D_R_aug_diff) if hyp.cycle_style_view_loss: sudo_input_obj_cycle = torch.cat( [sudo_input_0_cycle, sudo_input_1_cycle], dim=0) styled_emb3D_R_cycle = nlu.update_scene_with_objects( emb3D_R, sudo_input_obj_cycle, gt_boxesRMem_end, scores) styled_emb3D_e_X0_cycle = utils_vox.apply_4x4_to_vox( camX0_T_R, styled_emb3D_R_cycle) styled_emb3D_e_X1_cycle = utils_vox.apply_4x4_to_vox( camX1_T_R, styled_emb3D_R_cycle) summ_writer.summ_scalar('munit_loss', munit_loss.cpu().item()) total_loss += munit_loss if hyp.do_occ and hyp.occ_do_cheap: occX0_sup, freeX0_sup, _, freeXs = utils_vox.prep_occs_supervision( camX0_T_camXs, xyz_camXs, Z2, Y2, X2, agg=True) summ_writer.summ_occ('occ_sup/occ_sup', occX0_sup) summ_writer.summ_occ('occ_sup/free_sup', freeX0_sup) summ_writer.summ_occs('occ_sup/freeXs_sup', torch.unbind(freeXs, dim=1)) summ_writer.summ_occs('occ_sup/occXs_sup', torch.unbind(occXs_half, dim=1)) occ_loss, occX0s_pred_ = self.occnet( torch.mean(featX0s[:, 1:], dim=1), occX0_sup, freeX0_sup, torch.max(validX0s[:, 1:], dim=1)[0], summ_writer) occX0s_pred = __u(occX0s_pred_) total_loss += occ_loss if hyp.do_view: assert (hyp.do_feat) PH, PW = hyp.PH, hyp.PW sy = float(PH) / float(hyp.H) sx = float(PW) / float(hyp.W) assert (sx == 0.5) # else we need a fancier downsampler assert (sy == 0.5) projpix_T_cams = __u( utils_geom.scale_intrinsics(__p(pix_T_cams), sx, sy)) # st() if hyp.do_munit: feat_projX00 = utils_vox.apply_pixX_T_memR_to_voxR( projpix_T_cams[:, 0], camX0_T_camXs[:, 1], emb3D_e_X1, # use feat1 to predict rgb0 hyp.view_depth, PH, PW) feat_projX00_og = utils_vox.apply_pixX_T_memR_to_voxR( projpix_T_cams[:, 0], camX0_T_camXs[:, 1], emb3D_e_X1_og, # use feat1 to predict rgb0 hyp.view_depth, PH, PW) # only for checking the style styled_feat_projX00 = utils_vox.apply_pixX_T_memR_to_voxR( projpix_T_cams[:, 0], camX0_T_camXs[:, 1], styled_emb3D_e_X1, # use feat1 to predict rgb0 hyp.view_depth, PH, PW) if hyp.cycle_style_view_loss: styled_feat_projX00_cycle = utils_vox.apply_pixX_T_memR_to_voxR( projpix_T_cams[:, 0], camX0_T_camXs[:, 1], styled_emb3D_e_X1_cycle, # use feat1 to predict rgb0 hyp.view_depth, PH, PW) else: feat_projX00 = utils_vox.apply_pixX_T_memR_to_voxR( projpix_T_cams[:, 0], camX0_T_camXs[:, 1], featXs[:, 1], # use feat1 to predict rgb0 hyp.view_depth, PH, PW) rgb_X00 = utils_basic.downsample(rgb_camXs[:, 0], 2) rgb_X01 = utils_basic.downsample(rgb_camXs[:, 1], 2) valid_X00 = utils_basic.downsample(valid_camXs[:, 0], 2) view_loss, rgb_e, emb2D_e = self.viewnet(feat_projX00, rgb_X00, valid_X00, summ_writer, "rgb") if hyp.do_munit: _, rgb_e, emb2D_e = self.viewnet(feat_projX00_og, rgb_X00, valid_X00, summ_writer, "rgb_og") if hyp.do_munit: styled_view_loss, styled_rgb_e, styled_emb2D_e = self.viewnet( styled_feat_projX00, rgb_X00, valid_X00, summ_writer, "recon_style") if hyp.cycle_style_view_loss: styled_view_loss_cycle, styled_rgb_e_cycle, styled_emb2D_e_cycle = self.viewnet( styled_feat_projX00_cycle, rgb_X00, valid_X00, summ_writer, "recon_style_cycle") rgb_input_1 = torch.cat( [rgb_X01[1], rgb_X01[0], styled_rgb_e[0]], dim=2) rgb_input_2 = torch.cat( [rgb_X01[0], rgb_X01[1], styled_rgb_e[1]], dim=2) complete_vis = torch.cat([rgb_input_1, rgb_input_2], dim=1) summ_writer.summ_rgb('munit/munit_recons_vis', complete_vis.unsqueeze(0)) if not hyp.do_munit: total_loss += view_loss else: if hyp.basic_view_loss: total_loss += view_loss if hyp.style_view_loss: total_loss += styled_view_loss if hyp.cycle_style_view_loss: total_loss += styled_view_loss_cycle summ_writer.summ_scalar('loss', total_loss.cpu().item()) if hyp.save_embed_tsne: for index, class_val in enumerate(object_classes): class_val_content, class_val_style = class_val.split("/") style_val = styles[index].squeeze().unsqueeze(0) self.cluster_pool.update(style_val, [class_val_style]) print(self.cluster_pool.num) if self.cluster_pool.is_full(): embeds, classes = self.cluster_pool.fetch() with open("offline_cluster" + '/%st.txt' % 'classes', 'w') as f: for index, embed in enumerate(classes): class_val = classes[index] f.write("%s\n" % class_val) f.close() with open("offline_cluster" + '/%st.txt' % 'embeddings', 'w') as f: for index, embed in enumerate(embeds): # embed = utils_basic.l2_normalize(embed,dim=0) print("writing {} embed".format(index)) embed_l_s = [str(i) for i in embed.tolist()] embed_str = '\t'.join(embed_l_s) f.write("%s\n" % embed_str) f.close() st() return total_loss, results
def get_synth_flow_v2(xyz_cam0, occ0, unp0, summ_writer, sometimes_zero=False, do_vis=False): # this version re-voxlizes occ1, rather than warp B, C, Z, Y, X = list(unp0.shape) assert (C == 3) __p = lambda x: utils_basic.pack_seqdim(x, B) __u = lambda x: utils_basic.unpack_seqdim(x, B) # we do not sample any rotations here, to keep the distribution purely # uniform across all translations # (rotation ruins this, since the pivot point is at the camera) cam1_T_cam0 = [ utils_geom.get_random_rt(B, r_amount=0.0, t_amount=3.0), # large motion utils_geom.get_random_rt( B, r_amount=0.0, t_amount=0.1, # small motion sometimes_zero=sometimes_zero) ] cam1_T_cam0 = random.sample(cam1_T_cam0, k=1)[0] xyz_cam1 = utils_geom.apply_4x4(cam1_T_cam0, xyz_cam0) occ1 = utils_vox.voxelize_xyz(xyz_cam1, Z, Y, X) unp1 = utils_vox.apply_4x4_to_vox(cam1_T_cam0, unp0) occs = [occ0, occ1] unps = [unp0, unp1] if do_vis: summ_writer.summ_occs('synth/occs', occs) summ_writer.summ_unps('synth/unps', unps, occs) mem_T_cam = utils_vox.get_mem_T_ref(B, Z, Y, X) cam_T_mem = utils_vox.get_ref_T_mem(B, Z, Y, X) mem1_T_mem0 = utils_basic.matmul3(mem_T_cam, cam1_T_cam0, cam_T_mem) xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X) xyz_mem1 = utils_geom.apply_4x4(mem1_T_mem0, xyz_mem0) xyz_mem0 = xyz_mem0.reshape(B, Z, Y, X, 3) xyz_mem1 = xyz_mem1.reshape(B, Z, Y, X, 3) flow = xyz_mem1 - xyz_mem0 # this is B x Z x Y x X x 3 flow = flow.permute(0, 4, 1, 2, 3) # this is B x 3 x Z x Y x X if do_vis: summ_writer.summ_3D_flow('synth/flow', flow, clip=2.0) if do_vis: occ0_e = utils_samp.backwarp_using_3D_flow(occ1, flow, binary_feat=True) unp0_e = utils_samp.backwarp_using_3D_flow(unp1, flow) summ_writer.summ_occs('synth/occs_stab', [occ0, occ0_e]) summ_writer.summ_unps('synth/unps_stab', [unp0, unp0_e], [occ0, occ0_e]) occs = torch.stack(occs, dim=1) unps = torch.stack(unps, dim=1) return occs, unps, flow, cam1_T_cam0