def objective_function(optimise_array, static_image, dynamic_path, dvf_path, weighted_normalise, dynamic_data_magnitude): static_image.fill( np.reshape(optimise_array, static_image.as_array().astype(np.double).shape)) objective_value = 0.0 for i in range(len(dynamic_path)): dynamic_image = reg.NiftiImageData(dynamic_path[i]) dvf_image = reg.NiftiImageData3DDeformation(dvf_path[i]) resampler = reg.NiftyResample() resampler.set_reference_image(static_image) resampler.set_floating_image(dynamic_image) resampler.add_transformation(dvf_image) resampler.set_interpolation_type_to_cubic_spline() objective_value = objective_value + (np.nansum( np.square(dynamic_image.as_array().astype(np.double) - ((np.nansum(dynamic_image.as_array().astype(np.double), dtype=np.double) / dynamic_data_magnitude) * warp_image_forward(resampler, static_image)), dtype=np.double), dtype=np.double) * weighted_normalise[i]) print("Objective function value: {0}".format(str(objective_value))) return objective_value
def get_resamplers(static_image, dynamic_array, dvf_array, output_path): resamplers = [] static_image_path = "{0}/temp_static.nii".format(output_path) dynamic_array_path = "{0}/temp_dynamic.nii".format(output_path) dvf_array_path = "{0}/temp_dvf.nii".format(output_path) for j in range(len(dynamic_array)): resampler = reg.NiftyResample() static_image.write(static_image_path) dynamic_array[j].write(dynamic_array_path) dvf_array[j].write(dvf_array_path) temp_static = reg.NiftiImageData(static_image_path) temp_dynamic = reg.NiftiImageData(dynamic_array_path) temp_dvf = reg.NiftiImageData3DDeformation(dvf_array_path) resampler.set_reference_image(temp_static) resampler.set_floating_image(temp_dynamic) resampler.add_transformation(temp_dvf) resampler.set_interpolation_type_to_linear() resamplers.append(resampler) return resamplers
def output_input(static_image, dynamic_path, dvf_path, output_path): static_image.write("{0}/static_image.nii".format(output_path)) for i in range(len(dynamic_path)): dynamic_image = reg.NiftiImageData(dynamic_path[i]) dvf_image = reg.NiftiImageData3DDeformation(dvf_path[i]) dynamic_image.write("{0}/dynamic_image_{1}.nii".format( output_path, str(i))) dvf_image.write("{0}/dvf_image_{1}.nii".format(output_path, str(i))) return True
def test_for_adj(static_image, dvf_array, output_path): static_image_path = "{0}/temp_static.nii".format(output_path) dvf_array_path = "{0}/temp_dvf.nii".format(output_path) for i in range(len(dvf_array)): static_image.write(static_image_path) dvf_array[i].write(dvf_array_path) temp_static = reg.NiftiImageData(static_image_path) temp_dvf = reg.NiftiImageData3DDeformation(dvf_array_path) resampler = reg.NiftyResample() resampler.set_reference_image(temp_static) resampler.set_floating_image(temp_static) resampler.add_transformation(temp_dvf) resampler.set_interpolation_type_to_linear() warp = warp_image_forward(resampler, temp_static) warped_image = static_image.clone() warped_image.fill(warp) warped_image.write("{0}/warp_forward_{1}.nii".format( output_path, str(i))) difference = temp_static.as_array().astype(np.double) - warp difference_image = temp_static.clone() difference_image.fill(difference) difference_image.write("{0}/warp_forward_difference_{1}.nii".format( output_path, str(i))) warp = warp_image_adjoint(resampler, temp_static) warped_image = temp_static.clone() warped_image.fill(warp) warped_image.write("{0}/warp_adjoint_{1}.nii".format( output_path, str(i))) difference = temp_static.as_array().astype(np.double) - warp difference_image = temp_static.clone() difference_image.fill(difference) difference_image.write("{0}/warp_adjoint_difference_{1}.nii".format( output_path, str(i))) return True
def gradient_function(optimise_array, static_image, dynamic_path, dvf_path, weighted_normalise, dynamic_data_magnitude): static_image.fill( np.reshape(optimise_array, static_image.as_array().astype(np.double).shape)) gradient_value = static_image.clone() gradient_value.fill(0.0) adjoint_image = static_image.clone() for i in range(len(dynamic_path)): dynamic_image = reg.NiftiImageData(dynamic_path[i]) dvf_image = reg.NiftiImageData3DDeformation(dvf_path[i]) resampler = reg.NiftyResample() resampler.set_reference_image(static_image) resampler.set_floating_image(dynamic_image) resampler.add_transformation(dvf_image) resampler.set_interpolation_type_to_cubic_spline() adjoint_image.fill(( (np.nansum(dynamic_image.as_array().astype(np.double), dtype=np.double) / dynamic_data_magnitude) * warp_image_forward(resampler, static_image)) - dynamic_image.as_array().astype(np.double)) gradient_value.fill((gradient_value.as_array().astype(np.double) + (warp_image_adjoint(resampler, adjoint_image) * weighted_normalise[i]))) # gradient_value.write("{0}/gradient.nii".format(output_path)) print( "Max gradient value: {0}, Mean gradient value: {1}, Gradient norm: {2}" .format( str(np.amax(gradient_value.as_array().astype(np.double))), str( np.nanmean( np.abs(gradient_value.as_array().astype(np.double), dtype=np.double))), str(np.linalg.norm(gradient_value.as_array().astype(np.double))))) return np.ravel(gradient_value.as_array().astype(np.double)).astype( np.double)
def test_for_adj(static_image, dvf_path, output_path): for i in range(len(dvf_path)): dvf_image = reg.NiftiImageData3DDeformation(dvf_path[i]) resampler = reg.NiftyResample() resampler.set_reference_image(static_image) resampler.set_floating_image(static_image) resampler.add_transformation(dvf_image) resampler.set_interpolation_type_to_cubic_spline() warp = warp_image_forward(resampler, static_image) warped_image = static_image.clone() warped_image.fill(warp) warped_image.write("{0}/warp_forward_{1}.nii".format( output_path, str(i))) difference = static_image.as_array().astype(np.double) - warp difference_image = static_image.clone() difference_image.fill(difference) difference_image.write("{0}/warp_forward_difference_{1}.nii".format( output_path, str(i))) warp = warp_image_adjoint(resampler, static_image) warped_image = static_image.clone() warped_image.fill(warp) warped_image.write("{0}/warp_adjoint_{1}.nii".format( output_path, str(i))) difference = static_image.as_array().astype(np.double) - warp difference_image = static_image.clone() difference_image.fill(difference) difference_image.write("{0}/warp_adjoint_difference_{1}.nii".format( output_path, str(i))) return True
def resample_attn_image(image): """Resample the attenuation image.""" if trans_type == 'tm': transformation = reg.AffineTransformation(trans) elif trans_type == 'disp': transformation = reg.NiftiImageData3DDisplacement(trans) elif trans_type == 'def': transformation = reg.NiftiImageData3DDeformation(trans) else: raise ValueError("Unknown transformation type.") resampler = reg.NiftyResample() resampler.set_reference_image(image) resampler.set_floating_image(image) resampler.set_interpolation_type_to_linear() resampler.set_padding_value(0.0) resampler.add_transformation(transformation) return resampler.forward(image)
def back_warp(static_path, dvf_path, output_path): if not os.path.exists(output_path): os.makedirs(output_path, mode=0o770) for i in range(len(dvf_path)): static_image = reg.NiftiImageData(static_path) dvf_image = reg.NiftiImageData3DDeformation(dvf_path[i]) resampler = reg.NiftyResample() resampler.set_reference_image(static_image) resampler.set_floating_image(static_image) resampler.add_transformation(dvf_image) resampler.set_interpolation_type_to_cubic_spline() warped_static_image = warp_image_forward(resampler, static_image) static_image.fill(warped_static_image) static_image.write("{0}/back_warped_{1}.nii".format( output_path, str(i))) return True
def main(): # file paths to data input_data_path = parser.parser(sys.argv[1], "data_path:=") data_split = parser.parser(sys.argv[1], "data_split:=") input_dvf_path = parser.parser(sys.argv[1], "dvf_path:=") dvf_split = parser.parser(sys.argv[1], "dvf_split:=") output_path = parser.parser(sys.argv[1], "output_path:=") do_op_test = parser.parser(sys.argv[1], "do_op_test:=") do_reg = parser.parser(sys.argv[1], "do_reg:=") do_test_for_adj = parser.parser(sys.argv[1], "do_test_for_adj:=") for i in range(len(input_data_path)): if not os.path.exists(output_path[i]): os.makedirs(output_path[i], mode=0o770) new_dvf_path = "{0}/new_dvfs/".format(output_path[i]) if not os.path.exists(new_dvf_path): os.makedirs(new_dvf_path, mode=0o770) # get static and dynamic paths dynamic_path = get_data_path(input_data_path[i], data_split[i]) # load dynamic objects dynamic_array = [] for j in range(len(dynamic_path)): dynamic_array.append(reg.NiftiImageData(dynamic_path[j])) static_path = "{0}/static_path.nii".format(output_path[i]) # load static objects static_image = reg.NiftiImageData(dynamic_path[0]) for j in range(1, len(dynamic_path)): static_image.fill(static_image.as_array().astype(np.double) + dynamic_array[j].as_array().astype(np.double)) static_image.write(static_path) if bool(distutils.util.strtobool(do_op_test[i])): op_test(static_image, output_path[i]) # if do reg the calc dvf if not load if bool(distutils.util.strtobool(do_reg[i])): dvf_path = register_data(static_path, dynamic_path, output_path[i]) else: dvf_path = get_dvf_path(input_dvf_path[i], dvf_split[i]) # fix dvf header and load dvf objects dvf_path = edit_header(dvf_path, new_dvf_path) dvf_array = [] for j in range(len(dvf_path)): dvf_array.append(reg.NiftiImageData3DDeformation(dvf_path[j])) # create object to get forward and adj resamplers = get_resamplers(static_image, dynamic_array, dvf_array, output_path[i]) # test for adj if bool(distutils.util.strtobool(do_test_for_adj[i])): test_for_adj(static_image, dvf_array, output_path[i]) output_input(static_image, dynamic_array, dvf_array, output_path[i]) # initial static image initial_static_image = static_image.clone() # array to optimise optimise_array = static_image.as_array().astype(np.double) # array bounds bounds = [] for j in range(len(np.ravel(optimise_array.copy()))): bounds.append((-np.inf, np.inf)) # optimise optimise_array = np.reshape( scipy.optimize.minimize(objective_function, np.ravel(optimise_array).astype(np.double), args=(resamplers, dynamic_array, static_image, output_path[i]), method="L-BFGS-B", jac=gradient_function, bounds=bounds, tol=0.0000000001, options={ "disp": True }).x, optimise_array.shape) # output static_image.fill(optimise_array) static_image.write("{0}/optimiser_output_{1}.nii".format( output_path[i], str(i))) difference = static_image.as_array().astype( np.double) - initial_static_image.as_array().astype(np.double) difference_image = initial_static_image.clone() difference_image.fill(difference) static_image.write("{0}/optimiser_output_difference_{1}.nii".format( output_path[i], str(i)))
def main(): ########################################################################### # Parse input files ########################################################################### if trans_pattern is None: raise AssertionError("--trans missing") if sino_pattern is None: raise AssertionError("--sino missing") trans_files = sorted(glob(trans_pattern)) sino_files = sorted(glob(sino_pattern)) attn_files = sorted(glob(attn_pattern)) rand_files = sorted(glob(rand_pattern)) num_ms = len(sino_files) # Check some sinograms found if num_ms == 0: raise AssertionError("No sinograms found!") # Should have as many trans as sinos if num_ms != len(trans_files): raise AssertionError("#trans should match #sinos. " "#sinos = " + str(num_ms) + ", #trans = " + str(len(trans_files))) # If any rand, check num == num_ms if len(rand_files) > 0 and len(rand_files) != num_ms: raise AssertionError("#rand should match #sinos. " "#sinos = " + str(num_ms) + ", #rand = " + str(len(rand_files))) # For attn, there should be 0, 1 or num_ms images if len(attn_files) > 1 and len(attn_files) != num_ms: raise AssertionError("#attn should be 0, 1 or #sinos") ########################################################################### # Read input ########################################################################### if trans_type == "tm": trans = [reg.AffineTransformation(file) for file in trans_files] elif trans_type == "disp": trans = [ reg.NiftiImageData3DDisplacement(file) for file in trans_files ] elif trans_type == "def": trans = [reg.NiftiImageData3DDeformation(file) for file in trans_files] else: raise error("Unknown transformation type") sinos_raw = [pet.AcquisitionData(file) for file in sino_files] attns = [pet.ImageData(file) for file in attn_files] rands = [pet.AcquisitionData(file) for file in rand_files] # Loop over all sinograms sinos = [0] * num_ms for ind in range(num_ms): # If any sinograms contain negative values # (shouldn't be the case), set them to 0 sino_arr = sinos_raw[ind].as_array() if (sino_arr < 0).any(): print("Input sinogram " + str(ind) + " contains -ve elements. Setting to 0...") sinos[ind] = sinos_raw[ind].clone() sino_arr[sino_arr < 0] = 0 sinos[ind].fill(sino_arr) else: sinos[ind] = sinos_raw[ind] # If rebinning is desired segs_to_combine = 1 if args['--numSegsToCombine']: segs_to_combine = int(args['--numSegsToCombine']) views_to_combine = 1 if args['--numViewsToCombine']: views_to_combine = int(args['--numViewsToCombine']) if segs_to_combine * views_to_combine > 1: sinos[ind] = sinos[ind].rebin(segs_to_combine, views_to_combine) # only print first time if ind == 0: print(f"Rebinned sino dimensions: {sinos[ind].dimensions()}") ########################################################################### # Initialise recon image ########################################################################### if initial_estimate: image = pet.ImageData(initial_estimate) else: # Create image based on ProjData image = sinos[0].create_uniform_image(0.0, (nxny, nxny)) # If using GPU, need to make sure that image is right size. if use_gpu: dim = (127, 320, 320) spacing = (2.03125, 2.08626, 2.08626) # elif non-default spacing desired elif args['--dxdy']: dim = image.dimensions() dxdy = float(args['--dxdy']) spacing = (image.voxel_sizes()[0], dxdy, dxdy) if use_gpu or args['--dxdy']: image.initialise(dim=dim, vsize=spacing) image.fill(0.0) ########################################################################### # Set up resamplers ########################################################################### resamplers = [get_resampler(image, trans=tran) for tran in trans] ########################################################################### # Resample attenuation images (if necessary) ########################################################################### resampled_attns = None if len(attns) > 0: resampled_attns = [0] * num_ms # if using GPU, dimensions of attn and recon images have to match ref = image if use_gpu else None for i in range(len(attns)): # if we only have 1 attn image, then we need to resample into # space of each gate. However, if we have num_ms attn images, then # assume they are already in the correct position, so use None as # transformation. tran = trans[i] if len(attns) == 1 else None # If only 1 attn image, then resample that. If we have num_ms attn # images, then use each attn image of each frame. attn = attns[0] if len(attns) == 1 else attns[i] resam = get_resampler(attn, ref=ref, trans=tran) resampled_attns[i] = resam.forward(attn) ########################################################################### # Set up acquisition models ########################################################################### print("Setting up acquisition models...") if not use_gpu: acq_models = num_ms * [pet.AcquisitionModelUsingRayTracingMatrix()] else: acq_models = num_ms * [pet.AcquisitionModelUsingNiftyPET()] for acq_model in acq_models: acq_model.set_use_truncation(True) acq_model.set_cuda_verbosity(verbosity) # If present, create ASM from ECAT8 normalisation data asm_norm = None if norm_file: asm_norm = pet.AcquisitionSensitivityModel(norm_file) # Loop over each motion state for ind in range(num_ms): # Create attn ASM if necessary asm_attn = None if resampled_attns: asm_attn = get_asm_attn(sinos[ind], resampled_attns[i], acq_models[ind]) # Get ASM dependent on attn and/or norm asm = None if asm_norm and asm_attn: if ind == 0: print("ASM contains norm and attenuation...") asm = pet.AcquisitionSensitivityModel(asm_norm, asm_attn) elif asm_norm: if ind == 0: print("ASM contains norm...") asm = asm_norm elif asm_attn: if ind == 0: print("ASM contains attenuation...") asm = asm_attn if asm: acq_models[ind].set_acquisition_sensitivity(asm) if len(rands) > 0: acq_models[ind].set_background_term(rands[ind]) # Set up acq_models[ind].set_up(sinos[ind], image) ########################################################################### # Set up reconstructor ########################################################################### print("Setting up reconstructor...") # Create composition operators containing acquisition models and resamplers C = [ CompositionOperator(am, res, preallocate=True) for am, res in zip(*(acq_models, resamplers)) ] # Configure the PDHG algorithm if args['--normK'] and not args['--onlyNormK']: normK = float(args['--normK']) else: kl = [KullbackLeibler(b=sino, eta=(sino * 0 + 1e-5)) for sino in sinos] f = BlockFunction(*kl) K = BlockOperator(*C) # Calculate normK print("Calculating norm of the block operator...") normK = K.norm(iterations=10) print("Norm of the BlockOperator ", normK) if args['--onlyNormK']: exit(0) # Optionally rescale sinograms and BlockOperator using normK scale_factor = 1. / normK if args['--normaliseDataAndBlock'] else 1.0 kl = [ KullbackLeibler(b=sino * scale_factor, eta=(sino * 0 + 1e-5)) for sino in sinos ] f = BlockFunction(*kl) K = BlockOperator(*C) * scale_factor # If preconditioned if precond: def get_nonzero_recip(data): """Get the reciprocal of a datacontainer. Voxels where input == 0 will have their reciprocal set to 1 (instead of infinity)""" inv_np = data.as_array() inv_np[inv_np == 0] = 1 inv_np = 1. / inv_np data.fill(inv_np) tau = K.adjoint(K.range_geometry().allocate(1)) get_nonzero_recip(tau) tmp_sigma = K.direct(K.domain_geometry().allocate(1)) sigma = 0. * tmp_sigma get_nonzero_recip(sigma[0]) def precond_proximal(self, x, tau, out=None): """Modify proximal method to work with preconditioned tau""" pars = { 'algorithm': FGP_TV, 'input': np.asarray(x.as_array() / tau.as_array(), dtype=np.float32), 'regularization_parameter': self.lambdaReg, 'number_of_iterations': self.iterationsTV, 'tolerance_constant': self.tolerance, 'methodTV': self.methodTV, 'nonneg': self.nonnegativity, 'printingOut': self.printing } res, info = regularisers.FGP_TV(pars['input'], pars['regularization_parameter'], pars['number_of_iterations'], pars['tolerance_constant'], pars['methodTV'], pars['nonneg'], self.device) if out is not None: out.fill(res) else: out = x.copy() out.fill(res) out *= tau return out FGP_TV.proximal = precond_proximal print("Will run proximal with preconditioned tau...") # If not preconditioned else: sigma = float(args['--sigma']) # If we need to calculate default tau if args['--tau']: tau = float(args['--tau']) else: tau = 1 / (sigma * normK**2) if regularisation == 'none': G = IndicatorBox(lower=0) elif regularisation == 'FGP_TV': r_iterations = float(args['--reg_iters']) r_tolerance = 1e-7 r_iso = 0 r_nonneg = 1 r_printing = 0 device = 'gpu' if use_gpu else 'cpu' G = FGP_TV(r_alpha, r_iterations, r_tolerance, r_iso, r_nonneg, r_printing, device) else: raise error("Unknown regularisation") if precond: def PDHG_new_update(self): """Modify the PDHG update to allow preconditioning""" # save previous iteration self.x_old.fill(self.x) self.y_old.fill(self.y) # Gradient ascent for the dual variable self.operator.direct(self.xbar, out=self.y_tmp) self.y_tmp *= self.sigma self.y_tmp += self.y_old self.f.proximal_conjugate(self.y_tmp, self.sigma, out=self.y) # Gradient descent for the primal variable self.operator.adjoint(self.y, out=self.x_tmp) self.x_tmp *= -1 * self.tau self.x_tmp += self.x_old self.g.proximal(self.x_tmp, self.tau, out=self.x) # Update self.x.subtract(self.x_old, out=self.xbar) self.xbar *= self.theta self.xbar += self.x PDHG.update = PDHG_new_update # Get filename outp_file = outp_prefix if descriptive_fname: if len(attn_files) > 0: outp_file += "_wAC" if norm_file: outp_file += "_wNorm" if use_gpu: outp_file += "_wGPU" outp_file += "_Reg-" + regularisation if regularisation == 'FGP_TV': outp_file += "-alpha" + str(r_alpha) outp_file += "-riters" + str(r_iterations) if args['--normK']: outp_file += '_userNormK' + str(normK) else: outp_file += '_calcNormK' + str(normK) if args['--normaliseDataAndBlock']: outp_file += '_wDataScale' else: outp_file += '_noDataScale' if not precond: outp_file += "_sigma" + str(sigma) outp_file += "_tau" + str(tau) else: outp_file += "_wPrecond" outp_file += "_nGates" + str(len(sino_files)) if resamplers is None: outp_file += "_noMotion" pdhg = PDHG(f=f, g=G, operator=K, sigma=sigma, tau=tau, max_iteration=num_iters, update_objective_interval=update_obj_fn_interval, x_init=image, log_file=outp_file + ".log") def callback_save(iteration, objective_value, solution): """Callback function to save images""" if (iteration + 1) % save_interval == 0: out = solution if not nifti else reg.NiftiImageData(solution) out.write(outp_file + "_iters" + str(iteration + 1)) pdhg.run(iterations=num_iters, callback=callback_save, verbose=True, very_verbose=True) if visualisations: # show reconstructed image out = pdhg.get_output() out_arr = out.as_array() z = out_arr.shape[0] // 2 show_2D_array('Reconstructed image', out.as_array()[z, :, :]) pylab.show()