Exemplo n.º 1
0
def save_babel(dwi_data,
               dwi_header,
               b0_data,
               b0_header,
               bval_path,
               bvec_path,
               out_path,
               affine=None,
               flip=None,
               swap=None):
    """
    Save a loaded fdf file to nifti.

    Parameters
    ----------
    out_path: Path of the nifti file to be saved
    data: Raw data to be saved
    raw_header: Raw header from fdf files
    bval_path: Path to the bval file to be saved
    bvec_path: Path to the bvec file to be saved
    affine: Affine transformation to save with the data

    Return
    ------
    None
    """
    nifti1_dwi_header = format_raw_header(dwi_header)
    nifti1_b0_header = format_raw_header(b0_header)

    if not is_header_compatible(nifti1_dwi_header, nifti1_b0_header):
        raise Exception("Images are not of the same resolution/affine")

    nifti1_header = nifti1_dwi_header

    if 'orientation' in nifti1_header:
        orientation = np.identity(4)
        orientation[:3, :3] = nifti1_header['orientation'].reshape(3, 3)
        affine = np.linalg.inv(orientation)

    write_gradient_information(dwi_header, b0_header, bval_path, bvec_path,
                               flip, swap)

    data = np.concatenate([b0_data[:, :, :, np.newaxis], dwi_data], axis=3)

    nifti1_header.set_data_shape(data.shape)

    img = nib.nifti1.Nifti1Image(dataobj=data,
                                 header=nifti1_header,
                                 affine=affine)
    vox_dim = [round(num, 3) for num in dwi_header['voxel_dim'][0:4]]
    img.header.set_zooms(vox_dim)
    qform = img.header.get_qform()
    qform[:2, :3] *= -1.

    if 'origin' in nifti1_header:
        qform[:len(nifti1_header['origin']), 3] = -nifti1_header['origin']

    img.get_header().set_qform(qform)
    img.update_header()
    img.to_filename(out_path)
Exemplo n.º 2
0
    def are_compatible(sft_1, sft_2):
        """ Compatibility verification of two StatefulTractogram to ensure space,
        origin, data_per_point and data_per_streamline consistency """

        are_sft_compatible = True
        if not is_header_compatible(sft_1, sft_2):
            logger.warning('Inconsistent spatial attributes between both sft.')
            are_sft_compatible = False

        if sft_1.space != sft_2.space:
            logger.warning('Inconsistent space between both sft.')
            are_sft_compatible = False
        if sft_1.origin != sft_2.origin:
            logger.warning('Inconsistent origin between both sft.')
            are_sft_compatible = False

        if sft_1.get_data_per_point_keys() != sft_2.get_data_per_point_keys():
            logger.warning(
                'Inconsistent data_per_point between both sft.')
            are_sft_compatible = False
        if sft_1.get_data_per_streamline_keys() != \
                sft_2.get_data_per_streamline_keys():
            logger.warning(
                'Inconsistent data_per_streamline between both sft.')
            are_sft_compatible = False

        return are_sft_compatible
Exemplo n.º 3
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_bundles)
    output_streamlines_filename = '{}streamlines.trk'.format(
        args.output_prefix)
    output_voxels_filename = '{}voxels.nii.gz'.format(args.output_prefix)
    assert_outputs_exist(parser, args,
                         [output_voxels_filename, output_streamlines_filename])

    if not 0 <= args.ratio_voxels <= 1 or not 0 <= args.ratio_streamlines <= 1:
        parser.error('Ratios must be between 0 and 1.')

    fusion_streamlines = []
    if args.reference:
        reference_file = args.reference
    else:
        reference_file = args.in_bundles[0]
    sft_list = []
    for name in args.in_bundles:
        tmp_sft = load_tractogram_with_reference(parser, args, name)
        tmp_sft.to_vox()
        tmp_sft.to_corner()

        if not is_header_compatible(reference_file, tmp_sft):
            raise ValueError('Headers are not compatible.')
        sft_list.append(tmp_sft)
        fusion_streamlines.append(tmp_sft.streamlines)

    fusion_streamlines, _ = union_robust(fusion_streamlines)

    transformation, dimensions, _, _ = get_reference_info(reference_file)
    volume = np.zeros(dimensions)
    streamlines_vote = dok_matrix(
        (len(fusion_streamlines), len(args.in_bundles)))

    for i in range(len(args.in_bundles)):
        sft = sft_list[i]
        binary = compute_tract_counts_map(sft.streamlines, dimensions)
        volume[binary > 0] += 1

        if args.same_tractogram:
            _, indices = intersection_robust(
                [fusion_streamlines, sft.streamlines])
            streamlines_vote[list(indices), [i]] += 1

    if args.same_tractogram:
        real_indices = []
        ratio_value = int(args.ratio_streamlines * len(args.in_bundles))
        real_indices = np.where(
            np.sum(streamlines_vote, axis=1) >= ratio_value)[0]
        new_sft = StatefulTractogram.from_sft(fusion_streamlines[real_indices],
                                              sft_list[0])
        save_tractogram(new_sft, output_streamlines_filename)

    volume[volume < int(args.ratio_voxels * len(args.in_bundles))] = 0
    volume[volume > 0] = 1
    nib.save(nib.Nifti1Image(volume.astype(np.uint8), transformation),
             output_voxels_filename)
Exemplo n.º 4
0
def load_node_nifti(directory, in_label, out_label, ref_img):
    in_filename = os.path.join(directory,
                               '{}_{}.nii.gz'.format(in_label, out_label))

    if os.path.isfile(in_filename):
        if not is_header_compatible(in_filename, ref_img):
            raise IOError('{} do not have a compatible header'.format(
                in_filename))
        return nib.load(in_filename).get_fdata(dtype=np.float64)

    return None
Exemplo n.º 5
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_bundle, args.in_centroid])
    assert_outputs_exist(parser, args,
                         [args.output_label, args.output_distance])

    is_header_compatible(args.in_bundle, args.in_centroid)

    sft_bundle = load_tractogram_with_reference(parser, args, args.in_bundle)

    sft_centroid = load_tractogram_with_reference(parser, args,
                                                  args.in_centroid)

    if not len(sft_bundle.streamlines):
        logging.error('Empty bundle file {}. Skipping'.format(args.in_bundle))
        raise ValueError

    if not len(sft_centroid.streamlines):
        logging.error('Empty centroid streamline file {}. Skipping'.format(
            args.centroid_streamline))
        raise ValueError

    min_dist_label, min_dist = min_dist_to_centroid(
        sft_bundle.streamlines.data, sft_centroid.streamlines.data)
    min_dist_label += 1

    # Save assignment in a compressed numpy file
    # You can load this file and access its data using
    # f = np.load('someFile.npz')
    # assignment = f['arr_0']
    np.savez_compressed(args.output_label, min_dist_label)

    # Save distance in a compressed numpy file
    # You can load this file and access its data using
    # f = np.load('someFile.npz')
    # distance = f['arr_0']
    np.savez_compressed(args.output_distance, min_dist)
Exemplo n.º 6
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    assert_inputs_exist(parser, [args.in_tractogram, args.in_mask])
    assert_outputs_exist(parser, args, args.out_tractogram)

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
    if args.step_size is not None:
        sft = resample_streamlines_step_size(sft, args.step_size)

    mask_img = nib.load(args.in_mask)
    binary_mask = get_data_as_mask(mask_img)

    if not is_header_compatible(sft, mask_img):
        parser.error('Incompatible header between the tractogram and mask.')

    bundle_disjoint, _ = ndi.label(binary_mask)
    unique, count = np.unique(bundle_disjoint, return_counts=True)
    if args.biggest_blob:
        val = unique[np.argmax(count[1:]) + 1]
        binary_mask[bundle_disjoint != val] = 0
        unique = [0, val]
    if len(unique) == 2:
        logging.info('The provided mask has 1 entity '
                     'cut_outside_of_mask_streamlines function selected.')
        new_sft = cut_outside_of_mask_streamlines(sft, binary_mask)
    elif len(unique) == 3:
        logging.info('The provided mask has 2 entity '
                     'cut_between_masks_streamlines function selected.')
        new_sft = cut_between_masks_streamlines(sft, binary_mask)

    else:
        logging.error('The provided mask has more than 2 entities. Cannot cut '
                      'between >2.')
        return

    if len(new_sft) == 0:
        logging.warning('No streamline intersected the provided mask. '
                        'Saving empty tractogram.')
    elif args.error_rate is not None:
        compressed_strs = [
            compress_streamlines(s, args.error_rate)
            for s in new_sft.streamlines
        ]
        new_sft = StatefulTractogram.from_sft(compressed_strs, sft)

    save_tractogram(new_sft, args.out_tractogram)
Exemplo n.º 7
0
def assert_same_resolution(images):
    """
    Check the resolution of multiple images.
    Parameters
    ----------
    images : array of string or string
        List of images or an image.
    """
    if isinstance(images, str):
        images = [images]

    if len(images) == 0:
        raise Exception("Can't check if images are of the same "
                        "resolution/affine. No image has been given")

    for curr_image in images[1:]:
        if not is_header_compatible(images[0], curr_image):
            raise Exception("Images are not of the same resolution/affine")
def main():
    parser = _build_args_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_files)

    all_valid = True
    for filepath in args.in_files:
        _, in_extension = split_name_with_nii(filepath)
        if in_extension not in ['.trk', '.nii', '.nii.gz']:
            parser.error(
                '{} does not have a supported extension'.format(filepath))
        if not is_header_compatible(args.in_files[0], filepath):
            print('{} and {} do not have compatible header.'.format(
                args.in_files[0], filepath))
            all_valid = False
    if all_valid:
        print('All input files have compatible headers.')
Exemplo n.º 9
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if len(args.in_dwis) != len(args.in_bvals) \
            or len(args.in_dwis) != len(args.in_bvecs):
        parser.error('DWI, bvals and bvecs must have the same length')

    assert_inputs_exist(parser, args.in_dwis + args.in_bvals + args.in_bvecs)
    assert_outputs_exist(parser, args,
                         [args.out_dwi, args.out_bval, args.out_bvec])

    all_bvals = []
    all_bvecs = []
    total_size = 0
    for i in range(len(args.in_dwis)):
        bvals, bvecs = read_bvals_bvecs(args.in_bvals[i], args.in_bvecs[i])
        if len(bvals) != len(bvecs):
            raise ValueError('Paired bvals and bvecs must have the same size.')
        total_size += len(bvals)
        all_bvals.append(bvals)
        all_bvecs.append(bvecs)
    all_bvals = np.concatenate(all_bvals)
    all_bvecs = np.concatenate(all_bvecs)

    ref_dwi = nib.load(args.in_dwis[0])
    all_dwi = np.zeros(ref_dwi.shape[0:3] + (total_size, ),
                       dtype=args.data_type)
    last_count = ref_dwi.shape[-1]
    all_dwi[..., 0:last_count] = ref_dwi.get_fdata()
    for i in range(1, len(args.in_dwis)):
        curr_dwi = nib.load(args.in_dwis[i])
        if not is_header_compatible(curr_dwi, ref_dwi):
            raise ValueError('All DWI must have the compatible header.')
        curr_size = curr_dwi.shape[-1]
        all_dwi[..., last_count:last_count+curr_size] = \
            curr_dwi.get_fdata()

    np.savetxt(args.out_bval, all_bvals, '%d')
    np.savetxt(args.out_bvec, all_bvecs.T, '%0.15f')
    nib.save(nib.Nifti1Image(all_dwi, ref_dwi.affine, header=ref_dwi.header),
             args.out_dwi)
Exemplo n.º 10
0
def load_node_nifti(directory, in_label, out_label, ref_filename):
    in_filename_1 = os.path.join(directory,
                                 '{}_{}.nii.gz'.format(in_label, out_label))
    in_filename_2 = os.path.join(directory,
                                 '{}_{}.nii.gz'.format(out_label, in_label))
    in_filename = None
    if os.path.isfile(in_filename_1):
        in_filename = in_filename_1
    elif os.path.isfile(in_filename_2):
        in_filename = in_filename_2

    if in_filename is not None:
        if not is_header_compatible(in_filename, ref_filename):
            logging.error('{} and {} do not have a compatible header'.format(
                in_filename, ref_filename))
            raise IOError
        return nib.load(in_filename).get_fdata()

    _, dims, _, _ = get_reference_info(ref_filename)
    return np.zeros(dims)
Exemplo n.º 11
0
def verify_compatibility_with_reference_sft(ref_sft, files_to_verify, parser,
                                            args):
    """
    Verifies the compatibility of a reference sft with a list of files.

    Params
    ------
    ref_sft: StatefulTractogram
        A tractogram to be used as reference.
    files_to_verify: List[str]
        List of files that should be compatible with the reference sft. Files
        can be either other tractograms or nifti files (ex: masks).
    parser: argument parser
        Will raise an error if a file is not compatible.
    args: Namespace
        Should contain a args.reference if any file is a .tck.
    """
    save_ref = args.reference

    for file in files_to_verify:
        if file is not None:
            _, ext = os.path.splitext(file)
            if ext in ['.trk', '.tck', '.fib', '.vtk', '.dpy']:
                # Cheating ref because it may send a lot of warning if loading
                # many trk with ref (reference was maybe added only for some
                # of these files)
                if ext == '.trk':
                    args.reference = None
                else:
                    args.reference = save_ref
                mask = load_tractogram_with_reference(parser,
                                                      args,
                                                      file,
                                                      bbox_check=False)
            else:  # should be a nifti file.
                mask = file
            compatible = is_header_compatible(ref_sft, mask)
            if not compatible:
                parser.error(
                    "Reference tractogram incompatible with {}".format(file))
Exemplo n.º 12
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.moving_tractogram, args.target_file,
                                 args.deformation])
    assert_outputs_exist(parser, args, args.out_tractogram)

    sft = load_tractogram_with_reference(parser, args, args.moving_tractogram,
                                         bbox_check=False)

    deformation = nib.load(args.deformation)
    deformation_data = np.squeeze(deformation.get_fdata())

    if not is_header_compatible(sft, deformation):
        parser.error('Input tractogram/reference do not have the same spatial '
                     'attribute as the deformation field.')

    # Warning: Apply warp in-place
    moved_streamlines = warp_streamlines(sft, deformation_data)
    new_sft = StatefulTractogram(moved_streamlines, args.target_file,
                                 Space.RASMM,
                                 data_per_point=sft.data_per_point,
                                 data_per_streamline=sft.data_per_streamline)

    if args.remove_invalid:
        ori_len = len(new_sft)
        new_sft.remove_invalid_streamlines()
        logging.warning('Removed {} invalid streamlines.'.format(
            ori_len - len(new_sft)))
        save_tractogram(new_sft, args.out_tractogram)
    elif args.keep_invalid:
        if not new_sft.is_bbox_in_vox_valid():
            logging.warning('Saving tractogram with invalid streamlines.')
        save_tractogram(new_sft, args.out_tractogram, bbox_valid_check=False)
    else:
        save_tractogram(new_sft, args.out_tractogram)
Exemplo n.º 13
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram] + args.gt_bundles)
    assert_output_dirs_exist_and_empty(parser,
                                       args,
                                       args.out_dir,
                                       create_dir=True)

    if (args.gt_tails and not args.gt_heads) \
            or (args.gt_heads and not args.gt_tails):
        parser.error("Both --gt_heads and --gt_tails are needed.")
    if args.gt_endpoints and (args.gt_tails or args.gt_heads):
        parser.error("Can only provide --gt_endpoints or --gt_tails/gt_heads")
    if not args.gt_endpoints and (not args.gt_tails and not args.gt_heads):
        parser.error(
            "Either input --gt_endpoints or --gt_heads and --gt_tails.")

    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    _, ext = os.path.splitext(args.in_tractogram)
    sft = load_tractogram_with_reference(parser,
                                         args,
                                         args.in_tractogram,
                                         bbox_check=False)

    if args.remove_invalid:
        sft.remove_invalid_streamlines()

    initial_count = len(sft)

    logging.info("Verifying compatibility with ground-truth")
    for gt in args.gt_bundles:
        compatible = is_header_compatible(sft, gt)
        if not compatible:
            parser.error("Input tractogram incompatible with" " {}".format(gt))

    logging.info("Computing ground-truth masks")
    gt_bundle_masks, gt_bundle_inv_masks, affine, dimensions,  = \
        compute_gt_masks(args.gt_bundles, parser, args)

    # If endpoints without heads/tails are loaded, split them and continue
    # normally after. Q/C of the output is important
    if args.gt_endpoints:
        logging.info("Extracting ground-truth end and tail masks")
        gt_tails, gt_heads, affine, dimensions = \
            extract_tails_heads_from_endpoints(
                args.gt_endpoints, args.out_dir)
    else:
        gt_tails, gt_heads = args.gt_tails, args.gt_heads

    logging.info("Verifying compatibility with endpoints")
    for gt in gt_tails + gt_heads:
        compatible = is_header_compatible(sft, gt)
        if not compatible:
            parser.error("Input tractogram incompatible with" " {}".format(gt))

    # Load the endpoints heads/tails, keep the correct combinations
    # separately from all the possible combinations
    tc_filenames = list(zip(gt_tails, gt_heads))

    length_dict = {}
    if args.gt_config:
        with open(args.gt_config, "r") as json_file:
            length_dict = json.load(json_file)

    tc_streamlines_list = []
    wpc_streamlines_list = []
    fc_streamlines_list = []
    nc_streamlines = []

    logging.info("Scoring true connections")
    for i, (mask_1_filename, mask_2_filename) in enumerate(tc_filenames):

        # Automatically generate filename for Q/C
        prefix_1 = extract_prefix(mask_1_filename)
        prefix_2 = extract_prefix(mask_2_filename)

        tc_sft, wpc_sft, fc_sft, nc, sft = extract_true_connections(
            sft, mask_1_filename, mask_2_filename, args.gt_config, length_dict,
            extract_prefix(args.gt_bundles[i]), gt_bundle_inv_masks[i],
            args.dilate_endpoints, args.wrong_path_as_separate)
        nc_streamlines.extend(nc)
        if len(tc_sft) > 0:
            save_tractogram(tc_sft,
                            os.path.join(
                                args.out_dir,
                                "{}_{}_tc{}".format(prefix_1, prefix_2, ext)),
                            bbox_valid_check=False)

        if len(wpc_sft) > 0:
            save_tractogram(wpc_sft,
                            os.path.join(
                                args.out_dir,
                                "{}_{}_wpc{}".format(prefix_1, prefix_2, ext)),
                            bbox_valid_check=False)

        if len(fc_sft) > 0:
            save_tractogram(fc_sft,
                            os.path.join(
                                args.out_dir,
                                "{}_{}_fc{}".format(prefix_1, prefix_2, ext)),
                            bbox_valid_check=False)

        tc_streamlines_list.append(tc_sft.streamlines)
        wpc_streamlines_list.append(wpc_sft.streamlines)
        fc_streamlines_list.append(fc_sft.streamlines)

        logging.info("Recognized {} streamlines between {} and {}".format(
            len(tc_sft.streamlines) + len(wpc_sft.streamlines) +
            len(fc_sft.streamlines) + len(nc), prefix_1, prefix_2))

    # Again keep the keep the correct combinations
    comb_filename = list(
        itertools.combinations(itertools.chain(*zip(gt_tails, gt_heads)), r=2))

    # Remove the true connections from all combinations, leaving only
    # false connections
    for tc_f in tc_filenames:
        comb_filename.remove(tc_f)

    logging.info("Scoring false connections")
    # Go through all the possible combinations of endpoints masks
    for i, roi in enumerate(comb_filename):
        mask_1_filename, mask_2_filename = roi

        # That would be done here.
        # Automatically generate filename for Q/C
        prefix_1 = extract_prefix(mask_1_filename)
        prefix_2 = extract_prefix(mask_2_filename)
        _, ext = os.path.splitext(args.in_tractogram)

        fc_sft, sft = extract_false_connections(sft, mask_1_filename,
                                                mask_2_filename,
                                                args.dilate_endpoints)

        if len(fc_sft) > 0:
            save_tractogram(fc_sft,
                            os.path.join(
                                args.out_dir,
                                "{}_{}_fc{}".format(prefix_1, prefix_2, ext)),
                            bbox_valid_check=False)

        logging.info("Recognized {} streamlines between {} and {}".format(
            len(fc_sft.streamlines), prefix_1, prefix_2))

        fc_streamlines_list.append(fc_sft.streamlines)

    nc_streamlines.extend(sft.streamlines)

    final_results = {}
    no_conn_sft = StatefulTractogram.from_sft(nc_streamlines, sft)
    save_tractogram(no_conn_sft,
                    os.path.join(args.out_dir, "nc{}".format(ext)),
                    bbox_valid_check=False)

    # Total number of streamlines for each category
    # and statistic that are not "bundle-wise"
    tc_streamlines_count = len(list(itertools.chain(*tc_streamlines_list)))
    fc_streamlines_count = len(list(itertools.chain(*fc_streamlines_list)))

    if args.wrong_path_as_separate:
        wpc_streamlines_count = len(
            list(itertools.chain(*wpc_streamlines_list)))
    else:
        wpc_streamlines_count = 0

    nc_streamlines_count = len(nc_streamlines)
    total_count = tc_streamlines_count + fc_streamlines_count + \
        wpc_streamlines_count + nc_streamlines_count

    assert total_count == initial_count

    final_results["tractogram_filename"] = str(args.in_tractogram)
    final_results["tractogram_overlap"] = 0.0
    final_results["tc_streamlines"] = tc_streamlines_count
    final_results["fc_streamlines"] = fc_streamlines_count
    final_results["nc_streamlines"] = nc_streamlines_count

    final_results["tc_bundle"] = len([x for x in tc_streamlines_list if x])
    final_results["fc_bundle"] = len([x for x in fc_streamlines_list if x])

    final_results["tc_streamlines_ratio"] = tc_streamlines_count / total_count
    final_results["fc_streamlines_ratio"] = fc_streamlines_count / total_count
    final_results["nc_streamlines_ratio"] = nc_streamlines_count / total_count

    if args.wrong_path_as_separate:
        final_results["wpc_streamlines"] = wpc_streamlines_count
        final_results["wpc_streamlines_ratio"] = \
            wpc_streamlines_count / total_count
        final_results["wpc_bundle"] = len(
            [x for x in wpc_streamlines_list if x])

    final_results["total_streamlines"] = total_count
    final_results["bundle_wise"] = {}
    final_results["bundle_wise"]["true_connections"] = {}
    final_results["bundle_wise"]["false_connections"] = {}
    tractogram_overlap = 0.0

    for i, filename in enumerate(tc_filenames):
        current_tc_streamlines = tc_streamlines_list[i]
        current_tc_voxels, current_tc_endpoints_voxels = get_binary_maps(
            current_tc_streamlines, sft)

        if args.wrong_path_as_separate:
            current_wpc_streamlines = wpc_streamlines_list[i]
            current_wpc_voxels, _ = get_binary_maps(current_wpc_streamlines,
                                                    sft)

        tmp_dict = {}
        tmp_dict["tc_streamlines"] = len(current_tc_streamlines)

        tmp_dict["tc_dice"] = compute_dice_voxel(gt_bundle_masks[i],
                                                 current_tc_voxels)[0]

        bundle_overlap = gt_bundle_masks[i] * current_tc_voxels
        bundle_overreach = np.zeros(dimensions)
        bundle_overreach[np.where((gt_bundle_masks[i] == 0)
                                  & (current_tc_voxels >= 1))] = 1
        bundle_lacking = np.zeros(dimensions)
        bundle_lacking[np.where((gt_bundle_masks[i] == 1)
                                & (current_tc_voxels == 0))] = 1

        if args.wrong_path_as_separate:
            tmp_dict["wpc_streamlines"] = len(current_wpc_streamlines)
            tmp_dict["wpc_dice"] = \
                compute_dice_voxel(gt_bundle_masks[i],
                                   current_wpc_voxels)[0]
            # Add wrong path to overreach
            bundle_overreach[np.where((gt_bundle_masks[i] == 0)
                                      & (current_wpc_voxels >= 1))] = 1

        tmp_dict["tc_bundle_overlap"] = np.count_nonzero(bundle_overlap)
        tmp_dict["tc_bundle_overreach"] = \
            np.count_nonzero(bundle_overreach)
        tmp_dict["tc_bundle_lacking"] = np.count_nonzero(bundle_lacking)
        tmp_dict["tc_bundle_overlap_PCT"] = \
            tmp_dict["tc_bundle_overlap"] / \
            (tmp_dict["tc_bundle_overlap"] +
                tmp_dict["tc_bundle_lacking"])
        tractogram_overlap += tmp_dict["tc_bundle_overlap_PCT"]

        endpoints_overlap = \
            gt_bundle_masks[i] * current_tc_endpoints_voxels
        endpoints_overreach = np.zeros(dimensions)
        endpoints_overreach[np.where((gt_bundle_masks[i] == 0)
                                     & (current_tc_endpoints_voxels >= 1))] = 1
        tmp_dict["tc_endpoints_overlap"] = np.count_nonzero(endpoints_overlap)
        tmp_dict["tc_endpoints_overreach"] = np.count_nonzero(
            endpoints_overreach)

        final_results["bundle_wise"]["true_connections"][str(filename)] = \
            tmp_dict

    # Bundle-wise statistics, useful for more complex phantom
    for i, filename in enumerate(comb_filename):
        current_fc_streamlines = fc_streamlines_list[i]
        current_fc_voxels, _ = get_binary_maps(current_fc_streamlines, sft)

        tmp_dict = {}

        if len(current_fc_streamlines):
            tmp_dict["fc_streamlines"] = len(current_fc_streamlines)
            tmp_dict["fc_voxels"] = np.count_nonzero(current_fc_voxels)

            final_results["bundle_wise"]["false_connections"][str(filename)] =\
                tmp_dict

    final_results["tractogram_overlap"] = \
        tractogram_overlap / len(gt_bundle_masks)

    with open(os.path.join(args.out_dir, "results.json"), "w") as f:
        json.dump(final_results,
                  f,
                  indent=args.indent,
                  sort_keys=args.sort_keys)
Exemplo n.º 14
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram, args.in_labels],
                        args.reference)
    assert_outputs_exist(parser, args, args.out_hdf5)

    # HDF5 will not overwrite the file
    if os.path.isfile(args.out_hdf5):
        os.remove(args.out_hdf5)

    if (args.save_raw_connections or args.save_intermediate
            or args.save_discarded) and not args.out_dir:
        parser.error('To save outputs in the streamlines form, provide the '
                     'output directory using --out_dir.')

    if args.out_dir:
        if os.path.abspath(args.out_dir) == os.getcwd():
            parser.error('Do not use the current path as output directory.')
        assert_output_dirs_exist_and_empty(parser,
                                           args,
                                           args.out_dir,
                                           create_dir=True)

    log_level = logging.WARNING
    if args.verbose:
        log_level = logging.INFO
    logging.basicConfig(level=log_level)
    coloredlogs.install(level=log_level)
    set_sft_logger_level('WARNING')

    img_labels = nib.load(args.in_labels)
    data_labels = get_data_as_label(img_labels)
    real_labels = np.unique(data_labels)[1:]
    if args.out_labels_list:
        np.savetxt(args.out_labels_list, real_labels, fmt='%i')

    # Voxel size must be isotropic, for speed/performance considerations
    vox_sizes = img_labels.header.get_zooms()
    if not np.allclose(np.mean(vox_sizes), vox_sizes, atol=1e-03):
        parser.error('Labels must be isotropic')

    logging.info('*** Loading streamlines ***')
    time1 = time.time()
    sft = load_tractogram_with_reference(parser,
                                         args,
                                         args.in_tractogram,
                                         bbox_check=False)
    sft.remove_invalid_streamlines()
    time2 = time.time()
    logging.info('    Loading {} streamlines took {} sec.'.format(
        len(sft), round(time2 - time1, 2)))

    if not is_header_compatible(sft, img_labels):
        raise IOError('{} and {}do not have a compatible header'.format(
            args.in_tractogram, args.in_labels))

    sft.to_vox()
    sft.to_corner()
    # Get all streamlines intersection indices
    logging.info('*** Computing streamlines intersection ***')
    time1 = time.time()

    indices, points_to_idx = uncompress(sft.streamlines, return_mapping=True)

    time2 = time.time()
    logging.info('    Streamlines intersection took {} sec.'.format(
        round(time2 - time1, 2)))

    # Compute the connectivity mapping
    logging.info('*** Computing connectivity information ***')
    time1 = time.time()
    con_info = compute_connectivity(indices, data_labels, real_labels,
                                    extract_longest_segments_from_profile)
    time2 = time.time()
    logging.info('    Connectivity computation took {} sec.'.format(
        round(time2 - time1, 2)))

    # Prepare directories and information needed to save.
    _create_required_output_dirs(args)

    logging.info('*** Starting connection post-processing and saving. ***')
    logging.info('    This can be long, be patient.')
    time1 = time.time()

    # Saving will be done from streamlines already in the right space
    comb_list = list(itertools.combinations(real_labels, r=2))
    comb_list.extend(zip(real_labels, real_labels))

    iteration_counter = 0
    with h5py.File(args.out_hdf5, 'w') as hdf5_file:
        affine, dimensions, voxel_sizes, voxel_order = get_reference_info(sft)
        hdf5_file.attrs['affine'] = affine
        hdf5_file.attrs['dimensions'] = dimensions
        hdf5_file.attrs['voxel_sizes'] = voxel_sizes
        hdf5_file.attrs['voxel_order'] = voxel_order

        # Each connections is processed independently. Multiprocessing would be
        # a burden on the I/O of most SSD/HD
        for in_label, out_label in comb_list:
            if iteration_counter > 0 and iteration_counter % 100 == 0:
                logging.info('Split {} nodes out of {}'.format(
                    iteration_counter, len(comb_list)))
            iteration_counter += 1

            pair_info = []
            if in_label not in con_info:
                continue
            elif out_label in con_info[in_label]:
                pair_info.extend(con_info[in_label][out_label])

            if out_label not in con_info:
                continue
            elif in_label in con_info[out_label]:
                pair_info.extend(con_info[out_label][in_label])

            if not len(pair_info):
                continue

            connecting_streamlines = []
            connecting_ids = []
            for connection in pair_info:
                strl_idx = connection['strl_idx']
                curr_streamlines = compute_streamline_segment(
                    sft.streamlines[strl_idx], indices[strl_idx],
                    connection['in_idx'], connection['out_idx'],
                    points_to_idx[strl_idx])
                connecting_streamlines.append(curr_streamlines)
                connecting_ids.append(strl_idx)

            # Each step is processed from the previous 'success'
            #   1. raw         -> length pass/fail
            #   2. length pass -> loops pass/fail
            #   3. loops pass  -> outlier detection pass/fail
            #   4. outlier detection pass -> qb curvature pass/fail
            #   5. qb curvature pass == final connections
            connecting_streamlines = ArraySequence(connecting_streamlines)
            raw_dps = sft.data_per_streamline[connecting_ids]
            raw_sft = StatefulTractogram.from_sft(connecting_streamlines,
                                                  sft,
                                                  data_per_streamline=raw_dps,
                                                  data_per_point={})
            _save_if_needed(raw_sft, hdf5_file, args, 'raw', 'raw', in_label,
                            out_label)

            # Doing all post-processing
            if not args.no_pruning:
                valid_length_ids, invalid_length_ids = _prune_segments(
                    raw_sft.streamlines, args.min_length, args.max_length,
                    vox_sizes[0])

                invalid_length_sft = raw_sft[invalid_length_ids]
                valid_length = connecting_streamlines[valid_length_ids]
                _save_if_needed(invalid_length_sft, hdf5_file, args,
                                'discarded', 'invalid_length', in_label,
                                out_label)
            else:
                valid_length = connecting_streamlines
                valid_length_ids = range(len(connecting_streamlines))

            if not len(valid_length):
                continue

            valid_length_sft = raw_sft[valid_length_ids]
            _save_if_needed(valid_length_sft, hdf5_file, args, 'intermediate',
                            'valid_length', in_label, out_label)

            if not args.no_remove_loops:
                no_loop_ids = remove_loops_and_sharp_turns(
                    valid_length, args.loop_max_angle)
                loop_ids = np.setdiff1d(np.arange(len(valid_length)),
                                        no_loop_ids)

                loops_sft = valid_length_sft[loop_ids]
                no_loops = valid_length[no_loop_ids]
                _save_if_needed(loops_sft, hdf5_file, args, 'discarded',
                                'loops', in_label, out_label)
            else:
                no_loops = valid_length
                no_loop_ids = range(len(valid_length))

            if not len(no_loops):
                continue
            no_loops_sft = valid_length_sft[no_loop_ids]
            _save_if_needed(no_loops_sft, hdf5_file, args, 'intermediate',
                            'no_loops', in_label, out_label)

            if not args.no_remove_outliers:
                outliers_ids, inliers_ids = remove_outliers(
                    no_loops,
                    args.outlier_threshold,
                    nb_samplings=10,
                    fast_approx=True)

                outliers_sft = no_loops_sft[outliers_ids]
                inliers = no_loops[inliers_ids]
                _save_if_needed(outliers_sft, hdf5_file, args, 'discarded',
                                'outliers', in_label, out_label)
            else:
                inliers = no_loops
                inliers_ids = range(len(no_loops))

            if not len(inliers):
                continue

            inliers_sft = no_loops_sft[inliers_ids]
            _save_if_needed(inliers_sft, hdf5_file, args, 'intermediate',
                            'inliers', in_label, out_label)

            if not args.no_remove_curv_dev:
                no_qb_curv_ids = remove_loops_and_sharp_turns(
                    inliers,
                    args.loop_max_angle,
                    use_qb=True,
                    qb_threshold=args.curv_qb_distance)
                qb_curv_ids = np.setdiff1d(np.arange(len(inliers)),
                                           no_qb_curv_ids)

                qb_curv_sft = inliers_sft[qb_curv_ids]
                _save_if_needed(qb_curv_sft, hdf5_file, args, 'discarded',
                                'qb_curv', in_label, out_label)
            else:
                no_qb_curv_ids = range(len(inliers))

            no_qb_curv_sft = inliers_sft[no_qb_curv_ids]
            _save_if_needed(no_qb_curv_sft, hdf5_file, args, 'final', 'final',
                            in_label, out_label)

    time2 = time.time()
    logging.info(
        '    Connections post-processing and saving took {} sec.'.format(
            round(time2 - time1, 2)))
Exemplo n.º 15
0
def concatenate_sft(sft_list, erase_metadata=False, metadata_fake_init=False):
    """ Concatenate a list of StatefulTractogram together """
    if erase_metadata:
        sft_list[0].data_per_point = {}
        sft_list[0].data_per_streamline = {}

    for sft in sft_list[1:]:
        if erase_metadata:
            sft.data_per_point = {}
            sft.data_per_streamline = {}
        elif metadata_fake_init:
            for dps_key in list(sft.data_per_streamline.keys()):
                if dps_key not in sft_list[0].data_per_streamline.keys():
                    del sft.data_per_streamline[dps_key]
            for dpp_key in list(sft.data_per_point.keys()):
                if dpp_key not in sft_list[0].data_per_point.keys():
                    del sft.data_per_point[dpp_key]

            for dps_key in sft_list[0].data_per_streamline.keys():
                if dps_key not in sft.data_per_streamline:
                    arr_shape = sft_list[0].data_per_streamline[dps_key].shape
                    arr_shape[0] = len(sft)
                    sft.data_per_streamline[dps_key] = np.zeros(arr_shape)
            for dpp_key in sft_list[0].data_per_point.keys():
                if dpp_key not in sft.data_per_point:
                    arr_seq = ArraySequence()
                    arr_seq_shape = list(
                        sft_list[0].data_per_point[dpp_key]._data.shape)
                    arr_seq_shape[0] = len(sft.streamlines._data)
                    arr_seq._data = np.zeros(arr_seq_shape)
                    arr_seq._offsets = sft.streamlines._offsets
                    arr_seq._lengths = sft.streamlines._lengths
                    sft.data_per_point[dpp_key] = arr_seq

        if not metadata_fake_init and \
                not StatefulTractogram.are_compatible(sft, sft_list[0]):
            raise ValueError('Incompatible SFT, check space attributes and '
                             'data_per_point/streamlines.')
        elif not is_header_compatible(sft, sft_list[0]):
            raise ValueError('Incompatible SFT, check space attributes.')

    total_streamlines = 0
    total_points = 0
    lengths = []
    for sft in sft_list:
        total_streamlines += len(sft.streamlines._offsets)
        total_points += len(sft.streamlines._data)
        lengths.extend(sft.streamlines._lengths)
    lengths = np.array(lengths, dtype=np.uint32)
    offsets = np.concatenate(([0], np.cumsum(lengths[:-1]))).astype(np.uint64)

    dpp = {}
    for dpp_key in sft_list[0].data_per_point.keys():
        arr_seq_shape = list(sft_list[0].data_per_point[dpp_key]._data.shape)
        arr_seq_shape[0] = total_points
        dpp[dpp_key] = ArraySequence()
        dpp[dpp_key]._data = np.zeros(arr_seq_shape)
        dpp[dpp_key]._lengths = lengths
        dpp[dpp_key]._offsets = offsets

    dps = {}
    for dps_key in sft_list[0].data_per_streamline.keys():
        arr_seq_shape = list(sft_list[0].data_per_streamline[dps_key].shape)
        arr_seq_shape[0] = total_streamlines
        dps[dps_key] = np.zeros(arr_seq_shape)

    streamlines = ArraySequence()
    streamlines._data = np.zeros((total_points, 3))
    streamlines._lengths = lengths
    streamlines._offsets = offsets

    pts_counter = 0
    strs_counter = 0
    for sft in sft_list:
        pts_curr_len = len(sft.streamlines._data)
        strs_curr_len = len(sft.streamlines._offsets)

        if strs_curr_len == 0 or pts_curr_len == 0:
            continue

        streamlines._data[pts_counter:pts_counter+pts_curr_len] = \
            sft.streamlines._data

        for dpp_key in sft_list[0].data_per_point.keys():
            dpp[dpp_key]._data[pts_counter:pts_counter+pts_curr_len] = \
                sft.data_per_point[dpp_key]._data
        for dps_key in sft_list[0].data_per_streamline.keys():
            dps[dps_key][strs_counter:strs_counter+strs_curr_len] = \
                sft.data_per_streamline[dps_key]
        pts_counter += pts_curr_len
        strs_counter += strs_curr_len

    fused_sft = StatefulTractogram.from_sft(streamlines,
                                            sft_list[0],
                                            data_per_point=dpp,
                                            data_per_streamline=dps)
    return fused_sft
Exemplo n.º 16
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    assert_inputs_exist(parser, args.in_tractograms)
    assert_outputs_exist(parser, args, args.out_tractogram,
                         optional=args.save_indices)

    if args.operation == 'lazy_concatenate':
        logging.info('Using lazy_concatenate, no spatial or metadata related '
                     'checks are performed.\nMetadata will be lost, only '
                     'trk/tck file are supported.')

        def list_generator_from_nib(filenames):
            for in_file in filenames:
                tractogram_file = nib.streamlines.load(in_file, lazy_load=True)
                for s in tractogram_file.streamlines:
                    yield s
        header = None
        for in_file in args.in_tractograms:
            _, ext = os.path.splitext(in_file)
            if ext == '.trk':
                if header is None:
                    header = nib.streamlines.load(
                        in_file, lazy_load=True).header
                elif not is_header_compatible(header, in_file):
                    logging.warning('Incompatible headers in the list.')

        generator = list_generator_from_nib(args.in_tractograms)
        out_tractogram = LazyTractogram(lambda: generator,
                                        affine_to_rasmm=np.eye(4))
        nib.streamlines.save(out_tractogram, args.out_tractogram,
                             header=header)
        return

    # Load all input streamlines.
    sft_list = []
    for f in args.in_tractograms:
        sft_list.append(load_tractogram_with_reference(
            parser, args, f, bbox_check=not args.ignore_invalid))

    # Apply the requested operation to each input file.
    logging.info('Performing operation \'{}\'.'.format(args.operation))
    new_sft = concatenate_sft(sft_list, args.no_metadata, args.fake_metadata)
    if args.operation == 'concatenate':
        indices = np.arange(len(new_sft), dtype=np.uint32)
    else:
        streamlines_list = [sft.streamlines for sft in sft_list]
        op_name = args.operation
        if args.robust:
            op_name += '_robust'
            _, indices = OPERATIONS[op_name](streamlines_list,
                                             precision=args.precision)
        else:
            _, indices = perform_streamlines_operation(
                OPERATIONS[op_name], streamlines_list,
                precision=args.precision)

    # Save the indices to a file if requested.
    if args.save_indices:
        start = 0
        out_dict = {}
        streamlines_len_cumsum = [len(sft) for sft in sft_list]
        for name, nb in zip(args.in_tractograms, streamlines_len_cumsum):
            end = start + nb
            # Switch to int32 for json
            out_dict[name] = [int(i - start)
                              for i in indices if start <= i < end]
            start = end

        with open(args.save_indices, 'wt') as f:
            json.dump(out_dict, f,
                      indent=args.indent,
                      sort_keys=args.sort_keys)

    # Save the new streamlines (and metadata)
    logging.info('Saving {} streamlines to {}.'.format(len(indices),
                                                       args.out_tractogram))
    save_tractogram(new_sft[indices], args.out_tractogram,
                    bbox_valid_check=not args.ignore_invalid)
Exemplo n.º 17
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_volume)
    assert_outputs_exist(parser, args, args.out_image)

    output_names = [
        'axial_superior', 'axial_inferior', 'coronal_posterior',
        'coronal_anterior', 'sagittal_left', 'sagittal_right'
    ]

    for filename in args.in_bundles:
        _, ext = os.path.splitext(filename)
        if ext == '.tck':
            tractogram = load_tractogram_with_reference(parser, args, filename)
        else:
            tractogram = filename
        if not is_header_compatible(args.in_volume, tractogram):
            parser.error('{} does not have a compatible header with {}'.format(
                filename, args.in_volume))
        # Delete temporary tractogram
        else:
            del tractogram

    output_dir = os.path.dirname(args.out_image)
    if output_dir:
        assert_output_dirs_exist_and_empty(parser,
                                           args,
                                           output_dir,
                                           create_dir=True)

    _, extension = os.path.splitext(args.out_image)

    # ----------------------------------------------------------------------- #
    # Mosaic, column 0: orientation names and data description
    # ----------------------------------------------------------------------- #
    width = args.resolution_of_thumbnails
    height = args.resolution_of_thumbnails
    rows = 6
    cols = len(args.in_bundles)
    text_pos_x = 50
    text_pos_y = 50

    # Creates a new empty image, RGB mode
    mosaic = Image.new('RGB', ((cols + 1) * width, (rows + 1) * height))

    # Prepare draw and font objects to render text
    draw = ImageDraw.Draw(mosaic)
    font = get_font(args)

    # Data of the volume used as background
    ref_img = nib.load(args.in_volume)
    data = ref_img.get_fdata(dtype=np.float32)
    affine = ref_img.affine
    mean, std = data[data > 0].mean(), data[data > 0].std()
    value_range = (mean - 0.5 * std, mean + 1.5 * std)

    # First column with rows description
    draw_column_with_names(draw, output_names, text_pos_x, text_pos_y, height,
                           font)

    # ----------------------------------------------------------------------- #
    # Columns with bundles
    # ----------------------------------------------------------------------- #
    random.seed(args.random_coloring)
    for idx_bundle, bundle_file in enumerate(args.in_bundles):

        bundle_file_name = os.path.basename(bundle_file)
        bundle_name, bundle_ext = split_name_with_nii(bundle_file_name)

        i = (idx_bundle + 1) * width

        if not os.path.isfile(bundle_file):
            print('\nInput file {} doesn\'t exist.'.format(bundle_file))

            number_streamlines = 0

            view_number = 6
            j = height * view_number

            draw_bundle_information(draw, bundle_file_name, number_streamlines,
                                    i + text_pos_x, j + text_pos_y, font)

        else:
            if args.uniform_coloring:
                colors = args.uniform_coloring
            elif args.random_coloring is not None:
                colors = random_rgb()
            # Select the streamlines to plot
            if bundle_ext in ['.tck', '.trk']:
                if (args.random_coloring is None
                        and args.uniform_coloring is None):
                    colors = None
                bundle_tractogram_file = nib.streamlines.load(bundle_file)
                streamlines = bundle_tractogram_file.streamlines
                bundle_actor = actor.line(streamlines, colors)
                nbr_of_elem = len(streamlines)
            # Select the volume to plot
            elif bundle_ext in ['.nii.gz', '.nii']:
                if not args.random_coloring and not args.uniform_coloring:
                    colors = [1.0, 1.0, 1.0]
                bundle_img_file = nib.load(bundle_file)
                roi = get_data_as_mask(bundle_img_file)
                bundle_actor = actor.contour_from_roi(roi,
                                                      bundle_img_file.affine,
                                                      colors)
                nbr_of_elem = np.count_nonzero(roi)

            # Render
            ren = window.Scene()
            zoom = args.zoom
            opacity = args.opacity_background

            # Structural data
            slice_actor = actor.slicer(data, affine, value_range)
            slice_actor.opacity(opacity)
            ren.add(slice_actor)

            # Streamlines
            ren.add(bundle_actor)
            ren.reset_camera()
            ren.zoom(zoom)
            view_number = 0
            set_img_in_cell(mosaic, ren, view_number, width, height, i)

            ren.pitch(180)
            ren.reset_camera()
            ren.zoom(zoom)
            view_number = 1
            set_img_in_cell(mosaic, ren, view_number, width, height, i)

            ren.rm(slice_actor)
            slice_actor2 = slice_actor.copy()
            slice_actor2.display(None, slice_actor2.shape[1] // 2, None)
            slice_actor2.opacity(opacity)
            ren.add(slice_actor2)

            ren.pitch(90)
            ren.set_camera(view_up=(0, 0, 1))
            ren.reset_camera()
            ren.zoom(zoom)
            view_number = 2
            set_img_in_cell(mosaic, ren, view_number, width, height, i)

            ren.pitch(180)
            ren.set_camera(view_up=(0, 0, 1))
            ren.reset_camera()
            ren.zoom(zoom)
            view_number = 3
            set_img_in_cell(mosaic, ren, view_number, width, height, i)

            ren.rm(slice_actor2)
            slice_actor3 = slice_actor.copy()
            slice_actor3.display(slice_actor3.shape[0] // 2, None, None)
            slice_actor3.opacity(opacity)
            ren.add(slice_actor3)

            ren.yaw(90)
            ren.reset_camera()
            ren.zoom(zoom)
            view_number = 4
            set_img_in_cell(mosaic, ren, view_number, width, height, i)

            ren.yaw(180)
            ren.reset_camera()
            ren.zoom(zoom)
            view_number = 5
            set_img_in_cell(mosaic, ren, view_number, width, height, i)

            view_number = 6
            j = height * view_number
            draw_bundle_information(draw, bundle_file_name, nbr_of_elem,
                                    i + text_pos_x, j + text_pos_y, font)

    # Save image to file
    mosaic.save(args.out_image)
Exemplo n.º 18
0
def main():
    parser = _build_args_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_bundles)
    output_streamlines_filename = '{}streamlines.trk'.format(
        args.output_prefix)
    output_voxels_filename = '{}voxels.nii.gz'.format(args.output_prefix)
    assert_outputs_exist(parser, args,
                         [output_voxels_filename, output_streamlines_filename])

    if not 0 <= args.ratio_voxels <= 1 or not 0 <= args.ratio_streamlines <= 1:
        parser.error('Ratios must be between 0 and 1.')

    fusion_streamlines = []
    for name in args.in_bundles:
        fusion_streamlines.extend(
            load_tractogram_with_reference(parser, args, name).streamlines)

    fusion_streamlines, _ = perform_streamlines_operation(
        union, [fusion_streamlines], 0)
    fusion_streamlines = ArraySequence(fusion_streamlines)
    if args.reference:
        reference_file = args.reference
    else:
        reference_file = args.in_bundles[0]

    transformation, dimensions, _, _ = get_reference_info(reference_file)
    volume = np.zeros(dimensions)
    streamlines_vote = dok_matrix(
        (len(fusion_streamlines), len(args.in_bundles)))

    for i, name in enumerate(args.in_bundles):
        if not is_header_compatible(reference_file, name):
            raise ValueError('Both headers are not the same')
        sft = load_tractogram_with_reference(parser, args, name)
        bundle = sft.get_streamlines_copy()
        sft.to_vox()
        bundle_vox_space = sft.get_streamlines_copy()
        binary = compute_tract_counts_map(bundle_vox_space, dimensions)
        volume[binary > 0] += 1

        if args.same_tractogram:
            _, indices = perform_streamlines_operation(
                intersection, [fusion_streamlines, bundle], 0)
            streamlines_vote[list(indices), i] += 1

    if args.same_tractogram:
        real_indices = []
        for i in range(len(fusion_streamlines)):
            ratio_value = int(args.ratio_streamlines * len(args.in_bundles))
            if np.sum(streamlines_vote[i]) >= ratio_value:
                real_indices.append(i)

        new_streamlines = fusion_streamlines[real_indices]

        sft = StatefulTractogram(new_streamlines, reference_file, Space.RASMM)
        save_tractogram(sft, output_streamlines_filename)

    volume[volume < int(args.ratio_streamlines * len(args.in_bundles))] = 0
    volume[volume > 0] = 1
    nib.save(nib.Nifti1Image(volume.astype(np.uint8), transformation),
             output_voxels_filename)
Exemplo n.º 19
0
def main():
    parser = _buildArgsParser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_tractogram])
    assert_outputs_exists(parser, args, [args.out_tractogram])
    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    roi_opt_list = prepare_filtering_list(parser, args)

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)

    for i, roi_opt in enumerate(roi_opt_list):
        # Atlas needs an extra argument (value in the LUT)
        if roi_opt[0] == 'atlas_roi':
            filter_type, filter_arg_1, filter_arg_2, \
                filter_mode, filter_criteria = roi_opt
        else:
            filter_type, filter_arg, filter_mode, filter_criteria = roi_opt
        is_not = False if filter_criteria == 'include' else True

        if filter_type == 'drawn_roi':
            img = nib.load(filter_arg)
            if not is_header_compatible(img, sft):
                parser.error('Headers from the tractogram and the mask are '
                             'not compatible.')

            mask = img.get_data()
            filtered_streamlines, indexes = filter_grid_roi(
                sft, mask, filter_mode, is_not)

        elif filter_type == 'atlas_roi':
            img = nib.load(filter_arg_1)
            if not is_header_compatible(img, sft):
                parser.error('Headers from the tractogram and the mask are '
                             'not compatible.')

            atlas = img.get_data().astype(np.uint16)
            mask = np.zeros(atlas.shape, dtype=np.uint16)
            mask[atlas == int(filter_arg_2)] = 1

            filtered_streamlines, indexes = filter_grid_roi(
                sft, mask, filter_mode, is_not)

        # For every case, the input number must be greater or equal to 0 and
        # below the dimension, since this is a voxel space operation
        elif filter_type in ['x_plane', 'y_plane', 'z_plane']:
            filter_arg = int(filter_arg)
            _, dim, _, _ = sft.space_attribute
            mask = np.zeros(dim, dtype=np.int16)
            error_msg = None
            if filter_type == 'x_plane':
                if 0 <= filter_arg < dim[0]:
                    mask[filter_arg, :, :] = 1
                else:
                    error_msg = 'X plane ' + str(filter_arg)

            elif filter_type == 'y_plane':
                if 0 <= filter_arg < dim[1]:
                    mask[:, filter_arg, :] = 1
                else:
                    error_msg = 'Y plane ' + str(filter_arg)

            elif filter_type == 'z_plane':
                if 0 <= filter_arg < dim[2]:
                    mask[:, :, filter_arg] = 1
                else:
                    error_msg = 'Z plane ' + str(filter_arg)

            if error_msg:
                parser.error('{} is not valid according to the '
                             'tractogram header.'.format(error_msg))

            filtered_streamlines, indexes = filter_grid_roi(
                sft, mask, filter_mode, is_not)

        elif filter_type == 'bdo':
            geometry, radius, center = read_info_from_mb_bdo(filter_arg)
            if geometry == 'Ellipsoid':
                filtered_streamlines, indexes = filter_ellipsoid(
                    sft, radius, center, filter_mode, is_not)
            elif geometry == 'Cuboid':
                filtered_streamlines, indexes = filter_cuboid(
                    sft, radius, center, filter_mode, is_not)

        logging.debug('The filtering options {0} resulted in '
                      '{1} streamlines'.format(roi_opt,
                                               len(filtered_streamlines)))

        data_per_streamline = sft.data_per_streamline[indexes]
        data_per_point = sft.data_per_point[indexes]

        sft = StatefulTractogram(filtered_streamlines,
                                 sft,
                                 Space.RASMM,
                                 data_per_streamline=data_per_streamline,
                                 data_per_point=data_per_point)

    if not filtered_streamlines:
        if args.no_empty:
            logging.debug("The file {} won't be written (0 streamline)".format(
                args.out_tractogram))

            return

        logging.debug('The file {} contains 0 streamline'.format(
            args.out_tractogram))

    save_tractogram(sft, args.out_tractogram)
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_bundle] + args.in_metrics,
                        optional=args.in_centroid)

    if args.nb_pts_per_streamline <= 1:
        parser.error('--nb_pts_per_streamline {} needs to be greater than '
                     '1'.format(args.nb_pts_per_streamline))

    assert_same_resolution(args.in_metrics + [args.in_bundle])
    sft = load_tractogram_with_reference(parser, args, args.in_bundle)

    metrics = [nib.load(m) for m in args.in_metrics]

    bundle_name, _ = os.path.splitext(os.path.basename(args.in_bundle))
    stats = {}
    if len(sft) == 0:
        stats[bundle_name] = None
        print(json.dumps(stats, indent=args.indent, sort_keys=args.sort_keys))
        return

    # Centroid - will be use as reference to reorient each streamline
    if args.in_centroid:
        is_header_compatible(args.in_bundle, args.in_centroid)
        sft_centroid = load_tractogram_with_reference(parser, args,
                                                      args.in_centroid)
        centroid_streamlines = sft_centroid.streamlines[0]
        nb_pts_per_streamline = len(centroid_streamlines)
    else:
        centroid_streamlines = get_streamlines_centroid(
            sft.streamlines, args.nb_pts_per_streamline)
        nb_pts_per_streamline = args.nb_pts_per_streamline

    resampled_sft = resample_streamlines_num_points(sft, nb_pts_per_streamline)

    # Make sure all streamlines go in the same direction. We want to make
    # sure point #1 / args.nb_pts_per_streamline of streamline A is matched
    # with point #1 / 20 of streamline B and so on
    num_streamlines = len(resampled_sft)

    for s in np.arange(num_streamlines):
        streamline = resampled_sft.streamlines[s]
        direct = average_euclidean(centroid_streamlines, streamline)
        flipped = average_euclidean(centroid_streamlines, streamline[::-1])

        if flipped < direct:
            resampled_sft.streamlines[s] = streamline[::-1]

    profiles = get_bundle_metrics_profiles(resampled_sft, metrics)
    t_profiles = np.expand_dims(profiles, axis=1)
    t_profiles = np.rollaxis(t_profiles, 3, 2)

    stats[bundle_name] = {}
    for metric, profile, t_profile in zip(metrics, profiles, t_profiles):
        metric_name, _ = split_name_with_nii(
            os.path.basename(metric.get_filename()))
        stats[bundle_name][metric_name] = {
            'mean': np.mean(profile, axis=0).tolist(),
            'std': np.std(profile, axis=0).tolist(),
            'bundleprofile': t_profile.tolist()
        }

    print(json.dumps(stats, indent=args.indent, sort_keys=args.sort_keys))
Exemplo n.º 21
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    assert_outputs_exist(parser, args, args.out_image)

    # Binary operations require specific verifications
    binary_op = [
        'union', 'intersection', 'difference', 'invert', 'dilation', 'erosion',
        'closing', 'opening'
    ]

    if args.operation not in OPERATIONS.keys():
        parser.error('Operation {} not implement.'.format(args.operation))

    # Find at least one image for reference
    for input_arg in args.in_images:
        if not is_float(input_arg):
            ref_img = nib.load(input_arg)
            mask = np.zeros(ref_img.shape)
            break

    # Load all input masks.
    input_data = []
    for input_arg in args.in_images:
        if not is_float(input_arg) and \
                not is_header_compatible(ref_img, input_arg):
            parser.error('Inputs do not have a compatible header.')
        data = load_data(input_arg)

        if isinstance(data, np.ndarray) and \
            data.dtype != ref_img.get_data_dtype() and \
                not args.data_type:
            parser.error('Inputs do not have a compatible data type.\n'
                         'Use --data_type to specify output datatype.')
        if args.operation in binary_op and isinstance(data, np.ndarray):
            unique = np.unique(data)
            if not len(unique) <= 2:
                parser.error('Binary operations can only be performed with '
                             'binary masks')

            if len(unique) == 2 and not (unique == [0, 1]).all():
                logging.warning('Input data for binary operation are not '
                                'binary arrays, will be converted.\n'
                                'Non-zeros will be set to ones.')
                data[data != 0] = 1

        if isinstance(data, np.ndarray):
            data = data.astype(np.float64)
            mask[data > 0] = 1
        input_data.append(data)

    if args.operation == 'convert' and not args.data_type:
        parser.error('Convert operation must be used with --data_type.')

    try:
        output_data = OPERATIONS[args.operation](input_data)
    except ValueError:
        logging.error('{} operation failed.'.format(
            args.operation.capitalize()))
        return

    if args.data_type:
        output_data = output_data.astype(args.data_type)
        ref_img.header.set_data_dtype(args.data_type)
    else:
        output_data = output_data.astype(ref_img.get_data_dtype())

    if args.exclude_background:
        output_data[mask == 0] = 0

    new_img = nib.Nifti1Image(output_data,
                              ref_img.affine,
                              header=ref_img.header)
    nib.save(new_img, args.out_image)
Exemplo n.º 22
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    # The number of labels maps must be equal to the number of bundles
    tmp = args.in_bundles + args.in_labels
    args.in_labels = args.in_bundles[(len(tmp) // 2):] + args.in_labels
    args.in_bundles = args.in_bundles[0:len(tmp) // 2]
    assert_inputs_exist(parser, args.in_bundles + args.in_labels)
    assert_output_dirs_exist_and_empty(parser,
                                       args, [],
                                       optional=args.save_rendering)

    stats = {}
    num_digits_labels = 3
    scene = window.Scene()
    scene.background(tuple(map(int, args.background)))
    for i, filename in enumerate(args.in_bundles):
        sft = load_tractogram_with_reference(parser, args, filename)
        sft.to_vox()
        sft.to_corner()
        img_labels = nib.load(args.in_labels[i])

        # same subject: same header or coregistered subjects: same header
        if not is_header_compatible(sft, args.in_bundles[0]) \
                or not is_header_compatible(img_labels, args.in_bundles[0]):
            parser.error('All headers must be identical.')

        data_labels = img_labels.get_fdata()
        bundle_name, _ = os.path.splitext(os.path.basename(filename))
        unique_labels = np.unique(data_labels)[1:].astype(int)

        # Empty bundle should at least return a json
        if not len(sft):
            tmp_dict = {}
            for label in unique_labels:
                tmp_dict['{}'.format(label).zfill(num_digits_labels)] \
                    = {'mean': 0.0, 'std': 0.0}
            stats[bundle_name] = {'diameter': tmp_dict}
            continue

        counter = 0
        labels_dict = {label: ([], []) for label in unique_labels}
        pts_labels = map_coordinates(data_labels,
                                     sft.streamlines._data.T - 0.5,
                                     order=0)
        # For each label, all positions and directions are needed to get
        # a tube estimation per label.
        for streamline in sft.streamlines:
            direction = np.gradient(streamline, axis=0).tolist()
            curr_labels = pts_labels[counter:counter +
                                     len(streamline)].tolist()

            for i, label in enumerate(curr_labels):
                if label > 0:
                    labels_dict[label][0].append(streamline[i])
                    labels_dict[label][1].append(direction[i])

            counter += len(streamline)

        centroid = np.zeros((len(unique_labels), 3))
        radius = np.zeros((len(unique_labels), 1))
        error = np.zeros((len(unique_labels), 1))
        for key in unique_labels:
            key = int(key)
            c, d, e = fit_circle_in_space(labels_dict[key][0],
                                          labels_dict[key][1],
                                          args.fitting_func)
            centroid[key - 1], radius[key - 1], error[key - 1] = c, d, e

        # Spatial smoothing to avoid degenerate estimation
        centroid_smooth = gaussian_filter(centroid,
                                          sigma=[1, 0],
                                          mode='nearest')
        centroid_smooth[::len(centroid) - 1] = centroid[::len(centroid) - 1]
        radius = gaussian_filter(radius, sigma=1, mode='nearest')
        error = gaussian_filter(error, sigma=1, mode='nearest')

        tmp_dict = {}
        for label in unique_labels:
            tmp_dict['{}'.format(label).zfill(num_digits_labels)] \
                = {'mean': float(radius[label-1])*2,
                   'std': float(error[label-1])}
        stats[bundle_name] = {'diameter': tmp_dict}

        if args.show_rendering or args.save_rendering:
            tube_actor = create_tube_with_radii(
                centroid_smooth,
                radius,
                error,
                wireframe=args.wireframe,
                error_coloring=args.error_coloring)
            scene.add(tube_actor)
            cmap = plt.get_cmap('jet')
            coloring = cmap(pts_labels / np.max(pts_labels))[:, 0:3]
            streamlines_actor = actor.streamtube(sft.streamlines,
                                                 linewidth=args.width,
                                                 opacity=args.opacity,
                                                 colors=coloring)
            scene.add(streamlines_actor)

            slice_actor = actor.slicer(data_labels, np.eye(4))
            slice_actor.opacity(0.0)
            scene.add(slice_actor)

    # If there's actually streamlines to display
    if args.show_rendering:
        showm = window.ShowManager(scene, reset_camera=True)
        showm.initialize()
        showm.start()
    elif args.save_rendering:
        scene.reset_camera()
        snapshot(scene,
                 os.path.join(args.save_rendering, 'superior.png'),
                 size=(1920, 1080),
                 offscreen=True)

        scene.pitch(180)
        scene.reset_camera()
        snapshot(scene,
                 os.path.join(args.save_rendering, 'inferior.png'),
                 size=(1920, 1080),
                 offscreen=True)

        scene.pitch(90)
        scene.set_camera(view_up=(0, 0, 1))
        scene.reset_camera()
        snapshot(scene,
                 os.path.join(args.save_rendering, 'posterior.png'),
                 size=(1920, 1080),
                 offscreen=True)

        scene.pitch(180)
        scene.set_camera(view_up=(0, 0, 1))
        scene.reset_camera()
        snapshot(scene,
                 os.path.join(args.save_rendering, 'anterior.png'),
                 size=(1920, 1080),
                 offscreen=True)

        scene.yaw(90)
        scene.reset_camera()
        snapshot(scene,
                 os.path.join(args.save_rendering, 'right.png'),
                 size=(1920, 1080),
                 offscreen=True)

        scene.yaw(180)
        scene.reset_camera()
        snapshot(scene,
                 os.path.join(args.save_rendering, 'left.png'),
                 size=(1920, 1080),
                 offscreen=True)
    print(json.dumps(stats, indent=args.indent, sort_keys=args.sort_keys))
Exemplo n.º 23
0
def _processing_wrapper(args):
    hdf5_filename = args[0]
    labels_img = args[1]
    in_label, out_label = args[2]
    measures_to_compute = copy.copy(args[3])
    if args[4] is not None:
        similarity_directory = args[4][0]
    weighted = args[5]
    include_dps = args[6]
    min_lesion_vol = args[7]

    hdf5_file = h5py.File(hdf5_filename, 'r')
    key = '{}_{}'.format(in_label, out_label)
    if key not in hdf5_file:
        return
    streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key)
    if len(streamlines) == 0:
        return

    affine, dimensions, voxel_sizes, _ = get_reference_info(labels_img)
    measures_to_return = {}

    if not (np.allclose(hdf5_file.attrs['affine'], affine, atol=1e-03)
            and np.array_equal(hdf5_file.attrs['dimensions'], dimensions)):
        raise ValueError('Provided hdf5 have incompatible headers.')

    # Precompute to save one transformation, insert later
    if 'length' in measures_to_compute:
        streamlines_copy = list(streamlines)
        # scil_decompose_connectivity.py requires isotropic voxels
        mean_length = np.average(length(streamlines_copy))*voxel_sizes[0]

    # If density is not required, do not compute it
    # Only required for volume, similarity and any metrics
    if not ((len(measures_to_compute) == 1 and
             ('length' in measures_to_compute or
              'streamline_count' in measures_to_compute)) or
            (len(measures_to_compute) == 2 and
             ('length' in measures_to_compute and
              'streamline_count' in measures_to_compute))):

        density = compute_tract_counts_map(streamlines,
                                           dimensions)

    if 'volume' in measures_to_compute:
        measures_to_return['volume'] = np.count_nonzero(density) * \
            np.prod(voxel_sizes)
        measures_to_compute.remove('volume')
    if 'streamline_count' in measures_to_compute:
        measures_to_return['streamline_count'] = len(streamlines)
        measures_to_compute.remove('streamline_count')
    if 'length' in measures_to_compute:
        measures_to_return['length'] = mean_length
        measures_to_compute.remove('length')
    if 'similarity' in measures_to_compute and similarity_directory:
        density_sim = load_node_nifti(similarity_directory,
                                      in_label, out_label,
                                      labels_img)
        if density_sim is None:
            ba_vox = 0
        else:
            ba_vox = compute_bundle_adjacency_voxel(density, density_sim)

        measures_to_return['similarity'] = ba_vox
        measures_to_compute.remove('similarity')

    for measure in measures_to_compute:
        # Maps
        if isinstance(measure, str) and os.path.isdir(measure):
            map_dirname = measure
            map_data = load_node_nifti(map_dirname,
                                       in_label, out_label,
                                       labels_img)
            measures_to_return[map_dirname] = np.average(
                map_data[map_data > 0])
        elif isinstance(measure, tuple):
            if not isinstance(measure[0], tuple) \
                    and os.path.isfile(measure[0]):
                metric_filename = measure[0]
                metric_img = measure[1]
                if not is_header_compatible(metric_img, labels_img):
                    logging.error('{} do not have a compatible header'.format(
                        metric_filename))
                    raise IOError

                metric_data = metric_img.get_fdata(dtype=np.float64)
                if weighted:
                    avg_value = np.average(metric_data, weights=density)
                else:
                    avg_value = np.average(metric_data[density > 0])
                measures_to_return[metric_filename] = avg_value
            # lesion
            else:
                lesion_filename = measure[0][0]
                computed_lesion_labels = measure[0][1]
                lesion_img = measure[1]
                if not is_header_compatible(lesion_img, labels_img):
                    logging.error('{} do not have a compatible header'.format(
                        lesion_filename))
                    raise IOError

                voxel_sizes = lesion_img.header.get_zooms()[0:3]
                lesion_img.set_filename('tmp.nii.gz')
                lesion_atlas = get_data_as_label(lesion_img)
                tmp_dict = compute_lesion_stats(
                    density.astype(bool), lesion_atlas,
                    voxel_sizes=voxel_sizes, single_label=True,
                    min_lesion_vol=min_lesion_vol,
                    precomputed_lesion_labels=computed_lesion_labels)

                tmp_ind = _streamlines_in_mask(list(streamlines),
                                               lesion_atlas.astype(np.uint8),
                                               np.eye(3), [0, 0, 0])
                streamlines_count = len(
                    np.where(tmp_ind == [0, 1][True])[0].tolist())

                if tmp_dict:
                    measures_to_return[lesion_filename+'vol'] = \
                        tmp_dict['lesion_total_volume']
                    measures_to_return[lesion_filename+'count'] = \
                        tmp_dict['lesion_count']
                    measures_to_return[lesion_filename+'sc'] = \
                        streamlines_count
                else:
                    measures_to_return[lesion_filename+'vol'] = 0
                    measures_to_return[lesion_filename+'count'] = 0
                    measures_to_return[lesion_filename+'sc'] = 0

    if include_dps:
        for dps_key in hdf5_file[key].keys():
            if dps_key not in ['data', 'offsets', 'lengths']:
                out_file = os.path.join(include_dps, dps_key)
                if 'commit' in dps_key:
                    measures_to_return[out_file] = np.sum(
                        hdf5_file[key][dps_key])
                else:
                    measures_to_return[out_file] = np.average(
                        hdf5_file[key][dps_key])

    return {(in_label, out_label): measures_to_return}
Exemplo n.º 24
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    image_files = []
    indices_per_volume = []
    # Separate argument per volume
    used_indices_all = False
    for v_args in args.volume_ids:
        if len(v_args) < 2:
            logging.error("No indices was given for a given volume")

        image_files.append(v_args[0])
        if "all" in v_args:
            used_indices_all = True
            indices_per_volume.append("all")
        else:
            indices_per_volume.append(np.asarray(v_args[1:], dtype=np.int))

    if used_indices_all and args.out_labels_ids:
        logging.error("'all' indices cannot be used with 'out_labels_ids'")

    # Check inputs / output
    assert_inputs_exist(parser, image_files)
    assert_outputs_exist(parser, args, args.output)

    # Load volume and do checks
    data_list = []
    first_img = nib.load(image_files[0])
    for i in range(len(image_files)):
        # Load images
        volume_nib = nib.load(image_files[i])
        data = np.round(volume_nib.get_data()).astype(np.int)
        data_list.append(data)
        assert (is_header_compatible(first_img, image_files[i]))

        if (isinstance(indices_per_volume[i], str)
                and indices_per_volume[i] == "all"):
            indices_per_volume[i] = np.unique(data)

    filtered_ids_per_vol = []
    # Remove background labels
    for id_list in indices_per_volume:
        id_list = np.asarray(id_list)
        new_ids = id_list[~np.in1d(id_list, args.background)]
        filtered_ids_per_vol.append(new_ids)
    # Prepare output indices
    if args.out_labels_ids:
        out_labels = args.out_labels_ids
        if len(out_labels) != len(np.hstack(indices_per_volume)):
            logging.error("--out_labels_ids, requires the same amount"
                          " of total given input indices")
    elif args.unique:
        stack = np.hstack(filtered_ids_per_vol)
        ids = np.arange(len(stack) + 1)
        out_labels = np.setdiff1d(ids, args.background)[:len(stack)]
    elif args.group_in_m:
        m_list = []
        for i in range(len(filtered_ids_per_vol)):
            prefix = i * 1000000
            m_list.append(prefix + np.asarray(filtered_ids_per_vol[i]))
        out_labels = np.hstack(m_list)
    else:
        out_labels = np.hstack(filtered_ids_per_vol)

    if len(np.unique(out_labels)) != len(out_labels):
        logging.error("The same output label number was used "
                      "for multiple inputs")

    # Create the resulting volume
    current_id = 0
    resulting_labels = (np.ones_like(data_list[0], dtype=np.int) *
                        args.background)
    for i in range(len(image_files)):
        # Add Given labels for each volume
        for index in filtered_ids_per_vol[i]:
            mask = data_list[i] == index
            resulting_labels[mask] = out_labels[current_id]
            current_id += 1

            if np.count_nonzero(mask) == 0:
                logging.warning("Label {} was not in the volume".format(index))

    # Save final combined volume
    nii = nib.Nifti1Image(resulting_labels, first_img.affine, first_img.header)
    nib.save(nii, args.output)
Exemplo n.º 25
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, [args.in_hdf5, args.in_labels],
                        args.force_labels_list)

    log_level = logging.WARNING
    if args.verbose:
        log_level = logging.INFO
    logging.basicConfig(level=log_level)
    coloredlogs.install(level=log_level)

    measures_to_compute = []
    measures_output_filename = []
    if args.volume:
        measures_to_compute.append('volume')
        measures_output_filename.append(args.volume)
    if args.streamline_count:
        measures_to_compute.append('streamline_count')
        measures_output_filename.append(args.streamline_count)
    if args.length:
        measures_to_compute.append('length')
        measures_output_filename.append(args.length)
    if args.similarity:
        measures_to_compute.append('similarity')
        measures_output_filename.append(args.similarity[1])

    dict_maps_out_name = {}
    if args.maps is not None:
        for in_folder, out_name in args.maps:
            measures_to_compute.append(in_folder)
            dict_maps_out_name[in_folder] = out_name
            measures_output_filename.append(out_name)

    dict_metrics_out_name = {}
    if args.metrics is not None:
        for in_name, out_name in args.metrics:
            # Verify that all metrics are compatible with each other
            if not is_header_compatible(args.metrics[0][0], in_name):
                raise IOError('Metrics {} and  {} do not share a compatible '
                              'header'.format(args.metrics[0][0], in_name))

            # This is necessary to support more than one map for weighting
            measures_to_compute.append((in_name, nib.load(in_name)))
            dict_metrics_out_name[in_name] = out_name
            measures_output_filename.append(out_name)

    dict_lesion_out_name = {}
    if args.lesion_load is not None:
        in_name = args.lesion_load[0]
        lesion_img = nib.load(in_name)
        lesion_data = get_data_as_mask(lesion_img, dtype=bool)
        lesion_atlas, _ = ndi.label(lesion_data)
        measures_to_compute.append(((in_name, np.unique(lesion_atlas)[1:]),
                                    nib.Nifti1Image(lesion_atlas,
                                                    lesion_img.affine)))

        out_name_1 = os.path.join(args.lesion_load[1], 'lesion_vol.npy')
        out_name_2 = os.path.join(args.lesion_load[1], 'lesion_count.npy')
        out_name_3 = os.path.join(args.lesion_load[1], 'lesion_sc.npy')

        dict_lesion_out_name[in_name+'vol'] = out_name_1
        dict_lesion_out_name[in_name+'count'] = out_name_2
        dict_lesion_out_name[in_name+'sc'] = out_name_3
        measures_output_filename.extend([out_name_1, out_name_2, out_name_3])

    assert_outputs_exist(parser, args, measures_output_filename)
    if not measures_to_compute:
        parser.error('No connectivity measures were selected, nothing '
                     'to compute.')

    logging.info('The following measures will be computed and save: {}'.format(
        measures_output_filename))

    if args.include_dps:
        if not os.path.isdir(args.include_dps):
            os.makedirs(args.include_dps)
        logging.info('data_per_streamline weighting is activated.')

    img_labels = nib.load(args.in_labels)
    data_labels = get_data_as_label(img_labels)
    if not args.force_labels_list:
        labels_list = np.unique(data_labels)[1:].tolist()
    else:
        labels_list = np.loadtxt(
            args.force_labels_list, dtype=np.int16).tolist()

    comb_list = list(itertools.combinations(labels_list, r=2))
    if not args.no_self_connection:
        comb_list.extend(zip(labels_list, labels_list))

    nbr_cpu = validate_nbr_processes(parser, args)
    measures_dict_list = []
    if nbr_cpu == 1:
        for comb in comb_list:
            measures_dict_list.append(_processing_wrapper([args.in_hdf5,
                                                           img_labels, comb,
                                                           measures_to_compute,
                                                           args.similarity,
                                                           args.density_weighting,
                                                           args.include_dps,
                                                           args.min_lesion_vol]))
    else:
        pool = multiprocessing.Pool(nbr_cpu)
        measures_dict_list = pool.map(_processing_wrapper,
                                      zip(itertools.repeat(args.in_hdf5),
                                          itertools.repeat(img_labels),
                                          comb_list,
                                          itertools.repeat(
                                              measures_to_compute),
                                          itertools.repeat(args.similarity),
                                          itertools.repeat(
                                          args.density_weighting),
                                          itertools.repeat(args.include_dps),
                                          itertools.repeat(args.min_lesion_vol)))
        pool.close()
        pool.join()

    # Removing None entries (combinaisons that do not exist)
    # Fusing the multiprocessing output into a single dictionary
    measures_dict_list = [it for it in measures_dict_list if it is not None]
    if not measures_dict_list:
        raise ValueError('Empty matrix, no entries to save.')
    measures_dict = measures_dict_list[0]
    for dix in measures_dict_list[1:]:
        measures_dict.update(dix)

    if args.no_self_connection:
        total_elem = len(labels_list)**2 - len(labels_list)
        results_elem = len(measures_dict.keys())*2 - len(labels_list)
    else:
        total_elem = len(labels_list)**2
        results_elem = len(measures_dict.keys())*2

    logging.info('Out of {} possible nodes, {} contain value'.format(
        total_elem, results_elem))

    # Filling out all the matrices (symmetric) in the order of labels_list
    nbr_of_measures = len(list(measures_dict.values())[0])
    matrix = np.zeros((len(labels_list), len(labels_list), nbr_of_measures))

    for in_label, out_label in measures_dict:
        curr_node_dict = measures_dict[(in_label, out_label)]
        measures_ordering = list(curr_node_dict.keys())

        for i, measure in enumerate(curr_node_dict):
            in_pos = labels_list.index(in_label)
            out_pos = labels_list.index(out_label)
            matrix[in_pos, out_pos, i] = curr_node_dict[measure]
            matrix[out_pos, in_pos, i] = curr_node_dict[measure]

    # Saving the matrices separatly with the specified name or dps
    for i, measure in enumerate(measures_ordering):
        if measure == 'volume':
            matrix_basename = args.volume
        elif measure == 'streamline_count':
            matrix_basename = args.streamline_count
        elif measure == 'length':
            matrix_basename = args.length
        elif measure == 'similarity':
            matrix_basename = args.similarity[1]
        elif measure in dict_metrics_out_name:
            matrix_basename = dict_metrics_out_name[measure]
        elif measure in dict_maps_out_name:
            matrix_basename = dict_maps_out_name[measure]
        elif measure in dict_lesion_out_name:
            matrix_basename = dict_lesion_out_name[measure]
        else:
            matrix_basename = measure

        np.save(matrix_basename, matrix[:, :, i])
Exemplo n.º 26
0
def main():
    # Callback required for FURY
    def keypress_callback(obj, _):
        key = obj.GetKeySym().lower()
        nonlocal clusters_linewidth, background_linewidth
        nonlocal curr_streamlines_actor, concat_streamlines_actor, show_curr_actor
        iterator = len(accepted_streamlines) + len(rejected_streamlines)
        renwin = interactor_style.GetInteractor().GetRenderWindow()
        renderer = interactor_style.GetCurrentRenderer()

        if key == 'c' and iterator < len(sft_accepted_on_size):
            if show_curr_actor:
                renderer.rm(concat_streamlines_actor)
                renwin.Render()
                show_curr_actor = False
                logging.info('Streamlines rendering OFF')
            else:
                renderer.add(concat_streamlines_actor)
                renderer.rm(curr_streamlines_actor)
                renderer.add(curr_streamlines_actor)
                renwin.Render()
                show_curr_actor = True
                logging.info('Streamlines rendering ON')
            return

        if key == 'q':
            show_manager.exit()
            if iterator < len(sft_accepted_on_size):
                logging.warning(
                    'Early exit, everything remaining to be rejected.')
            return

        if key in ['a', 'r'] and iterator < len(sft_accepted_on_size):
            if key == 'a':
                accepted_streamlines.append(iterator)
                choices.append('a')
                logging.info('Accepted file %s',
                             filename_accepted_on_size[iterator])
            elif key == 'r':
                rejected_streamlines.append(iterator)
                choices.append('r')
                logging.info('Rejected file %s',
                             filename_accepted_on_size[iterator])
            iterator += 1

        if key == 'z':
            if iterator > 0:
                last_choice = choices.pop()
                if last_choice == 'r':
                    rejected_streamlines.pop()
                else:
                    accepted_streamlines.pop()
                logging.info('Rewind on step.')

                iterator -= 1
            else:
                logging.warning('Cannot rewind, first element.')

        if key in ['a', 'r', 'z'] and iterator < len(sft_accepted_on_size):
            renderer.rm(curr_streamlines_actor)
            curr_streamlines = sft_accepted_on_size[iterator].streamlines
            curr_streamlines_actor = actor.line(curr_streamlines,
                                                opacity=0.8,
                                                linewidth=clusters_linewidth)
            renderer.add(curr_streamlines_actor)

        if iterator == len(sft_accepted_on_size):
            print('No more cluster, press q to exit')
            renderer.rm(curr_streamlines_actor)

        renwin.Render()

    parser = _build_args_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_bundles)
    assert_outputs_exist(parser, args, [args.out_accepted, args.out_rejected])

    if args.out_accepted_dir:
        assert_output_dirs_exist_and_empty(parser,
                                           args,
                                           args.out_accepted_dir,
                                           create_dir=True)
    if args.out_rejected_dir:
        assert_output_dirs_exist_and_empty(parser,
                                           args,
                                           args.out_rejected_dir,
                                           create_dir=True)

    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    if args.min_cluster_size < 1:
        parser.error('Minimum cluster size must be at least 1.')

    clusters_linewidth = args.clusters_linewidth
    background_linewidth = args.background_linewidth

    # To accelerate procedure, clusters can be discarded based on size
    # Concatenation is to give spatial context
    sft_accepted_on_size, filename_accepted_on_size = [], []
    sft_rejected_on_size, filename_rejected_on_size = [], []
    concat_streamlines = []
    for filename in args.in_bundles:
        if not is_header_compatible(args.in_bundles[0], filename):
            return
        basename = os.path.basename(filename)
        sft = load_tractogram_with_reference(parser,
                                             args,
                                             filename,
                                             bbox_check=False)
        if len(sft) >= args.min_cluster_size:
            sft_accepted_on_size.append(sft)
            filename_accepted_on_size.append(basename)
            concat_streamlines.extend(sft.streamlines)
        else:
            logging.info('File %s has %s streamlines, automatically rejected.',
                         filename, len(sft))
            sft_rejected_on_size.append(sft)
            filename_rejected_on_size.append(basename)

    if not filename_accepted_on_size:
        parser.error('No cluster survived the cluster_size threshold.')

    logging.info('%s clusters to be classified.', len(sft_accepted_on_size))
    # The clusters are sorted by size for simplicity/efficiency
    tuple_accepted = zip(
        *sorted(zip(sft_accepted_on_size, filename_accepted_on_size),
                key=lambda x: len(x[0]),
                reverse=True))
    sft_accepted_on_size, filename_accepted_on_size = tuple_accepted

    # Initialize the actors, scene, window, observer
    concat_streamlines_actor = actor.line(concat_streamlines,
                                          colors=(1, 1, 1),
                                          opacity=args.background_opacity,
                                          linewidth=background_linewidth)
    curr_streamlines_actor = actor.line(sft_accepted_on_size[0].streamlines,
                                        opacity=0.8,
                                        linewidth=clusters_linewidth)

    scene = window.Scene()
    interactor_style = interactor.CustomInteractorStyle()
    show_manager = window.ShowManager(scene,
                                      size=(800, 800),
                                      reset_camera=False,
                                      interactor_style=interactor_style)
    scene.add(concat_streamlines_actor)
    scene.add(curr_streamlines_actor)
    interactor_style.AddObserver('KeyPressEvent', keypress_callback)

    # Lauch rendering and selection procedure
    choices, accepted_streamlines, rejected_streamlines = [], [], []
    show_curr_actor = True
    show_manager.start()

    # Early exit means everything else is rejected
    missing = len(args.in_bundles) - len(choices) - len(sft_rejected_on_size)
    len_accepted = len(sft_accepted_on_size)
    rejected_streamlines.extend(range(len_accepted - missing, len_accepted))
    if missing > 0:
        logging.info('%s clusters automatically rejected from early exit',
                     missing)

    # Save accepted clusters (by GUI)
    accepted_streamlines = save_clusters(sft_accepted_on_size,
                                         accepted_streamlines,
                                         args.out_accepted_dir,
                                         filename_accepted_on_size)

    accepted_sft = StatefulTractogram(accepted_streamlines,
                                      sft_accepted_on_size[0], Space.RASMM)
    save_tractogram(accepted_sft, args.out_accepted, bbox_valid_check=False)

    # Save rejected clusters (by GUI)
    rejected_streamlines = save_clusters(sft_accepted_on_size,
                                         rejected_streamlines,
                                         args.out_rejected_dir,
                                         filename_accepted_on_size)

    # Save rejected clusters (by size)
    rejected_streamlines.extend(
        save_clusters(sft_rejected_on_size, range(len(sft_rejected_on_size)),
                      args.out_rejected_dir, filename_rejected_on_size))

    rejected_sft = StatefulTractogram(rejected_streamlines,
                                      sft_accepted_on_size[0], Space.RASMM)
    save_tractogram(rejected_sft, args.out_rejected, bbox_valid_check=False)
Exemplo n.º 27
0
"""

affine, dimensions, voxel_sizes, voxel_order = get_reference_info(
    reference_anatomy)
print(affine)
print(dimensions)
print(voxel_sizes)
print(voxel_order)
"""
If you have a Trk file that was generated using a particular anatomy,
to be considered valid all fields must correspond between the headers.
It can be easily verified using this function, which also accept
the same variety of input as ``get_reference_info``
"""

print(is_header_compatible(reference_anatomy, bundles_filename[0]))
"""
If a TRK was generated with a valid header, but the reference NIFTI was lost
a header can be generated to then generate a fake NIFTI file.

If you wish to manually save Trk and Tck file using nibabel streamlines
API for more freedom of action (not recommended for beginners) you can
create a valid header using create_tractogram_header
"""

nifti_header = create_nifti_header(affine, dimensions, voxel_sizes)
nib.save(nib.Nifti1Image(np.zeros(dimensions), affine, nifti_header),
         'fake.nii.gz')
nib.save(reference_anatomy, os.path.basename(ref_anat_filename))
"""
Once loaded, no matter the original file format, the stateful tractogram is
Exemplo n.º 28
0
def _processing_wrapper(args):
    hdf5_filename = args[0]
    labels_img = args[1]
    in_label, out_label = args[2]
    measures_to_compute = copy.copy(args[3])
    if args[4] is not None:
        similarity_directory = args[4][0]
    weighted = args[5]
    include_dps = args[6]

    hdf5_file = h5py.File(hdf5_filename, 'r')
    key = '{}_{}'.format(in_label, out_label)
    if key not in hdf5_file:
        return
    streamlines = reconstruct_streamlines_from_hdf5(hdf5_file, key)

    affine, dimensions, voxel_sizes, _ = get_reference_info(labels_img)
    measures_to_return = {}

    if not (np.allclose(hdf5_file.attrs['affine'], affine, atol=1e-03)
            and np.array_equal(hdf5_file.attrs['dimensions'], dimensions)):
        raise ValueError('Provided hdf5 have incompatible headers.')

    # Precompute to save one transformation, insert later
    if 'length' in measures_to_compute:
        streamlines_copy = list(streamlines)
        # scil_decompose_connectivity.py requires isotropic voxels
        mean_length = np.average(length(streamlines_copy)) * voxel_sizes[0]

    # If density is not required, do not compute it
    # Only required for volume, similarity and any metrics
    if not ((len(measures_to_compute) == 1 and
             ('length' in measures_to_compute
              or 'streamline_count' in measures_to_compute)) or
            (len(measures_to_compute) == 2 and
             ('length' in measures_to_compute
              and 'streamline_count' in measures_to_compute))):

        density = compute_tract_counts_map(streamlines, dimensions)

    if 'volume' in measures_to_compute:
        measures_to_return['volume'] = np.count_nonzero(density) * \
            np.prod(voxel_sizes)
        measures_to_compute.remove('volume')
    if 'streamline_count' in measures_to_compute:
        measures_to_return['streamline_count'] = len(streamlines)
        measures_to_compute.remove('streamline_count')
    if 'length' in measures_to_compute:
        measures_to_return['length'] = mean_length
        measures_to_compute.remove('length')
    if 'similarity' in measures_to_compute and similarity_directory:
        density_sim = load_node_nifti(similarity_directory, in_label,
                                      out_label, labels_img)
        if density_sim is None:
            ba_vox = 0
        else:
            ba_vox = compute_bundle_adjacency_voxel(density, density_sim)

        measures_to_return['similarity'] = ba_vox
        measures_to_compute.remove('similarity')

    for measure in measures_to_compute:
        if isinstance(measure, str) and os.path.isdir(measure):
            map_dirname = measure
            map_data = load_node_nifti(map_dirname, in_label, out_label,
                                       labels_img)
            measures_to_return[map_dirname] = np.average(
                map_data[map_data > 0])
        elif isinstance(measure, tuple) and os.path.isfile(measure[0]):
            metric_filename = measure[0]
            metric_img = measure[1]
            if not is_header_compatible(metric_img, labels_img):
                logging.error('{} do not have a compatible header'.format(
                    metric_filename))
                raise IOError

            metric_data = metric_img.get_fdata(dtype=np.float64)
            if weighted:
                density = density / np.max(density)
                voxels_value = metric_data * density
                voxels_value = voxels_value[voxels_value > 0]
            else:
                voxels_value = metric_data[density > 0]

            measures_to_return[metric_filename] = np.average(voxels_value)

    if include_dps:
        for dps_key in hdf5_file[key].keys():
            if dps_key not in ['data', 'offsets', 'lengths']:
                out_file = os.path.join(include_dps, dps_key)
                measures_to_return[out_file] = np.average(
                    hdf5_file[key][dps_key])

    return {(in_label, out_label): measures_to_return}
Exemplo n.º 29
0
def load_tractogram(filename,
                    reference,
                    to_space=Space.RASMM,
                    shifted_origin=False,
                    bbox_valid_check=True,
                    trk_header_check=True):
    """ Load the stateful tractogram from any format (trk, tck, vtk, fib, dpy)

    Parameters
    ----------
    filename : string
        Filename with valid extension
    reference : Nifti or Trk filename, Nifti1Image or TrkFile, Nifti1Header or
        trk.header (dict), or 'same' if the input is a trk file.
        Reference that provides the spatial attribute.
        Typically a nifti-related object from the native diffusion used for
        streamlines generation
    to_space : Enum (dipy.io.stateful_tractogram.Space)
        Space to which the streamlines will be transformed after loading.
    shifted_origin : bool
        Information on the position of the origin,
        False is Trackvis standard, default (center of the voxel)
        True is NIFTI standard (corner of the voxel)
    bbox_valid_check : bool
        Verification for negative voxel coordinates or values above the
        volume dimensions. Default is True, to enforce valid file.
    trk_header_check : bool
        Verification that the reference has the same header as the spatial
        attributes as the input tractogram when a Trk is loaded

    Returns
    -------
    output : StatefulTractogram
        The tractogram to load (must have been saved properly)
    """
    _, extension = os.path.splitext(filename)
    if extension not in ['.trk', '.tck', '.vtk', '.fib', '.dpy']:
        logging.error('Output filename is not one of the supported format')
        return False

    if to_space not in Space:
        logging.error('Space MUST be one of the 3 choices (Enum)')
        return False

    if reference == 'same':
        if extension == '.trk':
            reference = filename
        else:
            logging.error('Reference must be provided, "same" is only ' +
                          'available for Trk file.')
            return False

    if trk_header_check and extension == '.trk':
        if not is_header_compatible(filename, reference):
            logging.error('Trk file header does not match the provided ' +
                          'reference')
            return False

    timer = time.time()
    data_per_point = None
    data_per_streamline = None
    if extension in ['.trk', '.tck']:
        tractogram_obj = nib.streamlines.load(filename).tractogram
        streamlines = tractogram_obj.streamlines
        if extension == '.trk':
            data_per_point = tractogram_obj.data_per_point
            data_per_streamline = tractogram_obj.data_per_streamline

    elif extension in ['.vtk', '.fib']:
        streamlines = load_vtk_streamlines(filename)
    elif extension in ['.dpy']:
        dpy_obj = Dpy(filename, mode='r')
        streamlines = list(dpy_obj.read_tracks())
        dpy_obj.close()
    logging.debug('Load %s with %s streamlines in %s seconds', filename,
                  len(streamlines), round(time.time() - timer, 3))

    sft = StatefulTractogram(streamlines,
                             reference,
                             Space.RASMM,
                             shifted_origin=shifted_origin,
                             data_per_point=data_per_point,
                             data_per_streamline=data_per_streamline)

    if to_space == Space.VOX:
        sft.to_vox()
    elif to_space == Space.VOXMM:
        sft.to_voxmm()

    if bbox_valid_check and not sft.is_bbox_in_vox_valid():
        raise ValueError('Bounding box is not valid in voxel space, cannot ' +
                         'load a valid file if some coordinates are invalid.' +
                         'Please set bbox_valid_check to False and then use' +
                         'the function remove_invalid_streamlines to discard' +
                         'invalid streamlines.')

    return sft
Exemplo n.º 30
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    assert_inputs_exist(parser, args.in_tractogram)
    assert_outputs_exist(parser, args, args.out_tractogram)
    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    roi_opt_list, only_filtering_list = prepare_filtering_list(parser, args)
    o_dict = {}

    sft = load_tractogram_with_reference(parser, args, args.in_tractogram)

    # Streamline count before filtering
    o_dict['streamline_count_before_filtering'] = len(sft.streamlines)

    for i, roi_opt in enumerate(roi_opt_list):
        curr_dict = {}
        # Atlas needs an extra argument (value in the LUT)
        if roi_opt[0] == 'atlas_roi':
            filter_type, filter_arg, filter_arg_2, \
                filter_mode, filter_criteria = roi_opt
        else:
            filter_type, filter_arg, filter_mode, filter_criteria = roi_opt

        curr_dict['filename'] = os.path.abspath(filter_arg)
        curr_dict['type'] = filter_type
        curr_dict['mode'] = filter_mode
        curr_dict['criteria'] = filter_criteria

        is_exclude = False if filter_criteria == 'include' else True

        if filter_type == 'drawn_roi' or filter_type == 'atlas_roi':
            img = nib.load(filter_arg)
            if not is_header_compatible(img, sft):
                parser.error('Headers from the tractogram and the mask are '
                             'not compatible.')
            if filter_type == 'drawn_roi':
                mask = get_data_as_mask(img)
            else:
                atlas = get_data_as_label(img)
                mask = np.zeros(atlas.shape, dtype=np.uint16)
                mask[atlas == int(filter_arg_2)] = 1
            filtered_sft, indexes = filter_grid_roi(sft, mask,
                                                    filter_mode, is_exclude)

        # For every case, the input number must be greater or equal to 0 and
        # below the dimension, since this is a voxel space operation
        elif filter_type in ['x_plane', 'y_plane', 'z_plane']:
            filter_arg = int(filter_arg)
            _, dim, _, _ = sft.space_attributes
            mask = np.zeros(dim, dtype=np.int16)
            error_msg = None
            if filter_type == 'x_plane':
                if 0 <= filter_arg < dim[0]:
                    mask[filter_arg, :, :] = 1
                else:
                    error_msg = 'X plane ' + str(filter_arg)

            elif filter_type == 'y_plane':
                if 0 <= filter_arg < dim[1]:
                    mask[:, filter_arg, :] = 1
                else:
                    error_msg = 'Y plane ' + str(filter_arg)

            elif filter_type == 'z_plane':
                if 0 <= filter_arg < dim[2]:
                    mask[:, :, filter_arg] = 1
                else:
                    error_msg = 'Z plane ' + str(filter_arg)

            if error_msg:
                parser.error('{} is not valid according to the '
                             'tractogram header.'.format(error_msg))

            filtered_sft, indexes = filter_grid_roi(sft, mask,
                                                    filter_mode, is_exclude)

        elif filter_type == 'bdo':
            geometry, radius, center = read_info_from_mb_bdo(filter_arg)
            if geometry == 'Ellipsoid':
                filtered_sft, indexes = filter_ellipsoid(sft,
                                                         radius, center,
                                                         filter_mode, is_exclude)
            elif geometry == 'Cuboid':
                filtered_sft, indexes = filter_cuboid(sft,
                                                      radius, center,
                                                      filter_mode, is_exclude)

        logging.debug('The filtering options {0} resulted in '
                      '{1} streamlines'.format(roi_opt, len(filtered_sft)))

        sft = filtered_sft

        if only_filtering_list:
            filtering_Name = 'Filter_' + str(i)
            curr_dict['streamline_count_after_filtering'] = len(sft.streamlines)
            o_dict[filtering_Name] = curr_dict

    # Streamline count after filtering
    o_dict['streamline_count_final_filtering'] = len(sft.streamlines)
    if args.display_counts:
        print(json.dumps(o_dict, indent=args.indent))

    if not filtered_sft:
        if args.no_empty:
            logging.debug("The file {} won't be written (0 streamline)".format(
                args.out_tractogram))

            return

        logging.debug('The file {} contains 0 streamline'.format(
            args.out_tractogram))

    save_tractogram(sft, args.out_tractogram)