pre_filter1 = lambda d: d.num_nodes > 0 # noqa pre_filter2 = lambda d: d.num_nodes > 0 and d.name[:4] != '2007' # noqa transform = T.Compose([ T.Delaunay(), T.FaceToEdge(), T.Distance() if args.isotropic else T.Cartesian(), ]) path = osp.join('..', 'data', 'PascalVOC-WILLOW') pretrain_datasets = [] for category in PascalVOC.categories: dataset = PascalVOC(path, category, train=True, transform=transform, pre_filter=pre_filter2 if category in ['car', 'motorbike'] else pre_filter1) pretrain_datasets += [ValidPairDataset(dataset, dataset, sample=True)] pretrain_dataset = torch.utils.data.ConcatDataset(pretrain_datasets) pretrain_loader = DataLoader(pretrain_dataset, args.batch_size, shuffle=True, follow_batch=['x_s', 'x_t']) path = osp.join('..', 'data', 'WILLOW') datasets = [WILLOW(path, cat, transform) for cat in WILLOW.categories] device = 'cuda' if torch.cuda.is_available() else 'cpu' psi_1 = SplineCNN(dataset.num_node_features, args.dim,
args = parser.parse_args() pre_filter = lambda data: data.pos.size(0) > 0 # noqa transform = T.Compose([ T.Delaunay(), T.FaceToEdge(), T.Distance() if args.isotropic else T.Cartesian(), ]) train_datasets = [] test_datasets = [] path = osp.join('..', 'data', 'PascalVOC') for category in PascalVOC.categories: dataset = PascalVOC(path, category, train=True, transform=transform, pre_filter=pre_filter) train_datasets += [ValidPairDataset(dataset, dataset, sample=True)] dataset = PascalVOC(path, category, train=False, transform=transform, pre_filter=pre_filter) test_datasets += [ValidPairDataset(dataset, dataset, sample=True)] train_dataset = torch.utils.data.ConcatDataset(train_datasets) train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True, follow_batch=['x_s', 'x_t'])