Ejemplo n.º 1
0
def get_seeds_from_wm(wm_path, threshold=0):

    wm_file = nib.load(wm_path)
    wm_img = wm_file.get_fdata()

    seeds = np.argwhere(wm_img > threshold)
    seeds = np.hstack([seeds, np.ones([len(seeds), 1])])

    seeds = (wm_file.affine.dot(seeds.T).T)[:, :3].reshape(-1, 1, 3)

    n_seeds = len(seeds)

    header = TrkFile.create_empty_header()

    header["voxel_to_rasmm"] = wm_file.affine
    header["dimensions"] = wm_file.header["dim"][1:4]
    header["voxel_sizes"] = wm_file.header["pixdim"][1:4]
    header["voxel_order"] = get_reference_info(wm_file)[3]

    tractogram = Tractogram(streamlines=ArraySequence(seeds),
                            affine_to_rasmm=np.eye(4))

    save_path = os.path.join(os.path.dirname(wm_path), "seeds_from_wm.trk")

    print("Saving {}".format(save_path))
    TrkFile(tractogram, header).save(save_path)
Ejemplo n.º 2
0
def calculate_tract_disconnection(trk_file_path, lesion):
    ''' Calculates the percent disconnection of tract defined in trk_file_path 
        for a given lesion.
    
        trk_file_path (String)
        lesion (ndarray)
    '''
    # load the trk
    try:
        trk = TrkFile.load(trk_file_path)
    except:
        print('Error during lesion analysis when loading file: ',
              trk_file_path)
        return 0, 0, 0
    streamlines = trk.tractogram.streamlines  # ArraySequence

    # get the start indices of each set of streamline coords if they were to be concatenated
    start_indices = []
    for i, streamline in enumerate(streamlines):
        start_indices.append((start_indices[i - 1] if start_indices else 0) +
                             len(streamline))

    # concatenate the streamline coords, floor then cast to int
    # first shift the streamline coords by 0.5 voxel to account for the default shift in nibabel
    # then scale the streamline coords by 0.5 since they are in 1mm voxel space and we are comparing to 2mm voxel lesion
    coords = np.floor(0.5 * (np.vstack(streamlines) + 0.5)).astype('int16')

    # get value of lesion voxels at streamline coords
    overlap = lesion[coords[:, 0], coords[:, 1], coords[:, 2]]

    # split list so it is grouped by streamline
    overlap = np.split(overlap, start_indices)[:-1]

    # see if streamlines pass through lesion
    overlap = np.array([np.any(i) for i in overlap])

    num_streamlines = len(streamlines)
    disconnected_streamlines = np.count_nonzero(overlap)

    return num_streamlines, disconnected_streamlines, 100 * (
        disconnected_streamlines / num_streamlines)
Ejemplo n.º 3
0
def maybe_add_tangent(trk_path, min_length=0, max_length=1000):

    cache_path = trk_path[:-4] + "_t.trk"

    print("check if resampled files already in directory: {0}".format(
        cache_path))
    if os.path.exists(cache_path):
        print("Resampled fibers found in directory :) ")
        trk_file = nib.streamlines.load(cache_path)
        return trk_file.tractogram
    else:
        trk_file = nib.streamlines.load(trk_path)
        tractogram = trk_file.tractogram

        if "t" not in tractogram.data_per_point:
            tractogram = add_tangent(tractogram,
                                     min_length=min_length,
                                     max_length=max_length)
            TrkFile(tractogram, trk_file.header).save(cache_path)

        return tractogram
Ejemplo n.º 4
0
def trim(trk_path, min_length, max_length, fast=True, overwrite=False):

    if overwrite:
        trimmed_path = trk_path
    else:
        trimmed_path = (
            trk_path[:-4] + "_{:2.0f}mm{:3.0f}.trk".format(
                min_length, max_length)
        )

    print("Loading fibers for trimming ...")
    trk_file = nib.streamlines.load(trk_path)
    tractogram = trk_file.tractogram

    print("Trimming fibers ...")
    if fast and len(tractogram.data_per_point) == 0:
        cmd = [
            "track_vis", trk_path, "-nr", "-l", str(min_length),
            str(max_length), "-o", trimmed_path
        ]
        cmd = " ".join(cmd)
        subprocess.run(cmd, shell=True)
    else:
        pool = ProcessPool(nodes=10)

        def has_ok_length(f):
            l = np.linalg.norm(f[1:] - f[:-1], axis=1).sum()
            if l < min_length or l > max_length:
                return False
            else:
                return True

        is_ok = pool.map(has_ok_length, tractogram.streamlines)

        TrkFile(tractogram[is_ok], trk_file.header).save(trimmed_path)

    print("Saving trimmed fibers to {}".format(trimmed_path))
    return trimmed_path
Ejemplo n.º 5
0
def resample(trk_path,
             npts,
             smoothing,
             out_dir,
             min_length=0,
             max_length=1000):

    trk_file = nib.streamlines.load(trk_path)

    tractogram = resample_tractogram(trk_file.tractogram, npts, smoothing,
                                     min_length, max_length)

    if out_dir is None:
        out_dir = os.path.dirname(trk_path)

    basename = os.path.basename(trk_path).split(".")[0]
    save_path = os.path.join(
        out_dir, "{}_s={}_n={}.trk".format(basename, smoothing, npts))

    os.makedirs(out_dir, exist_ok=True)
    print("Saving {}".format(save_path))
    TrkFile(tractogram, trk_file.header).save(save_path)

    return tractogram
Ejemplo n.º 6
0
def mark(config, gpu_queue=None):

    gpu_idx = -1
    try:
        gpu_idx = maybe_get_a_gpu() if gpu_queue is None else gpu_queue.get()
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu_idx
    except Exception as e:
        print(str(e))
    print("Loading DWI data ...")

    dwi_img = nib.load(config["dwi_path"])
    dwi_img = nib.funcs.as_closest_canonical(dwi_img)
    dwi_aff = dwi_img.affine
    dwi_affi = np.linalg.inv(dwi_aff)
    dwi = dwi_img.get_data()

    def xyz2ijk(coords, snap=False):

        ijk = (coords.T).copy()

        ijk = np.vstack([ijk, np.ones([1, ijk.shape[1]])])

        dwi_affi.dot(ijk, out=ijk)

        if snap:
            return (np.round(ijk, out=ijk).astype(int, copy=False).T)[:, :4]
        else:
            return (ijk.T)[:, :4]

    # ==========================================================================

    print("Loading fibers ...")

    trk_file = nib.streamlines.load(config["trk_path"])
    tractogram = trk_file.tractogram

    if "t" in tractogram.data_per_point:
        print("Fibers are already resampled")
        tangents = tractogram.data_per_point["t"]
    else:
        print("Fibers are not resampled. Resampling now ...")
        tractogram = maybe_add_tangent(config["trk_path"],
                                       min_length=30,
                                       max_length=200)
        tangents = tractogram.data_per_point["t"]

    n_fibers = len(tractogram)
    fiber_lengths = np.array([len(t.streamline) for t in tractogram])
    max_length = fiber_lengths.max()
    n_pts = fiber_lengths.sum()

    # ==========================================================================

    print("Loading model ...")
    model_name = config['model_name']

    if hasattr(MODELS[model_name], "custom_objects"):
        model = load_model(config["model_path"],
                           custom_objects=MODELS[model_name].custom_objects,
                           compile=False)
    else:
        model = load_model(config["model_path"], compile=False)

    block_size = get_blocksize(model, dwi.shape[-1])

    d = np.zeros([n_fibers, dwi.shape[-1] * block_size**3 + 1])

    inputs = np.zeros([n_fibers, max_length, 3])

    print("Writing to input array ...")

    for i, fiber_t in enumerate(tangents):
        inputs[i, :fiber_lengths[i], :] = fiber_t

    outputs = np.zeros([n_fibers, max_length, 4])

    print("Starting iteration ...")

    step = 0
    while step < max_length:
        t0 = time()

        xyz = inputs[:, step, :]
        ijk = xyz2ijk(xyz, snap=True)

        for ii, idx in enumerate(ijk):
            try:
                d[ii, :-1] = dwi[idx[0] - (block_size // 2):idx[0] +
                                 (block_size // 2) + 1,
                                 idx[1] - (block_size // 2):idx[1] +
                                 (block_size // 2) + 1, idx[2] -
                                 (block_size // 2):idx[2] + (block_size // 2) +
                                 1, :].flatten()  # returns copy
            except (IndexError, ValueError):
                pass

        d[:, -1] = np.linalg.norm(d[:, :-1], axis=1) + 10**-2

        d[:, :-1] /= d[:, -1].reshape(-1, 1)

        if step == 0:
            vin = -inputs[:, step + 1, :]
            vout = -inputs[:, step, :]
        else:
            vin = inputs[:, step - 1, :]
            vout = inputs[:, step, :]

        model_inputs = np.hstack([vin, d])
        chunk = 2**15  # 32768
        n_chunks = np.ceil(n_fibers / chunk).astype(int)
        for c in range(n_chunks):

            fvm_pred, kappa_pred = model(model_inputs[c * chunk:(c + 1) *
                                                      chunk])

            log1p_kappa_pred = np.log1p(kappa_pred)

            log_prob_pred = fvm_pred.log_prob(vout[c * chunk:(c + 1) * chunk])

            log_prob_map_pred = fvm_pred._log_normalization() + kappa_pred

            outputs[c * chunk:(c + 1) * chunk, step, 0] = kappa_pred
            outputs[c * chunk:(c + 1) * chunk, step, 1] = log1p_kappa_pred
            outputs[c * chunk:(c + 1) * chunk, step, 2] = log_prob_pred
            outputs[c * chunk:(c + 1) * chunk, step, 3] = log_prob_map_pred

        print("Step {:3d}/{:3d}, ETA: {:4.0f} min".format(
            step, max_length, (max_length - step) * (time() - t0) / 60),
              end="\r")

        step += 1

    if gpu_queue is not None:
        gpu_queue.put(gpu_idx)

    kappa = [
        outputs[i, :fiber_lengths[i], 0].reshape(-1, 1)
        for i in range(n_fibers)
    ]
    log1p_kappa = [
        outputs[i, :fiber_lengths[i], 1].reshape(-1, 1)
        for i in range(n_fibers)
    ]
    log_prob = [
        outputs[i, :fiber_lengths[i], 2].reshape(-1, 1)
        for i in range(n_fibers)
    ]
    log_prob_map = [
        outputs[i, :fiber_lengths[i], 3].reshape(-1, 1)
        for i in range(n_fibers)
    ]

    log_prob_sum = [
        np.ones_like(log_prob[i]) * (log_prob[i].sum() / log_prob_map[i].sum())
        for i in range(n_fibers)
    ]
    log_prob_ratio = [
        np.ones_like(log_prob[i]) * (log_prob[i] - log_prob_map[i]).mean()
        for i in range(n_fibers)
    ]

    other_data = {}
    for key in list(trk_file.tractogram.data_per_point.keys()):
        if key not in [
                "kappa", "log1p_kappa", "log_prob", "log_prob_map",
                "log_prob_sum", "log_prob_ratio"
        ]:
            other_data[key] = trk_file.tractogram.data_per_point[key]

    data_per_point = PerArraySequenceDict(n_rows=n_pts,
                                          kappa=kappa,
                                          log_prob=log_prob,
                                          log_prob_sum=log_prob_sum,
                                          log_prob_ratio=log_prob_ratio,
                                          **other_data)
    tractogram = Tractogram(streamlines=tractogram.streamlines,
                            data_per_point=data_per_point,
                            affine_to_rasmm=np.eye(4))
    out_dir = os.path.join(os.path.dirname(config["dwi_path"]),
                           "marked_fibers", timestamp())
    os.makedirs(out_dir, exist_ok=True)

    marked_path = os.path.join(out_dir, "marked.trk")
    TrkFile(tractogram, trk_file.header).save(marked_path)

    config["out_dir"] = out_dir

    configs.save(config)
Ejemplo n.º 7
0
def merge_trks(trk_dir, keep, weighted, out_dir):
    """
    WARNING: Alignment between trk files is not checked, but assumed the same!
    """
    bundles = []
    for i, trk_path in enumerate(glob.glob(os.path.join(trk_dir, "*.trk"))):
        print("Loading {:.<20}".format(os.path.basename(trk_path)), end="\r")
        trk_file = nib.streamlines.load(trk_path)
        bundles.append(trk_file.tractogram)
        if i == 0:
            header = trk_file.header

    n_fibers = sum([len(b.streamlines) for b in bundles])
    n_bundles = len(bundles)

    print("Loaded {} fibers from {} bundles.".format(n_fibers, n_bundles))

    merged_bundles = bundles[0].copy()
    for b in bundles[1:]:
        merged_bundles.extend(b)

    if keep < 1:
        if weighted:
            p = np.zeros(n_fibers)
            offset=0
            for b in bundles:
                l = len(b.streamlines)
                p[offset:offset+l] = 1 / (l * n_bundles)
                offset += l
        else:
            p = np.ones(n_fibers) / n_fibers

        keep_n = int(keep * n_fibers)
        print("Subsampling {} fibers".format(keep_n))

        np.random.seed(42)
        subsample = np.random.choice(
            merged_bundles.streamlines,
            size=keep_n,
            replace=False,
            p=p)

        tractogram = Tractogram(
                streamlines=subsample,
                affine_to_rasmm=np.eye(4)
            )
    else:
        tractogram = merged_bundles

    if out_dir is None:
        out_dir = os.path.dirname(trk_dir)
        out_dir = os.path.join(out_dir, "merged_tracts")

    os.makedirs(out_dir, exist_ok=True)

    if weighted:
        save_path = os.path.join(out_dir,
            "merged_W{:04d}.trk".format(int(1000*args.keep)))
    else:
        save_path = os.path.join(out_dir,
            "merged_{:04d}.trk".format(int(1000*args.keep)))

    print("Saving {}".format(save_path))

    TrkFile(tractogram, header).save(save_path)
Ejemplo n.º 8
0
def run_rf_inference(config=None, gpu_queue=None):
    """"""
    try:
        gpu_idx = maybe_get_a_gpu() if gpu_queue is None else gpu_queue.get()
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu_idx
    except Exception as e:
        print(str(e))

    print(
        "Loading DWI...")  ####################################################

    dwi_img = nib.load(config['dwi_path'])
    dwi_img = nib.funcs.as_closest_canonical(dwi_img)
    dwi_aff = dwi_img.affine
    dwi_affi = np.linalg.inv(dwi_aff)
    dwi = dwi_img.get_data()

    def xyz2ijk(coords, snap=False):
        ijk = (coords.T).copy()
        dwi_affi.dot(ijk, out=ijk)
        if snap:
            return np.round(ijk, out=ijk).astype(int, copy=False).T
        else:
            return ijk.T

    with open(os.path.join(config['model_dir'], 'model'), 'rb') as f:
        model = pickle.load(f)

    train_config_file = os.path.join(config['model_dir'], 'config.yml')
    bvec_path = configs.load(train_config_file, 'bvecs')
    _, bvecs = read_bvals_bvecs(None, bvec_path)

    terminator = Terminator(config['term_path'], config['thresh'])

    prior = Prior(config['prior_path'])

    print(
        "Initializing Fibers...")  ############################################

    seed_file = nib.streamlines.load(config['seed_path'])
    xyz = seed_file.tractogram.streamlines.data
    n_seeds = 2 * len(xyz)
    xyz = np.vstack([xyz, xyz])  # Duplicate seeds for both directions
    xyz = np.hstack([xyz, np.ones([n_seeds, 1])])  # add affine dimension
    xyz = xyz.reshape(-1, 1, 4)  # (fiber, segment, coord)

    fiber_idx = np.hstack([
        np.arange(n_seeds // 2, dtype="int32"),
        np.arange(n_seeds // 2, dtype="int32")
    ])
    fibers = [[] for _ in range(n_seeds // 2)]

    print(
        "Start Iteration...")  ################################################

    input_shape = model.n_features_
    block_size = int(np.cbrt(input_shape / dwi.shape[-1]))

    d = np.zeros([n_seeds, dwi.shape[-1] * block_size**3])
    dnorm = np.zeros([n_seeds, 1])
    vout = np.zeros([n_seeds, 3])
    for i in range(config['max_steps']):
        t0 = time()

        # Get coords of latest segement for each fiber
        ijk = xyz2ijk(xyz[:, -1, :], snap=True)

        n_ongoing = len(ijk)

        for ii, idx in enumerate(ijk):
            d[ii] = dwi[idx[0] - (block_size // 2):idx[0] + (block_size // 2) +
                        1, idx[1] - (block_size // 2):idx[1] +
                        (block_size // 2) + 1,
                        idx[2] - (block_size // 2):idx[2] + (block_size // 2) +
                        1, :].flatten()  # returns copy
            dnorm[ii] = np.linalg.norm(d[ii])
            d[ii] /= dnorm[ii]

        if i == 0:
            inputs = np.hstack(
                [prior(xyz[:, 0, :]), d[:n_ongoing], dnorm[:n_ongoing]])
        else:
            inputs = np.hstack(
                [vout[:n_ongoing], d[:n_ongoing], dnorm[:n_ongoing]])

        chunk = 2**15  # 32768
        n_chunks = np.ceil(n_ongoing / chunk).astype(int)
        for c in range(n_chunks):

            outputs = model.predict(inputs[c * chunk:(c + 1) * chunk])
            v = bvecs[outputs, ...]
            vout[c * chunk:(c + 1) * chunk] = v

        rout = xyz[:, -1, :3] + config['step_size'] * vout
        rout = np.hstack([rout, np.ones((n_ongoing, 1))]).reshape(-1, 1, 4)

        xyz = np.concatenate([xyz, rout], axis=1)

        terminal_indices = terminator(xyz[:, -1, :])

        for idx in terminal_indices:
            gidx = fiber_idx[idx]
            # Other end not yet added
            if not fibers[gidx]:
                fibers[gidx].append(np.copy(xyz[idx, :, :3]))
            # Other end already added
            else:
                this_end = xyz[idx, :, :3]
                other_end = fibers[gidx][0]
                merged_fiber = np.vstack(
                    [np.flip(this_end[1:], axis=0),
                     other_end])  # stitch ends together
                fibers[gidx] = [merged_fiber]

        xyz = np.delete(xyz, terminal_indices, axis=0)
        vout = np.delete(vout, terminal_indices, axis=0)
        fiber_idx = np.delete(fiber_idx, terminal_indices)

        print(
            "Iter {:4d}/{}, finished {:5d}/{:5d} ({:3.0f}%) of all seeds with"
            " {:6.0f} steps/sec".format(
                (i + 1), config['max_steps'], n_seeds - n_ongoing, n_seeds,
                100 * (1 - n_ongoing / n_seeds), n_ongoing / (time() - t0)),
            end="\r")

        if n_ongoing == 0:
            break

        gc.collect()

    # Include unfinished fibers:

    fibers = [
        fibers[gidx] for gidx in range(len(fibers)) if gidx not in fiber_idx
    ]
    # Save Result

    fibers = [f[0] for f in fibers]

    tractogram = Tractogram(streamlines=ArraySequence(fibers),
                            affine_to_rasmm=np.eye(4))

    timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S")
    out_dir = os.path.join(os.path.dirname(config["dwi_path"]),
                           "predicted_fibers", timestamp)

    configs.deep_update(config, {"out_dir": out_dir})

    os.makedirs(out_dir, exist_ok=True)

    fiber_path = os.path.join(out_dir, timestamp + ".trk")
    print("\nSaving {}".format(fiber_path))
    TrkFile(tractogram, seed_file.header).save(fiber_path)

    config_path = os.path.join(out_dir, "config.yml")
    print("Saving {}".format(config_path))
    with open(config_path, "w") as file:
        yaml.dump(config, file, default_flow_style=False)

    if config["score"]:
        score_on_tm(fiber_path)

    return tractogram
Ejemplo n.º 9
0
def get_ismrm_seeds(data_dir, source, keep, weighted, threshold, voxel):

    trk_dir = os.path.join(data_dir, "bundles")

    if source in ["wm", "trk"]:
        anat_path = os.path.join(data_dir, "masks", "wm.nii.gz")
        resized_path = os.path.join(data_dir, "masks",
                                    "wm_{}.nii.gz".format(voxel))
    elif source == "brain":
        anat_path = os.path.join("subjects", "ismrm_gt",
                                 "dwi_brain_mask.nii.gz")
        resized_path = os.path.join("subjects", "ismrm_gt",
                                    "dwi_brain_mask_125.nii.gz")

    sp.call([
        "mrresize", "-voxel", "{:1.2f}".format(voxel / 100), anat_path,
        resized_path
    ])

    if source == "trk":

        print("Running Tractconverter...")
        sp.call([
            "python", "tractconverter/scripts/WalkingTractConverter.py", "-i",
            trk_dir, "-a", resized_path, "-vtk2trk"
        ])

        print("Loading seed bundles...")
        seed_bundles = []
        for i, trk_path in enumerate(glob.glob(os.path.join(trk_dir,
                                                            "*.trk"))):
            trk_file = nib.streamlines.load(trk_path)
            endpoints = []
            for fiber in trk_file.tractogram.streamlines:
                endpoints.append(fiber[0])
                endpoints.append(fiber[-1])
            seed_bundles.append(endpoints)
            if i == 0:
                header = trk_file.header

        n_seeds = sum([len(b) for b in seed_bundles])
        n_bundles = len(seed_bundles)

        print("Loaded {} seeds from {} bundles.".format(n_seeds, n_bundles))

        seeds = np.array([[seed] for bundle in seed_bundles
                          for seed in bundle])

        if keep < 1:
            if weighted:
                p = np.zeros(n_seeds)
                offset = 0
                for b in seed_bundles:
                    l = len(b)
                    p[offset:offset + l] = 1 / (l * n_bundles)
                    offset += l
            else:
                p = np.ones(n_seeds) / n_seeds

    elif source in ["brain", "wm"]:

        weighted = False

        wm_file = nib.load(resized_path)
        wm_img = wm_file.get_fdata()

        seeds = np.argwhere(wm_img > threshold)
        seeds = np.hstack([seeds, np.ones([len(seeds), 1])])

        seeds = (wm_file.affine.dot(seeds.T).T)[:, :3].reshape(-1, 1, 3)

        n_seeds = len(seeds)

        if keep < 1:
            p = np.ones(n_seeds) / n_seeds

        header = TrkFile.create_empty_header()

        header["voxel_to_rasmm"] = wm_file.affine
        header["dimensions"] = wm_file.header["dim"][1:4]
        header["voxel_sizes"] = wm_file.header["pixdim"][1:4]
        header["voxel_order"] = get_reference_info(wm_file)[3]

    if keep < 1:
        keep_n = int(keep * n_seeds)
        print("Subsampling from {} seeds to {} seeds".format(n_seeds, keep_n))
        np.random.seed(42)
        keep_idx = np.random.choice(len(seeds),
                                    size=keep_n,
                                    replace=False,
                                    p=p)
        seeds = seeds[keep_idx]

    tractogram = Tractogram(streamlines=ArraySequence(seeds),
                            affine_to_rasmm=np.eye(4))

    save_dir = os.path.join(data_dir, "seeds")
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    save_path = os.path.join(save_dir, "seeds_from_{}_{}_vox{:03d}.trk")
    save_path = save_path.format(
        source, "W" + str(int(100 * keep)) if weighted else "all", voxel)

    print("Saving {}".format(save_path))
    TrkFile(tractogram, header).save(save_path)

    os.remove(resized_path)
    for file in glob.glob(os.path.join(trk_dir, "*.trk")):
        os.remove(file)