Пример #1
0
def gen_fake_dataset():
    """Returns a function to be called on each worker that returns a Fake Dataset."""

    # For fake dataset, since the dataset is randomized, we create it once on the
    # driver, and then send the same dataset to all the training workers.
    # Use 10% of nodes for validation and 10% for testing.
    fake_dataset = FakeDataset(transform=RandomNodeSplit(num_val=0.1, num_test=0.1))

    def gen_dataset():
        return fake_dataset

    return gen_dataset
Пример #2
0
def test_random_node_split_on_hetero_data():
    data = HeteroData()

    data['paper'].x = torch.randn(2000, 16)
    data['paper'].y = torch.randint(4, (2000, ), dtype=torch.long)
    data['author'].x = torch.randn(300, 16)

    transform = RandomNodeSplit()
    assert str(transform) == 'RandomNodeSplit(split=train_rest)'
    data = transform(data)
    assert len(data) == 5

    assert len(data['author']) == 1
    assert len(data['paper']) == 5

    assert data['paper'].train_mask.sum() == 500
    assert data['paper'].val_mask.sum() == 500
    assert data['paper'].test_mask.sum() == 1000
Пример #3
0
def test_random_node_split(num_splits):
    num_nodes, num_classes = 1000, 4
    x = torch.randn(num_nodes, 16)
    y = torch.randint(num_classes, (num_nodes, ), dtype=torch.long)
    data = Data(x=x, y=y)

    transform = RandomNodeSplit(split='train_rest',
                                num_splits=num_splits,
                                num_val=100,
                                num_test=200)
    assert str(transform) == 'RandomNodeSplit(split=train_rest)'
    data = transform(data)
    assert len(data) == 5

    train_mask = data.train_mask
    train_mask = train_mask.unsqueeze(-1) if num_splits == 1 else train_mask
    assert train_mask.size() == (num_nodes, num_splits)
    val_mask = data.val_mask
    val_mask = val_mask.unsqueeze(-1) if num_splits == 1 else val_mask
    assert val_mask.size() == (num_nodes, num_splits)
    test_mask = data.test_mask
    test_mask = test_mask.unsqueeze(-1) if num_splits == 1 else test_mask
    assert test_mask.size() == (num_nodes, num_splits)

    for i in range(train_mask.size(-1)):
        assert train_mask[:, i].sum() == num_nodes - 100 - 200
        assert val_mask[:, i].sum() == 100
        assert test_mask[:, i].sum() == 200
        assert (train_mask[:, i] & val_mask[:, i] & test_mask[:, i]).sum() == 0
        assert ((train_mask[:, i] | val_mask[:, i]
                 | test_mask[:, i]).sum() == num_nodes)

    transform = RandomNodeSplit(split='train_rest',
                                num_splits=num_splits,
                                num_val=0.1,
                                num_test=0.2)
    data = transform(data)

    train_mask = data.train_mask
    train_mask = train_mask.unsqueeze(-1) if num_splits == 1 else train_mask
    val_mask = data.val_mask
    val_mask = val_mask.unsqueeze(-1) if num_splits == 1 else val_mask
    test_mask = data.test_mask
    test_mask = test_mask.unsqueeze(-1) if num_splits == 1 else test_mask

    for i in range(train_mask.size(-1)):
        assert train_mask[:, i].sum() == num_nodes - 100 - 200
        assert val_mask[:, i].sum() == 100
        assert test_mask[:, i].sum() == 200
        assert (train_mask[:, i] & val_mask[:, i] & test_mask[:, i]).sum() == 0
        assert ((train_mask[:, i] | val_mask[:, i]
                 | test_mask[:, i]).sum() == num_nodes)

    transform = RandomNodeSplit(split='test_rest',
                                num_splits=num_splits,
                                num_train_per_class=10,
                                num_val=100)
    assert str(transform) == 'RandomNodeSplit(split=test_rest)'
    data = transform(data)
    assert len(data) == 5

    train_mask = data.train_mask
    train_mask = train_mask.unsqueeze(-1) if num_splits == 1 else train_mask
    val_mask = data.val_mask
    val_mask = val_mask.unsqueeze(-1) if num_splits == 1 else val_mask
    test_mask = data.test_mask
    test_mask = test_mask.unsqueeze(-1) if num_splits == 1 else test_mask

    for i in range(train_mask.size(-1)):
        assert train_mask[:, i].sum() == 10 * num_classes
        assert val_mask[:, i].sum() == 100
        assert test_mask[:, i].sum() == num_nodes - 10 * num_classes - 100
        assert (train_mask[:, i] & val_mask[:, i] & test_mask[:, i]).sum() == 0
        assert ((train_mask[:, i] | val_mask[:, i]
                 | test_mask[:, i]).sum() == num_nodes)

    transform = RandomNodeSplit(split='test_rest',
                                num_splits=num_splits,
                                num_train_per_class=10,
                                num_val=0.1)
    data = transform(data)

    train_mask = data.train_mask
    train_mask = train_mask.unsqueeze(-1) if num_splits == 1 else train_mask
    val_mask = data.val_mask
    val_mask = val_mask.unsqueeze(-1) if num_splits == 1 else val_mask
    test_mask = data.test_mask
    test_mask = test_mask.unsqueeze(-1) if num_splits == 1 else test_mask

    for i in range(train_mask.size(-1)):
        assert train_mask[:, i].sum() == 10 * num_classes
        assert val_mask[:, i].sum() == 100
        assert test_mask[:, i].sum() == num_nodes - 10 * num_classes - 100
        assert (train_mask[:, i] & val_mask[:, i] & test_mask[:, i]).sum() == 0
        assert ((train_mask[:, i] | val_mask[:, i]
                 | test_mask[:, i]).sum() == num_nodes)

    transform = RandomNodeSplit(split='random',
                                num_splits=num_splits,
                                num_train_per_class=10,
                                num_val=100,
                                num_test=200)
    assert str(transform) == 'RandomNodeSplit(split=random)'
    data = transform(data)
    assert len(data) == 5

    train_mask = data.train_mask
    train_mask = train_mask.unsqueeze(-1) if num_splits == 1 else train_mask
    val_mask = data.val_mask
    val_mask = val_mask.unsqueeze(-1) if num_splits == 1 else val_mask
    test_mask = data.test_mask
    test_mask = test_mask.unsqueeze(-1) if num_splits == 1 else test_mask

    for i in range(train_mask.size(-1)):
        assert train_mask[:, i].sum() == 10 * num_classes
        assert val_mask[:, i].sum() == 100
        assert test_mask[:, i].sum() == 200
        assert (train_mask[:, i] & val_mask[:, i] & test_mask[:, i]).sum() == 0
        assert ((train_mask[:, i] | val_mask[:, i]
                 | test_mask[:, i]).sum() == 10 * num_classes + 100 + 200)

    transform = RandomNodeSplit(split='random',
                                num_splits=num_splits,
                                num_train_per_class=10,
                                num_val=0.1,
                                num_test=0.2)
    assert str(transform) == 'RandomNodeSplit(split=random)'
    data = transform(data)

    train_mask = data.train_mask
    train_mask = train_mask.unsqueeze(-1) if num_splits == 1 else train_mask
    val_mask = data.val_mask
    val_mask = val_mask.unsqueeze(-1) if num_splits == 1 else val_mask
    test_mask = data.test_mask
    test_mask = test_mask.unsqueeze(-1) if num_splits == 1 else test_mask

    for i in range(train_mask.size(-1)):
        assert train_mask[:, i].sum() == 10 * num_classes
        assert val_mask[:, i].sum() == 100
        assert test_mask[:, i].sum() == 200
        assert (train_mask[:, i] & val_mask[:, i] & test_mask[:, i]).sum() == 0
        assert ((train_mask[:, i] | val_mask[:, i]
                 | test_mask[:, i]).sum() == 10 * num_classes + 100 + 200)