def prepare_data(dataset, seed): """ :param dataset: name of the dataset used :return: data, in the correct format """ # Retrieve main path of project dirname = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) # Download and store dataset at chosen location if dataset == 'Cora' or dataset == 'PubMed' or dataset == 'Citeseer': path = os.path.join(dirname, 'data') data = Planetoid(path, name=dataset, split='full')[0] # data.train_mask, data.val_mask, data.test_mask = split_function(data.y.numpy()) data.num_classes = (max(data.y) + 1).item() # dataset = Planetoid(path, name=dataset, split='public', transform=T.NormalizeFeatures(), num_train_per_class=20, num_val=500, num_test=1000) # data = modify_train_mask(data) elif dataset == 'Amazon': path = os.path.join(dirname, 'data', 'Amazon') data = Amazon(path, 'photo')[0] data.num_classes = (max(data.y) + 1).item() data.train_mask, data.val_mask, data.test_mask = split_function( data.y.numpy()) # Amazon: 4896 train, 1224 val, 1530 test elif dataset == 'Reddit': path = os.path.join(dirname, 'data', 'Reedit') data = Reddit(path)[0] data.train_mask, data.val_mask, data.test_mask = split_function( data.y.numpy()) elif dataset == 'PPI': path = os.path.join(dirname, 'data', 'PPI') data = ppi_prepoc(path, seed) data.x = data.graphs[0].x data.num_classes = data.graphs[0].y.size(1) for df in data.graphs: df.num_classes = data.num_classes #elif dataset = 'MUTAG' # Get it in right format if dataset != 'PPI': print('Train mask is of size: ', data.train_mask[data.train_mask == True].shape) # data = add_noise_features(data, args.num_noise) return data
def prepare_data(dataset, train_ratio=0.8, input_dim=None, seed=10): """Import, save and process dataset Args: dataset (str): name of the dataset used seed (int): seed number Returns: [torch_geometric.Data]: dataset in the correct format with required attributes and train/test/val split """ # Retrieve main path of project dirname = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) # Download and store dataset at chosen location if dataset == 'Cora' or dataset == 'PubMed' or dataset == 'Citeseer': path = os.path.join(dirname, 'data') data = Planetoid(path, name=dataset, split='full')[0] data.name = dataset data.num_classes = (max(data.y) + 1).item() # data.train_mask, data.val_mask, data.test_mask = split_function(data.y.numpy()) # data = Planetoid(path, name=dataset, split='public', transform=T.NormalizeFeatures(), num_train_per_class=20, num_val=500, num_test=1000) elif dataset == 'Amazon': path = os.path.join(dirname, 'data', 'Amazon') data = Amazon(path, 'photo')[0] data.name = dataset data.num_classes = (max(data.y) + 1).item() data.train_mask, data.val_mask, data.test_mask = split_function( data.y.numpy(), seed=seed) # Amazon: 4896 train, 1224 val, 1530 test elif dataset in ['syn1', 'syn2', 'syn4', 'syn5']: data = synthetic_data(dataset, dirname, train_ratio, input_dim) elif dataset == 'syn6': data = gc_data(dataset, dirname, train_ratio) elif dataset == 'Mutagenicity': data = gc_data(dataset, dirname, train_ratio) return data