Ejemplo n.º 1
0
def _load_streamlines_from_hdf(hdf_group: h5py.Group):
    streamlines = ArraySequence()
    streamlines._data = np.array(hdf_group['data'])
    streamlines._offsets = np.array(hdf_group['offsets'])
    streamlines._lengths = np.array(hdf_group['lengths'])

    return streamlines
Ejemplo n.º 2
0
Archivo: dpy.py Proyecto: MarcCote/dipy
 def read_tracksi(self, indices):
     """ read tracks with specific indices
     """
     tracks = Streamlines()
     for i in indices:
         off0, off1 = self.offsets[i:i + 2]
         tracks.append(self.tracks[off0:off1])
     return tracks
Ejemplo n.º 3
0
 def read_tracksi(self, indices):
     """ read tracks with specific indices
     """
     tracks = Streamlines()
     for i in indices:
         off0, off1 = self.offsets[i:i + 2]
         tracks.append(self.tracks[off0:off1])
     return tracks
Ejemplo n.º 4
0
Archivo: dpy.py Proyecto: MarcCote/dipy
 def read_tracks(self):
     """ read the entire tractography
     """
     I = self.offsets[:]
     TR = self.tracks[:]
     tracks = Streamlines()
     for i in range(len(I) - 1):
         off0, off1 = I[i:i + 2]
         tracks.append(TR[off0:off1])
     return tracks
Ejemplo n.º 5
0
 def read_tracks(self):
     """ read the entire tractography
     """
     I = self.offsets[:]
     TR = self.tracks[:]
     tracks = Streamlines()
     for i in range(len(I) - 1):
         off0, off1 = I[i:i + 2]
         tracks.append(TR[off0:off1])
     return tracks
Ejemplo n.º 6
0
def get_seeds_from_wm(wm_path, threshold=0):

    wm_file = nib.load(wm_path)
    wm_img = wm_file.get_fdata()

    seeds = np.argwhere(wm_img > threshold)
    seeds = np.hstack([seeds, np.ones([len(seeds), 1])])

    seeds = (wm_file.affine.dot(seeds.T).T)[:, :3].reshape(-1, 1, 3)

    n_seeds = len(seeds)

    header = TrkFile.create_empty_header()

    header["voxel_to_rasmm"] = wm_file.affine
    header["dimensions"] = wm_file.header["dim"][1:4]
    header["voxel_sizes"] = wm_file.header["pixdim"][1:4]
    header["voxel_order"] = get_reference_info(wm_file)[3]

    tractogram = Tractogram(streamlines=ArraySequence(seeds),
                            affine_to_rasmm=np.eye(4))

    save_path = os.path.join(os.path.dirname(wm_path), "seeds_from_wm.trk")

    print("Saving {}".format(save_path))
    TrkFile(tractogram, header).save(save_path)
Ejemplo n.º 7
0
def _compute_streamline_mean(cur_ind, cur_min, cur_max, data):
    # From the precomputed indices, compute the binary map
    # and use it to weight the metric data for this specific streamline.
    cur_range = tuple(cur_max - cur_min)
    streamline_density = compute_tract_counts_map(ArraySequence([cur_ind]),
                                                  cur_range)
    streamline_data = data[cur_min[0]:cur_max[0], cur_min[1]:cur_max[1],
                           cur_min[2]:cur_max[2]]
    streamline_average = np.average(streamline_data,
                                    weights=streamline_density)
    return streamline_average
def multiprocess_subsampling(args):
    streamlines = args[0]
    min_distance = args[1]
    cluster_thr = args[2]
    min_cluster_size = args[3]
    average_streamlines = args[4]

    min_cluster_size = max(min_cluster_size, 1)
    thresholds = [40, 30, 20, cluster_thr]
    cluster_map = qbx_and_merge(ArraySequence(streamlines),
                                thresholds,
                                nb_pts=20,
                                verbose=False)

    return subsample_clusters(cluster_map, streamlines, min_distance,
                              min_cluster_size, average_streamlines)
Ejemplo n.º 9
0
def _process_streamlines(streamlines):
    # Compute the bounding boxes and indices for all streamlines.
    mins = []
    maxs = []
    offset_streamlines = []

    # Offset the streamlines to compute the indices only in the bounding box.
    # Reduces memory use later on.
    for idx, s in enumerate(streamlines):
        mins.append(np.min(s.astype(int), 0))
        maxs.append(np.max(s.astype(int), 0) + 1)
        offset_streamlines.append((s - mins[-1]).astype(np.float32))

    offset_streamlines = ArraySequence(offset_streamlines)
    indices = uncompress(offset_streamlines)

    return mins, maxs, indices
def main():
    parser = build_parser()
    args = parser.parse_args()
    print(args)

    with Timer("Loading streamlines"):
        trk = nib.streamlines.load(args.tractogram)
        losses = trk.tractogram.data_per_streamline['loss']
        del trk.tractogram.data_per_streamline['loss']  # Not supported in MI-Brain for my version.

    with Timer("Coloring streamlines"):
        viridis = plt.get_cmap('RdYlGn')

        losses = -losses[:, 0]
        losses -= losses.mean()

        vmin = losses.min()
        vmax = losses.max()

        if args.normalization == "norm":
            cNorm = colors.Normalize(vmin=vmin, vmax=vmax)
        elif args.normalization == "log":
            cNorm = colors.LogNorm(vmin=vmin, vmax=vmax)
        elif args.normalization == "symlog":
            cNorm = colors.SymLogNorm(linthresh=0.03, linscale=1, vmin=vmin, vmax=vmax)
        else:
            raise ValueError("Unkown normalization: {}".format(args.normalization))

        scalarMap = cm.ScalarMappable(norm=cNorm, cmap=viridis)
        print(scalarMap.get_clim())
        # losses -= losses.mean()
        # losses /= losses.std()
        streamlines_colors = scalarMap.to_rgba(losses, bytes=True)[:, :-1]

        # from dipy.viz import fvtk
        # streamlines_colors = fvtk.create_colormap(-losses[:, 0]) * 255
        colors_per_point = ArraySequence([np.tile(c, (len(s), 1)) for s, c in zip(trk.tractogram.streamlines, streamlines_colors)])
        trk.tractogram.data_per_point['color'] = colors_per_point

    with Timer("Saving streamlines"):
        if args.out is None:
           args.out = args.tractogram[:-4] + "_color_" + args.normalization + args.tractogram[-4:]

        nib.streamlines.save(trk.tractogram, args.out)
Ejemplo n.º 11
0
def tracking(image,
             bvecs,
             bvals,
             wm,
             seeds,
             fibers,
             prune_length=3,
             rseed=42,
             plot=False,
             proba=False,
             verbose=False):
    # Pipelines transcribed from:
    #   https://dipy.org/documentation/1.1.1./examples_built/tracking_introduction_eudx/#example-tracking-introduction-eudx
    #   https://dipy.org/documentation/1.1.1./examples_built/tracking_probabilistic/

    # Load Images
    dwi_loaded = nib.load(image)
    dwi_data = dwi_loaded.get_fdata()

    wm_loaded = nib.load(wm)
    wm_data = wm_loaded.get_fdata()

    seeds_loaded = nib.load(seeds)
    seeds_data = seeds_loaded.get_fdata()
    seeds = utils.seeds_from_mask(seeds_data, dwi_loaded.affine, density=2)

    # Load B-values & B-vectors
    # NB. Use aligned b-vecs if providing eddy-aligned data
    bvals, bvecs = read_bvals_bvecs(bvals, bvecs)
    gtab = gradient_table(bvals, bvecs)
    csa_model = CsaOdfModel(gtab, sh_order=6)

    # Set stopping criterion
    gfa = csa_model.fit(dwi_data, mask=wm_data).gfa
    stop_criterion = ThresholdStoppingCriterion(gfa, .25)

    if proba:
        # Establish ODF model
        response, ratio = auto_response(gtab,
                                        dwi_data,
                                        roi_radius=10,
                                        fa_thr=0.7)
        csd_model = ConstrainedSphericalDeconvModel(gtab, response, sh_order=6)
        csd_fit = csd_model.fit(dwi_data, mask=wm_data)

        # Create Probabilisitic direction getter
        fod = csd_fit.odf(default_sphere)
        pmf = fod.clip(min=0)
        prob_dg = ProbabilisticDirectionGetter.from_pmf(pmf,
                                                        max_angle=30.,
                                                        sphere=default_sphere)
        # Use the probabilisitic direction getter as the dg
        dg = prob_dg

    else:
        # Establish ODF model
        csa_peaks = peaks_from_model(csa_model,
                                     dwi_data,
                                     default_sphere,
                                     relative_peak_threshold=0.8,
                                     min_separation_angle=45,
                                     mask=wm_data)

        # Use the CSA peaks as the dg
        dg = csa_peaks

    # Create generator and perform tracing
    s_generator = LocalTracking(dg,
                                stop_criterion,
                                seeds,
                                dwi_loaded.affine,
                                0.5,
                                random_seed=rseed)
    streamlines = Streamlines(s_generator)

    # Prune streamlines
    streamlines = ArraySequence(
        [strline for strline in streamlines if len(strline) > prune_length])
    sft = StatefulTractogram(streamlines, dwi_loaded, Space.RASMM)

    # Save streamlines
    save_trk(sft, fibers + ".trk")

    # Visualize fibers
    if plot and has_fury:
        from dipy.viz import window, actor, colormap as cmap

        # Create the 3D display.
        r = window.Renderer()
        r.add(actor.line(streamlines, cmap.line_colors(streamlines)))
        window.record(r, out_path=fibers + '.png', size=(800, 800))
Ejemplo n.º 12
0
def extract_true_connections(
    sft, mask_1_filename, mask_2_filename, gt_config, length_dict,
    gt_bundle, gt_bundle_inv_mask, dilate_endpoints, wrong_path_as_separate
):
    """
    Extract true connections based on two regions from a tractogram.
    May extract false and no connections if the config is passed.

    Parameters
    ----------
    sft: StatefulTractogram
        Tractogram containing the streamlines to be extracted.
    mask_1_filename: str
        Filename of the "head" of the bundle.
    mask_2_filename: str
        Filename of the "tail" of the bundle.
    gt_config: dict or None
        Dictionary containing the bundle's parameters.
    length_dict: dict or None
        Dictionary containing the bundle's length parameters.
    gt_bundle: str
        Bundle's name.
    gt_bundle_inv_mask: np.ndarray
        Inverse mask of the bundle.
    dilate_endpoints: int or None
        If set, dilate the masks for n iterations.
    wrong_path_as_separate: bool
        If true, save the WPCs as separate from TCs.

    Returns
    -------
    tc_sft: StatefulTractogram
        SFT of true connections.
    wpc_sft: StatefulTractogram
        SFT of wrong-path-connections.
    fc_sft: StatefulTractogram
        SFT of false connections (streamlines that are too long).
    nc_streamlines: StatefulTractogram
        SFT of no connections (streamlines that loop)
    sft: StatefulTractogram
        SFT of remaining streamlines.
    """

    mask_1_img = nib.load(mask_1_filename)
    mask_2_img = nib.load(mask_2_filename)
    mask_1 = get_data_as_mask(mask_1_img)
    mask_2 = get_data_as_mask(mask_2_img)

    if dilate_endpoints:
        mask_1 = binary_dilation(mask_1, iterations=dilate_endpoints)
        mask_2 = binary_dilation(mask_2, iterations=dilate_endpoints)

    # TODO: Handle streamline IDs instead of streamlines
    tmp_sft, sft = extract_streamlines(mask_1, mask_2, sft)

    streamlines = tmp_sft.streamlines
    tc_streamlines = streamlines
    wpc_streamlines = []
    fc_streamlines = []
    nc_streamlines = []

    # Config file for each 'bundle'
    # Loops => no connection (nc) # TODO Is this legit ?
    # Length => false connection (fc) # TODO Is this legit ?
    if gt_config:
        min_len, max_len = \
            length_dict[gt_bundle]['length']

        # Bring streamlines to world coordinates so proper length
        # is calculated
        tmp_sft.to_rasmm()
        streamlines = tmp_sft.streamlines
        lengths = np.array(list(length(streamlines)))
        tmp_sft.to_vox()
        streamlines = tmp_sft.streamlines

        valid_min_length_mask = lengths > min_len
        valid_max_length_mask = lengths < max_len
        valid_length_mask = np.logical_and(valid_min_length_mask,
                                           valid_max_length_mask)
        streamlines = ArraySequence(streamlines)

        val_len_streamlines = streamlines[valid_length_mask]
        fc_streamlines = streamlines[~valid_length_mask]

        angle = length_dict[gt_bundle]['angle']
        tc_streamlines_ids = remove_loops_and_sharp_turns(
            val_len_streamlines, angle)

        loop_ids = np.setdiff1d(
            range(len(val_len_streamlines)), tc_streamlines_ids)

        loops = val_len_streamlines[list(loop_ids)]
        tc_streamlines = val_len_streamlines[list(tc_streamlines_ids)]

        if loops:
            nc_streamlines = loops

    # Streamlines getting out of the bundle mask can be considered
    # separately as wrong path connection (wpc)
    # TODO: Maybe only consider if they cross another GT bundle ?
    if wrong_path_as_separate:
        tmp_sft = StatefulTractogram.from_sft(tc_streamlines, sft)
        _, wp_ids = filter_grid_roi(
            tmp_sft, gt_bundle_inv_mask, 'any', False)
        wpc_streamlines = tmp_sft.streamlines[list(wp_ids)]
        tc_ids = np.setdiff1d(range(len(tmp_sft)), wp_ids)
        tc_streamlines = tmp_sft.streamlines[list(tc_ids)]

    tc_sft = StatefulTractogram.from_sft(tc_streamlines, sft)
    wpc_sft = StatefulTractogram.from_sft([], sft)
    fc_sft = StatefulTractogram.from_sft(fc_streamlines, sft)
    if wrong_path_as_separate and len(wpc_streamlines):
        wpc_sft = StatefulTractogram.from_sft(wpc_streamlines, sft)

    return tc_sft, wpc_sft, fc_sft, nc_streamlines, sft
Ejemplo n.º 13
0
def evaluation_tractogram(hyperparams, model, dataset, batch_size_override,
                          metric):
    loss = loss_factory(hyperparams, model, dataset, loss_type=None)
    batch_scheduler = batch_scheduler_factory(
        hyperparams,
        dataset,
        train_mode=False,
        batch_size_override=batch_size_override,
        use_data_augment=False)

    _ = loss.losses  # Hack to generate update dict in loss :(

    if hyperparams['model'] == 'ffnn_regression':
        timestep_losses, inputs, targets = log_variables(
            batch_scheduler, model, loss.loss_per_time_step,
            dataset.symb_inputs * 1, dataset.symb_targets * 1)
        # Regrouping data into streamlines will only work if the original streamlines were NOT shuffled, resampled or augmented
        timesteps_loss = ArraySequence()
        seq_loss = []
        timesteps_inputs = ArraySequence()
        timesteps_targets = ArraySequence()
        idx = 0
        for length in dataset.streamlines._lengths:
            start = idx
            idx = end = idx + length
            timesteps_loss.extend(timestep_losses[start:end])
            seq_loss.extend(np.mean(timestep_losses[start:end]))
            timesteps_inputs.extend(inputs[start:end])
            timesteps_targets.extend(targets[start:end])
    else:

        timestep_losses, seq_losses, inputs, targets, masks = log_variables(
            batch_scheduler, model, loss.loss_per_time_step, loss.loss_per_seq,
            dataset.symb_inputs * 1, dataset.symb_targets * 1,
            dataset.symb_mask * 1)
        timesteps_loss = ArraySequence([
            l[:int(m.sum())]
            for l, m in zip(chain(*timestep_losses), chain(*masks))
        ])
        seq_loss = np.array(list(chain(*seq_losses)))
        timesteps_inputs = ArraySequence(
            [i[:int(m.sum())] for i, m in zip(chain(*inputs), chain(*masks))])
        # Use np.squeeze in case gru_multistep is used to remove the empty k=1 dimension
        timesteps_targets = ArraySequence([
            np.squeeze(t[:int(m.sum())])
            for t, m in zip(chain(*targets), chain(*masks))
        ])

    if metric == 'sequence':
        # Color is based on sequence loss
        values = seq_loss
    elif metric == 'timestep' or metric == 'cumul_avg':
        # Color is based on timestep loss
        values = np.concatenate(timesteps_loss)
    else:
        raise ValueError("Unrecognized metric: {}".format(metric))

    cmap = cm.get_cmap('bwr')
    vmin = np.percentile(values, 5)
    vmax = np.percentile(values, 95)
    scalar_map = cm.ScalarMappable(norm=mplcolors.Normalize(vmin=vmin,
                                                            vmax=vmax),
                                   cmap=cmap)

    streamlines = []
    colors = []

    for i, t, l, seq_l in zip(timesteps_inputs, timesteps_targets,
                              timesteps_loss, seq_loss):
        pts = np.r_[i[:, :3], [i[-1, :3] + t[-1]]]

        color = np.zeros_like(pts)
        if metric == 'sequence':
            # Streamline color is based on sequence loss
            color[:, :] = scalar_map.to_rgba(seq_l, bytes=True)[:3]
        elif metric == 'timestep':
            # Streamline color is based on timestep loss
            # Identify first point with green
            color[0, :] = [0, 255, 0]
            color[1:, :] = scalar_map.to_rgba(l, bytes=True)[:, :3]
        elif metric == 'cumul_avg':
            # Streamline color is based on timestep loss

            # Compute cumulative average
            cumul_avg = np.cumsum(l) / np.arange(1, len(l) + 1)

            # Identify first point with green
            color[0, :] = [0, 255, 0]
            color[1:, :] = scalar_map.to_rgba(cumul_avg, bytes=True)[:, :3]
        else:
            raise ValueError("Unrecognized metric: {}".format(metric))

        streamlines.append(pts)
        colors.append(color)

    tractogram = nib.streamlines.Tractogram(streamlines,
                                            data_per_point={"colors": colors})
    return tractogram
Ejemplo n.º 14
0
def prediction_tractogram(hyperparams, model, dataset, batch_size_override,
                          prediction_method):
    loss = loss_factory(hyperparams,
                        model,
                        dataset,
                        loss_type=prediction_method)
    batch_scheduler = batch_scheduler_factory(
        hyperparams,
        dataset,
        train_mode=False,
        batch_size_override=batch_size_override,
        use_data_augment=False)

    _ = loss.losses  # Hack to generate update dict in loss :(
    predictions = loss.samples

    predict, timestep_losses, inputs, targets, masks = log_variables(
        batch_scheduler, model, predictions, loss.loss_per_time_step,
        dataset.symb_inputs * 1, dataset.symb_targets * 1,
        dataset.symb_mask * 1)
    if hyperparams['model'] == 'ffnn_regression':
        # Regrouping data into streamlines will only work if the original streamlines were NOT shuffled, resampled or augmented
        timesteps_prediction = ArraySequence()
        timesteps_loss = ArraySequence()
        timesteps_inputs = ArraySequence()
        timesteps_targets = ArraySequence()
        idx = 0
        for length in dataset.streamlines._lengths:
            start = idx
            idx = end = idx + length
            timesteps_prediction.extend(predict[start:end])
            timesteps_loss.extend(timestep_losses[start:end])
            timesteps_inputs.extend(inputs[start:end])
            timesteps_targets.extend(targets[start:end])
    else:
        timesteps_prediction = ArraySequence(
            [p[:int(m.sum())] for p, m in zip(chain(*predict), chain(*masks))])
        timesteps_loss = ArraySequence([
            l[:int(m.sum())]
            for l, m in zip(chain(*timestep_losses), chain(*masks))
        ])
        timesteps_inputs = ArraySequence(
            [i[:int(m.sum())] for i, m in zip(chain(*inputs), chain(*masks))])
        # Use np.squeeze in case gru_multistep is used to remove the empty k=1 dimension
        timesteps_targets = ArraySequence([
            np.squeeze(t[:int(m.sum())])
            for t, m in zip(chain(*targets), chain(*masks))
        ])

    # Debug : Print norm stats
    # print("Dataset: {}; # of streamlines: {}".format(dataset.name, len(dataset)))
    # all_predictions = np.array(list(chain(*timesteps_prediction)))
    # prediction_norms = np.linalg.norm(all_predictions, axis=1)
    # print("Prediction norm --- Mean:{}; Max:{}; Min:{}".format(np.mean(prediction_norms), np.max(prediction_norms), np.min(prediction_norms)))
    # all_targets = np.array(list(chain(*timesteps_targets)))
    # target_norms = np.linalg.norm(all_targets, axis=1)
    # print("Target norm --- Mean:{}; Max:{}; Min:{}".format(np.mean(target_norms), np.max(target_norms), np.min(target_norms)))

    # Color is based on timestep loss
    cmap = cm.get_cmap('bwr')
    values = np.concatenate(timesteps_loss)
    vmin = np.percentile(values, 5)
    vmax = np.percentile(values, 95)
    scalar_map = cm.ScalarMappable(norm=mplcolors.Normalize(vmin=vmin,
                                                            vmax=vmax),
                                   cmap=cmap)

    streamlines = []
    colors = []

    for i, t, p, l in zip(timesteps_inputs, timesteps_targets,
                          timesteps_prediction, timesteps_loss):
        pts = np.r_[i[:, :3], [i[-1, :3] + t[-1]]]

        streamline = np.zeros(((len(pts) - 1) * 3 + 1, 3))
        streamline[::3] = pts
        streamline[1:-1:3] = pts[:-1] + p
        streamline[2:-1:3] = pts[:-1]
        streamlines.append(streamline)

        # Color input streamlines in a uniform color, then color predictions based on L2 error
        color = np.zeros_like(streamline)

        # Base color of streamlines is minimum value (best score)
        color[:] = scalar_map.to_rgba(vmin, bytes=True)[:3]
        color[1:-1:3, :] = scalar_map.to_rgba(l, bytes=True)[:, :3]
        colors.append(color)

    tractogram = nib.streamlines.Tractogram(streamlines,
                                            data_per_point={"colors": colors})
    return tractogram
Ejemplo n.º 15
0
 def get_array_sequence(self, item=None):
     if item is None:
         streamlines = _load_streamlines_from_hdf(self.hdf_group)
     else:
         streamlines = ArraySequence()
         if isinstance(item, int):
             streamline = self._get_one_streamline(item)
             streamlines.append(streamline)
         elif isinstance(item, list) or isinstance(item, np.ndarray):
             for i in item:
                 streamline = self._get_one_streamline(i)
                 streamlines.append(streamline, cache_build=True)
             streamlines.finalize_append()
         elif isinstance(item, slice):
             offsets = self.hdf_group['offsets'][item]
             lengths = self.hdf_group['lengths'][item]
             for offset, length in zip(offsets, lengths):
                 streamline = self.hdf_group['data'][offset:offset + length]
                 streamlines.append(streamline, cache_build=True)
             streamlines.finalize_append()
         else:
             raise ValueError(
                 'Item should be either a int, list, '
                 'np.ndarray or slice but we received {}'.format(
                     type(item)))
     return streamlines
def dwi_deterministic_tracing(image, bvecs, bvals, wm, seeds, fibers,
                              prune_length=3, plot=False, verbose=False):
    # Pipeline transcribed from:
    #   http://nipy.org/dipy/examples_built/introduction_to_basic_tracking.html
    # Load Images
    dwi_loaded = nib.load(image)
    dwi_data = dwi_loaded.get_data()

    wm_loaded = nib.load(wm)
    wm_data = wm_loaded.get_data()

    seeds_loaded = nib.load(seeds)
    seeds_data = seeds_loaded.get_data()

    # Load B-values & B-vectors
    # NB. Use aligned b-vecs if providing eddy-aligned data
    bvals, bvecs = read_bvals_bvecs(bvals, bvecs)
    gtab = gradient_table(bvals, bvecs)

    # Establish ODF model
    csa_model = CsaOdfModel(gtab, sh_order=6)
    csa_peaks = peaks_from_model(csa_model, dwi_data, default_sphere,
                                 relative_peak_threshold=0.8,
                                 min_separation_angle=45,
                                 mask=wm_data)

    # Classify tissue for high FA and create seeds
    # (Putting this inside a looped try-block to handle fuzzy failures)
    classifier = ThresholdTissueClassifier(csa_peaks.gfa, 0.25)
    seeds = wrap_fuzzy_failures(utils.seeds_from_mask,
                                args=[seeds_data],
                                kwargs={"density": [2, 2, 2],
                                        "affine": np.eye(4)},
                                errortype=ValueError,
                                failure_threshold=5,
                                verbose=verbose)

    # Perform deterministic tracing
    # (Putting this inside a looped try-block to handle fuzzy failures)
    streamlines_generator = wrap_fuzzy_failures(LocalTracking,
                                                args=[csa_peaks,
                                                      classifier,
                                                      seeds],
                                                kwargs={"affine": np.eye(4),
                                                        "step_size": 0.5},
                                                errortype=ValueError,
                                                failure_threshold=5,
                                                verbose=verbose)
    streamlines = wrap_fuzzy_failures(Streamlines,
                                      args=[streamlines_generator],
                                      kwargs={},
                                      errortype=IndexError,
                                      failure_threshold=5,
                                      verbose=verbose)

    # Prune streamlines
    streamlines = ArraySequence([strline
                                 for strline in streamlines
                                 if len(strline) > prune_length])

    # Save streamlines
    save_trk(fibers + ".trk", streamlines, dwi_loaded.affine,
             shape=wm_data.shape, vox_size=wm_loaded.header.get_zooms())

    # Visualize fibers
    if plot and have_fury:
        from dipy.viz import window, actor, colormap as cmap

        color = cmap.line_colors(streamlines)
        streamlines_actor = actor.line(streamlines, color)

        # Create the 3D display.
        r = window.Renderer()
        r.add(streamlines_actor)

        # Save still image.
        window.record(r, n_frames=1, out_path=fibers + ".png",
                      size=(800, 800))
def prediction_tractogram(hyperparams, model, dataset, batch_size_override, prediction_method):
    loss = loss_factory(hyperparams, model, dataset, loss_type=prediction_method)
    batch_scheduler = batch_scheduler_factory(hyperparams, dataset, train_mode=False, batch_size_override=batch_size_override, use_data_augment=False)

    _ = loss.losses  # Hack to generate update dict in loss :(
    predictions = loss.samples

    predict, timestep_losses, inputs, targets, masks = log_variables(batch_scheduler,
                                                                     model,
                                                                     predictions,
                                                                     loss.loss_per_time_step,
                                                                     dataset.symb_inputs * 1,
                                                                     dataset.symb_targets * 1,
                                                                     dataset.symb_mask * 1)
    if hyperparams['model'] == 'ffnn_regression':
        # Regrouping data into streamlines will only work if the original streamlines were NOT shuffled, resampled or augmented
        timesteps_prediction = ArraySequence()
        timesteps_loss = ArraySequence()
        timesteps_inputs = ArraySequence()
        timesteps_targets = ArraySequence()
        idx = 0
        for length in dataset.streamlines._lengths:
            start = idx
            idx = end = idx+length
            timesteps_prediction.extend(predict[start:end])
            timesteps_loss.extend(timestep_losses[start:end])
            timesteps_inputs.extend(inputs[start:end])
            timesteps_targets.extend(targets[start:end])
    else:
        timesteps_prediction = ArraySequence([p[:int(m.sum())] for p, m in zip(chain(*predict), chain(*masks))])
        timesteps_loss = ArraySequence([l[:int(m.sum())] for l, m in zip(chain(*timestep_losses), chain(*masks))])
        timesteps_inputs = ArraySequence([i[:int(m.sum())] for i, m in zip(chain(*inputs), chain(*masks))])
        # Use np.squeeze in case gru_multistep is used to remove the empty k=1 dimension
        timesteps_targets = ArraySequence([np.squeeze(t[:int(m.sum())]) for t, m in zip(chain(*targets), chain(*masks))])

    # Debug : Print norm stats
    # print("Dataset: {}; # of streamlines: {}".format(dataset.name, len(dataset)))
    # all_predictions = np.array(list(chain(*timesteps_prediction)))
    # prediction_norms = np.linalg.norm(all_predictions, axis=1)
    # print("Prediction norm --- Mean:{}; Max:{}; Min:{}".format(np.mean(prediction_norms), np.max(prediction_norms), np.min(prediction_norms)))
    # all_targets = np.array(list(chain(*timesteps_targets)))
    # target_norms = np.linalg.norm(all_targets, axis=1)
    # print("Target norm --- Mean:{}; Max:{}; Min:{}".format(np.mean(target_norms), np.max(target_norms), np.min(target_norms)))

    # Color is based on timestep loss
    cmap = cm.get_cmap('bwr')
    values = np.concatenate(timesteps_loss)
    vmin = np.percentile(values, 5)
    vmax = np.percentile(values, 95)
    scalar_map = cm.ScalarMappable(norm=mplcolors.Normalize(vmin=vmin, vmax=vmax), cmap=cmap)

    streamlines = []
    colors = []

    for i, t, p, l in zip(timesteps_inputs, timesteps_targets, timesteps_prediction, timesteps_loss):
        pts = np.r_[i[:, :3], [i[-1, :3] + t[-1]]]

        streamline = np.zeros(((len(pts) - 1) * 3 + 1, 3))
        streamline[::3] = pts
        streamline[1:-1:3] = pts[:-1] + p
        streamline[2:-1:3] = pts[:-1]
        streamlines.append(streamline)

        # Color input streamlines in a uniform color, then color predictions based on L2 error
        color = np.zeros_like(streamline)

        # Base color of streamlines is minimum value (best score)
        color[:] = scalar_map.to_rgba(vmin, bytes=True)[:3]
        color[1:-1:3, :] = scalar_map.to_rgba(l, bytes=True)[:, :3]
        colors.append(color)

    tractogram = nib.streamlines.Tractogram(streamlines, data_per_point={"colors": colors})
    return tractogram
Ejemplo n.º 18
0
def get_ismrm_seeds(data_dir, source, keep, weighted, threshold, voxel):

    trk_dir = os.path.join(data_dir, "bundles")

    if source in ["wm", "trk"]:
        anat_path = os.path.join(data_dir, "masks", "wm.nii.gz")
        resized_path = os.path.join(data_dir, "masks",
                                    "wm_{}.nii.gz".format(voxel))
    elif source == "brain":
        anat_path = os.path.join("subjects", "ismrm_gt",
                                 "dwi_brain_mask.nii.gz")
        resized_path = os.path.join("subjects", "ismrm_gt",
                                    "dwi_brain_mask_125.nii.gz")

    sp.call([
        "mrresize", "-voxel", "{:1.2f}".format(voxel / 100), anat_path,
        resized_path
    ])

    if source == "trk":

        print("Running Tractconverter...")
        sp.call([
            "python", "tractconverter/scripts/WalkingTractConverter.py", "-i",
            trk_dir, "-a", resized_path, "-vtk2trk"
        ])

        print("Loading seed bundles...")
        seed_bundles = []
        for i, trk_path in enumerate(glob.glob(os.path.join(trk_dir,
                                                            "*.trk"))):
            trk_file = nib.streamlines.load(trk_path)
            endpoints = []
            for fiber in trk_file.tractogram.streamlines:
                endpoints.append(fiber[0])
                endpoints.append(fiber[-1])
            seed_bundles.append(endpoints)
            if i == 0:
                header = trk_file.header

        n_seeds = sum([len(b) for b in seed_bundles])
        n_bundles = len(seed_bundles)

        print("Loaded {} seeds from {} bundles.".format(n_seeds, n_bundles))

        seeds = np.array([[seed] for bundle in seed_bundles
                          for seed in bundle])

        if keep < 1:
            if weighted:
                p = np.zeros(n_seeds)
                offset = 0
                for b in seed_bundles:
                    l = len(b)
                    p[offset:offset + l] = 1 / (l * n_bundles)
                    offset += l
            else:
                p = np.ones(n_seeds) / n_seeds

    elif source in ["brain", "wm"]:

        weighted = False

        wm_file = nib.load(resized_path)
        wm_img = wm_file.get_fdata()

        seeds = np.argwhere(wm_img > threshold)
        seeds = np.hstack([seeds, np.ones([len(seeds), 1])])

        seeds = (wm_file.affine.dot(seeds.T).T)[:, :3].reshape(-1, 1, 3)

        n_seeds = len(seeds)

        if keep < 1:
            p = np.ones(n_seeds) / n_seeds

        header = TrkFile.create_empty_header()

        header["voxel_to_rasmm"] = wm_file.affine
        header["dimensions"] = wm_file.header["dim"][1:4]
        header["voxel_sizes"] = wm_file.header["pixdim"][1:4]
        header["voxel_order"] = get_reference_info(wm_file)[3]

    if keep < 1:
        keep_n = int(keep * n_seeds)
        print("Subsampling from {} seeds to {} seeds".format(n_seeds, keep_n))
        np.random.seed(42)
        keep_idx = np.random.choice(len(seeds),
                                    size=keep_n,
                                    replace=False,
                                    p=p)
        seeds = seeds[keep_idx]

    tractogram = Tractogram(streamlines=ArraySequence(seeds),
                            affine_to_rasmm=np.eye(4))

    save_dir = os.path.join(data_dir, "seeds")
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    save_path = os.path.join(save_dir, "seeds_from_{}_{}_vox{:03d}.trk")
    save_path = save_path.format(
        source, "W" + str(int(100 * keep)) if weighted else "all", voxel)

    print("Saving {}".format(save_path))
    TrkFile(tractogram, header).save(save_path)

    os.remove(resized_path)
    for file in glob.glob(os.path.join(trk_dir, "*.trk")):
        os.remove(file)
def evaluation_tractogram(hyperparams, model, dataset, batch_size_override, metric):
    loss = loss_factory(hyperparams, model, dataset, loss_type=None)
    batch_scheduler = batch_scheduler_factory(hyperparams, dataset, train_mode=False, batch_size_override=batch_size_override, use_data_augment=False)

    _ = loss.losses  # Hack to generate update dict in loss :(

    if hyperparams['model'] == 'ffnn_regression':
        timestep_losses, inputs, targets = log_variables(batch_scheduler,
                                                         model,
                                                         loss.loss_per_time_step,
                                                         dataset.symb_inputs * 1,
                                                         dataset.symb_targets * 1)
        # Regrouping data into streamlines will only work if the original streamlines were NOT shuffled, resampled or augmented
        timesteps_loss = ArraySequence()
        seq_loss = []
        timesteps_inputs = ArraySequence()
        timesteps_targets = ArraySequence()
        idx = 0
        for length in dataset.streamlines._lengths:
            start = idx
            idx = end = idx+length
            timesteps_loss.extend(timestep_losses[start:end])
            seq_loss.extend(np.mean(timestep_losses[start:end]))
            timesteps_inputs.extend(inputs[start:end])
            timesteps_targets.extend(targets[start:end])
    else:

        timestep_losses, seq_losses, inputs, targets, masks = log_variables(batch_scheduler,
                                                                            model,
                                                                            loss.loss_per_time_step,
                                                                            loss.loss_per_seq,
                                                                            dataset.symb_inputs * 1,
                                                                            dataset.symb_targets * 1,
                                                                            dataset.symb_mask * 1)
        timesteps_loss = ArraySequence([l[:int(m.sum())] for l, m in zip(chain(*timestep_losses), chain(*masks))])
        seq_loss = np.array(list(chain(*seq_losses)))
        timesteps_inputs = ArraySequence([i[:int(m.sum())] for i, m in zip(chain(*inputs), chain(*masks))])
        # Use np.squeeze in case gru_multistep is used to remove the empty k=1 dimension
        timesteps_targets = ArraySequence([np.squeeze(t[:int(m.sum())]) for t, m in zip(chain(*targets), chain(*masks))])

    if metric == 'sequence':
        # Color is based on sequence loss
        values = seq_loss
    elif metric == 'timestep' or metric == 'cumul_avg':
        # Color is based on timestep loss
        values = np.concatenate(timesteps_loss)
    else:
        raise ValueError("Unrecognized metric: {}".format(metric))

    cmap = cm.get_cmap('bwr')
    vmin = np.percentile(values, 5)
    vmax = np.percentile(values, 95)
    scalar_map = cm.ScalarMappable(norm=mplcolors.Normalize(vmin=vmin, vmax=vmax), cmap=cmap)

    streamlines = []
    colors = []

    for i, t, l, seq_l in zip(timesteps_inputs, timesteps_targets, timesteps_loss, seq_loss):
        pts = np.r_[i[:, :3], [i[-1, :3] + t[-1]]]

        color = np.zeros_like(pts)
        if metric == 'sequence':
            # Streamline color is based on sequence loss
            color[:, :] = scalar_map.to_rgba(seq_l, bytes=True)[:3]
        elif metric == 'timestep':
            # Streamline color is based on timestep loss
            # Identify first point with green
            color[0, :] = [0, 255, 0]
            color[1:, :] = scalar_map.to_rgba(l, bytes=True)[:, :3]
        elif metric == 'cumul_avg':
            # Streamline color is based on timestep loss

            # Compute cumulative average
            cumul_avg = np.cumsum(l) / np.arange(1, len(l) + 1)

            # Identify first point with green
            color[0, :] = [0, 255, 0]
            color[1:, :] = scalar_map.to_rgba(cumul_avg, bytes=True)[:, :3]
        else:
            raise ValueError("Unrecognized metric: {}".format(metric))

        streamlines.append(pts)
        colors.append(color)

    tractogram = nib.streamlines.Tractogram(streamlines, data_per_point={"colors": colors})
    return tractogram