Exemplo n.º 1
0
def get_dataset(dataset_name):
    """
    Retrieves the dataset corresponding to the given name.
    """
    print("Getting dataset...")
    path = join('dataset', dataset_name)
    if dataset_name == 'reddit':
        dataset = Reddit(path)
    elif dataset_name == 'ppi':
        dataset = PPI(path)
    elif dataset_name == 'github':
        dataset = GitHub(path)
        data = dataset.data
        idx_train, idx_test = train_test_split(list(range(data.x.shape[0])),
                                               test_size=0.4,
                                               random_state=42)
        idx_val, idx_test = train_test_split(idx_test,
                                             test_size=0.5,
                                             random_state=42)
        data.train_mask = torch.tensor(idx_train)
        data.val_mask = torch.tensor(idx_val)
        data.test_mask = torch.tensor(idx_test)
        dataset.data = data
    elif dataset_name in ['amazon_comp', 'amazon_photo']:
        dataset = Amazon(path, "Computers", T.NormalizeFeatures()
                         ) if dataset_name == 'amazon_comp' else Amazon(
                             path, "Photo", T.NormalizeFeatures())
        data = dataset.data
        idx_train, idx_test = train_test_split(list(range(data.x.shape[0])),
                                               test_size=0.4,
                                               random_state=42)
        idx_val, idx_test = train_test_split(idx_test,
                                             test_size=0.5,
                                             random_state=42)
        data.train_mask = torch.tensor(idx_train)
        data.val_mask = torch.tensor(idx_val)
        data.test_mask = torch.tensor(idx_test)
        dataset.data = data
    elif dataset_name in ["Cora", "CiteSeer", "PubMed"]:
        dataset = Planetoid(path,
                            name=dataset_name,
                            split="full",
                            transform=T.NormalizeFeatures())
    else:
        raise NotImplementedError

    print("Dataset ready!")
    return dataset
def get_dataset(dataset_name):
    """
    Retrieves the dataset corresponding to the given name.
    """
    path = 'dataset'
    if dataset_name == 'reddit':
        dataset = Reddit(path)
    elif dataset_name == 'amazon_comp':
        dataset = Amazon(path, name="Computers")
        data = dataset.data
        idx_train, idx_test = train_test_split(list(range(data.x.shape[0])),
                                               test_size=0.4,
                                               random_state=42)
        idx_val, idx_test = train_test_split(idx_test,
                                             test_size=0.5,
                                             random_state=42)

        train_mask = torch.tensor([False] * data.x.shape[0])
        val_mask = torch.tensor([False] * data.x.shape[0])
        test_mask = torch.tensor([False] * data.x.shape[0])

        train_mask[idx_train] = True
        val_mask[idx_val] = True
        test_mask[idx_test] = True

        data.train_mask = train_mask
        data.val_mask = val_mask
        data.test_mask = test_mask
        dataset.data = data
    elif dataset_name in ["Cora", "CiteSeer", "PubMed"]:
        dataset = Planetoid(
            path,
            name=dataset_name,
            split="full",
        )
    else:
        raise NotImplementedError

    return dataset