Example #1
0
def reduce_contrasts(
    components: str = 'components_453_gm',
    studies: Union[str, List[str]] = 'all',
    masked_dir='unmasked',
    output_dir='reduced',
    n_jobs=1,
    lstsq=False,
):
    batch_size = 200

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    if studies == 'all':
        studies = STUDY_LIST

    modl_atlas = fetch_atlas_modl()
    mask = fetch_mask()
    dictionary = modl_atlas[components]
    masker = NiftiMasker(mask_img=mask).fit()
    components = masker.transform(dictionary)
    for study in studies:
        this_data, targets = load(join(masked_dir, 'data_%s.pt' % study))
        n_samples = this_data.shape[0]
        batches = list(gen_batches(n_samples, batch_size))
        this_data = Parallel(n_jobs=n_jobs,
                             verbose=10,
                             backend='multiprocessing',
                             mmap_mode='r')(delayed(single_reduce)(
                                 components, this_data[batch], lstsq=lstsq)
                                            for batch in batches)
        this_data = np.concatenate(this_data, axis=0)

        dump((this_data, targets), join(output_dir, 'data_%s.pt' % study))
Example #2
0
def unmask(dataset, output_dir=None, n_jobs=1, batch_size=1000):
    if dataset == 'hcp':
        fetch_data = fetch_hcp
    elif dataset == 'archi':
        fetch_data = fetch_archi
    elif dataset == 'brainomics':
        fetch_data = fetch_brainomics
    elif dataset == 'la5c':
        fetch_data = fetch_la5c
    elif dataset == 'human_voice':
        fetch_data = fetch_human_voice
    elif dataset == 'camcan':
        fetch_data = fetch_camcan
    elif dataset == 'brainpedia':
        fetch_data = fetch_brainpedia
    else:
        raise ValueError

    imgs = fetch_data()
    if dataset == 'hcp':
        imgs = imgs.contrasts
    mask = fetch_mask()

    artifact_dir = join(get_output_dir(output_dir), 'unmasked', dataset)

    create_raw_contrast_data(imgs,
                             mask,
                             artifact_dir,
                             n_jobs=n_jobs,
                             batch_size=batch_size)
Example #3
0
def mask_contrasts(studies: Union[str, List[str]] = 'all',
                   output_dir: str = 'masked',
                   use_raw=False,
                   n_jobs: int = 1):
    batch_size = 10

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    if use_raw and studies == 'all':
        data = fetch_all()
    else:
        data = fetch_contrasts(studies)
    mask = fetch_mask()
    masker = NiftiMasker(smoothing_fwhm=4,
                         mask_img=mask,
                         verbose=0,
                         memory_level=1,
                         memory=None).fit()

    for study, this_data in data.groupby('study'):
        imgs = this_data['z_map'].values
        targets = this_data.reset_index()

        n_samples = this_data.shape[0]
        batches = list(gen_batches(n_samples, batch_size))
        this_data = Parallel(n_jobs=n_jobs,
                             verbose=10,
                             backend='multiprocessing',
                             mmap_mode='r')(
                                 delayed(single_mask)(masker, imgs[batch])
                                 for batch in batches)
        this_data = np.concatenate(this_data, axis=0)

        dump((this_data, targets), join(output_dir, 'data_%s.pt' % study))
Example #4
0
def compute_classifs(estimator, standard_scaler, config, return_type='img'):
    if config['model']['estimator'] in ['multi_study', 'ensemble']:
        module = curate_module(estimator)
        module.eval()

        studies = module.classifiers.keys()
        in_features = module.embedder.in_features

        with torch.no_grad():
            classifs = module(
                {study: torch.eye(in_features)
                 for study in studies},
                logits=True)
            biases = module(
                {study: torch.zeros((1, in_features))
                 for study in studies},
                logits=True)
            classifs = {
                study: classifs[study] - biases[study]
                for study in studies
            }
            classifs = {
                study: classif - classif.mean(dim=0, keepdim=True)
                for study, classif in classifs.items()
            }
        classifs = {
            study: classif.numpy().T
            for study, classif in classifs.items()
        }

    elif config['model']['estimator'] == 'logistic':
        classifs = estimator.coef_
    else:
        raise ValueError('Wrong config file')

    if standard_scaler is not None:
        for study, classif in classifs.items():
            sc = standard_scaler.scs_[study]
            classifs[study] = classif / sc.scale_[None, :]

    if config['data']['reduced']:
        mask = fetch_mask()
        masker = NiftiMasker(mask_img=mask).fit()
        modl_atlas = fetch_atlas_modl()
        dictionary = modl_atlas['components_453_gm']
        dictionary = masker.transform(dictionary)
        classifs = {
            study: classif.dot(dictionary)
            for study, classif in classifs.items()
        }
    if return_type == 'img':
        classifs_img = masker.inverse_transform(
            np.concatenate(list(classifs.values()), axis=0))
        return classifs_img
    elif return_type == 'arrays':
        return classifs
    else:
        raise ValueError
Example #5
0
def compute_rec():
    mask_img = fetch_mask()
    masker = MultiNiftiMasker(mask_img=mask_img).fit()
    atlas = fetch_atlas_modl()
    components_imgs = [
        atlas.positive_new_components16, atlas.positive_new_components64,
        atlas.positive_new_components512
    ]
    components = masker.transform(components_imgs)
    proj, proj_inv, rec = make_projection_matrix(components, scale_bases=True)
    dump(rec, join(get_output_dir(), 'benchmark', 'rec.pkl'))
Example #6
0
def compute_components(estimator, config, return_type='img'):
    """Compute components from a FactoredClassifier estimator"""
    module = curate_module(estimator)
    components = module.embedder.linear.weight.detach().numpy()

    if config['data']['reduced']:
        mask = fetch_mask()
        masker = NiftiMasker(mask_img=mask).fit()
        modl_atlas = fetch_atlas_modl()
        dictionary = modl_atlas['components_453_gm']
        dictionary = masker.transform(dictionary)
        components = components.dot(dictionary)
    if return_type == 'img':
        components_img = masker.inverse_transform(components)
        return components_img
    elif return_type == 'arrays':
        return components
    else:
        raise ValueError
Example #7
0
def plot_4d_image(img, names=None, output_dir=None,
                  colors=None,
                  view_types=['stat_map'],
                  threshold=True,
                  n_jobs=1, verbose=10):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    if colors is None:
        colors = repeat(None)

    if 'surf_stat_map_right' in view_types or 'surf_stat_map_left' in view_types:
        fetch_surf_fsaverage5()
    filename = img
    img = check_niimg(img, ensure_ndim=4)
    img.get_data()
    if names is None or isinstance(names, str):
        if names is None:
            dirname, filename = os.path.split(filename)
            names = filename.replace('.nii.gz', '')
        names = numbered_names(names)
    else:
        assert len(names) == img.get_shape()[3]

    mask = fetch_mask()
    masker = NiftiMasker(mask_img=mask).fit()
    components = masker.transform(img)
    n_components = len(components)
    threshold = np.percentile(np.abs(components),
                              100. * (1 - 1. / n_components)) if threshold else 0

    imgs = Parallel(n_jobs=n_jobs, verbose=verbose)(
        delayed(plot_single)(img, name, output_dir, view_types, color,
                             threshold=threshold)
        for name, img, color in zip(names, iter_img(img), colors))
    return imgs
Example #8
0
    ]
    components = masker.transform(components_imgs)
    proj, proj_inv, rec = make_projection_matrix(components, scale_bases=True)
    dump(rec, join(get_output_dir(), 'benchmark', 'rec.pkl'))


def load_rec():
    return load(join(get_output_dir(), 'benchmark', 'rec.pkl'))


# compute_rec()

exp_dirs = join(get_output_dir(), 'single_exp', '8')
models = []
rec = load_rec()
mask_img = fetch_mask()
masker = MultiNiftiMasker(mask_img=mask_img).fit()

for exp_dir in [exp_dirs]:
    estimator = load(join(exp_dirs, 'estimator.pkl'))
    transformer = load(join(exp_dirs, 'transformer.pkl'))
    print([(dataset, this_class)
           for dataset, lbin in transformer.lbins_.items()
           for this_class in lbin.classes_])
    coef = estimator.coef_
    coef_rec = coef.dot(rec)
    print(join(exp_dirs, 'maps.nii.gz'))
    imgs = masker.inverse_transform(coef_rec)
    imgs.to_filename(join(exp_dirs, 'maps.nii.gz'))
plot_stat_map(index_img(imgs, 10))
plt.show()
Example #9
0
model = load(join(artifact_dir, 'estimator.pkl'))
maps = model.coef_

source = config['source']

if source == 'craddock':
    components = fetch_craddock_parcellation().parcellate400
    data = np.ones_like(check_niimg(components).get_data())
    mask = new_img_like(components, data)
    label_masker = NiftiLabelsMasker(labels_img=components,
                                     smoothing_fwhm=0,
                                     mask_img=mask).fit()
    maps_img = label_masker.inverse_transform(maps)
else:
    mask = fetch_mask()
    masker = MultiNiftiMasker(mask_img=mask).fit()

    if source == 'msdl':
        components = fetch_atlas_msdl()['maps']
        components = masker.transform(components)
    elif source in ['hcp_rs', 'hcp_rs_concat', 'hcp_rs_positive']:
        data = fetch_atlas_modl()
        if source == 'hcp_rs':
            components_imgs = [data.nips2017_components64]
        elif source == 'hcp_rs_concat':
            components_imgs = [
                data.nips2017_components16, data.nips2017_components64,
                data.nips2017_components256
            ]
        else:
Example #10
0
def test_fetch_mask():
    fetch_mask()