Esempio n. 1
0
    def test_device(self):
        in_channels, D = 2, 2
        device = torch.device("cuda")
        coords, feats, labels = data_loader(in_channels, batch_size=1)
        feats = feats.double()
        feats.requires_grad_()

        use_feat = (torch.rand(feats.size(0)) < 0.5).to(device)
        pruning = MinkowskiPruning()

        input = SparseTensor(feats, coords, device=device)
        output = pruning(input, use_feat)
        print(input)
        print(output)

        fn = MinkowskiPruningFunction()
        self.assertTrue(
            gradcheck(
                fn,
                (
                    input.F,
                    use_feat,
                    input.coordinate_map_key,
                    output.coordinate_map_key,
                    input.coordinate_manager,
                ),
            ))
Esempio n. 2
0
    def test_empty(self):
        in_channels = 2
        coords, feats, labels = data_loader(in_channels, batch_size=1)
        feats = feats.double()
        feats.requires_grad_()
        input = SparseTensor(feats, coords)
        use_feat = torch.BoolTensor(len(input))
        use_feat.zero_()
        pruning = MinkowskiPruning()
        output = pruning(input, use_feat)
        print(input)
        print(use_feat)
        print(output)

        # Check backward
        fn = MinkowskiPruningFunction()
        self.assertTrue(
            gradcheck(
                fn,
                (
                    input.F,
                    use_feat,
                    input.coordinate_map_key,
                    output.coordinate_map_key,
                    input.coordinate_manager,
                ),
            ))
Esempio n. 3
0
    def test_pruning(self):
        in_channels, D = 2, 2
        coords, feats, labels = data_loader(in_channels)
        feats = feats.double()
        feats.requires_grad_()
        input = SparseTensor(feats, coords=coords)
        use_feat = torch.rand(feats.size(0)) < 0.5
        pruning = MinkowskiPruning(D)
        output = pruning(input, use_feat)
        print(use_feat, output)

        # Check backward
        fn = MinkowskiPruningFunction()
        self.assertTrue(
            gradcheck(fn, (input.F, use_feat, input.coords_key,
                           output.coords_key, input.coords_man)))

        device = torch.device('cuda')
        with torch.cuda.device(0):
            input = input.to(device)
            output = pruning(input, use_feat)
            print(output)

        self.assertTrue(
            gradcheck(fn, (input.F, use_feat, input.coords_key,
                           output.coords_key, input.coords_man)))