예제 #1
0
    def test_pn2(self):
        from torch_points3d.applications.pointnet2 import PointNet2

        input_nc = 2
        num_layers = 3
        output_nc = 5
        model = PointNet2(
            architecture="unet",
            input_nc=input_nc,
            output_nc=output_nc,
            num_layers=num_layers,
            multiscale=True,
            config=None,
        )
        dataset = MockDataset(input_nc, num_points=512)
        self.assertEqual(len(model._modules["down_modules"]), num_layers - 1)
        self.assertEqual(len(model._modules["inner_modules"]), 1)
        self.assertEqual(len(model._modules["up_modules"]), num_layers)

        try:
            data_out = model.forward(dataset[0])
            self.assertEqual(data_out.x.shape[1], output_nc)
        except Exception as e:
            print("Model failing:")
            print(model)
            raise e
예제 #2
0
import torch
from torch_points3d.applications.pointnet2 import PointNet2
from torch_geometric.data import Batch, Data

num_points = 1024
num_classes = 10
input_nc = 5

pos = torch.randn((num_points, 3)).unsqueeze(0)
T = torch.randn((num_points, input_nc)).unsqueeze(0)

data = Data(pos=pos, x=T)

data = Batch.from_data_list([data, data])
print(data)
# Batch(batch=[2], pos=[2, 1024, 3], x=[2, 1024, 5])

model = PointNet2(
    architecture="encoder",
    input_nc=input_nc,
    num_layers=3,
    output_nc=num_classes,
)

res = model(data)
print(res)
# Data(x=[2, 10, 1])
예제 #3
0
import torch
from torch_points3d.applications.pointnet2 import PointNet2
from torch_geometric.data import Batch, Data

num_points = 1024
num_classes = 10
input_nc = 5

pos = torch.randn((num_points, 3)).unsqueeze(0)
T = torch.randn((num_points, input_nc)).unsqueeze(0)

data = Data(pos=pos, x=T)
data = Batch.from_data_list([data, data])

print(data)
# Batch(batch=[2], pos=[2, 1024, 3], x=[2, 1024, 5])

# pos gets concatenated with x features
model = PointNet2(architecture="unet",
                  input_nc=input_nc,
                  num_layers=3,
                  output_nc=num_classes)

res = model(data)
print(res)
# Data(pos=[2, 1024, 3], x=[2, 10, 1024])