Ejemplo n.º 1
0
    def test_simple(self):
        r"""
        Performs weight linear interpolation on 3 features
        Parameters
        ----------
        features : torch.Tensor
            (B, c, m) Features descriptors to be interpolated from
        idx : torch.Tensor
            (B, n, 3) three nearest neighbors of the target features in features
        weight : torch.Tensor
            (B, n, 3) weights

        Returns
        -------
        torch.Tensor
            (B, c, n) tensor of the interpolated features
        """
        time_history_1 = []
        time_history_2 = []
        for i in range(100):
            batch_size = np.random.randint(2, 8)
            c = np.random.randint(128, 256)
            m = np.random.randint(256, 512)
            n = np.random.randint(2 * m, 4 * m)

            features = np.random.normal(0, 1, (batch_size, c, m))
            idx = np.random.randint(0, m, (batch_size, n, 3))
            weight = np.abs(np.random.normal(0, 1, (batch_size, n, 3)))

            features = torch.tensor(features,
                                    requires_grad=True).float().cuda()
            idx = torch.from_numpy(idx).int().cuda()
            weight = torch.from_numpy(weight).float().cuda()

            t0 = time()
            out = three_interpolate(features, idx, weight)
            out.backward(torch.zeros_like(out))

            t1 = time()

            out2 = three_interpolate2(features, idx, weight)
            out2.backward(torch.zeros_like(out2))
            t2 = time()

            if i > 5:
                time_history_1.append(t1 - t0)
                time_history_2.append(t2 - t1)

        print(np.mean(time_history_1), np.std(time_history_1),
              np.sum(time_history_1))
        print(np.mean(time_history_2), np.std(time_history_2),
              np.sum(time_history_2))
Ejemplo n.º 2
0
    def conv(self, pos, pos_skip, x):
        assert pos_skip.shape[2] == 3

        if pos is not None:
            dist, idx = tp.three_nn(pos_skip, pos)
            dist_recip = 1.0 / (dist + 1e-8)
            norm = torch.sum(dist_recip, dim=2, keepdim=True)
            weight = dist_recip / norm
            interpolated_feats = tp.three_interpolate(x, idx, weight)
        else:
            interpolated_feats = x.expand(*(x.size()[0:2] + (pos_skip.size(1),)))

        return interpolated_feats
    def forward(
            self, unknown: torch.Tensor, known: torch.Tensor,
            unknow_feats: torch.Tensor, known_feats: torch.Tensor
    ) -> torch.Tensor:
        r"""
        Parameters
        ----------
        unknown : torch.Tensor
            (B, n, 3) tensor of the xyz positions of the unknown features
        known : torch.Tensor
            (B, m, 3) tensor of the xyz positions of the known features
        unknow_feats : torch.Tensor
            (B, C1, n) tensor of the features to be propigated to
        known_feats : torch.Tensor
            (B, C2, m) tensor of features to be propigated

        Returns
        -------
        new_features : torch.Tensor
            (B, mlp[-1], n) tensor of the features of the unknown features
        """

        dist, idx = tp.three_nn(unknown, known)
        dist_recip = 1.0 / (dist + 1e-8)
        norm = torch.sum(dist_recip, dim=2, keepdim=True)
        weight = dist_recip / norm

        interpolated_feats = tp.three_interpolate(
            known_feats, idx, weight
        )
        if unknow_feats is not None:
            new_features = torch.cat([interpolated_feats, unknow_feats],
                                     dim=1)  #(B, C2 + C1, n)
        else:
            new_features = interpolated_feats
        
        new_features = new_features.unsqueeze(-1)
        new_features = self.mlp(new_features)

        return new_features.squeeze(-1)