예제 #1
0
def test_len_split():
    for len in range(2000):
        train, valid, test = len_split(len)

        assert (
            len == train + valid + test
        ), "Splitting of the dataset wrong, total len expected: {}, got {}".format(
            train + valid + test, len)
예제 #2
0
def test_train_MEG_swap():

    dataset_path = ["Z:\Desktop\sub8\\ball1_sss.fif"]

    dataset = MEG_Dataset(dataset_path, duration=1.0, overlap=0.0)

    train_len, valid_len, test_len = len_split(len(dataset))

    train_dataset, valid_dataset, test_dataset = random_split(
        dataset, [train_len, valid_len, test_len]
    )

    device = "cpu"

    trainloader = DataLoader(
        train_dataset, batch_size=10, shuffle=False, num_workers=1
    )

    validloader = DataLoader(
        valid_dataset, batch_size=2, shuffle=False, num_workers=1
    )

    epochs = 1

    with torch.no_grad():
        x, _, _ = iter(trainloader).next()
    n_times = x.shape[-1]

    net = models.MNet(n_times)

    optimizer = SGD(net.parameters(), lr=0.0001, weight_decay=5e-4)
    loss_function = torch.nn.MSELoss()

    model, _, _ = train(
        net,
        trainloader,
        validloader,
        optimizer,
        loss_function,
        device,
        epochs,
        10,
        0,
        "",
    )

    print("Test succeeded!")
예제 #3
0
def test_MEG_dataset_shape_2():
    dataset_path = ["Z:\Desktop\sub8\\ball1_sss.fif"]

    dataset = MEG_Dataset2(dataset_path, duration=1.0, overlap=0.0)

    train_len, valid_len, test_len = len_split(len(dataset))

    print(len(dataset))
    print("{} {} {}".format(train_len, valid_len, test_len))

    train_dataset, valid_test, test_dataset = random_split(
        dataset, [train_len, valid_len, test_len])

    assert (train_dataset.__len__() == 524
            ), "Bad split, train set length expected = 524, got {}".format(
                train_dataset.__len__())

    assert (
        valid_test.__len__() == 112
    ), "Bad split, validation set length expected = 112 , got {}".format(
        valid_test.__len__())
    assert (test_dataset.__len__() == 113
            ), "Bad split, test set length expected = 113 , got {}".format(
                test_dataset.__len__())

    trainloader = DataLoader(train_dataset,
                             batch_size=50,
                             shuffle=False,
                             num_workers=1)

    sample_data, sample_target, sample_bp = iter(trainloader).next()

    assert sample_data.shape == torch.Size(
        [50, 1, 204,
         501]), "wrong data shape, data shape expected = {}, got {}".format(
             torch.Size([50, 1, 204, 501]), sample_data.shape)

    assert sample_target.shape == torch.Size(
        [50, 2,
         2]), "wrong target shape, data shape expected = {}, got {}".format(
             torch.Size([50, 2, 2]), sample_target.shape)

    assert sample_bp.shape == torch.Size(
        [50, 204,
         6]), "wrong target shape, data shape expected = {}, got {}".format(
             torch.Size([50, 204, 6]), sample_target.shape)