from torch.nn import Linear as Lin # noqa from torch_geometric.nn.functional import (sparse_voxel_max_pool, dense_voxel_max_pool) # noqa from torch_geometric.visualization.model import show_model # noqa path = os.path.dirname(os.path.realpath(__file__)) path = os.path.join(path, '..', 'data', 'ModelNet10RandPC') transform = CartesianAdj() init_transform = NormalizeScale() train_dataset = ModelNet10RandPC(path, True, transform=init_transform) test_dataset = ModelNet10RandPC(path, False, transform=init_transform) batch_size = 6 train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size) class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = SplineConv(1, 32, dim=3, kernel_size=5) self.conv2 = SplineConv(32, 64, dim=3, kernel_size=5) self.conv3 = SplineConv(64, 64, dim=3, kernel_size=5) self.conv4 = SplineConv(64, 64, dim=3, kernel_size=5) self.conv5 = SplineConv(64, 64, dim=3, kernel_size=5) self.fc1 = nn.Linear(8 * 64, 10) self.att1 = Lin(32, 2) self.att2 = Lin(64, 2)
from torch_geometric.nn.modules import SplineConv # noqa from torch_geometric.nn.functional import sparse_voxel_max_pool # noqa from torch_geometric.visualization.model import show_model # noqa path = os.path.dirname(os.path.realpath(__file__)) path = os.path.join(path, '..', 'data', 'ShapeNet2') category = 'Pistol' transform = CartesianAdj() train_transform = Compose([NormalizeScale(), transform]) test_transform = Compose([NormalizeScale(), transform]) train_dataset = ShapeNet2(path, category, 'train', train_transform) val_dataset = ShapeNet2(path, category, 'val', test_transform) test_dataset = ShapeNet2(path, category, 'test', test_transform) train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=8) test_loader = DataLoader(test_dataset, batch_size=8) min_label = train_dataset.set.target.min() max_label = train_dataset.set.target.max() num_classes = max_label - min_label + 1 def upscale(input, cluster): cluster = Variable(cluster) cluster = cluster.view(-1, 1).expand(cluster.size(0), input.size(1)) return torch.gather(input, 0, cluster) class Net(nn.Module):