Example #1
0
def test_batch_xyz(ADNI_names, path_dicts, signal_params):
    # Check for multiple overlaps if data == xyz_to_batch(batch_to_xyz(data))
    for overlap_coeff in range(1, 4):
        signal_params['overlap_coeff'] = overlap_coeff

        dataset = SHDataset(path_dicts,
                            patch_size=signal_params['patch_size'],
                            signal_parameters=signal_params,
                            transformations=None,
                            cache_dir=None)

        patient = np.random.choice(ADNI_names)
        data_patient = dataset.get_data_by_name(patient)

        data_batch = torch.FloatTensor(data_patient['sh'])

        data_xyz = batch_to_xyz(data_batch,
                                data_patient['real_size'],
                                overlap_coeff=signal_params['overlap_coeff'],
                                empty=data_patient['empty'])

        assert data_xyz.shape[:3] == data_patient['real_size']

        data_batch_2, number_of_patches = xyz_to_batch(
            data_xyz,
            signal_params['patch_size'],
            overlap_coeff=signal_params['overlap_coeff'])

        data_batch_2 = data_batch_2[~data_patient['empty']]

        assert torch.isclose(data_batch, data_batch_2, rtol=0.05,
                             atol=1e-6).all()
def test_training(path_dicts, signal_params, net):
    dataset = SHDataset(path_dicts,
                        patch_size=signal_params['patch_size'],
                        signal_parameters=signal_params,
                        transformations=None,
                        n_jobs=-1,
                        cache_dir=None)

    trainer = BaseTrainer(
        net,
        optimizer_parameters={
            "lr": 0.01,
            "weight_decay": 1e-8,
        },
        loss_specs={
            "type": "mse",
            "parameters": {}
        },
        metrics=["acc", "mse_gfa", "mse"],
        metric_to_maximize="mse",
        patience=100,
        save_folder=None,
    )

    trainer.train(dataset,
                  dataset,
                  num_epochs=5,
                  batch_size=128,
                  validation=True)
Example #3
0
def test_cache_no_cache(path_dicts, signal_params, cache_directory):

    shutil.rmtree(cache_directory, ignore_errors=True)
    SHDataset(path_dicts,
              patch_size=signal_params['patch_size'],
              signal_parameters=signal_params,
              transformations=None,
              n_jobs=-1,
              cache_dir=None)
    assert not os.path.isdir(cache_directory)

    SHDataset(path_dicts,
              patch_size=signal_params['patch_size'],
              signal_parameters=signal_params,
              transformations=None,
              n_jobs=-1,
              cache_dir=cache_directory)
    assert os.path.isdir(cache_directory)
Example #4
0
def test_dataset(ADNI_names, path_dicts, signal_params):
    dataset = SHDataset(path_dicts,
                        patch_size=signal_params['patch_size'],
                        signal_parameters=signal_params,
                        transformations=None,
                        n_jobs=-1,
                        cache_dir=None)

    sh_order = signal_params['sh_order']
    ncoef = (sh_order + 2) * (sh_order + 1) / 2

    signal, mask = dataset[300]

    assert list(signal.shape) == signal_params['patch_size'] + [ncoef]

    assert len(dataset) == 608

    patient = np.random.choice(ADNI_names)
    dataset.get_data_by_name(patient)
Example #5
0
def test_parallel_is_faster(path_dicts, signal_params, cache_directory):

    shutil.rmtree(cache_directory, ignore_errors=True)
    t1 = time.time()
    SHDataset(path_dicts,
              patch_size=signal_params['patch_size'],
              signal_parameters=signal_params,
              transformations=None,
              n_jobs=-1,
              cache_dir=None)
    t1 = time.time() - t1

    shutil.rmtree(cache_directory, ignore_errors=True)
    t2 = time.time()
    SHDataset(path_dicts,
              patch_size=signal_params['patch_size'],
              signal_parameters=signal_params,
              transformations=None,
              n_jobs=1,
              cache_dir=None)
    t2 = time.time() - t2

    assert t2 > t1
Example #6
0
def test_cache_is_faster(path_dicts, signal_params, cache_directory):

    shutil.rmtree(cache_directory, ignore_errors=True)
    t1 = time.time()
    SHDataset(path_dicts,
              patch_size=signal_params['patch_size'],
              signal_parameters=signal_params,
              transformations=None,
              n_jobs=-1,
              cache_dir=cache_directory)
    t1 = time.time() - t1

    # We don't delete the cache so the new dataset can use it
    t2 = time.time()
    SHDataset(path_dicts,
              patch_size=signal_params['patch_size'],
              signal_parameters=signal_params,
              transformations=None,
              n_jobs=-1,
              cache_dir=cache_directory)
    t2 = time.time() - t2

    assert t2 < t1
Example #7
0
def test_normalization(ADNI_names, path_dicts, signal_params):
    dataset = SHDataset(path_dicts,
                        patch_size=signal_params['patch_size'],
                        signal_parameters=signal_params,
                        transformations=None,
                        normalize_data=False,
                        cache_dir=None)

    patient = np.random.choice(ADNI_names)
    data_raw = dataset.get_data_by_name(patient)['sh']
    dataset.normalize_data()
    mean, std = dataset.mean, dataset.std
    data_normalized = dataset.get_data_by_name(patient)['sh']

    assert np.isclose(data_normalized * std + mean,
                      data_raw,
                      rtol=0.05,
                      atol=1e-6).all()
Example #8
0
print("Train dataset size :", len(path_train))
print("Test dataset size :", len(path_test))
print("Validation dataset size", len(path_validation))

with open(save_folder + "train_val_test.txt", "w") as output:
    output.write(str([x['name'] for x in path_train]) + '\n')
    output.write(str([x['name'] for x in path_validation]) + '\n')
    output.write(str([x['name'] for x in path_test]))


SIGNAL_PARAMETERS['overlap_coeff'] = 1

train_dataset = SHDataset(path_train,
                          patch_size=SIGNAL_PARAMETERS["patch_size"],
                          signal_parameters=SIGNAL_PARAMETERS,
                          transformations=None,
                          normalize_data=True,
                          n_jobs=8,
                          cache_dir="./")

validation_dataset = SHDataset(path_validation,
                               patch_size=SIGNAL_PARAMETERS["patch_size"],
                               signal_parameters=SIGNAL_PARAMETERS,
                               transformations=None,
                               normalize_data=True,
                               mean=train_dataset.mean,
                               std=train_dataset.std,
                               b0_mean=train_dataset.b0_mean,
                               b0_std=train_dataset.b0_std,
                               n_jobs=8,
                               cache_dir="./")
Example #9
0
SIGNAL_PARAMETERS['overlap_coeff'] = 2

paths, _ = get_paths_SIMON()  # get_paths_ADNI()

paths = [d for d in paths if d['name'] == dwi_file]

mean, std, b0_mean, b0_std = np.load(save_folder + 'mean_std.npy',
                                     allow_pickle=True)

# Create the dataset
dataset = SHDataset(paths,
                    patch_size=SIGNAL_PARAMETERS["patch_size"],
                    signal_parameters=SIGNAL_PARAMETERS,
                    transformations=None,
                    normalize_data=True,
                    mean=mean,
                    std=std,
                    b0_mean=b0_mean,
                    b0_std=b0_std,
                    n_jobs=8)

# Load the network
net, _ = ENet.load(save_folder + net_file)

net = net.to("cuda")
net.eval()

# Get the dmri name
dwi_name = dataset.names[0]
data = dataset.get_data_by_name(dwi_name)