def score_auto_extract_auto_IBs(streamlines, bundles_masks, ref_bundles, ROIs, wm, save_segmented=False, save_IBs=False, save_VBs=False, save_VCWPs=False, out_segmented_strl_dir='', base_out_segmented_strl='', ref_anat_fname=''): """ TODO document Parameters ------------ streamlines : sequence sequence of T streamlines. One streamline is an ndarray of shape (N, 3), where N is the number of points in that streamline, and ``streamlines[t][n]`` is the n-th point in the t-th streamline. Points are of form x, y, z in *voxel* coordinates. bundles_masks : sequence list of nibabel objects corresponding to mask of bundles ROIs : sequence list of nibabel objects corresponding to mask of ROIs wm : nibabel object mask of the white matter save_segmented : bool if true, returns indices of streamlines composing VC, IC, VCWP and NC Returns --------- scores : dict dictionnary containing a score for each metric indices : dict dictionnary containing the indices of streamlines composing VC, IC, VCWP and NC """ # Load all streamlines, since streamlines is a generator. # full_strl = [s for s in streamlines] VC_indices, found_vbs_info = _auto_extract_VCs(streamlines, ref_bundles) VC = len(VC_indices) logging.debug('Found {} candidate VC'.format(VC)) if save_VBs: _save_extracted_VBs(found_vbs_info, streamlines, out_segmented_strl_dir, base_out_segmented_strl, ref_anat_fname) # TODO might be readded # To keep track of streamlines that have been classified # classified_streamlines_indices = VC_indices # New algorithm # Step 1: remove streamlines shorter than threshold (currently 35) # Step 2: apply Quickbundle with a distance threshold of 20 # Step 3: remove singletons # Step 4: assign to closest ROIs pair logging.debug("Starting IC, IB scoring") total_strl_count = len(streamlines) candidate_ic_strl_indices = sorted(set(range(total_strl_count)) - VC_indices) length_thres = 35. candidate_ic_streamlines = [] rejected_streamlines = [] for idx in candidate_ic_strl_indices: if slength(streamlines[idx]) >= length_thres: candidate_ic_streamlines.append(streamlines[idx].astype('f4')) else: rejected_streamlines.append(streamlines[idx].astype('f4')) logging.debug('Found {} candidate IC'.format(len(candidate_ic_streamlines))) logging.debug('Found {} streamlines that were too short'.format(len(rejected_streamlines))) ic_counts = 0 ib_pairs = {} if len(candidate_ic_streamlines): # Fix seed to always generate the same output # Shuffle to try to reduce the ordering dependency for QB random.seed(0.2) random.shuffle(candidate_ic_streamlines) # TODO threshold on distance as arg qb = QuickBundlesX([30, 25, 20, 15]) clusters_obj = qb.cluster(candidate_ic_streamlines) clusters = clusters_obj.get_clusters(-1) # Retrieves clusters obtained with the smallest threshold. # clusters = qb.cluster(candidate_ic_streamlines) logging.debug("Found {} potential IB clusters".format(len(clusters))) # TODO this should be better handled rois_info = [] for roi in ROIs: rois_info.append((get_root_image_name(os.path.basename(roi.get_filename())), np.array(np.where(roi.get_data())).T)) centroids = nib.streamlines.Tractogram(clusters.centroids) centroids.apply_affine(np.linalg.inv(ROIs[0].affine)) all_centroids_closest_pairs = get_closest_roi_pairs_for_all_streamlines(centroids.streamlines, rois_info) for c_idx, c in enumerate(clusters): closest_for_cluster = all_centroids_closest_pairs[c_idx] if closest_for_cluster not in ib_pairs: ib_pairs[closest_for_cluster] = [] ic_counts += len(c) ib_pairs[closest_for_cluster].extend(c.indices) # all_ics_closest_pairs = get_closest_roi_pairs_for_all_streamlines(candidate_ic_streamlines, rois_info) # for c_idx, c in enumerate(clusters): # closest_for_cluster = [all_ics_closest_pairs[i] for i in clusters[c]['indices']] # if len(clusters[c]['indices']) > 1: # ic_counts += len(clusters[c]['indices']) # occurences = Counter(closest_for_cluster) # # TODO handle either an equality or maybe a range # most_frequent = occurences.most_common(1)[0][0] # val = ib_pairs.get(most_frequent) # if val is None: # # Check if flipped pair exists # val1 = ib_pairs.get((most_frequent[1], most_frequent[0])) # if val1 is not None: # val1.append(c_idx) # else: # ib_pairs[most_frequent] = [c_idx] # else: # val.append(c_idx) # else: # rejected_streamlines.append(candidate_ic_streamlines[clusters[c]['indices'][0]]) if save_segmented and save_IBs: for k, v in ib_pairs.iteritems(): out_strl = [] # for c_idx in v: # out_strl.extend([s for s in np.array(candidate_ic_streamlines)[clusters[c_idx]['indices']]]) out_strl = np.array(candidate_ic_streamlines)[v] out_fname = os.path.join(out_segmented_strl_dir, base_out_segmented_strl + \ '_IB_{0}_{1}.tck'.format(k[0], k[1])) nib.streamlines.save(nib.streamlines.Tractogram(out_strl, affine_to_rasmm=np.eye(4)), out_fname) # ib_f = TCK.create(out_fname) # save_tracts_tck_from_dipy_voxel_space(ib_f, ref_anat_fname, # out_strl) if len(rejected_streamlines) > 0 and save_segmented: out_nc_fname = os.path.join(out_segmented_strl_dir, '{}_NC.tck'.format(base_out_segmented_strl)) nib.streamlines.save(nib.streamlines.Tractogram(rejected_streamlines, affine_to_rasmm=np.eye(4)), out_nc_fname) # out_file = TCK.create(out_nc_fname) # save_tracts_tck_from_dipy_voxel_space(out_file, ref_anat_fname, # rejected_streamlines) # TODO readd classifed_steamlines_indices to validate if ic_counts != len(candidate_ic_strl_indices) - len(rejected_streamlines): raise ValueError("Some streamlines were not correctly assigned to NC") VC /= total_strl_count IC = (len(candidate_ic_strl_indices) - len(rejected_streamlines)) / total_strl_count NC = len(rejected_streamlines) / total_strl_count VCWP = 0 # TODO could have sanity check on global extracted streamlines vs all # possible indices nb_VB_found = [v['nb_streamlines'] > 0 for k, v in found_vbs_info.iteritems()].count(True) streamlines_per_bundle = {k: v['nb_streamlines'] for k, v in found_vbs_info.iteritems() if v['nb_streamlines'] > 0} scores = {} scores['version'] = 2 scores['algo_version'] = 5 scores['VC'] = VC scores['IC'] = IC scores['VCWP'] = VCWP scores['NC'] = NC scores['VB'] = nb_VB_found scores['IB'] = len(ib_pairs.keys()) scores['streamlines_per_bundle'] = streamlines_per_bundle scores['total_streamlines_count'] = total_strl_count return scores
def score_submission(streamlines_fname, tracts_attribs, base_data_dir, basic_bundles_attribs, save_full_vc=False, save_full_ic=False, save_full_nc=False, save_IBs=False, save_VBs=False, segmented_out_dir='', segmented_base_name='', verbose=False): """ Score a submission, using the following algorithm: 1: extract all streamlines that are valid, which are classified as Valid Connections (VC) making up Valid Bundles (VB). 2: remove streamlines shorter than an threshold based on the GT dataset 3: cluster the remaining streamlines 4: remove singletons 5: assign each cluster to the closest ROIs pair. Those make up the Invalid Connections (IC), grouped as Invalid Bundles (IB). 6: streamlines that are neither in VC nor IC are classified as No Connection (NC). Parameters ------------ streamlines_fname : string path to the file containing the streamlines. tracts_attribs : dictionary contains the attributes of the submission. Must contain the 'orientation' attribute for .vtk files. base_data_dir : string path to the direction containing the scoring data. basic_bundles_attribs : dictionary contains the attributes of the basic bundles (name, list of streamlines, segmentation threshold) save_full_vc : bool indicates if the full set of VC will be saved in an individual file. save_full_ic : bool indicates if the full set of IC will be saved in an individual file. save_full_nc : bool indicates if the full set of NC will be saved in an individual file. save_IBs : bool indicates if the invalid bundles will be saved in individual file for each IB. save_VBs : bool indicates if the valid bundles will be saved in individual file for each VB. segmented_out_dir : string the path to the directory where segmented files will be saved. segmented_base_name : string the base name to use for saving segmented files. verbose : bool indicates if the algorithm needs to be verbose when logging messages. Returns --------- scores : dict dictionnary containing a score for each metric """ if verbose: logging.basicConfig(level=logging.DEBUG) # Prepare needed scoring data logging.debug('Preparing GT data') masks_dir = os.path.join(base_data_dir, "masks") rois_dir = os.path.join(masks_dir, "rois") bundles_dir = os.path.join(base_data_dir, "bundles") bundles_masks_dir = os.path.join(masks_dir, "bundles") ref_anat_fname = os.path.join(masks_dir, "wm.nii.gz") ROIs = [ nib.load(os.path.join(rois_dir, f)) for f in sorted(os.listdir(rois_dir)) ] ref_bundles = _prepare_gt_bundles_info(bundles_dir, bundles_masks_dir, basic_bundles_attribs, ref_anat_fname) streamlines_gen = get_tracts_voxel_space_for_dipy(streamlines_fname, ref_anat_fname, tracts_attribs) # Load all streamlines, since streamlines is a generator. full_strl = [s for s in streamlines_gen] # Extract VCs and VBs VC_indices, found_vbs_info, bundles_found = auto_extract_VCs( full_strl, ref_bundles) VC = len(VC_indices) if save_VBs or save_full_vc: save_valid_connections(found_vbs_info, full_strl, segmented_out_dir, segmented_base_name, ref_anat_fname, save_vbs=save_VBs, save_full_vc=save_full_vc) logging.debug("Starting IC, IB scoring") total_strl_count = len(full_strl) candidate_ic_strl_indices = sorted( set(range(total_strl_count)) - VC_indices) candidate_ic_streamlines = [] rejected_streamlines = [] rejected_idx = [] candidate_idx = [] # Chosen from GT dataset length_thres = 35. # Filter streamlines that are too short, consider them as NC for idx in candidate_ic_strl_indices: if slength(full_strl[idx]) >= length_thres: candidate_ic_streamlines.append(full_strl[idx].astype('f4')) candidate_idx.append(idx) else: rejected_streamlines.append(full_strl[idx].astype('f4')) rejected_idx.append(idx) logging.debug('Found {} candidate IC'.format( len(candidate_ic_streamlines))) logging.debug('Found {} streamlines that were too short'.format( len(rejected_streamlines))) ic_counts = 0 nb_ib = 0 if len(candidate_ic_streamlines): additional_rejected, ic_counts, nb_ib, rejected_c_idx = \ group_and_assign_ibs( candidate_ic_streamlines, ROIs, save_IBs, save_full_ic, segmented_out_dir, segmented_base_name, ref_anat_fname) rejected_streamlines.extend(additional_rejected) rejected_idx.append([candidate_idx[c_idx] for c_idx in rejected_c_idx]) if ic_counts != len(candidate_ic_strl_indices) - len(rejected_streamlines): raise ValueError("Some streamlines were not correctly assigned to NC") if len(rejected_streamlines) > 0 and save_full_nc: out_nc_fname = os.path.join(segmented_out_dir, '{}_NC.tck'.format(segmented_base_name)) out_file = TCK.create(out_nc_fname) save_tracts_tck_from_dipy_voxel_space(out_file, ref_anat_fname, rejected_streamlines) VC /= total_strl_count IC = (len(candidate_ic_strl_indices) - len(rejected_streamlines)) / total_strl_count NC = len(rejected_streamlines) / total_strl_count VCWP = 0 nb_VB_found = [ v['nb_streamlines'] > 0 for k, v in found_vbs_info.iteritems() ].count(True) streamlines_per_bundle = { k: v['nb_streamlines'] for k, v in found_vbs_info.iteritems() if v['nb_streamlines'] > 0 } scores = {} scores['version'] = 2 scores['algo_version'] = 5 scores['VC'] = VC scores['IC'] = IC scores['VCWP'] = VCWP scores['NC'] = NC scores['VB'] = nb_VB_found scores['IB'] = nb_ib scores['streamlines_per_bundle'] = streamlines_per_bundle scores['total_streamlines_count'] = total_strl_count scores['ami_rejected_indices'] = rejected_idx scores['ami_bundles_found'] = bundles_found scores['ami_VC_indices'] = list(VC_indices) # Get bundle overlap, overreach and f1-score for each bundle. scores['overlap_per_bundle'] = { k: v["overlap"] for k, v in found_vbs_info.items() } scores['overreach_per_bundle'] = { k: v["overreach"] for k, v in found_vbs_info.items() } scores['overreach_norm_gt_per_bundle'] = { k: v["overreach_norm"] for k, v in found_vbs_info.items() } scores['f1_score_per_bundle'] = { k: v["f1_score"] for k, v in found_vbs_info.items() } # Compute average bundle overlap, overreach and f1-score. scores['mean_OL'] = np.mean(list(scores['overlap_per_bundle'].values())) scores['mean_OR'] = np.mean(list(scores['overreach_per_bundle'].values())) scores['mean_ORn'] = np.mean( list(scores['overreach_norm_gt_per_bundle'].values())) scores['mean_F1'] = np.mean(list(scores['f1_score_per_bundle'].values())) return scores
def clean_tractogram(self, tractogram, affine_vox2mask): """ Remove potential "non-connections" by filtering according to curvature, length and mask Parameters: ----------- tractogram: Tractogram Full tractogram Returns: -------- tractogram: Tractogram Filtered tractogram """ print('Cleaning tractogram ... ', end='', flush=True) streamlines = tractogram.streamlines # # Filter by curvature # dirty_mask = is_flag_set( # stopping_flags, StoppingFlags.STOPPING_CURVATURE) dirty_mask = np.zeros(len(streamlines)) # Filter by length unless the streamline ends in GM # Example case: Bundle 3 of fibercup tends to be shorter than 35 lengths = [slength(s) for s in streamlines] short_lengths = np.asarray([lgt <= self.min_length for lgt in lengths]) dirty_mask = np.logical_or(short_lengths, dirty_mask) long_lengths = np.asarray([lgt > 200. for lgt in lengths]) dirty_mask = np.logical_or(long_lengths, dirty_mask) # start_mask = is_inside_mask( # np.asarray([s[0] for s in streamlines])[:, None], # self.target_mask.data, affine_vox2mask, 0.5) # assert(np.any(start_mask)) # end_mask = is_inside_mask( # np.asarray([s[-1] for s in streamlines])[:, None], # self.target_mask.data, affine_vox2mask, 0.5) # assert(np.any(end_mask)) # mask_mask = np.logical_not(np.logical_and(start_mask, end_mask)) # dirty_mask = np.logical_or( # dirty_mask, # mask_mask) # Filter by loops # For example: A streamline ending and starting in the same roi looping_mask = np.array([winding(s) for s in streamlines]) > 330 dirty_mask = np.logical_or(dirty_mask, looping_mask) # Only keep valid streamlines valid_indices = np.nonzero(np.logical_not(dirty_mask)) clean_streamlines = streamlines[valid_indices] clean_dps = tractogram.data_per_streamline[valid_indices] print('Done !') print('Kept {}/{} streamlines'.format(len(valid_indices[0]), len(streamlines))) return Tractogram(clean_streamlines, data_per_streamline=clean_dps)
def score_submission(streamlines_fname, tracts_attribs, base_data_dir, basic_bundles_attribs, save_full_vc=False, save_full_ic=False, save_full_nc=False, save_IBs=False, save_VBs=False, segmented_out_dir='', segmented_base_name='', verbose=False): """ Score a submission, using the following algorithm: 1: extract all streamlines that are valid, which are classified as Valid Connections (VC) making up Valid Bundles (VB). 2: remove streamlines shorter than an threshold based on the GT dataset 3: cluster the remaining streamlines 4: remove singletons 5: assign each cluster to the closest ROIs pair. Those make up the Invalid Connections (IC), grouped as Invalid Bundles (IB). 6: streamlines that are neither in VC nor IC are classified as No Connection (NC). Parameters ------------ streamlines_fname : string path to the file containing the streamlines. tracts_attribs : dictionary contains the attributes of the submission. Must contain the 'orientation' attribute for .vtk files. base_data_dir : string path to the direction containing the scoring data. basic_bundles_attribs : dictionary contains the attributes of the basic bundles (name, list of streamlines, segmentation threshold) save_full_vc : bool indicates if the full set of VC will be saved in an individual file. save_full_ic : bool indicates if the full set of IC will be saved in an individual file. save_full_nc : bool indicates if the full set of NC will be saved in an individual file. save_IBs : bool indicates if the invalid bundles will be saved in individual file for each IB. save_VBs : bool indicates if the valid bundles will be saved in individual file for each VB. segmented_out_dir : string the path to the directory where segmented files will be saved. segmented_base_name : string the base name to use for saving segmented files. verbose : bool indicates if the algorithm needs to be verbose when logging messages. Returns --------- scores : dict dictionnary containing a score for each metric """ if verbose: logging.basicConfig(level=logging.DEBUG) # Prepare needed scoring data logging.debug('Preparing GT data') masks_dir = os.path.join(base_data_dir, "masks") rois_dir = os.path.join(masks_dir, "rois") bundles_dir = os.path.join(base_data_dir, "bundles") bundles_masks_dir = os.path.join(masks_dir, "bundles") ref_anat_fname = os.path.join(masks_dir, "wm.nii.gz") ROIs = [nib.load(os.path.join(rois_dir, f)) for f in sorted(os.listdir(rois_dir))] ref_bundles = _prepare_gt_bundles_info(bundles_dir, bundles_masks_dir, basic_bundles_attribs, ref_anat_fname) streamlines_gen = get_tracts_voxel_space_for_dipy(streamlines_fname, ref_anat_fname, tracts_attribs) # Load all streamlines, since streamlines is a generator. full_strl = [s for s in streamlines_gen] # Extract VCs and VBs VC_indices, found_vbs_info = auto_extract_VCs(full_strl, ref_bundles) VC = len(VC_indices) if save_VBs or save_full_vc: save_valid_connections(found_vbs_info, full_strl, segmented_out_dir, segmented_base_name, ref_anat_fname, save_vbs=save_VBs, save_full_vc=save_full_vc) logging.debug("Starting IC, IB scoring") total_strl_count = len(full_strl) candidate_ic_strl_indices = sorted(set(range(total_strl_count)) - VC_indices) candidate_ic_streamlines = [] rejected_streamlines = [] # Chosen from GT dataset length_thres = 35. # Filter streamlines that are too short, consider them as NC for idx in candidate_ic_strl_indices: if slength(full_strl[idx]) >= length_thres: candidate_ic_streamlines.append(full_strl[idx].astype('f4')) else: rejected_streamlines.append(full_strl[idx].astype('f4')) logging.debug('Found {} candidate IC'.format(len(candidate_ic_streamlines))) logging.debug('Found {} streamlines that were too short'.format(len(rejected_streamlines))) ic_counts = 0 nb_ib = 0 if len(candidate_ic_streamlines): additional_rejected, ic_counts, nb_ib = group_and_assign_ibs( candidate_ic_streamlines, ROIs, save_IBs, save_full_ic, segmented_out_dir, segmented_base_name, ref_anat_fname) rejected_streamlines.extend(additional_rejected) if ic_counts != len(candidate_ic_strl_indices) - len(rejected_streamlines): raise ValueError("Some streamlines were not correctly assigned to NC") if len(rejected_streamlines) > 0 and save_full_nc: out_nc_fname = os.path.join(segmented_out_dir, '{}_NC.tck'.format(segmented_base_name)) out_file = TCK.create(out_nc_fname) save_tracts_tck_from_dipy_voxel_space(out_file, ref_anat_fname, rejected_streamlines) VC /= total_strl_count IC = (len(candidate_ic_strl_indices) - len(rejected_streamlines)) / total_strl_count NC = len(rejected_streamlines) / total_strl_count VCWP = 0 nb_VB_found = [v['nb_streamlines'] > 0 for k, v in found_vbs_info.iteritems()].count(True) streamlines_per_bundle = {k: v['nb_streamlines'] for k, v in found_vbs_info.iteritems() if v['nb_streamlines'] > 0} scores = {} scores['version'] = 2 scores['algo_version'] = 5 scores['VC'] = VC scores['IC'] = IC scores['VCWP'] = VCWP scores['NC'] = NC scores['VB'] = nb_VB_found scores['IB'] = nb_ib scores['streamlines_per_bundle'] = streamlines_per_bundle scores['total_streamlines_count'] = total_strl_count # Get bundle overlap, overreach and f1-score for each bundle. scores['overlap_per_bundle'] = {k: v["overlap"] for k, v in found_vbs_info.items()} scores['overreach_per_bundle'] = {k: v["overreach"] for k, v in found_vbs_info.items()} scores['overreach_norm_gt_per_bundle'] = {k: v["overreach_norm"] for k, v in found_vbs_info.items()} scores['f1_score_per_bundle'] = {k: v["f1_score"] for k, v in found_vbs_info.items()} # Compute average bundle overlap, overreach and f1-score. scores['mean_OL'] = np.mean(list(scores['overlap_per_bundle'].values())) scores['mean_OR'] = np.mean(list(scores['overreach_per_bundle'].values())) scores['mean_ORn'] = np.mean(list(scores['overreach_norm_gt_per_bundle'].values())) scores['mean_F1'] = np.mean(list(scores['f1_score_per_bundle'].values())) return scores
def score_auto_extract_auto_IBs(streamlines, bundles_masks, ref_bundles, ROIs, wm, save_segmented=False, save_IBs=False, save_VBs=False, save_VCWPs=False, out_segmented_strl_dir='', base_out_segmented_strl='', ref_anat_fname=''): """ TODO document Parameters ------------ streamlines : sequence sequence of T streamlines. One streamline is an ndarray of shape (N, 3), where N is the number of points in that streamline, and ``streamlines[t][n]`` is the n-th point in the t-th streamline. Points are of form x, y, z in *voxel* coordinates. bundles_masks : sequence list of nibabel objects corresponding to mask of bundles ROIs : sequence list of nibabel objects corresponding to mask of ROIs wm : nibabel object mask of the white matter save_segmented : bool if true, returns indices of streamlines composing VC, IC, VCWP and NC Returns --------- scores : dict dictionnary containing a score for each metric indices : dict dictionnary containing the indices of streamlines composing VC, IC, VCWP and NC """ # Load all streamlines, since streamlines is a generator. # full_strl = [s for s in streamlines] VC_indices, found_vbs_info = _auto_extract_VCs(streamlines, ref_bundles) VC = len(VC_indices) logging.debug('Found {} candidate VC'.format(VC)) if save_VBs: _save_extracted_VBs(found_vbs_info, streamlines, out_segmented_strl_dir, base_out_segmented_strl, ref_anat_fname) # TODO might be readded # To keep track of streamlines that have been classified # classified_streamlines_indices = VC_indices # New algorithm # Step 1: remove streamlines shorter than threshold (currently 35) # Step 2: apply Quickbundle with a distance threshold of 20 # Step 3: remove singletons # Step 4: assign to closest ROIs pair logging.debug("Starting IC, IB scoring") total_strl_count = len(streamlines) candidate_ic_strl_indices = sorted( set(range(total_strl_count)) - VC_indices) length_thres = 35. candidate_ic_streamlines = [] rejected_streamlines = [] for idx in candidate_ic_strl_indices: if slength(streamlines[idx]) >= length_thres: candidate_ic_streamlines.append(streamlines[idx].astype('f4')) else: rejected_streamlines.append(streamlines[idx].astype('f4')) logging.debug('Found {} candidate IC'.format( len(candidate_ic_streamlines))) logging.debug('Found {} streamlines that were too short'.format( len(rejected_streamlines))) ic_counts = 0 ib_pairs = {} if len(candidate_ic_streamlines): # Fix seed to always generate the same output # Shuffle to try to reduce the ordering dependency for QB random.seed(0.2) random.shuffle(candidate_ic_streamlines) # TODO threshold on distance as arg qb = QuickBundlesX([30, 25, 20, 15]) clusters_obj = qb.cluster(candidate_ic_streamlines) clusters = clusters_obj.get_clusters( -1) # Retrieves clusters obtained with the smallest threshold. # clusters = qb.cluster(candidate_ic_streamlines) logging.debug("Found {} potential IB clusters".format(len(clusters))) # TODO this should be better handled rois_info = [] for roi in ROIs: rois_info.append( (get_root_image_name(os.path.basename(roi.get_filename())), np.array(np.where(roi.get_data())).T)) centroids = nib.streamlines.Tractogram(clusters.centroids) centroids.apply_affine(np.linalg.inv(ROIs[0].affine)) all_centroids_closest_pairs = get_closest_roi_pairs_for_all_streamlines( centroids.streamlines, rois_info) for c_idx, c in enumerate(clusters): closest_for_cluster = all_centroids_closest_pairs[c_idx] if closest_for_cluster not in ib_pairs: ib_pairs[closest_for_cluster] = [] ic_counts += len(c) ib_pairs[closest_for_cluster].extend(c.indices) # all_ics_closest_pairs = get_closest_roi_pairs_for_all_streamlines(candidate_ic_streamlines, rois_info) # for c_idx, c in enumerate(clusters): # closest_for_cluster = [all_ics_closest_pairs[i] for i in clusters[c]['indices']] # if len(clusters[c]['indices']) > 1: # ic_counts += len(clusters[c]['indices']) # occurences = Counter(closest_for_cluster) # # TODO handle either an equality or maybe a range # most_frequent = occurences.most_common(1)[0][0] # val = ib_pairs.get(most_frequent) # if val is None: # # Check if flipped pair exists # val1 = ib_pairs.get((most_frequent[1], most_frequent[0])) # if val1 is not None: # val1.append(c_idx) # else: # ib_pairs[most_frequent] = [c_idx] # else: # val.append(c_idx) # else: # rejected_streamlines.append(candidate_ic_streamlines[clusters[c]['indices'][0]]) if save_segmented and save_IBs: for k, v in ib_pairs.iteritems(): out_strl = [] # for c_idx in v: # out_strl.extend([s for s in np.array(candidate_ic_streamlines)[clusters[c_idx]['indices']]]) out_strl = np.array(candidate_ic_streamlines)[v] out_fname = os.path.join(out_segmented_strl_dir, base_out_segmented_strl + \ '_IB_{0}_{1}.tck'.format(k[0], k[1])) nib.streamlines.save( nib.streamlines.Tractogram(out_strl, affine_to_rasmm=np.eye(4)), out_fname) # ib_f = TCK.create(out_fname) # save_tracts_tck_from_dipy_voxel_space(ib_f, ref_anat_fname, # out_strl) if len(rejected_streamlines) > 0 and save_segmented: out_nc_fname = os.path.join( out_segmented_strl_dir, '{}_NC.tck'.format(base_out_segmented_strl)) nib.streamlines.save( nib.streamlines.Tractogram(rejected_streamlines, affine_to_rasmm=np.eye(4)), out_nc_fname) # out_file = TCK.create(out_nc_fname) # save_tracts_tck_from_dipy_voxel_space(out_file, ref_anat_fname, # rejected_streamlines) # TODO readd classifed_steamlines_indices to validate if ic_counts != len(candidate_ic_strl_indices) - len(rejected_streamlines): raise ValueError("Some streamlines were not correctly assigned to NC") VC /= total_strl_count IC = (len(candidate_ic_strl_indices) - len(rejected_streamlines)) / total_strl_count NC = len(rejected_streamlines) / total_strl_count VCWP = 0 # TODO could have sanity check on global extracted streamlines vs all # possible indices nb_VB_found = [ v['nb_streamlines'] > 0 for k, v in found_vbs_info.iteritems() ].count(True) streamlines_per_bundle = { k: v['nb_streamlines'] for k, v in found_vbs_info.iteritems() if v['nb_streamlines'] > 0 } scores = {} scores['version'] = 2 scores['algo_version'] = 5 scores['VC'] = VC scores['IC'] = IC scores['VCWP'] = VCWP scores['NC'] = NC scores['VB'] = nb_VB_found scores['IB'] = len(ib_pairs.keys()) scores['streamlines_per_bundle'] = streamlines_per_bundle scores['total_streamlines_count'] = total_strl_count return scores
def display( self, valid_tractogram: Tractogram, env: BaseEnv, valid_reward: float = 0, i_episode: int = 0, run_tractometer: bool = False, ): """ Stats stuff There's so much going on in this function, it should be split or something Parameters ---------- valid_tractogram: Tractogram Tractogram containing all the streamlines tracked during the last validation run env: BaseEnv Environment used to render streamlines valid_reward: np.ndarray of float of size Reward of the last validation run i_episode: int Current episode """ lens = [slength(s) for s in valid_tractogram.streamlines] avg_length = np.mean(lens) # Euclidian length print('---------------------------------------------------') print(self.experiment_path) print('Episode {} \t avg length: {} \t total reward: {}'.format( i_episode, avg_length, valid_reward)) print('---------------------------------------------------') # Save tractogram so it can be looked at, used by the tractometer # and more filename = pjoin(self.experiment_path, "tractogram_{}_{}_{}.trk".format( self.experiment, self.name, self.test_subject_id)) self._save_tractogram( valid_tractogram, self.reference_file, Space.VOX, filename) if self.comet_experiment is not None: if self.run_tractometer and run_tractometer: # Load bundle attributes for tractometer # TODO: No need to load this every time, should only be loaded # once gt_bundles_attribs_path = pjoin( self.ground_truth_folder, 'gt_bundles_attributes.json') basic_bundles_attribs = load_attribs(gt_bundles_attribs_path) # Score tractogram scores = score_submission( filename, {'orientation': 'unknown'}, self.ground_truth_folder, basic_bundles_attribs) self.vc_monitor.update(scores['VC']) self.ic_monitor.update(scores['IC']) self.nc_monitor.update(scores['NC']) self.vb_monitor.update(scores['VB']) self.ib_monitor.update(scores['IB']) self.ol_monitor.update(scores['mean_OL']) self.vc_monitor.end_epoch(i_episode) self.ic_monitor.end_epoch(i_episode) self.nc_monitor.end_epoch(i_episode) self.vb_monitor.end_epoch(i_episode) self.ib_monitor.end_epoch(i_episode) self.ol_monitor.end_epoch(i_episode) if self.render: # Save image of tractogram to be displayed in comet env.render( valid_tractogram, '{}.png'.format(i_episode)) # Update monitors self.len_monitor.update(avg_length) self.len_monitor.end_epoch(i_episode) self.reward_monitor.update(valid_reward) self.reward_monitor.end_epoch(i_episode) # Update comet self.comet_monitor.update( self.reward_monitor, self.len_monitor, self.vc_monitor, self.ic_monitor, self.nc_monitor, self.vb_monitor, self.ib_monitor, self.ol_monitor, i_episode=i_episode)