示例#1
0
    def test_sumpooling(self):
        in_channels, D = 2, 2
        coords, feats, labels = data_loader(in_channels)
        feats = feats.double()
        feats.requires_grad_()
        input = SparseTensor(feats, coords=coords)
        pool = MinkowskiSumPooling(kernel_size=3, stride=2, dimension=D)
        output = pool(input)
        print(output)

        # Check backward
        fn = MinkowskiAvgPoolingFunction()
        self.assertTrue(
            gradcheck(
                fn,
                (input.F, input.tensor_stride, pool.stride, pool.kernel_size,
                 pool.dilation, pool.region_type_, pool.region_offset_, False,
                 input.coords_key, None, input.coords_man)))

        device = torch.device('cuda')
        with torch.cuda.device(0):
            input = input.to(device)
            pool = pool.to(device)
            output = pool(input)
            print(output)
示例#2
0
    def test_sumpooling(self):
        in_channels, D = 2, 2
        coords, feats, labels = data_loader(in_channels)
        feats = feats.double()
        feats.requires_grad_()
        input = SparseTensor(feats, coords)
        pool = MinkowskiSumPooling(kernel_size=3, stride=2, dimension=D)
        output = pool(input)
        print(output)

        # Check backward
        fn = MinkowskiLocalPoolingFunction()
        self.assertTrue(
            gradcheck(
                fn,
                (
                    input.F,
                    pool.pooling_mode,
                    pool.kernel_generator,
                    input.coordinate_map_key,
                    output.coordinate_map_key,
                    input._manager,
                ),
            )
        )
        input = SparseTensor(feats, coords, device=0)
        output = pool(input)
        print(output)
        self.assertTrue(
            gradcheck(
                fn,
                (
                    input.F,
                    pool.pooling_mode,
                    pool.kernel_generator,
                    input.coordinate_map_key,
                    output.coordinate_map_key,
                    input._manager,
                ),
            )
        )