def forward(
            self, xyz: torch.Tensor, features: Optional[torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""
        Parameters
        ----------
        xyz : torch.Tensor
            (B, N, 3) tensor of the xyz coordinates of the features
        features : torch.Tensor
            (B, C, N) tensor of the descriptors of the the features

        Returns
        -------
        new_xyz : torch.Tensor
            (B, npoint, 3) tensor of the new features' xyz
        new_features : torch.Tensor
            (B,  \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors
        """

        new_features_list = []

        xyz_flipped = xyz.transpose(1, 2).contiguous()
        new_xyz = (pointnet2_utils.gather_operation(
            xyz_flipped, pointnet2_utils.furthest_point_sample(
                xyz, self.npoint)).transpose(1, 2).contiguous()
                   if self.npoint is not None else None)

        if self.fuse == 'add':
            for i in range(len(self.groupers)):
                new_features = self.groupers[i](
                    xyz, new_xyz, features)  # (B, C, npoint, nsample)

                new_features = self.mlps[i](
                    new_features)  # (B, mlp[-1], npoint, nsample)
                new_features = F.max_pool2d(new_features,
                                            kernel_size=[
                                                1, new_features.size(3)
                                            ])  # (B, mlp[-1], npoint, 1)
                new_features = new_features.squeeze(-1)  # (B, mlp[-1], npoint)
                if i == 0:
                    new_features_sum = new_features
                else:
                    new_features_sum += new_features
            return new_xyz, new_features_sum

        for i in range(len(self.groupers)):
            new_features = self.groupers[i](
                xyz, new_xyz, features)  # (B, C, npoint, nsample)

            new_features = self.mlps[i](
                new_features)  # (B, mlp[-1], npoint, nsample)
            new_features = F.max_pool2d(new_features,
                                        kernel_size=[
                                            1, new_features.size(3)
                                        ])  # (B, mlp[-1], npoint, 1)
            new_features = new_features.squeeze(-1)  # (B, mlp[-1], npoint)

            new_features_list.append(new_features)

        return new_xyz, torch.cat(new_features_list, dim=1)
Example #2
0
def sample_and_group(npoint, radius, nsample, xyz, points):
    """
    Input:
        npoint:
        radius:
        nsample:
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, N, D]
    Return:
        new_xyz: sampled points position data, [B, npoint, nsample, 3]
        new_points: sampled points data, [B, npoint, nsample, 3+D]
    """
    B, N, C = xyz.shape
    S = npoint
    xyz = xyz.contiguous()

    fps_idx = pointnet2_utils.furthest_point_sample(
        xyz, npoint).long()  # [B, npoint]
    new_xyz = index_points(xyz, fps_idx)
    new_points = index_points(points, fps_idx)
    # new_xyz = xyz[:]
    # new_points = points[:]

    idx = knn_point(nsample, xyz, new_xyz)
    #idx = query_ball_point(radius, nsample, xyz, new_xyz)
    grouped_xyz = index_points(xyz, idx)  # [B, npoint, nsample, C]
    grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
    grouped_points = index_points(points, idx)
    grouped_points_norm = grouped_points - new_points.view(B, S, 1, -1)
    new_points = torch.cat([
        grouped_points_norm,
        new_points.view(B, S, 1, -1).repeat(1, 1, nsample, 1)
    ],
                           dim=-1)
    return new_xyz, new_points
Example #3
0
def sample_and_group(npoint, nsample, xyz, points):
    """
    Input:
        npoint:  number of selected FPS points
        nsample: number of k-nn
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, N, D]
    Return:
        new_xyz: sampled points position data, [B, npoint, nsample, 3]
        new_points: sampled points data, [B, npoint, nsample, D]  # concat[points, points-anchor]
    """
    B, N, C = xyz.shape
    S = npoint
    xyz = xyz.contiguous()  # xyz [btach, points, xyz]

    # fps_idx = farthest_point_sample(xyz, npoint).long()
    fps_idx = pointnet2_utils.furthest_point_sample(
        xyz, npoint).long()  # [B, npoint]
    new_xyz = index_points(xyz, fps_idx)
    new_points = index_points(points, fps_idx)

    idx = knn_point(nsample, xyz, new_xyz)
    # idx = query_ball_point(radius, nsample, xyz, new_xyz)
    grouped_points = index_points(points, idx)
    return new_xyz, grouped_points
Example #4
0
def pool(xyz, points, k, npoint):
    xyz_flipped = xyz.transpose(1, 2).contiguous()
    new_xyz = pointnet2_utils.gather_operation(
        xyz_flipped, pointnet2_utils.furthest_point_sample(
            xyz_flipped, npoint)).transpose(1, 2).contiguous()
    _, idx = knn_point(k, xyz, new_xyz)
    new_points = torch.max(pointnet2_utils.grouping_operation(
        points.permute(0, 2, 1).contiguous(),
        idx.int().permute(0, 2, 1).contiguous()).permute(0, 3, 2, 1),
                           dim=2).values

    return new_xyz, new_points
Example #5
0
def fps_subsample(pcd, n_points=2048):
    """
    Args
        pcd: (b, 16384, 3)

    returns
        new_pcd: (b, n_points, 3)
    """
    new_pcd = gather_operation(
        pcd.permute(0, 2, 1).contiguous(),
        furthest_point_sample(pcd, n_points))
    new_pcd = new_pcd.permute(0, 2, 1).contiguous()
    return new_pcd
Example #6
0
 def forward(self, x):
     """
     :param x: input data points corrdications [b, n, 3+c] first 3 dims are coordinates
     :return: grouped_points [b,points, knn, 3+c]
     !!! Notice that: the sampled points = grouped_points[:,:,0,:]
     """
     # sampeld_points = index_points(x, farthest_point_sample(x[:, :, :3], self.points))  # [b,points, 3]
     # sampeld_points = index_points(x, farthest_point_sample(x[:, :, :3], self.points))  # [b,points, 3]
     idx = furthest_point_sample((x[:, :, :3]).contiguous(), self.points).long()
     sampeld_points = index_points(x, idx)  # [b,points, 3]
     distances = square_distance(sampeld_points[:, :, :3], x[:, :, :3])  # including sampled points self.
     knn_idx = distances.argsort()[:, :, :self.knn]
     grouped_points = index_points(x, knn_idx)  # [b,points, knn, 3+c]
     return grouped_points
Example #7
0
 def forward(self, x, feature, num_pool):
     x = x.contiguous()
     xyz_flipped = x.transpose(1, 2).contiguous()
     x_sub = (pointnet2_utils.gather_operation(
         x, pointnet2_utils.furthest_point_sample(xyz_flipped,
                                                  num_pool)).transpose(
                                                      1, 2).contiguous())
     # sub_index = self.knn(x_sub, xyz_flipped).int()
     sub_index = pointnet2_utils.ball_query(0.2 * self.nlayer, self.k,
                                            xyz_flipped, x_sub)
     x = pointnet2_utils.grouping_operation(x, sub_index)
     x = torch.max(x, dim=-1)[0]
     # x = soft_pool2d(x, [1, x.shape[-1]]).squeeze(-1)
     feature = pointnet2_utils.grouping_operation(feature, sub_index)
     feature = self.mlp(feature)
     feature = torch.max(feature, dim=-1)[0]
     # feature = soft_pool2d(feature, [1, feature.shape[-1]]).squeeze(-1)
     return x, feature
Example #8
0
    def forward(self, xyz, points):
        B, N, C = xyz.shape
        S = self.groups
        xyz = xyz.contiguous()  # xyz [btach, points, xyz]

        # fps_idx = farthest_point_sample(xyz, self.groups).long()
        fps_idx = pointnet2_utils.furthest_point_sample(xyz, self.groups).long() # [B, npoint]
        new_xyz = index_points(xyz, fps_idx)
        new_points = index_points(points, fps_idx)

        idx = knn_point(self.kneighbors, xyz, new_xyz)
        # idx = query_ball_point(radius, nsample, xyz, new_xyz)
        # grouped_xyz = index_points(xyz, idx)  # [B, npoint, nsample, C]
        grouped_points = index_points(points, idx)
        grouped_points_norm = grouped_points - new_points.view(B, S, 1, -1)
        new_points = torch.cat([grouped_points_norm,
                                new_points.view(B, S, 1, -1).repeat(1, 1, self.kneighbors, 1)]
                               , dim=-1)
        return new_xyz, new_points
Example #9
0
    def __init__(
            self, transforms=None, train=True, self_supervision=False, num_points=2048,
    ):
        super().__init__()

        self.transforms = transforms

        self.self_supervision = self_supervision

        self.train = train

        self.num_points = num_points

        root = '/home/zhangniansong/data/ScanObjNN/main_split_nobg/'
        if self.self_supervision:
            h5 = h5py.File(root + 'training_objectdataset.h5', 'r')
            points_train = np.array(h5['data']).astype(np.float32)
            h5.close()
            self.points = points_train
            self.labels = None
        elif train:
            h5 = h5py.File(root + 'training_objectdataset.h5', 'r')
            self.points = np.array(h5['data']).astype(np.float32)
            self.labels = np.array(h5['label']).astype(int)
            h5.close()
        else:
            h5 = h5py.File(root + 'test_objectdataset.h5', 'r')
            self.points = np.array(h5['data']).astype(np.float32)
            self.labels = np.array(h5['label']).astype(int)
            h5.close()

        self.points_ = torch.tensor(self.points).cuda() # maybe modify to `to(device)`
        fps_idx = pointnet2_utils.furthest_point_sample(self.points_, self.num_points)
        self.points_ = index_points(self.points_, fps_idx.long())
        
        # BUG FIX:
        # if move data on gpu in dataset, will raise CUDA_NOT_INITIALIZED error in dataloader
        self.points = self.points_.cpu()

        del self.points_
        torch.cuda.empty_cache()

        print('Successfully load ScanObjectNN with', len(self.labels), 'instances')
Example #10
0
def sample_and_group(xyz, points, npoint, nsample, radius, use_xyz=True):
    """
    Args:
        xyz: Tensor, (B, 3, N)
        points: Tensor, (B, f, N)
        npoint: int
        nsample: int
        radius: float
        use_xyz: boolean

    Returns:
        new_xyz: Tensor, (B, 3, npoint)
        new_points: Tensor, (B, 3 | f+3 | f, npoint, nsample)
        idx_local: Tensor, (B, npoint, nsample)
        grouped_xyz: Tensor, (B, 3, npoint, nsample)

    """
    xyz_flipped = xyz.permute(0, 2, 1).contiguous()  # (B, N, 3)
    new_xyz = gather_operation(xyz,
                               furthest_point_sample(xyz_flipped,
                                                     npoint))  # (B, 3, npoint)

    idx = ball_query(radius, nsample, xyz_flipped,
                     new_xyz.permute(0, 2,
                                     1).contiguous())  # (B, npoint, nsample)
    grouped_xyz = grouping_operation(xyz, idx)  # (B, 3, npoint, nsample)
    grouped_xyz -= new_xyz.unsqueeze(3).repeat(1, 1, 1, nsample)

    if points is not None:
        grouped_points = grouping_operation(points,
                                            idx)  # (B, f, npoint, nsample)
        if use_xyz:
            new_points = torch.cat([grouped_xyz, grouped_points], 1)
        else:
            new_points = grouped_points
    else:
        new_points = grouped_xyz

    return new_xyz, new_points, idx, grouped_xyz
Example #11
0
def sample_and_ball_group(s, radius, n, coords, features):
    """
    Sampling by FPS and grouping by ball query.

    Input:
        s[int]: number of points to be sampled by FPS
        k[int]: number of points to be grouped into a neighbor by ball query
        n[int]: fix number of points in ball neighbor
        coords[tensor]: input points coordinates data with size of [B, N, 3]
        features[tensor]: input points features data with size of [B, N, D]
    
    Returns:
        new_coords[tensor]: sampled and grouped points coordinates by FPS with size of [B, s, k, 3]
        new_features[tensor]: sampled and grouped points features by FPS with size of [B, s, k, 2D]
    """
    batch_size = coords.shape[0]
    coords = coords.contiguous()

    # FPS sampling
    fps_idx = pointnet2_utils.furthest_point_sample(coords, s).long()  # [B, s]
    new_coords = index_points(coords, fps_idx)  # [B, s, 3]
    new_features = index_points(features, fps_idx)  # [B, s, D]

    # ball_query grouping
    idx = query_ball_point(radius, n, coords, new_coords)  # [B, s, n]
    grouped_features = index_points(features, idx)  # [B, s, n, D]

    # Matrix sub
    grouped_features_norm = grouped_features - new_features.view(
        batch_size, s, 1, -1)  # [B, s, n, D]

    # Concat, my be different in many networks
    aggregated_features = torch.cat([
        grouped_features_norm,
        new_features.view(batch_size, s, 1, -1).repeat(1, 1, n, 1)
    ],
                                    dim=-1)  # [B, s, n, 2D]

    return new_coords, aggregated_features  # [B, s, 3], [B, s, n, 2D]
Example #12
0
def sample_and_knn_group(s, k, coords, features):
    """
    Sampling by FPS and grouping by KNN.

    Input:
        s[int]: number of points to be sampled by FPS
        k[int]: number of points to be grouped into a neighbor by KNN
        coords[tensor]: input points coordinates data with size of [B, N, 3]
        features[tensor]: input points features data with size of [B, N, D]
    
    Returns:
        new_coords[tensor]: sampled and grouped points coordinates by FPS with size of [B, s, k, 3]
        new_features[tensor]: sampled and grouped points features by FPS with size of [B, s, k, 2D]
    """
    batch_size = coords.shape[0]
    coords = coords.contiguous()

    # FPS sampling
    fps_idx = pointnet2_utils.furthest_point_sample(coords, s).long()  # [B, s]
    new_coords = index_points(coords, fps_idx)  # [B, s, 3]
    new_features = index_points(features, fps_idx)  # [B, s, D]

    # K-nn grouping
    idx = knn_point(k, coords, new_coords)  # [B, s, k]
    grouped_features = index_points(features, idx)  # [B, s, k, D]

    # Matrix sub
    grouped_features_norm = grouped_features - new_features.view(
        batch_size, s, 1, -1)  # [B, s, k, D]

    # Concat
    aggregated_features = torch.cat([
        grouped_features_norm,
        new_features.view(batch_size, s, 1, -1).repeat(1, 1, k, 1)
    ],
                                    dim=-1)  # [B, s, k, 2D]

    return new_coords, aggregated_features  # [B, s, 3], [B, s, k, 2D]
Example #13
0
            vis.add_geometry(pcd)

            ctr = vis.get_view_control()
            ctr.rotate(10.0, 0.0)

            vis.update_geometry()
            vis.poll_events()
            vis.update_renderer()
            time.sleep(1)
            vis.capture_screen_image('{}b.png'.format(i))
            vis.remove_geometry(pcd)

            point = torch.from_numpy(pcd.points)
            x_sub = (
            pointnet2_utils.gather_operation(
                xyz_flipped, pointnet2_utils.furthest_point_sample(x, num_pool)
                ).transpose(1, 2).contiguous()
            )
            vis.update_geometry(pcd)
            vis.poll_events()
            vis.update_renderer()
            time.sleep(1)
            vis.capture_screen_image('{}b.png'.format(i))
            vis.remove_geometry(pcd)

            # pcd = o3d.geometry.PointCloud()
            # pcd.points = o3d.utility.Vector3dVector(data)
            # o3d.visualization.draw_geometries([pcd],
            #                         zoom=1,
            #                         front=[1, 1, -1],
            #                         lookat=[0, 0, 0],
Example #14
0
    def loadAndPredict(self, folder_path):
        # Load the data
        data = loadData(folder_path)

        gt_movable_link_mask = data['gt_movable_link_mask']
        rgb = data['rgb']

        # sample a pixel to interact
        xs, ys = np.where(gt_movable_link_mask > 0)
        if len(xs) == 0:
            print('No Movable Pixel! Quit!')
            exit(1)
        idx = np.random.randint(len(xs))
        x, y = xs[idx], ys[idx]
        marked_rgb = (rgb * 255).astype(np.uint8)
        marked_rgb = cv2.circle(marked_rgb, (y, x),
                                radius=3,
                                color=(0, 0, 255),
                                thickness=5)

        # prepare input pc
        cam_XYZA_id1, cam_XYZA_id2, cam_XYZA_pts = data['id1'], data[
            'id2'], data['pc']
        cam_XYZA = compute_XYZA_matrix(cam_XYZA_id1, cam_XYZA_id2,
                                       cam_XYZA_pts, rgb.shape[0],
                                       rgb.shape[1])
        pt = cam_XYZA[x, y, :3]
        ptid = np.array([x, y], dtype=np.int32)
        mask = (cam_XYZA[:, :, 3] > 0.5)
        mask[x, y] = False
        pc = cam_XYZA[mask, :3]
        grid_x, grid_y = np.meshgrid(np.arange(448), np.arange(448))
        grid_xy = np.stack([grid_y, grid_x]).astype(np.int32)  # 2 x 448 x 448
        pcids = grid_xy[:, mask].T
        pc_movable = (gt_movable_link_mask > 0)[mask]
        idx = np.arange(pc.shape[0])
        np.random.shuffle(idx)
        while len(idx) < 30000:
            idx = np.concatenate([idx, idx])
        idx = idx[:30000 - 1]
        pc = pc[idx, :]
        pc_movable = pc_movable[idx]
        pcids = pcids[idx, :]
        pc = np.vstack([pt, pc])
        pcids = np.vstack([ptid, pcids])
        pc_movable = np.append(True, pc_movable)
        pc[:, 0] -= 5
        pc = torch.from_numpy(pc).unsqueeze(0).to(self.device)

        input_pcid = furthest_point_sample(
            pc, self.train_conf.num_point_per_shape).long().reshape(-1)
        pc = pc[:, input_pcid, :3]  # 1 x N x 3
        pc_movable = pc_movable[input_pcid.cpu().numpy()]  # N
        pcids = pcids[input_pcid.cpu().numpy()]
        pccolors = rgb[pcids[:, 0], pcids[:, 1]]

        # push through unet
        feats = self.network.pointnet2(pc.repeat(1, 1,
                                                 2))[0].permute(1, 0)  # N x F

        # sample a random direction to query
        gripper_direction_camera = torch.randn(1, 3).to(self.device)
        gripper_direction_camera = F.normalize(gripper_direction_camera, dim=1)
        gripper_forward_direction_camera = torch.randn(1, 3).to(self.device)
        gripper_forward_direction_camera = F.normalize(
            gripper_forward_direction_camera, dim=1)

        up = gripper_direction_camera
        forward = gripper_forward_direction_camera
        left = torch.cross(up, forward)
        forward = torch.cross(left, up)
        forward = F.normalize(forward, dim=1)

        dirs1 = up.repeat(self.train_conf.num_point_per_shape, 1)
        dirs2 = forward.repeat(self.train_conf.num_point_per_shape, 1)

        input_queries = torch.cat([dirs1, dirs2], dim=1)
        net = self.network.critic(feats, input_queries)
        result = torch.sigmoid(net).detach().cpu().numpy()
        result *= pc_movable

        point_cloud = pc.cpu().numpy()[0]

        return point_cloud, result
Example #15
0
    def forward(self, xyz, rgb, istrain=False):
        hs = []
        #xyz_copy = xyz.clone()
        #rgb_copy = rgb.clone()
        batch_size, n_points, _ = xyz.shape
        part_length = n_points//self.npart
        last_point = -1
        #h = self.proj_in(rgb)
        h = rgb
        s2_count = 0
        for i in range(self.num_layers):
            h_input = h.clone()
            xyz_input = xyz.clone()
            batch_size, n_points, _ = h.shape
            if self.k>1:
                if i == self.num_layers-1:
                    if self.cluster == 'xyz':
                        g = self.nng(xyz, istrain = istrain and self.graph_jitter)
                    elif self.cluster == 'rgb':
                        g = self.nng(h, istrain=istrain and self.graph_jitter)
                    elif self.cluster == 'xyzrgb':
                        g = self.nng( torch.cat((xyz,h), 2), istrain=istrain and self.graph_jitter)
                elif i==0 or  n_points !=  last_point:
                    g = self.nng(xyz, istrain=istrain and self.graph_jitter)
            last_point = n_points
            h = h.view(batch_size * n_points, -1)

            if self.k==1:
                h = self.conv[i](h)
            elif self.conv_type == 'GatedGCN':
                h = self.conv[i](g, h, g.edata['feat'], snorm_n = 1/g.number_of_nodes() , snorm_e = 1/g.number_of_edges())
            else:
                h = self.conv[i](g, h)
            h = F.leaky_relu(h, 0.2)
            h = h.view(batch_size, n_points, -1)
            h = h.transpose(1, 2).contiguous()
            xyz, h  = self.sa[i](xyz_input, h)
            h = h.transpose(1, 2).contiguous()
            #h = self.conv_s1[i](h)
            if self.id_skip and  h.shape[1] <= self.init_points//4:
            # We could use identity mapping Here or add connect drop
                if istrain and self.drop_connect_rate>0:
                    h = drop_connect(h, p=self.drop_connect_rate, training=istrain)

                if h.shape[1] == n_points:
                    h = h_input + self.res_scale * h  # Here I borrow the idea from Inception-ResNet-v2
                elif h.shape[1] == n_points//2:
                    h_input_s2 = pointnet2_utils.gather_operation(
                        h_input.transpose(1, 2).contiguous(), 
                        pointnet2_utils.furthest_point_sample(xyz_input, h.shape[1] )
                    ).transpose(1, 2).contiguous()
                    h = self.conv_s2[s2_count](h_input_s2) + self.res_scale * h
                    s2_count +=1
        if self.npart==1:
            # Pooling
            h_max, _ = torch.max(h, 1)
            h_avg = torch.mean(h, 1)
            hs.append(h_max)
            hs.append(h_avg)

            h = torch.cat(hs, 1)
            h = self.embs[0](h)
            h = self.bn_embs[0](h)
            h = self.dropouts[0](h)
            h = self.proj_output(h)
        else:
            # Sort 
            batch_size, n_points, _ = h.shape
            y_index = torch.argsort(xyz[:, :, 1],dim = 1).view(batch_size * n_points, -1)
            h = h.view(batch_size * n_points, -1)
            h = h[y_index, :].view(batch_size, n_points, -1)
            h = h.transpose(1, 2)
            # Part Pooling            
            h = self.partpool(h)
            for i in range(self.npart):
                part_h = h[:,:,i]
                part_h = self.embs[i](part_h)
                part_h = self.bn_embs[i](part_h)
                part_h = self.dropouts[i](part_h)
                part_h = self.proj_outputs[i](part_h)
                hs.append(part_h)
            h = hs
        return h
Example #16
0
    def forward(self, xyz, rgb, istrain=False):
        hs = []
        #xyz_copy = xyz.clone()
        #rgb_copy = rgb.clone()
        batch_size, n_points, _ = xyz.shape
        part_length = n_points // self.npart
        last_point = -1
        last_feature_dim = -1
        #h = self.proj_in(rgb)
        h = rgb
        s2_count = 0
        for i in range(self.num_layers):
            h_input = h.clone()
            xyz_input = xyz.clone()
            batch_size, n_points, feature_dim = h.shape

            ######## Build Graph #########
            last_point = n_points

            ######### Dynamic Graph Conv #########
            xyz = xyz.transpose(1, 2).contiguous()
            #print(h.shape) # batchsize x point_number x feature_dim
            h = h.transpose(1, 2).contiguous()
            for j in range(self.num_conv):
                index = self.num_conv * i + j
                ####### BN + ReLU #####
                if self.pre_act == True:
                    if self.norm == 'ln':
                        h = h.transpose(1, 2).contiguous()
                        h = self.bn[index](h)
                        h = h.transpose(1, 2).contiguous()
                    else:
                        h = self.bn[index](h)
                    h = F.leaky_relu(h, 0.2)

                ####### Graph Feature ###########
                if self.k == 1 and j == 0:
                    h = h.unsqueeze(-1)
                else:
                    if i == self.num_layers - 1:
                        if self.cluster == 'xyz':
                            h = get_graph_feature(xyz, h, k=self.k)
                        elif self.cluster == 'xyzrgb' or self.cluster == 'allxyzrgb':
                            h = get_graph_feature(torch.cat((xyz, h), 1),
                                                  h,
                                                  k=self.k)
                    else:
                        # Common Layers
                        if self.cluster == 'allxyzrgb':
                            h = get_graph_feature(torch.cat((xyz, h), 1),
                                                  h,
                                                  k=self.k)
                        else:
                            h = get_graph_feature(xyz, h, k=self.k)

                ####### Conv ##########
                if self.light == True and i > 0:
                    #shuffle after the first layer
                    h = channel_shuffle(h, 2)
                    h = self.conv[index](h)
                else:
                    h = self.conv[index](h)
                h = h.max(dim=-1, keepdim=False)[0]
                ####### BN + ReLU #####
                if self.pre_act == False:
                    if self.norm == 'ln':
                        h = h.transpose(1, 2).contiguous()
                        h = self.bn[index](h)
                        h = h.transpose(1, 2).contiguous()
                    else:
                        h = self.bn[index](h)
                    h = F.leaky_relu(h, 0.2)

            h = h.transpose(1, 2).contiguous()
            #print(h.shape) # batchsize x point_number x feature_dim
            batch_size, n_points, feature_dim = h.shape

            ######### Residual Before Downsampling#############
            if self.id_skip == 1:
                if istrain and self.drop_connect_rate > 0:
                    h = drop_connect(h,
                                     p=self.drop_connect_rate,
                                     training=istrain)
                if feature_dim != last_feature_dim:
                    h_input = self.conv_s2[s2_count](h_input)
                h = h_input + self.res_scale * h

            ######### PointNet++ MSG ########
            if feature_dim != last_feature_dim:
                h = h.transpose(1, 2).contiguous()
                xyz, h = self.sa[s2_count](xyz_input, h)
                h = h.transpose(1, 2).contiguous()
                if self.id_skip == 2:
                    h_input = pointnet2_utils.gather_operation(
                        h_input.transpose(1, 2).contiguous(),
                        pointnet2_utils.furthest_point_sample(
                            xyz_input, h.shape[1])).transpose(1,
                                                              2).contiguous()
            else:
                xyz = xyz.transpose(1, 2).contiguous()

            ######### Residual After Downsampling (Paper) #############
            if self.id_skip == 2:
                if istrain and self.drop_connect_rate > 0:
                    h = drop_connect(h,
                                     p=self.drop_connect_rate,
                                     training=istrain)
                if feature_dim != last_feature_dim:
                    h_input = self.conv_s2[s2_count](h_input)
                h = h_input + self.res_scale * h

            if feature_dim != last_feature_dim:
                s2_count += 1
                last_feature_dim = feature_dim

            #print(xyz.shape, h.shape)
        if self.npart == 1:
            # Pooling
            h_max, _ = torch.max(h, 1)
            h_avg = torch.mean(h, 1)
            hs.append(h_max)
            hs.append(h_avg)

            h = torch.cat(hs, 1)
            h = self.embs[0](h)
            h = self.bn_embs[0](h)
            h = self.dropouts[0](h)
            h = self.proj_output(h)
        else:
            # Sort
            #batch_size, n_points, _ = h.shape
            #y_index = torch.argsort(xyz[:, :, 1],dim = 1).view(batch_size * n_points, -1)
            #h = h.view(batch_size * n_points, -1)
            #h = h[y_index, :].view(batch_size, n_points, -1)
            h = h.transpose(1, 2)
            # Part Pooling
            h = self.partpool(h)
            for i in range(self.npart):
                part_h = h[:, :, i]
                part_h = self.embs[i](part_h)
                part_h = self.bn_embs[i](part_h)
                part_h = self.dropouts[i](part_h)
                part_h = self.proj_outputs[i](part_h)
                hs.append(part_h)
            h = hs
        return h
def forward(batch, data_features, network, conf, \
        is_val=False, step=None, epoch=None, batch_ind=0, num_batch=1, start_time=0, \
        log_console=False, log_tb=False, tb_writer=None, lr=None):
    # prepare input
    input_pcs = torch.cat(batch[data_features.index('pcs')],
                          dim=0).to(conf.device)  # B x 3N x 3
    input_pxids = torch.cat(batch[data_features.index('pc_pxids')],
                            dim=0).to(conf.device)  # B x 3N x 2
    input_movables = torch.cat(batch[data_features.index('pc_movables')],
                               dim=0).to(conf.device)  # B x 3N
    batch_size = input_pcs.shape[0]

    input_pcid1 = torch.arange(batch_size).unsqueeze(1).repeat(
        1, conf.num_point_per_shape).long().reshape(-1)  # BN
    input_pcid2 = furthest_point_sample(
        input_pcs, conf.num_point_per_shape).long().reshape(-1)  # BN
    input_pcs = input_pcs[input_pcid1,
                          input_pcid2, :].reshape(batch_size,
                                                  conf.num_point_per_shape, -1)
    input_pxids = input_pxids[input_pcid1,
                              input_pcid2, :].reshape(batch_size,
                                                      conf.num_point_per_shape,
                                                      -1)
    input_movables = input_movables[input_pcid1, input_pcid2].reshape(
        batch_size, conf.num_point_per_shape)

    input_dirs1 = torch.cat(
        batch[data_features.index('gripper_direction_camera')],
        dim=0).to(conf.device)  # B x 3
    input_dirs2 = torch.cat(
        batch[data_features.index('gripper_forward_direction_camera')],
        dim=0).to(conf.device)  # B x 3

    # forward through the network
    pred_result_logits, pred_whole_feats = network(
        input_pcs, input_dirs1, input_dirs2)  # B x 2, B x F x N

    # prepare gt
    gt_result = torch.Tensor(batch[data_features.index('result')]).long().to(
        conf.device)  # B
    gripper_img_target = torch.cat(
        batch[data_features.index('gripper_img_target')],
        dim=0).to(conf.device)  # B x 3 x H x W

    # for each type of loss, compute losses per data
    result_loss_per_data = network.critic.get_ce_loss(pred_result_logits,
                                                      gt_result)

    # for each type of loss, compute avg loss per batch
    result_loss = result_loss_per_data.mean()

    # compute total loss
    total_loss = result_loss

    # display information
    data_split = 'train'
    if is_val:
        data_split = 'val'

    with torch.no_grad():
        # log to console
        if log_console:
            utils.printout(conf.flog, \
                f'''{strftime("%H:%M:%S", time.gmtime(time.time()-start_time)):>9s} '''
                f'''{epoch:>5.0f}/{conf.epochs:<5.0f} '''
                f'''{data_split:^10s} '''
                f'''{batch_ind:>5.0f}/{num_batch:<5.0f} '''
                f'''{100. * (1+batch_ind+num_batch*epoch) / (num_batch*conf.epochs):>9.1f}%      '''
                f'''{lr:>5.2E} '''
                f'''{total_loss.item():>10.5f}''')
            conf.flog.flush()

        # log to tensorboard
        if log_tb and tb_writer is not None:
            tb_writer.add_scalar('total_loss', total_loss.item(), step)
            tb_writer.add_scalar('lr', lr, step)

        # gen visu
        if is_val and (
                not conf.no_visu) and epoch % conf.num_epoch_every_visu == 0:
            visu_dir = os.path.join(conf.exp_dir, 'val_visu')
            out_dir = os.path.join(visu_dir, 'epoch-%04d' % epoch)
            input_pc_dir = os.path.join(out_dir, 'input_pc')
            gripper_img_target_dir = os.path.join(out_dir,
                                                  'gripper_img_target')
            info_dir = os.path.join(out_dir, 'info')

            if batch_ind == 0:
                # create folders
                os.mkdir(out_dir)
                os.mkdir(input_pc_dir)
                os.mkdir(gripper_img_target_dir)
                os.mkdir(info_dir)

            if batch_ind < conf.num_batch_every_visu:
                utils.printout(conf.flog, 'Visualizing ...')
                for i in range(batch_size):
                    fn = 'data-%03d.png' % (batch_ind * batch_size + i)
                    render_utils.render_pts(os.path.join(
                        BASE_DIR, input_pc_dir, fn),
                                            input_pcs[i].cpu().numpy(),
                                            highlight_id=0)
                    cur_gripper_img_target = (
                        gripper_img_target[i].permute(1, 2, 0).cpu().numpy() *
                        255).astype(np.uint8)
                    Image.fromarray(cur_gripper_img_target).save(
                        os.path.join(gripper_img_target_dir, fn))
                    with open(
                            os.path.join(info_dir, fn.replace('.png', '.txt')),
                            'w') as fout:
                        fout.write('cur_dir: %s\n' %
                                   batch[data_features.index('cur_dir')][i])
                        fout.write('pred: %s\n' % utils.print_true_false(
                            (pred_result_logits[i] > 0).cpu().numpy()))
                        fout.write(
                            'gt: %s\n' %
                            utils.print_true_false(gt_result[i].cpu().numpy()))
                        fout.write('result_loss: %f\n' %
                                   result_loss_per_data[i].item())

            if batch_ind == conf.num_batch_every_visu - 1:
                # visu html
                utils.printout(conf.flog, 'Generating html visualization ...')
                sublist = 'input_pc,gripper_img_target,info'
                cmd = 'cd %s && python %s . 10 htmls %s %s > /dev/null' % (
                    out_dir,
                    os.path.join(BASE_DIR, 'gen_html_hierachy_local.py'),
                    sublist, sublist)
                call(cmd, shell=True)
                utils.printout(conf.flog, 'DONE')

    return total_loss, pred_whole_feats.detach(), input_pcs.detach(
    ), input_pxids.detach(), input_movables.detach()
Example #18
0
pc_movable = (gt_movable_link_mask > 0)[mask]
idx = np.arange(pc.shape[0])
np.random.shuffle(idx)
while len(idx) < 30000:
    idx = np.concatenate([idx, idx])
idx = idx[:30000-1]
pc = pc[idx, :]
pc_movable = pc_movable[idx]
pcids = pcids[idx, :]
pc = np.vstack([pt, pc])
pcids = np.vstack([ptid, pcids])
pc_movable = np.append(True, pc_movable)
pc[:, 0] -= 5
pc = torch.from_numpy(pc).unsqueeze(0).to(device)

input_pcid = furthest_point_sample(pc, train_conf.num_point_per_shape).long().reshape(-1)
pc = pc[:, input_pcid, :3]  # 1 x N x 3
pc_movable = pc_movable[input_pcid.cpu().numpy()]     # N
pcids = pcids[input_pcid.cpu().numpy()]
pccolors = rgb[pcids[:, 0], pcids[:, 1]]

# push through unet
feats = network.pointnet2(pc.repeat(1, 1, 2))[0].permute(1, 0)    # N x F

# setup robot
robot_urdf_fn = './robots/panda_gripper.urdf'
robot_material = env.get_material(4, 4, 0.01)
robot = Robot(env, robot_urdf_fn, robot_material)

def plot_figure(up, forward):
    # cam to world