def forward(self, coords, features):
        r"""
            Forward pass

            Parameters
            ----------
            coords: torch.Tensor, shape (B, N, 3)
                coordinates of the point cloud
            features: torch.Tensor, shape (B, d_in, N, 1)
                features of the point cloud

            Returns
            -------
            torch.Tensor, shape (B, 2*d_out, N, 1)
        """
        knn_output = knn(coords.cpu().contiguous(), coords.cpu().contiguous(), self.num_neighbors)

        x = self.mlp1(features)

        x = self.lse1(coords, x, knn_output)
        x = self.pool1(x)

        x = self.lse2(coords, x, knn_output)
        x = self.pool2(x)

        return self.lrelu(self.mlp2(x) + self.shortcut(features))
    def forward(self, inputs):

        # finding neighboring points
        coords = inputs['points'][:,:3]
        knn_output = knn(coords.cpu().contiguous(), coords.cpu().contiguous(), self.num_neighbors)
        idx, dist = knn_output
        B, N, K = idx.size()
        # idx(B, N, K), coords(B, N, 3)
        # neighbors[b, i, n, k] = coords[b, idx[b, n, k], i] = extended_coords[b, i, extended_idx[b, i, n, k], k]
        extended_idx = idx.unsqueeze(1).expand(B, 3, N, K)
        extended_coords = coords.transpose(-2, -1).unsqueeze(-1).expand(B, 3, N, K)
        neighbors = torch.gather(extended_coords, 2, extended_idx)  # shape (B, 3, N, K)
        # if USE_CUDA:
        #     neighbors = neighbors.cuda()

        # relative point position encoding
        concat = torch.cat((
            extended_coords,
            neighbors,
            extended_coords - neighbors,
            dist.unsqueeze(-3)
        ), dim=-3).to(self.device)  ## [B,3+3+3+1,N,K]



        if concat.shape[0] > self.part:
            # nn.Linear performs randomly when batch size is too large
            num_parts = concat.shape[0] // self.part
            part_linear_out = [self.linear(concat[num_part*self.part:(num_part+1)*self.part])
                               for num_part in range(num_parts+1)]
            x = torch.cat(part_linear_out, dim=0)
        else:
            x = self.linear(concat)
        torch.backends.cudnn.enabled = False
        x = self.norm(x.permute(0, 2, 1)).permute(0, 2, 1) if self.use_norm else x
        torch.backends.cudnn.enabled = True
        x = F.relu(x)

        ## maxpooling
        x_max = torch.max(x, dim=1, keepdim=True)[0]

        if self.last_vfe:
            return x_max
        else:
            x_repeat = x_max.repeat(1, inputs.shape[1], 1)
            x_concatenated = torch.cat([x, x_repeat], dim=2)
            return x_concatenated
    def forward(self, input):
        r"""
            Forward pass

            Parameters
            ----------
            input: torch.Tensor, shape (B, N, d_in)
                input points

            Returns
            -------
            torch.Tensor, shape (B, num_classes, N)
                segmentation scores for each point
        """
        N = input.size(1)
        d = self.decimation

        coords = input[...,:3].clone().cpu()
        x = self.fc_start(input).transpose(-2,-1).unsqueeze(-1)
        x = self.bn_start(x) # shape (B, d, N, 1)

        decimation_ratio = 1

        # <<<<<<<<<< ENCODER
        x_stack = []

        permutation = torch.randperm(N)
        coords = coords[:,permutation]
        x = x[:,:,permutation]

        for lfa in self.encoder:
            # at iteration i, x.shape = (B, N//(d**i), d_in)
            x = lfa(coords[:,:N//decimation_ratio], x)
            x_stack.append(x.clone())
            decimation_ratio *= d
            x = x[:,:,:N//decimation_ratio]


        # # >>>>>>>>>> ENCODER

        x = self.mlp(x)

        # <<<<<<<<<< DECODER
        for mlp in self.decoder:
            neighbors, _ = knn(
                coords[:,:N//decimation_ratio].cpu().contiguous(), # original set
                coords[:,:d*N//decimation_ratio].cpu().contiguous(), # upsampled set
                1
            ) # shape (B, N, 1)
            neighbors = neighbors.to(self.device)

            extended_neighbors = neighbors.unsqueeze(1).expand(-1, x.size(1), -1, 1)

            x_neighbors = torch.gather(x, -2, extended_neighbors)

            x = torch.cat((x_neighbors, x_stack.pop()), dim=1)

            x = mlp(x)

            decimation_ratio //= d

        # >>>>>>>>>> DECODER
        # inverse permutation
        x = x[:,:,torch.argsort(permutation)]

        scores = self.fc_end(x)

        return scores.squeeze(-1)
    def forward(self, input):
        r"""
            Forward pass

            Parameters
            ----------
            input: torch.Tensor, shape (B, N, d_in)
                input points

            Returns
            -------
            torch.Tensor, shape (B, num_classes, N)
                segmentation scores for each point
        """
        N = input.size(1)  # (B,N,d_in)
        d = self.decimation  # sample multiplier, e.g. 4

        coords = input[..., :3].clone().cpu()  # (B,N,3)
        x = self.fc_start(input).transpose(-2, -1).unsqueeze(-1)
        x = self.bn_start(x)  # shape (B, d, N, 1)

        decimation_ratio = 1  # Note: at first it is 1

        # <<<<<<<<<< ENCODER
        x_stack = []  # store the encoder results for decoder

        permutation = torch.randperm(N)
        coords = coords[:, permutation]  # permute points
        x = x[:, :, permutation]  # permute points

        for lfa in self.encoder:
            # at iteration i, x.shape = (B, N//(d**i), d_in)
            x = lfa(
                coords[:, :N // decimation_ratio], x
            )  # shape (B,(i+1)*4*8,N//(d*(i+1)),1), i start is current index, from 0
            x_stack.append(x.clone())
            decimation_ratio *= d

            # random sampling operation
            x = x[:, :, :N // decimation_ratio]

        # # >>>>>>>>>> ENCODER

        x = self.mlp(x)  # (B,512,N/256=256,1)

        # <<<<<<<<<< DECODER
        for mlp in self.decoder:  # the below shape only list 1st decoder layer
            neighbors, _ = knn(
                coords[:, :N // decimation_ratio].cpu().contiguous(
                ),  # original set (B,256,1)
                coords[:, :d * N // decimation_ratio].cpu().contiguous(
                ),  # upsampled set for query points (B,1024,1)
                1)  # shape (B, d*N//decimation_ratio, 1)
            # here means in the nearest nb_idx wrt query points (i.e. orginal set)
            neighbors = neighbors.to(self.device)  # (B,1024,1)

            extended_neighbors = neighbors.unsqueeze(1).expand(
                -1, x.size(1), -1, 1)  # (B,512,1024,1)

            x_neighbors = torch.gather(
                x, -2, extended_neighbors)  # find x's nearest nbs (B,512,?,1)

            x = torch.cat((x_neighbors, x_stack.pop()),
                          dim=1)  # skip connection (global + local features)

            x = mlp(x)

            decimation_ratio //= d

        # >>>>>>>>>> DECODER
        # inverse permutation
        x = x[:, :, torch.argsort(permutation)]  # (B,C,N,1)

        scores = self.fc_end(x)  #(B,C,N,1)

        return scores.squeeze(-1)  #(B,C,N)