Beispiel #1
0
def find_optimal_temperature(config):

    model_paths = glob.glob(config["model_glob"])

    dwi_path_1 = config["inference"]["dwi_path"].format("")
    dwi_path_2 = config["inference"]["dwi_path"].format("retest")

    gpu_queue = SimpleQueue()
    for idx in get_gpus():
        gpu_queue.put(str(idx))

    procs = []
    pred_manager = Manager()
    predictions = pred_manager.dict()
    try:
        for mp in model_paths:

            #if any(t in mp for t in []):

            model_config = config["inference"].copy()
            model_config["model_path"] = mp

            for j in [0, 1]:
                run_config = model_config.copy()
                parse(run_config, "dwi_path", j)
                parse(run_config, "prior_path", j)
                parse(run_config, "term_path", j)
                parse(run_config, "seed_path", j)
                while gpu_queue.empty():
                    sleep(10)

                p = Process(target=run_inference,
                            args=(run_config, gpu_queue, predictions))
                p.start()
                procs.append(p)
                print("Launched {}: {}".format(mp.split("/")[-1], j))
                sleep(10)

    except KeyboardInterrupt:
        pass
    finally:
        for p in procs:
            p.join()
            while p.exitcode is None:
                sleep(0.1)

    pred_pairs = group_by_model(predictions)
    config["pred_pairs"] = pred_pairs

    save(config,
         name="opT_{}.yml".format(timestamp()),
         out_dir=os.path.dirname(config["model_glob"]))
    """
Beispiel #2
0
def set_config():
    try:
        config = configs.from_request(request)
        result = configs.save(config, os.environ.get('DAC_CONFIG_PATH'))
    except Exception as e:
        return Response(str(e))

    get_settings.cache_clear()
    get_price_settings.cache_clear()

    return Response(json.dumps(result))
Beispiel #3
0
def train(config=None, gpu_queue=None):

    try:
        gpu_idx = maybe_get_a_gpu() if gpu_queue is None else gpu_queue.get()
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu_idx
    except Exception as e:
        print(str(e))

    day, hour = timestamp(separate=True)

    out_dir = os.path.join("models", config["model_name"],
                           config.get("model_type", ""), day, hour)

    os.makedirs(out_dir, exist_ok=True)
    configs.deep_update(config, {"out_dir": out_dir})

    configs.add(config, to=".running")

    model = MODELS[config["model_name"]](config)

    try:
        train_seq = model.get_sequence(config)
        eval_seq = model.get_sequence(config, istraining=False)
        configs.deep_update(config, {
            "train_seq": train_seq,
            "eval_seq": eval_seq
        })

        checkpoints = out_dir + '/inter_model_{epoch:02d}-{val_loss:.4f}.h5'
        if "RNN" in config["model_name"]:
            configs.deep_update(config,
                                {"reset_batches": train_seq.reset_batches})
            configs.deep_update(config, {"filepath": checkpoints})
        if 'Trackifier' in config['model_name']:
            configs.deep_update(config, {"filepath": checkpoints})

        callbacks = parse_callbacks(config["callbacks"])
        optimizer = getattr(keras_optimizers,
                            config["optimizer"])(**config["opt_params"])

        model.compile(optimizer)

        if isinstance(config['train_path'], list):
            for i, subject in enumerate(config['train_path']):
                samples_config = os.path.join(os.path.dirname(subject),
                                              'config.yml')
                samples_config = configs.load(samples_config)
                config['input_sampels_config_{0}'.format(i)] = samples_config
        else:
            samples_config = os.path.join(
                os.path.dirname(config['train_path']), 'config.yml')
            samples_config = configs.load(samples_config)
            config['input_sampels_config'] = samples_config
        repo = git.Repo(".")
        commit = repo.head.commit
        config['commit'] = str(commit)
        configs.save(config)

        print("\nStart training...")
        no_exception = True
        model.keras.fit_generator(
            train_seq,
            callbacks=callbacks,
            validation_data=eval_seq,
            epochs=config["epochs"],
            shuffle=config["shuffle"],
            max_queue_size=2000,
            verbose=1,
            workers=5,
            use_multiprocessing=True,
        )
    except KeyboardInterrupt:
        model.stop_training = True
    except Exception as e:
        shutil.rmtree(out_dir)
        no_exception = False
        raise e
    finally:
        configs.remove(config, _from=".running")
        if no_exception:
            configs.add(config, to=".archive")
            model_path = os.path.join(out_dir, "final_model.h5")
            print("\nSaving {}".format(model_path))
            model.keras.save(model_path)
        if gpu_queue is not None:
            gpu_queue.put(gpu_idx)

    return model.keras
Beispiel #4
0
 def save_config(self, event):
     configs.save(self)
Beispiel #5
0
def mark(config, gpu_queue=None):

    gpu_idx = -1
    try:
        gpu_idx = maybe_get_a_gpu() if gpu_queue is None else gpu_queue.get()
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu_idx
    except Exception as e:
        print(str(e))
    print("Loading DWI data ...")

    dwi_img = nib.load(config["dwi_path"])
    dwi_img = nib.funcs.as_closest_canonical(dwi_img)
    dwi_aff = dwi_img.affine
    dwi_affi = np.linalg.inv(dwi_aff)
    dwi = dwi_img.get_data()

    def xyz2ijk(coords, snap=False):

        ijk = (coords.T).copy()

        ijk = np.vstack([ijk, np.ones([1, ijk.shape[1]])])

        dwi_affi.dot(ijk, out=ijk)

        if snap:
            return (np.round(ijk, out=ijk).astype(int, copy=False).T)[:, :4]
        else:
            return (ijk.T)[:, :4]

    # ==========================================================================

    print("Loading fibers ...")

    trk_file = nib.streamlines.load(config["trk_path"])
    tractogram = trk_file.tractogram

    if "t" in tractogram.data_per_point:
        print("Fibers are already resampled")
        tangents = tractogram.data_per_point["t"]
    else:
        print("Fibers are not resampled. Resampling now ...")
        tractogram = maybe_add_tangent(config["trk_path"],
                                       min_length=30,
                                       max_length=200)
        tangents = tractogram.data_per_point["t"]

    n_fibers = len(tractogram)
    fiber_lengths = np.array([len(t.streamline) for t in tractogram])
    max_length = fiber_lengths.max()
    n_pts = fiber_lengths.sum()

    # ==========================================================================

    print("Loading model ...")
    model_name = config['model_name']

    if hasattr(MODELS[model_name], "custom_objects"):
        model = load_model(config["model_path"],
                           custom_objects=MODELS[model_name].custom_objects,
                           compile=False)
    else:
        model = load_model(config["model_path"], compile=False)

    block_size = get_blocksize(model, dwi.shape[-1])

    d = np.zeros([n_fibers, dwi.shape[-1] * block_size**3 + 1])

    inputs = np.zeros([n_fibers, max_length, 3])

    print("Writing to input array ...")

    for i, fiber_t in enumerate(tangents):
        inputs[i, :fiber_lengths[i], :] = fiber_t

    outputs = np.zeros([n_fibers, max_length, 4])

    print("Starting iteration ...")

    step = 0
    while step < max_length:
        t0 = time()

        xyz = inputs[:, step, :]
        ijk = xyz2ijk(xyz, snap=True)

        for ii, idx in enumerate(ijk):
            try:
                d[ii, :-1] = dwi[idx[0] - (block_size // 2):idx[0] +
                                 (block_size // 2) + 1,
                                 idx[1] - (block_size // 2):idx[1] +
                                 (block_size // 2) + 1, idx[2] -
                                 (block_size // 2):idx[2] + (block_size // 2) +
                                 1, :].flatten()  # returns copy
            except (IndexError, ValueError):
                pass

        d[:, -1] = np.linalg.norm(d[:, :-1], axis=1) + 10**-2

        d[:, :-1] /= d[:, -1].reshape(-1, 1)

        if step == 0:
            vin = -inputs[:, step + 1, :]
            vout = -inputs[:, step, :]
        else:
            vin = inputs[:, step - 1, :]
            vout = inputs[:, step, :]

        model_inputs = np.hstack([vin, d])
        chunk = 2**15  # 32768
        n_chunks = np.ceil(n_fibers / chunk).astype(int)
        for c in range(n_chunks):

            fvm_pred, kappa_pred = model(model_inputs[c * chunk:(c + 1) *
                                                      chunk])

            log1p_kappa_pred = np.log1p(kappa_pred)

            log_prob_pred = fvm_pred.log_prob(vout[c * chunk:(c + 1) * chunk])

            log_prob_map_pred = fvm_pred._log_normalization() + kappa_pred

            outputs[c * chunk:(c + 1) * chunk, step, 0] = kappa_pred
            outputs[c * chunk:(c + 1) * chunk, step, 1] = log1p_kappa_pred
            outputs[c * chunk:(c + 1) * chunk, step, 2] = log_prob_pred
            outputs[c * chunk:(c + 1) * chunk, step, 3] = log_prob_map_pred

        print("Step {:3d}/{:3d}, ETA: {:4.0f} min".format(
            step, max_length, (max_length - step) * (time() - t0) / 60),
              end="\r")

        step += 1

    if gpu_queue is not None:
        gpu_queue.put(gpu_idx)

    kappa = [
        outputs[i, :fiber_lengths[i], 0].reshape(-1, 1)
        for i in range(n_fibers)
    ]
    log1p_kappa = [
        outputs[i, :fiber_lengths[i], 1].reshape(-1, 1)
        for i in range(n_fibers)
    ]
    log_prob = [
        outputs[i, :fiber_lengths[i], 2].reshape(-1, 1)
        for i in range(n_fibers)
    ]
    log_prob_map = [
        outputs[i, :fiber_lengths[i], 3].reshape(-1, 1)
        for i in range(n_fibers)
    ]

    log_prob_sum = [
        np.ones_like(log_prob[i]) * (log_prob[i].sum() / log_prob_map[i].sum())
        for i in range(n_fibers)
    ]
    log_prob_ratio = [
        np.ones_like(log_prob[i]) * (log_prob[i] - log_prob_map[i]).mean()
        for i in range(n_fibers)
    ]

    other_data = {}
    for key in list(trk_file.tractogram.data_per_point.keys()):
        if key not in [
                "kappa", "log1p_kappa", "log_prob", "log_prob_map",
                "log_prob_sum", "log_prob_ratio"
        ]:
            other_data[key] = trk_file.tractogram.data_per_point[key]

    data_per_point = PerArraySequenceDict(n_rows=n_pts,
                                          kappa=kappa,
                                          log_prob=log_prob,
                                          log_prob_sum=log_prob_sum,
                                          log_prob_ratio=log_prob_ratio,
                                          **other_data)
    tractogram = Tractogram(streamlines=tractogram.streamlines,
                            data_per_point=data_per_point,
                            affine_to_rasmm=np.eye(4))
    out_dir = os.path.join(os.path.dirname(config["dwi_path"]),
                           "marked_fibers", timestamp())
    os.makedirs(out_dir, exist_ok=True)

    marked_path = os.path.join(out_dir, "marked.trk")
    TrkFile(tractogram, trk_file.header).save(marked_path)

    config["out_dir"] = out_dir

    configs.save(config)
Beispiel #6
0
 def save_config(self, event):
     configs.save(self)
Beispiel #7
0
def agreement(model_path,
              dwi_path_1,
              trk_path_1,
              dwi_path_2,
              trk_path_2,
              wm_path,
              fixel_cnt_path,
              cluster_thresh,
              centroid_size,
              fixel_thresh,
              bundle_min_cnt,
              gpu_queue=None):

    try:
        gpu_idx = maybe_get_a_gpu() if gpu_queue is None else gpu_queue.get()
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu_idx
    except Exception as e:
        print(str(e))

    temperature = np.round(float(re.findall("T=(.*)\.h5", model_path)[0]), 6)
    model = load_model(model_path)

    print("Load data ...")

    dwi_img_1 = nib.load(dwi_path_1)
    dwi_img_1 = nib.funcs.as_closest_canonical(dwi_img_1)
    affine_1 = dwi_img_1.affine
    dwi_1 = dwi_img_1.get_data()

    dwi_img_2 = nib.load(dwi_path_2)
    dwi_img_2 = nib.funcs.as_closest_canonical(dwi_img_2)
    affine_2 = dwi_img_2.affine
    dwi_2 = dwi_img_2.get_data()

    wm_img = nib.load(wm_path)
    wm_data = wm_img.get_data()
    n_wm = (wm_data > 0).sum()

    fixel_cnt = nib.load(fixel_cnt_path).get_data()[:, :, :, 0]
    fixel_cnt = fixel_cnt[wm_data > 0]

    k_fixels = np.unique(fixel_cnt)
    max_fixels = k_fixels.max()
    n_fixels_gt = np.sum(k * (fixel_cnt == k).sum() for k in k_fixels)

    img_shape = dwi_1.shape[:-1]

    #---------------------------------------------------------------------------

    tractogram_1 = maybe_add_tangent(trk_path_1)
    tractogram_2 = maybe_add_tangent(trk_path_2)

    streamlines_1 = tractogram_1.streamlines
    streamlines_2 = tractogram_2.streamlines

    n_streamlines_1 = len(streamlines_1)
    n_streamlines_2 = len(streamlines_2)

    tractogram_1.extend(tractogram_2)

    ############################################################################

    print("Clustering streamlines.")

    feature = ResampleFeature(nb_points=centroid_size)

    qb = QuickBundles(threshold=cluster_thresh,
                      metric=AveragePointwiseEuclideanMetric(feature))

    bundles = qb.cluster(streamlines_1)
    bundles.refdata = tractogram_1

    n_bundles = len(bundles)

    print("Found {} bundles.".format(n_bundles))

    print("Computing bundle masks...")

    direction_masks_1 = np.zeros((n_bundles, ) + img_shape + (3, ), np.float16)
    direction_masks_2 = np.zeros((n_bundles, ) + img_shape + (3, ), np.float16)
    count_masks_1 = np.zeros((n_bundles, ) + img_shape, np.uint16)
    count_masks_2 = np.zeros((n_bundles, ) + img_shape, np.uint16)

    marginal_bundles = 0
    for i, b in enumerate(bundles.clusters):

        is_from_1 = np.argwhere(
            np.array(b.indices) < n_streamlines_1).squeeze().tolist()
        is_from_2 = np.argwhere(
            np.array(b.indices) >= n_streamlines_1).squeeze().tolist()

        if (np.sum(is_from_1) > bundle_min_cnt
                and np.sum(is_from_2) > bundle_min_cnt):

            bundle_map(b[is_from_1],
                       affine_1,
                       img_shape,
                       dir_out=direction_masks_1[i],
                       cnt_out=count_masks_1[i])

            bundle_map(b[is_from_2],
                       affine_2,
                       img_shape,
                       dir_out=direction_masks_2[i],
                       cnt_out=count_masks_2[i])
        else:
            marginal_bundles += 1

        assert direction_masks_1.dtype.name == "float16"
        assert direction_masks_2.dtype.name == "float16"
        assert count_masks_1.dtype.name == "uint16"
        assert count_masks_2.dtype.name == "uint16"

        print("Computed bundle {:3d}.".format(i), end="\r")

        #gc.collect()

    overlap = ((count_masks_1 > 0) * (count_masks_2 > 0) *
               np.expand_dims(wm_data > 0, 0))

    print("Calculating Fixels...")

    fixel_directions_1 = []
    fixel_directions_2 = []
    fixel_cnts_1 = []
    fixel_cnts_2 = []
    fixel_ijk = []
    n_fixels = []
    no_overlap = 0
    for vox in np.argwhere(wm_data > 0):

        matched = overlap[:, vox[0], vox[1], vox[2]] > 0

        if matched.sum() > 0:

            dir_1 = direction_masks_1[matched, vox[0], vox[1], vox[2], :]
            cnts_1 = count_masks_1[matched, vox[0], vox[1], vox[2]]

            dir_2 = direction_masks_2[matched, vox[0], vox[1], vox[2], :]
            cnts_2 = count_masks_2[matched, vox[0], vox[1], vox[2]]

            fixels1, fixels2, f_cnts_1, f_cnts_2 = cluster_fixels(
                dir_1,
                dir_2,
                cnts_1,
                cnts_2,
                threshold=np.cos(np.pi / fixel_thresh))

            n_f = len(fixels1)

            fixel_directions_1.append(fixels1)
            fixel_directions_2.append(fixels2)

            fixel_cnts_1.append(f_cnts_1)
            fixel_cnts_2.append(f_cnts_2)

            fixel_ijk.append(np.tile(vox, (n_f, 1)))

            n_fixels.append(n_f)
        else:
            no_overlap += 1

    fixel_directions_1 = np.vstack(fixel_directions_1)
    fixel_directions_2 = np.vstack(fixel_directions_2)
    fixel_cnts_1 = np.vstack(fixel_cnts_1).reshape(-1)
    fixel_cnts_2 = np.vstack(fixel_cnts_2).reshape(-1)
    fixel_ijk = np.vstack(fixel_ijk)

    #gc.collect()

    ############################################################################

    print("Computing agreement ...")

    n_fixels_sum = np.sum(n_fixels)

    block_size = get_blocksize(model, dwi_1.shape[-1])

    d_1 = np.zeros(
        [n_fixels_sum, block_size, block_size, block_size, dwi_1.shape[-1]])
    d_2 = np.zeros(
        [n_fixels_sum, block_size, block_size, block_size, dwi_1.shape[-1]])
    i, j, k = fixel_ijk.T
    for idx in range(block_size**3):
        ii, jj, kk = np.unravel_index(idx,
                                      (block_size, block_size, block_size))
        d_1[:, ii, jj, kk, :] = dwi_1[i + ii - 1, j + jj - 1, k + kk - 1, :]
        d_2[:, ii, jj, kk, :] = dwi_2[i + ii - 1, j + jj - 1, k + kk - 1, :]

    d_1 = d_1.reshape(-1, dwi_1.shape[-1] * block_size**3)
    d_2 = d_2.reshape(-1, dwi_2.shape[-1] * block_size**3)

    dnorm_1 = np.linalg.norm(d_1, axis=1, keepdims=True) + 10**-2
    dnorm_2 = np.linalg.norm(d_2, axis=1, keepdims=True) + 10**-2

    d_1 /= dnorm_1
    d_2 /= dnorm_2

    model_inputs_1 = np.hstack([fixel_directions_1, d_1, dnorm_1])
    model_inputs_2 = np.hstack([fixel_directions_2, d_2, dnorm_2])

    fixel_agreements, fixel_kappa_1, fixel_kappa_2, fixel_mu_1, fixel_mu_2 = \
    agreement_for(
        model,
        model_inputs_1,
        model_inputs_2,
        fixel_cnts_1,
        fixel_cnts_2
    )

    agreement = {"temperature": temperature}
    agreement["model_path"] = model_path
    agreement["n_bundles"] = n_bundles
    agreement["value"] = fixel_agreements.sum() / n_fixels_gt
    agreement["min"] = fixel_agreements.min()
    agreement["mean"] = fixel_agreements.mean()
    agreement["max"] = fixel_agreements.max()
    agreement["std"] = fixel_agreements.std()
    agreement["n_fixels_sum"] = n_fixels_sum
    agreement["n_wm"] = n_wm
    agreement["n_fixels_gt"] = n_fixels_gt
    agreement["marginal_bundles"] = marginal_bundles
    agreement["no_overlap"] = no_overlap
    agreement["dwi_1"] = dwi_path_1
    agreement["trk_1"] = trk_path_1
    agreement["dwi_2"] = dwi_path_2
    agreement["trk_2"] = trk_path_2
    agreement["fixel_cnt_path"] = fixel_cnt_path
    agreement["cluster_thresh"] = cluster_thresh
    agreement["centroid_size"] = centroid_size
    agreement["fixel_thresh"] = fixel_thresh
    agreement["bundle_min_cnt"] = bundle_min_cnt
    agreement["wm_path"] = wm_path
    agreement["ideal"] = ideal_agreement(temperature)

    for k, cnt in zip(*np.unique(n_fixels, return_counts=True)):
        agreement["n_vox_with_{}_fixels".format(k)] = cnt

    for i in [1, 5, 10]:
        agreement["le_{}_fibers_per_fixel_1".format(i)] = np.mean(
            fixel_cnts_1 < i)

    agreement["mean_fibers_per_fixel_1"] = np.mean(fixel_cnts_1)
    agreement["median_fibers_per_fixel_1"] = np.median(fixel_cnts_1)
    agreement["mean_fibers_per_fixel_2"] = np.mean(fixel_cnts_2)
    agreement["median_fibers_per_fixel_2"] = np.median(fixel_cnts_2)
    agreement["std_fibers_per_fixel"] = np.std(fixel_cnts_1)
    agreement["max_fibers_per_fixel"] = np.max(fixel_cnts_1)
    agreement["min_fibers_per_fixel"] = np.min(fixel_cnts_1)

    fixel_angles = (fixel_directions_1 * fixel_directions_2).sum(axis=1)
    agreement["mean_fixel_angle"] = fixel_angles.mean()
    agreement["median_fixel_angle"] = np.median(fixel_angles)
    agreement["std_fixel_angle"] = fixel_angles.std()
    agreement["negative_fixel_angles"] = (fixel_angles < 0).mean()

    save(agreement, "agreement_T={}.yml".format(temperature),
         os.path.dirname(model_path))

    np.savez(
        os.path.join(os.path.dirname(model_path),
                     "data_T={}".format(temperature)),
        fixel_cnts_1=fixel_cnts_1,
        fixel_cnts_2=fixel_cnts_2,
        fixel_mu_1=fixel_mu_1,
        fixel_mu_2=fixel_mu_2,
        fixel_kappa_1=fixel_kappa_1,
        fixel_kappa_2=fixel_kappa_2,
        fixel_directions_1=fixel_directions_1,
        fixel_directions_2=fixel_directions_2,
        fixel_agreements=fixel_agreements,
    )

    K.clear_session()
    if gpu_queue is not None:
        gpu_queue.put(gpu_idx)