Пример #1
0
def objective_function(optimise_array, static_image, dynamic_path, dvf_path,
                       weighted_normalise, dynamic_data_magnitude):
    static_image.fill(
        np.reshape(optimise_array,
                   static_image.as_array().astype(np.double).shape))

    objective_value = 0.0

    for i in range(len(dynamic_path)):
        dynamic_image = reg.NiftiImageData(dynamic_path[i])
        dvf_image = reg.NiftiImageData3DDeformation(dvf_path[i])

        resampler = reg.NiftyResample()

        resampler.set_reference_image(static_image)
        resampler.set_floating_image(dynamic_image)
        resampler.add_transformation(dvf_image)

        resampler.set_interpolation_type_to_cubic_spline()

        objective_value = objective_value + (np.nansum(
            np.square(dynamic_image.as_array().astype(np.double) -
                      ((np.nansum(dynamic_image.as_array().astype(np.double),
                                  dtype=np.double) / dynamic_data_magnitude) *
                       warp_image_forward(resampler, static_image)),
                      dtype=np.double),
            dtype=np.double) * weighted_normalise[i])

    print("Objective function value: {0}".format(str(objective_value)))

    return objective_value
Пример #2
0
def get_resamplers(static_image, dynamic_array, dvf_array, output_path):
    resamplers = []

    static_image_path = "{0}/temp_static.nii".format(output_path)
    dynamic_array_path = "{0}/temp_dynamic.nii".format(output_path)
    dvf_array_path = "{0}/temp_dvf.nii".format(output_path)

    for j in range(len(dynamic_array)):
        resampler = reg.NiftyResample()

        static_image.write(static_image_path)
        dynamic_array[j].write(dynamic_array_path)
        dvf_array[j].write(dvf_array_path)

        temp_static = reg.NiftiImageData(static_image_path)
        temp_dynamic = reg.NiftiImageData(dynamic_array_path)
        temp_dvf = reg.NiftiImageData3DDeformation(dvf_array_path)

        resampler.set_reference_image(temp_static)
        resampler.set_floating_image(temp_dynamic)
        resampler.add_transformation(temp_dvf)

        resampler.set_interpolation_type_to_linear()

        resamplers.append(resampler)

    return resamplers
Пример #3
0
def output_input(static_image, dynamic_path, dvf_path, output_path):
    static_image.write("{0}/static_image.nii".format(output_path))

    for i in range(len(dynamic_path)):
        dynamic_image = reg.NiftiImageData(dynamic_path[i])
        dvf_image = reg.NiftiImageData3DDeformation(dvf_path[i])

        dynamic_image.write("{0}/dynamic_image_{1}.nii".format(
            output_path, str(i)))
        dvf_image.write("{0}/dvf_image_{1}.nii".format(output_path, str(i)))

    return True
Пример #4
0
def test_for_adj(static_image, dvf_array, output_path):
    static_image_path = "{0}/temp_static.nii".format(output_path)
    dvf_array_path = "{0}/temp_dvf.nii".format(output_path)

    for i in range(len(dvf_array)):
        static_image.write(static_image_path)
        dvf_array[i].write(dvf_array_path)

        temp_static = reg.NiftiImageData(static_image_path)
        temp_dvf = reg.NiftiImageData3DDeformation(dvf_array_path)

        resampler = reg.NiftyResample()

        resampler.set_reference_image(temp_static)
        resampler.set_floating_image(temp_static)
        resampler.add_transformation(temp_dvf)

        resampler.set_interpolation_type_to_linear()

        warp = warp_image_forward(resampler, temp_static)

        warped_image = static_image.clone()
        warped_image.fill(warp)

        warped_image.write("{0}/warp_forward_{1}.nii".format(
            output_path, str(i)))

        difference = temp_static.as_array().astype(np.double) - warp

        difference_image = temp_static.clone()
        difference_image.fill(difference)

        difference_image.write("{0}/warp_forward_difference_{1}.nii".format(
            output_path, str(i)))

        warp = warp_image_adjoint(resampler, temp_static)

        warped_image = temp_static.clone()
        warped_image.fill(warp)

        warped_image.write("{0}/warp_adjoint_{1}.nii".format(
            output_path, str(i)))

        difference = temp_static.as_array().astype(np.double) - warp

        difference_image = temp_static.clone()
        difference_image.fill(difference)

        difference_image.write("{0}/warp_adjoint_difference_{1}.nii".format(
            output_path, str(i)))

    return True
Пример #5
0
def gradient_function(optimise_array, static_image, dynamic_path, dvf_path,
                      weighted_normalise, dynamic_data_magnitude):
    static_image.fill(
        np.reshape(optimise_array,
                   static_image.as_array().astype(np.double).shape))

    gradient_value = static_image.clone()
    gradient_value.fill(0.0)

    adjoint_image = static_image.clone()

    for i in range(len(dynamic_path)):
        dynamic_image = reg.NiftiImageData(dynamic_path[i])
        dvf_image = reg.NiftiImageData3DDeformation(dvf_path[i])

        resampler = reg.NiftyResample()

        resampler.set_reference_image(static_image)
        resampler.set_floating_image(dynamic_image)
        resampler.add_transformation(dvf_image)

        resampler.set_interpolation_type_to_cubic_spline()

        adjoint_image.fill((
            (np.nansum(dynamic_image.as_array().astype(np.double),
                       dtype=np.double) / dynamic_data_magnitude) *
            warp_image_forward(resampler, static_image)) -
                           dynamic_image.as_array().astype(np.double))
        gradient_value.fill((gradient_value.as_array().astype(np.double) +
                             (warp_image_adjoint(resampler, adjoint_image) *
                              weighted_normalise[i])))

        # gradient_value.write("{0}/gradient.nii".format(output_path))

    print(
        "Max gradient value: {0}, Mean gradient value: {1}, Gradient norm: {2}"
        .format(
            str(np.amax(gradient_value.as_array().astype(np.double))),
            str(
                np.nanmean(
                    np.abs(gradient_value.as_array().astype(np.double),
                           dtype=np.double))),
            str(np.linalg.norm(gradient_value.as_array().astype(np.double)))))

    return np.ravel(gradient_value.as_array().astype(np.double)).astype(
        np.double)
Пример #6
0
def test_for_adj(static_image, dvf_path, output_path):
    for i in range(len(dvf_path)):
        dvf_image = reg.NiftiImageData3DDeformation(dvf_path[i])

        resampler = reg.NiftyResample()

        resampler.set_reference_image(static_image)
        resampler.set_floating_image(static_image)
        resampler.add_transformation(dvf_image)

        resampler.set_interpolation_type_to_cubic_spline()

        warp = warp_image_forward(resampler, static_image)

        warped_image = static_image.clone()
        warped_image.fill(warp)

        warped_image.write("{0}/warp_forward_{1}.nii".format(
            output_path, str(i)))

        difference = static_image.as_array().astype(np.double) - warp

        difference_image = static_image.clone()
        difference_image.fill(difference)

        difference_image.write("{0}/warp_forward_difference_{1}.nii".format(
            output_path, str(i)))

        warp = warp_image_adjoint(resampler, static_image)

        warped_image = static_image.clone()
        warped_image.fill(warp)

        warped_image.write("{0}/warp_adjoint_{1}.nii".format(
            output_path, str(i)))

        difference = static_image.as_array().astype(np.double) - warp

        difference_image = static_image.clone()
        difference_image.fill(difference)

        difference_image.write("{0}/warp_adjoint_difference_{1}.nii".format(
            output_path, str(i)))

    return True
Пример #7
0
def resample_attn_image(image):
    """Resample the attenuation image."""
    if trans_type == 'tm':
        transformation = reg.AffineTransformation(trans)
    elif trans_type == 'disp':
        transformation = reg.NiftiImageData3DDisplacement(trans)
    elif trans_type == 'def':
        transformation = reg.NiftiImageData3DDeformation(trans)
    else:
        raise ValueError("Unknown transformation type.")

    resampler = reg.NiftyResample()
    resampler.set_reference_image(image)
    resampler.set_floating_image(image)
    resampler.set_interpolation_type_to_linear()
    resampler.set_padding_value(0.0)
    resampler.add_transformation(transformation)
    return resampler.forward(image)
Пример #8
0
def back_warp(static_path, dvf_path, output_path):
    if not os.path.exists(output_path):
        os.makedirs(output_path, mode=0o770)

    for i in range(len(dvf_path)):
        static_image = reg.NiftiImageData(static_path)
        dvf_image = reg.NiftiImageData3DDeformation(dvf_path[i])

        resampler = reg.NiftyResample()

        resampler.set_reference_image(static_image)
        resampler.set_floating_image(static_image)
        resampler.add_transformation(dvf_image)

        resampler.set_interpolation_type_to_cubic_spline()

        warped_static_image = warp_image_forward(resampler, static_image)

        static_image.fill(warped_static_image)

        static_image.write("{0}/back_warped_{1}.nii".format(
            output_path, str(i)))

        return True
Пример #9
0
def main():
    # file paths to data
    input_data_path = parser.parser(sys.argv[1], "data_path:=")
    data_split = parser.parser(sys.argv[1], "data_split:=")
    input_dvf_path = parser.parser(sys.argv[1], "dvf_path:=")
    dvf_split = parser.parser(sys.argv[1], "dvf_split:=")
    output_path = parser.parser(sys.argv[1], "output_path:=")
    do_op_test = parser.parser(sys.argv[1], "do_op_test:=")
    do_reg = parser.parser(sys.argv[1], "do_reg:=")
    do_test_for_adj = parser.parser(sys.argv[1], "do_test_for_adj:=")

    for i in range(len(input_data_path)):
        if not os.path.exists(output_path[i]):
            os.makedirs(output_path[i], mode=0o770)

        new_dvf_path = "{0}/new_dvfs/".format(output_path[i])

        if not os.path.exists(new_dvf_path):
            os.makedirs(new_dvf_path, mode=0o770)

        # get static and dynamic paths
        dynamic_path = get_data_path(input_data_path[i], data_split[i])

        # load dynamic objects
        dynamic_array = []

        for j in range(len(dynamic_path)):
            dynamic_array.append(reg.NiftiImageData(dynamic_path[j]))

        static_path = "{0}/static_path.nii".format(output_path[i])

        # load static objects
        static_image = reg.NiftiImageData(dynamic_path[0])

        for j in range(1, len(dynamic_path)):
            static_image.fill(static_image.as_array().astype(np.double) +
                              dynamic_array[j].as_array().astype(np.double))

        static_image.write(static_path)

        if bool(distutils.util.strtobool(do_op_test[i])):
            op_test(static_image, output_path[i])

        # if do reg the calc dvf if not load
        if bool(distutils.util.strtobool(do_reg[i])):
            dvf_path = register_data(static_path, dynamic_path, output_path[i])
        else:
            dvf_path = get_dvf_path(input_dvf_path[i], dvf_split[i])

        # fix dvf header and load dvf objects
        dvf_path = edit_header(dvf_path, new_dvf_path)

        dvf_array = []

        for j in range(len(dvf_path)):
            dvf_array.append(reg.NiftiImageData3DDeformation(dvf_path[j]))

        # create object to get forward and adj
        resamplers = get_resamplers(static_image, dynamic_array, dvf_array,
                                    output_path[i])

        # test for adj
        if bool(distutils.util.strtobool(do_test_for_adj[i])):
            test_for_adj(static_image, dvf_array, output_path[i])
            output_input(static_image, dynamic_array, dvf_array,
                         output_path[i])

        # initial static image
        initial_static_image = static_image.clone()

        # array to optimise
        optimise_array = static_image.as_array().astype(np.double)

        # array bounds
        bounds = []

        for j in range(len(np.ravel(optimise_array.copy()))):
            bounds.append((-np.inf, np.inf))

        # optimise
        optimise_array = np.reshape(
            scipy.optimize.minimize(objective_function,
                                    np.ravel(optimise_array).astype(np.double),
                                    args=(resamplers, dynamic_array,
                                          static_image, output_path[i]),
                                    method="L-BFGS-B",
                                    jac=gradient_function,
                                    bounds=bounds,
                                    tol=0.0000000001,
                                    options={
                                        "disp": True
                                    }).x, optimise_array.shape)

        # output
        static_image.fill(optimise_array)
        static_image.write("{0}/optimiser_output_{1}.nii".format(
            output_path[i], str(i)))

        difference = static_image.as_array().astype(
            np.double) - initial_static_image.as_array().astype(np.double)

        difference_image = initial_static_image.clone()
        difference_image.fill(difference)

        static_image.write("{0}/optimiser_output_difference_{1}.nii".format(
            output_path[i], str(i)))
Пример #10
0
def main():

    ###########################################################################
    # Parse input files
    ###########################################################################

    if trans_pattern is None:
        raise AssertionError("--trans missing")
    if sino_pattern is None:
        raise AssertionError("--sino missing")
    trans_files = sorted(glob(trans_pattern))
    sino_files = sorted(glob(sino_pattern))
    attn_files = sorted(glob(attn_pattern))
    rand_files = sorted(glob(rand_pattern))

    num_ms = len(sino_files)
    # Check some sinograms found
    if num_ms == 0:
        raise AssertionError("No sinograms found!")
    # Should have as many trans as sinos
    if num_ms != len(trans_files):
        raise AssertionError("#trans should match #sinos. "
                             "#sinos = " + str(num_ms) + ", #trans = " +
                             str(len(trans_files)))
    # If any rand, check num == num_ms
    if len(rand_files) > 0 and len(rand_files) != num_ms:
        raise AssertionError("#rand should match #sinos. "
                             "#sinos = " + str(num_ms) + ", #rand = " +
                             str(len(rand_files)))

    # For attn, there should be 0, 1 or num_ms images
    if len(attn_files) > 1 and len(attn_files) != num_ms:
        raise AssertionError("#attn should be 0, 1 or #sinos")

    ###########################################################################
    # Read input
    ###########################################################################

    if trans_type == "tm":
        trans = [reg.AffineTransformation(file) for file in trans_files]
    elif trans_type == "disp":
        trans = [
            reg.NiftiImageData3DDisplacement(file) for file in trans_files
        ]
    elif trans_type == "def":
        trans = [reg.NiftiImageData3DDeformation(file) for file in trans_files]
    else:
        raise error("Unknown transformation type")

    sinos_raw = [pet.AcquisitionData(file) for file in sino_files]
    attns = [pet.ImageData(file) for file in attn_files]
    rands = [pet.AcquisitionData(file) for file in rand_files]

    # Loop over all sinograms
    sinos = [0] * num_ms
    for ind in range(num_ms):
        # If any sinograms contain negative values
        # (shouldn't be the case), set them to 0
        sino_arr = sinos_raw[ind].as_array()
        if (sino_arr < 0).any():
            print("Input sinogram " + str(ind) +
                  " contains -ve elements. Setting to 0...")
            sinos[ind] = sinos_raw[ind].clone()
            sino_arr[sino_arr < 0] = 0
            sinos[ind].fill(sino_arr)
        else:
            sinos[ind] = sinos_raw[ind]
        # If rebinning is desired
        segs_to_combine = 1
        if args['--numSegsToCombine']:
            segs_to_combine = int(args['--numSegsToCombine'])
        views_to_combine = 1
        if args['--numViewsToCombine']:
            views_to_combine = int(args['--numViewsToCombine'])
        if segs_to_combine * views_to_combine > 1:
            sinos[ind] = sinos[ind].rebin(segs_to_combine, views_to_combine)
            # only print first time
            if ind == 0:
                print(f"Rebinned sino dimensions: {sinos[ind].dimensions()}")

    ###########################################################################
    # Initialise recon image
    ###########################################################################

    if initial_estimate:
        image = pet.ImageData(initial_estimate)
    else:
        # Create image based on ProjData
        image = sinos[0].create_uniform_image(0.0, (nxny, nxny))
        # If using GPU, need to make sure that image is right size.
        if use_gpu:
            dim = (127, 320, 320)
            spacing = (2.03125, 2.08626, 2.08626)
        # elif non-default spacing desired
        elif args['--dxdy']:
            dim = image.dimensions()
            dxdy = float(args['--dxdy'])
            spacing = (image.voxel_sizes()[0], dxdy, dxdy)
        if use_gpu or args['--dxdy']:
            image.initialise(dim=dim, vsize=spacing)
            image.fill(0.0)

    ###########################################################################
    # Set up resamplers
    ###########################################################################

    resamplers = [get_resampler(image, trans=tran) for tran in trans]

    ###########################################################################
    # Resample attenuation images (if necessary)
    ###########################################################################

    resampled_attns = None
    if len(attns) > 0:
        resampled_attns = [0] * num_ms
        # if using GPU, dimensions of attn and recon images have to match
        ref = image if use_gpu else None
        for i in range(len(attns)):
            # if we only have 1 attn image, then we need to resample into
            # space of each gate. However, if we have num_ms attn images, then
            # assume they are already in the correct position, so use None as
            # transformation.
            tran = trans[i] if len(attns) == 1 else None
            # If only 1 attn image, then resample that. If we have num_ms attn
            # images, then use each attn image of each frame.
            attn = attns[0] if len(attns) == 1 else attns[i]
            resam = get_resampler(attn, ref=ref, trans=tran)
            resampled_attns[i] = resam.forward(attn)

    ###########################################################################
    # Set up acquisition models
    ###########################################################################

    print("Setting up acquisition models...")
    if not use_gpu:
        acq_models = num_ms * [pet.AcquisitionModelUsingRayTracingMatrix()]
    else:
        acq_models = num_ms * [pet.AcquisitionModelUsingNiftyPET()]
        for acq_model in acq_models:
            acq_model.set_use_truncation(True)
            acq_model.set_cuda_verbosity(verbosity)

    # If present, create ASM from ECAT8 normalisation data
    asm_norm = None
    if norm_file:
        asm_norm = pet.AcquisitionSensitivityModel(norm_file)

    # Loop over each motion state
    for ind in range(num_ms):
        # Create attn ASM if necessary
        asm_attn = None
        if resampled_attns:
            asm_attn = get_asm_attn(sinos[ind], resampled_attns[i],
                                    acq_models[ind])

        # Get ASM dependent on attn and/or norm
        asm = None
        if asm_norm and asm_attn:
            if ind == 0:
                print("ASM contains norm and attenuation...")
            asm = pet.AcquisitionSensitivityModel(asm_norm, asm_attn)
        elif asm_norm:
            if ind == 0:
                print("ASM contains norm...")
            asm = asm_norm
        elif asm_attn:
            if ind == 0:
                print("ASM contains attenuation...")
            asm = asm_attn
        if asm:
            acq_models[ind].set_acquisition_sensitivity(asm)

        if len(rands) > 0:
            acq_models[ind].set_background_term(rands[ind])

        # Set up
        acq_models[ind].set_up(sinos[ind], image)

    ###########################################################################
    # Set up reconstructor
    ###########################################################################

    print("Setting up reconstructor...")

    # Create composition operators containing acquisition models and resamplers
    C = [
        CompositionOperator(am, res, preallocate=True)
        for am, res in zip(*(acq_models, resamplers))
    ]

    # Configure the PDHG algorithm
    if args['--normK'] and not args['--onlyNormK']:
        normK = float(args['--normK'])
    else:
        kl = [KullbackLeibler(b=sino, eta=(sino * 0 + 1e-5)) for sino in sinos]
        f = BlockFunction(*kl)
        K = BlockOperator(*C)
        # Calculate normK
        print("Calculating norm of the block operator...")
        normK = K.norm(iterations=10)
        print("Norm of the BlockOperator ", normK)
        if args['--onlyNormK']:
            exit(0)

    # Optionally rescale sinograms and BlockOperator using normK
    scale_factor = 1. / normK if args['--normaliseDataAndBlock'] else 1.0
    kl = [
        KullbackLeibler(b=sino * scale_factor, eta=(sino * 0 + 1e-5))
        for sino in sinos
    ]
    f = BlockFunction(*kl)
    K = BlockOperator(*C) * scale_factor

    # If preconditioned
    if precond:

        def get_nonzero_recip(data):
            """Get the reciprocal of a datacontainer. Voxels where input == 0
            will have their reciprocal set to 1 (instead of infinity)"""
            inv_np = data.as_array()
            inv_np[inv_np == 0] = 1
            inv_np = 1. / inv_np
            data.fill(inv_np)

        tau = K.adjoint(K.range_geometry().allocate(1))
        get_nonzero_recip(tau)

        tmp_sigma = K.direct(K.domain_geometry().allocate(1))
        sigma = 0. * tmp_sigma
        get_nonzero_recip(sigma[0])

        def precond_proximal(self, x, tau, out=None):
            """Modify proximal method to work with preconditioned tau"""
            pars = {
                'algorithm':
                FGP_TV,
                'input':
                np.asarray(x.as_array() / tau.as_array(), dtype=np.float32),
                'regularization_parameter':
                self.lambdaReg,
                'number_of_iterations':
                self.iterationsTV,
                'tolerance_constant':
                self.tolerance,
                'methodTV':
                self.methodTV,
                'nonneg':
                self.nonnegativity,
                'printingOut':
                self.printing
            }

            res, info = regularisers.FGP_TV(pars['input'],
                                            pars['regularization_parameter'],
                                            pars['number_of_iterations'],
                                            pars['tolerance_constant'],
                                            pars['methodTV'], pars['nonneg'],
                                            self.device)
            if out is not None:
                out.fill(res)
            else:
                out = x.copy()
                out.fill(res)
            out *= tau
            return out

        FGP_TV.proximal = precond_proximal
        print("Will run proximal with preconditioned tau...")

    # If not preconditioned
    else:
        sigma = float(args['--sigma'])
        # If we need to calculate default tau
        if args['--tau']:
            tau = float(args['--tau'])
        else:
            tau = 1 / (sigma * normK**2)

    if regularisation == 'none':
        G = IndicatorBox(lower=0)
    elif regularisation == 'FGP_TV':
        r_iterations = float(args['--reg_iters'])
        r_tolerance = 1e-7
        r_iso = 0
        r_nonneg = 1
        r_printing = 0
        device = 'gpu' if use_gpu else 'cpu'
        G = FGP_TV(r_alpha, r_iterations, r_tolerance, r_iso, r_nonneg,
                   r_printing, device)
    else:
        raise error("Unknown regularisation")

    if precond:

        def PDHG_new_update(self):
            """Modify the PDHG update to allow preconditioning"""
            # save previous iteration
            self.x_old.fill(self.x)
            self.y_old.fill(self.y)

            # Gradient ascent for the dual variable
            self.operator.direct(self.xbar, out=self.y_tmp)
            self.y_tmp *= self.sigma
            self.y_tmp += self.y_old

            self.f.proximal_conjugate(self.y_tmp, self.sigma, out=self.y)

            # Gradient descent for the primal variable
            self.operator.adjoint(self.y, out=self.x_tmp)
            self.x_tmp *= -1 * self.tau
            self.x_tmp += self.x_old

            self.g.proximal(self.x_tmp, self.tau, out=self.x)

            # Update
            self.x.subtract(self.x_old, out=self.xbar)
            self.xbar *= self.theta
            self.xbar += self.x

        PDHG.update = PDHG_new_update

    # Get filename
    outp_file = outp_prefix
    if descriptive_fname:
        if len(attn_files) > 0:
            outp_file += "_wAC"
        if norm_file:
            outp_file += "_wNorm"
        if use_gpu:
            outp_file += "_wGPU"
        outp_file += "_Reg-" + regularisation
        if regularisation == 'FGP_TV':
            outp_file += "-alpha" + str(r_alpha)
            outp_file += "-riters" + str(r_iterations)
        if args['--normK']:
            outp_file += '_userNormK' + str(normK)
        else:
            outp_file += '_calcNormK' + str(normK)
        if args['--normaliseDataAndBlock']:
            outp_file += '_wDataScale'
        else:
            outp_file += '_noDataScale'
        if not precond:
            outp_file += "_sigma" + str(sigma)
            outp_file += "_tau" + str(tau)
        else:
            outp_file += "_wPrecond"
        outp_file += "_nGates" + str(len(sino_files))
        if resamplers is None:
            outp_file += "_noMotion"

    pdhg = PDHG(f=f,
                g=G,
                operator=K,
                sigma=sigma,
                tau=tau,
                max_iteration=num_iters,
                update_objective_interval=update_obj_fn_interval,
                x_init=image,
                log_file=outp_file + ".log")

    def callback_save(iteration, objective_value, solution):
        """Callback function to save images"""
        if (iteration + 1) % save_interval == 0:
            out = solution if not nifti else reg.NiftiImageData(solution)
            out.write(outp_file + "_iters" + str(iteration + 1))

    pdhg.run(iterations=num_iters,
             callback=callback_save,
             verbose=True,
             very_verbose=True)

    if visualisations:
        # show reconstructed image
        out = pdhg.get_output()
        out_arr = out.as_array()
        z = out_arr.shape[0] // 2
        show_2D_array('Reconstructed image', out.as_array()[z, :, :])
        pylab.show()