def test_len_split(): for len in range(2000): train, valid, test = len_split(len) assert ( len == train + valid + test ), "Splitting of the dataset wrong, total len expected: {}, got {}".format( train + valid + test, len)
def test_train_MEG_swap(): dataset_path = ["Z:\Desktop\sub8\\ball1_sss.fif"] dataset = MEG_Dataset(dataset_path, duration=1.0, overlap=0.0) train_len, valid_len, test_len = len_split(len(dataset)) train_dataset, valid_dataset, test_dataset = random_split( dataset, [train_len, valid_len, test_len] ) device = "cpu" trainloader = DataLoader( train_dataset, batch_size=10, shuffle=False, num_workers=1 ) validloader = DataLoader( valid_dataset, batch_size=2, shuffle=False, num_workers=1 ) epochs = 1 with torch.no_grad(): x, _, _ = iter(trainloader).next() n_times = x.shape[-1] net = models.MNet(n_times) optimizer = SGD(net.parameters(), lr=0.0001, weight_decay=5e-4) loss_function = torch.nn.MSELoss() model, _, _ = train( net, trainloader, validloader, optimizer, loss_function, device, epochs, 10, 0, "", ) print("Test succeeded!")
def test_MEG_dataset_shape_2(): dataset_path = ["Z:\Desktop\sub8\\ball1_sss.fif"] dataset = MEG_Dataset2(dataset_path, duration=1.0, overlap=0.0) train_len, valid_len, test_len = len_split(len(dataset)) print(len(dataset)) print("{} {} {}".format(train_len, valid_len, test_len)) train_dataset, valid_test, test_dataset = random_split( dataset, [train_len, valid_len, test_len]) assert (train_dataset.__len__() == 524 ), "Bad split, train set length expected = 524, got {}".format( train_dataset.__len__()) assert ( valid_test.__len__() == 112 ), "Bad split, validation set length expected = 112 , got {}".format( valid_test.__len__()) assert (test_dataset.__len__() == 113 ), "Bad split, test set length expected = 113 , got {}".format( test_dataset.__len__()) trainloader = DataLoader(train_dataset, batch_size=50, shuffle=False, num_workers=1) sample_data, sample_target, sample_bp = iter(trainloader).next() assert sample_data.shape == torch.Size( [50, 1, 204, 501]), "wrong data shape, data shape expected = {}, got {}".format( torch.Size([50, 1, 204, 501]), sample_data.shape) assert sample_target.shape == torch.Size( [50, 2, 2]), "wrong target shape, data shape expected = {}, got {}".format( torch.Size([50, 2, 2]), sample_target.shape) assert sample_bp.shape == torch.Size( [50, 204, 6]), "wrong target shape, data shape expected = {}, got {}".format( torch.Size([50, 204, 6]), sample_target.shape)