Ejemplo n.º 1
0
    def forward(self, data):
        """
        Args:
            pos: (B,N,3)
            features: (B,C,N)
            seed_pos (B,N,3)
        Returns:
            VoteNetResults
        """
        if data.pos.dim() != 3:
            raise ValueError("This method only supports dense convolutions for now")

        if self.sampling == "seed_fps":
            sample_idx = tp.furthest_point_sample(data.seed_pos, self.num_proposal)
        else:
            raise ValueError("Unknown sampling strategy: %s. Exiting!" % (self.sampling))

        data_features = self.vote_aggregation(data, sampled_idx=sample_idx)

        # --------- PROPOSAL GENERATION ---------
        x = F.relu(self.bn1(self.conv1(data_features.x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.conv3(x)  # (batch_size, 2+3+num_heading_bin*2+num_size_cluster*4, num_proposal)

        return VoteNetResults.from_logits(
            data.seed_inds,
            data.pos,
            data.seed_pos,
            data_features.pos,
            x,
            self.num_class,
            self.num_heading_bin,
            self.mean_size_arr,
        )
    def sample(self, pos, **kwargs):
        """ Sample pos

        Arguments:
            pos -- [B, N, 3]

        Returns:
            indexes -- [B, num_sample]
        """
        if len(pos.shape) != 3:
            raise ValueError(" This class is for dense data and expects the pos tensor to be of dimension 2")
        return tp.furthest_point_sample(pos, self._get_num_to_sample(pos.shape[1]))