def test_density_sampling():
    import os
    import os.path
    import numpy as np
    import experiments.registration.dataset_info as info
    import nibabel as nib
    import dipy.align.metrics as metrics
    import dipy.align.imwarp as imwarp
    from dipy.align import VerbosityLevels
    from experiments.registration.rcommon import getBaseFileName, decompose_path, readAntsAffine
    from dipy.fixes import argparse as arg
    from experiments.registration.evaluation import (
        compute_densities,
        sample_from_density,
        create_ss_de,
        create_ss_mode,
        create_ss_median,
        create_ss_mean,
    )
    from experiments.registration.splines import CubicSpline
    import dipy.viz.regtools as rt

    i1_name = info.get_ibsr(1, "strip")
    i1_nib = nib.load(i1_name)
    i1 = i1_nib.get_data().squeeze()
    mask = (i1 > 0).astype(np.int32)

    wt1_name = "warpedDiff_brainweb_t2_strip_IBSR_01_ana_strip.nii.gz"
    wt1_nib = nib.load(wt1_name)
    wt1 = wt1_nib.get_data().squeeze()

    rt.overlay_slices(i1, wt1)
    nbins = 100
    densities = compute_densities(i1.astype(np.int32), wt1.astype(np.float64), nbins, mask)
    figure()
    imshow(densities)

    # Compare different estimators
    ss_sampled = create_ss_de(i1.astype(np.int32), densities)
    ss_mode = create_ss_mode(i1.astype(np.int32), densities)
    ss_median = create_ss_median(i1.astype(np.int32), densities)
    ss_mean = create_ss_mean(i1.astype(np.int32), densities)

    rt.overlay_slices(ss_sampled, i1)
    rt.overlay_slices(ss_mode, i1)
    rt.overlay_slices(ss_median, i1)
    rt.overlay_slices(ss_mean, i1)

    s1 = ss_sampled[:, ss_sampled.shape[1] // 2, :].T
    s2 = ss_mode[:, ss_mode.shape[1] // 2, :].T
    s3 = ss_median[:, ss_median.shape[1] // 2, :].T
    s4 = ss_mean[:, ss_mean.shape[1] // 2, :].T
    slices = [[s1, s2], [s3, s4]]
    titles = [["Sampled", "Mode"], ["Median", "Mean"]]

    fig, ax = plt.subplots(2, 2)
    fig.set_facecolor("white")
    for ii, a_row in enumerate(ax):
        for jj, a in enumerate(a_row):
            a.set_axis_off()
            a.imshow(slices[ii][jj], cmap=cm.gray, origin="lower")
            a.set_title(titles[ii][jj])

    # Fit densities with splines in a regular grid
    f = densities[100]
    kspacing = 1  # Number of grid cells between spline knots
    spline = CubicSpline(kspacing)
    coef = spline.fit_to_data(f)
    # Check fit
    fit = spline.evaluate(coef, nbins)
    figure()
    plot(f)
    plot(fit)
    # And the derivative
    df = spline.evaluate(coef, nbins, 1)
    figure()
    plot(f)
    plot(df)

    fit = np.zeros_like(densities)
    for i in range(densities.shape[0]):
        coef = spline.fit_to_data(densities[i])
        fit[i, :] = spline.evaluate(coef, nbins)
    fig = plt.figure()
    ax = fig.add_subplot(1, 2, 1)
    ax.imshow(densities)
    ax = fig.add_subplot(1, 2, 2)
    ax.imshow(fit)
def create_semi_synthetic(params):
    r""" Create semi-synthetic image using real_mod1 as anatomy and tmp_mod2 template as intensity model
    Template tmp_mod1 is registered towards real_mod1 (which are assumed of the same modality) using SyN-CC.
    The transformation is applied to template tmp_mod2 (which is assumed to be perfectly aligned with tmp_mod1).
    The transfer function is computed from real_mod1 to warped tmp_mod2 and applied to real_mod1.
    """
    real_mod1 = params.real
    base_fixed = getBaseFileName(real_mod1)
    tmp_mod1 = params.template
    prealign_name = params.prealign
    tmp_mod2_list = [os.path.join(params.warp_dir, name) for name in os.listdir(params.warp_dir)]
    tmp_mod2_list = [tmp_mod1] + tmp_mod2_list
    # Check if all warpings are already done
    warp_done = os.path.isfile("mask_" + base_fixed + ".nii.gz")
    if warp_done:
        for tmp_mod2 in tmp_mod2_list:
            base_moving = getBaseFileName(tmp_mod2)
            wname = "warpedDiff_" + base_moving + "_" + base_fixed
            if real_mod1[-3:] == "img":  # Analyze
                wname += ".img"
            else:
                wname += ".nii.gz"
            if not os.path.isfile(wname):
                warp_done = False
                break

    # Load input images
    real_nib = nib.load(real_mod1)
    real_aff = real_nib.get_affine()
    real = real_nib.get_data().squeeze()
    if real_mod1[-3:] == "img":  # Analyze: move reference from center to corner
        offset = real_aff[:3, :3].dot(np.array(real.shape) // 2)
        real_aff[:3, 3] += offset

    t_mod1_nib = nib.load(tmp_mod1)
    t_mod1_aff = t_mod1_nib.get_affine()
    t_mod1 = t_mod1_nib.get_data().squeeze()
    if tmp_mod1[-3:] == "img":  # Analyze: move reference from center to corner
        offset = t_mod1_aff[:3, :3].dot(np.array(t_mod1.shape) // 2)
        t_mod1_aff[:3, 3] += offset

    # Load pre-align matrix
    print("Pre-align:", prealign_name)
    if not prealign_name:
        prealign = np.eye(4)
    else:
        if real_mod1[-3:] == "img":  # Analyze
            ref_coordinate_system = "LAS"
        else:  # DICOM
            ref_coordinate_system = "LPS"

        if tmp_mod1[-3:] == "img":  # Analyze
            tgt_coordinate_system = "LAS"
        else:  # DICOM
            tgt_coordinate_system = "LPS"
        prealign = readAntsAffine(prealign_name, ref_coordinate_system, tgt_coordinate_system)
    # Configure CC metric
    sigma_diff = 1.7
    radius = 4
    similarity_metric = metrics.CCMetric(3, sigma_diff, radius)

    # Configure optimizer
    opt_iter = [100, 100, 50]
    step_length = 0.25
    opt_tol = 1e-5
    inv_iter = 20
    inv_tol = 1e-3
    ss_sigma_factor = 0.2
    if not warp_done:
        syn = imwarp.SymmetricDiffeomorphicRegistration(
            similarity_metric, opt_iter, step_length, ss_sigma_factor, opt_tol, inv_iter, inv_tol, callback=None
        )
        # Run registration
        syn.verbosity = VerbosityLevels.DEBUG
        mapping = syn.optimize(real, t_mod1, real_aff, t_mod1_aff, prealign)
        # Save the warped template (so we can visually check the registration result)
        warped = mapping.transform(t_mod1)
        base_moving = getBaseFileName(tmp_mod1)
        oname = "warpedDiff_" + base_moving + "_" + base_fixed
        if real_mod1[-3:] == "img":  # Analyze
            oname += ".img"
        else:
            oname += ".nii.gz"
        real[...] = warped[...]
        real_nib.to_filename(oname)
        mask = (t_mod1 > 0).astype(np.int32)
        wmask = mapping.transform(mask, "nearest")
        wmask_nib = nib.Nifti1Image(wmask, t_mod1_aff)
        wmask_nib.to_filename("mask_" + base_fixed + ".nii.gz")
    else:
        wmask_nib = nib.load("mask_" + base_fixed + ".nii.gz")
        wmask = wmask_nib.get_data().squeeze()

    # Compute and save the semi-synthetic images in different modalities
    for tmp_mod2 in tmp_mod2_list:
        print("Warping: " + tmp_mod2)
        t_mod2_nib = nib.load(tmp_mod2)
        t_mod2_aff = t_mod2_nib.get_affine()
        t_mod2 = t_mod2_nib.get_data().squeeze()

        base_moving = getBaseFileName(tmp_mod2)
        oname = base_moving + "_" + base_fixed
        if real_mod1[-3:] == "img":  # Analyze
            oname += ".img"
        else:
            oname += ".nii.gz"

        if not warp_done:
            # Save warped image
            warped = mapping.transform(t_mod2)
            wnib = nib.Nifti1Image(warped, t_mod2_aff)
            wnib.to_filename("warpedDiff_" + oname)
        else:
            wnib = nib.load("warpedDiff_" + oname)
            warped = wnib.get_data().squeeze()

        real_nib = nib.load(real_mod1)
        real = real_nib.get_data().squeeze()

        use_density_estimation = True
        nbins = 100
        if use_density_estimation:
            print("Using density sampling.")
            oname = "ssds_" + oname
            # Compute marginal distributions
            densities = np.array(compute_densities(real.astype(np.int32), warped.astype(np.float64), nbins, wmask))
            # Sample the marginal distributions
            real[...] = create_ss_de(real.astype(np.int32), densities)
        else:
            print("Using mean transfer.")
            oname = "ssmt_" + oname
            # Compute transfer function
            means, vars = get_mean_transfer(real, warped)
            # Apply transfer to real
            real[...] = means[real]

        # Save semi_synthetic
        real_nib.to_filename(oname)