コード例 #1
0
    def forward(self, input):
        assert isinstance(input, spconv.SparseConvTensor)
        features = input.features
        device = features.device
        indices = input.indices
        spatial_shape = input.spatial_shape
        batch_size = input.batch_size
        if not self.subm:
            if self.transposed:
                out_spatial_shape = ops.get_deconv_output_size(
                    spatial_shape, self.kernel_size, self.stride, self.padding,
                    self.dilation, self.output_padding)
            else:
                out_spatial_shape = ops.get_conv_output_size(
                    spatial_shape, self.kernel_size, self.stride, self.padding,
                    self.dilation)
        else:
            out_spatial_shape = spatial_shape
        # input.update_grid(out_spatial_shape)
        # t = time.time()

        ## Weight Standardization
        weight = self.weight
        weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
                                                            keepdim=True).mean(
                                                                dim=3,
                                                                keepdim=True)
        weight = weight - weight_mean
        std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1,
                                                              1) + 1e-5
        weight = weight / std.expand_as(weight)

        if self.conv1x1:
            features = torch.mm(
                input.features, weight.view(self.in_channels,
                                            self.out_channels))
            if self.bias is not None:
                features += self.bias
            out_tensor = spconv.SparseConvTensor(features, input.indices,
                                                 input.spatial_shape,
                                                 input.batch_size)
            out_tensor.indice_dict = input.indice_dict
            out_tensor.grid = input.grid
            return out_tensor
        datas = input.find_indice_pair(self.indice_key)
        if self.inverse:
            assert datas is not None and self.indice_key is not None
            _, outids, indice_pairs, indice_pair_num, out_spatial_shape = datas
            assert indice_pair_num.shape[0] == np.prod(
                self.kernel_size
            ), "inverse conv must have same kernel size as its couple conv"
        else:
            if self.indice_key is not None and datas is not None:
                outids, _, indice_pairs, indice_pair_num, _ = datas
            else:
                outids, indice_pairs, indice_pair_num = ops.get_indice_pairs(
                    indices,
                    batch_size,
                    spatial_shape,
                    self.kernel_size,
                    self.stride,
                    self.padding,
                    self.dilation,
                    self.output_padding,
                    self.subm,
                    self.transposed,
                    grid=input.grid,
                    use_hash=self.use_hash)
                input.indice_dict[self.indice_key] = (outids, indices,
                                                      indice_pairs,
                                                      indice_pair_num,
                                                      spatial_shape)
        if self.fused_bn:
            assert self.bias is not None
            out_features = ops.fused_indice_conv(features, weight, self.bias,
                                                 indice_pairs.to(device),
                                                 indice_pair_num,
                                                 outids.shape[0], self.inverse,
                                                 self.subm)
        else:
            if self.subm:
                out_features = Fsp.indice_subm_conv(features, weight,
                                                    indice_pairs.to(device),
                                                    indice_pair_num,
                                                    outids.shape[0], self.algo)
            else:
                if self.inverse:
                    out_features = Fsp.indice_inverse_conv(
                        features, weight, indice_pairs.to(device),
                        indice_pair_num, outids.shape[0], self.algo)
                else:
                    out_features = Fsp.indice_conv(features, weight,
                                                   indice_pairs.to(device),
                                                   indice_pair_num,
                                                   outids.shape[0], self.algo)

            if self.bias is not None:
                out_features += self.bias
        out_tensor = spconv.SparseConvTensor(out_features, outids,
                                             out_spatial_shape, batch_size)
        out_tensor.indice_dict = input.indice_dict
        out_tensor.grid = input.grid
        return out_tensor
コード例 #2
0
ファイル: conv.py プロジェクト: Benzlxs/spconv_scannet
    def forward(self, input):
        assert isinstance(input, spconv.SparseConvTensor)
        features = input.features
        device = features.device
        indices = input.indices
        spatial_shape = input.spatial_shape
        batch_size = input.batch_size
        if not self.subm:
            if self.transposed:
                out_spatial_shape = ops.get_deconv_output_size(
                    spatial_shape, self.kernel_size, self.stride, self.padding,
                    self.dilation, self.output_padding)
            else:
                out_spatial_shape = ops.get_conv_output_size(
                    spatial_shape, self.kernel_size, self.stride, self.padding,
                    self.dilation)

        else:
            out_spatial_shape = spatial_shape
        # input.update_grid(out_spatial_shape)
        # t = time.time()
        if self.conv1x1:
            input.features = torch.mm(
                input.features,
                self.weight.view(self.in_channels, self.out_channels))
            if self.bias:
                input.features += self.bias
            return input
        datas = input.find_indice_pair(self.indice_key)
        if self.inverse:
            assert datas is not None and self.indice_key is not None
            _, outids, indice_pairs, indice_pair_num, out_spatial_shape = datas
            assert indice_pairs.shape[0] == np.prod(
                self.kernel_size
            ), "inverse conv must have same kernel size as its couple conv"
        else:
            if self.indice_key is not None and datas is not None:
                outids, _, indice_pairs, indice_pair_num, _ = datas
            else:
                outids, indice_pairs, indice_pair_num = ops.get_indice_pairs(
                    indices,
                    batch_size,
                    spatial_shape,
                    self.kernel_size,
                    self.stride,
                    self.padding,
                    self.dilation,
                    self.output_padding,
                    self.subm,
                    self.transposed,
                    grid=input.grid)
                input.indice_dict[self.indice_key] = (outids, indices,
                                                      indice_pairs,
                                                      indice_pair_num,
                                                      spatial_shape)
        if self.subm:
            out_features = Fsp.indice_subm_conv(features, self.weight,
                                                indice_pairs.to(device),
                                                indice_pair_num,
                                                outids.shape[0])
        else:
            if self.inverse:
                out_features = Fsp.indice_inverse_conv(features, self.weight,
                                                       indice_pairs.to(device),
                                                       indice_pair_num,
                                                       outids.shape[0])
            else:
                out_features = Fsp.indice_conv(features, self.weight,
                                               indice_pairs.to(device),
                                               indice_pair_num,
                                               outids.shape[0])

        if self.bias:
            out_features += self.bias
        out_tensor = spconv.SparseConvTensor(out_features, outids,
                                             out_spatial_shape, batch_size)
        out_tensor.indice_dict = input.indice_dict
        out_tensor.grid = input.grid
        return out_tensor