コード例 #1
0
def test_get_dataset(data_path):
    """Test splitting with default size of datasets."""
    train, test, valid = get_datasets(
        data_path=data_path,
        nb_nodes=185,
        task_type="classification",
        nb_classes=2,
        split=None,
        k_fold=None,
        seed=1234,
    )
    assert 164 == len(train)
    assert 24 == len(test)
    assert 48 == len(valid)
コード例 #2
0
def test_kfold(data_path):
    """Test kfold splitting."""
    train, test, valid = get_datasets(
        data_path=data_path,
        nb_nodes=185,
        task_type="classification",
        nb_classes=2,
        split=None,
        k_fold=[7, 1, 2],
        seed=1234,
    )
    assert 164 == len(train)
    assert 24 == len(test)
    assert 48 == len(valid)
コード例 #3
0
ファイル: conftest.py プロジェクト: carrascomj/gcn-prot
def batch(data_path):
    """Define a batch for the KrasHras dataset."""
    train, _, _ = get_datasets(
        data_path=data_path,
        nb_nodes=7,
        task_type="classification",
        nb_classes=2,
        split=None,
        k_fold=None,
        seed=1234,
    )
    for batch in torch.utils.data.DataLoader(
        train, shuffle=False, batch_size=25, drop_last=False
    ):
        return batch
コード例 #4
0
def test_indexing(data_path):
    """Test random access the generated graph dataset."""
    train, _, _ = get_datasets(
        data_path=data_path,
        nb_nodes=185,
        task_type="classification",
        nb_classes=2,
        split=None,
        k_fold=None,
        seed=1234,
    )
    prot = train[0]
    prot_dims = [len(tr) for tr in prot]
    # v, c, m, y
    assert prot_dims == [185, 185, 185, 2]
コード例 #5
0
def test_gaussian_augmentation(data_path):
    """Test splitting with default size of datasets."""
    train, valid, test = get_datasets(
        data_path=data_path,
        nb_nodes=185,
        task_type="classification",
        nb_classes=2,
        split=None,
        augment=2,
        k_fold=None,
        seed=1234,
    )
    assert 164 * 2 == len(train)
    assert 48 * 2 == len(test)
    assert 24 == len(valid)
コード例 #6
0
def test_validation(data_path, nn_kras):
    """Fit one epoch of train + test."""
    train, test, valid = get_datasets(
        data_path=data_path,
        nb_nodes=7,
        task_type="classification",
        nb_classes=2,
        split=None,
        k_fold=None,
        seed=1234,
    )
    validator = Validation(nn_kras, valid)
    validator.validate()
    stats = validator.compute_stats()
    set(stats.keys()) == {
        "recall",
        "precision",
        "accuracy",
        "f_score",
    }
コード例 #7
0
def test_dataloading_batch(data_path):
    """Test transformation of input."""
    train, _, _ = get_datasets(
        data_path=data_path,
        nb_nodes=185,
        task_type="classification",
        nb_classes=2,
        split=None,
        k_fold=None,
        seed=1234,
    )
    trainloader = torch.utils.data.DataLoader(
        train, shuffle=False, batch_size=25, drop_last=False
    )
    for batch in trainloader:
        batch_dims = [len(tr) for tr in batch]
        break

    v = batch[0]
    assert batch_dims == [25, 25, 25, 2]
    assert v.shape == torch.Size([25, 185, 29])
コード例 #8
0
def test_fit_epoch(data_path, nn_kras):
    """Fit one epoch of train + test."""
    train, test, _ = get_datasets(
        data_path=data_path,
        nb_nodes=7,
        task_type="classification",
        nb_classes=2,
        split=None,
        k_fold=None,
        seed=1234,
    )
    optimizer = torch.optim.Adam(nn_kras.parameters())
    criterion = torch.nn.CrossEntropyLoss()
    fit_network(
        nn_kras,
        train,
        test,
        optimizer,
        criterion,
        20,
        epochs=1,
        plot_every=2000,
    )