示例#1
0
def test_batch_xyz(ADNI_names, path_dicts, signal_params):
    # Check for multiple overlaps if data == xyz_to_batch(batch_to_xyz(data))
    for overlap_coeff in range(1, 4):
        signal_params['overlap_coeff'] = overlap_coeff

        dataset = SHDataset(path_dicts,
                            patch_size=signal_params['patch_size'],
                            signal_parameters=signal_params,
                            transformations=None,
                            cache_dir=None)

        patient = np.random.choice(ADNI_names)
        data_patient = dataset.get_data_by_name(patient)

        data_batch = torch.FloatTensor(data_patient['sh'])

        data_xyz = batch_to_xyz(data_batch,
                                data_patient['real_size'],
                                overlap_coeff=signal_params['overlap_coeff'],
                                empty=data_patient['empty'])

        assert data_xyz.shape[:3] == data_patient['real_size']

        data_batch_2, number_of_patches = xyz_to_batch(
            data_xyz,
            signal_params['patch_size'],
            overlap_coeff=signal_params['overlap_coeff'])

        data_batch_2 = data_batch_2[~data_patient['empty']]

        assert torch.isclose(data_batch, data_batch_2, rtol=0.05,
                             atol=1e-6).all()
示例#2
0
def test_normalization(ADNI_names, path_dicts, signal_params):
    dataset = SHDataset(path_dicts,
                        patch_size=signal_params['patch_size'],
                        signal_parameters=signal_params,
                        transformations=None,
                        normalize_data=False,
                        cache_dir=None)

    patient = np.random.choice(ADNI_names)
    data_raw = dataset.get_data_by_name(patient)['sh']
    dataset.normalize_data()
    mean, std = dataset.mean, dataset.std
    data_normalized = dataset.get_data_by_name(patient)['sh']

    assert np.isclose(data_normalized * std + mean,
                      data_raw,
                      rtol=0.05,
                      atol=1e-6).all()
示例#3
0
def test_dataset(ADNI_names, path_dicts, signal_params):
    dataset = SHDataset(path_dicts,
                        patch_size=signal_params['patch_size'],
                        signal_parameters=signal_params,
                        transformations=None,
                        n_jobs=-1,
                        cache_dir=None)

    sh_order = signal_params['sh_order']
    ncoef = (sh_order + 2) * (sh_order + 1) / 2

    signal, mask = dataset[300]

    assert list(signal.shape) == signal_params['patch_size'] + [ncoef]

    assert len(dataset) == 608

    patient = np.random.choice(ADNI_names)
    dataset.get_data_by_name(patient)
示例#4
0
                    normalize_data=True,
                    mean=mean,
                    std=std,
                    b0_mean=b0_mean,
                    b0_std=b0_std,
                    n_jobs=8)

# Load the network
net, _ = ENet.load(save_folder + net_file)

net = net.to("cuda")
net.eval()

# Get the dmri name
dwi_name = dataset.names[0]
data = dataset.get_data_by_name(dwi_name)

sh_true = batch_to_xyz(
    data['sh'],
    data['real_size'],
    empty=data['empty'],
    overlap_coeff=SIGNAL_PARAMETERS['overlap_coeff'])
sh_true = sh_true * dataset.std + dataset.mean

net_pred = net.predict_dataset(dataset, batch_size=16, numpy=False)[dwi_name]
sh_pred = net_pred['sh_pred']
mean_b0_pred = net_pred['mean_b0_pred']
alpha = net_pred['alpha']
beta = net_pred['beta']

sh_pred = batch_to_xyz(