コード例 #1
0
ファイル: base_so3conv.py プロジェクト: XYZ-99/EPN_PointCloud
    def forward(self, f1, f2, x1, x2):
        # nb, nc, np, na -> nb, nc, na
        sp1 = sptk.SphericalPointCloud(x1, f1, None)
        sp2 = sptk.SphericalPointCloud(x2, f2, None)

        f1 = self._pooling(sp1)
        f2 = self._pooling(sp2)

        nb = f1.shape[0]
        na = f1.shape[2]

        # expand and concat into metric space (nb, nc*2, na_tgt, na_src)
        f2_expand = f2.unsqueeze(-1).expand(-1, -1, -1, na).contiguous()
        f1_expand = f1.unsqueeze(-2).expand(-1, -1, na, -1).contiguous()

        x_out = torch.cat((f1_expand, f2_expand), 1)

        # fc layers with relu
        for linear in self.linear:
            x_out = linear(x_out)
            x_out = F.relu(x_out)

        attention_wts = self.attention_layer(x_out).view(nb, na, na)
        confidence = F.softmax(attention_wts * self.temperature, dim=1)
        y = self.regressor_layer(x_out)

        # return: [nb, na, na], [nb, n_out, na, na]
        return confidence, y
コード例 #2
0
ファイル: base_so3conv.py プロジェクト: XYZ-99/EPN_PointCloud
 def forward(self, frag, clouds):
     x = self.prop(frag, clouds)
     feat = self.norm(x.feats)
     if self.relu is not None:
         feat = self.relu(feat)
     if self.training and self.dropout is not None:
         feat = self.dropout(feat)
     return sptk.SphericalPointCloud(x.xyz, feat, x.anchors)
コード例 #3
0
ファイル: base_so3conv.py プロジェクト: XYZ-99/EPN_PointCloud
def preprocess_input(x, na, add_center=True):
    has_normals = x.shape[2] == 6
    # add a dummy center point at index zero
    if add_center and not has_normals:
        center = x.mean(1, keepdim=True)
        x = torch.cat((center, x), dim=1)[:, :-1]
    xyz = x[:, :, :3]
    return sptk.SphericalPointCloud(
        xyz.permute(0, 2, 1).contiguous(),
        sptk.get_occupancy_features(x, na, add_center), None)
コード例 #4
0
ファイル: base_so3conv.py プロジェクト: XYZ-99/EPN_PointCloud
    def forward(self, x):
        # [b, 3, p] x [b, c1, p]
        x = self.conv(x)
        feat = self.norm(x.feats)
        if self.relu is not None:
            feat = self.relu(feat)
        if self.training and self.dropout is not None:
            feat = self.dropout(feat)

        # [b, 3, p] x [b, c2, p]
        return sptk.SphericalPointCloud(x.xyz, feat, x.anchors)
コード例 #5
0
ファイル: base_so3conv.py プロジェクト: XYZ-99/EPN_PointCloud
    def forward(self, x, inter_idx=None, inter_w=None):
        inter_idx, inter_w, sample_idx, x = self.conv(x, inter_idx, inter_w)
        # import ipdb; ipdb.set_trace()
        feat = self.norm(x.feats)

        if self.relu is not None:
            feat = self.relu(feat)
        ## TODO no need to add self.training

        if self.training and self.dropout is not None:
            feat = self.dropout(feat)
        return inter_idx, inter_w, sample_idx, sptk.SphericalPointCloud(
            x.xyz, feat, x.anchors)
コード例 #6
0
ファイル: base_so3conv.py プロジェクト: XYZ-99/EPN_PointCloud
    def forward(self, x):
        # nb, nc, np, na -> nb, nc, na

        # attention first
        nb, nc, np, na = x.feats.shape

        attn = self.attention_layer(x.feats)
        attn = F.softmax(attn, dim=3)

        # nb, nc, np, 1
        x_out = (x.feats * attn).sum(-1, keepdim=True)
        x_in = sptk.SphericalPointCloud(x.xyz, x_out, None)

        # nb, nc
        x_out = self.pointnet(x_in).view(nb, -1)

        return F.normalize(x_out, p=2, dim=1), attn
コード例 #7
0
ファイル: base_so3conv.py プロジェクト: XYZ-99/EPN_PointCloud
    def forward(self, x, inter_idx, inter_w):
        '''
            inter, intra conv with skip connection
        '''
        skip_feature = x.feats
        inter_idx, inter_w, sample_idx, x = self.inter_conv(
            x, inter_idx, inter_w)

        if self.use_intra:
            x = self.intra_conv(x)
        if self.stride > 1:
            skip_feature = sptk.functional.batched_index_select(
                skip_feature, 2, sample_idx.long())
        skip_feature = self.skip_conv(skip_feature)
        skip_feature = self.relu(self.norm(skip_feature))
        x_out = sptk.SphericalPointCloud(x.xyz, x.feats + skip_feature,
                                         x.anchors)
        return inter_idx, inter_w, sample_idx, x_out
コード例 #8
0
ファイル: base_so3conv.py プロジェクト: XYZ-99/EPN_PointCloud
    def forward(self, x, label=None):
        x_out = x.feats
        norm_cnt = 0
        end = len(self.linear)
        for lid, linear in enumerate(self.linear):
            norm = self.norm[norm_cnt]
            x_out = linear(x_out)
            x_out = F.relu(norm(x_out))
            norm_cnt += 1

        out_feat = x_out
        x_in = sptk.SphericalPointCloud(x.xyz, out_feat, x.anchors)

        x_out = self.pointnet(x_in)

        norm = self.norm[norm_cnt]
        norm_cnt += 1
        x_out = F.relu(norm(x_out))

        # mean pooling
        if self.pooling_method == 'mean':
            x_out = x_out.mean(dim=2)
        elif self.pooling_method == 'debug':
            # for debug only
            x_out = x_out[..., 0].mean(2)
        elif self.pooling_method == 'max':
            # max pooling
            x_out = x_out.max(2)[0]
        elif self.pooling_method.startswith('attention'):
            out_feat = self.attention_layer(x_out)  # Bx1XA or BxCxA
            confidence = F.softmax(out_feat * self.temperature, dim=2)
            x_out = x_out * confidence
            x_out = x_out.sum(-1)
        else:
            raise NotImplementedError(
                f"Pooling mode {self.pooling_method} is not implemented!")

        x_out = self.fc2(x_out)

        return x_out, out_feat.squeeze()
コード例 #9
0
ファイル: base_so3conv.py プロジェクト: XYZ-99/EPN_PointCloud
    def forward(self, feats, label=None):
        x_out = feats
        norm_cnt = 0
        end = len(self.linear)
        for lid, linear in enumerate(self.linear):
            norm = self.norm[norm_cnt]
            x_out = linear(x_out)
            x_out = F.relu(norm(x_out))
            norm_cnt += 1

        # mean pool at xyz
        out_feat = x_out
        x_out = x_out.mean(2, keepdim=True)

        # group convolution after mean pool
        if hasattr(self, 'intra'):
            x_in = sptk.SphericalPointCloud(None, x_out, None)
            for lid, conv in enumerate(self.intra):
                skip_feat = x_in.feats
                x_in = conv(x_in)

                # skip connection
                norm = self.norm[norm_cnt]
                skip_feat = self.skipconv[lid](skip_feat)
                skip_feat = F.relu(norm(skip_feat))
                x_in = sptk.SphericalPointCloud(None, skip_feat + x_in.feats,
                                                None)
                norm_cnt += 1
            x_out = x_in.feats

        # mean pooling
        if self.pooling_method == 'mean':
            x_out = x_out.mean(dim=3).mean(dim=2)
        elif self.pooling_method == 'debug':
            # for debug only
            x_out = x_out[..., 0].mean(2)
        elif self.pooling_method == 'max':
            # max pooling
            x_out = x_out.mean(2).max(-1)[0]
        ############## DEBUG ONLY ######################
        elif label is not None:

            def to_one_hot(label, num_class):
                '''
                label: [B,...]
                return [B,...,num_class]
                '''
                comp = torch.arange(num_class).long().to(label.device)
                for i in range(label.dim()):
                    comp = comp.unsqueeze(0)
                onehot = label.unsqueeze(-1) == comp
                return onehot.float()

            x_out = x_out.mean(2)
            label = label.squeeze()
            if label.dim() == 2:
                cdim = x_out.shape[1]
                label = label.repeat(1, 5)[:, :cdim]
            confidence = to_one_hot(label, x_out.shape[2])
            if confidence.dim() < 3:
                confidence = confidence.unsqueeze(1)
            x_out = x_out * confidence
            x_out = x_out.sum(-1)
        ####################################################
        elif self.pooling_method.startswith('attention'):
            x_out = x_out.mean(2)
            out_feat = self.attention_layer(x_out)  # Bx1XA or BxCxA
            confidence = F.softmax(out_feat * self.temperature, dim=2)
            x_out = x_out * confidence
            x_out = x_out.sum(-1)
        else:
            raise NotImplementedError(
                f"Pooling mode {self.pooling_method} is not implemented!")

        # fc layers
        for linear in self.fc1:
            x_out = linear(x_out)
            x_out = F.relu(x_out)

        x_out = self.fc2(x_out)

        return x_out, out_feat.squeeze()