def test_batch_xyz(ADNI_names, path_dicts, signal_params): # Check for multiple overlaps if data == xyz_to_batch(batch_to_xyz(data)) for overlap_coeff in range(1, 4): signal_params['overlap_coeff'] = overlap_coeff dataset = SHDataset(path_dicts, patch_size=signal_params['patch_size'], signal_parameters=signal_params, transformations=None, cache_dir=None) patient = np.random.choice(ADNI_names) data_patient = dataset.get_data_by_name(patient) data_batch = torch.FloatTensor(data_patient['sh']) data_xyz = batch_to_xyz(data_batch, data_patient['real_size'], overlap_coeff=signal_params['overlap_coeff'], empty=data_patient['empty']) assert data_xyz.shape[:3] == data_patient['real_size'] data_batch_2, number_of_patches = xyz_to_batch( data_xyz, signal_params['patch_size'], overlap_coeff=signal_params['overlap_coeff']) data_batch_2 = data_batch_2[~data_patient['empty']] assert torch.isclose(data_batch, data_batch_2, rtol=0.05, atol=1e-6).all()
def test_training(path_dicts, signal_params, net): dataset = SHDataset(path_dicts, patch_size=signal_params['patch_size'], signal_parameters=signal_params, transformations=None, n_jobs=-1, cache_dir=None) trainer = BaseTrainer( net, optimizer_parameters={ "lr": 0.01, "weight_decay": 1e-8, }, loss_specs={ "type": "mse", "parameters": {} }, metrics=["acc", "mse_gfa", "mse"], metric_to_maximize="mse", patience=100, save_folder=None, ) trainer.train(dataset, dataset, num_epochs=5, batch_size=128, validation=True)
def test_cache_no_cache(path_dicts, signal_params, cache_directory): shutil.rmtree(cache_directory, ignore_errors=True) SHDataset(path_dicts, patch_size=signal_params['patch_size'], signal_parameters=signal_params, transformations=None, n_jobs=-1, cache_dir=None) assert not os.path.isdir(cache_directory) SHDataset(path_dicts, patch_size=signal_params['patch_size'], signal_parameters=signal_params, transformations=None, n_jobs=-1, cache_dir=cache_directory) assert os.path.isdir(cache_directory)
def test_dataset(ADNI_names, path_dicts, signal_params): dataset = SHDataset(path_dicts, patch_size=signal_params['patch_size'], signal_parameters=signal_params, transformations=None, n_jobs=-1, cache_dir=None) sh_order = signal_params['sh_order'] ncoef = (sh_order + 2) * (sh_order + 1) / 2 signal, mask = dataset[300] assert list(signal.shape) == signal_params['patch_size'] + [ncoef] assert len(dataset) == 608 patient = np.random.choice(ADNI_names) dataset.get_data_by_name(patient)
def test_parallel_is_faster(path_dicts, signal_params, cache_directory): shutil.rmtree(cache_directory, ignore_errors=True) t1 = time.time() SHDataset(path_dicts, patch_size=signal_params['patch_size'], signal_parameters=signal_params, transformations=None, n_jobs=-1, cache_dir=None) t1 = time.time() - t1 shutil.rmtree(cache_directory, ignore_errors=True) t2 = time.time() SHDataset(path_dicts, patch_size=signal_params['patch_size'], signal_parameters=signal_params, transformations=None, n_jobs=1, cache_dir=None) t2 = time.time() - t2 assert t2 > t1
def test_cache_is_faster(path_dicts, signal_params, cache_directory): shutil.rmtree(cache_directory, ignore_errors=True) t1 = time.time() SHDataset(path_dicts, patch_size=signal_params['patch_size'], signal_parameters=signal_params, transformations=None, n_jobs=-1, cache_dir=cache_directory) t1 = time.time() - t1 # We don't delete the cache so the new dataset can use it t2 = time.time() SHDataset(path_dicts, patch_size=signal_params['patch_size'], signal_parameters=signal_params, transformations=None, n_jobs=-1, cache_dir=cache_directory) t2 = time.time() - t2 assert t2 < t1
def test_normalization(ADNI_names, path_dicts, signal_params): dataset = SHDataset(path_dicts, patch_size=signal_params['patch_size'], signal_parameters=signal_params, transformations=None, normalize_data=False, cache_dir=None) patient = np.random.choice(ADNI_names) data_raw = dataset.get_data_by_name(patient)['sh'] dataset.normalize_data() mean, std = dataset.mean, dataset.std data_normalized = dataset.get_data_by_name(patient)['sh'] assert np.isclose(data_normalized * std + mean, data_raw, rtol=0.05, atol=1e-6).all()
print("Train dataset size :", len(path_train)) print("Test dataset size :", len(path_test)) print("Validation dataset size", len(path_validation)) with open(save_folder + "train_val_test.txt", "w") as output: output.write(str([x['name'] for x in path_train]) + '\n') output.write(str([x['name'] for x in path_validation]) + '\n') output.write(str([x['name'] for x in path_test])) SIGNAL_PARAMETERS['overlap_coeff'] = 1 train_dataset = SHDataset(path_train, patch_size=SIGNAL_PARAMETERS["patch_size"], signal_parameters=SIGNAL_PARAMETERS, transformations=None, normalize_data=True, n_jobs=8, cache_dir="./") validation_dataset = SHDataset(path_validation, patch_size=SIGNAL_PARAMETERS["patch_size"], signal_parameters=SIGNAL_PARAMETERS, transformations=None, normalize_data=True, mean=train_dataset.mean, std=train_dataset.std, b0_mean=train_dataset.b0_mean, b0_std=train_dataset.b0_std, n_jobs=8, cache_dir="./")
SIGNAL_PARAMETERS['overlap_coeff'] = 2 paths, _ = get_paths_SIMON() # get_paths_ADNI() paths = [d for d in paths if d['name'] == dwi_file] mean, std, b0_mean, b0_std = np.load(save_folder + 'mean_std.npy', allow_pickle=True) # Create the dataset dataset = SHDataset(paths, patch_size=SIGNAL_PARAMETERS["patch_size"], signal_parameters=SIGNAL_PARAMETERS, transformations=None, normalize_data=True, mean=mean, std=std, b0_mean=b0_mean, b0_std=b0_std, n_jobs=8) # Load the network net, _ = ENet.load(save_folder + net_file) net = net.to("cuda") net.eval() # Get the dmri name dwi_name = dataset.names[0] data = dataset.get_data_by_name(dwi_name)