def score_auto_extract_auto_IBs(streamlines, bundles_masks, ref_bundles, ROIs, wm,
                                save_segmented=False, save_IBs=False,
                                save_VBs=False, save_VCWPs=False,
                                out_segmented_strl_dir='',
                                base_out_segmented_strl='',
                                ref_anat_fname=''):
    """
    TODO document


    Parameters
    ------------
    streamlines : sequence
        sequence of T streamlines. One streamline is an ndarray of shape (N, 3),
        where N is the number of points in that streamline, and
        ``streamlines[t][n]`` is the n-th point in the t-th streamline. Points
        are of form x, y, z in *voxel* coordinates.
    bundles_masks : sequence
        list of nibabel objects corresponding to mask of bundles
    ROIs : sequence
        list of nibabel objects corresponding to mask of ROIs
    wm : nibabel object
        mask of the white matter
    save_segmented : bool
        if true, returns indices of streamlines composing VC, IC, VCWP and NC

    Returns
    ---------
    scores : dict
        dictionnary containing a score for each metric
    indices : dict
        dictionnary containing the indices of streamlines composing VC, IC,
        VCWP and NC

    """

    # Load all streamlines, since streamlines is a generator.
    # full_strl = [s for s in streamlines]

    VC_indices, found_vbs_info = _auto_extract_VCs(streamlines, ref_bundles)
    VC = len(VC_indices)
    logging.debug('Found {} candidate VC'.format(VC))

    if save_VBs:
        _save_extracted_VBs(found_vbs_info, streamlines, out_segmented_strl_dir,
                            base_out_segmented_strl, ref_anat_fname)

    # TODO might be readded
    # To keep track of streamlines that have been classified
    # classified_streamlines_indices = VC_indices

    # New algorithm
    # Step 1: remove streamlines shorter than threshold (currently 35)
    # Step 2: apply Quickbundle with a distance threshold of 20
    # Step 3: remove singletons
    # Step 4: assign to closest ROIs pair
    logging.debug("Starting IC, IB scoring")

    total_strl_count = len(streamlines)
    candidate_ic_strl_indices = sorted(set(range(total_strl_count)) - VC_indices)

    length_thres = 35.

    candidate_ic_streamlines = []
    rejected_streamlines = []

    for idx in candidate_ic_strl_indices:
        if slength(streamlines[idx]) >= length_thres:
            candidate_ic_streamlines.append(streamlines[idx].astype('f4'))
        else:
            rejected_streamlines.append(streamlines[idx].astype('f4'))

    logging.debug('Found {} candidate IC'.format(len(candidate_ic_streamlines)))
    logging.debug('Found {} streamlines that were too short'.format(len(rejected_streamlines)))

    ic_counts = 0
    ib_pairs = {}

    if len(candidate_ic_streamlines):

        # Fix seed to always generate the same output
        # Shuffle to try to reduce the ordering dependency for QB
        random.seed(0.2)
        random.shuffle(candidate_ic_streamlines)


        # TODO threshold on distance as arg
        qb = QuickBundlesX([30, 25, 20, 15])
        clusters_obj = qb.cluster(candidate_ic_streamlines)
        clusters = clusters_obj.get_clusters(-1)  # Retrieves clusters obtained with the smallest threshold.
        # clusters = qb.cluster(candidate_ic_streamlines)

        logging.debug("Found {} potential IB clusters".format(len(clusters)))

        # TODO this should be better handled
        rois_info = []
        for roi in ROIs:
            rois_info.append((get_root_image_name(os.path.basename(roi.get_filename())),
                              np.array(np.where(roi.get_data())).T))

        centroids = nib.streamlines.Tractogram(clusters.centroids)
        centroids.apply_affine(np.linalg.inv(ROIs[0].affine))
        all_centroids_closest_pairs = get_closest_roi_pairs_for_all_streamlines(centroids.streamlines, rois_info)

        for c_idx, c in enumerate(clusters):
            closest_for_cluster = all_centroids_closest_pairs[c_idx]

            if closest_for_cluster not in ib_pairs:
                ib_pairs[closest_for_cluster] = []

            ic_counts += len(c)
            ib_pairs[closest_for_cluster].extend(c.indices)

        # all_ics_closest_pairs = get_closest_roi_pairs_for_all_streamlines(candidate_ic_streamlines, rois_info)

        # for c_idx, c in enumerate(clusters):
        #     closest_for_cluster = [all_ics_closest_pairs[i] for i in clusters[c]['indices']]

        #     if len(clusters[c]['indices']) > 1:
        #         ic_counts += len(clusters[c]['indices'])
        #         occurences = Counter(closest_for_cluster)

        #         # TODO handle either an equality or maybe a range
        #         most_frequent = occurences.most_common(1)[0][0]

        #         val = ib_pairs.get(most_frequent)
        #         if val is None:
        #             # Check if flipped pair exists
        #             val1 = ib_pairs.get((most_frequent[1], most_frequent[0]))
        #             if val1 is not None:
        #                 val1.append(c_idx)
        #             else:
        #                 ib_pairs[most_frequent] = [c_idx]
        #         else:
        #             val.append(c_idx)
        #     else:
        #         rejected_streamlines.append(candidate_ic_streamlines[clusters[c]['indices'][0]])

        if save_segmented and save_IBs:
            for k, v in ib_pairs.iteritems():
                out_strl = []
                # for c_idx in v:
                #     out_strl.extend([s for s in np.array(candidate_ic_streamlines)[clusters[c_idx]['indices']]])
                out_strl = np.array(candidate_ic_streamlines)[v]
                out_fname = os.path.join(out_segmented_strl_dir,
                                         base_out_segmented_strl + \
                                         '_IB_{0}_{1}.tck'.format(k[0], k[1]))

                nib.streamlines.save(nib.streamlines.Tractogram(out_strl, affine_to_rasmm=np.eye(4)), out_fname)
                # ib_f = TCK.create(out_fname)
                # save_tracts_tck_from_dipy_voxel_space(ib_f, ref_anat_fname,
                                                      # out_strl)

    if len(rejected_streamlines) > 0 and save_segmented:
        out_nc_fname = os.path.join(out_segmented_strl_dir,
                                    '{}_NC.tck'.format(base_out_segmented_strl))
        nib.streamlines.save(nib.streamlines.Tractogram(rejected_streamlines, affine_to_rasmm=np.eye(4)), out_nc_fname)
        # out_file = TCK.create(out_nc_fname)
        # save_tracts_tck_from_dipy_voxel_space(out_file, ref_anat_fname,
                                              # rejected_streamlines)

    # TODO readd classifed_steamlines_indices to validate
    if ic_counts != len(candidate_ic_strl_indices) - len(rejected_streamlines):
        raise ValueError("Some streamlines were not correctly assigned to NC")

    VC /= total_strl_count
    IC = (len(candidate_ic_strl_indices) - len(rejected_streamlines)) / total_strl_count
    NC = len(rejected_streamlines) / total_strl_count
    VCWP = 0

    # TODO could have sanity check on global extracted streamlines vs all
    # possible indices

    nb_VB_found = [v['nb_streamlines'] > 0 for k, v in found_vbs_info.iteritems()].count(True)
    streamlines_per_bundle = {k: v['nb_streamlines'] for k, v in found_vbs_info.iteritems() if v['nb_streamlines'] > 0}

    scores = {}
    scores['version'] = 2
    scores['algo_version'] = 5
    scores['VC'] = VC
    scores['IC'] = IC
    scores['VCWP'] = VCWP
    scores['NC'] = NC
    scores['VB'] = nb_VB_found
    scores['IB'] = len(ib_pairs.keys())
    scores['streamlines_per_bundle'] = streamlines_per_bundle
    scores['total_streamlines_count'] = total_strl_count

    return scores
Esempio n. 2
0
def score_submission(streamlines_fname,
                     tracts_attribs,
                     base_data_dir,
                     basic_bundles_attribs,
                     save_full_vc=False,
                     save_full_ic=False,
                     save_full_nc=False,
                     save_IBs=False,
                     save_VBs=False,
                     segmented_out_dir='',
                     segmented_base_name='',
                     verbose=False):
    """
    Score a submission, using the following algorithm:
        1: extract all streamlines that are valid, which are classified as
           Valid Connections (VC) making up Valid Bundles (VB).
        2: remove streamlines shorter than an threshold based on the GT dataset
        3: cluster the remaining streamlines
        4: remove singletons
        5: assign each cluster to the closest ROIs pair. Those make up the
           Invalid Connections (IC), grouped as Invalid Bundles (IB).
        6: streamlines that are neither in VC nor IC are classified as
           No Connection (NC).


    Parameters
    ------------
    streamlines_fname : string
        path to the file containing the streamlines.
    tracts_attribs : dictionary
        contains the attributes of the submission. Must contain the
        'orientation' attribute for .vtk files.
    base_data_dir : string
        path to the direction containing the scoring data.
    basic_bundles_attribs : dictionary
        contains the attributes of the basic bundles (name, list of streamlines,
        segmentation threshold)
    save_full_vc : bool
        indicates if the full set of VC will be saved in an individual file.
    save_full_ic : bool
        indicates if the full set of IC will be saved in an individual file.
    save_full_nc : bool
        indicates if the full set of NC will be saved in an individual file.
    save_IBs : bool
        indicates if the invalid bundles will be saved in individual file for
        each IB.
    save_VBs : bool
        indicates if the valid bundles will be saved in individual file for
        each VB.
    segmented_out_dir : string
        the path to the directory where segmented files will be saved.
    segmented_base_name : string
        the base name to use for saving segmented files.
    verbose : bool
        indicates if the algorithm needs to be verbose when logging messages.

    Returns
    ---------
    scores : dict
        dictionnary containing a score for each metric
    """

    if verbose:
        logging.basicConfig(level=logging.DEBUG)

    # Prepare needed scoring data
    logging.debug('Preparing GT data')
    masks_dir = os.path.join(base_data_dir, "masks")
    rois_dir = os.path.join(masks_dir, "rois")
    bundles_dir = os.path.join(base_data_dir, "bundles")
    bundles_masks_dir = os.path.join(masks_dir, "bundles")
    ref_anat_fname = os.path.join(masks_dir, "wm.nii.gz")

    ROIs = [
        nib.load(os.path.join(rois_dir, f))
        for f in sorted(os.listdir(rois_dir))
    ]

    ref_bundles = _prepare_gt_bundles_info(bundles_dir, bundles_masks_dir,
                                           basic_bundles_attribs,
                                           ref_anat_fname)

    streamlines_gen = get_tracts_voxel_space_for_dipy(streamlines_fname,
                                                      ref_anat_fname,
                                                      tracts_attribs)

    # Load all streamlines, since streamlines is a generator.
    full_strl = [s for s in streamlines_gen]

    # Extract VCs and VBs
    VC_indices, found_vbs_info, bundles_found = auto_extract_VCs(
        full_strl, ref_bundles)
    VC = len(VC_indices)

    if save_VBs or save_full_vc:
        save_valid_connections(found_vbs_info,
                               full_strl,
                               segmented_out_dir,
                               segmented_base_name,
                               ref_anat_fname,
                               save_vbs=save_VBs,
                               save_full_vc=save_full_vc)

    logging.debug("Starting IC, IB scoring")

    total_strl_count = len(full_strl)
    candidate_ic_strl_indices = sorted(
        set(range(total_strl_count)) - VC_indices)

    candidate_ic_streamlines = []
    rejected_streamlines = []
    rejected_idx = []
    candidate_idx = []

    # Chosen from GT dataset
    length_thres = 35.

    # Filter streamlines that are too short, consider them as NC
    for idx in candidate_ic_strl_indices:
        if slength(full_strl[idx]) >= length_thres:
            candidate_ic_streamlines.append(full_strl[idx].astype('f4'))
            candidate_idx.append(idx)
        else:
            rejected_streamlines.append(full_strl[idx].astype('f4'))
            rejected_idx.append(idx)

    logging.debug('Found {} candidate IC'.format(
        len(candidate_ic_streamlines)))
    logging.debug('Found {} streamlines that were too short'.format(
        len(rejected_streamlines)))

    ic_counts = 0
    nb_ib = 0

    if len(candidate_ic_streamlines):
        additional_rejected, ic_counts, nb_ib, rejected_c_idx = \
            group_and_assign_ibs(
                                   candidate_ic_streamlines,
                                   ROIs, save_IBs, save_full_ic,
                                   segmented_out_dir,
                                   segmented_base_name,
                                   ref_anat_fname)

        rejected_streamlines.extend(additional_rejected)
        rejected_idx.append([candidate_idx[c_idx] for c_idx in rejected_c_idx])

    if ic_counts != len(candidate_ic_strl_indices) - len(rejected_streamlines):
        raise ValueError("Some streamlines were not correctly assigned to NC")

    if len(rejected_streamlines) > 0 and save_full_nc:
        out_nc_fname = os.path.join(segmented_out_dir,
                                    '{}_NC.tck'.format(segmented_base_name))
        out_file = TCK.create(out_nc_fname)
        save_tracts_tck_from_dipy_voxel_space(out_file, ref_anat_fname,
                                              rejected_streamlines)

    VC /= total_strl_count
    IC = (len(candidate_ic_strl_indices) -
          len(rejected_streamlines)) / total_strl_count
    NC = len(rejected_streamlines) / total_strl_count
    VCWP = 0

    nb_VB_found = [
        v['nb_streamlines'] > 0 for k, v in found_vbs_info.iteritems()
    ].count(True)
    streamlines_per_bundle = {
        k: v['nb_streamlines']
        for k, v in found_vbs_info.iteritems() if v['nb_streamlines'] > 0
    }

    scores = {}
    scores['version'] = 2
    scores['algo_version'] = 5
    scores['VC'] = VC
    scores['IC'] = IC
    scores['VCWP'] = VCWP
    scores['NC'] = NC
    scores['VB'] = nb_VB_found
    scores['IB'] = nb_ib
    scores['streamlines_per_bundle'] = streamlines_per_bundle
    scores['total_streamlines_count'] = total_strl_count
    scores['ami_rejected_indices'] = rejected_idx
    scores['ami_bundles_found'] = bundles_found
    scores['ami_VC_indices'] = list(VC_indices)

    # Get bundle overlap, overreach and f1-score for each bundle.
    scores['overlap_per_bundle'] = {
        k: v["overlap"]
        for k, v in found_vbs_info.items()
    }
    scores['overreach_per_bundle'] = {
        k: v["overreach"]
        for k, v in found_vbs_info.items()
    }
    scores['overreach_norm_gt_per_bundle'] = {
        k: v["overreach_norm"]
        for k, v in found_vbs_info.items()
    }
    scores['f1_score_per_bundle'] = {
        k: v["f1_score"]
        for k, v in found_vbs_info.items()
    }

    # Compute average bundle overlap, overreach and f1-score.
    scores['mean_OL'] = np.mean(list(scores['overlap_per_bundle'].values()))
    scores['mean_OR'] = np.mean(list(scores['overreach_per_bundle'].values()))
    scores['mean_ORn'] = np.mean(
        list(scores['overreach_norm_gt_per_bundle'].values()))
    scores['mean_F1'] = np.mean(list(scores['f1_score_per_bundle'].values()))

    return scores
Esempio n. 3
0
    def clean_tractogram(self, tractogram, affine_vox2mask):
        """
        Remove potential "non-connections" by filtering according to
        curvature, length and mask

        Parameters:
        -----------
        tractogram: Tractogram
            Full tractogram

        Returns:
        --------
        tractogram: Tractogram
            Filtered tractogram
        """
        print('Cleaning tractogram ... ', end='', flush=True)

        streamlines = tractogram.streamlines

        # # Filter by curvature
        # dirty_mask = is_flag_set(
        #     stopping_flags, StoppingFlags.STOPPING_CURVATURE)
        dirty_mask = np.zeros(len(streamlines))

        # Filter by length unless the streamline ends in GM
        # Example case: Bundle 3 of fibercup tends to be shorter than 35

        lengths = [slength(s) for s in streamlines]
        short_lengths = np.asarray([lgt <= self.min_length for lgt in lengths])

        dirty_mask = np.logical_or(short_lengths, dirty_mask)

        long_lengths = np.asarray([lgt > 200. for lgt in lengths])

        dirty_mask = np.logical_or(long_lengths, dirty_mask)

        # start_mask = is_inside_mask(
        #     np.asarray([s[0] for s in streamlines])[:, None],
        #     self.target_mask.data, affine_vox2mask, 0.5)

        # assert(np.any(start_mask))

        # end_mask = is_inside_mask(
        #     np.asarray([s[-1] for s in streamlines])[:, None],
        #     self.target_mask.data, affine_vox2mask, 0.5)

        # assert(np.any(end_mask))

        # mask_mask = np.logical_not(np.logical_and(start_mask, end_mask))

        # dirty_mask = np.logical_or(
        #     dirty_mask,
        #     mask_mask)

        # Filter by loops
        # For example: A streamline ending and starting in the same roi
        looping_mask = np.array([winding(s) for s in streamlines]) > 330
        dirty_mask = np.logical_or(dirty_mask, looping_mask)

        # Only keep valid streamlines
        valid_indices = np.nonzero(np.logical_not(dirty_mask))
        clean_streamlines = streamlines[valid_indices]
        clean_dps = tractogram.data_per_streamline[valid_indices]
        print('Done !')

        print('Kept {}/{} streamlines'.format(len(valid_indices[0]),
                                              len(streamlines)))

        return Tractogram(clean_streamlines, data_per_streamline=clean_dps)
Esempio n. 4
0
def score_submission(streamlines_fname,
                     tracts_attribs,
                     base_data_dir,
                     basic_bundles_attribs,
                     save_full_vc=False,
                     save_full_ic=False,
                     save_full_nc=False,
                     save_IBs=False,
                     save_VBs=False,
                     segmented_out_dir='',
                     segmented_base_name='',
                     verbose=False):
    """
    Score a submission, using the following algorithm:
        1: extract all streamlines that are valid, which are classified as
           Valid Connections (VC) making up Valid Bundles (VB).
        2: remove streamlines shorter than an threshold based on the GT dataset
        3: cluster the remaining streamlines
        4: remove singletons
        5: assign each cluster to the closest ROIs pair. Those make up the
           Invalid Connections (IC), grouped as Invalid Bundles (IB).
        6: streamlines that are neither in VC nor IC are classified as
           No Connection (NC).


    Parameters
    ------------
    streamlines_fname : string
        path to the file containing the streamlines.
    tracts_attribs : dictionary
        contains the attributes of the submission. Must contain the
        'orientation' attribute for .vtk files.
    base_data_dir : string
        path to the direction containing the scoring data.
    basic_bundles_attribs : dictionary
        contains the attributes of the basic bundles (name, list of streamlines,
        segmentation threshold)
    save_full_vc : bool
        indicates if the full set of VC will be saved in an individual file.
    save_full_ic : bool
        indicates if the full set of IC will be saved in an individual file.
    save_full_nc : bool
        indicates if the full set of NC will be saved in an individual file.
    save_IBs : bool
        indicates if the invalid bundles will be saved in individual file for
        each IB.
    save_VBs : bool
        indicates if the valid bundles will be saved in individual file for
        each VB.
    segmented_out_dir : string
        the path to the directory where segmented files will be saved.
    segmented_base_name : string
        the base name to use for saving segmented files.
    verbose : bool
        indicates if the algorithm needs to be verbose when logging messages.

    Returns
    ---------
    scores : dict
        dictionnary containing a score for each metric
    """

    if verbose:
        logging.basicConfig(level=logging.DEBUG)

    # Prepare needed scoring data
    logging.debug('Preparing GT data')
    masks_dir = os.path.join(base_data_dir, "masks")
    rois_dir = os.path.join(masks_dir, "rois")
    bundles_dir = os.path.join(base_data_dir, "bundles")
    bundles_masks_dir = os.path.join(masks_dir, "bundles")
    ref_anat_fname = os.path.join(masks_dir, "wm.nii.gz")

    ROIs = [nib.load(os.path.join(rois_dir, f))
            for f in sorted(os.listdir(rois_dir))]

    ref_bundles = _prepare_gt_bundles_info(bundles_dir,
                                           bundles_masks_dir,
                                           basic_bundles_attribs,
                                           ref_anat_fname)

    streamlines_gen = get_tracts_voxel_space_for_dipy(streamlines_fname,
                                                      ref_anat_fname,
                                                      tracts_attribs)

    # Load all streamlines, since streamlines is a generator.
    full_strl = [s for s in streamlines_gen]

    # Extract VCs and VBs
    VC_indices, found_vbs_info = auto_extract_VCs(full_strl, ref_bundles)
    VC = len(VC_indices)

    if save_VBs or save_full_vc:
        save_valid_connections(found_vbs_info, full_strl, segmented_out_dir,
                               segmented_base_name, ref_anat_fname,
                               save_vbs=save_VBs, save_full_vc=save_full_vc)

    logging.debug("Starting IC, IB scoring")

    total_strl_count = len(full_strl)
    candidate_ic_strl_indices = sorted(set(range(total_strl_count)) - VC_indices)

    candidate_ic_streamlines = []
    rejected_streamlines = []

    # Chosen from GT dataset
    length_thres = 35.

    # Filter streamlines that are too short, consider them as NC
    for idx in candidate_ic_strl_indices:
        if slength(full_strl[idx]) >= length_thres:
            candidate_ic_streamlines.append(full_strl[idx].astype('f4'))
        else:
            rejected_streamlines.append(full_strl[idx].astype('f4'))

    logging.debug('Found {} candidate IC'.format(len(candidate_ic_streamlines)))
    logging.debug('Found {} streamlines that were too short'.format(len(rejected_streamlines)))

    ic_counts = 0
    nb_ib = 0

    if len(candidate_ic_streamlines):
        additional_rejected, ic_counts, nb_ib = group_and_assign_ibs(
                                                   candidate_ic_streamlines,
                                                   ROIs, save_IBs, save_full_ic,
                                                   segmented_out_dir,
                                                   segmented_base_name,
                                                   ref_anat_fname)

        rejected_streamlines.extend(additional_rejected)

    if ic_counts != len(candidate_ic_strl_indices) - len(rejected_streamlines):
        raise ValueError("Some streamlines were not correctly assigned to NC")

    if len(rejected_streamlines) > 0 and save_full_nc:
        out_nc_fname = os.path.join(segmented_out_dir,
                                    '{}_NC.tck'.format(segmented_base_name))
        out_file = TCK.create(out_nc_fname)
        save_tracts_tck_from_dipy_voxel_space(out_file, ref_anat_fname,
                                              rejected_streamlines)

    VC /= total_strl_count
    IC = (len(candidate_ic_strl_indices) - len(rejected_streamlines)) / total_strl_count
    NC = len(rejected_streamlines) / total_strl_count
    VCWP = 0

    nb_VB_found = [v['nb_streamlines'] > 0 for k, v in found_vbs_info.iteritems()].count(True)
    streamlines_per_bundle = {k: v['nb_streamlines'] for k, v in found_vbs_info.iteritems() if v['nb_streamlines'] > 0}

    scores = {}
    scores['version'] = 2
    scores['algo_version'] = 5
    scores['VC'] = VC
    scores['IC'] = IC
    scores['VCWP'] = VCWP
    scores['NC'] = NC
    scores['VB'] = nb_VB_found
    scores['IB'] = nb_ib
    scores['streamlines_per_bundle'] = streamlines_per_bundle
    scores['total_streamlines_count'] = total_strl_count

    # Get bundle overlap, overreach and f1-score for each bundle.
    scores['overlap_per_bundle'] = {k: v["overlap"] for k, v in found_vbs_info.items()}
    scores['overreach_per_bundle'] = {k: v["overreach"] for k, v in found_vbs_info.items()}
    scores['overreach_norm_gt_per_bundle'] = {k: v["overreach_norm"] for k, v in found_vbs_info.items()}
    scores['f1_score_per_bundle'] = {k: v["f1_score"] for k, v in found_vbs_info.items()}

    # Compute average bundle overlap, overreach and f1-score.
    scores['mean_OL'] = np.mean(list(scores['overlap_per_bundle'].values()))
    scores['mean_OR'] = np.mean(list(scores['overreach_per_bundle'].values()))
    scores['mean_ORn'] = np.mean(list(scores['overreach_norm_gt_per_bundle'].values()))
    scores['mean_F1'] = np.mean(list(scores['f1_score_per_bundle'].values()))

    return scores
Esempio n. 5
0
def score_auto_extract_auto_IBs(streamlines,
                                bundles_masks,
                                ref_bundles,
                                ROIs,
                                wm,
                                save_segmented=False,
                                save_IBs=False,
                                save_VBs=False,
                                save_VCWPs=False,
                                out_segmented_strl_dir='',
                                base_out_segmented_strl='',
                                ref_anat_fname=''):
    """
    TODO document


    Parameters
    ------------
    streamlines : sequence
        sequence of T streamlines. One streamline is an ndarray of shape (N, 3),
        where N is the number of points in that streamline, and
        ``streamlines[t][n]`` is the n-th point in the t-th streamline. Points
        are of form x, y, z in *voxel* coordinates.
    bundles_masks : sequence
        list of nibabel objects corresponding to mask of bundles
    ROIs : sequence
        list of nibabel objects corresponding to mask of ROIs
    wm : nibabel object
        mask of the white matter
    save_segmented : bool
        if true, returns indices of streamlines composing VC, IC, VCWP and NC

    Returns
    ---------
    scores : dict
        dictionnary containing a score for each metric
    indices : dict
        dictionnary containing the indices of streamlines composing VC, IC,
        VCWP and NC

    """

    # Load all streamlines, since streamlines is a generator.
    # full_strl = [s for s in streamlines]

    VC_indices, found_vbs_info = _auto_extract_VCs(streamlines, ref_bundles)
    VC = len(VC_indices)
    logging.debug('Found {} candidate VC'.format(VC))

    if save_VBs:
        _save_extracted_VBs(found_vbs_info, streamlines,
                            out_segmented_strl_dir, base_out_segmented_strl,
                            ref_anat_fname)

    # TODO might be readded
    # To keep track of streamlines that have been classified
    # classified_streamlines_indices = VC_indices

    # New algorithm
    # Step 1: remove streamlines shorter than threshold (currently 35)
    # Step 2: apply Quickbundle with a distance threshold of 20
    # Step 3: remove singletons
    # Step 4: assign to closest ROIs pair
    logging.debug("Starting IC, IB scoring")

    total_strl_count = len(streamlines)
    candidate_ic_strl_indices = sorted(
        set(range(total_strl_count)) - VC_indices)

    length_thres = 35.

    candidate_ic_streamlines = []
    rejected_streamlines = []

    for idx in candidate_ic_strl_indices:
        if slength(streamlines[idx]) >= length_thres:
            candidate_ic_streamlines.append(streamlines[idx].astype('f4'))
        else:
            rejected_streamlines.append(streamlines[idx].astype('f4'))

    logging.debug('Found {} candidate IC'.format(
        len(candidate_ic_streamlines)))
    logging.debug('Found {} streamlines that were too short'.format(
        len(rejected_streamlines)))

    ic_counts = 0
    ib_pairs = {}

    if len(candidate_ic_streamlines):

        # Fix seed to always generate the same output
        # Shuffle to try to reduce the ordering dependency for QB
        random.seed(0.2)
        random.shuffle(candidate_ic_streamlines)

        # TODO threshold on distance as arg
        qb = QuickBundlesX([30, 25, 20, 15])
        clusters_obj = qb.cluster(candidate_ic_streamlines)
        clusters = clusters_obj.get_clusters(
            -1)  # Retrieves clusters obtained with the smallest threshold.
        # clusters = qb.cluster(candidate_ic_streamlines)

        logging.debug("Found {} potential IB clusters".format(len(clusters)))

        # TODO this should be better handled
        rois_info = []
        for roi in ROIs:
            rois_info.append(
                (get_root_image_name(os.path.basename(roi.get_filename())),
                 np.array(np.where(roi.get_data())).T))

        centroids = nib.streamlines.Tractogram(clusters.centroids)
        centroids.apply_affine(np.linalg.inv(ROIs[0].affine))
        all_centroids_closest_pairs = get_closest_roi_pairs_for_all_streamlines(
            centroids.streamlines, rois_info)

        for c_idx, c in enumerate(clusters):
            closest_for_cluster = all_centroids_closest_pairs[c_idx]

            if closest_for_cluster not in ib_pairs:
                ib_pairs[closest_for_cluster] = []

            ic_counts += len(c)
            ib_pairs[closest_for_cluster].extend(c.indices)

        # all_ics_closest_pairs = get_closest_roi_pairs_for_all_streamlines(candidate_ic_streamlines, rois_info)

        # for c_idx, c in enumerate(clusters):
        #     closest_for_cluster = [all_ics_closest_pairs[i] for i in clusters[c]['indices']]

        #     if len(clusters[c]['indices']) > 1:
        #         ic_counts += len(clusters[c]['indices'])
        #         occurences = Counter(closest_for_cluster)

        #         # TODO handle either an equality or maybe a range
        #         most_frequent = occurences.most_common(1)[0][0]

        #         val = ib_pairs.get(most_frequent)
        #         if val is None:
        #             # Check if flipped pair exists
        #             val1 = ib_pairs.get((most_frequent[1], most_frequent[0]))
        #             if val1 is not None:
        #                 val1.append(c_idx)
        #             else:
        #                 ib_pairs[most_frequent] = [c_idx]
        #         else:
        #             val.append(c_idx)
        #     else:
        #         rejected_streamlines.append(candidate_ic_streamlines[clusters[c]['indices'][0]])

        if save_segmented and save_IBs:
            for k, v in ib_pairs.iteritems():
                out_strl = []
                # for c_idx in v:
                #     out_strl.extend([s for s in np.array(candidate_ic_streamlines)[clusters[c_idx]['indices']]])
                out_strl = np.array(candidate_ic_streamlines)[v]
                out_fname = os.path.join(out_segmented_strl_dir,
                                         base_out_segmented_strl + \
                                         '_IB_{0}_{1}.tck'.format(k[0], k[1]))

                nib.streamlines.save(
                    nib.streamlines.Tractogram(out_strl,
                                               affine_to_rasmm=np.eye(4)),
                    out_fname)
                # ib_f = TCK.create(out_fname)
                # save_tracts_tck_from_dipy_voxel_space(ib_f, ref_anat_fname,
                # out_strl)

    if len(rejected_streamlines) > 0 and save_segmented:
        out_nc_fname = os.path.join(
            out_segmented_strl_dir,
            '{}_NC.tck'.format(base_out_segmented_strl))
        nib.streamlines.save(
            nib.streamlines.Tractogram(rejected_streamlines,
                                       affine_to_rasmm=np.eye(4)),
            out_nc_fname)
        # out_file = TCK.create(out_nc_fname)
        # save_tracts_tck_from_dipy_voxel_space(out_file, ref_anat_fname,
        # rejected_streamlines)

    # TODO readd classifed_steamlines_indices to validate
    if ic_counts != len(candidate_ic_strl_indices) - len(rejected_streamlines):
        raise ValueError("Some streamlines were not correctly assigned to NC")

    VC /= total_strl_count
    IC = (len(candidate_ic_strl_indices) -
          len(rejected_streamlines)) / total_strl_count
    NC = len(rejected_streamlines) / total_strl_count
    VCWP = 0

    # TODO could have sanity check on global extracted streamlines vs all
    # possible indices

    nb_VB_found = [
        v['nb_streamlines'] > 0 for k, v in found_vbs_info.iteritems()
    ].count(True)
    streamlines_per_bundle = {
        k: v['nb_streamlines']
        for k, v in found_vbs_info.iteritems() if v['nb_streamlines'] > 0
    }

    scores = {}
    scores['version'] = 2
    scores['algo_version'] = 5
    scores['VC'] = VC
    scores['IC'] = IC
    scores['VCWP'] = VCWP
    scores['NC'] = NC
    scores['VB'] = nb_VB_found
    scores['IB'] = len(ib_pairs.keys())
    scores['streamlines_per_bundle'] = streamlines_per_bundle
    scores['total_streamlines_count'] = total_strl_count

    return scores
Esempio n. 6
0
    def display(
        self,
        valid_tractogram: Tractogram,
        env: BaseEnv,
        valid_reward: float = 0,
        i_episode: int = 0,
        run_tractometer: bool = False,
    ):
        """
        Stats stuff

        There's so much going on in this function, it should be split or
        something

        Parameters
        ----------
        valid_tractogram: Tractogram
            Tractogram containing all the streamlines tracked during the last
            validation run
        env: BaseEnv
            Environment used to render streamlines
        valid_reward: np.ndarray of float of size
            Reward of the last validation run
        i_episode: int
            Current episode
        """

        lens = [slength(s) for s in valid_tractogram.streamlines]

        avg_length = np.mean(lens)  # Euclidian length

        print('---------------------------------------------------')
        print(self.experiment_path)
        print('Episode {} \t avg length: {} \t total reward: {}'.format(
            i_episode,
            avg_length,
            valid_reward))
        print('---------------------------------------------------')

        # Save tractogram so it can be looked at, used by the tractometer
        # and more
        filename = pjoin(self.experiment_path,
                         "tractogram_{}_{}_{}.trk".format(
                             self.experiment, self.name, self.test_subject_id))

        self._save_tractogram(
            valid_tractogram,
            self.reference_file,
            Space.VOX,
            filename)

        if self.comet_experiment is not None:
            if self.run_tractometer and run_tractometer:
                #  Load bundle attributes for tractometer
                # TODO: No need to load this every time, should only be loaded
                # once
                gt_bundles_attribs_path = pjoin(
                    self.ground_truth_folder, 'gt_bundles_attributes.json')
                basic_bundles_attribs = load_attribs(gt_bundles_attribs_path)

                # Score tractogram
                scores = score_submission(
                    filename,
                    {'orientation': 'unknown'},
                    self.ground_truth_folder,
                    basic_bundles_attribs)

                self.vc_monitor.update(scores['VC'])
                self.ic_monitor.update(scores['IC'])
                self.nc_monitor.update(scores['NC'])
                self.vb_monitor.update(scores['VB'])
                self.ib_monitor.update(scores['IB'])
                self.ol_monitor.update(scores['mean_OL'])

                self.vc_monitor.end_epoch(i_episode)
                self.ic_monitor.end_epoch(i_episode)
                self.nc_monitor.end_epoch(i_episode)
                self.vb_monitor.end_epoch(i_episode)
                self.ib_monitor.end_epoch(i_episode)
                self.ol_monitor.end_epoch(i_episode)

            if self.render:
                # Save image of tractogram to be displayed in comet
                env.render(
                    valid_tractogram,
                    '{}.png'.format(i_episode))

            # Update monitors
            self.len_monitor.update(avg_length)
            self.len_monitor.end_epoch(i_episode)

            self.reward_monitor.update(valid_reward)
            self.reward_monitor.end_epoch(i_episode)

            # Update comet
            self.comet_monitor.update(
                self.reward_monitor,
                self.len_monitor,
                self.vc_monitor,
                self.ic_monitor,
                self.nc_monitor,
                self.vb_monitor,
                self.ib_monitor,
                self.ol_monitor,
                i_episode=i_episode)