Example #1
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser, args, [], optional=args.output_centroids)
    if args.output_clusters_dir:
        assert_output_dirs_exist_and_empty(parser,
                                           args,
                                           args.output_clusters_dir,
                                           create_dir=True)

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
    streamlines = sft.streamlines
    thresholds = [40, 30, 20, args.dist_thresh]
    clusters = qbx_and_merge(streamlines,
                             thresholds,
                             nb_pts=args.nb_points,
                             verbose=False)

    for i, cluster in enumerate(clusters):
        if len(cluster.indices) > 1:
            cluster_streamlines = itemgetter(*cluster.indices)(streamlines)
        else:
            cluster_streamlines = streamlines[cluster.indices]

        new_sft = StatefulTractogram(cluster_streamlines, sft, Space.RASMM)
        save_tractogram(
            new_sft,
            os.path.join(args.output_clusters_dir, 'cluster_{}.trk'.format(i)))

    if args.output_centroids:
        new_sft = StatefulTractogram(clusters.centroids, sft, Space.RASMM)
        save_tractogram(new_sft, args.output_centroids)
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_hdf5)
    assert_output_dirs_exist_and_empty(parser, args, args.out_dir,
                                       create_dir=True)

    keys = []
    for filename in args.in_hdf5:
        curr_file = h5py.File(filename, 'r')
        keys.extend(curr_file.keys())
        curr_file.close()

    nbr_cpu = validate_nbr_processes(parser, args, args.nbr_processes)
    if nbr_cpu == 1:
        for key in keys:
            _average_wrapper([args.in_hdf5, key, args.binary, args.out_dir])
    else:
        pool = multiprocessing.Pool(nbr_cpu)
        _ = pool.map(_average_wrapper,
                     zip(itertools.repeat(args.in_hdf5),
                         keys,
                         itertools.repeat(args.binary),
                         itertools.repeat(args.out_dir)))
        pool.close()
        pool.join()
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_hdf5)
    assert_output_dirs_exist_and_empty(parser,
                                       args,
                                       args.out_dir,
                                       create_dir=True)

    hdf5_file = h5py.File(args.in_hdf5, 'r')
    for key in hdf5_file.keys():
        affine = hdf5_file.attrs['affine']
        dimensions = hdf5_file.attrs['dimensions']
        voxel_sizes = hdf5_file.attrs['voxel_sizes']
        streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key)
        header = create_nifti_header(affine, dimensions, voxel_sizes)
        sft = StatefulTractogram(streamlines,
                                 header,
                                 Space.VOX,
                                 origin=Origin.TRACKVIS)
        if args.include_dps:
            for dps_key in hdf5_file[key].keys():
                if dps_key not in ['data', 'offsets', 'lengths']:
                    sft.data_per_streamline[dps_key] = hdf5_file[key][dps_key]

        save_tractogram(sft, '{}.trk'.format(os.path.join(args.out_dir, key)))

    hdf5_file.close()
Example #4
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_hdf5)
    assert_output_dirs_exist_and_empty(parser,
                                       args,
                                       args.out_dir,
                                       create_dir=True)
    if args.save_empty and args.labels_list is None:
        parser.error("The option --save_empty requires --labels_list.")

    with h5py.File(args.in_hdf5, 'r') as hdf5_file:
        if args.save_empty:
            all_labels = np.loadtxt(args.labels_list, dtype='str')
            comb_list = list(itertools.combinations(all_labels, r=2))
            comb_list.extend(zip(all_labels, all_labels))
            keys = [i[0] + '_' + i[1] for i in comb_list]
        else:
            keys = hdf5_file.keys()

        if args.edge_keys is not None:
            selected_keys = [key for key in keys if key in args.edge_keys]
        elif args.node_keys is not None:
            selected_keys = []
            for node in args.node_keys:
                selected_keys.extend([
                    key for key in keys
                    if key.startswith(node + '_') or key.endswith('_' + node)
                ])
        else:
            selected_keys = keys

        affine = hdf5_file.attrs['affine']
        dimensions = hdf5_file.attrs['dimensions']
        voxel_sizes = hdf5_file.attrs['voxel_sizes']
        header = create_nifti_header(affine, dimensions, voxel_sizes)
        for key in selected_keys:
            streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key)

            if len(streamlines) == 0 and not args.save_empty:
                continue

            sft = StatefulTractogram(streamlines,
                                     header,
                                     Space.VOX,
                                     origin=Origin.TRACKVIS)
            if args.include_dps:
                for dps_key in hdf5_file[key].keys():
                    if dps_key not in ['data', 'offsets', 'lengths']:
                        sft.data_per_streamline[dps_key] = hdf5_file[key][
                            dps_key]

            save_tractogram(sft,
                            '{}.trk'.format(os.path.join(args.out_dir, key)))
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_matrices,
                        [args.labels_list, args.in_ordering])
    assert_output_dirs_exist_and_empty(parser, args, [], args.out_dir)
    if args.out_dir is None:
        args.out_dir = './'
    if args.out_suffix is None:
        args.out_suffix = ""
    out_filenames = []
    for filename in args.in_matrices:
        basename, _ = os.path.splitext(filename)
        basename = os.path.basename(basename)
        out_filenames.append('{}/{}{}.npy'.format(args.out_dir,
                                                  basename,
                                                  args.out_suffix))

    assert_outputs_exist(parser, args, out_filenames)
    with open(args.in_ordering, 'r') as my_file:
        lines = my_file.readlines()
        ordering = [[int(val) for val in lines[0].split()],
                    [int(val) for val in lines[1].split()]]

    for filename in args.in_matrices:
        basename, _ = os.path.splitext(filename)
        basename = os.path.basename(basename)
        matrix = load_matrix_in_any_format(filename)

        if args.labels_list:
            labels_list = np.loadtxt(args.labels_list, dtype=np.int16).tolist()
            indices_1, indices_2 = [], []
            for j in ordering[0]:
                indices_1.append(labels_list.index(j))
            for j in ordering[1]:
                indices_2.append(labels_list.index(j))
        else:
            indices_1 = ordering[0]
            indices_2 = ordering[1]

        if (np.array(indices_1) > matrix.shape[0]).any() \
                or (indices_2 > np.array(matrix.shape[1])).any():
            raise ValueError('Indices from config higher than matrix size, '
                             'maybe you need a labels list?')
        tmp_matrix = matrix[tuple(indices_1), :]
        tmp_matrix = tmp_matrix[:, tuple(indices_2)]
        save_matrix_in_any_format('{}/{}{}.npy'.format(args.out_dir,
                                                       basename,
                                                       args.out_suffix),
                                  tmp_matrix)
Example #6
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_bundle] + args.in_metrics)
    assert_output_dirs_exist_and_empty(parser, args,
                                       args.out_folder,
                                       create_dir=True)

    assert_same_resolution(args.in_metrics)

    sft = load_tractogram_with_reference(parser, args, args.in_bundle)
    sft.to_vox()
    sft.to_corner()

    if len(sft.streamlines) == 0:
        logging.warning('Empty bundle file {}. Skipping'.format(args.bundle))
        return

    mins, maxs, indices = _process_streamlines(sft.streamlines)

    metrics = [nib.load(metric) for metric in args.in_metrics]
    for metric in metrics:
        data = metric.get_fdata(dtype=np.float32)
        endpoint_metric_map = np.zeros(metric.shape)
        count = np.zeros(metric.shape)
        for cur_min, cur_max, cur_ind, orig_s in zip(mins, maxs,
                                                     indices,
                                                     sft.streamlines):
            streamline_mean = _compute_streamline_mean(cur_ind,
                                                       cur_min,
                                                       cur_max,
                                                       data)

            xyz = orig_s[0, :].astype(int)
            endpoint_metric_map[xyz[0], xyz[1], xyz[2]] += streamline_mean
            count[xyz[0], xyz[1], xyz[2]] += 1

            xyz = orig_s[-1, :].astype(int)
            endpoint_metric_map[xyz[0], xyz[1], xyz[2]] += streamline_mean
            count[xyz[0], xyz[1], xyz[2]] += 1

        endpoint_metric_map[count != 0] /= count[count != 0]
        metric_fname, ext = split_name_with_nii(
            os.path.basename(metric.get_filename()))
        nib.save(nib.Nifti1Image(endpoint_metric_map, metric.affine,
                                 metric.header),
                 os.path.join(args.out_folder,
                              '{}_endpoints_metric{}'.format(metric_fname,
                                                             ext)))
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    required = args.in_label
    assert_inputs_exist(parser, required)

    label_img = nib.load(args.in_label)
    label_img_data = get_data_as_label(label_img)

    if args.scilpy_lut:
        with open(os.path.join(get_lut_dir(), args.scilpy_lut + '.json')) as f:
            label_dict = json.load(f)
        (label_indices, label_names) = zip(*label_dict.items())
    else:
        with open(args.custom_lut) as f:
            label_dict = json.load(f)
        (label_indices, label_names) = zip(*label_dict.items())

    output_filenames = []
    for label, name in zip(label_indices, label_names):
        if int(label) != 0:
            if args.out_prefix:
                output_filenames.append(
                    os.path.join(
                        args.out_dir,
                        '{0}_{1}.nii.gz'.format(args.out_prefix, name)))
            else:
                output_filenames.append(
                    os.path.join(args.out_dir, '{0}.nii.gz'.format(name)))

    assert_output_dirs_exist_and_empty(parser, args, [], optional=args.out_dir)
    assert_outputs_exist(parser, args, output_filenames)

    # Extract the voxels that match the label and save them to a file.
    cnt_filename = 0
    for label in label_indices:
        if int(label) != 0:
            split_label = np.zeros(label_img.shape, dtype=np.uint16)
            split_label[np.where(label_img_data == int(label))] = label

            split_image = nib.Nifti1Image(split_label,
                                          label_img.affine,
                                          header=label_img.header)
            nib.save(split_image, output_filenames[cnt_filename])
            cnt_filename += 1
def main():
    parser = _build_args_parser()
    args = parser.parse_args()

    assert_inputs_exist(
        parser, [args.in_tractogram, args.config_file, args.transformation])

    for directory in args.models_directories:
        if not os.path.isdir(directory):
            parser.error('Input folder {0} does not exist'.format(directory))

    assert_output_dirs_exist_and_empty(parser, args, args.output)

    logging.basicConfig(
        filename=os.path.join(args.output, 'logfile.txt'),
        filemode='w',
        format='%(asctime)s, %(name)s %(levelname)s %(message)s',
        datefmt='%H:%M:%S',
        level=args.log_level)

    coloredlogs.install(level=args.log_level)

    transfo = np.loadtxt(args.transformation)
    if args.inverse:
        transfo = np.linalg.inv(np.loadtxt(args.transformation))

    with open(args.config_file) as json_data:
        config = json.load(json_data)

    voting = VotingScheme(
        config,
        args.models_directories,
        transfo,
        args.output,
        tractogram_clustering_thr=args.tractogram_clustering_thr,
        minimal_vote_ratio=args.minimal_vote_ratio,
        multi_parameters=args.multi_parameters)

    if args.seeds is None:
        seeds = [random.randint(1, 1000)]
    else:
        seeds = args.seeds

    voting(args.in_tractogram, nbr_processes=args.processes, seeds=seeds)
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    required = args.in_labels
    assert_inputs_exist(parser, required)

    label_img = nib.load(args.in_labels)
    label_img_data = get_data_as_label(label_img)

    if args.range:
        label_indices = [item for sublist in args.range for item in sublist]
    else:
        label_indices = np.unique(label_img_data)
    label_names = [str(i) for i in label_indices]

    output_filenames = []
    for label, name in zip(label_indices, label_names):
        if int(label) != 0:
            if args.out_prefix:
                output_filenames.append(
                    os.path.join(
                        args.out_dir,
                        '{0}_{1}.nii.gz'.format(args.out_prefix, name)))
            else:
                output_filenames.append(
                    os.path.join(args.out_dir, '{0}.nii.gz'.format(name)))

    assert_output_dirs_exist_and_empty(parser, args, [], optional=args.out_dir)
    assert_outputs_exist(parser, args, output_filenames)

    # Extract the voxels that match the label and save them to a file.
    cnt_filename = 0
    for label in label_indices:
        if int(label) != 0:
            split_label = np.zeros(label_img.shape, dtype=np.uint16)
            split_label[np.where(label_img_data == int(label))] = label

            split_image = nib.Nifti1Image(split_label,
                                          label_img.affine,
                                          header=label_img.header)
            nib.save(split_image, output_filenames[cnt_filename])
            cnt_filename += 1
Example #10
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    _, out_extension = os.path.splitext(args.in_tractogram)

    assert_output_dirs_exist_and_empty(parser, args, [], optional=args.out_dir)
    # Check only the first potential output filename
    assert_outputs_exist(parser, args, os.path.join(args.out_dir,
                                                    '{}_0{}'.format(args.out_prefix,
                                                                    out_extension)))

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
    streamlines_count = len(sft.streamlines)

    if args.nb_chunk:
        chunk_size = int(streamlines_count/args.nb_chunk)
        nb_chunk = args.nb_chunk
    else:
        chunk_size = args.chunk_size
        nb_chunk = int(streamlines_count/chunk_size)+1

    # All chunks will be equal except the last one
    chunk_sizes = np.ones((nb_chunk,), dtype=np.int16) * chunk_size
    chunk_sizes[-1] += (streamlines_count - chunk_size * nb_chunk)
    curr_count = 0
    for i in range(nb_chunk):
        streamlines = sft.streamlines[curr_count:curr_count + chunk_sizes[i]]
        data_per_streamline = sft.data_per_streamline[curr_count:curr_count
                                                      + chunk_sizes[i]]
        data_per_point = sft.data_per_point[curr_count:curr_count
                                            + chunk_sizes[i]]
        curr_count += chunk_sizes[i]
        new_sft = StatefulTractogram.from_sft(streamlines, sft,
                                              data_per_point=data_per_point,
                                              data_per_streamline=data_per_streamline)

        out_name = os.path.join(args.out_dir,
                                '{0}_{1}{2}'.format(args.out_prefix, i,
                                                    out_extension))
        save_tractogram(new_sft, out_name)
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_json)
    assert_output_dirs_exist_and_empty(parser,
                                       args,
                                       args.out_dir,
                                       create_dir=True)

    if args.fill_color and len(args.fill_color) != 8:
        parser.error('Hexadecimal RGB color should be formatted as 0xRRGGBB')

    with open(args.in_json, 'r+') as f:
        mean_std_per_point = json.load(f)

    for bundle_name, bundle_stats in mean_std_per_point.items():
        for metric, metric_stats in bundle_stats.items():
            nb_points = len(metric_stats)
            num_digits_labels = len(str(nb_points))
            means = []
            stds = []
            for label_int in range(1, nb_points + 1):
                label = str(label_int).zfill(num_digits_labels)
                mean = metric_stats.get(label, {'mean': np.nan})['mean']
                mean = mean if mean else np.nan
                std = metric_stats.get(label, {'std': np.nan})['std']
                std = std if std else np.nan
                means += [mean]
                stds += [std]

            fig = plot_metrics_stats(
                np.array(means),
                np.array(stds),
                title=bundle_name,
                xlabel='Location along the streamline',
                ylabel=metric,
                fill_color=(args.fill_color.replace("0x", "#")
                            if args.fill_color else None))
            fig.savefig(os.path.join(args.out_dir,
                                     '{}_{}.png'.format(bundle_name, metric)),
                        bbox_inches='tight')
Example #12
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram, args.in_dwi,
                                 args.in_bval, args.in_bvec],
                        [args.in_peaks, args.in_tracking_mask])
    assert_output_dirs_exist_and_empty(parser, args, args.out_dir,
                                       optional=args.save_kernels)

    if args.load_kernels and not os.path.isdir(args.load_kernels):
        parser.error('Kernels directory does not exist.')

    if args.compute_only and not args.save_kernels:
        parser.error('--compute_only must be used with --save_kernels.')

    if args.load_kernels and args.save_kernels:
        parser.error('Cannot load and save kernels at the same time.')

    if args.ball_stick and args.perp_diff:
        parser.error('Cannot use --perp_diff with ball&stick.')

    if not args.ball_stick and not args.in_peaks:
        parser.error('Stick Zeppelin Ball model requires --in_peaks')

    if args.ball_stick and args.iso_diff and len(args.iso_diff) > 1:
        parser.error('Cannot use more than one --iso_diff with '
                     'ball&stick.')

    # If it is a trk, check compatibility of header since COMMIT does not do it
    dwi_img = nib.load(args.in_dwi)
    _, ext = os.path.splitext(args.in_tractogram)
    if ext == '.trk' and not is_header_compatible(args.in_tractogram,
                                                  dwi_img):
        parser.error('{} does not have a compatible header with {}'.format(
            args.in_tractogram, args.in_dwi))

    if args.threshold_weights == 'None' or args.threshold_weights == 'none':
        args.threshold_weights = None
        if not args.keep_whole_tractogram and ext != '.h5':
            logging.warning('Not thresholding weigth with trk file without '
                            'the --keep_whole_tractogram will not save a '
                            'tractogram')
    else:
        args.threshold_weights = float(args.threshold_weights)

    # COMMIT has some c-level stdout and non-logging print that cannot
    # be easily stopped. Manual redirection of all printed output
    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)
        redirected_stdout = redirect_stdout(sys.stdout)
    else:
        f = io.StringIO()
        redirected_stdout = redirect_stdout(f)
        redirect_stdout_c()

    tmp_dir = tempfile.TemporaryDirectory()
    if ext == '.h5':
        logging.debug('Reconstructing {} into a tractogram for COMMIT.'.format(
            args.in_tractogram))

        hdf5_file = h5py.File(args.in_tractogram, 'r')
        if not (np.allclose(hdf5_file.attrs['affine'], dwi_img.affine,
                            atol=1e-03)
                and np.array_equal(hdf5_file.attrs['dimensions'],
                                   dwi_img.shape[0:3])):
            parser.error('{} does not have a compatible header with {}'.format(
                args.in_tractogram, args.in_dwi))

        # Keep track of the order of connections/streamlines in relation to the
        # tractogram as well as the number of streamlines for each connection.
        hdf5_keys = list(hdf5_file.keys())
        streamlines = []
        offsets_list = [0]
        for key in hdf5_keys:
            tmp_streamlines = reconstruct_streamlines_from_hdf5(hdf5_file,
                                                                key)
            offsets_list.append(len(tmp_streamlines))
            streamlines.extend(tmp_streamlines)

        offsets_list = np.cumsum(offsets_list)

        sft = StatefulTractogram(streamlines, args.in_dwi,
                                 Space.VOX, origin=Origin.TRACKVIS)
        tmp_tractogram_filename = os.path.join(tmp_dir.name, 'tractogram.trk')

        # Keeping the input variable, saving trk file for COMMIT internal use
        save_tractogram(sft, tmp_tractogram_filename)
        args.in_tractogram = tmp_tractogram_filename

    # Writing the scheme file with proper shells
    tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme')
    tmp_bval_filename = os.path.join(tmp_dir.name, 'bval')
    bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec)
    shells_centroids, indices_shells = identify_shells(bvals, args.b_thr,
                                                       roundCentroids=True)
    np.savetxt(tmp_bval_filename, shells_centroids[indices_shells],
               newline=' ', fmt='%i')
    fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename)
    logging.debug('Lauching COMMIT on {} shells at found at {}.'.format(
        len(shells_centroids),
        shells_centroids))

    if len(shells_centroids) == 2 and not args.ball_stick:
        parser.error('The DWI data appears to be single-shell.\n'
                     'Use --ball_stick for single-shell.')

    with redirected_stdout:
        # Setting up the tractogram and nifti files
        trk2dictionary.run(filename_tractogram=args.in_tractogram,
                           filename_peaks=args.in_peaks,
                           peaks_use_affine=False,
                           filename_mask=args.in_tracking_mask,
                           ndirs=args.nbr_dir,
                           gen_trk=False,
                           path_out=tmp_dir.name)

        # Preparation for fitting
        commit.core.setup(ndirs=args.nbr_dir)
        mit = commit.Evaluation('.', '.')

        # FIX for very small values during HCP processing
        # (based on order of magnitude of signal)
        img = nib.load(args.in_dwi)
        data = img.get_fdata(dtype=np.float32)
        data[data < (0.001*10**np.floor(np.log10(np.mean(data[data > 0]))))] = 0
        nib.save(nib.Nifti1Image(data, img.affine),
                 os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'))

        mit.load_data(os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'),
                      tmp_scheme_filename)
        mit.set_model('StickZeppelinBall')

        if args.ball_stick:
            logging.debug('Disabled zeppelin, using the Ball & Stick model.')
            para_diff = args.para_diff or 1.7E-3
            perp_diff = []
            isotropc_diff = args.iso_diff or [2.0E-3]
            mit.model.set(para_diff, perp_diff, isotropc_diff)
        else:
            logging.debug('Using the Stick Zeppelin Ball model.')
            para_diff = args.para_diff or 1.7E-3
            perp_diff = args.perp_diff or [0.85E-3, 0.51E-3]
            isotropc_diff = args.iso_diff or [1.7E-3, 3.0E-3]
            mit.model.set(para_diff, perp_diff, isotropc_diff)

        # The kernels are, by default, set to be in the current directory
        # Depending on the choice, manually change the saving location
        if args.save_kernels:
            kernels_dir = os.path.join(args.save_kernels)
            regenerate_kernels = True
        elif args.load_kernels:
            kernels_dir = os.path.join(args.load_kernels)
            regenerate_kernels = False
        else:
            kernels_dir = os.path.join(tmp_dir.name, 'kernels', mit.model.id)
            regenerate_kernels = True
        mit.set_config('ATOMS_path', kernels_dir)

        mit.generate_kernels(ndirs=500, regenerate=regenerate_kernels)
        if args.compute_only:
            return
        mit.load_kernels()
        mit.load_dictionary(tmp_dir.name,
                            use_mask=args.in_tracking_mask is not None)
        mit.set_threads(args.nbr_processes)

        mit.build_operator(build_dir=tmp_dir.name)
        mit.fit(tol_fun=1e-3, max_iter=args.nbr_iter, verbose=0)
        mit.save_results()

    # Simplifying output for streamlines and cleaning output directory
    commit_results_dir = os.path.join(tmp_dir.name,
                                      'Results_StickZeppelinBall')
    pk_file = open(os.path.join(commit_results_dir, 'results.pickle'), 'rb')
    commit_output_dict = pickle.load(pk_file)
    nbr_streamlines = lazy_streamlines_count(args.in_tractogram)
    commit_weights = np.asarray(commit_output_dict[2][:nbr_streamlines])
    np.savetxt(os.path.join(commit_results_dir,
                            'commit_weights.txt'),
               commit_weights)

    if ext == '.h5':
        new_filename = os.path.join(commit_results_dir,
                                    'decompose_commit.h5')
        with h5py.File(new_filename, 'w') as new_hdf5_file:
            new_hdf5_file.attrs['affine'] = sft.affine
            new_hdf5_file.attrs['dimensions'] = sft.dimensions
            new_hdf5_file.attrs['voxel_sizes'] = sft.voxel_sizes
            new_hdf5_file.attrs['voxel_order'] = sft.voxel_order
            # Assign the weights into the hdf5, while respecting the ordering of
            # connections/streamlines
            logging.debug('Adding commit weights to {}.'.format(new_filename))
            for i, key in enumerate(hdf5_keys):
                new_group = new_hdf5_file.create_group(key)
                old_group = hdf5_file[key]
                tmp_commit_weights = commit_weights[offsets_list[i]:offsets_list[i+1]]
                if args.threshold_weights is not None:
                    essential_ind = np.where(
                        tmp_commit_weights > args.threshold_weights)[0]
                    tmp_streamlines = reconstruct_streamlines(old_group['data'],
                                                              old_group['offsets'],
                                                              old_group['lengths'],
                                                              indices=essential_ind)

                    # Replacing the data with the one above the threshold
                    # Safe since this hdf5 was a copy in the first place
                    new_group.create_dataset('data',
                                             data=tmp_streamlines.get_data(),
                                             dtype=np.float32)
                    new_group.create_dataset('offsets',
                                             data=tmp_streamlines._offsets,
                                             dtype=np.int64)
                    new_group.create_dataset('lengths',
                                             data=tmp_streamlines._lengths,
                                             dtype=np.int32)

                for dps_key in hdf5_file[key].keys():
                    if dps_key not in ['data', 'offsets', 'lengths']:
                        new_group.create_dataset(key,
                                                 data=hdf5_file[key][dps_key])
                new_group.create_dataset('commit_weights',
                                         data=tmp_commit_weights)

    files = os.listdir(commit_results_dir)
    for f in files:
        shutil.move(os.path.join(commit_results_dir, f), args.out_dir)

    # Save split tractogram (essential/nonessential) and/or saving the
    # tractogram with data_per_streamline updated
    if args.keep_whole_tractogram or args.threshold_weights is not None:
        # Reload is needed because of COMMIT handling its file by itself
        tractogram_file = nib.streamlines.load(args.in_tractogram)
        tractogram = tractogram_file.tractogram
        tractogram.data_per_streamline['commit_weights'] = commit_weights

        if args.threshold_weights is not None:
            essential_ind = np.where(
                commit_weights > args.threshold_weights)[0]
            nonessential_ind = np.where(
                commit_weights <= args.threshold_weights)[0]
            logging.debug('{} essential streamlines were kept at '
                          'threshold {}'.format(len(essential_ind),
                                                args.threshold_weights))
            logging.debug('{} nonessential streamlines were kept at '
                          'threshold {}'.format(len(nonessential_ind),
                                                args.threshold_weights))

            # TODO PR when Dipy 1.2 is out with sft slicing
            essential_streamlines = tractogram.streamlines[essential_ind]
            essential_dps = tractogram.data_per_streamline[essential_ind]
            essential_dpp = tractogram.data_per_point[essential_ind]
            essential_tractogram = Tractogram(essential_streamlines,
                                              data_per_point=essential_dpp,
                                              data_per_streamline=essential_dps,
                                              affine_to_rasmm=np.eye(4))

            nonessential_streamlines = tractogram.streamlines[nonessential_ind]
            nonessential_dps = tractogram.data_per_streamline[nonessential_ind]
            nonessential_dpp = tractogram.data_per_point[nonessential_ind]
            nonessential_tractogram = Tractogram(nonessential_streamlines,
                                                 data_per_point=nonessential_dpp,
                                                 data_per_streamline=nonessential_dps,
                                                 affine_to_rasmm=np.eye(4))

            nib.streamlines.save(essential_tractogram,
                                 os.path.join(args.out_dir,
                                              'essential_tractogram.trk'),
                                 header=tractogram_file.header)
            nib.streamlines.save(nonessential_tractogram,
                                 os.path.join(args.out_dir,
                                              'nonessential_tractogram.trk'),
                                 header=tractogram_file.header,)
        if args.keep_whole_tractogram:
            output_filename = os.path.join(args.out_dir, 'tractogram.trk')
            logging.debug('Saving tractogram with weights as {}'.format(
                output_filename))
            nib.streamlines.save(tractogram_file, output_filename)

    tmp_dir.cleanup()
Example #13
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_volume)
    assert_outputs_exist(parser, args, args.out_image)

    output_names = [
        'axial_superior', 'axial_inferior', 'coronal_posterior',
        'coronal_anterior', 'sagittal_left', 'sagittal_right'
    ]

    for filename in args.in_bundles:
        _, ext = os.path.splitext(filename)
        if ext == '.tck':
            tractogram = load_tractogram_with_reference(parser, args, filename)
        else:
            tractogram = filename
        if not is_header_compatible(args.in_volume, tractogram):
            parser.error('{} does not have a compatible header with {}'.format(
                filename, args.in_volume))
        # Delete temporary tractogram
        else:
            del tractogram

    output_dir = os.path.dirname(args.out_image)
    if output_dir:
        assert_output_dirs_exist_and_empty(parser,
                                           args,
                                           output_dir,
                                           create_dir=True)

    _, extension = os.path.splitext(args.out_image)

    # ----------------------------------------------------------------------- #
    # Mosaic, column 0: orientation names and data description
    # ----------------------------------------------------------------------- #
    width = args.resolution_of_thumbnails
    height = args.resolution_of_thumbnails
    rows = 6
    cols = len(args.in_bundles)
    text_pos_x = 50
    text_pos_y = 50

    # Creates a new empty image, RGB mode
    mosaic = Image.new('RGB', ((cols + 1) * width, (rows + 1) * height))

    # Prepare draw and font objects to render text
    draw = ImageDraw.Draw(mosaic)
    font = get_font(args)

    # Data of the volume used as background
    ref_img = nib.load(args.in_volume)
    data = ref_img.get_fdata(dtype=np.float32)
    affine = ref_img.affine
    mean, std = data[data > 0].mean(), data[data > 0].std()
    value_range = (mean - 0.5 * std, mean + 1.5 * std)

    # First column with rows description
    draw_column_with_names(draw, output_names, text_pos_x, text_pos_y, height,
                           font)

    # ----------------------------------------------------------------------- #
    # Columns with bundles
    # ----------------------------------------------------------------------- #
    random.seed(args.random_coloring)
    for idx_bundle, bundle_file in enumerate(args.in_bundles):

        bundle_file_name = os.path.basename(bundle_file)
        bundle_name, bundle_ext = split_name_with_nii(bundle_file_name)

        i = (idx_bundle + 1) * width

        if not os.path.isfile(bundle_file):
            print('\nInput file {} doesn\'t exist.'.format(bundle_file))

            number_streamlines = 0

            view_number = 6
            j = height * view_number

            draw_bundle_information(draw, bundle_file_name, number_streamlines,
                                    i + text_pos_x, j + text_pos_y, font)

        else:
            if args.uniform_coloring:
                colors = args.uniform_coloring
            elif args.random_coloring is not None:
                colors = random_rgb()
            # Select the streamlines to plot
            if bundle_ext in ['.tck', '.trk']:
                if (args.random_coloring is None
                        and args.uniform_coloring is None):
                    colors = None
                bundle_tractogram_file = nib.streamlines.load(bundle_file)
                streamlines = bundle_tractogram_file.streamlines
                bundle_actor = actor.line(streamlines, colors)
                nbr_of_elem = len(streamlines)
            # Select the volume to plot
            elif bundle_ext in ['.nii.gz', '.nii']:
                if not args.random_coloring and not args.uniform_coloring:
                    colors = [1.0, 1.0, 1.0]
                bundle_img_file = nib.load(bundle_file)
                roi = get_data_as_mask(bundle_img_file)
                bundle_actor = actor.contour_from_roi(roi,
                                                      bundle_img_file.affine,
                                                      colors)
                nbr_of_elem = np.count_nonzero(roi)

            # Render
            ren = window.Scene()
            zoom = args.zoom
            opacity = args.opacity_background

            # Structural data
            slice_actor = actor.slicer(data, affine, value_range)
            slice_actor.opacity(opacity)
            ren.add(slice_actor)

            # Streamlines
            ren.add(bundle_actor)
            ren.reset_camera()
            ren.zoom(zoom)
            view_number = 0
            set_img_in_cell(mosaic, ren, view_number, width, height, i)

            ren.pitch(180)
            ren.reset_camera()
            ren.zoom(zoom)
            view_number = 1
            set_img_in_cell(mosaic, ren, view_number, width, height, i)

            ren.rm(slice_actor)
            slice_actor2 = slice_actor.copy()
            slice_actor2.display(None, slice_actor2.shape[1] // 2, None)
            slice_actor2.opacity(opacity)
            ren.add(slice_actor2)

            ren.pitch(90)
            ren.set_camera(view_up=(0, 0, 1))
            ren.reset_camera()
            ren.zoom(zoom)
            view_number = 2
            set_img_in_cell(mosaic, ren, view_number, width, height, i)

            ren.pitch(180)
            ren.set_camera(view_up=(0, 0, 1))
            ren.reset_camera()
            ren.zoom(zoom)
            view_number = 3
            set_img_in_cell(mosaic, ren, view_number, width, height, i)

            ren.rm(slice_actor2)
            slice_actor3 = slice_actor.copy()
            slice_actor3.display(slice_actor3.shape[0] // 2, None, None)
            slice_actor3.opacity(opacity)
            ren.add(slice_actor3)

            ren.yaw(90)
            ren.reset_camera()
            ren.zoom(zoom)
            view_number = 4
            set_img_in_cell(mosaic, ren, view_number, width, height, i)

            ren.yaw(180)
            ren.reset_camera()
            ren.zoom(zoom)
            view_number = 5
            set_img_in_cell(mosaic, ren, view_number, width, height, i)

            view_number = 6
            j = height * view_number
            draw_bundle_information(draw, bundle_file_name, nbr_of_elem,
                                    i + text_pos_x, j + text_pos_y, font)

    # Save image to file
    mosaic.save(args.out_image)
Example #14
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    # The number of labels maps must be equal to the number of bundles
    tmp = args.in_bundles + args.in_labels
    args.in_labels = args.in_bundles[(len(tmp) // 2):] + args.in_labels
    args.in_bundles = args.in_bundles[0:len(tmp) // 2]
    assert_inputs_exist(parser, args.in_bundles + args.in_labels)
    assert_output_dirs_exist_and_empty(parser,
                                       args, [],
                                       optional=args.save_rendering)

    stats = {}
    num_digits_labels = 3
    scene = window.Scene()
    scene.background(tuple(map(int, args.background)))
    for i, filename in enumerate(args.in_bundles):
        sft = load_tractogram_with_reference(parser, args, filename)
        sft.to_vox()
        sft.to_corner()
        img_labels = nib.load(args.in_labels[i])

        # same subject: same header or coregistered subjects: same header
        if not is_header_compatible(sft, args.in_bundles[0]) \
                or not is_header_compatible(img_labels, args.in_bundles[0]):
            parser.error('All headers must be identical.')

        data_labels = img_labels.get_fdata()
        bundle_name, _ = os.path.splitext(os.path.basename(filename))
        unique_labels = np.unique(data_labels)[1:].astype(int)

        # Empty bundle should at least return a json
        if not len(sft):
            tmp_dict = {}
            for label in unique_labels:
                tmp_dict['{}'.format(label).zfill(num_digits_labels)] \
                    = {'mean': 0.0, 'std': 0.0}
            stats[bundle_name] = {'diameter': tmp_dict}
            continue

        counter = 0
        labels_dict = {label: ([], []) for label in unique_labels}
        pts_labels = map_coordinates(data_labels,
                                     sft.streamlines._data.T - 0.5,
                                     order=0)
        # For each label, all positions and directions are needed to get
        # a tube estimation per label.
        for streamline in sft.streamlines:
            direction = np.gradient(streamline, axis=0).tolist()
            curr_labels = pts_labels[counter:counter +
                                     len(streamline)].tolist()

            for i, label in enumerate(curr_labels):
                if label > 0:
                    labels_dict[label][0].append(streamline[i])
                    labels_dict[label][1].append(direction[i])

            counter += len(streamline)

        centroid = np.zeros((len(unique_labels), 3))
        radius = np.zeros((len(unique_labels), 1))
        error = np.zeros((len(unique_labels), 1))
        for key in unique_labels:
            key = int(key)
            c, d, e = fit_circle_in_space(labels_dict[key][0],
                                          labels_dict[key][1],
                                          args.fitting_func)
            centroid[key - 1], radius[key - 1], error[key - 1] = c, d, e

        # Spatial smoothing to avoid degenerate estimation
        centroid_smooth = gaussian_filter(centroid,
                                          sigma=[1, 0],
                                          mode='nearest')
        centroid_smooth[::len(centroid) - 1] = centroid[::len(centroid) - 1]
        radius = gaussian_filter(radius, sigma=1, mode='nearest')
        error = gaussian_filter(error, sigma=1, mode='nearest')

        tmp_dict = {}
        for label in unique_labels:
            tmp_dict['{}'.format(label).zfill(num_digits_labels)] \
                = {'mean': float(radius[label-1])*2,
                   'std': float(error[label-1])}
        stats[bundle_name] = {'diameter': tmp_dict}

        if args.show_rendering or args.save_rendering:
            tube_actor = create_tube_with_radii(
                centroid_smooth,
                radius,
                error,
                wireframe=args.wireframe,
                error_coloring=args.error_coloring)
            scene.add(tube_actor)
            cmap = plt.get_cmap('jet')
            coloring = cmap(pts_labels / np.max(pts_labels))[:, 0:3]
            streamlines_actor = actor.streamtube(sft.streamlines,
                                                 linewidth=args.width,
                                                 opacity=args.opacity,
                                                 colors=coloring)
            scene.add(streamlines_actor)

            slice_actor = actor.slicer(data_labels, np.eye(4))
            slice_actor.opacity(0.0)
            scene.add(slice_actor)

    # If there's actually streamlines to display
    if args.show_rendering:
        showm = window.ShowManager(scene, reset_camera=True)
        showm.initialize()
        showm.start()
    elif args.save_rendering:
        scene.reset_camera()
        snapshot(scene,
                 os.path.join(args.save_rendering, 'superior.png'),
                 size=(1920, 1080),
                 offscreen=True)

        scene.pitch(180)
        scene.reset_camera()
        snapshot(scene,
                 os.path.join(args.save_rendering, 'inferior.png'),
                 size=(1920, 1080),
                 offscreen=True)

        scene.pitch(90)
        scene.set_camera(view_up=(0, 0, 1))
        scene.reset_camera()
        snapshot(scene,
                 os.path.join(args.save_rendering, 'posterior.png'),
                 size=(1920, 1080),
                 offscreen=True)

        scene.pitch(180)
        scene.set_camera(view_up=(0, 0, 1))
        scene.reset_camera()
        snapshot(scene,
                 os.path.join(args.save_rendering, 'anterior.png'),
                 size=(1920, 1080),
                 offscreen=True)

        scene.yaw(90)
        scene.reset_camera()
        snapshot(scene,
                 os.path.join(args.save_rendering, 'right.png'),
                 size=(1920, 1080),
                 offscreen=True)

        scene.yaw(180)
        scene.reset_camera()
        snapshot(scene,
                 os.path.join(args.save_rendering, 'left.png'),
                 size=(1920, 1080),
                 offscreen=True)
    print(json.dumps(stats, indent=args.indent, sort_keys=args.sort_keys))
Example #15
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_json)
    assert_output_dirs_exist_and_empty(parser,
                                       args,
                                       args.out_dir,
                                       create_dir=True)

    if args.fill_color and len(args.fill_color) != 8:
        parser.error('Hexadecimal RGB color should be formatted as 0xRRGGBB')

    with open(args.in_json, 'r+') as f:
        if args.stats_over_population:
            mean_std_per_point = json.load(f)
        else:
            mean_std_per_point = list(json.load(f).values())[0]

    for bundle_name, bundle_stats in mean_std_per_point.items():
        for metric, metric_stats in bundle_stats.items():
            nb_points = args.nb_pts if args.nb_pts is not None \
                else len(metric_stats)
            num_digits_labels = len(list(metric_stats.keys())[0])
            means = []
            stds = []
            for label_int in range(1, nb_points + 1):
                label = str(label_int).zfill(num_digits_labels)
                mean = metric_stats.get(label, {'mean': 0})['mean']
                std = metric_stats.get(label, {'std': 0})['std']
                if not isinstance(mean, list):
                    mean = [mean]
                    std = [std]

                means += [mean]
                stds += [std]

            color = None
            if args.dict_colors:
                with open(args.dict_colors, 'r') as data:
                    dict_colors = json.load(data)
                # Supports variation from rbx-flow
                for key in dict_colors.keys():
                    if key in bundle_name:
                        color = dict_colors[key]
            elif args.fill_color is not None:
                color = args.fill_color
            if color is None:
                color = '0x000000'

            # Robustify for missing data
            means = np.array(
                list(itertools.zip_longest(*means, fillvalue=np.nan))).T
            stds = np.array(
                list(itertools.zip_longest(*stds, fillvalue=np.nan))).T
            for i in range(len(means)):
                _nan = np.isnan(means[i, :])
                if np.count_nonzero(_nan) > 0:
                    if np.count_nonzero(_nan) < len(means[i, :]):
                        means[i, _nan] = np.average(means[i, ~_nan])
                        stds[i, _nan] = np.average(stds[i, ~_nan])
                    else:
                        means[i, _nan] = -1
                        stds[i, _nan] = -1
            if not args.stats_over_population:
                means = np.squeeze(means)
                stds = np.squeeze(stds)
            fig = plot_metrics_stats(means,
                                     stds,
                                     title=bundle_name,
                                     xlabel='Location along the streamline',
                                     ylabel=metric,
                                     fill_color=(color.replace("0x", "#")),
                                     display_means=args.display_means)
            fig.savefig(os.path.join(args.out_dir,
                                     '{}_{}.png'.format(bundle_name, metric)),
                        bbox_inches='tight')
def load_and_verify_everything(parser, args):
    """
    - Reads the config file
    - Loads the masks / sft
        - If endpoints were given instead of head + tail, separate into two
          sub-rois.
    - Verifies compatibility
    """
    assert_inputs_exist(parser, args.gt_config)
    assert_output_dirs_exist_and_empty(parser,
                                       args,
                                       args.out_dir,
                                       create_dir=True)
    os.makedirs(os.path.join(args.out_dir, 'segmented_VB'))
    if args.compute_ic:
        os.makedirs(os.path.join(args.out_dir, 'segmented_IB'))
    if args.save_wpc_separately:
        os.makedirs(os.path.join(args.out_dir, 'segmented_WPC'))

    # Read the config file
    (bundle_names, gt_masks_files, limits_masks_files, roi_options, lengths,
     angles, orientation_lengths,
     abs_orientation_lengths) = read_config_file(args)

    # Find all masks to be loaded.
    all_mask_files = list(
        itertools.chain(
            *[list(roi_option.values()) for roi_option in roi_options]))
    all_mask_files = list(dict.fromkeys(all_mask_files))  # Removes duplicates

    # Verify options
    assert_inputs_exist(parser, all_mask_files + [args.in_tractogram],
                        gt_masks_files + limits_masks_files)

    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    logging.info("Loading tractogram.")
    sft = load_tractogram_with_reference(parser,
                                         args,
                                         args.in_tractogram,
                                         bbox_check=False)

    if args.remove_invalid:
        sft.remove_invalid_streamlines()

    logging.info("Verifying compatibility of tractogram with gt_masks and "
                 "limits_masks.")
    all_masks = gt_masks_files + limits_masks_files
    all_masks = list(dict.fromkeys(all_masks))  # Removes duplicates
    verify_compatibility_with_reference_sft(sft, all_masks, parser, args)

    logging.info("Loading and/or computing ground-truth masks and limits "
                 "masks.")
    gt_masks, _, affine, dimensions, = \
        compute_masks(gt_masks_files, parser, args)
    limits_masks, limits_inv_masks, _, _, = \
        compute_masks(limits_masks_files, parser, args)

    logging.info("Extracting ground-truth head and tail masks.")
    gt_tails, gt_heads = compute_endpoint_masks(roi_options, affine,
                                                dimensions, args.out_dir)

    # Update all_rois, remove duplicates
    all_rois = gt_tails + gt_heads
    all_rois = list(dict.fromkeys(all_rois))  # Removes duplicates

    logging.info("Verifying tractogram compatibility with endpoint ROIs.")
    for file in all_rois:
        compatible = is_header_compatible(sft, file)
        if not compatible:
            parser.error("Input tractogram incompatible with {}".format(file))

    return (gt_tails, gt_heads, sft, bundle_names, all_rois, lengths, angles,
            orientation_lengths, abs_orientation_lengths, limits_inv_masks,
            gt_masks, dimensions)
Example #17
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram] + args.gt_bundles)
    assert_output_dirs_exist_and_empty(parser,
                                       args,
                                       args.out_dir,
                                       create_dir=True)

    if (args.gt_tails and not args.gt_heads) \
            or (args.gt_heads and not args.gt_tails):
        parser.error("Both --gt_heads and --gt_tails are needed.")
    if args.gt_endpoints and (args.gt_tails or args.gt_heads):
        parser.error("Can only provide --gt_endpoints or --gt_tails/gt_heads")
    if not args.gt_endpoints and (not args.gt_tails and not args.gt_heads):
        parser.error(
            "Either input --gt_endpoints or --gt_heads and --gt_tails.")

    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    _, ext = os.path.splitext(args.in_tractogram)
    sft = load_tractogram_with_reference(parser,
                                         args,
                                         args.in_tractogram,
                                         bbox_check=False)

    if args.remove_invalid:
        sft.remove_invalid_streamlines()

    initial_count = len(sft)

    logging.info("Verifying compatibility with ground-truth")
    for gt in args.gt_bundles:
        compatible = is_header_compatible(sft, gt)
        if not compatible:
            parser.error("Input tractogram incompatible with" " {}".format(gt))

    logging.info("Computing ground-truth masks")
    gt_bundle_masks, gt_bundle_inv_masks, affine, dimensions,  = \
        compute_gt_masks(args.gt_bundles, parser, args)

    # If endpoints without heads/tails are loaded, split them and continue
    # normally after. Q/C of the output is important
    if args.gt_endpoints:
        logging.info("Extracting ground-truth end and tail masks")
        gt_tails, gt_heads, affine, dimensions = \
            extract_tails_heads_from_endpoints(
                args.gt_endpoints, args.out_dir)
    else:
        gt_tails, gt_heads = args.gt_tails, args.gt_heads

    logging.info("Verifying compatibility with endpoints")
    for gt in gt_tails + gt_heads:
        compatible = is_header_compatible(sft, gt)
        if not compatible:
            parser.error("Input tractogram incompatible with" " {}".format(gt))

    # Load the endpoints heads/tails, keep the correct combinations
    # separately from all the possible combinations
    tc_filenames = list(zip(gt_tails, gt_heads))

    length_dict = {}
    if args.gt_config:
        with open(args.gt_config, "r") as json_file:
            length_dict = json.load(json_file)

    tc_streamlines_list = []
    wpc_streamlines_list = []
    fc_streamlines_list = []
    nc_streamlines = []

    logging.info("Scoring true connections")
    for i, (mask_1_filename, mask_2_filename) in enumerate(tc_filenames):

        # Automatically generate filename for Q/C
        prefix_1 = extract_prefix(mask_1_filename)
        prefix_2 = extract_prefix(mask_2_filename)

        tc_sft, wpc_sft, fc_sft, nc, sft = extract_true_connections(
            sft, mask_1_filename, mask_2_filename, args.gt_config, length_dict,
            extract_prefix(args.gt_bundles[i]), gt_bundle_inv_masks[i],
            args.dilate_endpoints, args.wrong_path_as_separate)
        nc_streamlines.extend(nc)
        if len(tc_sft) > 0:
            save_tractogram(tc_sft,
                            os.path.join(
                                args.out_dir,
                                "{}_{}_tc{}".format(prefix_1, prefix_2, ext)),
                            bbox_valid_check=False)

        if len(wpc_sft) > 0:
            save_tractogram(wpc_sft,
                            os.path.join(
                                args.out_dir,
                                "{}_{}_wpc{}".format(prefix_1, prefix_2, ext)),
                            bbox_valid_check=False)

        if len(fc_sft) > 0:
            save_tractogram(fc_sft,
                            os.path.join(
                                args.out_dir,
                                "{}_{}_fc{}".format(prefix_1, prefix_2, ext)),
                            bbox_valid_check=False)

        tc_streamlines_list.append(tc_sft.streamlines)
        wpc_streamlines_list.append(wpc_sft.streamlines)
        fc_streamlines_list.append(fc_sft.streamlines)

        logging.info("Recognized {} streamlines between {} and {}".format(
            len(tc_sft.streamlines) + len(wpc_sft.streamlines) +
            len(fc_sft.streamlines) + len(nc), prefix_1, prefix_2))

    # Again keep the keep the correct combinations
    comb_filename = list(
        itertools.combinations(itertools.chain(*zip(gt_tails, gt_heads)), r=2))

    # Remove the true connections from all combinations, leaving only
    # false connections
    for tc_f in tc_filenames:
        comb_filename.remove(tc_f)

    logging.info("Scoring false connections")
    # Go through all the possible combinations of endpoints masks
    for i, roi in enumerate(comb_filename):
        mask_1_filename, mask_2_filename = roi

        # That would be done here.
        # Automatically generate filename for Q/C
        prefix_1 = extract_prefix(mask_1_filename)
        prefix_2 = extract_prefix(mask_2_filename)
        _, ext = os.path.splitext(args.in_tractogram)

        fc_sft, sft = extract_false_connections(sft, mask_1_filename,
                                                mask_2_filename,
                                                args.dilate_endpoints)

        if len(fc_sft) > 0:
            save_tractogram(fc_sft,
                            os.path.join(
                                args.out_dir,
                                "{}_{}_fc{}".format(prefix_1, prefix_2, ext)),
                            bbox_valid_check=False)

        logging.info("Recognized {} streamlines between {} and {}".format(
            len(fc_sft.streamlines), prefix_1, prefix_2))

        fc_streamlines_list.append(fc_sft.streamlines)

    nc_streamlines.extend(sft.streamlines)

    final_results = {}
    no_conn_sft = StatefulTractogram.from_sft(nc_streamlines, sft)
    save_tractogram(no_conn_sft,
                    os.path.join(args.out_dir, "nc{}".format(ext)),
                    bbox_valid_check=False)

    # Total number of streamlines for each category
    # and statistic that are not "bundle-wise"
    tc_streamlines_count = len(list(itertools.chain(*tc_streamlines_list)))
    fc_streamlines_count = len(list(itertools.chain(*fc_streamlines_list)))

    if args.wrong_path_as_separate:
        wpc_streamlines_count = len(
            list(itertools.chain(*wpc_streamlines_list)))
    else:
        wpc_streamlines_count = 0

    nc_streamlines_count = len(nc_streamlines)
    total_count = tc_streamlines_count + fc_streamlines_count + \
        wpc_streamlines_count + nc_streamlines_count

    assert total_count == initial_count

    final_results["tractogram_filename"] = str(args.in_tractogram)
    final_results["tractogram_overlap"] = 0.0
    final_results["tc_streamlines"] = tc_streamlines_count
    final_results["fc_streamlines"] = fc_streamlines_count
    final_results["nc_streamlines"] = nc_streamlines_count

    final_results["tc_bundle"] = len([x for x in tc_streamlines_list if x])
    final_results["fc_bundle"] = len([x for x in fc_streamlines_list if x])

    final_results["tc_streamlines_ratio"] = tc_streamlines_count / total_count
    final_results["fc_streamlines_ratio"] = fc_streamlines_count / total_count
    final_results["nc_streamlines_ratio"] = nc_streamlines_count / total_count

    if args.wrong_path_as_separate:
        final_results["wpc_streamlines"] = wpc_streamlines_count
        final_results["wpc_streamlines_ratio"] = \
            wpc_streamlines_count / total_count
        final_results["wpc_bundle"] = len(
            [x for x in wpc_streamlines_list if x])

    final_results["total_streamlines"] = total_count
    final_results["bundle_wise"] = {}
    final_results["bundle_wise"]["true_connections"] = {}
    final_results["bundle_wise"]["false_connections"] = {}
    tractogram_overlap = 0.0

    for i, filename in enumerate(tc_filenames):
        current_tc_streamlines = tc_streamlines_list[i]
        current_tc_voxels, current_tc_endpoints_voxels = get_binary_maps(
            current_tc_streamlines, sft)

        if args.wrong_path_as_separate:
            current_wpc_streamlines = wpc_streamlines_list[i]
            current_wpc_voxels, _ = get_binary_maps(current_wpc_streamlines,
                                                    sft)

        tmp_dict = {}
        tmp_dict["tc_streamlines"] = len(current_tc_streamlines)

        tmp_dict["tc_dice"] = compute_dice_voxel(gt_bundle_masks[i],
                                                 current_tc_voxels)[0]

        bundle_overlap = gt_bundle_masks[i] * current_tc_voxels
        bundle_overreach = np.zeros(dimensions)
        bundle_overreach[np.where((gt_bundle_masks[i] == 0)
                                  & (current_tc_voxels >= 1))] = 1
        bundle_lacking = np.zeros(dimensions)
        bundle_lacking[np.where((gt_bundle_masks[i] == 1)
                                & (current_tc_voxels == 0))] = 1

        if args.wrong_path_as_separate:
            tmp_dict["wpc_streamlines"] = len(current_wpc_streamlines)
            tmp_dict["wpc_dice"] = \
                compute_dice_voxel(gt_bundle_masks[i],
                                   current_wpc_voxels)[0]
            # Add wrong path to overreach
            bundle_overreach[np.where((gt_bundle_masks[i] == 0)
                                      & (current_wpc_voxels >= 1))] = 1

        tmp_dict["tc_bundle_overlap"] = np.count_nonzero(bundle_overlap)
        tmp_dict["tc_bundle_overreach"] = \
            np.count_nonzero(bundle_overreach)
        tmp_dict["tc_bundle_lacking"] = np.count_nonzero(bundle_lacking)
        tmp_dict["tc_bundle_overlap_PCT"] = \
            tmp_dict["tc_bundle_overlap"] / \
            (tmp_dict["tc_bundle_overlap"] +
                tmp_dict["tc_bundle_lacking"])
        tractogram_overlap += tmp_dict["tc_bundle_overlap_PCT"]

        endpoints_overlap = \
            gt_bundle_masks[i] * current_tc_endpoints_voxels
        endpoints_overreach = np.zeros(dimensions)
        endpoints_overreach[np.where((gt_bundle_masks[i] == 0)
                                     & (current_tc_endpoints_voxels >= 1))] = 1
        tmp_dict["tc_endpoints_overlap"] = np.count_nonzero(endpoints_overlap)
        tmp_dict["tc_endpoints_overreach"] = np.count_nonzero(
            endpoints_overreach)

        final_results["bundle_wise"]["true_connections"][str(filename)] = \
            tmp_dict

    # Bundle-wise statistics, useful for more complex phantom
    for i, filename in enumerate(comb_filename):
        current_fc_streamlines = fc_streamlines_list[i]
        current_fc_voxels, _ = get_binary_maps(current_fc_streamlines, sft)

        tmp_dict = {}

        if len(current_fc_streamlines):
            tmp_dict["fc_streamlines"] = len(current_fc_streamlines)
            tmp_dict["fc_voxels"] = np.count_nonzero(current_fc_voxels)

            final_results["bundle_wise"]["false_connections"][str(filename)] =\
                tmp_dict

    final_results["tractogram_overlap"] = \
        tractogram_overlap / len(gt_bundle_masks)

    with open(os.path.join(args.out_dir, "results.json"), "w") as f:
        json.dump(final_results,
                  f,
                  indent=args.indent,
                  sort_keys=args.sort_keys)
Example #18
0
def main():
    # Callback required for FURY
    def keypress_callback(obj, _):
        key = obj.GetKeySym().lower()
        nonlocal clusters_linewidth, background_linewidth
        nonlocal curr_streamlines_actor, concat_streamlines_actor, show_curr_actor
        iterator = len(accepted_streamlines) + len(rejected_streamlines)
        renwin = interactor_style.GetInteractor().GetRenderWindow()
        renderer = interactor_style.GetCurrentRenderer()

        if key == 'c' and iterator < len(sft_accepted_on_size):
            if show_curr_actor:
                renderer.rm(concat_streamlines_actor)
                renwin.Render()
                show_curr_actor = False
                logging.info('Streamlines rendering OFF')
            else:
                renderer.add(concat_streamlines_actor)
                renderer.rm(curr_streamlines_actor)
                renderer.add(curr_streamlines_actor)
                renwin.Render()
                show_curr_actor = True
                logging.info('Streamlines rendering ON')
            return

        if key == 'q':
            show_manager.exit()
            if iterator < len(sft_accepted_on_size):
                logging.warning(
                    'Early exit, everything remaining to be rejected.')
            return

        if key in ['a', 'r'] and iterator < len(sft_accepted_on_size):
            if key == 'a':
                accepted_streamlines.append(iterator)
                choices.append('a')
                logging.info('Accepted file %s',
                             filename_accepted_on_size[iterator])
            elif key == 'r':
                rejected_streamlines.append(iterator)
                choices.append('r')
                logging.info('Rejected file %s',
                             filename_accepted_on_size[iterator])
            iterator += 1

        if key == 'z':
            if iterator > 0:
                last_choice = choices.pop()
                if last_choice == 'r':
                    rejected_streamlines.pop()
                else:
                    accepted_streamlines.pop()
                logging.info('Rewind on step.')

                iterator -= 1
            else:
                logging.warning('Cannot rewind, first element.')

        if key in ['a', 'r', 'z'] and iterator < len(sft_accepted_on_size):
            renderer.rm(curr_streamlines_actor)
            curr_streamlines = sft_accepted_on_size[iterator].streamlines
            curr_streamlines_actor = actor.line(curr_streamlines,
                                                opacity=0.8,
                                                linewidth=clusters_linewidth)
            renderer.add(curr_streamlines_actor)

        if iterator == len(sft_accepted_on_size):
            print('No more cluster, press q to exit')
            renderer.rm(curr_streamlines_actor)

        renwin.Render()

    parser = _build_args_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_bundles)
    assert_outputs_exist(parser, args, [args.out_accepted, args.out_rejected])

    if args.out_accepted_dir:
        assert_output_dirs_exist_and_empty(parser,
                                           args,
                                           args.out_accepted_dir,
                                           create_dir=True)
    if args.out_rejected_dir:
        assert_output_dirs_exist_and_empty(parser,
                                           args,
                                           args.out_rejected_dir,
                                           create_dir=True)

    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    if args.min_cluster_size < 1:
        parser.error('Minimum cluster size must be at least 1.')

    clusters_linewidth = args.clusters_linewidth
    background_linewidth = args.background_linewidth

    # To accelerate procedure, clusters can be discarded based on size
    # Concatenation is to give spatial context
    sft_accepted_on_size, filename_accepted_on_size = [], []
    sft_rejected_on_size, filename_rejected_on_size = [], []
    concat_streamlines = []
    for filename in args.in_bundles:
        if not is_header_compatible(args.in_bundles[0], filename):
            return
        basename = os.path.basename(filename)
        sft = load_tractogram_with_reference(parser,
                                             args,
                                             filename,
                                             bbox_check=False)
        if len(sft) >= args.min_cluster_size:
            sft_accepted_on_size.append(sft)
            filename_accepted_on_size.append(basename)
            concat_streamlines.extend(sft.streamlines)
        else:
            logging.info('File %s has %s streamlines, automatically rejected.',
                         filename, len(sft))
            sft_rejected_on_size.append(sft)
            filename_rejected_on_size.append(basename)

    if not filename_accepted_on_size:
        parser.error('No cluster survived the cluster_size threshold.')

    logging.info('%s clusters to be classified.', len(sft_accepted_on_size))
    # The clusters are sorted by size for simplicity/efficiency
    tuple_accepted = zip(
        *sorted(zip(sft_accepted_on_size, filename_accepted_on_size),
                key=lambda x: len(x[0]),
                reverse=True))
    sft_accepted_on_size, filename_accepted_on_size = tuple_accepted

    # Initialize the actors, scene, window, observer
    concat_streamlines_actor = actor.line(concat_streamlines,
                                          colors=(1, 1, 1),
                                          opacity=args.background_opacity,
                                          linewidth=background_linewidth)
    curr_streamlines_actor = actor.line(sft_accepted_on_size[0].streamlines,
                                        opacity=0.8,
                                        linewidth=clusters_linewidth)

    scene = window.Scene()
    interactor_style = interactor.CustomInteractorStyle()
    show_manager = window.ShowManager(scene,
                                      size=(800, 800),
                                      reset_camera=False,
                                      interactor_style=interactor_style)
    scene.add(concat_streamlines_actor)
    scene.add(curr_streamlines_actor)
    interactor_style.AddObserver('KeyPressEvent', keypress_callback)

    # Lauch rendering and selection procedure
    choices, accepted_streamlines, rejected_streamlines = [], [], []
    show_curr_actor = True
    show_manager.start()

    # Early exit means everything else is rejected
    missing = len(args.in_bundles) - len(choices) - len(sft_rejected_on_size)
    len_accepted = len(sft_accepted_on_size)
    rejected_streamlines.extend(range(len_accepted - missing, len_accepted))
    if missing > 0:
        logging.info('%s clusters automatically rejected from early exit',
                     missing)

    # Save accepted clusters (by GUI)
    accepted_streamlines = save_clusters(sft_accepted_on_size,
                                         accepted_streamlines,
                                         args.out_accepted_dir,
                                         filename_accepted_on_size)

    accepted_sft = StatefulTractogram(accepted_streamlines,
                                      sft_accepted_on_size[0], Space.RASMM)
    save_tractogram(accepted_sft, args.out_accepted, bbox_valid_check=False)

    # Save rejected clusters (by GUI)
    rejected_streamlines = save_clusters(sft_accepted_on_size,
                                         rejected_streamlines,
                                         args.out_rejected_dir,
                                         filename_accepted_on_size)

    # Save rejected clusters (by size)
    rejected_streamlines.extend(
        save_clusters(sft_rejected_on_size, range(len(sft_rejected_on_size)),
                      args.out_rejected_dir, filename_rejected_on_size))

    rejected_sft = StatefulTractogram(rejected_streamlines,
                                      sft_accepted_on_size[0], Space.RASMM)
    save_tractogram(rejected_sft, args.out_rejected, bbox_valid_check=False)
Example #19
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    required_args = [args.in_json, args.in_participants]
    assert_inputs_exist(parser, required_args)

    req_folder = os.path.join(args.out_dir, 'Graph')
    assert_output_dirs_exist_and_empty(parser, args, req_folder)

    # We generated the stats object
    my_data = data_for_stat(args.in_json, args.in_participants)

    bundles = args.bundles
    metrics = args.metrics
    values = args.values

    if args.bundles == 'all':
        bundles = my_data.get_bundles_list()
    if args.metrics == 'all':
        metrics = my_data.get_metrics_list()
    if args.values == 'all':
        values = my_data.get_values_list()

    alpha_error = float(args.alpha_error)

    my_group_dict = my_data.get_groups_dictionnary(args.group_by)
    # Initialise the result dictionnary
    result_dict = {}

    # We do the comparison for every single combianaison of metric-bundle-value
    for b, m, v in product(bundles, metrics, values):
        # First we extract the basic information on that comparison

        curr_comparison_measure = ('_').join([b, m, v])

        logging.debug('______________________')
        logging.debug('Measure to compare: {}'.format(curr_comparison_measure))

        # Check normality of that metric across all groups

        current_normality = {}
        overall_normality = True
        groups_array = []
        for group in my_group_dict:
            curr_sample = get_group_data_sample(my_group_dict, group, b, m, v)
            logging.debug('Group {}'.format(group))
            current_normality[group] = verify_normality(
                curr_sample, alpha_error)
            if not current_normality[group][0]:
                overall_normality = False
            groups_array.append(curr_sample)
        logging.debug('Normality result:')
        logging.debug(current_normality)
        logging.debug('Overall Normality:')
        logging.debug(overall_normality)
        logging.debug('Groups array:')
        logging.debug(groups_array)

        # Generate graph of the metric
        if args.generate_graph:
            visualise_distribution(groups_array,
                                   my_data.get_participants_list(), b, m, v,
                                   args.out_dir,
                                   my_data.get_groups_list(args.group_by))

        # Quit if we didnt separate by group
        if len(my_data.get_groups_list(args.group_by)) == 1:
            logging.error('There is only 1 group generated. '
                          'We cannot continue the groups comparison')
            raise BaseException('Only 1 group generated from '
                                '{}'.format(args.group_by))

        # Check homoscedasticity
        variance_equality = verify_homoscedasticity(
            groups_array, normality=overall_normality, alpha=alpha_error)
        logging.debug('Equality of variance result:')
        logging.debug(variance_equality)

        # Now we compare the groups population

        difference = verify_group_difference(
            groups_array,
            normality=overall_normality,
            homoscedasticity=variance_equality[1],
            alpha=alpha_error)
        logging.debug('Main test result:')
        logging.debug(difference)

        # Finally if we have more than 2 groups and found a difference
        # We do a post hoc analysis to explore where is this difference
        if difference[1] and difference[0] == 'ANOVA':
            diff_2_by_2 = verify_post_hoc(groups_array,
                                          my_data.get_groups_list(
                                              args.group_by),
                                          test='Student',
                                          alpha=alpha_error)
        elif difference[1] and difference[0] == 'Kruskalwallis':
            diff_2_by_2 = verify_post_hoc(groups_array,
                                          my_data.get_groups_list(
                                              args.group_by),
                                          test='Mannwhitneyu',
                                          alpha=alpha_error)
        elif difference[1] and difference[0] == 'Friedmann':
            diff_2_by_2 = verify_post_hoc(groups_array,
                                          my_data.get_groups_list(
                                              args.group_by),
                                          test='Wilcoxon',
                                          alpha=alpha_error)
        else:
            diff_2_by_2 = []

        logging.debug('Summary of difference 2 by 2:')
        logging.debug(diff_2_by_2)

        # Write the current metric in the report
        curr_dict = write_current_dictionnary(curr_comparison_measure,
                                              current_normality,
                                              variance_equality, difference,
                                              diff_2_by_2)
        result_dict[curr_comparison_measure] = curr_dict

    # Saving the result dictionnary into a json file and csv if necessary
    if args.out_json:
        with open(os.path.join(args.out_dir, args.out_json), 'w') as outfile:
            json.dump(result_dict,
                      outfile,
                      indent=args.indent,
                      sort_keys=args.sort_keys)
    else:
        print(
            json.dumps(result_dict,
                       indent=args.indent,
                       sort_keys=args.sort_keys))
Example #20
0
def main():
    parser = build_args_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.tracks, args.labels])

    if os.path.abspath(args.output) == os.getcwd():
        parser.error('Do not use the current path as output directory.')

    assert_output_dirs_exist_and_empty(parser, args, args.output)

    log_level = logging.WARNING
    if args.verbose:
        log_level = logging.INFO
    logging.basicConfig(level=log_level)

    img_labels = nb.load(args.labels)
    if not np.issubdtype(img_labels.get_data_dtype().type, np.integer):
        parser.error("Label image should contain integers for labels.")

    # Ensure that voxel size is isotropic. Currently, for speed considerations,
    # we take the length in voxel space and multiply by the voxel size. For
    # this to work correctly, voxel size must be isotropic.
    vox_sizes = img_labels.header.get_zooms()
    if not np.mean(vox_sizes) == vox_sizes[0]:
        parser.error('Labels must be isotropic')

    if np.min(img_labels.get_data()) < 0 or \
            np.max(img_labels.get_data()) > args.max_labels:
        parser.error('Invalid labels in labels image')

    logging.info('*** Loading streamlines ***')
    time1 = time.time()
    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
    sft.to_vox()
    sft.to_corner()
    time2 = time.time()

    logging.info('    Number of streamlines to process: {}'.format(
        len(sft.streamlines)))
    logging.info('    Loading streamlines took %0.3f ms',
                 (time2 - time1) * 1000.0)

    # Get all streamlines intersection indices
    logging.info('*** Computing streamlines intersection ***')
    time1 = time.time()

    indices, points_to_idx = uncompress(sft.streamlines, return_mapping=True)

    time2 = time.time()
    logging.info('    Streamlines intersection took %0.3f ms',
                 (time2 - time1) * 1000.0)

    # Compute the connectivity mapping
    logging.info('*** Computing connectivity information ***')
    time1 = time.time()
    con_info = compute_connectivity(indices, img_labels.get_data(),
                                    extract_longest_segments_from_profile)
    time2 = time.time()
    logging.info('    Connectivity computation took %0.3f ms',
                 (time2 - time1) * 1000.0)

    # Symmetrize matrix
    final_con_info = _symmetrize_con_info(con_info)

    # Prepare directories and information needed to save.
    saving_opts = _get_saving_options(args)
    out_paths = _get_output_paths(args.output)
    _create_required_output_dirs(out_paths, args)

    # Here, we use nb_labels + 1 since we want the direct mapping from image
    # label to matrix element. We will remove the first row and column before
    # saving.
    # TODO for other metrics
    # dtype should be adjusted depending on the type of elements
    # stored in the con_mat
    nb_labels = args.max_labels
    con_mat = np.zeros((nb_labels + 1, nb_labels + 1), dtype=np.uint32)

    logging.info('*** Starting connection post-processing and saving. ***')
    logging.info('    This can be long, be patient.')
    time1 = time.time()
    for in_label in list(final_con_info.keys()):
        for out_label in list(final_con_info[in_label].keys()):
            pair_info = final_con_info[in_label][out_label]

            if not len(pair_info):
                continue

            final_strl = []

            for connection in pair_info:
                strl_idx = connection['strl_idx']
                final_strl.append(
                    compute_streamline_segment(sft.streamlines[strl_idx],
                                               indices[strl_idx],
                                               connection['in_idx'],
                                               connection['out_idx'],
                                               points_to_idx[strl_idx]))

            _save_if_needed(final_strl, args, saving_opts, out_paths, 'raw',
                            'raw', in_label, out_label)

            # Doing all post-processing
            if not args.no_pruning:
                pruned_strl, invalid_strl = _prune_segments(
                    final_strl, args.min_length, args.max_length, vox_sizes[0])

                _save_if_needed(invalid_strl, args, saving_opts, out_paths,
                                'discarded', 'removed_length', in_label,
                                out_label)
            else:
                pruned_strl = final_strl

            if not len(pruned_strl):
                continue

            _save_if_needed(pruned_strl, args, saving_opts, out_paths,
                            'intermediate', 'pruned', in_label, out_label)

            if not args.no_remove_loops:
                no_loops, loops = remove_loops_and_sharp_turns(
                    pruned_strl, args.loop_max_angle)
                _save_if_needed(loops, args, saving_opts, out_paths,
                                'discarded', 'loops', in_label, out_label)
            else:
                no_loops = pruned_strl

            if not len(no_loops):
                continue

            _save_if_needed(no_loops, args, saving_opts, out_paths,
                            'intermediate', 'no_loops', in_label, out_label)

            if not args.no_remove_outliers:
                no_outliers, outliers = remove_outliers(
                    no_loops, args.outlier_threshold)
                _save_if_needed(outliers, args, saving_opts, out_paths,
                                'discarded', 'outliers', in_label, out_label)
            else:
                no_outliers = no_loops

            if not len(no_outliers):
                continue

            _save_if_needed(no_outliers, args, saving_opts, out_paths,
                            'intermediate', 'no_outliers', in_label, out_label)

            if not args.no_remove_loops_again:
                no_qb_loops_strl, loops2 = remove_loops_and_sharp_turns(
                    no_outliers, args.loop_max_angle, True,
                    args.loop_qb_distance)
                _save_if_needed(loops2, args, saving_opts, out_paths,
                                'discarded', 'qb_loops', in_label, out_label)
            else:
                no_qb_loops_strl = no_outliers

            _save_if_needed(no_qb_loops_strl, args, saving_opts, out_paths,
                            'final', 'final', in_label, out_label)

            # TODO for other metrics
            # This would be where this is modified and the value
            # is computed (eg: mean FA in the connection.
            con_mat[in_label, out_label] += len(no_qb_loops_strl)

    time2 = time.time()
    logging.info('    Connection post-processing and saving took %0.3f ms',
                 (time2 - time1) * 1000.0)

    # Remove first line and column, since they are index 0 and
    # would represent a connection to non-label voxels. Only used when
    # post-processing to avoid unnecessary -1 on labels for each access.
    con_mat = con_mat[1:, 1:]
    np.save(os.path.join(args.output, 'final_matrix.npy'), con_mat)
Example #21
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(
        parser, [args.in_tractogram, args.in_dwi, args.in_bval, args.in_bvec],
        [args.in_peaks, args.in_tracking_mask])
    assert_output_dirs_exist_and_empty(parser,
                                       args,
                                       args.out_dir,
                                       optional=args.save_kernels)

    if args.commit2:
        if os.path.splitext(args.in_tractogram)[1] != '.h5':
            parser.error('COMMIT2 requires .h5 file for connectomics.')
        args.ball_stick = True

    if args.load_kernels and not os.path.isdir(args.load_kernels):
        parser.error('Kernels directory does not exist.')

    if args.compute_only and not args.save_kernels:
        parser.error('--compute_only must be used with --save_kernels.')

    if args.load_kernels and args.save_kernels:
        parser.error('Cannot load and save kernels at the same time.')

    if args.ball_stick and args.perp_diff:
        parser.error('Cannot use --perp_diff with ball&stick.')

    if not args.ball_stick and not args.in_peaks:
        parser.error('Stick Zeppelin Ball model requires --in_peaks')

    if args.ball_stick and args.iso_diff and len(args.iso_diff) > 1:
        parser.error('Cannot use more than one --iso_diff with ' 'ball&stick.')

    # If it is a trk, check compatibility of header since COMMIT does not do it
    dwi_img = nib.load(args.in_dwi)
    _, ext = os.path.splitext(args.in_tractogram)
    if ext == '.trk' and not is_header_compatible(args.in_tractogram, dwi_img):
        parser.error('{} does not have a compatible header with {}'.format(
            args.in_tractogram, args.in_dwi))

    # COMMIT has some c-level stdout and non-logging print that cannot
    # be easily stopped. Manual redirection of all printed output
    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)
        redirected_stdout = redirect_stdout(sys.stdout)
    else:
        f = io.StringIO()
        redirected_stdout = redirect_stdout(f)
        redirect_stdout_c()

    tmp_dir = tempfile.TemporaryDirectory()
    hdf5_file = None
    offsets_list = None
    if ext == '.h5':
        logging.debug('Reconstructing {} into a tractogram for COMMIT.'.format(
            args.in_tractogram))

        hdf5_file = h5py.File(args.in_tractogram, 'r')
        if not (np.allclose(
                hdf5_file.attrs['affine'], dwi_img.affine, atol=1e-03)
                and np.array_equal(hdf5_file.attrs['dimensions'],
                                   dwi_img.shape[0:3])):
            parser.error('{} does not have a compatible header with {}'.format(
                args.in_tractogram, args.in_dwi))

        # Keep track of the order of connections/streamlines in relation to the
        # tractogram as well as the number of streamlines for each connection.
        bundle_groups_len = []
        hdf5_keys = list(hdf5_file.keys())
        streamlines = []
        for key in hdf5_keys:
            tmp_streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key)
            streamlines.extend(tmp_streamlines)
            bundle_groups_len.append(len(tmp_streamlines))

        offsets_list = np.cumsum([0] + bundle_groups_len)
        sft = StatefulTractogram(streamlines,
                                 args.in_dwi,
                                 Space.VOX,
                                 origin=Origin.TRACKVIS)
        tmp_tractogram_filename = os.path.join(tmp_dir.name, 'tractogram.trk')

        # Keeping the input variable, saving trk file for COMMIT internal use
        save_tractogram(sft, tmp_tractogram_filename)
        args.in_tractogram = tmp_tractogram_filename

    # Writing the scheme file with proper shells
    tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme')
    tmp_bval_filename = os.path.join(tmp_dir.name, 'bval')
    bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec)
    shells_centroids, indices_shells = identify_shells(bvals,
                                                       args.b_thr,
                                                       roundCentroids=True)
    np.savetxt(tmp_bval_filename,
               shells_centroids[indices_shells],
               newline=' ',
               fmt='%i')
    fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename)
    logging.debug('Lauching COMMIT on {} shells at found at {}.'.format(
        len(shells_centroids), shells_centroids))

    if len(shells_centroids) == 2 and not args.ball_stick:
        parser.error('The DWI data appears to be single-shell.\n'
                     'Use --ball_stick for single-shell.')

    with redirected_stdout:
        # Setting up the tractogram and nifti files
        trk2dictionary.run(filename_tractogram=args.in_tractogram,
                           filename_peaks=args.in_peaks,
                           peaks_use_affine=False,
                           filename_mask=args.in_tracking_mask,
                           ndirs=args.nbr_dir,
                           path_out=tmp_dir.name)

        # Preparation for fitting
        commit.core.setup(ndirs=args.nbr_dir)
        mit = commit.Evaluation('.', '.')

        # FIX for very small values during HCP processing
        # (based on order of magnitude of signal)
        img = nib.load(args.in_dwi)
        data = img.get_fdata(dtype=np.float32)
        data[data < (0.001 *
                     10**np.floor(np.log10(np.mean(data[data > 0]))))] = 0
        nib.save(nib.Nifti1Image(data, img.affine),
                 os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'))

        mit.load_data(os.path.join(tmp_dir.name, 'dwi_zero_fix.nii.gz'),
                      tmp_scheme_filename)
        mit.set_model('StickZeppelinBall')

        if args.ball_stick:
            logging.debug('Disabled zeppelin, using the Ball & Stick model.')
            para_diff = args.para_diff or 1.7E-3
            perp_diff = []
            isotropc_diff = args.iso_diff or [2.0E-3]
            mit.model.set(para_diff, perp_diff, isotropc_diff)
        else:
            logging.debug('Using the Stick Zeppelin Ball model.')
            para_diff = args.para_diff or 1.7E-3
            perp_diff = args.perp_diff or [0.85E-3, 0.51E-3]
            isotropc_diff = args.iso_diff or [1.7E-3, 3.0E-3]
            mit.model.set(para_diff, perp_diff, isotropc_diff)

        # The kernels are, by default, set to be in the current directory
        # Depending on the choice, manually change the saving location
        if args.save_kernels:
            kernels_dir = os.path.join(args.save_kernels)
            regenerate_kernels = True
        elif args.load_kernels:
            kernels_dir = os.path.join(args.load_kernels)
            regenerate_kernels = False
        else:
            kernels_dir = os.path.join(tmp_dir.name, 'kernels', mit.model.id)
            regenerate_kernels = True
        mit.set_config('ATOMS_path', kernels_dir)

        mit.generate_kernels(ndirs=args.nbr_dir, regenerate=regenerate_kernels)
        if args.compute_only:
            return
        mit.load_kernels()
        use_mask = args.in_tracking_mask is not None
        mit.load_dictionary(tmp_dir.name, use_all_voxels_in_mask=use_mask)
        mit.set_threads(args.nbr_processes)

        mit.build_operator(build_dir=os.path.join(tmp_dir.name, 'build/'))
        tol_fun = 1e-2 if args.commit2 else 1e-3
        mit.fit(tol_fun=tol_fun, max_iter=args.nbr_iter, verbose=False)
        mit.save_results()
        _save_results_wrapper(args, tmp_dir, ext, hdf5_file, offsets_list,
                              'commit_1/', False)

        if args.commit2:
            tmp = np.insert(np.cumsum(bundle_groups_len), 0, 0)
            group_idx = np.array(
                [np.arange(tmp[i], tmp[i + 1]) for i in range(len(tmp) - 1)])
            group_w = np.empty_like(bundle_groups_len, dtype=np.float64)
            for k in range(len(bundle_groups_len)):
                group_w[k] = np.sqrt(bundle_groups_len[k]) / \
                    (np.linalg.norm(mit.x[group_idx[k]]) + 1e-12)
            prior_on_bundles = commit.solvers.init_regularisation(
                mit,
                structureIC=group_idx,
                weightsIC=group_w,
                regnorms=[
                    commit.solvers.group_sparsity, commit.solvers.non_negative,
                    commit.solvers.non_negative
                ],
                lambdas=[args.lambda_commit_2, 0.0, 0.0])
            mit.fit(tol_fun=1e-3,
                    max_iter=args.nbr_iter,
                    regularisation=prior_on_bundles,
                    verbose=False)
            mit.save_results()
            _save_results_wrapper(args, tmp_dir, ext, hdf5_file, offsets_list,
                                  'commit_2/', True)

    tmp_dir.cleanup()
Example #22
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_inputs_exist(parser, args.in_wmparc)
    assert_output_dirs_exist_and_empty(parser,
                                       args,
                                       args.out_path,
                                       create_dir=True)

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    if args.angle <= 0:
        parser.error('Angle "{}" '.format(args.angle) +
                     'must be greater than or equal to 0')
    if args.ctx_dilation_radius < 0:
        parser.error(
            'Cortex dilation radius "{}" '.format(args.ctx_dilation_radius) +
            'must be greater than 0')

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)

    img_wmparc = nib.load(args.in_wmparc)
    if not is_header_compatible(img_wmparc, sft):
        parser.error('Headers from the tractogram and the wmparc are '
                     'not compatible.')
    if args.csf_bin:
        img_csf = nib.load(args.csf_bin)
        if not is_header_compatible(img_csf, sft):
            parser.error('Headers from the tractogram and the CSF mask are '
                         'not compatible.')

    if args.minL == 0 and np.isinf(args.maxL):
        logging.debug("You have not specified minL nor maxL. Output will "
                      "not be filtered according to length!")
    if np.isinf(args.angle):
        logging.debug("You have not specified the angle. Loops will "
                      "not be filtered!")
    if args.ctx_dilation_radius == 0:
        logging.debug("You have not specified the cortex dilation radius. "
                      "The wmparc atlas will not be dilated!")

    o_dict = {}
    step_dict = ['length', 'no_loops', 'no_end_csf', 'end_in_atlas']
    wm_labels = load_wmparc_labels()

    in_sft_name = os.path.splitext(os.path.basename(args.in_tractogram))[0]
    out_sft_rootname = in_sft_name + "_filtered"
    _, ext = os.path.splitext(args.in_tractogram)
    out_sft_name = os.path.join(args.out_path,
                                out_sft_rootname + "_filtered" + ext)

    # STEP 1 - Filter length
    step = step_dict[0]
    steps_combined = step

    new_sft = filter_streamlines_by_length(sft, args.minL, args.maxL)

    # Streamline count before and after filtering lengths
    o_dict[in_sft_name + ext] =\
        dict({'streamline_count': len(sft.streamlines)})
    o_dict[in_sft_name + '_' + steps_combined + ext] =\
        dict({'streamline_count': len(new_sft.streamlines)})

    if args.save_intermediate_tractograms:
        outliers_sft = compute_outliers(sft, new_sft)
        new_path = create_dir(args.out_path, step)
        save_intermediate_sft(new_sft, outliers_sft, new_path, in_sft_name,
                              step, steps_combined, ext, args.no_empty)
        o_dict[in_sft_name + '_' + step + '_outliers' + ext] =\
            dict({'streamline_count': len(outliers_sft.streamlines)})

    if len(new_sft.streamlines) == 0:
        if args.no_empty:
            logging.debug("The file {} won't be written".format(out_sft_name) +
                          "(0 streamlines after " + step + " filtering).")

            if args.verbose:
                display_count(o_dict, args.indent, args.sort_keys)

            if args.save_counts:
                save_count(o_dict, args.out_path, args.indent, args.sort_keys)

            return

        logging.debug(
            'The file {} contains 0 streamlines after '.format(out_sft_name) +
            step + ' filtering')
        save_tractogram(new_sft, out_sft_name)

        if args.verbose:
            display_count(o_dict, args.indent, args.sort_keys)

        if args.save_counts:
            save_count(o_dict, args.out_path, args.indent, args.sort_keys)

        return

    sft = new_sft

    # STEP 2 - Filter loops
    step = step_dict[1]
    steps_combined += "_" + step

    ids_c = remove_loops_and_sharp_turns(sft.streamlines, args.angle)
    new_sft = filter_tractogram_data(sft, ids_c)

    # Streamline count after filtering loops
    o_dict[in_sft_name + '_' + steps_combined + ext] =\
        dict({'streamline_count': len(new_sft.streamlines)})

    if args.save_intermediate_tractograms:
        outliers_sft = compute_outliers(sft, new_sft)
        new_path = create_dir(args.out_path, step)
        save_intermediate_sft(new_sft, outliers_sft, new_path, in_sft_name,
                              step, steps_combined, ext, args.no_empty)
        o_dict[in_sft_name + '_' + step + '_outliers' + ext] =\
            dict({'streamline_count': len(outliers_sft.streamlines)})

    if len(new_sft.streamlines) == 0:
        if args.no_empty:
            logging.debug("The file {} won't be written".format(out_sft_name) +
                          "(0 streamlines after " + step + " filtering).")

            if args.verbose:
                display_count(o_dict, args.indent, args.sort_keys)

            if args.save_counts:
                save_count(o_dict, args.out_path, args.indent, args.sort_keys)

            return

        logging.debug(
            'The file {} contains 0 streamlines after '.format(out_sft_name) +
            step + ' filtering')

        save_tractogram(new_sft, out_sft_name)

        if args.verbose:
            display_count(o_dict, args.indent, args.sort_keys)

        if args.save_counts:
            save_count(o_dict, args.out_path, args.indent, args.sort_keys)

        return

    sft = new_sft

    # STEP 3 - Filter CSF endings
    step = step_dict[2]
    steps_combined += "_" + step

    # Mask creation
    if args.csf_bin:
        mask = get_data_as_mask(img_csf)
    else:
        atlas = get_data_as_label(img_wmparc)
        mask = binarize_labels(atlas, wm_labels["csf_labels"])

    # Filter tractogram
    new_sft, _ = filter_grid_roi(sft, mask, 'any', True)

    # Streamline count after filtering CSF endings
    o_dict[in_sft_name + '_' + steps_combined + ext] =\
        dict({'streamline_count': len(new_sft.streamlines)})

    if args.save_volumes:
        new_path = create_dir(args.out_path, step)
        if not args.csf_bin:
            nib.save(
                nib.Nifti1Image(mask, img_wmparc.affine, img_wmparc.header),
                os.path.join(new_path, 'csf_bin' + '.nii.gz'))

    if args.save_intermediate_tractograms:
        outliers_sft = compute_outliers(sft, new_sft)
        new_path = create_dir(args.out_path, step)
        save_intermediate_sft(new_sft, outliers_sft, new_path, in_sft_name,
                              step, steps_combined, ext, args.no_empty)
        o_dict[in_sft_name + '_' + step + '_outliers' + ext] =\
            dict({'streamline_count': len(outliers_sft.streamlines)})

    if len(new_sft.streamlines) == 0:
        if args.no_empty:
            logging.debug("The file {} won't be written".format(out_sft_name) +
                          "(0 streamlines after " + step + " filtering).")

            if args.verbose:
                display_count(o_dict, args.indent, args.sort_keys)

            if args.save_counts:
                save_count(o_dict, args.out_path, args.indent, args.sort_keys)

            return

        logging.debug(
            'The file {} contains 0 streamlines after '.format(out_sft_name) +
            step + ' filtering')

        save_tractogram(new_sft, out_sft_name)

        if args.verbose:
            display_count(o_dict, args.indent, args.sort_keys)

        if args.save_counts:
            save_count(o_dict, args.out_path, args.indent, args.sort_keys)

        return

    sft = new_sft

    # STEP 4 - Filter WM endings
    step = step_dict[3]
    steps_combined += "_" + step

    # Mask creation
    ctx_fs_labels = wm_labels["ctx_lh_fs_labels"] + \
        wm_labels["ctx_rh_fs_labels"]
    vox_size = np.reshape(img_wmparc.header.get_zooms(), (1, 3))
    atlas_wm = get_data_as_label(img_wmparc)
    atlas_shape = atlas_wm.shape
    wmparc_ctx = binarize_labels(atlas_wm, ctx_fs_labels)
    wmparc_nuclei = binarize_labels(atlas_wm, wm_labels["nuclei_fs_labels"])

    # Dilation of cortex
    if args.ctx_dilation_radius:
        ctx_mask = dilate_mask(wmparc_ctx, atlas_shape, vox_size,
                               args.ctx_dilation_radius)
    else:
        ctx_mask = wmparc_ctx

    freesurfer_mask = np.zeros(atlas_shape, dtype=np.uint16)
    freesurfer_mask[np.logical_or(wmparc_nuclei, ctx_mask)] = 1

    # Filter tractogram
    new_sft, _ = filter_grid_roi(sft, freesurfer_mask, 'both_ends', False)

    # Streamline count after final filtering
    o_dict[out_sft_rootname + ext] =\
        dict({'streamline_count': len(new_sft.streamlines)})

    if args.save_volumes:
        new_path = create_dir(args.out_path, step)
        nib.save(
            nib.Nifti1Image(freesurfer_mask, img_wmparc.affine,
                            img_wmparc.header),
            os.path.join(new_path, 'atlas_bin' + '.nii.gz'))

    if args.save_intermediate_tractograms:
        outliers_sft = compute_outliers(sft, new_sft)
        new_path = create_dir(args.out_path, step)
        save_intermediate_sft(new_sft, outliers_sft, new_path, in_sft_name,
                              step, steps_combined, ext, args.no_empty)
        o_dict[in_sft_name + '_' + step + '_outliers' + ext] =\
            dict({'streamline_count': len(outliers_sft.streamlines)})

    # Finish filtering
    if args.verbose:
        display_count(o_dict, args.indent, args.sort_keys)

    if args.save_counts:
        save_count(o_dict, args.out_path, args.indent, args.sort_keys)

    if len(new_sft.streamlines) == 0:
        if args.no_empty:
            logging.debug("The file {} won't be written".format(out_sft_name) +
                          "(0 streamlines after " + step + " filtering).")
            return
        logging.debug(
            'The file {} contains 0 streamlines after '.format(out_sft_name) +
            step + ' filtering')

    sft = new_sft
    save_tractogram(sft, out_sft_name)
Example #23
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram, args.in_labels],
                        args.reference)
    assert_outputs_exist(parser, args, args.out_hdf5)

    # HDF5 will not overwrite the file
    if os.path.isfile(args.out_hdf5):
        os.remove(args.out_hdf5)

    if (args.save_raw_connections or args.save_intermediate
            or args.save_discarded) and not args.out_dir:
        parser.error('To save outputs in the streamlines form, provide the '
                     'output directory using --out_dir.')

    if args.out_dir:
        if os.path.abspath(args.out_dir) == os.getcwd():
            parser.error('Do not use the current path as output directory.')
        assert_output_dirs_exist_and_empty(parser,
                                           args,
                                           args.out_dir,
                                           create_dir=True)

    log_level = logging.WARNING
    if args.verbose:
        log_level = logging.INFO
    logging.basicConfig(level=log_level)
    coloredlogs.install(level=log_level)
    set_sft_logger_level('WARNING')

    img_labels = nib.load(args.in_labels)
    data_labels = get_data_as_label(img_labels)
    real_labels = np.unique(data_labels)[1:]
    if args.out_labels_list:
        np.savetxt(args.out_labels_list, real_labels, fmt='%i')

    # Voxel size must be isotropic, for speed/performance considerations
    vox_sizes = img_labels.header.get_zooms()
    if not np.allclose(np.mean(vox_sizes), vox_sizes, atol=1e-03):
        parser.error('Labels must be isotropic')

    logging.info('*** Loading streamlines ***')
    time1 = time.time()
    sft = load_tractogram_with_reference(parser,
                                         args,
                                         args.in_tractogram,
                                         bbox_check=False)
    sft.remove_invalid_streamlines()
    time2 = time.time()
    logging.info('    Loading {} streamlines took {} sec.'.format(
        len(sft), round(time2 - time1, 2)))

    if not is_header_compatible(sft, img_labels):
        raise IOError('{} and {}do not have a compatible header'.format(
            args.in_tractogram, args.in_labels))

    sft.to_vox()
    sft.to_corner()
    # Get all streamlines intersection indices
    logging.info('*** Computing streamlines intersection ***')
    time1 = time.time()

    indices, points_to_idx = uncompress(sft.streamlines, return_mapping=True)

    time2 = time.time()
    logging.info('    Streamlines intersection took {} sec.'.format(
        round(time2 - time1, 2)))

    # Compute the connectivity mapping
    logging.info('*** Computing connectivity information ***')
    time1 = time.time()
    con_info = compute_connectivity(indices, data_labels, real_labels,
                                    extract_longest_segments_from_profile)
    time2 = time.time()
    logging.info('    Connectivity computation took {} sec.'.format(
        round(time2 - time1, 2)))

    # Prepare directories and information needed to save.
    _create_required_output_dirs(args)

    logging.info('*** Starting connection post-processing and saving. ***')
    logging.info('    This can be long, be patient.')
    time1 = time.time()

    # Saving will be done from streamlines already in the right space
    comb_list = list(itertools.combinations(real_labels, r=2))
    comb_list.extend(zip(real_labels, real_labels))

    iteration_counter = 0
    with h5py.File(args.out_hdf5, 'w') as hdf5_file:
        affine, dimensions, voxel_sizes, voxel_order = get_reference_info(sft)
        hdf5_file.attrs['affine'] = affine
        hdf5_file.attrs['dimensions'] = dimensions
        hdf5_file.attrs['voxel_sizes'] = voxel_sizes
        hdf5_file.attrs['voxel_order'] = voxel_order

        # Each connections is processed independently. Multiprocessing would be
        # a burden on the I/O of most SSD/HD
        for in_label, out_label in comb_list:
            if iteration_counter > 0 and iteration_counter % 100 == 0:
                logging.info('Split {} nodes out of {}'.format(
                    iteration_counter, len(comb_list)))
            iteration_counter += 1

            pair_info = []
            if in_label not in con_info:
                continue
            elif out_label in con_info[in_label]:
                pair_info.extend(con_info[in_label][out_label])

            if out_label not in con_info:
                continue
            elif in_label in con_info[out_label]:
                pair_info.extend(con_info[out_label][in_label])

            if not len(pair_info):
                continue

            connecting_streamlines = []
            connecting_ids = []
            for connection in pair_info:
                strl_idx = connection['strl_idx']
                curr_streamlines = compute_streamline_segment(
                    sft.streamlines[strl_idx], indices[strl_idx],
                    connection['in_idx'], connection['out_idx'],
                    points_to_idx[strl_idx])
                connecting_streamlines.append(curr_streamlines)
                connecting_ids.append(strl_idx)

            # Each step is processed from the previous 'success'
            #   1. raw         -> length pass/fail
            #   2. length pass -> loops pass/fail
            #   3. loops pass  -> outlier detection pass/fail
            #   4. outlier detection pass -> qb curvature pass/fail
            #   5. qb curvature pass == final connections
            connecting_streamlines = ArraySequence(connecting_streamlines)
            raw_dps = sft.data_per_streamline[connecting_ids]
            raw_sft = StatefulTractogram.from_sft(connecting_streamlines,
                                                  sft,
                                                  data_per_streamline=raw_dps,
                                                  data_per_point={})
            _save_if_needed(raw_sft, hdf5_file, args, 'raw', 'raw', in_label,
                            out_label)

            # Doing all post-processing
            if not args.no_pruning:
                valid_length_ids, invalid_length_ids = _prune_segments(
                    raw_sft.streamlines, args.min_length, args.max_length,
                    vox_sizes[0])

                invalid_length_sft = raw_sft[invalid_length_ids]
                valid_length = connecting_streamlines[valid_length_ids]
                _save_if_needed(invalid_length_sft, hdf5_file, args,
                                'discarded', 'invalid_length', in_label,
                                out_label)
            else:
                valid_length = connecting_streamlines
                valid_length_ids = range(len(connecting_streamlines))

            if not len(valid_length):
                continue

            valid_length_sft = raw_sft[valid_length_ids]
            _save_if_needed(valid_length_sft, hdf5_file, args, 'intermediate',
                            'valid_length', in_label, out_label)

            if not args.no_remove_loops:
                no_loop_ids = remove_loops_and_sharp_turns(
                    valid_length, args.loop_max_angle)
                loop_ids = np.setdiff1d(np.arange(len(valid_length)),
                                        no_loop_ids)

                loops_sft = valid_length_sft[loop_ids]
                no_loops = valid_length[no_loop_ids]
                _save_if_needed(loops_sft, hdf5_file, args, 'discarded',
                                'loops', in_label, out_label)
            else:
                no_loops = valid_length
                no_loop_ids = range(len(valid_length))

            if not len(no_loops):
                continue
            no_loops_sft = valid_length_sft[no_loop_ids]
            _save_if_needed(no_loops_sft, hdf5_file, args, 'intermediate',
                            'no_loops', in_label, out_label)

            if not args.no_remove_outliers:
                outliers_ids, inliers_ids = remove_outliers(
                    no_loops,
                    args.outlier_threshold,
                    nb_samplings=10,
                    fast_approx=True)

                outliers_sft = no_loops_sft[outliers_ids]
                inliers = no_loops[inliers_ids]
                _save_if_needed(outliers_sft, hdf5_file, args, 'discarded',
                                'outliers', in_label, out_label)
            else:
                inliers = no_loops
                inliers_ids = range(len(no_loops))

            if not len(inliers):
                continue

            inliers_sft = no_loops_sft[inliers_ids]
            _save_if_needed(inliers_sft, hdf5_file, args, 'intermediate',
                            'inliers', in_label, out_label)

            if not args.no_remove_curv_dev:
                no_qb_curv_ids = remove_loops_and_sharp_turns(
                    inliers,
                    args.loop_max_angle,
                    use_qb=True,
                    qb_threshold=args.curv_qb_distance)
                qb_curv_ids = np.setdiff1d(np.arange(len(inliers)),
                                           no_qb_curv_ids)

                qb_curv_sft = inliers_sft[qb_curv_ids]
                _save_if_needed(qb_curv_sft, hdf5_file, args, 'discarded',
                                'qb_curv', in_label, out_label)
            else:
                no_qb_curv_ids = range(len(inliers))

            no_qb_curv_sft = inliers_sft[no_qb_curv_ids]
            _save_if_needed(no_qb_curv_sft, hdf5_file, args, 'final', 'final',
                            in_label, out_label)

    time2 = time.time()
    logging.info(
        '    Connections post-processing and saving took {} sec.'.format(
            round(time2 - time1, 2)))
Example #24
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.compute_only and not args.save_kernels:
        parser.error('--compute_only must be used with --save_kernels.')

    assert_inputs_exist(parser, [args.in_dwi, args.in_bval, args.in_bvec],
                        args.mask)

    assert_output_dirs_exist_and_empty(parser,
                                       args,
                                       args.out_dir,
                                       optional=args.save_kernels)

    # COMMIT has some c-level stdout and non-logging print that cannot
    # be easily stopped. Manual redirection of all printed output
    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)
        redirected_stdout = redirect_stdout(sys.stdout)
    else:
        f = io.StringIO()
        redirected_stdout = redirect_stdout(f)
        redirect_stdout_c()

    # Generage a scheme file from the bvals and bvecs files
    tmp_dir = tempfile.TemporaryDirectory()
    tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme')
    tmp_bval_filename = os.path.join(tmp_dir.name, 'bval')
    bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec)
    shells_centroids, indices_shells = identify_shells(bvals,
                                                       args.b_thr,
                                                       roundCentroids=True)
    np.savetxt(tmp_bval_filename,
               shells_centroids[indices_shells],
               newline=' ',
               fmt='%i')
    fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename)
    logging.debug(
        'Compute FreeWater with AMICO on {} shells at found at {}.'.format(
            len(shells_centroids), shells_centroids))

    with redirected_stdout:
        amico.core.setup()
        # Load the data
        ae = amico.Evaluation('.', '.')
        # Load the data
        ae.load_data(args.in_dwi,
                     scheme_filename=tmp_scheme_filename,
                     mask_filename=args.mask)

        # Compute the response functions
        ae.set_model("FreeWater")
        model_type = 'Human'
        if args.mouse:
            model_type = 'Mouse'

        ae.model.set(args.para_diff,
                     np.linspace(args.perp_diff_min, args.perp_diff_max, 10),
                     [args.iso_diff], model_type)

        ae.set_solver(lambda1=args.lambda1, lambda2=args.lambda2)

        # The kernels are, by default, set to be in the current directory
        # Depending on the choice, manually change the saving location
        if args.save_kernels:
            kernels_dir = os.path.join(args.save_kernels)
            regenerate_kernels = True
        elif args.load_kernels:
            kernels_dir = os.path.join(args.load_kernels)
            regenerate_kernels = False
        else:
            kernels_dir = os.path.join(tmp_dir.name, 'kernels', ae.model.id)
            regenerate_kernels = True

        ae.set_config('ATOMS_path', kernels_dir)
        ae.set_config('OUTPUT_path', args.out_dir)
        ae.generate_kernels(regenerate=regenerate_kernels)
        if args.compute_only:
            return

        ae.load_kernels()

        # Set number of processes
        solver_params = ae.get_config('solver_params')
        solver_params['numThreads'] = args.nbr_processes
        ae.set_config('solver_params', solver_params)

        ae.set_config('doNormalizeSignal', True)
        ae.set_config('doKeepb0Intact', False)
        ae.set_config('doComputeNRMSE', True)
        ae.set_config('doSaveCorrectedDWI', True)

        # Model fit
        ae.fit()
        # Save the results
        ae.save_results()

    tmp_dir.cleanup()
Example #25
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_output_dirs_exist_and_empty(parser,
                                       args,
                                       os.path.join(args.out_dir,
                                                    'Contrasts_MT_maps'),
                                       os.path.join(args.out_dir,
                                                    'MT_native_maps'),
                                       create_dir=True)

    # Merge all echos path into a list
    maps = [args.in_mtoff, args.in_mton, args.in_t1w]

    maps_flat = (args.in_mtoff + args.in_mton + args.in_t1w)

    jsons = [curr_map.replace('.nii.gz', '.json') for curr_map in maps_flat]

    # check data
    assert_inputs_exist(parser, jsons + maps_flat)
    for curr_map in maps[1:]:
        if len(curr_map) != len(maps[0]):
            parser.error('Not the same number of echoes per contrast')

    # Set TR and FlipAngle parameters for MT (mtoff contrast)
    # and T1w images
    parameters = [
        set_acq_parameters(maps[0][0].replace('.nii.gz', '.json')),
        set_acq_parameters(maps[2][0].replace('.nii.gz', '.json'))
    ]

    # Fix issue from the presence of invalide value and division by zero
    np.seterr(divide='ignore', invalid='ignore')

    # Define reference image for saving maps
    ref_img = nib.load(maps[0][0])

    # Define contrasts maps names
    contrasts_name = ['mt_off', 'mt_on', 'T1w']

    if args.out_prefix:
        contrasts_name = [
            args.out_prefix + '_' + curr_name for curr_name in contrasts_name
        ]

    # Compute contrasts maps
    computed_contrasts = []
    for idx, curr_map in enumerate(maps):
        computed_contrasts.append(compute_contrasts_maps(curr_map))

        nib.save(
            nib.Nifti1Image(computed_contrasts[idx].astype(np.float32),
                            ref_img.affine),
            os.path.join(args.out_dir, 'Contrasts_MT_maps',
                         contrasts_name[idx] + '.nii.gz'))

    # Compute and thresold MT maps
    MTR, MTsat = compute_MT_maps(computed_contrasts, parameters)
    for curr_map in MTR, MTsat:
        curr_map = threshold_MT_maps(curr_map, args.in_mask, 0, 100)
        if args.in_B1_map:
            curr_map = apply_B1_correction(curr_map, args.in_B1_map)

    # Save MT maps
    img_name = ['MTR', 'MTsat']

    if args.in_B1_map:
        img_name = [curr_name + '_B1_corrected' for curr_name in img_name]

    if args.out_prefix:
        img_name = [
            args.out_prefix + '_' + curr_name for curr_name in img_name
        ]

    img_data = MTR, MTsat
    for img_to_save, name in zip(img_data, img_name):
        nib.save(
            nib.Nifti1Image(img_to_save.astype(np.float32), ref_img.affine),
            os.path.join(args.out_dir, 'MT_native_maps', name + '.nii.gz'))
Example #26
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_output_dirs_exist_and_empty(parser,
                                       args,
                                       os.path.join(args.out_dir,
                                                    'Contrasts_ihMT_maps'),
                                       os.path.join(args.out_dir,
                                                    'ihMT_native_maps'),
                                       create_dir=True)

    # Merge all echos path into a list
    maps = [
        args.in_altnp, args.in_altpn, args.in_mtoff, args.in_negative,
        args.in_positive, args.in_t1w
    ]

    maps_flat = (args.in_altnp + args.in_altpn + args.in_mtoff +
                 args.in_negative + args.in_positive + args.in_t1w)

    jsons = [curr_map.replace('.nii.gz', '.json') for curr_map in maps_flat]

    # check echoes number and jsons
    assert_inputs_exist(parser, jsons + maps_flat)
    for curr_map in maps[1:]:
        if len(curr_map) != len(maps[0]):
            parser.error('Not the same number of echoes per contrast')

    # Set TR and FlipAngle parameters for ihMT (positive contrast)
    # and T1w images
    parameters = [
        set_acq_parameters(maps[4][0].replace('.nii.gz', '.json')),
        set_acq_parameters(maps[5][0].replace('.nii.gz', '.json'))
    ]

    # Fix issue from the presence of invalide value and division by zero
    np.seterr(divide='ignore', invalid='ignore')

    # Define reference image for saving maps
    ref_img = nib.load(maps[4][0])

    # Define contrasts maps names
    contrasts_name = [
        'altnp', 'altpn', 'reference', 'negative', 'positive', 'T1w'
    ]
    if args.filtering:
        contrasts_name = [
            curr_name + '_filter' for curr_name in contrasts_name
        ]

    if args.out_prefix:
        contrasts_name = [
            args.out_prefix + '_' + curr_name for curr_name in contrasts_name
        ]

    # Compute contrasts maps
    computed_contrasts = []
    for idx, curr_map in enumerate(maps):
        computed_contrasts.append(
            compute_contrasts_maps(curr_map, filtering=args.filtering))

        nib.save(
            nib.Nifti1Image(computed_contrasts[idx].astype(np.float32),
                            ref_img.affine),
            os.path.join(args.out_dir, 'Contrasts_ihMT_maps',
                         contrasts_name[idx] + '.nii.gz'))

    # Compute and thresold ihMT maps
    ihMTR, ihMTsat = compute_ihMT_maps(computed_contrasts, parameters)
    ihMTR = threshold_ihMT_maps(ihMTR, computed_contrasts, args.in_mask, 0,
                                100, [4, 3, 1, 0, 2])
    ihMTsat = threshold_ihMT_maps(ihMTsat, computed_contrasts, args.in_mask, 0,
                                  10, [4, 3, 1, 0])

    # Compute and thresold non-ihMT maps
    MTR, MTsat = compute_MT_maps(computed_contrasts, parameters)
    for curr_map in MTR, MTsat:
        curr_map = threshold_ihMT_maps(curr_map, computed_contrasts,
                                       args.in_mask, 0, 100, [4, 2])

    # Save ihMT and MT images
    img_name = ['ihMTR', 'ihMTsat', 'MTR', 'MTsat']

    if args.filtering:
        img_name = [curr_name + '_filter' for curr_name in img_name]

    if args.out_prefix:
        img_name = [
            args.out_prefix + '_' + curr_name for curr_name in img_name
        ]

    img_data = ihMTR, ihMTsat, MTR, MTsat
    for img_to_save, name in zip(img_data, img_name):
        nib.save(
            nib.Nifti1Image(img_to_save.astype(np.float32), ref_img.affine),
            os.path.join(args.out_dir, 'ihMT_native_maps', name + '.nii.gz'))
Example #27
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_dwi, args.in_bval, args.in_bvec],
                        args.in_mask)
    assert_output_dirs_exist_and_empty(parser, args,
                                       os.path.join(args.out_dir, 'NODDI'),
                                       optional=args.save_kernels)

    # COMMIT has some c-level stdout and non-logging print that cannot
    # be easily stopped. Manual redirection of all printed output
    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)
        redirected_stdout = redirect_stdout(sys.stdout)
    else:
        f = io.StringIO()
        redirected_stdout = redirect_stdout(f)
        redirect_stdout_c()

    # Generage a scheme file from the bvals and bvecs files
    tmp_dir = tempfile.TemporaryDirectory()
    tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.scheme')
    tmp_bval_filename = os.path.join(tmp_dir.name, 'bval')
    bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec)
    shells_centroids, indices_shells = identify_shells(bvals,
                                                       args.b_thr,
                                                       roundCentroids=True)
    np.savetxt(tmp_bval_filename, shells_centroids[indices_shells],
               newline=' ', fmt='%i')
    fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename)
    logging.debug('Compute NODDI with AMICO on {} shells at found '
                  'at {}.'.format(len(shells_centroids), shells_centroids))

    with redirected_stdout:
        # Load the data
        amico.core.setup()
        ae = amico.Evaluation('.', '.')
        ae.load_data(args.in_dwi,
                     tmp_scheme_filename,
                     mask_filename=args.in_mask)
        # Compute the response functions
        ae.set_model("NODDI")

        intra_vol_frac = np.linspace(0.1, 0.99, 12)
        intra_orient_distr = np.hstack((np.array([0.03, 0.06]),
                                        np.linspace(0.09, 0.99, 10)))

        ae.model.set(args.para_diff, args.iso_diff,
                     intra_vol_frac, intra_orient_distr,
                     False)
        ae.set_solver(lambda1=args.lambda1, lambda2=args.lambda2)

        # The kernels are, by default, set to be in the current directory
        # Depending on the choice, manually change the saving location
        if args.save_kernels:
            kernels_dir = os.path.join(args.save_kernels)
            regenerate_kernels = True
        elif args.load_kernels:
            kernels_dir = os.path.join(args.load_kernels)
            regenerate_kernels = False
        else:
            kernels_dir = os.path.join(tmp_dir.name, 'kernels', ae.model.id)
            regenerate_kernels = True

        ae.set_config('ATOMS_path', kernels_dir)
        out_model_dir = os.path.join(args.out_dir, ae.model.id)
        ae.set_config('OUTPUT_path', out_model_dir)
        ae.generate_kernels(regenerate=regenerate_kernels)
        ae.load_kernels()

        # Set number of processes
        solver_params = ae.get_config('solver_params')
        solver_params['numThreads'] = args.nbr_processes
        ae.set_config('solver_params', solver_params)

        # Model fit
        ae.fit()
        # Save the results
        ae.save_results()

    tmp_dir.cleanup()
def main():
    parser = build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram, args.labels])

    if os.path.abspath(args.output_dir) == os.getcwd():
        parser.error('Do not use the current path as output directory.')

    assert_output_dirs_exist_and_empty(parser, args, args.output_dir,
                                       create_dir=True)

    log_level = logging.WARNING
    if args.verbose:
        log_level = logging.INFO
    logging.basicConfig(level=log_level)
    coloredlogs.install(level=log_level)
    set_sft_logger_level('WARNING')

    img_labels = nib.load(args.labels)
    data_labels = img_labels.get_fdata().astype(np.int16)
    real_labels = np.unique(data_labels)[1:]
    if args.out_labels_list:
        np.savetxt(args.out_labels_list, real_labels, fmt='%i')

    if not np.issubdtype(img_labels.get_data_dtype().type, np.integer):
        parser.error("Label image should contain integers for labels.")

    # Voxel size must be isotropic, for speed/performance considerations
    vox_sizes = img_labels.header.get_zooms()
    if not np.mean(vox_sizes) == vox_sizes[0]:
        parser.error('Labels must be isotropic')

    logging.info('*** Loading streamlines ***')
    time1 = time.time()
    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
    time2 = time.time()
    logging.info('    Loading {} streamlines took {} sec.'.format(
        len(sft), round(time2 - time1, 2)))

    logging.info('*** Filtering streamlines ***')
    data_mask = np.zeros(data_labels.shape)
    data_mask[data_labels > 0] = 1

    original_len = len(sft)
    time1 = time.time()

    sft.to_vox()
    sft.to_corner()
    sft.remove_invalid_streamlines()
    time2 = time.time()
    logging.info(
        '    Discarded {} streamlines from filtering in {} sec.'.format(
            original_len - len(sft), round(time2 - time1, 2)))
    logging.info('    Number of streamlines to process: {}'.format(len(sft)))

    # Get all streamlines intersection indices
    logging.info('*** Computing streamlines intersection ***')
    time1 = time.time()

    indices, points_to_idx = uncompress(sft.streamlines, return_mapping=True)

    time2 = time.time()
    logging.info('    Streamlines intersection took {} sec.'.format(
        round(time2 - time1, 2)))

    # Compute the connectivity mapping
    logging.info('*** Computing connectivity information ***')
    time1 = time.time()
    con_info = compute_connectivity(indices,
                                    data_labels, real_labels,
                                    extract_longest_segments_from_profile)
    time2 = time.time()
    logging.info('    Connectivity computation took {} sec.'.format(
        round(time2 - time1, 2)))

    # Prepare directories and information needed to save.
    _create_required_output_dirs(args)

    logging.info('*** Starting connection post-processing and saving. ***')
    logging.info('    This can be long, be patient.')
    time1 = time.time()

    # Saving will be done from streamlines already in the right space
    comb_list = list(itertools.combinations(real_labels, r=2))
    comb_list.extend(zip(real_labels, real_labels))

    iteration_counter = 0
    for in_label, out_label in comb_list:
        if iteration_counter > 0 and iteration_counter % 100 == 0:
            logging.info('Split {} nodes out of {}'.format(iteration_counter,
                                                           len(comb_list)))
        iteration_counter += 1

        pair_info = []
        if in_label not in con_info:
            continue
        elif out_label in con_info[in_label]:
            pair_info.extend(con_info[in_label][out_label])

        if out_label not in con_info:
            continue
        elif in_label in con_info[out_label]:
            pair_info.extend(con_info[out_label][in_label])

        if not len(pair_info):
            continue

        connecting_streamlines = []
        for connection in pair_info:
            strl_idx = connection['strl_idx']
            curr_streamlines = compute_streamline_segment(
                sft.streamlines[strl_idx],
                indices[strl_idx],
                connection['in_idx'],
                connection['out_idx'],
                points_to_idx[strl_idx])
            connecting_streamlines.append(curr_streamlines)

        _save_if_needed(connecting_streamlines, args, 'raw',
                        'raw', in_label, out_label)

        # Doing all post-processing
        if not args.no_pruning:
            valid_length, invalid_length = _prune_segments(
                connecting_streamlines,
                args.min_length,
                args.max_length,
                vox_sizes[0])

            _save_if_needed(invalid_length, args,
                            'discarded', 'invalid_length',
                            in_label, out_label)
        else:
            valid_length = connecting_streamlines

        if not len(valid_length):
            continue

        _save_if_needed(valid_length, args,
                        'intermediate', 'valid_length', in_label, out_label)

        if not args.no_remove_loops:
            no_loop_ids = remove_loops_and_sharp_turns(valid_length,
                                                       args.loop_max_angle)
            no_loops = [valid_length[i] for i in no_loop_ids]

            loop_ids = np.setdiff1d(np.arange(len(valid_length)), no_loop_ids)
            loops = [valid_length[i] for i in loop_ids]

            _save_if_needed(loops, args,
                            'discarded', 'loops', in_label, out_label)
        else:
            no_loops = valid_length

        if not len(no_loops):
            continue

        _save_if_needed(no_loops, args,
                        'intermediate', 'no_loops', in_label, out_label)

        if not args.no_remove_outliers:
            inliers, outliers = remove_outliers(no_loops,
                                                args.outlier_threshold)
            _save_if_needed(outliers, args,
                            'discarded', 'outliers', in_label, out_label)
        else:
            inliers = no_loops

        if not len(inliers):
            continue

        _save_if_needed(inliers, args,
                        'intermediate', 'inliers', in_label, out_label)

        if not args.no_remove_curv_dev:
            no_qb_curv_ids = remove_loops_and_sharp_turns(
                inliers,
                args.loop_max_angle,
                use_qb=True,
                qb_threshold=args.curv_qb_distance)
            no_qb_curv = [inliers[i] for i in no_qb_curv_ids]

            qb_curv_ids = np.setdiff1d(
                np.arange(len(inliers)), no_qb_curv_ids)
            qb_curv = [inliers[i] for i in qb_curv_ids]

            _save_if_needed(qb_curv, args,
                            'discarded', 'qb_curv', in_label, out_label)
        else:
            no_qb_curv = inliers

        _save_if_needed(no_qb_curv, args,
                        'final', 'final', in_label, out_label)

    time2 = time.time()
    logging.info(
        '    Connections post-processing and saving took {} sec.'.format(
            round(time2 - time1, 2)))