def gather_file(args): if args.momentum: file = common.AsNPCopy(common.LoadITKField(args.files[0], ca.MEM_HOST)) else: #file = common.AsNPCopy(common.LoadITKImage(args.files[0], ca.MEM_HOST)) file = nib.load(args.files[0]).get_data() all_size = (len(args.files), ) + (15, 15, 15, 1, 3) #file.shape; data = torch.zeros(all_size) for i in range(0, len(args.files)): if args.momentum: cur_slice = torch.from_numpy( common.AsNPCopy(common.LoadITKField(args.files[i], ca.MEM_HOST))) else: cur_slice = nib.load(args.files[i]).get_data() if cur_slice.size == 3375: cur_slice = np.zeros([15, 15, 15, 1, 3]) cur_slice = torch.from_numpy(cur_slice) data[i] = cur_slice if args.momentum: # transpose the dataset to fit the training format data = data.numpy() data = np.transpose(data, [0, 4, 1, 2, 3]) data = torch.from_numpy(data) torch.save(data, args.output)
def predict_each_datapart(args, net, network_config, input_batch, datapart_idx, batch_size, patch_size, predict_transform_space): moving_image = torch.load(args.moving_image_dataset[datapart_idx]) target_image = torch.load(args.target_image_dataset[datapart_idx]) optimization_momentum = torch.load(args.deformation_parameter[datapart_idx]) for slice_idx in range(0, moving_image.size()[0]): print(slice_idx) moving_slice = moving_image[slice_idx].numpy() target_slice = target_image[slice_idx].numpy() if predict_transform_space: moving_slice = util.convert_to_registration_space(moving_slice) target_slice = util.convert_to_registration_space(target_slice) predicted_momentum = util.predict_momentum(moving_slice, target_slice, input_batch, batch_size, patch_size, net, predict_transform_space); m0_reg = common.FieldFromNPArr(predicted_momentum['image_space'], ca.MEM_DEVICE); moving_image_ca = common.ImFromNPArr(moving_slice, ca.MEM_DEVICE) target_image_ca = common.ImFromNPArr(target_slice, ca.MEM_DEVICE) registration_result = registration_methods.geodesic_shooting(moving_image_ca, target_image_ca, m0_reg, args.shoot_steps, ca.MEM_DEVICE, network_config) target_inv = common.AsNPCopy(registration_result['I1_inv']) print(target_inv.shape) if predict_transform_space: target_inv = util.convert_to_predict_space(target_inv) print(target_inv.shape) target_inv = torch.from_numpy(target_inv) target_image[slice_idx] = target_inv optimization_momentum[slice_idx] = optimization_momentum[slice_idx] - torch.from_numpy(predicted_momentum['prediction_space']) torch.save(target_image, args.warped_back_target_output[datapart_idx]) torch.save(optimization_momentum, args.momentum_residual[datapart_idx])
def preprocess_image(image_pyca, histeq): image_np = common.AsNPCopy(image_pyca) nan_mask = np.isnan(image_np) image_np[nan_mask] = 0 image_np /= np.amax(image_np) # perform histogram equalization if needed if histeq: image_np[image_np != 0] = exposure.equalize_hist(image_np[image_np != 0]) return image_np
def intensity_normalization_histeq(args): for i in range(0, len(args.input_images)): image = common.LoadITKImage(args.output_images[i], ca.MEM_HOST) grid = image.grid() image_np = common.AsNPCopy(image) nan_mask = np.isnan(image_np) image_np[nan_mask] = 0 image_np /= np.amax(image_np) # perform histogram equalization if needed if args.histeq: image_np[image_np != 0] = exposure.equalize_hist( image_np[image_np != 0]) image_result = common.ImFromNPArr(image_np, ca.MEM_HOST) image_result.setGrid(grid) common.SaveITKImage(image_result, args.output_images[i])
def predict_image(args, moving_images, target_images, output_prefixes): if (args.use_CPU_for_shooting): mType = ca.MEM_HOST else: mType = ca.MEM_DEVICE # load the prediction network predict_network_config = torch.load(args.prediction_parameter) prediction_net = create_net(args, predict_network_config); batch_size = args.batch_size patch_size = predict_network_config['patch_size'] input_batch = torch.zeros(batch_size, 2, patch_size, patch_size, patch_size).cuda() # use correction network if required if args.use_correction: correction_network_config = torch.load(args.correction_parameter); correction_net = create_net(args, correction_network_config); else: correction_net = None; # start prediction for i in range(0, len(moving_images)): common.Mkdir_p(os.path.dirname(output_prefixes[i])) if (args.affine_align): # Perform affine registration to both moving and target image to the ICBM152 atlas space. # Registration is done using Niftireg. call(["reg_aladin", "-noSym", "-speeeeed", "-ref", args.atlas , "-flo", moving_images[i], "-res", output_prefixes[i]+"moving_affine.nii", "-aff", output_prefixes[i]+'moving_affine_transform.txt']) call(["reg_aladin", "-noSym", "-speeeeed" ,"-ref", args.atlas , "-flo", target_images[i], "-res", output_prefixes[i]+"target_affine.nii", "-aff", output_prefixes[i]+'target_affine_transform.txt']) moving_image = common.LoadITKImage(output_prefixes[i]+"moving_affine.nii", mType) target_image = common.LoadITKImage(output_prefixes[i]+"target_affine.nii", mType) else: moving_image = common.LoadITKImage(moving_images[i], mType) target_image = common.LoadITKImage(target_images[i], mType) #preprocessing of the image moving_image_np = preprocess_image(moving_image, args.histeq); target_image_np = preprocess_image(target_image, args.histeq); grid = moving_image.grid() #moving_image = ca.Image3D(grid, mType) #target_image = ca.Image3D(grid, mType) moving_image_processed = common.ImFromNPArr(moving_image_np, mType) target_image_processed = common.ImFromNPArr(target_image_np, mType) moving_image.setGrid(grid) target_image.setGrid(grid) # Indicating whether we are using the old parameter files for the Neuroimage experiments (use .t7 files from matlab .h5 format) predict_transform_space = False if 'matlab_t7' in predict_network_config: predict_transform_space = True # run actual prediction prediction_result = util.predict_momentum(moving_image_np, target_image_np, input_batch, batch_size, patch_size, prediction_net, predict_transform_space); m0 = prediction_result['image_space'] #convert to registration space and perform registration m0_reg = common.FieldFromNPArr(m0, mType); #perform correction if (args.use_correction): registration_result = registration_methods.geodesic_shooting(moving_image_processed, target_image_processed, m0_reg, args.shoot_steps, mType, predict_network_config) target_inv_np = common.AsNPCopy(registration_result['I1_inv']) correct_transform_space = False if 'matlab_t7' in correction_network_config: correct_transform_space = True correction_result = util.predict_momentum(moving_image_np, target_inv_np, input_batch, batch_size, patch_size, correction_net, correct_transform_space); m0_correct = correction_result['image_space'] m0 += m0_correct; m0_reg = common.FieldFromNPArr(m0, mType); registration_result = registration_methods.geodesic_shooting(moving_image, target_image, m0_reg, args.shoot_steps, mType, predict_network_config) #endif write_result(registration_result, output_prefixes[i]);
def predict_image(args): if (args.use_CPU_for_shooting): mType = ca.MEM_HOST else: mType = ca.MEM_DEVICE # load the prediction network predict_network_config = torch.load(args.prediction_parameter) prediction_net = create_net(args, predict_network_config) batch_size = args.batch_size patch_size = predict_network_config['patch_size'] input_batch = torch.zeros(batch_size, 2, patch_size, patch_size, patch_size).cuda() # start prediction for i in range(0, len(args.moving_image)): common.Mkdir_p(os.path.dirname(args.output_prefix[i])) if (args.affine_align): # Perform affine registration to both moving and target image to the ICBM152 atlas space. # Registration is done using Niftireg. call([ "reg_aladin", "-noSym", "-speeeeed", "-ref", args.atlas, "-flo", args.moving_image[i], "-res", args.output_prefix[i] + "moving_affine.nii", "-aff", args.output_prefix[i] + 'moving_affine_transform.txt' ]) call([ "reg_aladin", "-noSym", "-speeeeed", "-ref", args.atlas, "-flo", args.target_image[i], "-res", args.output_prefix[i] + "target_affine.nii", "-aff", args.output_prefix[i] + 'target_affine_transform.txt' ]) moving_image = common.LoadITKImage( args.output_prefix[i] + "moving_affine.nii", mType) target_image = common.LoadITKImage( args.output_prefix[i] + "target_affine.nii", mType) else: moving_image = common.LoadITKImage(args.moving_image[i], mType) target_image = common.LoadITKImage(args.target_image[i], mType) #preprocessing of the image moving_image_np = preprocess_image(moving_image, args.histeq) target_image_np = preprocess_image(target_image, args.histeq) grid = moving_image.grid() moving_image_processed = common.ImFromNPArr(moving_image_np, mType) target_image_processed = common.ImFromNPArr(target_image_np, mType) moving_image.setGrid(grid) target_image.setGrid(grid) predict_transform_space = False if 'matlab_t7' in predict_network_config: predict_transform_space = True # run actual prediction prediction_result = util.predict_momentum(moving_image_np, target_image_np, input_batch, batch_size, patch_size, prediction_net, predict_transform_space) m0 = prediction_result['image_space'] m0_reg = common.FieldFromNPArr(prediction_result['image_space'], mType) registration_result = registration_methods.geodesic_shooting( moving_image_processed, target_image_processed, m0_reg, args.shoot_steps, mType, predict_network_config) phi = common.AsNPCopy(registration_result['phiinv']) phi_square = np.power(phi, 2) for sample_iter in range(1, args.samples): print(sample_iter) prediction_result = util.predict_momentum( moving_image_np, target_image_np, input_batch, batch_size, patch_size, prediction_net, predict_transform_space) m0 += prediction_result['image_space'] m0_reg = common.FieldFromNPArr(prediction_result['image_space'], mType) registration_result = registration_methods.geodesic_shooting( moving_image_processed, target_image_processed, m0_reg, args.shoot_steps, mType, predict_network_config) phi += common.AsNPCopy(registration_result['phiinv']) phi_square += np.power( common.AsNPCopy(registration_result['phiinv']), 2) m0_mean = np.divide(m0, args.samples) m0_reg = common.FieldFromNPArr(m0_mean, mType) registration_result = registration_methods.geodesic_shooting( moving_image_processed, target_image_processed, m0_reg, args.shoot_steps, mType, predict_network_config) phi_mean = registration_result['phiinv'] phi_var = np.divide(phi_square, args.samples) - np.power( np.divide(phi, args.samples), 2) #save result common.SaveITKImage(registration_result['I1'], args.output_prefix[i] + "I1.mhd") common.SaveITKField(phi_mean, args.output_prefix[i] + "phiinv_mean.mhd") common.SaveITKField(common.FieldFromNPArr(phi_var, mType), args.output_prefix[i] + "phiinv_var.mhd")