def forward( self, xyz: torch.Tensor, features: Optional[torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Parameters ---------- xyz : torch.Tensor (B, N, 3) tensor of the xyz coordinates of the features features : torch.Tensor (B, C, N) tensor of the descriptors of the the features Returns ------- new_xyz : torch.Tensor (B, npoint, 3) tensor of the new features' xyz new_features : torch.Tensor (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors """ new_features_list = [] xyz_flipped = xyz.transpose(1, 2).contiguous() new_xyz = (pointnet2_utils.gather_operation( xyz_flipped, pointnet2_utils.furthest_point_sample( xyz, self.npoint)).transpose(1, 2).contiguous() if self.npoint is not None else None) if self.fuse == 'add': for i in range(len(self.groupers)): new_features = self.groupers[i]( xyz, new_xyz, features) # (B, C, npoint, nsample) new_features = self.mlps[i]( new_features) # (B, mlp[-1], npoint, nsample) new_features = F.max_pool2d(new_features, kernel_size=[ 1, new_features.size(3) ]) # (B, mlp[-1], npoint, 1) new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) if i == 0: new_features_sum = new_features else: new_features_sum += new_features return new_xyz, new_features_sum for i in range(len(self.groupers)): new_features = self.groupers[i]( xyz, new_xyz, features) # (B, C, npoint, nsample) new_features = self.mlps[i]( new_features) # (B, mlp[-1], npoint, nsample) new_features = F.max_pool2d(new_features, kernel_size=[ 1, new_features.size(3) ]) # (B, mlp[-1], npoint, 1) new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) new_features_list.append(new_features) return new_xyz, torch.cat(new_features_list, dim=1)
def sample_and_group(npoint, radius, nsample, xyz, points): """ Input: npoint: radius: nsample: xyz: input points position data, [B, N, 3] points: input points data, [B, N, D] Return: new_xyz: sampled points position data, [B, npoint, nsample, 3] new_points: sampled points data, [B, npoint, nsample, 3+D] """ B, N, C = xyz.shape S = npoint xyz = xyz.contiguous() fps_idx = pointnet2_utils.furthest_point_sample( xyz, npoint).long() # [B, npoint] new_xyz = index_points(xyz, fps_idx) new_points = index_points(points, fps_idx) # new_xyz = xyz[:] # new_points = points[:] idx = knn_point(nsample, xyz, new_xyz) #idx = query_ball_point(radius, nsample, xyz, new_xyz) grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) grouped_points = index_points(points, idx) grouped_points_norm = grouped_points - new_points.view(B, S, 1, -1) new_points = torch.cat([ grouped_points_norm, new_points.view(B, S, 1, -1).repeat(1, 1, nsample, 1) ], dim=-1) return new_xyz, new_points
def sample_and_group(npoint, nsample, xyz, points): """ Input: npoint: number of selected FPS points nsample: number of k-nn xyz: input points position data, [B, N, 3] points: input points data, [B, N, D] Return: new_xyz: sampled points position data, [B, npoint, nsample, 3] new_points: sampled points data, [B, npoint, nsample, D] # concat[points, points-anchor] """ B, N, C = xyz.shape S = npoint xyz = xyz.contiguous() # xyz [btach, points, xyz] # fps_idx = farthest_point_sample(xyz, npoint).long() fps_idx = pointnet2_utils.furthest_point_sample( xyz, npoint).long() # [B, npoint] new_xyz = index_points(xyz, fps_idx) new_points = index_points(points, fps_idx) idx = knn_point(nsample, xyz, new_xyz) # idx = query_ball_point(radius, nsample, xyz, new_xyz) grouped_points = index_points(points, idx) return new_xyz, grouped_points
def pool(xyz, points, k, npoint): xyz_flipped = xyz.transpose(1, 2).contiguous() new_xyz = pointnet2_utils.gather_operation( xyz_flipped, pointnet2_utils.furthest_point_sample( xyz_flipped, npoint)).transpose(1, 2).contiguous() _, idx = knn_point(k, xyz, new_xyz) new_points = torch.max(pointnet2_utils.grouping_operation( points.permute(0, 2, 1).contiguous(), idx.int().permute(0, 2, 1).contiguous()).permute(0, 3, 2, 1), dim=2).values return new_xyz, new_points
def fps_subsample(pcd, n_points=2048): """ Args pcd: (b, 16384, 3) returns new_pcd: (b, n_points, 3) """ new_pcd = gather_operation( pcd.permute(0, 2, 1).contiguous(), furthest_point_sample(pcd, n_points)) new_pcd = new_pcd.permute(0, 2, 1).contiguous() return new_pcd
def forward(self, x): """ :param x: input data points corrdications [b, n, 3+c] first 3 dims are coordinates :return: grouped_points [b,points, knn, 3+c] !!! Notice that: the sampled points = grouped_points[:,:,0,:] """ # sampeld_points = index_points(x, farthest_point_sample(x[:, :, :3], self.points)) # [b,points, 3] # sampeld_points = index_points(x, farthest_point_sample(x[:, :, :3], self.points)) # [b,points, 3] idx = furthest_point_sample((x[:, :, :3]).contiguous(), self.points).long() sampeld_points = index_points(x, idx) # [b,points, 3] distances = square_distance(sampeld_points[:, :, :3], x[:, :, :3]) # including sampled points self. knn_idx = distances.argsort()[:, :, :self.knn] grouped_points = index_points(x, knn_idx) # [b,points, knn, 3+c] return grouped_points
def forward(self, x, feature, num_pool): x = x.contiguous() xyz_flipped = x.transpose(1, 2).contiguous() x_sub = (pointnet2_utils.gather_operation( x, pointnet2_utils.furthest_point_sample(xyz_flipped, num_pool)).transpose( 1, 2).contiguous()) # sub_index = self.knn(x_sub, xyz_flipped).int() sub_index = pointnet2_utils.ball_query(0.2 * self.nlayer, self.k, xyz_flipped, x_sub) x = pointnet2_utils.grouping_operation(x, sub_index) x = torch.max(x, dim=-1)[0] # x = soft_pool2d(x, [1, x.shape[-1]]).squeeze(-1) feature = pointnet2_utils.grouping_operation(feature, sub_index) feature = self.mlp(feature) feature = torch.max(feature, dim=-1)[0] # feature = soft_pool2d(feature, [1, feature.shape[-1]]).squeeze(-1) return x, feature
def forward(self, xyz, points): B, N, C = xyz.shape S = self.groups xyz = xyz.contiguous() # xyz [btach, points, xyz] # fps_idx = farthest_point_sample(xyz, self.groups).long() fps_idx = pointnet2_utils.furthest_point_sample(xyz, self.groups).long() # [B, npoint] new_xyz = index_points(xyz, fps_idx) new_points = index_points(points, fps_idx) idx = knn_point(self.kneighbors, xyz, new_xyz) # idx = query_ball_point(radius, nsample, xyz, new_xyz) # grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] grouped_points = index_points(points, idx) grouped_points_norm = grouped_points - new_points.view(B, S, 1, -1) new_points = torch.cat([grouped_points_norm, new_points.view(B, S, 1, -1).repeat(1, 1, self.kneighbors, 1)] , dim=-1) return new_xyz, new_points
def __init__( self, transforms=None, train=True, self_supervision=False, num_points=2048, ): super().__init__() self.transforms = transforms self.self_supervision = self_supervision self.train = train self.num_points = num_points root = '/home/zhangniansong/data/ScanObjNN/main_split_nobg/' if self.self_supervision: h5 = h5py.File(root + 'training_objectdataset.h5', 'r') points_train = np.array(h5['data']).astype(np.float32) h5.close() self.points = points_train self.labels = None elif train: h5 = h5py.File(root + 'training_objectdataset.h5', 'r') self.points = np.array(h5['data']).astype(np.float32) self.labels = np.array(h5['label']).astype(int) h5.close() else: h5 = h5py.File(root + 'test_objectdataset.h5', 'r') self.points = np.array(h5['data']).astype(np.float32) self.labels = np.array(h5['label']).astype(int) h5.close() self.points_ = torch.tensor(self.points).cuda() # maybe modify to `to(device)` fps_idx = pointnet2_utils.furthest_point_sample(self.points_, self.num_points) self.points_ = index_points(self.points_, fps_idx.long()) # BUG FIX: # if move data on gpu in dataset, will raise CUDA_NOT_INITIALIZED error in dataloader self.points = self.points_.cpu() del self.points_ torch.cuda.empty_cache() print('Successfully load ScanObjectNN with', len(self.labels), 'instances')
def sample_and_group(xyz, points, npoint, nsample, radius, use_xyz=True): """ Args: xyz: Tensor, (B, 3, N) points: Tensor, (B, f, N) npoint: int nsample: int radius: float use_xyz: boolean Returns: new_xyz: Tensor, (B, 3, npoint) new_points: Tensor, (B, 3 | f+3 | f, npoint, nsample) idx_local: Tensor, (B, npoint, nsample) grouped_xyz: Tensor, (B, 3, npoint, nsample) """ xyz_flipped = xyz.permute(0, 2, 1).contiguous() # (B, N, 3) new_xyz = gather_operation(xyz, furthest_point_sample(xyz_flipped, npoint)) # (B, 3, npoint) idx = ball_query(radius, nsample, xyz_flipped, new_xyz.permute(0, 2, 1).contiguous()) # (B, npoint, nsample) grouped_xyz = grouping_operation(xyz, idx) # (B, 3, npoint, nsample) grouped_xyz -= new_xyz.unsqueeze(3).repeat(1, 1, 1, nsample) if points is not None: grouped_points = grouping_operation(points, idx) # (B, f, npoint, nsample) if use_xyz: new_points = torch.cat([grouped_xyz, grouped_points], 1) else: new_points = grouped_points else: new_points = grouped_xyz return new_xyz, new_points, idx, grouped_xyz
def sample_and_ball_group(s, radius, n, coords, features): """ Sampling by FPS and grouping by ball query. Input: s[int]: number of points to be sampled by FPS k[int]: number of points to be grouped into a neighbor by ball query n[int]: fix number of points in ball neighbor coords[tensor]: input points coordinates data with size of [B, N, 3] features[tensor]: input points features data with size of [B, N, D] Returns: new_coords[tensor]: sampled and grouped points coordinates by FPS with size of [B, s, k, 3] new_features[tensor]: sampled and grouped points features by FPS with size of [B, s, k, 2D] """ batch_size = coords.shape[0] coords = coords.contiguous() # FPS sampling fps_idx = pointnet2_utils.furthest_point_sample(coords, s).long() # [B, s] new_coords = index_points(coords, fps_idx) # [B, s, 3] new_features = index_points(features, fps_idx) # [B, s, D] # ball_query grouping idx = query_ball_point(radius, n, coords, new_coords) # [B, s, n] grouped_features = index_points(features, idx) # [B, s, n, D] # Matrix sub grouped_features_norm = grouped_features - new_features.view( batch_size, s, 1, -1) # [B, s, n, D] # Concat, my be different in many networks aggregated_features = torch.cat([ grouped_features_norm, new_features.view(batch_size, s, 1, -1).repeat(1, 1, n, 1) ], dim=-1) # [B, s, n, 2D] return new_coords, aggregated_features # [B, s, 3], [B, s, n, 2D]
def sample_and_knn_group(s, k, coords, features): """ Sampling by FPS and grouping by KNN. Input: s[int]: number of points to be sampled by FPS k[int]: number of points to be grouped into a neighbor by KNN coords[tensor]: input points coordinates data with size of [B, N, 3] features[tensor]: input points features data with size of [B, N, D] Returns: new_coords[tensor]: sampled and grouped points coordinates by FPS with size of [B, s, k, 3] new_features[tensor]: sampled and grouped points features by FPS with size of [B, s, k, 2D] """ batch_size = coords.shape[0] coords = coords.contiguous() # FPS sampling fps_idx = pointnet2_utils.furthest_point_sample(coords, s).long() # [B, s] new_coords = index_points(coords, fps_idx) # [B, s, 3] new_features = index_points(features, fps_idx) # [B, s, D] # K-nn grouping idx = knn_point(k, coords, new_coords) # [B, s, k] grouped_features = index_points(features, idx) # [B, s, k, D] # Matrix sub grouped_features_norm = grouped_features - new_features.view( batch_size, s, 1, -1) # [B, s, k, D] # Concat aggregated_features = torch.cat([ grouped_features_norm, new_features.view(batch_size, s, 1, -1).repeat(1, 1, k, 1) ], dim=-1) # [B, s, k, 2D] return new_coords, aggregated_features # [B, s, 3], [B, s, k, 2D]
vis.add_geometry(pcd) ctr = vis.get_view_control() ctr.rotate(10.0, 0.0) vis.update_geometry() vis.poll_events() vis.update_renderer() time.sleep(1) vis.capture_screen_image('{}b.png'.format(i)) vis.remove_geometry(pcd) point = torch.from_numpy(pcd.points) x_sub = ( pointnet2_utils.gather_operation( xyz_flipped, pointnet2_utils.furthest_point_sample(x, num_pool) ).transpose(1, 2).contiguous() ) vis.update_geometry(pcd) vis.poll_events() vis.update_renderer() time.sleep(1) vis.capture_screen_image('{}b.png'.format(i)) vis.remove_geometry(pcd) # pcd = o3d.geometry.PointCloud() # pcd.points = o3d.utility.Vector3dVector(data) # o3d.visualization.draw_geometries([pcd], # zoom=1, # front=[1, 1, -1], # lookat=[0, 0, 0],
def loadAndPredict(self, folder_path): # Load the data data = loadData(folder_path) gt_movable_link_mask = data['gt_movable_link_mask'] rgb = data['rgb'] # sample a pixel to interact xs, ys = np.where(gt_movable_link_mask > 0) if len(xs) == 0: print('No Movable Pixel! Quit!') exit(1) idx = np.random.randint(len(xs)) x, y = xs[idx], ys[idx] marked_rgb = (rgb * 255).astype(np.uint8) marked_rgb = cv2.circle(marked_rgb, (y, x), radius=3, color=(0, 0, 255), thickness=5) # prepare input pc cam_XYZA_id1, cam_XYZA_id2, cam_XYZA_pts = data['id1'], data[ 'id2'], data['pc'] cam_XYZA = compute_XYZA_matrix(cam_XYZA_id1, cam_XYZA_id2, cam_XYZA_pts, rgb.shape[0], rgb.shape[1]) pt = cam_XYZA[x, y, :3] ptid = np.array([x, y], dtype=np.int32) mask = (cam_XYZA[:, :, 3] > 0.5) mask[x, y] = False pc = cam_XYZA[mask, :3] grid_x, grid_y = np.meshgrid(np.arange(448), np.arange(448)) grid_xy = np.stack([grid_y, grid_x]).astype(np.int32) # 2 x 448 x 448 pcids = grid_xy[:, mask].T pc_movable = (gt_movable_link_mask > 0)[mask] idx = np.arange(pc.shape[0]) np.random.shuffle(idx) while len(idx) < 30000: idx = np.concatenate([idx, idx]) idx = idx[:30000 - 1] pc = pc[idx, :] pc_movable = pc_movable[idx] pcids = pcids[idx, :] pc = np.vstack([pt, pc]) pcids = np.vstack([ptid, pcids]) pc_movable = np.append(True, pc_movable) pc[:, 0] -= 5 pc = torch.from_numpy(pc).unsqueeze(0).to(self.device) input_pcid = furthest_point_sample( pc, self.train_conf.num_point_per_shape).long().reshape(-1) pc = pc[:, input_pcid, :3] # 1 x N x 3 pc_movable = pc_movable[input_pcid.cpu().numpy()] # N pcids = pcids[input_pcid.cpu().numpy()] pccolors = rgb[pcids[:, 0], pcids[:, 1]] # push through unet feats = self.network.pointnet2(pc.repeat(1, 1, 2))[0].permute(1, 0) # N x F # sample a random direction to query gripper_direction_camera = torch.randn(1, 3).to(self.device) gripper_direction_camera = F.normalize(gripper_direction_camera, dim=1) gripper_forward_direction_camera = torch.randn(1, 3).to(self.device) gripper_forward_direction_camera = F.normalize( gripper_forward_direction_camera, dim=1) up = gripper_direction_camera forward = gripper_forward_direction_camera left = torch.cross(up, forward) forward = torch.cross(left, up) forward = F.normalize(forward, dim=1) dirs1 = up.repeat(self.train_conf.num_point_per_shape, 1) dirs2 = forward.repeat(self.train_conf.num_point_per_shape, 1) input_queries = torch.cat([dirs1, dirs2], dim=1) net = self.network.critic(feats, input_queries) result = torch.sigmoid(net).detach().cpu().numpy() result *= pc_movable point_cloud = pc.cpu().numpy()[0] return point_cloud, result
def forward(self, xyz, rgb, istrain=False): hs = [] #xyz_copy = xyz.clone() #rgb_copy = rgb.clone() batch_size, n_points, _ = xyz.shape part_length = n_points//self.npart last_point = -1 #h = self.proj_in(rgb) h = rgb s2_count = 0 for i in range(self.num_layers): h_input = h.clone() xyz_input = xyz.clone() batch_size, n_points, _ = h.shape if self.k>1: if i == self.num_layers-1: if self.cluster == 'xyz': g = self.nng(xyz, istrain = istrain and self.graph_jitter) elif self.cluster == 'rgb': g = self.nng(h, istrain=istrain and self.graph_jitter) elif self.cluster == 'xyzrgb': g = self.nng( torch.cat((xyz,h), 2), istrain=istrain and self.graph_jitter) elif i==0 or n_points != last_point: g = self.nng(xyz, istrain=istrain and self.graph_jitter) last_point = n_points h = h.view(batch_size * n_points, -1) if self.k==1: h = self.conv[i](h) elif self.conv_type == 'GatedGCN': h = self.conv[i](g, h, g.edata['feat'], snorm_n = 1/g.number_of_nodes() , snorm_e = 1/g.number_of_edges()) else: h = self.conv[i](g, h) h = F.leaky_relu(h, 0.2) h = h.view(batch_size, n_points, -1) h = h.transpose(1, 2).contiguous() xyz, h = self.sa[i](xyz_input, h) h = h.transpose(1, 2).contiguous() #h = self.conv_s1[i](h) if self.id_skip and h.shape[1] <= self.init_points//4: # We could use identity mapping Here or add connect drop if istrain and self.drop_connect_rate>0: h = drop_connect(h, p=self.drop_connect_rate, training=istrain) if h.shape[1] == n_points: h = h_input + self.res_scale * h # Here I borrow the idea from Inception-ResNet-v2 elif h.shape[1] == n_points//2: h_input_s2 = pointnet2_utils.gather_operation( h_input.transpose(1, 2).contiguous(), pointnet2_utils.furthest_point_sample(xyz_input, h.shape[1] ) ).transpose(1, 2).contiguous() h = self.conv_s2[s2_count](h_input_s2) + self.res_scale * h s2_count +=1 if self.npart==1: # Pooling h_max, _ = torch.max(h, 1) h_avg = torch.mean(h, 1) hs.append(h_max) hs.append(h_avg) h = torch.cat(hs, 1) h = self.embs[0](h) h = self.bn_embs[0](h) h = self.dropouts[0](h) h = self.proj_output(h) else: # Sort batch_size, n_points, _ = h.shape y_index = torch.argsort(xyz[:, :, 1],dim = 1).view(batch_size * n_points, -1) h = h.view(batch_size * n_points, -1) h = h[y_index, :].view(batch_size, n_points, -1) h = h.transpose(1, 2) # Part Pooling h = self.partpool(h) for i in range(self.npart): part_h = h[:,:,i] part_h = self.embs[i](part_h) part_h = self.bn_embs[i](part_h) part_h = self.dropouts[i](part_h) part_h = self.proj_outputs[i](part_h) hs.append(part_h) h = hs return h
def forward(self, xyz, rgb, istrain=False): hs = [] #xyz_copy = xyz.clone() #rgb_copy = rgb.clone() batch_size, n_points, _ = xyz.shape part_length = n_points // self.npart last_point = -1 last_feature_dim = -1 #h = self.proj_in(rgb) h = rgb s2_count = 0 for i in range(self.num_layers): h_input = h.clone() xyz_input = xyz.clone() batch_size, n_points, feature_dim = h.shape ######## Build Graph ######### last_point = n_points ######### Dynamic Graph Conv ######### xyz = xyz.transpose(1, 2).contiguous() #print(h.shape) # batchsize x point_number x feature_dim h = h.transpose(1, 2).contiguous() for j in range(self.num_conv): index = self.num_conv * i + j ####### BN + ReLU ##### if self.pre_act == True: if self.norm == 'ln': h = h.transpose(1, 2).contiguous() h = self.bn[index](h) h = h.transpose(1, 2).contiguous() else: h = self.bn[index](h) h = F.leaky_relu(h, 0.2) ####### Graph Feature ########### if self.k == 1 and j == 0: h = h.unsqueeze(-1) else: if i == self.num_layers - 1: if self.cluster == 'xyz': h = get_graph_feature(xyz, h, k=self.k) elif self.cluster == 'xyzrgb' or self.cluster == 'allxyzrgb': h = get_graph_feature(torch.cat((xyz, h), 1), h, k=self.k) else: # Common Layers if self.cluster == 'allxyzrgb': h = get_graph_feature(torch.cat((xyz, h), 1), h, k=self.k) else: h = get_graph_feature(xyz, h, k=self.k) ####### Conv ########## if self.light == True and i > 0: #shuffle after the first layer h = channel_shuffle(h, 2) h = self.conv[index](h) else: h = self.conv[index](h) h = h.max(dim=-1, keepdim=False)[0] ####### BN + ReLU ##### if self.pre_act == False: if self.norm == 'ln': h = h.transpose(1, 2).contiguous() h = self.bn[index](h) h = h.transpose(1, 2).contiguous() else: h = self.bn[index](h) h = F.leaky_relu(h, 0.2) h = h.transpose(1, 2).contiguous() #print(h.shape) # batchsize x point_number x feature_dim batch_size, n_points, feature_dim = h.shape ######### Residual Before Downsampling############# if self.id_skip == 1: if istrain and self.drop_connect_rate > 0: h = drop_connect(h, p=self.drop_connect_rate, training=istrain) if feature_dim != last_feature_dim: h_input = self.conv_s2[s2_count](h_input) h = h_input + self.res_scale * h ######### PointNet++ MSG ######## if feature_dim != last_feature_dim: h = h.transpose(1, 2).contiguous() xyz, h = self.sa[s2_count](xyz_input, h) h = h.transpose(1, 2).contiguous() if self.id_skip == 2: h_input = pointnet2_utils.gather_operation( h_input.transpose(1, 2).contiguous(), pointnet2_utils.furthest_point_sample( xyz_input, h.shape[1])).transpose(1, 2).contiguous() else: xyz = xyz.transpose(1, 2).contiguous() ######### Residual After Downsampling (Paper) ############# if self.id_skip == 2: if istrain and self.drop_connect_rate > 0: h = drop_connect(h, p=self.drop_connect_rate, training=istrain) if feature_dim != last_feature_dim: h_input = self.conv_s2[s2_count](h_input) h = h_input + self.res_scale * h if feature_dim != last_feature_dim: s2_count += 1 last_feature_dim = feature_dim #print(xyz.shape, h.shape) if self.npart == 1: # Pooling h_max, _ = torch.max(h, 1) h_avg = torch.mean(h, 1) hs.append(h_max) hs.append(h_avg) h = torch.cat(hs, 1) h = self.embs[0](h) h = self.bn_embs[0](h) h = self.dropouts[0](h) h = self.proj_output(h) else: # Sort #batch_size, n_points, _ = h.shape #y_index = torch.argsort(xyz[:, :, 1],dim = 1).view(batch_size * n_points, -1) #h = h.view(batch_size * n_points, -1) #h = h[y_index, :].view(batch_size, n_points, -1) h = h.transpose(1, 2) # Part Pooling h = self.partpool(h) for i in range(self.npart): part_h = h[:, :, i] part_h = self.embs[i](part_h) part_h = self.bn_embs[i](part_h) part_h = self.dropouts[i](part_h) part_h = self.proj_outputs[i](part_h) hs.append(part_h) h = hs return h
def forward(batch, data_features, network, conf, \ is_val=False, step=None, epoch=None, batch_ind=0, num_batch=1, start_time=0, \ log_console=False, log_tb=False, tb_writer=None, lr=None): # prepare input input_pcs = torch.cat(batch[data_features.index('pcs')], dim=0).to(conf.device) # B x 3N x 3 input_pxids = torch.cat(batch[data_features.index('pc_pxids')], dim=0).to(conf.device) # B x 3N x 2 input_movables = torch.cat(batch[data_features.index('pc_movables')], dim=0).to(conf.device) # B x 3N batch_size = input_pcs.shape[0] input_pcid1 = torch.arange(batch_size).unsqueeze(1).repeat( 1, conf.num_point_per_shape).long().reshape(-1) # BN input_pcid2 = furthest_point_sample( input_pcs, conf.num_point_per_shape).long().reshape(-1) # BN input_pcs = input_pcs[input_pcid1, input_pcid2, :].reshape(batch_size, conf.num_point_per_shape, -1) input_pxids = input_pxids[input_pcid1, input_pcid2, :].reshape(batch_size, conf.num_point_per_shape, -1) input_movables = input_movables[input_pcid1, input_pcid2].reshape( batch_size, conf.num_point_per_shape) input_dirs1 = torch.cat( batch[data_features.index('gripper_direction_camera')], dim=0).to(conf.device) # B x 3 input_dirs2 = torch.cat( batch[data_features.index('gripper_forward_direction_camera')], dim=0).to(conf.device) # B x 3 # forward through the network pred_result_logits, pred_whole_feats = network( input_pcs, input_dirs1, input_dirs2) # B x 2, B x F x N # prepare gt gt_result = torch.Tensor(batch[data_features.index('result')]).long().to( conf.device) # B gripper_img_target = torch.cat( batch[data_features.index('gripper_img_target')], dim=0).to(conf.device) # B x 3 x H x W # for each type of loss, compute losses per data result_loss_per_data = network.critic.get_ce_loss(pred_result_logits, gt_result) # for each type of loss, compute avg loss per batch result_loss = result_loss_per_data.mean() # compute total loss total_loss = result_loss # display information data_split = 'train' if is_val: data_split = 'val' with torch.no_grad(): # log to console if log_console: utils.printout(conf.flog, \ f'''{strftime("%H:%M:%S", time.gmtime(time.time()-start_time)):>9s} ''' f'''{epoch:>5.0f}/{conf.epochs:<5.0f} ''' f'''{data_split:^10s} ''' f'''{batch_ind:>5.0f}/{num_batch:<5.0f} ''' f'''{100. * (1+batch_ind+num_batch*epoch) / (num_batch*conf.epochs):>9.1f}% ''' f'''{lr:>5.2E} ''' f'''{total_loss.item():>10.5f}''') conf.flog.flush() # log to tensorboard if log_tb and tb_writer is not None: tb_writer.add_scalar('total_loss', total_loss.item(), step) tb_writer.add_scalar('lr', lr, step) # gen visu if is_val and ( not conf.no_visu) and epoch % conf.num_epoch_every_visu == 0: visu_dir = os.path.join(conf.exp_dir, 'val_visu') out_dir = os.path.join(visu_dir, 'epoch-%04d' % epoch) input_pc_dir = os.path.join(out_dir, 'input_pc') gripper_img_target_dir = os.path.join(out_dir, 'gripper_img_target') info_dir = os.path.join(out_dir, 'info') if batch_ind == 0: # create folders os.mkdir(out_dir) os.mkdir(input_pc_dir) os.mkdir(gripper_img_target_dir) os.mkdir(info_dir) if batch_ind < conf.num_batch_every_visu: utils.printout(conf.flog, 'Visualizing ...') for i in range(batch_size): fn = 'data-%03d.png' % (batch_ind * batch_size + i) render_utils.render_pts(os.path.join( BASE_DIR, input_pc_dir, fn), input_pcs[i].cpu().numpy(), highlight_id=0) cur_gripper_img_target = ( gripper_img_target[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) Image.fromarray(cur_gripper_img_target).save( os.path.join(gripper_img_target_dir, fn)) with open( os.path.join(info_dir, fn.replace('.png', '.txt')), 'w') as fout: fout.write('cur_dir: %s\n' % batch[data_features.index('cur_dir')][i]) fout.write('pred: %s\n' % utils.print_true_false( (pred_result_logits[i] > 0).cpu().numpy())) fout.write( 'gt: %s\n' % utils.print_true_false(gt_result[i].cpu().numpy())) fout.write('result_loss: %f\n' % result_loss_per_data[i].item()) if batch_ind == conf.num_batch_every_visu - 1: # visu html utils.printout(conf.flog, 'Generating html visualization ...') sublist = 'input_pc,gripper_img_target,info' cmd = 'cd %s && python %s . 10 htmls %s %s > /dev/null' % ( out_dir, os.path.join(BASE_DIR, 'gen_html_hierachy_local.py'), sublist, sublist) call(cmd, shell=True) utils.printout(conf.flog, 'DONE') return total_loss, pred_whole_feats.detach(), input_pcs.detach( ), input_pxids.detach(), input_movables.detach()
pc_movable = (gt_movable_link_mask > 0)[mask] idx = np.arange(pc.shape[0]) np.random.shuffle(idx) while len(idx) < 30000: idx = np.concatenate([idx, idx]) idx = idx[:30000-1] pc = pc[idx, :] pc_movable = pc_movable[idx] pcids = pcids[idx, :] pc = np.vstack([pt, pc]) pcids = np.vstack([ptid, pcids]) pc_movable = np.append(True, pc_movable) pc[:, 0] -= 5 pc = torch.from_numpy(pc).unsqueeze(0).to(device) input_pcid = furthest_point_sample(pc, train_conf.num_point_per_shape).long().reshape(-1) pc = pc[:, input_pcid, :3] # 1 x N x 3 pc_movable = pc_movable[input_pcid.cpu().numpy()] # N pcids = pcids[input_pcid.cpu().numpy()] pccolors = rgb[pcids[:, 0], pcids[:, 1]] # push through unet feats = network.pointnet2(pc.repeat(1, 1, 2))[0].permute(1, 0) # N x F # setup robot robot_urdf_fn = './robots/panda_gripper.urdf' robot_material = env.get_material(4, 4, 0.01) robot = Robot(env, robot_urdf_fn, robot_material) def plot_figure(up, forward): # cam to world