from torch.nn import Linear as Lin  # noqa
from torch_geometric.nn.functional import (sparse_voxel_max_pool,
                                           dense_voxel_max_pool)  # noqa
from torch_geometric.visualization.model import show_model  # noqa

path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, '..', 'data', 'ModelNet10RandPC')

transform = CartesianAdj()

init_transform = NormalizeScale()
train_dataset = ModelNet10RandPC(path, True, transform=init_transform)
test_dataset = ModelNet10RandPC(path, False, transform=init_transform)

batch_size = 6
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = SplineConv(1, 32, dim=3, kernel_size=5)
        self.conv2 = SplineConv(32, 64, dim=3, kernel_size=5)
        self.conv3 = SplineConv(64, 64, dim=3, kernel_size=5)
        self.conv4 = SplineConv(64, 64, dim=3, kernel_size=5)
        self.conv5 = SplineConv(64, 64, dim=3, kernel_size=5)
        self.fc1 = nn.Linear(8 * 64, 10)

        self.att1 = Lin(32, 2)
        self.att2 = Lin(64, 2)
Esempio n. 2
0
from torch_geometric.nn.modules import SplineConv  # noqa
from torch_geometric.nn.functional import sparse_voxel_max_pool  # noqa
from torch_geometric.visualization.model import show_model  # noqa

path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, '..', 'data', 'ShapeNet2')

category = 'Pistol'

transform = CartesianAdj()
train_transform = Compose([NormalizeScale(), transform])
test_transform = Compose([NormalizeScale(), transform])
train_dataset = ShapeNet2(path, category, 'train', train_transform)
val_dataset = ShapeNet2(path, category, 'val', test_transform)
test_dataset = ShapeNet2(path, category, 'test', test_transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8)
test_loader = DataLoader(test_dataset, batch_size=8)

min_label = train_dataset.set.target.min()
max_label = train_dataset.set.target.max()
num_classes = max_label - min_label + 1


def upscale(input, cluster):
    cluster = Variable(cluster)
    cluster = cluster.view(-1, 1).expand(cluster.size(0), input.size(1))
    return torch.gather(input, 0, cluster)


class Net(nn.Module):