Esempio n. 1
0
def prepare_data(dataset, seed):
    """
	:param dataset: name of the dataset used
	:return: data, in the correct format
	"""
    # Retrieve main path of project
    dirname = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))

    # Download and store dataset at chosen location
    if dataset == 'Cora' or dataset == 'PubMed' or dataset == 'Citeseer':
        path = os.path.join(dirname, 'data')
        data = Planetoid(path, name=dataset, split='full')[0]
        # data.train_mask, data.val_mask, data.test_mask = split_function(data.y.numpy())
        data.num_classes = (max(data.y) + 1).item()
        # dataset = Planetoid(path, name=dataset, split='public', transform=T.NormalizeFeatures(), num_train_per_class=20, num_val=500, num_test=1000)
        # data = modify_train_mask(data)

    elif dataset == 'Amazon':
        path = os.path.join(dirname, 'data', 'Amazon')
        data = Amazon(path, 'photo')[0]
        data.num_classes = (max(data.y) + 1).item()
        data.train_mask, data.val_mask, data.test_mask = split_function(
            data.y.numpy())
        # Amazon: 4896 train, 1224 val, 1530 test

    elif dataset == 'Reddit':
        path = os.path.join(dirname, 'data', 'Reedit')
        data = Reddit(path)[0]
        data.train_mask, data.val_mask, data.test_mask = split_function(
            data.y.numpy())

    elif dataset == 'PPI':
        path = os.path.join(dirname, 'data', 'PPI')
        data = ppi_prepoc(path, seed)
        data.x = data.graphs[0].x
        data.num_classes = data.graphs[0].y.size(1)
        for df in data.graphs:
            df.num_classes = data.num_classes

    #elif dataset = 'MUTAG'

    # Get it in right format
    if dataset != 'PPI':
        print('Train mask is of size: ',
              data.train_mask[data.train_mask == True].shape)

# data = add_noise_features(data, args.num_noise)

    return data
Esempio n. 2
0
def prepare_data(dataset, train_ratio=0.8, input_dim=None, seed=10):
    """Import, save and process dataset

    Args:
            dataset (str): name of the dataset used
            seed (int): seed number

    Returns:
            [torch_geometric.Data]: dataset in the correct format 
            with required attributes and train/test/val split
    """
    # Retrieve main path of project
    dirname = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))

    # Download and store dataset at chosen location
    if dataset == 'Cora' or dataset == 'PubMed' or dataset == 'Citeseer':
        path = os.path.join(dirname, 'data')
        data = Planetoid(path, name=dataset, split='full')[0]
        data.name = dataset
        data.num_classes = (max(data.y) + 1).item()
        # data.train_mask, data.val_mask, data.test_mask = split_function(data.y.numpy())
        # data = Planetoid(path, name=dataset, split='public', transform=T.NormalizeFeatures(), num_train_per_class=20, num_val=500, num_test=1000)

    elif dataset == 'Amazon':
        path = os.path.join(dirname, 'data', 'Amazon')
        data = Amazon(path, 'photo')[0]
        data.name = dataset
        data.num_classes = (max(data.y) + 1).item()
        data.train_mask, data.val_mask, data.test_mask = split_function(
            data.y.numpy(), seed=seed)
        # Amazon: 4896 train, 1224 val, 1530 test

    elif dataset in ['syn1', 'syn2', 'syn4', 'syn5']:
        data = synthetic_data(dataset, dirname, train_ratio, input_dim)

    elif dataset == 'syn6':
        data = gc_data(dataset, dirname, train_ratio)

    elif dataset == 'Mutagenicity':
        data = gc_data(dataset, dirname, train_ratio)

    return data